trillium_api/
cancel_on_disconnect.rs1use crate::TryFromConn;
2use std::{future::Future, marker::PhantomData, sync::Arc};
3use trillium::{async_trait, Conn, Handler, Info, Status, Upgrade};
4
5#[derive(Debug)]
10pub struct CancelOnDisconnect<F, OutputHandler, TryFromConn>(
11 F,
12 PhantomData<OutputHandler>,
13 PhantomData<TryFromConn>,
14);
15
16impl<F, OH, TFC, Fut> CancelOnDisconnect<F, OH, TFC>
17where
18 F: Fn(TFC) -> Fut + Send + Sync + 'static,
19 Fut: Future<Output = OH> + Send + 'static,
20 OH: Handler,
21 TFC: TryFromConn,
22 TFC::Error: Handler,
23{
24 pub fn new(handler: F) -> Self {
26 CancelOnDisconnect(handler, PhantomData, PhantomData)
27 }
28}
29
30pub fn cancel_on_disconnect<F, OH, TFC, Fut>(handler: F) -> CancelOnDisconnect<F, OH, TFC>
34where
35 F: Fn(TFC) -> Fut + Send + Sync + 'static,
36 Fut: Future<Output = OH> + Send + 'static,
37 OH: Handler,
38 TFC: TryFromConn,
39 TFC::Error: Handler,
40{
41 CancelOnDisconnect(handler, PhantomData, PhantomData)
42}
43
44#[async_trait]
45impl<F, OutputHandler, TFC, Fut> Handler for CancelOnDisconnect<F, OutputHandler, TFC>
46where
47 F: Fn(TFC) -> Fut + Send + Sync + 'static,
48 Fut: Future<Output = OutputHandler> + Send + 'static,
49 OutputHandler: Handler,
50 TFC: TryFromConn,
51 TFC::Error: Handler,
52{
53 async fn run(&self, mut conn: Conn) -> Conn {
54 let mut output_handler: Result<OutputHandler, <TFC as TryFromConn>::Error> =
55 match TFC::try_from_conn(&mut conn).await {
56 Ok(extracted) => {
57 let Some(ret) = conn.cancel_on_disconnect(self.0(extracted)).await else {
58 log::info!("client disconnected");
59 return conn;
60 };
61 Ok(ret)
62 }
63 Err(error_handler) => Err(error_handler),
64 };
65
66 if let Some(info) = conn.state_mut::<Info>() {
67 output_handler.init(info).await;
68 } else {
69 output_handler.init(&mut Info::default()).await;
70 }
71 let mut conn = output_handler.run(conn).await;
72 if conn.status().is_none() && conn.inner().response_body().is_some() {
73 conn.set_status(Status::Ok);
74 }
75 conn.with_state(OutputHandlerWrapper(
76 Arc::new(output_handler),
77 PhantomData::<Self>,
78 ))
79 }
80
81 async fn before_send(&self, conn: Conn) -> Conn {
82 if let Some(OutputHandlerWrapper(handler, _)) = conn
83 .state::<OutputHandlerWrapper<Self, OutputHandler, <TFC as TryFromConn>::Error>>()
84 .cloned()
85 {
86 handler.before_send(conn).await
87 } else {
88 conn
89 }
90 }
91
92 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
93 upgrade
94 .state()
95 .get::<OutputHandlerWrapper<Self, OutputHandler, <TFC as TryFromConn>::Error>>()
96 .cloned()
97 .map_or(false, |OutputHandlerWrapper(handler, _)| {
98 handler.has_upgrade(upgrade)
99 })
100 }
101
102 async fn upgrade(&self, upgrade: Upgrade) {
103 if let Some(OutputHandlerWrapper(handler, _)) = upgrade
104 .state()
105 .get::<OutputHandlerWrapper<Self, OutputHandler, <TFC as TryFromConn>::Error>>()
106 .cloned()
107 {
108 handler.upgrade(upgrade).await
109 }
110 }
111}
112
113struct OutputHandlerWrapper<TFC, OH, EH>(Arc<Result<OH, EH>>, PhantomData<TFC>);
114
115impl<TFC, OH, EH> Clone for OutputHandlerWrapper<TFC, OH, EH> {
116 fn clone(&self) -> Self {
117 Self(Arc::clone(&self.0), self.1)
118 }
119}