trillium_rustls/
server.rs1use futures_rustls::{
2 rustls::{ServerConfig, ServerConnection},
3 server::TlsStream,
4 TlsAcceptor,
5};
6use std::{
7 fmt::{Debug, Formatter},
8 io,
9 pin::Pin,
10 sync::Arc,
11 task::{Context, Poll},
12};
13use trillium_server_common::{async_trait, Acceptor, AsyncRead, AsyncWrite, Transport};
14
15use crate::crypto_provider;
16
17#[derive(Clone)]
22pub struct RustlsAcceptor(TlsAcceptor);
23impl Debug for RustlsAcceptor {
24 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25 f.debug_tuple("Rustls").field(&"<<TlsAcceptor>>").finish()
26 }
27}
28
29impl RustlsAcceptor {
30 pub fn new(t: impl Into<Self>) -> Self {
34 t.into()
35 }
36
37 pub fn from_single_cert(cert: &[u8], key: &[u8]) -> Self {
64 use std::io::Cursor;
65
66 let cert_chain = rustls_pemfile::certs(&mut Cursor::new(cert))
67 .collect::<Result<_, _>>()
68 .expect("could not read certificate");
69
70 let key_der = rustls_pemfile::private_key(&mut Cursor::new(key))
71 .expect("could not read key pemfile")
72 .expect("no private key found in `key`");
73
74 ServerConfig::builder_with_provider(crypto_provider())
75 .with_safe_default_protocol_versions()
76 .expect("crypto provider did not support safe default protocol versions")
77 .with_no_client_auth()
78 .with_single_cert(cert_chain, key_der)
79 .expect("could not create a rustls ServerConfig from the supplied cert and key")
80 .into()
81 }
82}
83
84impl From<ServerConfig> for RustlsAcceptor {
85 fn from(sc: ServerConfig) -> Self {
86 Self(Arc::new(sc).into())
87 }
88}
89
90impl From<TlsAcceptor> for RustlsAcceptor {
91 fn from(ta: TlsAcceptor) -> Self {
92 Self(ta)
93 }
94}
95
96#[derive(Debug)]
98pub struct RustlsServerTransport<T>(TlsStream<T>);
99
100impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsServerTransport<T> {
101 fn poll_read(
102 mut self: Pin<&mut Self>,
103 cx: &mut Context<'_>,
104 buf: &mut [u8],
105 ) -> Poll<io::Result<usize>> {
106 Pin::new(&mut self.0).poll_read(cx, buf)
107 }
108}
109
110impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for RustlsServerTransport<T> {
111 fn poll_write(
112 mut self: Pin<&mut Self>,
113 cx: &mut Context<'_>,
114 buf: &[u8],
115 ) -> Poll<io::Result<usize>> {
116 Pin::new(&mut self.0).poll_write(cx, buf)
117 }
118
119 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
120 Pin::new(&mut self.0).poll_flush(cx)
121 }
122
123 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
124 Pin::new(&mut self.0).poll_close(cx)
125 }
126
127 fn poll_write_vectored(
128 mut self: Pin<&mut Self>,
129 cx: &mut Context<'_>,
130 bufs: &[io::IoSlice<'_>],
131 ) -> Poll<io::Result<usize>> {
132 Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
133 }
134}
135
136impl<T: Transport> Transport for RustlsServerTransport<T> {
137 fn peer_addr(&self) -> io::Result<Option<std::net::SocketAddr>> {
138 self.inner_transport().peer_addr()
139 }
140}
141
142impl<T> RustlsServerTransport<T> {
143 pub fn inner_transport(&self) -> &T {
145 self.0.get_ref().0
146 }
147
148 pub fn inner_transport_mut(&mut self) -> &mut T {
150 self.0.get_mut().0
151 }
152}
153
154impl<T> AsRef<ServerConnection> for RustlsServerTransport<T> {
155 fn as_ref(&self) -> &ServerConnection {
156 self.0.get_ref().1
157 }
158}
159
160impl<T> AsMut<ServerConnection> for RustlsServerTransport<T> {
161 fn as_mut(&mut self) -> &mut ServerConnection {
162 self.0.get_mut().1
163 }
164}
165
166impl<T> From<TlsStream<T>> for RustlsServerTransport<T> {
167 fn from(value: TlsStream<T>) -> Self {
168 Self(value)
169 }
170}
171
172impl<T> From<RustlsServerTransport<T>> for TlsStream<T> {
173 fn from(RustlsServerTransport(value): RustlsServerTransport<T>) -> Self {
174 value
175 }
176}
177
178#[async_trait]
179impl<Input> Acceptor<Input> for RustlsAcceptor
180where
181 Input: Transport,
182{
183 type Output = RustlsServerTransport<Input>;
184 type Error = io::Error;
185 async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
186 self.0.accept(input).await.map(RustlsServerTransport)
187 }
188}