1use crate::crypto_provider;
2use RustlsClientTransportInner::{Tcp, Tls};
3use futures_rustls::{
4 TlsConnector,
5 client::TlsStream,
6 rustls::{
7 ClientConfig, ClientConnection, client::danger::ServerCertVerifier, crypto::CryptoProvider,
8 pki_types::ServerName,
9 },
10};
11use std::{
12 fmt::{self, Debug, Formatter},
13 io::{Error, ErrorKind, IoSlice, Result},
14 net::SocketAddr,
15 pin::Pin,
16 sync::Arc,
17 task::{Context, Poll},
18};
19use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Transport, Url};
20
21#[derive(Clone, Debug)]
22pub struct RustlsClientConfig(Arc<ClientConfig>);
23
24#[derive(Clone, Default)]
26pub struct RustlsConfig<Config> {
27 pub rustls_config: RustlsClientConfig,
29
30 pub tcp_config: Config,
32}
33
34impl<C: Connector> RustlsConfig<C> {
35 pub fn new(rustls_config: impl Into<RustlsClientConfig>, tcp_config: C) -> Self {
37 Self {
38 rustls_config: rustls_config.into(),
39 tcp_config,
40 }
41 }
42}
43
44impl Default for RustlsClientConfig {
45 fn default() -> Self {
46 Self(Arc::new(default_client_config()))
47 }
48}
49
50#[cfg(feature = "platform-verifier")]
51fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
52 Arc::new(rustls_platform_verifier::Verifier::new(provider).unwrap())
53}
54
55#[cfg(not(feature = "platform-verifier"))]
56fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
57 let roots = Arc::new(futures_rustls::rustls::RootCertStore::from_iter(
58 webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
59 ));
60 futures_rustls::rustls::client::WebPkiServerVerifier::builder_with_provider(roots, provider)
61 .build()
62 .unwrap()
63}
64
65fn default_client_config() -> ClientConfig {
66 let provider = crypto_provider();
67 let verifier = verifier(Arc::clone(&provider));
68
69 ClientConfig::builder_with_provider(provider)
70 .with_safe_default_protocol_versions()
71 .expect("crypto provider did not support safe default protocol versions")
72 .dangerous()
73 .with_custom_certificate_verifier(verifier)
74 .with_no_client_auth()
75}
76
77impl From<ClientConfig> for RustlsClientConfig {
78 fn from(rustls_config: ClientConfig) -> Self {
79 Self(Arc::new(rustls_config))
80 }
81}
82
83impl From<Arc<ClientConfig>> for RustlsClientConfig {
84 fn from(rustls_config: Arc<ClientConfig>) -> Self {
85 Self(rustls_config)
86 }
87}
88
89impl<C: Connector> RustlsConfig<C> {
90 pub fn with_tcp_config(mut self, config: C) -> Self {
92 self.tcp_config = config;
93 self
94 }
95}
96
97impl<Config: Debug> Debug for RustlsConfig<Config> {
98 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
99 f.debug_struct("RustlsConfig")
100 .field("rustls_config", &format_args!(".."))
101 .field("tcp_config", &self.tcp_config)
102 .finish()
103 }
104}
105
106impl<C: Connector> Connector for RustlsConfig<C> {
107 type Runtime = C::Runtime;
108 type Transport = RustlsClientTransport<C::Transport>;
109 type Udp = C::Udp;
110
111 async fn connect(&self, url: &Url) -> Result<Self::Transport> {
112 match url.scheme() {
113 "https" => {
114 let mut http = url.clone();
115 http.set_scheme("http").ok();
116 http.set_port(url.port_or_known_default()).ok();
117
118 let connector: TlsConnector = Arc::clone(&self.rustls_config.0).into();
119 let domain = url
120 .domain()
121 .and_then(|dns_name| ServerName::try_from(dns_name.to_string()).ok())
122 .ok_or_else(|| Error::other("missing domain"))?;
123
124 connector
125 .connect(domain, self.tcp_config.connect(&http).await?)
126 .await
127 .map_err(|e| Error::other(e.to_string()))
128 .map(Into::into)
129 }
130
131 "http" => self.tcp_config.connect(url).await.map(Into::into),
132
133 unknown => Err(Error::new(
134 ErrorKind::InvalidInput,
135 format!("unknown scheme {unknown}"),
136 )),
137 }
138 }
139
140 fn runtime(&self) -> Self::Runtime {
141 self.tcp_config.runtime()
142 }
143
144 async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
145 self.tcp_config.resolve(host, port).await
146 }
147}
148
149#[derive(Debug)]
150enum RustlsClientTransportInner<T> {
151 Tcp(T),
152 Tls(Box<TlsStream<T>>),
153}
154
155#[derive(Debug)]
160pub struct RustlsClientTransport<T>(RustlsClientTransportInner<T>);
161impl<T> From<T> for RustlsClientTransport<T> {
162 fn from(value: T) -> Self {
163 Self(Tcp(value))
164 }
165}
166
167impl<T> From<TlsStream<T>> for RustlsClientTransport<T> {
168 fn from(value: TlsStream<T>) -> Self {
169 Self(Tls(Box::new(value)))
170 }
171}
172
173impl<C> AsyncRead for RustlsClientTransport<C>
174where
175 C: AsyncWrite + AsyncRead + Unpin,
176{
177 fn poll_read(
178 mut self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 buf: &mut [u8],
181 ) -> Poll<Result<usize>> {
182 match &mut self.0 {
183 Tcp(c) => Pin::new(c).poll_read(cx, buf),
184 Tls(c) => Pin::new(c).poll_read(cx, buf),
185 }
186 }
187
188 fn poll_read_vectored(
189 mut self: Pin<&mut Self>,
190 cx: &mut Context<'_>,
191 bufs: &mut [std::io::IoSliceMut<'_>],
192 ) -> Poll<Result<usize>> {
193 match &mut self.0 {
194 Tcp(c) => Pin::new(c).poll_read_vectored(cx, bufs),
195 Tls(c) => Pin::new(c).poll_read_vectored(cx, bufs),
196 }
197 }
198}
199
200impl<C> AsyncWrite for RustlsClientTransport<C>
201where
202 C: AsyncRead + AsyncWrite + Unpin,
203{
204 fn poll_write(
205 mut self: Pin<&mut Self>,
206 cx: &mut Context<'_>,
207 buf: &[u8],
208 ) -> Poll<Result<usize>> {
209 match &mut self.0 {
210 Tcp(c) => Pin::new(c).poll_write(cx, buf),
211 Tls(c) => Pin::new(&mut *c).poll_write(cx, buf),
212 }
213 }
214
215 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
216 match &mut self.0 {
217 Tcp(c) => Pin::new(c).poll_flush(cx),
218 Tls(c) => Pin::new(&mut *c).poll_flush(cx),
219 }
220 }
221
222 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
223 match &mut self.0 {
224 Tcp(c) => Pin::new(c).poll_close(cx),
225 Tls(c) => Pin::new(&mut *c).poll_close(cx),
226 }
227 }
228
229 fn poll_write_vectored(
230 mut self: Pin<&mut Self>,
231 cx: &mut Context<'_>,
232 bufs: &[IoSlice<'_>],
233 ) -> Poll<Result<usize>> {
234 match &mut self.0 {
235 Tcp(c) => Pin::new(c).poll_write_vectored(cx, bufs),
236 Tls(c) => Pin::new(&mut *c).poll_write_vectored(cx, bufs),
237 }
238 }
239}
240
241impl<T: Transport> Transport for RustlsClientTransport<T> {
242 fn peer_addr(&self) -> Result<Option<SocketAddr>> {
243 self.as_ref().peer_addr()
244 }
245}
246
247impl<T> AsRef<T> for RustlsClientTransport<T> {
248 fn as_ref(&self) -> &T {
249 match &self.0 {
250 Tcp(x) => x,
251 Tls(x) => x.get_ref().0,
252 }
253 }
254}
255
256impl<T> RustlsClientTransport<T> {
257 pub fn tls_state_mut(&mut self) -> Option<&mut ClientConnection> {
259 match &mut self.0 {
260 Tls(x) => Some(x.get_mut().1),
261 _ => None,
262 }
263 }
264
265 pub fn tls_state(&self) -> Option<&ClientConnection> {
267 match &self.0 {
268 Tls(x) => Some(x.get_ref().1),
269 _ => None,
270 }
271 }
272}