Skip to main content

trillium_websockets/
websocket_connection.rs

1use crate::{Result, Role, WebSocketConfig};
2use async_tungstenite::{
3    tungstenite::{self, Message},
4    WebSocketStream,
5};
6use futures_util::{
7    stream::{SplitSink, SplitStream, Stream},
8    SinkExt, StreamExt,
9};
10use std::{
11    net::IpAddr,
12    pin::Pin,
13    task::{Context, Poll},
14};
15use stopper::{Stopper, StreamStopper};
16use trillium::{Headers, Method, StateSet, Upgrade};
17use trillium_http::transport::BoxedTransport;
18
19/**
20A struct that represents an specific websocket connection.
21
22This can be thought of as a combination of a [`async_tungstenite::WebSocketStream`] and a
23[`trillium::Conn`], as it contains a combination of their fields and
24associated functions.
25
26The WebSocketConn implements `Stream<Item=Result<Message, Error>>`,
27and can be polled with `StreamExt::next`
28 */
29
30#[derive(Debug)]
31pub struct WebSocketConn {
32    request_headers: Headers,
33    path: String,
34    method: Method,
35    state: StateSet,
36    peer_ip: Option<IpAddr>,
37    stopper: Stopper,
38    sink: SplitSink<Wss, Message>,
39    stream: Option<WStream>,
40}
41
42type Wss = WebSocketStream<BoxedTransport>;
43
44impl WebSocketConn {
45    /// send a [`Message::Text`] variant
46    pub async fn send_string(&mut self, string: String) -> Result<()> {
47        self.send(Message::Text(string)).await.map_err(Into::into)
48    }
49
50    /// send a [`Message::Binary`] variant
51    pub async fn send_bytes(&mut self, bin: Vec<u8>) -> Result<()> {
52        self.send(Message::Binary(bin)).await.map_err(Into::into)
53    }
54
55    #[cfg(feature = "json")]
56    /// send a [`Message::Text`] that contains json
57    /// note that json messages are not actually part of the websocket specification
58    pub async fn send_json(&mut self, json: &impl serde::Serialize) -> Result<()> {
59        self.send_string(serde_json::to_string(json)?).await
60    }
61
62    /// Sends a [`Message`] to the client
63    pub async fn send(&mut self, message: Message) -> Result<()> {
64        self.sink.send(message).await.map_err(Into::into)
65    }
66
67    /// Create a `WebSocketConn` from an HTTP upgrade, with optional config and the specified role
68    ///
69    /// You should not typically need to call this; the trillium client and server both provide
70    /// your code with a `WebSocketConn`.
71    #[doc(hidden)]
72    pub async fn new(upgrade: Upgrade, config: Option<WebSocketConfig>, role: Role) -> Self {
73        let Upgrade {
74            request_headers,
75            path,
76            method,
77            state,
78            buffer,
79            transport,
80            stopper,
81        } = upgrade;
82
83        let wss = if let Some(vec) = buffer {
84            WebSocketStream::from_partially_read(transport, vec, role, config).await
85        } else {
86            WebSocketStream::from_raw_socket(transport, role, config).await
87        };
88
89        let (sink, stream) = wss.split();
90        let stream = Some(WStream {
91            stream: stopper.stop_stream(stream),
92        });
93
94        Self {
95            request_headers,
96            path,
97            method,
98            state,
99            peer_ip: None,
100            sink,
101            stream,
102            stopper,
103        }
104    }
105
106    /// retrieve a clone of the server's [`Stopper`]
107    pub fn stopper(&self) -> Stopper {
108        self.stopper.clone()
109    }
110
111    /// close the websocket connection gracefully
112    pub async fn close(&mut self) -> Result<()> {
113        self.send(Message::Close(None)).await
114    }
115
116    /// retrieve the request headers for this conn
117    pub fn headers(&self) -> &Headers {
118        &self.request_headers
119    }
120
121    /// retrieves the peer ip for this conn, if available
122    pub fn peer_ip(&self) -> Option<IpAddr> {
123        self.peer_ip
124    }
125
126    /// Sets the peer ip for this conn
127    pub fn set_peer_ip(&mut self, peer_ip: Option<IpAddr>) {
128        self.peer_ip = peer_ip
129    }
130
131    /**
132    retrieves the path part of the request url, up to and excluding
133    any query component
134     */
135    pub fn path(&self) -> &str {
136        self.path.split('?').next().unwrap_or_default()
137    }
138
139    /**
140    Retrieves the query component of the path, excluding `?`. Returns
141    an empty string if there is no query component.
142     */
143    pub fn querystring(&self) -> &str {
144        self.path
145            .split_once('?')
146            .map(|(_, q)| q)
147            .unwrap_or_default()
148    }
149
150    /// retrieve the request method for this conn
151    pub fn method(&self) -> Method {
152        self.method
153    }
154
155    /**
156    retrieve state from the state set that has been accumulated by
157    trillium handlers run on the [`trillium::Conn`] before it
158    became a websocket. see [`trillium::Conn::state`] for more
159    information
160     */
161    pub fn state<T: 'static>(&self) -> Option<&T> {
162        self.state.get()
163    }
164
165    /**
166    retrieve a mutable borrow of the state from the state set
167     */
168    pub fn state_mut<T: 'static>(&mut self) -> Option<&mut T> {
169        self.state.get_mut()
170    }
171
172    /// see [`insert_state`]
173    #[deprecated = "use WebsocketConn::insert_state"]
174    pub fn set_state<T: Send + Sync + 'static>(&mut self, state: T) {
175        self.insert_state(state);
176    }
177
178    /// inserts new state
179    ///
180    /// returns the previously set state of the same type, if any existed
181    pub fn insert_state<T: Send + Sync + 'static>(&mut self, state: T) -> Option<T> {
182        self.state.insert(state)
183    }
184
185    /**
186    take some type T out of the state set that has been
187    accumulated by trillium handlers run on the [`trillium::Conn`]
188    before it became a websocket. see [`trillium::Conn::take_state`]
189    for more information
190     */
191    pub fn take_state<T: 'static>(&mut self) -> Option<T> {
192        self.state.take()
193    }
194
195    /// take the inbound Message stream from this conn
196    pub fn take_inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult>> {
197        self.stream.take()
198    }
199
200    /// borrow the inbound Message stream from this conn
201    pub fn inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult> + '_> {
202        self.stream.as_mut()
203    }
204}
205
206type MessageResult = std::result::Result<Message, tungstenite::Error>;
207
208#[derive(Debug)]
209pub struct WStream {
210    stream: StreamStopper<SplitStream<Wss>>,
211}
212
213impl Stream for WStream {
214    type Item = MessageResult;
215
216    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
217        self.stream.poll_next_unpin(cx)
218    }
219}
220
221impl AsMut<StateSet> for WebSocketConn {
222    fn as_mut(&mut self) -> &mut StateSet {
223        &mut self.state
224    }
225}
226
227impl AsRef<StateSet> for WebSocketConn {
228    fn as_ref(&self) -> &StateSet {
229        &self.state
230    }
231}
232
233impl Stream for WebSocketConn {
234    type Item = MessageResult;
235
236    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
237        match self.stream.as_mut() {
238            Some(stream) => stream.poll_next_unpin(cx),
239            None => Poll::Ready(None),
240        }
241    }
242}