Skip to main content

trillium_http/
received_body.rs

1use crate::{Body, Buffer, Error, Headers, HttpConfig, MutCow, copy};
2use Poll::{Pending, Ready};
3use ReceivedBodyState::{Chunked, End, FixedLength, PartialChunkSize, Start};
4use encoding_rs::Encoding;
5use futures_lite::{AsyncRead, AsyncReadExt, AsyncWrite, ready};
6use std::{
7    fmt::{self, Debug, Formatter},
8    io::{self, ErrorKind},
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13mod chunked;
14mod fixed_length;
15mod h3_data;
16
17/// A received http body
18///
19/// This type represents a body that will be read from the underlying transport, which it may either
20/// borrow from a [`Conn`](crate::Conn) or own.
21///
22/// ```rust
23/// # use trillium_testing::HttpTest;
24/// let app = HttpTest::new(|mut conn| async move {
25///     let body = conn.request_body();
26///     let body_string = body.read_string().await.unwrap();
27///     conn.with_response_body(format!("received: {body_string}"))
28/// });
29///
30/// app.get("/").block().assert_body("received: ");
31/// app.post("/")
32///     .with_body("hello")
33///     .block()
34///     .assert_body("received: hello");
35/// ```
36///
37/// ## Bounds checking
38///
39/// Every `ReceivedBody` has a maximum length beyond which it will return an error, expressed as a
40/// u64. To override this on the specific `ReceivedBody`, use [`ReceivedBody::with_max_len`] or
41/// [`ReceivedBody::set_max_len`]
42///
43/// The default maximum length is currently set to 500mb. In the next semver-minor release, this
44/// value will decrease substantially.
45///
46/// ## Large chunks, small read buffers
47///
48/// Attempting to read a chunked body with a buffer that is shorter than the chunk size in hex will
49/// result in an error. This limitation is temporary.
50#[derive(fieldwork::Fieldwork)]
51pub struct ReceivedBody<'conn, Transport> {
52    /// The content-length of this body, if available. This
53    /// usually is derived from the content-length header. If the http
54    /// request or response that this body is attached to uses
55    /// transfer-encoding chunked, this will be None.
56    ///
57    /// ```rust
58    /// # use trillium_testing::HttpTest;
59    /// HttpTest::new(|mut conn| async move {
60    ///     let body = conn.request_body();
61    ///     assert_eq!(body.content_length(), Some(5));
62    ///     let body_string = body.read_string().await.unwrap();
63    ///     conn.with_status(200)
64    ///         .with_response_body(format!("received: {body_string}"))
65    /// })
66    /// .post("/")
67    /// .with_body("hello")
68    /// .block()
69    /// .assert_ok()
70    /// .assert_body("received: hello");
71    /// ```
72    #[field(get)]
73    content_length: Option<u64>,
74
75    buffer: MutCow<'conn, Buffer>,
76
77    transport: Option<MutCow<'conn, Transport>>,
78
79    state: MutCow<'conn, ReceivedBodyState>,
80
81    on_completion: Option<Box<dyn FnOnce(Transport) + Send + Sync + 'static>>,
82
83    /// the character encoding of this body, usually determined from the content type
84    /// (mime-type) of the associated Conn.
85    #[field(get)]
86    encoding: &'static Encoding,
87
88    /// The maximum length that can be read from this body before error
89    ///
90    /// See also [`HttpConfig::received_body_max_len`]
91    #[field(with, get, set)]
92    max_len: u64,
93
94    /// The initial buffer capacity allocated when reading the body to bytes or a string
95    ///
96    /// See [`HttpConfig::received_body_initial_len`]
97    #[field(with, get, set)]
98    initial_len: usize,
99
100    /// The maximum number of read loops that reading this received body will perform before
101    /// yielding back to the runtime
102    ///
103    /// See [`HttpConfig::copy_loops_per_yield`]
104    #[field(with, get, set)]
105    copy_loops_per_yield: usize,
106
107    /// Maximum size to pre-allocate based on content-length for buffering this received body
108    ///
109    /// See [`HttpConfig::received_body_max_preallocate`]
110    #[field(with, get, set)]
111    max_preallocate: usize,
112
113    h3_max_field_section_size: u64,
114
115    trailers: MutCow<'conn, Option<Headers>>,
116
117    /// Byte offset into `b"HTTP/1.1 100 Continue\r\n\r\n"` that remains to be written before the
118    /// first read. `None` means no pending write.
119    send_100_continue_offset: Option<usize>,
120}
121
122fn slice_from(min: u64, buf: &[u8]) -> Option<&[u8]> {
123    buf.get(usize::try_from(min).unwrap_or(usize::MAX)..)
124        .filter(|buf| !buf.is_empty())
125}
126
127impl<'conn, Transport> ReceivedBody<'conn, Transport>
128where
129    Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
130{
131    #[allow(missing_docs)]
132    #[doc(hidden)]
133    pub fn new(
134        content_length: Option<u64>,
135        buffer: impl Into<MutCow<'conn, Buffer>>,
136        transport: impl Into<MutCow<'conn, Transport>>,
137        state: impl Into<MutCow<'conn, ReceivedBodyState>>,
138        on_completion: Option<Box<dyn FnOnce(Transport) + Send + Sync + 'static>>,
139        encoding: &'static Encoding,
140    ) -> Self {
141        Self::new_with_config(
142            content_length,
143            buffer,
144            transport,
145            state,
146            on_completion,
147            encoding,
148            &HttpConfig::DEFAULT,
149        )
150    }
151
152    #[allow(missing_docs)]
153    #[doc(hidden)]
154    pub(crate) fn new_with_config(
155        content_length: Option<u64>,
156        buffer: impl Into<MutCow<'conn, Buffer>>,
157        transport: impl Into<MutCow<'conn, Transport>>,
158        state: impl Into<MutCow<'conn, ReceivedBodyState>>,
159        on_completion: Option<Box<dyn FnOnce(Transport) + Send + Sync + 'static>>,
160        encoding: &'static Encoding,
161        config: &HttpConfig,
162    ) -> Self {
163        Self {
164            content_length,
165            buffer: buffer.into(),
166            transport: Some(transport.into()),
167            state: state.into(),
168            on_completion,
169            encoding,
170            max_len: config.received_body_max_len,
171            initial_len: config.received_body_initial_len,
172            copy_loops_per_yield: config.copy_loops_per_yield,
173            max_preallocate: config.received_body_max_preallocate,
174            h3_max_field_section_size: config.h3_max_field_section_size,
175            trailers: None.into(),
176            send_100_continue_offset: None,
177        }
178    }
179
180    /// Sets the destination for trailers decoded from the request body.
181    ///
182    /// When the body is fully read, any trailers will be written to the provided storage.
183    #[doc(hidden)]
184    #[must_use]
185    pub fn with_trailers(mut self, trailers: impl Into<MutCow<'conn, Option<Headers>>>) -> Self {
186        self.trailers = trailers.into();
187        self
188    }
189
190    /// Arranges for `HTTP/1.1 100 Continue\r\n\r\n` to be written to the transport before the
191    /// first body read. Used to implement lazy 100-continue for HTTP/1.1 request bodies.
192    #[must_use]
193    pub(crate) fn with_send_100_continue(mut self) -> Self {
194        self.send_100_continue_offset = Some(0);
195        self
196    }
197
198    // pub fn content_length(&self) -> Option<u64> {
199    //     self.content_length
200    // }
201
202    /// # Reads entire body to String.
203    ///
204    /// This uses the encoding determined by the content-type (mime)
205    /// charset. If an encoding problem is encountered, the String
206    /// returned by [`ReceivedBody::read_string`] will contain utf8
207    /// replacement characters.
208    ///
209    /// Note that this can only be performed once per Conn, as the
210    /// underlying data is not cached anywhere. This is the only copy of
211    /// the body contents.
212    ///
213    /// # Errors
214    ///
215    /// This will return an error if there is an IO error on the
216    /// underlying transport such as a disconnect
217    ///
218    /// This will also return an error if the length exceeds the maximum length. To override this
219    /// value on this specific body, use [`ReceivedBody::with_max_len`] or
220    /// [`ReceivedBody::set_max_len`]
221    pub async fn read_string(self) -> crate::Result<String> {
222        let encoding = self.encoding();
223        let bytes = self.read_bytes().await?;
224        let (s, _, _) = encoding.decode(&bytes);
225        Ok(s.to_string())
226    }
227
228    fn owns_transport(&self) -> bool {
229        self.transport.as_ref().is_some_and(MutCow::is_owned)
230    }
231
232    /// Similar to [`ReceivedBody::read_string`], but returns the raw bytes. This is useful for
233    /// bodies that are not text.
234    ///
235    /// You can use this in conjunction with `encoding` if you need different handling of malformed
236    /// character encoding than the lossy conversion provided by [`ReceivedBody::read_string`].
237    ///
238    /// # Errors
239    ///
240    /// This will return an error if there is an IO error on the underlying transport such as a
241    /// disconnect
242    ///
243    /// This will also return an error if the length exceeds
244    /// [`received_body_max_len`][HttpConfig::with_received_body_max_len]. To override this value on
245    /// this specific body, use [`ReceivedBody::with_max_len`] or [`ReceivedBody::set_max_len`]
246    pub async fn read_bytes(mut self) -> crate::Result<Vec<u8>> {
247        let mut vec = if let Some(len) = self.content_length {
248            if len > self.max_len {
249                return Err(Error::ReceivedBodyTooLong(self.max_len));
250            }
251
252            let len = usize::try_from(len).map_err(|_| Error::ReceivedBodyTooLong(self.max_len))?;
253
254            Vec::with_capacity(len.min(self.max_preallocate))
255        } else {
256            Vec::with_capacity(self.initial_len)
257        };
258
259        self.read_to_end(&mut vec).await?;
260        Ok(vec)
261    }
262
263    // pub fn encoding(&self) -> &'static Encoding {
264    //     self.encoding
265    // }
266
267    fn read_raw(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
268        if let Some(transport) = self.transport.as_deref_mut() {
269            read_buffered(&mut self.buffer, transport, cx, buf)
270        } else {
271            Ready(Err(ErrorKind::NotConnected.into()))
272        }
273    }
274
275    /// Consumes the remainder of this body from the underlying transport by reading it to the end
276    /// and discarding the contents. This is important for http1.1 keepalive, but most of the
277    /// time you do not need to directly call this. It returns the number of bytes consumed.
278    ///
279    /// # Errors
280    ///
281    /// This will return an [`std::io::Result::Err`] if there is an io error on the underlying
282    /// transport, such as a disconnect
283    #[allow(clippy::missing_errors_doc)] // false positive
284    pub async fn drain(self) -> io::Result<u64> {
285        let copy_loops_per_yield = self.copy_loops_per_yield;
286        copy(self, futures_lite::io::sink(), copy_loops_per_yield).await
287    }
288}
289
290impl<T> ReceivedBody<'static, T> {
291    /// takes the static transport from this received body
292    pub fn take_transport(&mut self) -> Option<T> {
293        self.transport.take().map(MutCow::unwrap_owned)
294    }
295}
296
297pub(crate) fn read_buffered<Transport>(
298    buffer: &mut Buffer,
299    transport: &mut Transport,
300    cx: &mut Context<'_>,
301    buf: &mut [u8],
302) -> Poll<io::Result<usize>>
303where
304    Transport: AsyncRead + Unpin,
305{
306    if buffer.is_empty() {
307        Pin::new(transport).poll_read(cx, buf)
308    } else if buffer.len() >= buf.len() {
309        let len = buf.len();
310        buf.copy_from_slice(&buffer[..len]);
311        buffer.ignore_front(len);
312        Ready(Ok(len))
313    } else {
314        let self_buffer_len = buffer.len();
315        buf[..self_buffer_len].copy_from_slice(buffer);
316        buffer.truncate(0);
317        match Pin::new(transport).poll_read(cx, &mut buf[self_buffer_len..]) {
318            Ready(Ok(additional)) => Ready(Ok(additional + self_buffer_len)),
319            Pending => Ready(Ok(self_buffer_len)),
320            other @ Ready(_) => other,
321        }
322    }
323}
324
325type StateOutput = Poll<io::Result<(ReceivedBodyState, usize)>>;
326
327impl<Transport> ReceivedBody<'_, Transport>
328where
329    Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
330{
331    #[inline]
332    fn handle_start(&mut self) -> StateOutput {
333        Ready(Ok((
334            match self.content_length {
335                Some(0) => End,
336
337                Some(total_length) if total_length <= self.max_len => FixedLength {
338                    current_index: 0,
339                    total: total_length,
340                },
341
342                Some(_) => {
343                    return Ready(Err(io::Error::new(
344                        ErrorKind::Unsupported,
345                        "content too long",
346                    )));
347                }
348
349                None => Chunked {
350                    remaining: 0,
351                    total: 0,
352                },
353            },
354            0,
355        )))
356    }
357}
358
359impl<Transport> AsyncRead for ReceivedBody<'_, Transport>
360where
361    Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
362{
363    fn poll_read(
364        mut self: Pin<&mut Self>,
365        cx: &mut Context<'_>,
366        buf: &mut [u8],
367    ) -> Poll<io::Result<usize>> {
368        const CONTINUE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
369        while let Some(offset) = self.send_100_continue_offset {
370            let n = {
371                let Some(transport) = self.transport.as_deref_mut() else {
372                    return Ready(Err(ErrorKind::NotConnected.into()));
373                };
374                if offset == 0 {
375                    log::trace!("sending 100-continue");
376                }
377                ready!(Pin::new(transport).poll_write(cx, &CONTINUE[offset..]))?
378            };
379            if n == 0 {
380                return Ready(Err(ErrorKind::WriteZero.into()));
381            }
382            let new_offset = offset + n;
383            self.send_100_continue_offset = if new_offset >= CONTINUE.len() {
384                None
385            } else {
386                Some(new_offset)
387            };
388        }
389
390        for _ in 0..self.copy_loops_per_yield {
391            let (new_body_state, bytes) = ready!(match *self.state {
392                Start => self.handle_start(),
393                Chunked { remaining, total } => self.handle_chunked(cx, buf, remaining, total),
394                PartialChunkSize { total } => self.handle_partial(cx, buf, total),
395                FixedLength {
396                    current_index,
397                    total,
398                } => self.handle_fixed_length(cx, buf, current_index, total),
399                ReceivedBodyState::H3Data {
400                    remaining_in_frame,
401                    total,
402                    frame_type,
403                    partial_frame_header,
404                } => self.handle_h3_data(
405                    cx,
406                    buf,
407                    remaining_in_frame,
408                    total,
409                    frame_type,
410                    partial_frame_header
411                ),
412                ReceivedBodyState::ReadingH1Trailers { total } => {
413                    self.handle_reading_h1_trailers(cx, buf, total)
414                }
415                End => Ready(Ok((End, 0))),
416            })?;
417
418            *self.state = new_body_state;
419
420            if *self.state == End {
421                if self.on_completion.is_some() && self.owns_transport() {
422                    let transport = self.transport.take().unwrap().unwrap_owned();
423                    let on_completion = self.on_completion.take().unwrap();
424                    on_completion(transport);
425                }
426                return Ready(Ok(bytes));
427            } else if bytes != 0 {
428                return Ready(Ok(bytes));
429            }
430        }
431
432        cx.waker().wake_by_ref();
433        Pending
434    }
435}
436
437impl<Transport> Debug for ReceivedBody<'_, Transport> {
438    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
439        f.debug_struct("RequestBody")
440            .field("state", &*self.state)
441            .field("content_length", &self.content_length)
442            .field("buffer", &format_args!(".."))
443            .field("on_completion", &self.on_completion.is_some())
444            .finish()
445    }
446}
447
448/// The type of H3 frame currently being processed in [`ReceivedBodyState::H3Data`].
449#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
450#[allow(missing_docs)]
451#[doc(hidden)]
452pub enum H3BodyFrameType {
453    /// Initial state — no frame decoded yet.
454    #[default]
455    Start,
456    /// Inside a DATA frame — body bytes to keep.
457    Data,
458    /// Inside an unknown frame — payload bytes to discard.
459    Unknown,
460    /// Inside a trailing HEADERS frame — accumulate into buffer for parsing.
461    Trailers,
462}
463
464#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
465#[allow(missing_docs)]
466#[doc(hidden)]
467pub enum ReceivedBodyState {
468    /// initial state
469    #[default]
470    Start,
471
472    /// read state for a chunked-encoded body. the number of bytes that have been read from the
473    /// current chunk is the difference between remaining and total.
474    Chunked {
475        /// remaining indicates the bytes left _in the current
476        /// chunk_. initial state is zero.
477        remaining: u64,
478
479        /// total indicates the absolute number of bytes read from all chunks
480        total: u64,
481    },
482
483    /// read state when we have buffered content between subsequent polls because chunk framing
484    /// overlapped a buffer boundary
485    PartialChunkSize { total: u64 },
486
487    /// read state for a fixed-length body.
488    FixedLength {
489        /// current index represents the bytes that have already been
490        /// read. initial state is zero
491        current_index: u64,
492
493        /// total length indicates the claimed length, usually
494        /// determined by the content-length header
495        total: u64,
496    },
497
498    /// read state for an H3 body framed as DATA frames.
499    H3Data {
500        /// bytes remaining in the current frame (DATA, Unknown, or Trailers). zero means we need
501        /// to read the next frame header.
502        remaining_in_frame: u64,
503
504        /// total body bytes read across all DATA frames.
505        total: u64,
506
507        /// what kind of frame we're currently inside.
508        frame_type: H3BodyFrameType,
509
510        /// when true, a partial frame header is sitting in `self.buffer` and needs more bytes
511        /// before we can decode it.
512        partial_frame_header: bool,
513    },
514
515    /// accumulating the HTTP/1.1 chunked trailer-section after the last-chunk (`0\r\n`).
516    ///
517    /// The trailer bytes (including any partially-received trailer headers) live in
518    /// `ReceivedBody::buffer` until a final empty line (`\r\n\r\n` or bare `\r\n`) is found.
519    ReadingH1Trailers {
520        /// total body bytes read across all chunks (for bounds-checking)
521        total: u64,
522    },
523
524    /// the terminal read state
525    End,
526}
527
528impl ReceivedBodyState {
529    pub fn new_h3() -> Self {
530        Self::H3Data {
531            remaining_in_frame: 0,
532            total: 0,
533            frame_type: H3BodyFrameType::Start,
534            partial_frame_header: false,
535        }
536    }
537}
538
539impl<Transport> From<ReceivedBody<'static, Transport>> for Body
540where
541    Transport: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
542{
543    fn from(rb: ReceivedBody<'static, Transport>) -> Self {
544        let len = rb.content_length;
545        Body::new_streaming(rb, len)
546    }
547}