Skip to main content

trillium_rustls/
server.rs

1use futures_rustls::{
2    rustls::{ServerConfig, ServerConnection},
3    server::TlsStream,
4    TlsAcceptor,
5};
6use std::{
7    fmt::{Debug, Formatter},
8    io,
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll},
12};
13use trillium_server_common::{async_trait, Acceptor, AsyncRead, AsyncWrite, Transport};
14
15use crate::crypto_provider;
16
17/**
18trillium [`Acceptor`] for Rustls
19*/
20
21#[derive(Clone)]
22pub struct RustlsAcceptor(TlsAcceptor);
23impl Debug for RustlsAcceptor {
24    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25        f.debug_tuple("Rustls").field(&"<<TlsAcceptor>>").finish()
26    }
27}
28
29impl RustlsAcceptor {
30    /**
31    build a new RustlsAcceptor from a [`ServerConfig`] or a [`TlsAcceptor`]
32    */
33    pub fn new(t: impl Into<Self>) -> Self {
34        t.into()
35    }
36
37    /**
38    build a new RustlsAcceptor from a cert chain (pem) and private key.
39
40    See
41    [`ConfigBuilder::with_single_cert`][`crate::rustls::ConfigBuilder::with_single_cert`]
42    for accepted formats. If you need to customize the
43    [`ServerConfig`], use ServerConfig's Into RustlsAcceptor, eg
44
45    ```rust,ignore
46    use trillium_rustls::{rustls::ServerConfig, RustlsAcceptor};
47    let rustls_acceptor: RustlsAcceptor = ServerConfig::builder()
48        .with_no_client_auth()
49        .with_single_cert(certs, private_key)
50        .expect("could not build rustls ServerConfig")
51        .into();
52    ```
53
54    # Example
55
56    ```rust,no_run
57    use trillium_rustls::RustlsAcceptor;
58    const KEY: &[u8] = include_bytes!("../examples/key.pem");
59    const CERT: &[u8] = include_bytes!("../examples/cert.pem");
60    let rustls_acceptor = RustlsAcceptor::from_single_cert(CERT, KEY);
61    ```
62    */
63    pub fn from_single_cert(cert: &[u8], key: &[u8]) -> Self {
64        use std::io::Cursor;
65
66        let cert_chain = rustls_pemfile::certs(&mut Cursor::new(cert))
67            .collect::<Result<_, _>>()
68            .expect("could not read certificate");
69
70        let key_der = rustls_pemfile::private_key(&mut Cursor::new(key))
71            .expect("could not read key pemfile")
72            .expect("no private key found in `key`");
73
74        ServerConfig::builder_with_provider(crypto_provider())
75            .with_safe_default_protocol_versions()
76            .expect("crypto provider did not support safe default protocol versions")
77            .with_no_client_auth()
78            .with_single_cert(cert_chain, key_der)
79            .expect("could not create a rustls ServerConfig from the supplied cert and key")
80            .into()
81    }
82}
83
84impl From<ServerConfig> for RustlsAcceptor {
85    fn from(sc: ServerConfig) -> Self {
86        Self(Arc::new(sc).into())
87    }
88}
89
90impl From<TlsAcceptor> for RustlsAcceptor {
91    fn from(ta: TlsAcceptor) -> Self {
92        Self(ta)
93    }
94}
95
96/// Transport for rustls server acceptor
97#[derive(Debug)]
98pub struct RustlsServerTransport<T>(TlsStream<T>);
99
100impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsServerTransport<T> {
101    fn poll_read(
102        mut self: Pin<&mut Self>,
103        cx: &mut Context<'_>,
104        buf: &mut [u8],
105    ) -> Poll<io::Result<usize>> {
106        Pin::new(&mut self.0).poll_read(cx, buf)
107    }
108}
109
110impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for RustlsServerTransport<T> {
111    fn poll_write(
112        mut self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114        buf: &[u8],
115    ) -> Poll<io::Result<usize>> {
116        Pin::new(&mut self.0).poll_write(cx, buf)
117    }
118
119    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
120        Pin::new(&mut self.0).poll_flush(cx)
121    }
122
123    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
124        Pin::new(&mut self.0).poll_close(cx)
125    }
126
127    fn poll_write_vectored(
128        mut self: Pin<&mut Self>,
129        cx: &mut Context<'_>,
130        bufs: &[io::IoSlice<'_>],
131    ) -> Poll<io::Result<usize>> {
132        Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
133    }
134}
135
136impl<T: Transport> Transport for RustlsServerTransport<T> {
137    fn peer_addr(&self) -> io::Result<Option<std::net::SocketAddr>> {
138        self.inner_transport().peer_addr()
139    }
140}
141
142impl<T> RustlsServerTransport<T> {
143    /// access the contained transport type (eg TcpStream)
144    pub fn inner_transport(&self) -> &T {
145        self.0.get_ref().0
146    }
147
148    /// mutably access the contained transport type (eg TcpStream)
149    pub fn inner_transport_mut(&mut self) -> &mut T {
150        self.0.get_mut().0
151    }
152}
153
154impl<T> AsRef<ServerConnection> for RustlsServerTransport<T> {
155    fn as_ref(&self) -> &ServerConnection {
156        self.0.get_ref().1
157    }
158}
159
160impl<T> AsMut<ServerConnection> for RustlsServerTransport<T> {
161    fn as_mut(&mut self) -> &mut ServerConnection {
162        self.0.get_mut().1
163    }
164}
165
166impl<T> From<TlsStream<T>> for RustlsServerTransport<T> {
167    fn from(value: TlsStream<T>) -> Self {
168        Self(value)
169    }
170}
171
172impl<T> From<RustlsServerTransport<T>> for TlsStream<T> {
173    fn from(RustlsServerTransport(value): RustlsServerTransport<T>) -> Self {
174        value
175    }
176}
177
178#[async_trait]
179impl<Input> Acceptor<Input> for RustlsAcceptor
180where
181    Input: Transport,
182{
183    type Output = RustlsServerTransport<Input>;
184    type Error = io::Error;
185    async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
186        self.0.accept(input).await.map(RustlsServerTransport)
187    }
188}