Skip to main content

trillium_proxy/
body_streamer.rs

1use event_listener::Event;
2
3use futures_lite::AsyncRead;
4
5use sluice::pipe::PipeReader;
6use std::{
7    future::Future,
8    pin::Pin,
9    sync::{
10        atomic::{AtomicBool, Ordering},
11        Arc,
12    },
13    task::{Context, Poll},
14};
15use trillium::{Conn, KnownHeaderName};
16
17use trillium_http::Body;
18
19use crate::bytes;
20
21struct BodyProxyReader {
22    reader: PipeReader,
23    started: Option<Arc<(Event, AtomicBool)>>,
24}
25
26impl Drop for BodyProxyReader {
27    fn drop(&mut self) {
28        // if we haven't started yet, notify the copy future that we're not going to
29        if let Some(started) = self.started.take() {
30            started.0.notify(usize::MAX);
31        }
32    }
33}
34
35impl AsyncRead for BodyProxyReader {
36    fn poll_read(
37        mut self: Pin<&mut Self>,
38        cx: &mut Context<'_>,
39        buf: &mut [u8],
40    ) -> Poll<std::io::Result<usize>> {
41        if let Some(started) = self.started.take() {
42            started.1.store(true, Ordering::SeqCst);
43            started.0.notify(usize::MAX);
44        }
45        Pin::new(&mut self.reader).poll_read(cx, buf)
46    }
47}
48
49pub(crate) fn stream_body(conn: &mut Conn) -> (impl Future<Output = ()> + Send + Sync + '_, Body) {
50    let started = Arc::new((Event::new(), AtomicBool::from(false)));
51    let started_clone = started.clone();
52    let (reader, writer) = sluice::pipe::pipe();
53    let len = conn
54        .request_headers()
55        .get_str(KnownHeaderName::ContentLength)
56        .and_then(|s| s.parse().ok());
57
58    (
59        async move {
60            log::trace!("waiting to stream request body");
61            started_clone.0.listen().await;
62            if started_clone.1.load(Ordering::SeqCst) {
63                log::trace!("started to stream request body");
64                let received_body = conn.request_body().await;
65                match trillium_http::copy(received_body, writer, 4).await {
66                    Ok(streamed) => {
67                        log::trace!("streamed {} request body bytes", bytes(streamed))
68                    }
69                    Err(e) => log::error!("request body stream error: {e}"),
70                };
71            } else {
72                log::trace!("not streaming request body");
73            }
74        },
75        Body::new_streaming(
76            BodyProxyReader {
77                started: Some(started),
78                reader,
79            },
80            len,
81        ),
82    )
83}