Skip to main content

trillium_websockets/
json.rs

1//! # websocket json adapter
2//!
3//! See the documentation for [`JsonWebSocketHandler`]
4
5use crate::{WebSocket, WebSocketConn, WebSocketHandler};
6use async_tungstenite::tungstenite::{Message, protocol::CloseFrame};
7use futures_lite::{Stream, ready};
8use serde::{Serialize, de::DeserializeOwned};
9use std::{
10    fmt::Debug,
11    future::Future,
12    ops::{Deref, DerefMut},
13    pin::Pin,
14    task::{Context, Poll},
15};
16
17/// # Implement this trait to use websockets with a json handler
18///
19/// JsonWebSocketHandler provides a small layer of abstraction on top of
20/// [`WebSocketHandler`], serializing and deserializing messages for
21/// you. This may eventually move to a crate of its own.
22///
23/// ## ℹ️ In order to use this trait, the `json` crate feature must be enabled.
24///
25/// ```
26/// use async_channel::{Receiver, Sender, unbounded};
27/// use serde::{Deserialize, Serialize};
28/// use std::pin::Pin;
29/// use trillium::log_error;
30/// use trillium_websockets::{JsonWebSocketHandler, Result, WebSocketConn, json_websocket};
31///
32/// #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
33/// struct Response {
34///     inbound_message: Inbound,
35/// }
36///
37/// #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
38/// struct Inbound {
39///     message: String,
40/// }
41///
42/// struct SomeJsonChannel;
43///
44/// impl JsonWebSocketHandler for SomeJsonChannel {
45///     type InboundMessage = Inbound;
46///     type OutboundMessage = Response;
47///     type StreamType = Pin<Box<Receiver<Self::OutboundMessage>>>;
48///
49///     async fn connect(&self, conn: &mut WebSocketConn) -> Self::StreamType {
50///         let (s, r) = unbounded();
51///         conn.insert_state(s);
52///         Box::pin(r)
53///     }
54///
55///     async fn receive_message(
56///         &self,
57///         inbound_message: Result<Self::InboundMessage>,
58///         conn: &mut WebSocketConn,
59///     ) {
60///         if let Ok(inbound_message) = inbound_message {
61///             log_error!(
62///                 conn.state::<Sender<Response>>()
63///                     .unwrap()
64///                     .send(Response { inbound_message })
65///                     .await
66///             );
67///         }
68///     }
69/// }
70///
71/// // fn main() {
72/// //    trillium_smol::run(json_websocket(SomeJsonChannel));
73/// // }
74/// ```
75#[allow(unused_variables)]
76pub trait JsonWebSocketHandler: Send + Sync + 'static {
77    /// A type that can be deserialized from the json sent from the
78    /// connected clients
79    type InboundMessage: DeserializeOwned + Send + 'static;
80
81    /// A serializable type that will be sent in the StreamType and
82    /// received by the connected websocket clients
83    type OutboundMessage: Serialize + Send + 'static;
84
85    /// A type that implements a stream of
86    /// [`Self::OutboundMessage`]s. This can be
87    /// futures_lite::stream::Pending if you never need to send an
88    /// outbound message.
89    type StreamType: Stream<Item = Self::OutboundMessage> + Send + Sync + 'static;
90
91    /// `connect` is called once for each upgraded websocket
92    /// connection, and returns a Self::StreamType.
93    fn connect(&self, conn: &mut WebSocketConn) -> impl Future<Output = Self::StreamType> + Send;
94
95    /// `receive_message` is called once for each successfully deserialized
96    /// InboundMessage along with the websocket conn that it was received
97    /// from.
98    fn receive_message(
99        &self,
100        message: crate::Result<Self::InboundMessage>,
101        conn: &mut WebSocketConn,
102    ) -> impl Future<Output = ()> + Send;
103
104    /// `disconnect` is called when websocket clients disconnect, along
105    /// with a CloseFrame, if one was provided. Implementing `disconnect`
106    /// is optional.
107    fn disconnect(
108        &self,
109        conn: &mut WebSocketConn,
110        close_frame: Option<CloseFrame>,
111    ) -> impl Future<Output = ()> + Send {
112        async {}
113    }
114}
115
116/// A wrapper type for [`JsonWebSocketHandler`]s
117///
118/// You do not need to interact with this type directly. Instead, use
119/// [`WebSocket::new_json`] or [`json_websocket`].
120pub struct JsonHandler<T> {
121    pub(crate) handler: T,
122}
123
124impl<T> Deref for JsonHandler<T> {
125    type Target = T;
126
127    fn deref(&self) -> &Self::Target {
128        &self.handler
129    }
130}
131
132impl<T> DerefMut for JsonHandler<T> {
133    fn deref_mut(&mut self) -> &mut Self::Target {
134        &mut self.handler
135    }
136}
137
138impl<T: JsonWebSocketHandler> JsonHandler<T> {
139    pub(crate) fn new(handler: T) -> Self {
140        Self { handler }
141    }
142}
143
144impl<T> Debug for JsonHandler<T> {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        f.debug_struct("JsonWebSocketHandler").finish()
147    }
148}
149
150pin_project_lite::pin_project! {
151    /// A stream for internal use that attempts to serialize the items in the wrapped stream to a
152    /// [`Message::Text`]
153    #[derive(Debug)]
154    pub struct SerializedStream<T> {
155        #[pin] inner: T
156    }
157}
158
159impl<T> Stream for SerializedStream<T>
160where
161    T: Stream,
162    T::Item: Serialize,
163{
164    type Item = Message;
165
166    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
167        Poll::Ready(
168            ready!(self.project().inner.poll_next(cx))
169                .and_then(|i| match serde_json::to_string(&i) {
170                    Ok(j) => Some(j),
171                    Err(e) => {
172                        log::error!("serialization error: {e}");
173                        None
174                    }
175                })
176                .map(Message::text),
177        )
178    }
179}
180
181impl<T> WebSocketHandler for JsonHandler<T>
182where
183    T: JsonWebSocketHandler,
184{
185    type OutboundStream = SerializedStream<T::StreamType>;
186
187    async fn connect(
188        &self,
189        mut conn: WebSocketConn,
190    ) -> Option<(WebSocketConn, Self::OutboundStream)> {
191        let stream = SerializedStream {
192            inner: self.handler.connect(&mut conn).await,
193        };
194        Some((conn, stream))
195    }
196
197    async fn inbound(&self, message: Message, conn: &mut WebSocketConn) {
198        self.handler
199            .receive_message(
200                message
201                    .to_text()
202                    .map_err(Into::into)
203                    .and_then(|m| serde_json::from_str(m).map_err(Into::into)),
204                conn,
205            )
206            .await;
207    }
208
209    async fn disconnect(&self, conn: &mut WebSocketConn, close_frame: Option<CloseFrame>) {
210        self.handler.disconnect(conn, close_frame).await
211    }
212}
213
214impl<T> WebSocket<JsonHandler<T>>
215where
216    T: JsonWebSocketHandler,
217{
218    /// Build a new trillium WebSocket handler from the provided
219    /// [`JsonWebSocketHandler`]
220    pub fn new_json(handler: T) -> Self {
221        Self::new(JsonHandler::new(handler))
222    }
223}
224
225/// builds a new trillium handler from the provided
226/// [`JsonWebSocketHandler`]. Alias for [`WebSocket::new_json`]
227pub fn json_websocket<T>(json_websocket_handler: T) -> WebSocket<JsonHandler<T>>
228where
229    T: JsonWebSocketHandler,
230{
231    WebSocket::new_json(json_websocket_handler)
232}