trillium_websockets/
websocket_connection.rs1use crate::{Result, Role, WebSocketConfig};
2use async_tungstenite::{
3 tungstenite::{self, Message},
4 WebSocketStream,
5};
6use futures_util::{
7 stream::{SplitSink, SplitStream, Stream},
8 SinkExt, StreamExt,
9};
10use std::{
11 net::IpAddr,
12 pin::Pin,
13 task::{Context, Poll},
14};
15use stopper::{Stopper, StreamStopper};
16use trillium::{Headers, Method, StateSet, Upgrade};
17use trillium_http::transport::BoxedTransport;
18
19#[derive(Debug)]
31pub struct WebSocketConn {
32 request_headers: Headers,
33 path: String,
34 method: Method,
35 state: StateSet,
36 peer_ip: Option<IpAddr>,
37 stopper: Stopper,
38 sink: SplitSink<Wss, Message>,
39 stream: Option<WStream>,
40}
41
42type Wss = WebSocketStream<BoxedTransport>;
43
44impl WebSocketConn {
45 pub async fn send_string(&mut self, string: String) -> Result<()> {
47 self.send(Message::Text(string)).await.map_err(Into::into)
48 }
49
50 pub async fn send_bytes(&mut self, bin: Vec<u8>) -> Result<()> {
52 self.send(Message::Binary(bin)).await.map_err(Into::into)
53 }
54
55 #[cfg(feature = "json")]
56 pub async fn send_json(&mut self, json: &impl serde::Serialize) -> Result<()> {
59 self.send_string(serde_json::to_string(json)?).await
60 }
61
62 pub async fn send(&mut self, message: Message) -> Result<()> {
64 self.sink.send(message).await.map_err(Into::into)
65 }
66
67 #[doc(hidden)]
72 pub async fn new(upgrade: Upgrade, config: Option<WebSocketConfig>, role: Role) -> Self {
73 let Upgrade {
74 request_headers,
75 path,
76 method,
77 state,
78 buffer,
79 transport,
80 stopper,
81 } = upgrade;
82
83 let wss = if let Some(vec) = buffer {
84 WebSocketStream::from_partially_read(transport, vec, role, config).await
85 } else {
86 WebSocketStream::from_raw_socket(transport, role, config).await
87 };
88
89 let (sink, stream) = wss.split();
90 let stream = Some(WStream {
91 stream: stopper.stop_stream(stream),
92 });
93
94 Self {
95 request_headers,
96 path,
97 method,
98 state,
99 peer_ip: None,
100 sink,
101 stream,
102 stopper,
103 }
104 }
105
106 pub fn stopper(&self) -> Stopper {
108 self.stopper.clone()
109 }
110
111 pub async fn close(&mut self) -> Result<()> {
113 self.send(Message::Close(None)).await
114 }
115
116 pub fn headers(&self) -> &Headers {
118 &self.request_headers
119 }
120
121 pub fn peer_ip(&self) -> Option<IpAddr> {
123 self.peer_ip
124 }
125
126 pub fn set_peer_ip(&mut self, peer_ip: Option<IpAddr>) {
128 self.peer_ip = peer_ip
129 }
130
131 pub fn path(&self) -> &str {
136 self.path.split('?').next().unwrap_or_default()
137 }
138
139 pub fn querystring(&self) -> &str {
144 self.path
145 .split_once('?')
146 .map(|(_, q)| q)
147 .unwrap_or_default()
148 }
149
150 pub fn method(&self) -> Method {
152 self.method
153 }
154
155 pub fn state<T: 'static>(&self) -> Option<&T> {
162 self.state.get()
163 }
164
165 pub fn state_mut<T: 'static>(&mut self) -> Option<&mut T> {
169 self.state.get_mut()
170 }
171
172 #[deprecated = "use WebsocketConn::insert_state"]
174 pub fn set_state<T: Send + Sync + 'static>(&mut self, state: T) {
175 self.insert_state(state);
176 }
177
178 pub fn insert_state<T: Send + Sync + 'static>(&mut self, state: T) -> Option<T> {
182 self.state.insert(state)
183 }
184
185 pub fn take_state<T: 'static>(&mut self) -> Option<T> {
192 self.state.take()
193 }
194
195 pub fn take_inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult>> {
197 self.stream.take()
198 }
199
200 pub fn inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult> + '_> {
202 self.stream.as_mut()
203 }
204}
205
206type MessageResult = std::result::Result<Message, tungstenite::Error>;
207
208#[derive(Debug)]
209pub struct WStream {
210 stream: StreamStopper<SplitStream<Wss>>,
211}
212
213impl Stream for WStream {
214 type Item = MessageResult;
215
216 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
217 self.stream.poll_next_unpin(cx)
218 }
219}
220
221impl AsMut<StateSet> for WebSocketConn {
222 fn as_mut(&mut self) -> &mut StateSet {
223 &mut self.state
224 }
225}
226
227impl AsRef<StateSet> for WebSocketConn {
228 fn as_ref(&self) -> &StateSet {
229 &self.state
230 }
231}
232
233impl Stream for WebSocketConn {
234 type Item = MessageResult;
235
236 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
237 match self.stream.as_mut() {
238 Some(stream) => stream.poll_next_unpin(cx),
239 None => Poll::Ready(None),
240 }
241 }
242}