trillium_webtransport/
lib.rs1#[cfg(test)]
11#[doc = include_str!("../README.md")]
12mod readme {}
13
14mod session_router;
15mod stream;
16
17use crate::session_router::Router;
18pub use crate::stream::{
19 Datagram, InboundBidiStream, InboundStream, InboundUniStream, OutboundBidiStream,
20 OutboundUniStream,
21};
22use async_channel::Receiver;
23use futures_lite::AsyncWriteExt;
24use std::{
25 io,
26 sync::{Arc, OnceLock},
27};
28use swansong::Swansong;
29use trillium::{Conn, Handler, Info, Method, Status, Transport, Upgrade};
30use trillium_http::h3::{H3Connection, quic_varint};
31use trillium_server_common::{
32 QuicConnection, Runtime,
33 h3::{
34 StreamId,
35 web_transport::{WebTransportDispatcher, WebTransportStream},
36 },
37};
38
39pub struct WebTransportConnection {
45 session_id: u64,
46 bidi_rx: Receiver<InboundBidiStream>,
47 uni_rx: Receiver<InboundUniStream>,
48 datagram_rx: Receiver<Datagram>,
49 swansong: Swansong,
50 upgrade: Upgrade,
51 h3_connection: Arc<H3Connection>,
52 quic_connection: QuicConnection,
53 runtime: Runtime,
54}
55
56impl WebTransportConnection {
57 pub async fn accept_bidi(&self) -> Option<InboundBidiStream> {
61 self.swansong.interrupt(self.bidi_rx.recv()).await?.ok()
62 }
63
64 pub fn runtime(&self) -> &Runtime {
66 &self.runtime
67 }
68
69 pub fn h3_connection(&self) -> &H3Connection {
71 &self.h3_connection
72 }
73
74 pub fn upgrade(&self) -> &Upgrade {
79 &self.upgrade
80 }
81
82 pub fn upgrade_mut(&mut self) -> &mut Upgrade {
84 &mut self.upgrade
85 }
86
87 pub async fn accept_uni(&self) -> Option<InboundUniStream> {
91 self.swansong.interrupt(self.uni_rx.recv()).await?.ok()
92 }
93
94 pub async fn recv_datagram(&self) -> Option<Datagram> {
98 self.swansong.interrupt(self.datagram_rx.recv()).await?.ok()
99 }
100
101 pub async fn accept_next_stream(&self) -> Option<InboundStream> {
110 futures_lite::future::race(
111 async { self.accept_bidi().await.map(InboundStream::Bidi) },
112 async { self.accept_uni().await.map(InboundStream::Uni) },
113 )
114 .await
115 }
116
117 pub fn send_datagram(&self, payload: &[u8]) -> io::Result<()> {
122 let quarter_id = self.session_id / 4;
123 let header_len = quic_varint::encoded_len(quarter_id);
124 let mut buf = vec![0u8; header_len + payload.len()];
125 quic_varint::encode(quarter_id, &mut buf).unwrap();
126 buf[header_len..].copy_from_slice(payload);
127 self.quic_connection.send_datagram(&buf)
128 }
129
130 pub async fn open_bidi(&self) -> io::Result<OutboundBidiStream> {
132 let (_stream_id, mut transport) = self.quic_connection.open_bidi().await?;
133 transport
134 .write_all(&wt_bidi_header(self.session_id))
135 .await?;
136 Ok(OutboundBidiStream::new(transport))
137 }
138
139 pub async fn open_uni(&self) -> io::Result<OutboundUniStream> {
141 let (_stream_id, mut stream) = self.quic_connection.open_uni().await?;
142 stream.write_all(&wt_uni_header(self.session_id)).await?;
143 Ok(OutboundUniStream::new(stream))
144 }
145}
146
147enum RoutingAction {
148 Stream(WebTransportStream),
149 Datagram(Vec<u8>),
150}
151
152fn wt_bidi_header(session_id: u64) -> Vec<u8> {
154 let mut buf =
155 vec![0u8; quic_varint::encoded_len(0x41u64) + quic_varint::encoded_len(session_id)];
156 let mut offset = quic_varint::encode(0x41u64, &mut buf).unwrap();
157 offset += quic_varint::encode(session_id, &mut buf[offset..]).unwrap();
158 buf.truncate(offset);
159 buf
160}
161
162fn wt_uni_header(session_id: u64) -> Vec<u8> {
164 let mut buf =
165 vec![0u8; quic_varint::encoded_len(0x54u64) + quic_varint::encoded_len(session_id)];
166 let mut offset = quic_varint::encode(0x54u64, &mut buf).unwrap();
167 offset += quic_varint::encode(session_id, &mut buf[offset..]).unwrap();
168 buf.truncate(offset);
169 buf
170}
171
172const DEFAULT_MAX_DATAGRAM_BUFFER: usize = 16;
173
174pub struct WebTransport<H> {
192 runtime: OnceLock<Runtime>,
193 max_datagram_buffer: usize,
194 handler: H,
195}
196
197pub trait WebTransportHandler: Send + Sync + 'static {
202 fn run(
204 &self,
205 web_transport_connection: WebTransportConnection,
206 ) -> impl Future<Output = ()> + Send;
207}
208
209impl<Fun, Fut> WebTransportHandler for Fun
210where
211 Fun: Fn(WebTransportConnection) -> Fut + Send + Sync + 'static,
212 Fut: Future<Output = ()> + Send,
213{
214 async fn run(&self, web_transport_connection: WebTransportConnection) {
215 self(web_transport_connection).await
216 }
217}
218
219impl<H> WebTransport<H>
220where
221 H: WebTransportHandler,
222{
223 pub fn new(handler: H) -> Self {
225 Self {
226 handler,
227 runtime: Default::default(),
228 max_datagram_buffer: DEFAULT_MAX_DATAGRAM_BUFFER,
229 }
230 }
231
232 pub fn with_max_datagram_buffer(mut self, max: usize) -> Self {
245 self.max_datagram_buffer = max;
246 self
247 }
248
249 fn runtime(&self) -> &Runtime {
250 self.runtime.get().unwrap()
251 }
252}
253
254struct WTUpgrade;
255
256impl<H> Handler for WebTransport<H>
257where
258 H: WebTransportHandler,
259{
260 async fn run(&self, conn: Conn) -> Conn {
261 let inner: &trillium_http::Conn<Box<dyn Transport>> = conn.as_ref();
262 if inner.state().contains::<QuicConnection>() && conn.method() == Method::Connect
263 {
267 conn.with_state(WTUpgrade).with_status(Status::Ok).halt()
268 } else {
269 conn
270 }
271 }
272
273 async fn init(&mut self, info: &mut Info) {
274 self.runtime.get_or_init(|| {
275 info.state::<Runtime>()
276 .cloned()
277 .expect("webtransport requires a Runtime")
278 });
279
280 info.http_config_mut()
281 .set_h3_datagrams_enabled(true)
282 .set_webtransport_enabled(true);
283 }
284
285 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
286 upgrade.state().get::<WTUpgrade>().is_some()
287 }
288
289 async fn upgrade(&self, mut upgrade: Upgrade) {
290 let Some(h3_connection) = upgrade.h3_connection() else {
291 log::error!("missing H3Connection in upgrade state");
292 return;
293 };
294 let Some(quic_connection) = upgrade.state_mut().take::<QuicConnection>() else {
295 log::error!("missing QuicConnection in upgrade state");
296 return;
297 };
298 let Some(stream_id) = upgrade.state_mut().take::<StreamId>() else {
299 log::error!("missing StreamId in upgrade state");
300 return;
301 };
302 let Some(dispatcher) = upgrade.state().get::<WebTransportDispatcher>().cloned() else {
303 log::error!("missing WebTransportDispatcher in upgrade state");
304 return;
305 };
306
307 let max_datagram_buffer = self.max_datagram_buffer;
308 let Some(router) = dispatcher.get_or_init_with(|| Router::new(max_datagram_buffer)) else {
309 log::error!("WebTransportDispatcher has a handler of an unexpected type");
310 return;
311 };
312
313 if let Some(routing_rx) = router.take_routing_rx() {
315 let router = router.clone();
316 let quic = quic_connection.clone();
317 self.runtime().clone().spawn(async move {
318 loop {
319 let action = futures_lite::future::race(
320 async { routing_rx.recv().await.ok().map(RoutingAction::Stream) },
321 async {
322 let mut data = Vec::new();
323 quic.recv_datagram(|d| data.extend_from_slice(d))
324 .await
325 .ok()
326 .map(|()| RoutingAction::Datagram(data))
327 },
328 )
329 .await;
330 match action {
331 Some(RoutingAction::Stream(stream)) => {
332 router.sessions.lock().await.route(stream);
333 }
334 Some(RoutingAction::Datagram(data)) => {
335 router.sessions.lock().await.route_datagram(&data);
336 }
337 None => break,
338 }
339 }
340 });
341 }
342
343 let session_id = stream_id.into();
344 log::trace!("starting webtransport session {session_id}");
345 let session_swansong = h3_connection.swansong().child();
346 let (bidi_rx, uni_rx, datagram_rx) = router.sessions.lock().await.register(session_id);
347
348 let runtime = self.runtime().clone();
349
350 self.handler
351 .run(WebTransportConnection {
352 session_id,
353 bidi_rx,
354 uni_rx,
355 datagram_rx,
356 swansong: session_swansong.clone(),
357 upgrade,
358 h3_connection,
359 quic_connection,
360 runtime,
361 })
362 .await;
363
364 log::trace!("finished handler, cleaning up");
365
366 session_swansong.shut_down().await;
367 router.sessions.lock().await.unregister(session_id);
368 }
369}