trillium_rustls/
server.rs1use crate::crypto_provider;
2use futures_rustls::{
3 TlsAcceptor,
4 rustls::{ServerConfig, ServerConnection},
5 server::TlsStream,
6};
7use std::{
8 fmt::{Debug, Formatter},
9 io,
10 pin::Pin,
11 sync::Arc,
12 task::{Context, Poll},
13};
14use trillium_server_common::{Acceptor, AsyncRead, AsyncWrite, Transport};
15
16#[derive(Clone)]
19pub struct RustlsAcceptor(TlsAcceptor);
20impl Debug for RustlsAcceptor {
21 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22 f.debug_tuple("Rustls").field(&"<<TlsAcceptor>>").finish()
23 }
24}
25
26impl RustlsAcceptor {
27 pub fn new(t: impl Into<Self>) -> Self {
29 t.into()
30 }
31
32 pub fn from_single_cert(cert: &[u8], key: &[u8]) -> Self {
59 use std::io::Cursor;
60
61 let cert_chain = rustls_pemfile::certs(&mut Cursor::new(cert))
62 .collect::<Result<_, _>>()
63 .expect("could not read certificate");
64
65 let key_der = rustls_pemfile::private_key(&mut Cursor::new(key))
66 .expect("could not read key pemfile")
67 .expect("no private key found in `key`");
68
69 ServerConfig::builder_with_provider(crypto_provider())
70 .with_safe_default_protocol_versions()
71 .expect("crypto provider did not support safe default protocol versions")
72 .with_no_client_auth()
73 .with_single_cert(cert_chain, key_der)
74 .expect("could not create a rustls ServerConfig from the supplied cert and key")
75 .into()
76 }
77}
78
79impl From<ServerConfig> for RustlsAcceptor {
80 fn from(sc: ServerConfig) -> Self {
81 Self(Arc::new(sc).into())
82 }
83}
84
85impl From<TlsAcceptor> for RustlsAcceptor {
86 fn from(ta: TlsAcceptor) -> Self {
87 Self(ta)
88 }
89}
90
91#[derive(Debug)]
93pub struct RustlsServerTransport<T>(TlsStream<T>);
94
95impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsServerTransport<T> {
96 fn poll_read(
97 mut self: Pin<&mut Self>,
98 cx: &mut Context<'_>,
99 buf: &mut [u8],
100 ) -> Poll<io::Result<usize>> {
101 Pin::new(&mut self.0).poll_read(cx, buf)
102 }
103}
104
105impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for RustlsServerTransport<T> {
106 fn poll_write(
107 mut self: Pin<&mut Self>,
108 cx: &mut Context<'_>,
109 buf: &[u8],
110 ) -> Poll<io::Result<usize>> {
111 Pin::new(&mut self.0).poll_write(cx, buf)
112 }
113
114 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
115 Pin::new(&mut self.0).poll_flush(cx)
116 }
117
118 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
119 Pin::new(&mut self.0).poll_close(cx)
120 }
121
122 fn poll_write_vectored(
123 mut self: Pin<&mut Self>,
124 cx: &mut Context<'_>,
125 bufs: &[io::IoSlice<'_>],
126 ) -> Poll<io::Result<usize>> {
127 Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
128 }
129}
130
131impl<T: Transport> Transport for RustlsServerTransport<T> {
132 fn peer_addr(&self) -> io::Result<Option<std::net::SocketAddr>> {
133 self.inner_transport().peer_addr()
134 }
135}
136
137impl<T> RustlsServerTransport<T> {
138 pub fn inner_transport(&self) -> &T {
140 self.0.get_ref().0
141 }
142
143 pub fn inner_transport_mut(&mut self) -> &mut T {
145 self.0.get_mut().0
146 }
147}
148
149impl<T> AsRef<ServerConnection> for RustlsServerTransport<T> {
150 fn as_ref(&self) -> &ServerConnection {
151 self.0.get_ref().1
152 }
153}
154
155impl<T> AsMut<ServerConnection> for RustlsServerTransport<T> {
156 fn as_mut(&mut self) -> &mut ServerConnection {
157 self.0.get_mut().1
158 }
159}
160
161impl<T> From<TlsStream<T>> for RustlsServerTransport<T> {
162 fn from(value: TlsStream<T>) -> Self {
163 Self(value)
164 }
165}
166
167impl<T> From<RustlsServerTransport<T>> for TlsStream<T> {
168 fn from(RustlsServerTransport(value): RustlsServerTransport<T>) -> Self {
169 value
170 }
171}
172
173impl<Input> Acceptor<Input> for RustlsAcceptor
174where
175 Input: Transport,
176{
177 type Error = io::Error;
178 type Output = RustlsServerTransport<Input>;
179
180 async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
181 self.0.accept(input).await.map(RustlsServerTransport)
182 }
183}