Skip to main content

trillium_rustls/
client.rs

1use crate::crypto_provider;
2use RustlsClientTransportInner::{Tcp, Tls};
3use futures_rustls::{
4    TlsConnector,
5    client::TlsStream,
6    rustls::{
7        ClientConfig, ClientConnection, client::danger::ServerCertVerifier, crypto::CryptoProvider,
8        pki_types::ServerName,
9    },
10};
11use std::{
12    fmt::{self, Debug, Formatter},
13    io::{Error, ErrorKind, IoSlice, Result},
14    net::SocketAddr,
15    pin::Pin,
16    sync::Arc,
17    task::{Context, Poll},
18};
19use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Transport, Url};
20
21#[derive(Clone, Debug)]
22pub struct RustlsClientConfig(Arc<ClientConfig>);
23
24/// Client configuration for RustlsConnector
25#[derive(Clone, Default)]
26pub struct RustlsConfig<Config> {
27    /// configuration for rustls itself
28    pub rustls_config: RustlsClientConfig,
29
30    /// configuration for the inner transport
31    pub tcp_config: Config,
32}
33
34impl<C: Connector> RustlsConfig<C> {
35    /// build a new default rustls config with this tcp config
36    pub fn new(rustls_config: impl Into<RustlsClientConfig>, tcp_config: C) -> Self {
37        Self {
38            rustls_config: rustls_config.into(),
39            tcp_config,
40        }
41    }
42}
43
44impl Default for RustlsClientConfig {
45    fn default() -> Self {
46        Self(Arc::new(default_client_config()))
47    }
48}
49
50#[cfg(feature = "platform-verifier")]
51fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
52    Arc::new(rustls_platform_verifier::Verifier::new(provider).unwrap())
53}
54
55#[cfg(not(feature = "platform-verifier"))]
56fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
57    let roots = Arc::new(futures_rustls::rustls::RootCertStore::from_iter(
58        webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
59    ));
60    futures_rustls::rustls::client::WebPkiServerVerifier::builder_with_provider(roots, provider)
61        .build()
62        .unwrap()
63}
64
65fn default_client_config() -> ClientConfig {
66    let provider = crypto_provider();
67    let verifier = verifier(Arc::clone(&provider));
68
69    ClientConfig::builder_with_provider(provider)
70        .with_safe_default_protocol_versions()
71        .expect("crypto provider did not support safe default protocol versions")
72        .dangerous()
73        .with_custom_certificate_verifier(verifier)
74        .with_no_client_auth()
75}
76
77impl From<ClientConfig> for RustlsClientConfig {
78    fn from(rustls_config: ClientConfig) -> Self {
79        Self(Arc::new(rustls_config))
80    }
81}
82
83impl From<Arc<ClientConfig>> for RustlsClientConfig {
84    fn from(rustls_config: Arc<ClientConfig>) -> Self {
85        Self(rustls_config)
86    }
87}
88
89impl<C: Connector> RustlsConfig<C> {
90    /// replace the tcp config
91    pub fn with_tcp_config(mut self, config: C) -> Self {
92        self.tcp_config = config;
93        self
94    }
95}
96
97impl<Config: Debug> Debug for RustlsConfig<Config> {
98    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
99        f.debug_struct("RustlsConfig")
100            .field("rustls_config", &format_args!(".."))
101            .field("tcp_config", &self.tcp_config)
102            .finish()
103    }
104}
105
106impl<C: Connector> Connector for RustlsConfig<C> {
107    type Runtime = C::Runtime;
108    type Transport = RustlsClientTransport<C::Transport>;
109    type Udp = C::Udp;
110
111    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
112        match url.scheme() {
113            "https" => {
114                let mut http = url.clone();
115                http.set_scheme("http").ok();
116                http.set_port(url.port_or_known_default()).ok();
117
118                let connector: TlsConnector = Arc::clone(&self.rustls_config.0).into();
119                let domain = url
120                    .domain()
121                    .and_then(|dns_name| ServerName::try_from(dns_name.to_string()).ok())
122                    .ok_or_else(|| Error::other("missing domain"))?;
123
124                connector
125                    .connect(domain, self.tcp_config.connect(&http).await?)
126                    .await
127                    .map_err(|e| Error::other(e.to_string()))
128                    .map(Into::into)
129            }
130
131            "http" => self.tcp_config.connect(url).await.map(Into::into),
132
133            unknown => Err(Error::new(
134                ErrorKind::InvalidInput,
135                format!("unknown scheme {unknown}"),
136            )),
137        }
138    }
139
140    fn runtime(&self) -> Self::Runtime {
141        self.tcp_config.runtime()
142    }
143
144    async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
145        self.tcp_config.resolve(host, port).await
146    }
147}
148
149#[derive(Debug)]
150enum RustlsClientTransportInner<T> {
151    Tcp(T),
152    Tls(Box<TlsStream<T>>),
153}
154
155/// Transport for the rustls connector
156///
157/// This may represent either an encrypted tls connection or a plaintext
158/// connection, depending on the request schema
159#[derive(Debug)]
160pub struct RustlsClientTransport<T>(RustlsClientTransportInner<T>);
161impl<T> From<T> for RustlsClientTransport<T> {
162    fn from(value: T) -> Self {
163        Self(Tcp(value))
164    }
165}
166
167impl<T> From<TlsStream<T>> for RustlsClientTransport<T> {
168    fn from(value: TlsStream<T>) -> Self {
169        Self(Tls(Box::new(value)))
170    }
171}
172
173impl<C> AsyncRead for RustlsClientTransport<C>
174where
175    C: AsyncWrite + AsyncRead + Unpin,
176{
177    fn poll_read(
178        mut self: Pin<&mut Self>,
179        cx: &mut Context<'_>,
180        buf: &mut [u8],
181    ) -> Poll<Result<usize>> {
182        match &mut self.0 {
183            Tcp(c) => Pin::new(c).poll_read(cx, buf),
184            Tls(c) => Pin::new(c).poll_read(cx, buf),
185        }
186    }
187
188    fn poll_read_vectored(
189        mut self: Pin<&mut Self>,
190        cx: &mut Context<'_>,
191        bufs: &mut [std::io::IoSliceMut<'_>],
192    ) -> Poll<Result<usize>> {
193        match &mut self.0 {
194            Tcp(c) => Pin::new(c).poll_read_vectored(cx, bufs),
195            Tls(c) => Pin::new(c).poll_read_vectored(cx, bufs),
196        }
197    }
198}
199
200impl<C> AsyncWrite for RustlsClientTransport<C>
201where
202    C: AsyncRead + AsyncWrite + Unpin,
203{
204    fn poll_write(
205        mut self: Pin<&mut Self>,
206        cx: &mut Context<'_>,
207        buf: &[u8],
208    ) -> Poll<Result<usize>> {
209        match &mut self.0 {
210            Tcp(c) => Pin::new(c).poll_write(cx, buf),
211            Tls(c) => Pin::new(&mut *c).poll_write(cx, buf),
212        }
213    }
214
215    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
216        match &mut self.0 {
217            Tcp(c) => Pin::new(c).poll_flush(cx),
218            Tls(c) => Pin::new(&mut *c).poll_flush(cx),
219        }
220    }
221
222    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
223        match &mut self.0 {
224            Tcp(c) => Pin::new(c).poll_close(cx),
225            Tls(c) => Pin::new(&mut *c).poll_close(cx),
226        }
227    }
228
229    fn poll_write_vectored(
230        mut self: Pin<&mut Self>,
231        cx: &mut Context<'_>,
232        bufs: &[IoSlice<'_>],
233    ) -> Poll<Result<usize>> {
234        match &mut self.0 {
235            Tcp(c) => Pin::new(c).poll_write_vectored(cx, bufs),
236            Tls(c) => Pin::new(&mut *c).poll_write_vectored(cx, bufs),
237        }
238    }
239}
240
241impl<T: Transport> Transport for RustlsClientTransport<T> {
242    fn peer_addr(&self) -> Result<Option<SocketAddr>> {
243        self.as_ref().peer_addr()
244    }
245}
246
247impl<T> AsRef<T> for RustlsClientTransport<T> {
248    fn as_ref(&self) -> &T {
249        match &self.0 {
250            Tcp(x) => x,
251            Tls(x) => x.get_ref().0,
252        }
253    }
254}
255
256impl<T> RustlsClientTransport<T> {
257    /// Retrieve the tls [`ClientConnection`] if this transport is Tls
258    pub fn tls_state_mut(&mut self) -> Option<&mut ClientConnection> {
259        match &mut self.0 {
260            Tls(x) => Some(x.get_mut().1),
261            _ => None,
262        }
263    }
264
265    /// Retrieve the tls [`ClientConnection`] if this transport is Tls
266    pub fn tls_state(&self) -> Option<&ClientConnection> {
267        match &self.0 {
268            Tls(x) => Some(x.get_ref().1),
269            _ => None,
270        }
271    }
272}