trillium_native_tls/
client.rs1use async_native_tls::{TlsConnector, TlsStream};
2use std::{
3 fmt::{Debug, Formatter},
4 future::Future,
5 io::{Error, ErrorKind, IoSlice, IoSliceMut, Result},
6 net::SocketAddr,
7 pin::Pin,
8 sync::Arc,
9 task::{Context, Poll},
10};
11use trillium_server_common::{async_trait, AsyncRead, AsyncWrite, Connector, Transport, Url};
12
13#[derive(Clone)]
17pub struct NativeTlsConfig<Config> {
18 pub tcp_config: Config,
20
21 pub tls_connector: Arc<TlsConnector>,
28}
29
30impl<C: Connector> NativeTlsConfig<C> {
31 pub fn with_tcp_config(mut self, config: C) -> Self {
33 self.tcp_config = config;
34 self
35 }
36}
37
38impl<C: Connector> From<C> for NativeTlsConfig<C> {
39 fn from(tcp_config: C) -> Self {
40 Self {
41 tcp_config,
42 tls_connector: Arc::new(TlsConnector::default()),
43 }
44 }
45}
46
47impl<Config: Debug> Debug for NativeTlsConfig<Config> {
48 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("NativeTlsConfig")
50 .field("tcp_config", &self.tcp_config)
51 .field("tls_connector", &"..")
52 .finish()
53 }
54}
55
56impl<Config: Default> Default for NativeTlsConfig<Config> {
57 fn default() -> Self {
58 Self {
59 tcp_config: Config::default(),
60 tls_connector: Arc::new(TlsConnector::default()),
61 }
62 }
63}
64
65impl<Config> AsRef<Config> for NativeTlsConfig<Config> {
66 fn as_ref(&self) -> &Config {
67 &self.tcp_config
68 }
69}
70
71#[async_trait]
72impl<T: Connector> Connector for NativeTlsConfig<T> {
73 type Transport = NativeTlsClientTransport<T::Transport>;
74
75 async fn connect(&self, url: &Url) -> Result<Self::Transport> {
76 match url.scheme() {
77 "https" => {
78 let mut http = url.clone();
79 http.set_scheme("http").ok();
80 http.set_port(url.port_or_known_default()).ok();
81 let inner_stream = self.tcp_config.connect(&http).await?;
82
83 self.tls_connector
84 .connect(url, inner_stream)
85 .await
86 .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))
87 .map(NativeTlsClientTransport::from)
88 }
89
90 "http" => self
91 .tcp_config
92 .connect(url)
93 .await
94 .map(NativeTlsClientTransport::from),
95
96 unknown => Err(Error::new(
97 ErrorKind::InvalidInput,
98 format!("unknown scheme {unknown}"),
99 )),
100 }
101 }
102
103 fn spawn<Fut: Future<Output = ()> + Send + 'static>(&self, fut: Fut) {
104 self.tcp_config.spawn(fut)
105 }
106}
107
108#[derive(Debug)]
116pub struct NativeTlsClientTransport<T>(NativeTlsClientTransportInner<T>);
117
118impl<T: AsyncWrite + AsyncRead + Unpin> NativeTlsClientTransport<T> {
119 pub fn as_tls(&self) -> Option<&TlsStream<T>> {
123 match &self.0 {
124 Tcp(_) => None,
125 Tls(tls) => Some(tls),
126 }
127 }
128}
129
130impl<T> From<T> for NativeTlsClientTransport<T> {
131 fn from(value: T) -> Self {
132 Self(Tcp(value))
133 }
134}
135
136impl<T> From<TlsStream<T>> for NativeTlsClientTransport<T> {
137 fn from(value: TlsStream<T>) -> Self {
138 Self(Tls(value))
139 }
140}
141
142impl<T: Transport> AsRef<T> for NativeTlsClientTransport<T> {
143 fn as_ref(&self) -> &T {
144 match &self.0 {
145 Tcp(transport) => transport,
146 Tls(tls_stream) => tls_stream.get_ref(),
147 }
148 }
149}
150
151#[derive(Debug)]
152enum NativeTlsClientTransportInner<T> {
153 Tcp(T),
154 Tls(TlsStream<T>),
155}
156use NativeTlsClientTransportInner::{Tcp, Tls};
157
158impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsClientTransport<T> {
159 fn poll_read(
160 mut self: Pin<&mut Self>,
161 cx: &mut Context<'_>,
162 buf: &mut [u8],
163 ) -> Poll<Result<usize>> {
164 match &mut self.0 {
165 Tcp(t) => Pin::new(t).poll_read(cx, buf),
166 Tls(t) => Pin::new(t).poll_read(cx, buf),
167 }
168 }
169
170 fn poll_read_vectored(
171 mut self: Pin<&mut Self>,
172 cx: &mut Context<'_>,
173 bufs: &mut [IoSliceMut<'_>],
174 ) -> Poll<Result<usize>> {
175 match &mut self.0 {
176 Tcp(t) => Pin::new(t).poll_read_vectored(cx, bufs),
177 Tls(t) => Pin::new(t).poll_read_vectored(cx, bufs),
178 }
179 }
180}
181
182impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsClientTransport<T> {
183 fn poll_write(
184 mut self: Pin<&mut Self>,
185 cx: &mut Context<'_>,
186 buf: &[u8],
187 ) -> Poll<Result<usize>> {
188 match &mut self.0 {
189 Tcp(t) => Pin::new(t).poll_write(cx, buf),
190 Tls(t) => Pin::new(t).poll_write(cx, buf),
191 }
192 }
193
194 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
195 match &mut self.0 {
196 Tcp(t) => Pin::new(t).poll_flush(cx),
197 Tls(t) => Pin::new(t).poll_flush(cx),
198 }
199 }
200
201 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
202 match &mut self.0 {
203 Tcp(t) => Pin::new(t).poll_close(cx),
204 Tls(t) => Pin::new(t).poll_close(cx),
205 }
206 }
207
208 fn poll_write_vectored(
209 mut self: Pin<&mut Self>,
210 cx: &mut Context<'_>,
211 bufs: &[IoSlice<'_>],
212 ) -> Poll<Result<usize>> {
213 match &mut self.0 {
214 Tcp(t) => Pin::new(t).poll_write_vectored(cx, bufs),
215 Tls(t) => Pin::new(t).poll_write_vectored(cx, bufs),
216 }
217 }
218}
219
220impl<T: Transport> Transport for NativeTlsClientTransport<T> {
221 fn peer_addr(&self) -> Result<Option<SocketAddr>> {
222 self.as_ref().peer_addr()
223 }
224}