trillium_testing/
test_transport.rs1use async_dup::Arc;
2use futures_lite::{AsyncRead, AsyncWrite};
3use std::{
4 fmt::{Debug, Display},
5 future::Future,
6 io,
7 net::{IpAddr, Shutdown, SocketAddr},
8 pin::Pin,
9 sync::RwLock,
10 task::{Context, Poll, Waker},
11};
12use trillium::TypeSet;
13use trillium_macros::{AsyncRead, AsyncWrite};
14
15#[derive(Default, Clone, Debug, AsyncRead, AsyncWrite, fieldwork::Fieldwork)]
17pub struct TestTransport {
18 #[async_read]
20 #[field(get = read_side)]
21 read: Arc<CloseableCursor>,
22
23 #[async_write]
25 #[field(get = write_side)]
26 write: Arc<CloseableCursor>,
27
28 #[field(vis = "pub(crate)", get)]
30 state: Arc<RwLock<TypeSet>>,
31
32 #[field(get, set, option_set_some)]
34 peer_ip: Option<IpAddr>,
35}
36
37impl trillium::Transport for TestTransport {
38 fn peer_addr(&self) -> io::Result<Option<SocketAddr>> {
39 if let Some(ip) = self.peer_ip {
40 Ok(Some(SocketAddr::from((ip, 0))))
41 } else {
42 Ok(None)
43 }
44 }
45}
46
47impl TestTransport {
48 pub fn new() -> (TestTransport, TestTransport) {
53 let a = Arc::new(CloseableCursor::default());
54 let b = Arc::new(CloseableCursor::default());
55 let state: Arc<RwLock<TypeSet>> = Default::default();
56
57 (
58 TestTransport {
59 read: a.clone(),
60 write: b.clone(),
61 state: state.clone(),
62 peer_ip: None,
63 },
64 TestTransport {
65 read: b,
66 write: a,
67 state,
68 peer_ip: None,
69 },
70 )
71 }
72
73 pub fn close(&mut self) {
75 self.read.close();
76 self.write.close();
77 }
78
79 pub fn shutdown(&self, how: Shutdown) {
83 match how {
84 Shutdown::Read => self.read.close(),
85 Shutdown::Write => self.write.close(),
86 Shutdown::Both => {
87 self.read.close();
88 self.write.close();
89 }
90 }
91 }
92
93 pub fn snapshot(&self) -> Vec<u8> {
95 self.read.snapshot()
96 }
97
98 pub fn write_all(&self, bytes: impl AsRef<[u8]>) {
101 io::Write::write_all(&mut &*self.write, bytes.as_ref()).unwrap();
102 }
103
104 pub async fn read_available(&self) -> Vec<u8> {
107 self.read.read_available().await
108 }
109
110 pub async fn read_available_string(&self) -> String {
113 self.read.read_available_string().await
114 }
115}
116
117impl Drop for TestTransport {
118 fn drop(&mut self) {
119 self.close();
120 }
121}
122
123#[derive(Default)]
124struct CloseableCursorInner {
125 data: Vec<u8>,
126 cursor: usize,
127 waker: Option<Waker>,
128 closed: bool,
129}
130
131#[derive(Default)]
132pub struct CloseableCursor(RwLock<CloseableCursorInner>);
133
134impl CloseableCursor {
135 pub fn len(&self) -> usize {
137 self.0.read().unwrap().data.len()
138 }
139
140 pub fn cursor(&self) -> usize {
142 self.0.read().unwrap().cursor
143 }
144
145 pub fn is_empty(&self) -> bool {
147 self.len() == 0
148 }
149
150 pub fn snapshot(&self) -> Vec<u8> {
152 self.0.read().unwrap().data.clone()
153 }
154
155 pub fn current(&self) -> bool {
157 let inner = self.0.read().unwrap();
158 inner.data.len() == inner.cursor
159 }
160
161 pub fn close(&self) {
163 let mut inner = self.0.write().unwrap();
164 inner.closed = true;
165 if let Some(waker) = inner.waker.take() {
166 waker.wake();
167 }
168 }
169
170 pub async fn read_available(&self) -> Vec<u8> {
172 ReadAvailable(self).await.unwrap()
173 }
174
175 pub async fn read_available_string(&self) -> String {
177 String::from_utf8(self.read_available().await).unwrap()
178 }
179}
180
181struct ReadAvailable<T>(T);
182
183impl<T: AsyncRead + Unpin> Future for ReadAvailable<T> {
184 type Output = io::Result<Vec<u8>>;
185
186 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
187 let mut buf = vec![];
188 let mut bytes_read = 0;
189 loop {
190 if buf.len() == bytes_read {
191 buf.reserve(32);
192 buf.resize(buf.capacity(), 0);
193 }
194 match Pin::new(&mut self.0).poll_read(cx, &mut buf[bytes_read..]) {
195 Poll::Ready(Ok(0)) => break,
196 Poll::Ready(Ok(new_bytes)) => {
197 bytes_read += new_bytes;
198 }
199 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
200 Poll::Pending if bytes_read == 0 => return Poll::Pending,
201 Poll::Pending => break,
202 }
203 }
204
205 buf.truncate(bytes_read);
206 Poll::Ready(Ok(buf))
207 }
208}
209
210impl Display for CloseableCursor {
211 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212 let inner = self.0.read().unwrap();
213 write!(f, "{}", String::from_utf8_lossy(&inner.data))
214 }
215}
216
217impl Debug for CloseableCursor {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 let inner = self.0.read().unwrap();
220 f.debug_struct("CloseableCursor")
221 .field(
222 "data",
223 &std::str::from_utf8(&inner.data).unwrap_or("not utf8"),
224 )
225 .field("closed", &inner.closed)
226 .field("cursor", &inner.cursor)
227 .finish()
228 }
229}
230
231impl AsyncRead for CloseableCursor {
232 fn poll_read(
233 self: Pin<&mut Self>,
234 cx: &mut Context<'_>,
235 buf: &mut [u8],
236 ) -> Poll<io::Result<usize>> {
237 Pin::new(&mut &*self).poll_read(cx, buf)
238 }
239}
240
241impl AsyncRead for &CloseableCursor {
242 fn poll_read(
243 self: Pin<&mut Self>,
244 cx: &mut Context<'_>,
245 buf: &mut [u8],
246 ) -> Poll<io::Result<usize>> {
247 let mut inner = self.0.write().unwrap();
248 if inner.cursor < inner.data.len() {
249 let bytes_to_copy = buf.len().min(inner.data.len() - inner.cursor);
250 buf[..bytes_to_copy]
251 .copy_from_slice(&inner.data[inner.cursor..inner.cursor + bytes_to_copy]);
252 inner.cursor += bytes_to_copy;
253 Poll::Ready(Ok(bytes_to_copy))
254 } else if inner.closed {
255 Poll::Ready(Ok(0))
256 } else {
257 inner.waker = Some(cx.waker().clone());
258 Poll::Pending
259 }
260 }
261}
262
263impl AsyncWrite for &CloseableCursor {
264 fn poll_write(
265 self: Pin<&mut Self>,
266 _cx: &mut Context<'_>,
267 buf: &[u8],
268 ) -> Poll<io::Result<usize>> {
269 let mut inner = self.0.write().unwrap();
270 if inner.closed {
271 Poll::Ready(Ok(0))
272 } else {
273 inner.data.extend_from_slice(buf);
274 if let Some(waker) = inner.waker.take() {
275 waker.wake();
276 }
277 Poll::Ready(Ok(buf.len()))
278 }
279 }
280
281 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
282 Poll::Ready(Ok(()))
283 }
284
285 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
286 self.close();
287 Poll::Ready(Ok(()))
288 }
289}
290
291impl io::Write for CloseableCursor {
292 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
293 io::Write::write(&mut &*self, buf)
294 }
295
296 fn flush(&mut self) -> io::Result<()> {
297 Ok(())
298 }
299}
300
301impl io::Write for &CloseableCursor {
302 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
303 let mut inner = self.0.write().unwrap();
304 if inner.closed {
305 Ok(0)
306 } else {
307 inner.data.extend_from_slice(buf);
308 if let Some(waker) = inner.waker.take() {
309 waker.wake();
310 }
311 Ok(buf.len())
312 }
313 }
314
315 fn flush(&mut self) -> io::Result<()> {
316 Ok(())
317 }
318}