Skip to main content

trillium_webtransport/
lib.rs

1//! WebTransport support for Trillium.
2//!
3//! This crate provides a [`WebTransport`] handler that accepts WebTransport sessions over
4//! HTTP/3, and a [`WebTransportConnection`] handle for sending and receiving streams and
5//! datagrams within each session.
6//!
7//! WebTransport requires an HTTP/3-capable server adapter configured with a QUIC endpoint
8//! and TLS.
9
10#[cfg(test)]
11#[doc = include_str!("../README.md")]
12mod readme {}
13
14mod session_router;
15mod stream;
16
17use crate::session_router::Router;
18pub use crate::stream::{
19    Datagram, InboundBidiStream, InboundStream, InboundUniStream, OutboundBidiStream,
20    OutboundUniStream,
21};
22use async_channel::Receiver;
23use futures_lite::AsyncWriteExt;
24use std::{
25    io,
26    sync::{Arc, OnceLock},
27};
28use swansong::Swansong;
29use trillium::{Conn, Handler, Info, Method, Status, Transport, Upgrade};
30use trillium_http::h3::{H3Connection, quic_varint};
31use trillium_server_common::{
32    QuicConnection, Runtime,
33    h3::{
34        StreamId,
35        web_transport::{WebTransportDispatcher, WebTransportStream},
36    },
37};
38
39/// A handle to an active WebTransport session.
40///
41/// Passed to your [`WebTransportHandler`] when a client opens a WebTransport session.
42/// Use it to accept streams from the client, open server-initiated streams, and exchange
43/// datagrams.
44pub struct WebTransportConnection {
45    session_id: u64,
46    bidi_rx: Receiver<InboundBidiStream>,
47    uni_rx: Receiver<InboundUniStream>,
48    datagram_rx: Receiver<Datagram>,
49    swansong: Swansong,
50    upgrade: Upgrade,
51    h3_connection: Arc<H3Connection>,
52    quic_connection: QuicConnection,
53    runtime: Runtime,
54}
55
56impl WebTransportConnection {
57    /// Accept the next inbound bidirectional stream for this session.
58    ///
59    /// Returns `None` when the session is shutting down or the QUIC connection has closed.
60    pub async fn accept_bidi(&self) -> Option<InboundBidiStream> {
61        self.swansong.interrupt(self.bidi_rx.recv()).await?.ok()
62    }
63
64    /// Returns the async runtime for this server.
65    pub fn runtime(&self) -> &Runtime {
66        &self.runtime
67    }
68
69    /// Returns the underlying HTTP/3 connection.
70    pub fn h3_connection(&self) -> &H3Connection {
71        &self.h3_connection
72    }
73
74    /// Returns the HTTP CONNECT upgrade that initiated this WebTransport session.
75    ///
76    /// Provides access to request headers, connection state, and peer information from
77    /// the CONNECT request.
78    pub fn upgrade(&self) -> &Upgrade {
79        &self.upgrade
80    }
81
82    /// Returns a mutable reference to the HTTP CONNECT upgrade that initiated this session.
83    pub fn upgrade_mut(&mut self) -> &mut Upgrade {
84        &mut self.upgrade
85    }
86
87    /// Accept the next inbound unidirectional stream for this session.
88    ///
89    /// Returns `None` when the session is shutting down or the QUIC connection has closed.
90    pub async fn accept_uni(&self) -> Option<InboundUniStream> {
91        self.swansong.interrupt(self.uni_rx.recv()).await?.ok()
92    }
93
94    /// Receive the next datagram for this session.
95    ///
96    /// Returns `None` when the session is shutting down or the QUIC connection has closed.
97    pub async fn recv_datagram(&self) -> Option<Datagram> {
98        self.swansong.interrupt(self.datagram_rx.recv()).await?.ok()
99    }
100
101    /// Accept the next inbound stream for this session.
102    ///
103    /// Races the bidi and uni stream channels and returns whichever arrives first.
104    /// Returns `None` when the session ends.
105    ///
106    /// Datagrams are intentionally excluded — use [`recv_datagram`](Self::recv_datagram)
107    /// in a separate concurrent loop, as datagrams typically require lower latency
108    /// than stream acceptance.
109    pub async fn accept_next_stream(&self) -> Option<InboundStream> {
110        futures_lite::future::race(
111            async { self.accept_bidi().await.map(InboundStream::Bidi) },
112            async { self.accept_uni().await.map(InboundStream::Uni) },
113        )
114        .await
115    }
116
117    /// Send an unreliable datagram to the client.
118    ///
119    /// Returns an error if the QUIC connection does not support datagrams or the payload is
120    /// too large.
121    pub fn send_datagram(&self, payload: &[u8]) -> io::Result<()> {
122        let quarter_id = self.session_id / 4;
123        let header_len = quic_varint::encoded_len(quarter_id);
124        let mut buf = vec![0u8; header_len + payload.len()];
125        quic_varint::encode(quarter_id, &mut buf).unwrap();
126        buf[header_len..].copy_from_slice(payload);
127        self.quic_connection.send_datagram(&buf)
128    }
129
130    /// Open a new server-initiated bidirectional stream for this session.
131    pub async fn open_bidi(&self) -> io::Result<OutboundBidiStream> {
132        let (_stream_id, mut transport) = self.quic_connection.open_bidi().await?;
133        transport
134            .write_all(&wt_bidi_header(self.session_id))
135            .await?;
136        Ok(OutboundBidiStream::new(transport))
137    }
138
139    /// Open a new server-initiated unidirectional stream for this session.
140    pub async fn open_uni(&self) -> io::Result<OutboundUniStream> {
141        let (_stream_id, mut stream) = self.quic_connection.open_uni().await?;
142        stream.write_all(&wt_uni_header(self.session_id)).await?;
143        Ok(OutboundUniStream::new(stream))
144    }
145}
146
147enum RoutingAction {
148    Stream(WebTransportStream),
149    Datagram(Vec<u8>),
150}
151
152/// Encode the bidi stream header: signal value 0x41 + session_id.
153fn wt_bidi_header(session_id: u64) -> Vec<u8> {
154    let mut buf =
155        vec![0u8; quic_varint::encoded_len(0x41u64) + quic_varint::encoded_len(session_id)];
156    let mut offset = quic_varint::encode(0x41u64, &mut buf).unwrap();
157    offset += quic_varint::encode(session_id, &mut buf[offset..]).unwrap();
158    buf.truncate(offset);
159    buf
160}
161
162/// Encode the uni stream header: stream type 0x54 + session_id.
163fn wt_uni_header(session_id: u64) -> Vec<u8> {
164    let mut buf =
165        vec![0u8; quic_varint::encoded_len(0x54u64) + quic_varint::encoded_len(session_id)];
166    let mut offset = quic_varint::encode(0x54u64, &mut buf).unwrap();
167    offset += quic_varint::encode(session_id, &mut buf[offset..]).unwrap();
168    buf.truncate(offset);
169    buf
170}
171
172const DEFAULT_MAX_DATAGRAM_BUFFER: usize = 16;
173
174/// A Trillium [`Handler`] that accepts WebTransport sessions.
175///
176/// Add this to your handler chain and provide a [`WebTransportHandler`] (or a closure) to
177/// process each session.
178///
179/// # Example
180///
181/// ```no_run
182/// use trillium_webtransport::{WebTransport, WebTransportConnection};
183///
184/// let handler = WebTransport::new(|conn: WebTransportConnection| async move {
185///     while let Some(stream) = conn.accept_next_stream().await {
186///         // handle stream...
187/// # drop(stream);
188///     }
189/// });
190/// ```
191pub struct WebTransport<H> {
192    runtime: OnceLock<Runtime>,
193    max_datagram_buffer: usize,
194    handler: H,
195}
196
197/// A handler for WebTransport sessions.
198///
199/// Any `Fn(WebTransportConnection) -> impl Future<Output = ()>` automatically implements this
200/// trait, so you can pass a closure or async function directly to [`WebTransport::new`].
201pub trait WebTransportHandler: Send + Sync + 'static {
202    /// Handle a WebTransport session. Called once per client-initiated session.
203    fn run(
204        &self,
205        web_transport_connection: WebTransportConnection,
206    ) -> impl Future<Output = ()> + Send;
207}
208
209impl<Fun, Fut> WebTransportHandler for Fun
210where
211    Fun: Fn(WebTransportConnection) -> Fut + Send + Sync + 'static,
212    Fut: Future<Output = ()> + Send,
213{
214    async fn run(&self, web_transport_connection: WebTransportConnection) {
215        self(web_transport_connection).await
216    }
217}
218
219impl<H> WebTransport<H>
220where
221    H: WebTransportHandler,
222{
223    /// Create a new `WebTransport` handler that passes each session to `handler`.
224    pub fn new(handler: H) -> Self {
225        Self {
226            handler,
227            runtime: Default::default(),
228            max_datagram_buffer: DEFAULT_MAX_DATAGRAM_BUFFER,
229        }
230    }
231
232    /// Set the maximum number of datagrams to buffer per session.
233    ///
234    /// When the buffer is full, the oldest datagram is dropped to make room for the newest.
235    ///
236    /// - **`max > 1`** — FIFO ring-buffer that tolerates bursts up to `max` datagrams before
237    ///   dropping. Good for ordered event streams where some loss is acceptable.
238    /// - **`max = 1`** — "latest-only" semantics: if multiple datagrams arrive while your
239    ///   [`recv_datagram`](WebTransportConnection::recv_datagram) loop is busy, only the most
240    ///   recent is retained. Good for streaming state (positions, sensor readings) where older
241    ///   values are invalidated by newer ones.
242    ///
243    /// Default: 16.
244    pub fn with_max_datagram_buffer(mut self, max: usize) -> Self {
245        self.max_datagram_buffer = max;
246        self
247    }
248
249    fn runtime(&self) -> &Runtime {
250        self.runtime.get().unwrap()
251    }
252}
253
254struct WTUpgrade;
255
256impl<H> Handler for WebTransport<H>
257where
258    H: WebTransportHandler,
259{
260    async fn run(&self, conn: Conn) -> Conn {
261        let inner: &trillium_http::Conn<Box<dyn Transport>> = conn.as_ref();
262        if inner.state().contains::<QuicConnection>() && conn.method() == Method::Connect
263        // todo(jbr): try to figure out why chrome isn't sending a protocol
264        //            && inner.protocol() == Some("webtransport-h3")
265        //            && inner.authority().is_some(/*and something else?*/)
266        {
267            conn.with_state(WTUpgrade).with_status(Status::Ok).halt()
268        } else {
269            conn
270        }
271    }
272
273    async fn init(&mut self, info: &mut Info) {
274        self.runtime.get_or_init(|| {
275            info.state::<Runtime>()
276                .cloned()
277                .expect("webtransport requires a Runtime")
278        });
279
280        info.http_config_mut()
281            .set_h3_datagrams_enabled(true)
282            .set_webtransport_enabled(true);
283    }
284
285    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
286        upgrade.state().get::<WTUpgrade>().is_some()
287    }
288
289    async fn upgrade(&self, mut upgrade: Upgrade) {
290        let Some(h3_connection) = upgrade.h3_connection() else {
291            log::error!("missing H3Connection in upgrade state");
292            return;
293        };
294        let Some(quic_connection) = upgrade.state_mut().take::<QuicConnection>() else {
295            log::error!("missing QuicConnection in upgrade state");
296            return;
297        };
298        let Some(stream_id) = upgrade.state_mut().take::<StreamId>() else {
299            log::error!("missing StreamId in upgrade state");
300            return;
301        };
302        let Some(dispatcher) = upgrade.state().get::<WebTransportDispatcher>().cloned() else {
303            log::error!("missing WebTransportDispatcher in upgrade state");
304            return;
305        };
306
307        let max_datagram_buffer = self.max_datagram_buffer;
308        let Some(router) = dispatcher.get_or_init_with(|| Router::new(max_datagram_buffer)) else {
309            log::error!("WebTransportDispatcher has a handler of an unexpected type");
310            return;
311        };
312
313        // Spawn the routing task if we're the first session on this connection.
314        if let Some(routing_rx) = router.take_routing_rx() {
315            let router = router.clone();
316            let quic = quic_connection.clone();
317            self.runtime().clone().spawn(async move {
318                loop {
319                    let action = futures_lite::future::race(
320                        async { routing_rx.recv().await.ok().map(RoutingAction::Stream) },
321                        async {
322                            let mut data = Vec::new();
323                            quic.recv_datagram(|d| data.extend_from_slice(d))
324                                .await
325                                .ok()
326                                .map(|()| RoutingAction::Datagram(data))
327                        },
328                    )
329                    .await;
330                    match action {
331                        Some(RoutingAction::Stream(stream)) => {
332                            router.sessions.lock().await.route(stream);
333                        }
334                        Some(RoutingAction::Datagram(data)) => {
335                            router.sessions.lock().await.route_datagram(&data);
336                        }
337                        None => break,
338                    }
339                }
340            });
341        }
342
343        let session_id = stream_id.into();
344        log::trace!("starting webtransport session {session_id}");
345        let session_swansong = h3_connection.swansong().child();
346        let (bidi_rx, uni_rx, datagram_rx) = router.sessions.lock().await.register(session_id);
347
348        let runtime = self.runtime().clone();
349
350        self.handler
351            .run(WebTransportConnection {
352                session_id,
353                bidi_rx,
354                uni_rx,
355                datagram_rx,
356                swansong: session_swansong.clone(),
357                upgrade,
358                h3_connection,
359                quic_connection,
360                runtime,
361            })
362            .await;
363
364        log::trace!("finished handler, cleaning up");
365
366        session_swansong.shut_down().await;
367        router.sessions.lock().await.unregister(session_id);
368    }
369}