Skip to main content

trillium_testing/
test_transport.rs

1use async_dup::Arc;
2use futures_lite::{AsyncRead, AsyncWrite};
3use std::{
4    fmt::{Debug, Display},
5    future::Future,
6    io,
7    net::{IpAddr, Shutdown, SocketAddr},
8    pin::Pin,
9    sync::RwLock,
10    task::{Context, Poll, Waker},
11};
12use trillium::TypeSet;
13use trillium_macros::{AsyncRead, AsyncWrite};
14
15/// a readable and writable transport for testing
16#[derive(Default, Clone, Debug, AsyncRead, AsyncWrite, fieldwork::Fieldwork)]
17pub struct TestTransport {
18    /// the read side of this transport
19    #[async_read]
20    #[field(get = read_side)]
21    read: Arc<CloseableCursor>,
22
23    /// the write side of this transport
24    #[async_write]
25    #[field(get = write_side)]
26    write: Arc<CloseableCursor>,
27
28    /// State that can be shared with the other side of this transport
29    #[field(vis = "pub(crate)", get)]
30    state: Arc<RwLock<TypeSet>>,
31
32    /// peer ip for the read side
33    #[field(get, set, option_set_some)]
34    peer_ip: Option<IpAddr>,
35}
36
37impl trillium::Transport for TestTransport {
38    fn peer_addr(&self) -> io::Result<Option<SocketAddr>> {
39        if let Some(ip) = self.peer_ip {
40            Ok(Some(SocketAddr::from((ip, 0))))
41        } else {
42            Ok(None)
43        }
44    }
45}
46
47impl TestTransport {
48    /// constructs a new test transport pair, representing two ends of
49    /// a connection. either of them can be written to, and the
50    /// content will be readable from the other. either of them can
51    /// also be closed.
52    pub fn new() -> (TestTransport, TestTransport) {
53        let a = Arc::new(CloseableCursor::default());
54        let b = Arc::new(CloseableCursor::default());
55        let state: Arc<RwLock<TypeSet>> = Default::default();
56
57        (
58            TestTransport {
59                read: a.clone(),
60                write: b.clone(),
61                state: state.clone(),
62                peer_ip: None,
63            },
64            TestTransport {
65                read: b,
66                write: a,
67                state,
68                peer_ip: None,
69            },
70        )
71    }
72
73    /// close this transport, representing a disconnection
74    pub fn close(&mut self) {
75        self.read.close();
76        self.write.close();
77    }
78
79    /// Shuts down the read, write, or both halves of this connection.
80    // This function will cause all pending and future I/O on the specified portions to return
81    // immediately with an appropriate value (see the documentation of Shutdown).
82    pub fn shutdown(&self, how: Shutdown) {
83        match how {
84            Shutdown::Read => self.read.close(),
85            Shutdown::Write => self.write.close(),
86            Shutdown::Both => {
87                self.read.close();
88                self.write.close();
89            }
90        }
91    }
92
93    /// take an owned snapshot of the received data
94    pub fn snapshot(&self) -> Vec<u8> {
95        self.read.snapshot()
96    }
97
98    /// synchronously append the supplied bytes to the write side of this transport, notifying the
99    /// read side of the other end
100    pub fn write_all(&self, bytes: impl AsRef<[u8]>) {
101        io::Write::write_all(&mut &*self.write, bytes.as_ref()).unwrap();
102    }
103
104    /// waits until there is content and then reads that content to a vec until there is no
105    /// further content immediately available
106    pub async fn read_available(&self) -> Vec<u8> {
107        self.read.read_available().await
108    }
109
110    /// waits until there is content and then reads that content to a string until there is no
111    /// further content immediately available
112    pub async fn read_available_string(&self) -> String {
113        self.read.read_available_string().await
114    }
115}
116
117impl Drop for TestTransport {
118    fn drop(&mut self) {
119        self.close();
120    }
121}
122
123#[derive(Default)]
124struct CloseableCursorInner {
125    data: Vec<u8>,
126    cursor: usize,
127    waker: Option<Waker>,
128    closed: bool,
129}
130
131#[derive(Default)]
132pub struct CloseableCursor(RwLock<CloseableCursorInner>);
133
134impl CloseableCursor {
135    /// the length of the content
136    pub fn len(&self) -> usize {
137        self.0.read().unwrap().data.len()
138    }
139
140    /// the current read position
141    pub fn cursor(&self) -> usize {
142        self.0.read().unwrap().cursor
143    }
144
145    /// does what it says on the tin
146    pub fn is_empty(&self) -> bool {
147        self.len() == 0
148    }
149
150    /// take a snapshot of the data
151    pub fn snapshot(&self) -> Vec<u8> {
152        self.0.read().unwrap().data.clone()
153    }
154
155    /// have we read to the end of the available content
156    pub fn current(&self) -> bool {
157        let inner = self.0.read().unwrap();
158        inner.data.len() == inner.cursor
159    }
160
161    /// close this cursor, waking any pending polls
162    pub fn close(&self) {
163        let mut inner = self.0.write().unwrap();
164        inner.closed = true;
165        if let Some(waker) = inner.waker.take() {
166            waker.wake();
167        }
168    }
169
170    /// read any available bytes
171    pub async fn read_available(&self) -> Vec<u8> {
172        ReadAvailable(self).await.unwrap()
173    }
174
175    /// read any available bytes as a string
176    pub async fn read_available_string(&self) -> String {
177        String::from_utf8(self.read_available().await).unwrap()
178    }
179}
180
181struct ReadAvailable<T>(T);
182
183impl<T: AsyncRead + Unpin> Future for ReadAvailable<T> {
184    type Output = io::Result<Vec<u8>>;
185
186    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
187        let mut buf = vec![];
188        let mut bytes_read = 0;
189        loop {
190            if buf.len() == bytes_read {
191                buf.reserve(32);
192                buf.resize(buf.capacity(), 0);
193            }
194            match Pin::new(&mut self.0).poll_read(cx, &mut buf[bytes_read..]) {
195                Poll::Ready(Ok(0)) => break,
196                Poll::Ready(Ok(new_bytes)) => {
197                    bytes_read += new_bytes;
198                }
199                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
200                Poll::Pending if bytes_read == 0 => return Poll::Pending,
201                Poll::Pending => break,
202            }
203        }
204
205        buf.truncate(bytes_read);
206        Poll::Ready(Ok(buf))
207    }
208}
209
210impl Display for CloseableCursor {
211    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        let inner = self.0.read().unwrap();
213        write!(f, "{}", String::from_utf8_lossy(&inner.data))
214    }
215}
216
217impl Debug for CloseableCursor {
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        let inner = self.0.read().unwrap();
220        f.debug_struct("CloseableCursor")
221            .field(
222                "data",
223                &std::str::from_utf8(&inner.data).unwrap_or("not utf8"),
224            )
225            .field("closed", &inner.closed)
226            .field("cursor", &inner.cursor)
227            .finish()
228    }
229}
230
231impl AsyncRead for CloseableCursor {
232    fn poll_read(
233        self: Pin<&mut Self>,
234        cx: &mut Context<'_>,
235        buf: &mut [u8],
236    ) -> Poll<io::Result<usize>> {
237        Pin::new(&mut &*self).poll_read(cx, buf)
238    }
239}
240
241impl AsyncRead for &CloseableCursor {
242    fn poll_read(
243        self: Pin<&mut Self>,
244        cx: &mut Context<'_>,
245        buf: &mut [u8],
246    ) -> Poll<io::Result<usize>> {
247        let mut inner = self.0.write().unwrap();
248        if inner.cursor < inner.data.len() {
249            let bytes_to_copy = buf.len().min(inner.data.len() - inner.cursor);
250            buf[..bytes_to_copy]
251                .copy_from_slice(&inner.data[inner.cursor..inner.cursor + bytes_to_copy]);
252            inner.cursor += bytes_to_copy;
253            Poll::Ready(Ok(bytes_to_copy))
254        } else if inner.closed {
255            Poll::Ready(Ok(0))
256        } else {
257            inner.waker = Some(cx.waker().clone());
258            Poll::Pending
259        }
260    }
261}
262
263impl AsyncWrite for &CloseableCursor {
264    fn poll_write(
265        self: Pin<&mut Self>,
266        _cx: &mut Context<'_>,
267        buf: &[u8],
268    ) -> Poll<io::Result<usize>> {
269        let mut inner = self.0.write().unwrap();
270        if inner.closed {
271            Poll::Ready(Ok(0))
272        } else {
273            inner.data.extend_from_slice(buf);
274            if let Some(waker) = inner.waker.take() {
275                waker.wake();
276            }
277            Poll::Ready(Ok(buf.len()))
278        }
279    }
280
281    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
282        Poll::Ready(Ok(()))
283    }
284
285    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
286        self.close();
287        Poll::Ready(Ok(()))
288    }
289}
290
291impl io::Write for CloseableCursor {
292    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
293        io::Write::write(&mut &*self, buf)
294    }
295
296    fn flush(&mut self) -> io::Result<()> {
297        Ok(())
298    }
299}
300
301impl io::Write for &CloseableCursor {
302    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
303        let mut inner = self.0.write().unwrap();
304        if inner.closed {
305            Ok(0)
306        } else {
307            inner.data.extend_from_slice(buf);
308            if let Some(waker) = inner.waker.take() {
309                waker.wake();
310            }
311            Ok(buf.len())
312        }
313    }
314
315    fn flush(&mut self) -> io::Result<()> {
316        Ok(())
317    }
318}