1use crate::TryFromConn;
2use std::{future::Future, marker::PhantomData, sync::Arc};
3use trillium::{async_trait, Conn, Handler, Info, Status, Upgrade};
4
5pub trait MutBorrowConn<'conn, ReturnType, Additional>: Send + Sync + 'conn {
7 type Fut: Future<Output = ReturnType> + Send + 'conn;
9 fn call(&self, conn: &'conn mut Conn, additional: Additional) -> Self::Fut;
11}
12
13impl<'conn, Fun, Fut, ReturnType, Additional> MutBorrowConn<'conn, ReturnType, Additional> for Fun
14where
15 Fun: Fn(&'conn mut Conn, Additional) -> Fut + Send + Sync + 'conn,
16 Fut: Future<Output = ReturnType> + Send + 'conn,
17{
18 type Fut = Fut;
19 fn call(&self, conn: &'conn mut Conn, additional: Additional) -> Fut {
20 self(conn, additional)
21 }
22}
23
24#[derive(Debug)]
34pub struct ApiHandler<F, OutputHandler, TryFromConn>(
35 F,
36 PhantomData<OutputHandler>,
37 PhantomData<TryFromConn>,
38);
39
40impl<TryFromConnHandler, OutputHandler, Extracted>
41 ApiHandler<TryFromConnHandler, OutputHandler, Extracted>
42where
43 TryFromConnHandler: for<'a> MutBorrowConn<'a, OutputHandler, Extracted>,
44 OutputHandler: Handler,
45 Extracted: TryFromConn,
46{
47 pub fn new(api_handler: TryFromConnHandler) -> Self {
50 Self::from(api_handler)
51 }
52}
53
54impl<TryFromConnHandler, OutputHandler, Extracted> From<TryFromConnHandler>
55 for ApiHandler<TryFromConnHandler, OutputHandler, Extracted>
56where
57 TryFromConnHandler: for<'a> MutBorrowConn<'a, OutputHandler, Extracted>,
58 OutputHandler: Handler,
59 Extracted: TryFromConn,
60{
61 fn from(value: TryFromConnHandler) -> Self {
62 Self(value, PhantomData, PhantomData)
63 }
64}
65
66pub fn api<TryFromConnHandler, OutputHandler, Extracted>(
71 api_handler: TryFromConnHandler,
72) -> ApiHandler<TryFromConnHandler, OutputHandler, Extracted>
73where
74 TryFromConnHandler: for<'a> MutBorrowConn<'a, OutputHandler, Extracted>,
75 Extracted: TryFromConn,
76 OutputHandler: Handler,
77{
78 ApiHandler::from(api_handler)
79}
80
81#[async_trait]
82impl<TryFromConnHandler, OutputHandler, Extracted> Handler
83 for ApiHandler<TryFromConnHandler, OutputHandler, Extracted>
84where
85 TryFromConnHandler: for<'a> MutBorrowConn<'a, OutputHandler, Extracted>,
86 Extracted: TryFromConn,
87 Extracted::Error: Handler,
88 OutputHandler: Handler,
89{
90 async fn run(&self, mut conn: Conn) -> Conn {
91 let mut output_handler: Result<OutputHandler, <Extracted as TryFromConn>::Error> =
92 match Extracted::try_from_conn(&mut conn).await {
93 Ok(extracted) => Ok(self.0.call(&mut conn, extracted).await),
94 Err(error_handler) => Err(error_handler),
95 };
96
97 if let Some(info) = conn.state_mut::<Info>() {
98 output_handler.init(info).await;
99 } else {
100 output_handler.init(&mut Info::default()).await;
101 }
102 let mut conn = output_handler.run(conn).await;
103 if conn.status().is_none() && conn.inner().response_body().is_some() {
104 conn.set_status(Status::Ok);
105 }
106 conn.with_state(OutputHandlerWrapper(
107 Arc::new(output_handler),
108 PhantomData::<Self>,
109 ))
110 }
111
112 async fn before_send(&self, conn: Conn) -> Conn {
113 if let Some(OutputHandlerWrapper(handler, _)) = conn
114 .state::<OutputHandlerWrapper<Self, OutputHandler, <Extracted as TryFromConn>::Error>>()
115 .cloned()
116 {
117 handler.before_send(conn).await
118 } else {
119 conn
120 }
121 }
122
123 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
124 upgrade
125 .state()
126 .get::<OutputHandlerWrapper<Self, OutputHandler, <Extracted as TryFromConn>::Error>>()
127 .cloned()
128 .map_or(false, |OutputHandlerWrapper(handler, _)| {
129 handler.has_upgrade(upgrade)
130 })
131 }
132
133 async fn upgrade(&self, upgrade: Upgrade) {
134 if let Some(OutputHandlerWrapper(handler, _)) = upgrade
135 .state()
136 .get::<OutputHandlerWrapper<Self, OutputHandler, <Extracted as TryFromConn>::Error>>()
137 .cloned()
138 {
139 handler.upgrade(upgrade).await
140 }
141 }
142}
143
144struct OutputHandlerWrapper<TFC, OH, EH>(Arc<Result<OH, EH>>, PhantomData<TFC>);
145
146impl<TFC, OH, EH> Clone for OutputHandlerWrapper<TFC, OH, EH> {
147 fn clone(&self) -> Self {
148 Self(Arc::clone(&self.0), self.1)
149 }
150}