Skip to main content

trillium_server_common/
binding.rs

1use crate::Transport;
2use futures_lite::{AsyncRead, AsyncWrite, Stream};
3use std::{
4    io::{IoSlice, Result},
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9/// A wrapper enum that has blanket implementations for common traits
10/// like TryFrom, Stream, AsyncRead, and AsyncWrite. This can contain
11/// listeners (like TcpListener), Streams (like Incoming), or
12/// bytestreams (like TcpStream).
13#[derive(Debug, Clone)]
14pub enum Binding<T, U> {
15    /// a tcp type (listener or incoming or stream)
16    Tcp(T),
17
18    /// a unix type (listener or incoming or stream)
19    Unix(U),
20}
21
22use Binding::{Tcp, Unix};
23
24impl<T, U> Binding<T, U> {
25    /// borrows the tcp stream or listener, if this is a tcp variant
26    pub fn get_tcp(&self) -> Option<&T> {
27        if let Tcp(t) = self { Some(t) } else { None }
28    }
29
30    /// borrows the unix stream or listener, if this is unix variant
31    pub fn get_unix(&self) -> Option<&U> {
32        if let Unix(u) = self { Some(u) } else { None }
33    }
34
35    /// mutably borrows the tcp stream or listener, if this is tcp variant
36    pub fn get_tcp_mut(&mut self) -> Option<&mut T> {
37        if let Tcp(t) = self { Some(t) } else { None }
38    }
39
40    /// mutably borrows the unix stream or listener, if this is unix variant
41    pub fn get_unix_mut(&mut self) -> Option<&mut U> {
42        if let Unix(u) = self { Some(u) } else { None }
43    }
44}
45
46impl<T: TryFrom<std::net::TcpListener>, U> TryFrom<std::net::TcpListener> for Binding<T, U> {
47    type Error = <T as TryFrom<std::net::TcpListener>>::Error;
48
49    fn try_from(value: std::net::TcpListener) -> std::result::Result<Self, Self::Error> {
50        Ok(Self::Tcp(value.try_into()?))
51    }
52}
53
54#[cfg(unix)]
55impl<T, U: TryFrom<std::os::unix::net::UnixListener>> TryFrom<std::os::unix::net::UnixListener>
56    for Binding<T, U>
57{
58    type Error = <U as TryFrom<std::os::unix::net::UnixListener>>::Error;
59
60    fn try_from(value: std::os::unix::net::UnixListener) -> std::result::Result<Self, Self::Error> {
61        Ok(Self::Unix(value.try_into()?))
62    }
63}
64
65impl<T, U, TI, UI> Stream for Binding<T, U>
66where
67    T: Stream<Item = Result<TI>> + Unpin,
68    U: Stream<Item = Result<UI>> + Unpin,
69{
70    type Item = Result<Binding<TI, UI>>;
71
72    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
73        match &mut *self {
74            Tcp(t) => Pin::new(t).poll_next(cx).map(|i| i.map(|x| x.map(Tcp))),
75            Unix(u) => Pin::new(u).poll_next(cx).map(|i| i.map(|x| x.map(Unix))),
76        }
77    }
78}
79
80impl<T, U> Binding<T, U>
81where
82    T: AsyncRead + Unpin,
83    U: AsyncRead + Unpin,
84{
85    fn as_async_read(&mut self) -> Pin<&mut (dyn AsyncRead + Unpin)> {
86        Pin::new(match self {
87            Tcp(t) => t as &mut (dyn AsyncRead + Unpin),
88            Unix(u) => u as &mut (dyn AsyncRead + Unpin),
89        })
90    }
91}
92
93impl<T, U> Binding<T, U>
94where
95    T: AsyncWrite + Unpin,
96    U: AsyncWrite + Unpin,
97{
98    fn as_async_write(&mut self) -> Pin<&mut (dyn AsyncWrite + Unpin)> {
99        Pin::new(match self {
100            Tcp(t) => t as &mut (dyn AsyncWrite + Unpin),
101            Unix(u) => u as &mut (dyn AsyncWrite + Unpin),
102        })
103    }
104}
105
106impl<T, U> AsyncRead for Binding<T, U>
107where
108    T: AsyncRead + Unpin,
109    U: AsyncRead + Unpin,
110{
111    fn poll_read(
112        mut self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114        buf: &mut [u8],
115    ) -> Poll<Result<usize>> {
116        self.as_async_read().poll_read(cx, buf)
117    }
118
119    fn poll_read_vectored(
120        mut self: Pin<&mut Self>,
121        cx: &mut Context<'_>,
122        bufs: &mut [std::io::IoSliceMut<'_>],
123    ) -> Poll<Result<usize>> {
124        self.as_async_read().poll_read_vectored(cx, bufs)
125    }
126}
127
128impl<T, U> AsyncWrite for Binding<T, U>
129where
130    T: AsyncWrite + Unpin,
131    U: AsyncWrite + Unpin,
132{
133    fn poll_write(
134        mut self: Pin<&mut Self>,
135        cx: &mut Context<'_>,
136        buf: &[u8],
137    ) -> Poll<Result<usize>> {
138        self.as_async_write().poll_write(cx, buf)
139    }
140
141    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
142        self.as_async_write().poll_flush(cx)
143    }
144
145    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
146        self.as_async_write().poll_close(cx)
147    }
148
149    fn poll_write_vectored(
150        mut self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        bufs: &[IoSlice<'_>],
153    ) -> Poll<Result<usize>> {
154        self.as_async_write().poll_write_vectored(cx, bufs)
155    }
156}
157
158impl<T, U> Binding<T, U>
159where
160    T: Transport,
161    U: Transport,
162{
163    fn as_transport_mut(&mut self) -> &mut dyn Transport {
164        match self {
165            Tcp(t) => t as &mut dyn Transport,
166            Unix(u) => u as &mut dyn Transport,
167        }
168    }
169
170    fn as_transport(&self) -> &dyn Transport {
171        match self {
172            Tcp(t) => t as &dyn Transport,
173            Unix(u) => u as &dyn Transport,
174        }
175    }
176}
177
178impl<T, U> Transport for Binding<T, U>
179where
180    T: Transport,
181    U: Transport,
182{
183    fn set_linger(&mut self, linger: Option<std::time::Duration>) -> Result<()> {
184        self.as_transport_mut().set_linger(linger)
185    }
186
187    fn set_nodelay(&mut self, nodelay: bool) -> Result<()> {
188        self.as_transport_mut().set_nodelay(nodelay)
189    }
190
191    fn set_ip_ttl(&mut self, ttl: u32) -> Result<()> {
192        self.as_transport_mut().set_ip_ttl(ttl)
193    }
194
195    fn peer_addr(&self) -> Result<Option<std::net::SocketAddr>> {
196        self.as_transport().peer_addr()
197    }
198}