Skip to main content

trillium_webtransport/
stream.rs

1use std::{
2    fmt::{self, Debug, Formatter},
3    io,
4    ops::Deref,
5    pin::Pin,
6    task::{Context, Poll},
7};
8use trillium_macros::{AsyncRead, AsyncWrite};
9use trillium_server_common::{
10    AsyncRead, AsyncWrite, QuicTransportBidi, QuicTransportReceive, QuicTransportSend,
11};
12
13/// A received WebTransport datagram.
14///
15/// Derefs to `&[u8]` and converts `Into<Vec<u8>>`.
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct Datagram(Vec<u8>);
18
19impl Deref for Datagram {
20    type Target = [u8];
21
22    fn deref(&self) -> &[u8] {
23        &self.0
24    }
25}
26
27impl AsRef<[u8]> for Datagram {
28    fn as_ref(&self) -> &[u8] {
29        &self.0
30    }
31}
32
33impl From<Vec<u8>> for Datagram {
34    fn from(v: Vec<u8>) -> Self {
35        Self(v)
36    }
37}
38
39impl From<Datagram> for Vec<u8> {
40    fn from(d: Datagram) -> Self {
41        d.0
42    }
43}
44
45/// An inbound WebTransport stream, yielded by
46/// [`WebTransportConnection::accept_next_stream`](crate::WebTransportConnection::accept_next_stream).
47///
48///
49/// Datagrams are handled separately via
50/// [`WebTransportConnection::recv_datagram`](crate::WebTransportConnection::recv_datagram), as they
51/// typically require a dedicated low-latency loop rather than sharing one with stream acceptance.
52#[derive(Debug)]
53pub enum InboundStream {
54    /// An inbound bidirectional stream opened by the client.
55    Bidi(InboundBidiStream),
56    /// An inbound unidirectional stream opened by the client.
57    Uni(InboundUniStream),
58}
59
60pub(crate) type BoxedRecvStream = Box<dyn QuicTransportReceive + Unpin + Send + Sync>;
61type BoxedSendStream = Box<dyn QuicTransportSend + Unpin + Send + Sync>;
62
63/// An inbound bidirectional WebTransport stream opened by the client.
64///
65/// Implements [`AsyncRead`] and [`AsyncWrite`].
66#[derive(AsyncWrite)]
67pub struct InboundBidiStream {
68    buffer: Vec<u8>,
69    offset: usize,
70    #[async_write]
71    stream: Box<dyn QuicTransportBidi>,
72}
73
74impl Debug for InboundBidiStream {
75    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
76        f.debug_struct("InboundBidiStream")
77            .field("buffer", &self.buffer)
78            .field("offset", &self.offset)
79            .field("transport", &format_args!("Box<dyn QuicTransportBidi>"))
80            .finish()
81    }
82}
83
84impl InboundBidiStream {
85    pub(crate) fn new(transport: Box<dyn QuicTransportBidi>, buffer: Vec<u8>) -> Self {
86        Self {
87            buffer,
88            offset: 0,
89            stream: transport,
90        }
91    }
92
93    pub fn reset(&mut self, code: Option<u64>) {
94        self.stream.reset(code.unwrap_or(0));
95    }
96
97    pub fn stop(&mut self, code: Option<u64>) {
98        self.stream.stop(code.unwrap_or(0));
99    }
100}
101
102impl AsyncRead for InboundBidiStream {
103    fn poll_read(
104        mut self: Pin<&mut Self>,
105        cx: &mut Context<'_>,
106        buf: &mut [u8],
107    ) -> Poll<io::Result<usize>> {
108        let this = &mut *self;
109        read_buffered(
110            &mut this.buffer,
111            &mut this.offset,
112            &mut this.stream,
113            cx,
114            buf,
115        )
116    }
117}
118
119/// An inbound unidirectional WebTransport stream opened by the client.
120///
121/// Implements [`AsyncRead`].
122pub struct InboundUniStream {
123    buffer: Vec<u8>,
124    offset: usize,
125    stream: BoxedRecvStream,
126}
127
128impl Debug for InboundUniStream {
129    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
130        f.debug_struct("InboundUniStream")
131            .field("buffer", &self.buffer)
132            .field("offset", &self.offset)
133            .finish_non_exhaustive()
134    }
135}
136
137impl InboundUniStream {
138    pub(crate) fn new(stream: BoxedRecvStream, buffer: Vec<u8>) -> Self {
139        Self {
140            buffer,
141            offset: 0,
142            stream,
143        }
144    }
145
146    pub fn stop(&mut self, code: Option<u64>) {
147        self.stream.stop(code.unwrap_or(0));
148    }
149}
150
151impl AsyncRead for InboundUniStream {
152    fn poll_read(
153        mut self: Pin<&mut Self>,
154        cx: &mut Context<'_>,
155        buf: &mut [u8],
156    ) -> Poll<io::Result<usize>> {
157        let this = &mut *self;
158        read_buffered(
159            &mut this.buffer,
160            &mut this.offset,
161            &mut this.stream,
162            cx,
163            buf,
164        )
165    }
166}
167
168/// A server-initiated bidirectional WebTransport stream.
169///
170/// Implements [`AsyncRead`] and [`AsyncWrite`].
171#[derive(AsyncRead, AsyncWrite)]
172pub struct OutboundBidiStream(Box<dyn QuicTransportBidi>);
173
174impl Debug for OutboundBidiStream {
175    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
176        f.debug_tuple("OutboundBidiStream").finish_non_exhaustive()
177    }
178}
179
180impl OutboundBidiStream {
181    pub(crate) fn new(transport: Box<dyn QuicTransportBidi>) -> Self {
182        Self(transport)
183    }
184
185    pub fn stop(&mut self, code: Option<u64>) {
186        self.0.stop(code.unwrap_or(0));
187    }
188
189    pub fn reset(&mut self, code: Option<u64>) {
190        self.0.reset(code.unwrap_or(0));
191    }
192}
193
194/// A server-initiated unidirectional WebTransport stream.
195///
196/// Implements [`AsyncWrite`].
197#[derive(AsyncWrite)]
198pub struct OutboundUniStream(BoxedSendStream);
199
200impl Debug for OutboundUniStream {
201    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
202        f.debug_tuple("OutboundUniStream").finish_non_exhaustive()
203    }
204}
205
206impl OutboundUniStream {
207    pub(crate) fn new(stream: BoxedSendStream) -> Self {
208        Self(stream)
209    }
210
211    pub fn reset(&mut self, code: Option<u64>) {
212        self.0.reset(code.unwrap_or(0));
213    }
214}
215
216fn read_buffered(
217    buffer: &mut Vec<u8>,
218    offset: &mut usize,
219    transport: &mut (impl AsyncRead + Unpin),
220    cx: &mut Context<'_>,
221    buf: &mut [u8],
222) -> Poll<io::Result<usize>> {
223    let remaining = buffer.len() - *offset;
224    if remaining == 0 {
225        return Pin::new(transport).poll_read(cx, buf);
226    }
227
228    let n = remaining.min(buf.len());
229    buf[..n].copy_from_slice(&buffer[*offset..*offset + n]);
230    *offset += n;
231
232    if *offset == buffer.len() {
233        *buffer = Vec::new();
234        *offset = 0;
235    }
236
237    Poll::Ready(Ok(n))
238}