Skip to main content

trillium_http/
copy.rs

1use futures_lite::{io::BufReader, ready, AsyncBufRead, AsyncRead, AsyncWrite};
2use std::{
3    future::Future,
4    io::{ErrorKind, Result},
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9/// copy from the `reader` to the `writer`, yielding back to the runtime every `loops_per_yield`
10///
11/// # Errors
12///
13/// This returns any io error encountered in reading or writing
14pub async fn copy<R, W>(reader: R, writer: W, loops_per_yield: usize) -> Result<u64>
15where
16    R: AsyncRead + Unpin,
17    W: AsyncWrite + Unpin,
18{
19    struct CopyFuture<R, W> {
20        reader: BufReader<R>,
21        writer: W,
22        amt: u64,
23        loops_per_yield: usize,
24    }
25
26    impl<R, W> Future for CopyFuture<R, W>
27    where
28        R: AsyncRead + Unpin,
29        W: AsyncWrite + Unpin,
30    {
31        type Output = Result<u64>;
32
33        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
34            for loop_number in 0..self.loops_per_yield {
35                log::trace!("copy loop number: {loop_number}");
36                let CopyFuture {
37                    reader,
38                    writer,
39                    amt,
40                    ..
41                } = &mut *self;
42
43                let writer = Pin::new(writer);
44                let mut reader = Pin::new(reader);
45                let buffer = ready!(reader.as_mut().poll_fill_buf(cx))?;
46                if buffer.is_empty() {
47                    ready!(writer.poll_flush(cx))?;
48                    return Poll::Ready(Ok(self.amt));
49                }
50
51                let i = ready!(writer.poll_write(cx, buffer))?;
52                if i == 0 {
53                    return Poll::Ready(Err(ErrorKind::WriteZero.into()));
54                }
55                *amt += i as u64;
56                reader.consume(i);
57            }
58
59            cx.waker().wake_by_ref();
60            Poll::Pending
61        }
62    }
63
64    let future = CopyFuture {
65        reader: BufReader::new(reader),
66        writer,
67        amt: 0,
68        loops_per_yield,
69    };
70    future.await
71}