Skip to main content

trillium_native_tls/
client.rs

1use async_native_tls::{TlsConnector, TlsStream};
2use std::{
3    fmt::{Debug, Formatter},
4    future::Future,
5    io::{Error, ErrorKind, IoSlice, IoSliceMut, Result},
6    net::SocketAddr,
7    pin::Pin,
8    sync::Arc,
9    task::{Context, Poll},
10};
11use trillium_server_common::{async_trait, AsyncRead, AsyncWrite, Connector, Transport, Url};
12
13/**
14Configuration for the native tls client connector
15*/
16#[derive(Clone)]
17pub struct NativeTlsConfig<Config> {
18    /// configuration for the inner Connector (usually tcp)
19    pub tcp_config: Config,
20
21    /**
22    native tls configuration
23
24    Although async_native_tls calls this
25    a TlsConnector, it's actually a builder ¯\_(ツ)_/¯
26    */
27    pub tls_connector: Arc<TlsConnector>,
28}
29
30impl<C: Connector> NativeTlsConfig<C> {
31    /// replace the tcp config
32    pub fn with_tcp_config(mut self, config: C) -> Self {
33        self.tcp_config = config;
34        self
35    }
36}
37
38impl<C: Connector> From<C> for NativeTlsConfig<C> {
39    fn from(tcp_config: C) -> Self {
40        Self {
41            tcp_config,
42            tls_connector: Arc::new(TlsConnector::default()),
43        }
44    }
45}
46
47impl<Config: Debug> Debug for NativeTlsConfig<Config> {
48    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
49        f.debug_struct("NativeTlsConfig")
50            .field("tcp_config", &self.tcp_config)
51            .field("tls_connector", &"..")
52            .finish()
53    }
54}
55
56impl<Config: Default> Default for NativeTlsConfig<Config> {
57    fn default() -> Self {
58        Self {
59            tcp_config: Config::default(),
60            tls_connector: Arc::new(TlsConnector::default()),
61        }
62    }
63}
64
65impl<Config> AsRef<Config> for NativeTlsConfig<Config> {
66    fn as_ref(&self) -> &Config {
67        &self.tcp_config
68    }
69}
70
71#[async_trait]
72impl<T: Connector> Connector for NativeTlsConfig<T> {
73    type Transport = NativeTlsClientTransport<T::Transport>;
74
75    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
76        match url.scheme() {
77            "https" => {
78                let mut http = url.clone();
79                http.set_scheme("http").ok();
80                http.set_port(url.port_or_known_default()).ok();
81                let inner_stream = self.tcp_config.connect(&http).await?;
82
83                self.tls_connector
84                    .connect(url, inner_stream)
85                    .await
86                    .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))
87                    .map(NativeTlsClientTransport::from)
88            }
89
90            "http" => self
91                .tcp_config
92                .connect(url)
93                .await
94                .map(NativeTlsClientTransport::from),
95
96            unknown => Err(Error::new(
97                ErrorKind::InvalidInput,
98                format!("unknown scheme {unknown}"),
99            )),
100        }
101    }
102
103    fn spawn<Fut: Future<Output = ()> + Send + 'static>(&self, fut: Fut) {
104        self.tcp_config.spawn(fut)
105    }
106}
107
108/**
109Client [`Transport`] for the native tls connector
110
111This may represent either an encrypted tls connection or a plaintext
112connection
113*/
114
115#[derive(Debug)]
116pub struct NativeTlsClientTransport<T>(NativeTlsClientTransportInner<T>);
117
118impl<T: AsyncWrite + AsyncRead + Unpin> NativeTlsClientTransport<T> {
119    /// Borrow the TlsStream, if this connection is tls.
120    ///
121    /// Returns None otherwise
122    pub fn as_tls(&self) -> Option<&TlsStream<T>> {
123        match &self.0 {
124            Tcp(_) => None,
125            Tls(tls) => Some(tls),
126        }
127    }
128}
129
130impl<T> From<T> for NativeTlsClientTransport<T> {
131    fn from(value: T) -> Self {
132        Self(Tcp(value))
133    }
134}
135
136impl<T> From<TlsStream<T>> for NativeTlsClientTransport<T> {
137    fn from(value: TlsStream<T>) -> Self {
138        Self(Tls(value))
139    }
140}
141
142impl<T: Transport> AsRef<T> for NativeTlsClientTransport<T> {
143    fn as_ref(&self) -> &T {
144        match &self.0 {
145            Tcp(transport) => transport,
146            Tls(tls_stream) => tls_stream.get_ref(),
147        }
148    }
149}
150
151#[derive(Debug)]
152enum NativeTlsClientTransportInner<T> {
153    Tcp(T),
154    Tls(TlsStream<T>),
155}
156use NativeTlsClientTransportInner::{Tcp, Tls};
157
158impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsClientTransport<T> {
159    fn poll_read(
160        mut self: Pin<&mut Self>,
161        cx: &mut Context<'_>,
162        buf: &mut [u8],
163    ) -> Poll<Result<usize>> {
164        match &mut self.0 {
165            Tcp(t) => Pin::new(t).poll_read(cx, buf),
166            Tls(t) => Pin::new(t).poll_read(cx, buf),
167        }
168    }
169
170    fn poll_read_vectored(
171        mut self: Pin<&mut Self>,
172        cx: &mut Context<'_>,
173        bufs: &mut [IoSliceMut<'_>],
174    ) -> Poll<Result<usize>> {
175        match &mut self.0 {
176            Tcp(t) => Pin::new(t).poll_read_vectored(cx, bufs),
177            Tls(t) => Pin::new(t).poll_read_vectored(cx, bufs),
178        }
179    }
180}
181
182impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsClientTransport<T> {
183    fn poll_write(
184        mut self: Pin<&mut Self>,
185        cx: &mut Context<'_>,
186        buf: &[u8],
187    ) -> Poll<Result<usize>> {
188        match &mut self.0 {
189            Tcp(t) => Pin::new(t).poll_write(cx, buf),
190            Tls(t) => Pin::new(t).poll_write(cx, buf),
191        }
192    }
193
194    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
195        match &mut self.0 {
196            Tcp(t) => Pin::new(t).poll_flush(cx),
197            Tls(t) => Pin::new(t).poll_flush(cx),
198        }
199    }
200
201    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
202        match &mut self.0 {
203            Tcp(t) => Pin::new(t).poll_close(cx),
204            Tls(t) => Pin::new(t).poll_close(cx),
205        }
206    }
207
208    fn poll_write_vectored(
209        mut self: Pin<&mut Self>,
210        cx: &mut Context<'_>,
211        bufs: &[IoSlice<'_>],
212    ) -> Poll<Result<usize>> {
213        match &mut self.0 {
214            Tcp(t) => Pin::new(t).poll_write_vectored(cx, bufs),
215            Tls(t) => Pin::new(t).poll_write_vectored(cx, bufs),
216        }
217    }
218}
219
220impl<T: Transport> Transport for NativeTlsClientTransport<T> {
221    fn peer_addr(&self) -> Result<Option<SocketAddr>> {
222        self.as_ref().peer_addr()
223    }
224}