Skip to main content

trillium_rustls/
server.rs

1use crate::crypto_provider;
2use futures_rustls::{
3    TlsAcceptor,
4    rustls::{ServerConfig, ServerConnection},
5    server::TlsStream,
6};
7use std::{
8    fmt::{Debug, Formatter},
9    io,
10    pin::Pin,
11    sync::Arc,
12    task::{Context, Poll},
13};
14use trillium_server_common::{Acceptor, AsyncRead, AsyncWrite, Transport};
15
16/// trillium [`Acceptor`] for Rustls
17
18#[derive(Clone)]
19pub struct RustlsAcceptor(TlsAcceptor);
20impl Debug for RustlsAcceptor {
21    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22        f.debug_tuple("Rustls").field(&"<<TlsAcceptor>>").finish()
23    }
24}
25
26impl RustlsAcceptor {
27    /// build a new RustlsAcceptor from a [`ServerConfig`] or a [`TlsAcceptor`]
28    pub fn new(t: impl Into<Self>) -> Self {
29        t.into()
30    }
31
32    /// build a new RustlsAcceptor from a cert chain (pem) and private key.
33    ///
34    /// See
35    /// [`ConfigBuilder::with_single_cert`][`crate::rustls::ConfigBuilder::with_single_cert`]
36    /// for accepted formats. If you need to customize the
37    /// [`ServerConfig`], use ServerConfig's `Into<RustlsAcceptor>`, eg
38    ///
39    /// ```rust,no_run
40    /// use trillium_rustls::{rustls::ServerConfig, RustlsAcceptor};
41    /// # let certs = vec![];
42    /// # let mut private_key = rustls_pemfile::private_key(&mut std::io::Cursor::new(b"")).unwrap().unwrap();
43    /// let rustls_acceptor: RustlsAcceptor = ServerConfig::builder()
44    ///     .with_no_client_auth()
45    ///     .with_single_cert(certs, private_key)
46    ///     .expect("could not build rustls ServerConfig")
47    ///     .into();
48    /// ```
49    ///
50    /// # Example
51    ///
52    /// ```rust,no_run
53    /// use trillium_rustls::RustlsAcceptor;
54    /// const KEY: &[u8] = include_bytes!("../examples/key.pem");
55    /// const CERT: &[u8] = include_bytes!("../examples/cert.pem");
56    /// let rustls_acceptor = RustlsAcceptor::from_single_cert(CERT, KEY);
57    /// ```
58    pub fn from_single_cert(cert: &[u8], key: &[u8]) -> Self {
59        use std::io::Cursor;
60
61        let cert_chain = rustls_pemfile::certs(&mut Cursor::new(cert))
62            .collect::<Result<_, _>>()
63            .expect("could not read certificate");
64
65        let key_der = rustls_pemfile::private_key(&mut Cursor::new(key))
66            .expect("could not read key pemfile")
67            .expect("no private key found in `key`");
68
69        ServerConfig::builder_with_provider(crypto_provider())
70            .with_safe_default_protocol_versions()
71            .expect("crypto provider did not support safe default protocol versions")
72            .with_no_client_auth()
73            .with_single_cert(cert_chain, key_der)
74            .expect("could not create a rustls ServerConfig from the supplied cert and key")
75            .into()
76    }
77}
78
79impl From<ServerConfig> for RustlsAcceptor {
80    fn from(sc: ServerConfig) -> Self {
81        Self(Arc::new(sc).into())
82    }
83}
84
85impl From<TlsAcceptor> for RustlsAcceptor {
86    fn from(ta: TlsAcceptor) -> Self {
87        Self(ta)
88    }
89}
90
91/// Transport for rustls server acceptor
92#[derive(Debug)]
93pub struct RustlsServerTransport<T>(TlsStream<T>);
94
95impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsServerTransport<T> {
96    fn poll_read(
97        mut self: Pin<&mut Self>,
98        cx: &mut Context<'_>,
99        buf: &mut [u8],
100    ) -> Poll<io::Result<usize>> {
101        Pin::new(&mut self.0).poll_read(cx, buf)
102    }
103}
104
105impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for RustlsServerTransport<T> {
106    fn poll_write(
107        mut self: Pin<&mut Self>,
108        cx: &mut Context<'_>,
109        buf: &[u8],
110    ) -> Poll<io::Result<usize>> {
111        Pin::new(&mut self.0).poll_write(cx, buf)
112    }
113
114    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
115        Pin::new(&mut self.0).poll_flush(cx)
116    }
117
118    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
119        Pin::new(&mut self.0).poll_close(cx)
120    }
121
122    fn poll_write_vectored(
123        mut self: Pin<&mut Self>,
124        cx: &mut Context<'_>,
125        bufs: &[io::IoSlice<'_>],
126    ) -> Poll<io::Result<usize>> {
127        Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
128    }
129}
130
131impl<T: Transport> Transport for RustlsServerTransport<T> {
132    fn peer_addr(&self) -> io::Result<Option<std::net::SocketAddr>> {
133        self.inner_transport().peer_addr()
134    }
135}
136
137impl<T> RustlsServerTransport<T> {
138    /// access the contained transport type (eg TcpStream)
139    pub fn inner_transport(&self) -> &T {
140        self.0.get_ref().0
141    }
142
143    /// mutably access the contained transport type (eg TcpStream)
144    pub fn inner_transport_mut(&mut self) -> &mut T {
145        self.0.get_mut().0
146    }
147}
148
149impl<T> AsRef<ServerConnection> for RustlsServerTransport<T> {
150    fn as_ref(&self) -> &ServerConnection {
151        self.0.get_ref().1
152    }
153}
154
155impl<T> AsMut<ServerConnection> for RustlsServerTransport<T> {
156    fn as_mut(&mut self) -> &mut ServerConnection {
157        self.0.get_mut().1
158    }
159}
160
161impl<T> From<TlsStream<T>> for RustlsServerTransport<T> {
162    fn from(value: TlsStream<T>) -> Self {
163        Self(value)
164    }
165}
166
167impl<T> From<RustlsServerTransport<T>> for TlsStream<T> {
168    fn from(RustlsServerTransport(value): RustlsServerTransport<T>) -> Self {
169        value
170    }
171}
172
173impl<Input> Acceptor<Input> for RustlsAcceptor
174where
175    Input: Transport,
176{
177    type Error = io::Error;
178    type Output = RustlsServerTransport<Input>;
179
180    async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
181        self.0.accept(input).await.map(RustlsServerTransport)
182    }
183}