Skip to main content

trillium_api/
cancel_on_disconnect.rs

1use crate::TryFromConn;
2use std::{future::Future, marker::PhantomData, sync::Arc};
3use trillium::{async_trait, Conn, Handler, Info, Status, Upgrade};
4
5/// A struct that cancels a handler if the client disconnects.
6///
7/// Note that the conn is not available to this handler, and any properties of the request needed
8/// for execution must be extracted through [`FromConn`] or [`TryFromConn`] arguments
9#[derive(Debug)]
10pub struct CancelOnDisconnect<F, OutputHandler, TryFromConn>(
11    F,
12    PhantomData<OutputHandler>,
13    PhantomData<TryFromConn>,
14);
15
16impl<F, OH, TFC, Fut> CancelOnDisconnect<F, OH, TFC>
17where
18    F: Fn(TFC) -> Fut + Send + Sync + 'static,
19    Fut: Future<Output = OH> + Send + 'static,
20    OH: Handler,
21    TFC: TryFromConn,
22    TFC::Error: Handler,
23{
24    /// Construct a new CancelOnDisconnect handler
25    pub fn new(handler: F) -> Self {
26        CancelOnDisconnect(handler, PhantomData, PhantomData)
27    }
28}
29
30/// Construct a new [`CancelOnDisconnect`] handler.
31///
32/// Alias for [`CancelOnDisconnect::new`]
33pub fn cancel_on_disconnect<F, OH, TFC, Fut>(handler: F) -> CancelOnDisconnect<F, OH, TFC>
34where
35    F: Fn(TFC) -> Fut + Send + Sync + 'static,
36    Fut: Future<Output = OH> + Send + 'static,
37    OH: Handler,
38    TFC: TryFromConn,
39    TFC::Error: Handler,
40{
41    CancelOnDisconnect(handler, PhantomData, PhantomData)
42}
43
44#[async_trait]
45impl<F, OutputHandler, TFC, Fut> Handler for CancelOnDisconnect<F, OutputHandler, TFC>
46where
47    F: Fn(TFC) -> Fut + Send + Sync + 'static,
48    Fut: Future<Output = OutputHandler> + Send + 'static,
49    OutputHandler: Handler,
50    TFC: TryFromConn,
51    TFC::Error: Handler,
52{
53    async fn run(&self, mut conn: Conn) -> Conn {
54        let mut output_handler: Result<OutputHandler, <TFC as TryFromConn>::Error> =
55            match TFC::try_from_conn(&mut conn).await {
56                Ok(extracted) => {
57                    let Some(ret) = conn.cancel_on_disconnect(self.0(extracted)).await else {
58                        log::info!("client disconnected");
59                        return conn;
60                    };
61                    Ok(ret)
62                }
63                Err(error_handler) => Err(error_handler),
64            };
65
66        if let Some(info) = conn.state_mut::<Info>() {
67            output_handler.init(info).await;
68        } else {
69            output_handler.init(&mut Info::default()).await;
70        }
71        let mut conn = output_handler.run(conn).await;
72        if conn.status().is_none() && conn.inner().response_body().is_some() {
73            conn.set_status(Status::Ok);
74        }
75        conn.with_state(OutputHandlerWrapper(
76            Arc::new(output_handler),
77            PhantomData::<Self>,
78        ))
79    }
80
81    async fn before_send(&self, conn: Conn) -> Conn {
82        if let Some(OutputHandlerWrapper(handler, _)) = conn
83            .state::<OutputHandlerWrapper<Self, OutputHandler, <TFC as TryFromConn>::Error>>()
84            .cloned()
85        {
86            handler.before_send(conn).await
87        } else {
88            conn
89        }
90    }
91
92    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
93        upgrade
94            .state()
95            .get::<OutputHandlerWrapper<Self, OutputHandler, <TFC as TryFromConn>::Error>>()
96            .cloned()
97            .map_or(false, |OutputHandlerWrapper(handler, _)| {
98                handler.has_upgrade(upgrade)
99            })
100    }
101
102    async fn upgrade(&self, upgrade: Upgrade) {
103        if let Some(OutputHandlerWrapper(handler, _)) = upgrade
104            .state()
105            .get::<OutputHandlerWrapper<Self, OutputHandler, <TFC as TryFromConn>::Error>>()
106            .cloned()
107        {
108            handler.upgrade(upgrade).await
109        }
110    }
111}
112
113struct OutputHandlerWrapper<TFC, OH, EH>(Arc<Result<OH, EH>>, PhantomData<TFC>);
114
115impl<TFC, OH, EH> Clone for OutputHandlerWrapper<TFC, OH, EH> {
116    fn clone(&self) -> Self {
117        Self(Arc::clone(&self.0), self.1)
118    }
119}