Skip to main content

trillium_native_tls/
server.rs

1use crate::Identity;
2use async_native_tls::{Error, TlsAcceptor, TlsStream};
3use std::{
4    io::{self, IoSlice, IoSliceMut},
5    net::SocketAddr,
6    pin::Pin,
7    task::{Context, Poll},
8};
9use trillium_server_common::{Acceptor, AsyncRead, AsyncWrite, Transport};
10
11/// trillium [`Acceptor`] for native-tls
12
13#[derive(Clone, Debug)]
14pub struct NativeTlsAcceptor(TlsAcceptor);
15
16impl NativeTlsAcceptor {
17    /// constructs a NativeTlsAcceptor from a [`native_tls::TlsAcceptor`],
18    /// an [`async_native_tls::TlsAcceptor`], or an [`Identity`]
19    pub fn new(t: impl Into<Self>) -> Self {
20        t.into()
21    }
22
23    /// constructs a NativeTlsAcceptor from a pkcs12 key and password. See
24    /// See [`Identity::from_pkcs8`]
25    pub fn from_pkcs12(der: &[u8], password: &str) -> Self {
26        Identity::from_pkcs12(der, password)
27            .expect("could not build Identity from provided pkcs12 key and password")
28            .into()
29    }
30
31    /// constructs a NativeTlsAcceptor from a pkcs8 pem and private
32    /// key. See [`Identity::from_pkcs8`]
33    pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Self {
34        Identity::from_pkcs8(pem, key)
35            .expect("could not build Identity from provided pem and key")
36            .into()
37    }
38}
39
40impl From<Identity> for NativeTlsAcceptor {
41    fn from(i: Identity) -> Self {
42        native_tls::TlsAcceptor::new(i).unwrap().into()
43    }
44}
45
46impl From<native_tls::TlsAcceptor> for NativeTlsAcceptor {
47    fn from(i: native_tls::TlsAcceptor) -> Self {
48        Self(i.into())
49    }
50}
51
52impl From<TlsAcceptor> for NativeTlsAcceptor {
53    fn from(i: TlsAcceptor) -> Self {
54        Self(i)
55    }
56}
57
58impl From<(&[u8], &str)> for NativeTlsAcceptor {
59    fn from(i: (&[u8], &str)) -> Self {
60        Self::from_pkcs12(i.0, i.1)
61    }
62}
63
64impl<Input> Acceptor<Input> for NativeTlsAcceptor
65where
66    Input: Transport,
67{
68    type Error = Error;
69    type Output = NativeTlsServerTransport<Input>;
70
71    async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
72        self.0.accept(input).await.map(NativeTlsServerTransport)
73    }
74}
75
76/// Server Tls Transport
77///
78/// A wrapper type around [`TlsStream`] that also implements [`Transport`]
79#[derive(Debug)]
80pub struct NativeTlsServerTransport<T>(TlsStream<T>);
81
82impl<T: AsyncWrite + AsyncRead + Unpin> AsRef<T> for NativeTlsServerTransport<T> {
83    fn as_ref(&self) -> &T {
84        self.0.get_ref()
85    }
86}
87impl<T: AsyncWrite + AsyncRead + Unpin> AsMut<T> for NativeTlsServerTransport<T> {
88    fn as_mut(&mut self) -> &mut T {
89        self.0.get_mut()
90    }
91}
92
93impl<T> AsRef<TlsStream<T>> for NativeTlsServerTransport<T> {
94    fn as_ref(&self) -> &TlsStream<T> {
95        &self.0
96    }
97}
98impl<T> AsMut<TlsStream<T>> for NativeTlsServerTransport<T> {
99    fn as_mut(&mut self) -> &mut TlsStream<T> {
100        &mut self.0
101    }
102}
103
104impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsServerTransport<T> {
105    fn poll_read(
106        mut self: Pin<&mut Self>,
107        cx: &mut Context<'_>,
108        buf: &mut [u8],
109    ) -> Poll<io::Result<usize>> {
110        Pin::new(&mut self.0).poll_read(cx, buf)
111    }
112
113    fn poll_read_vectored(
114        mut self: Pin<&mut Self>,
115        cx: &mut Context<'_>,
116        bufs: &mut [IoSliceMut<'_>],
117    ) -> Poll<io::Result<usize>> {
118        Pin::new(&mut self.0).poll_read_vectored(cx, bufs)
119    }
120}
121
122impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for NativeTlsServerTransport<T> {
123    fn poll_write(
124        mut self: Pin<&mut Self>,
125        cx: &mut Context<'_>,
126        buf: &[u8],
127    ) -> Poll<io::Result<usize>> {
128        Pin::new(&mut self.0).poll_write(cx, buf)
129    }
130
131    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
132        Pin::new(&mut self.0).poll_flush(cx)
133    }
134
135    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
136        Pin::new(&mut self.0).poll_close(cx)
137    }
138
139    fn poll_write_vectored(
140        mut self: Pin<&mut Self>,
141        cx: &mut Context<'_>,
142        bufs: &[IoSlice<'_>],
143    ) -> Poll<io::Result<usize>> {
144        Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
145    }
146}
147
148impl<T: Transport> Transport for NativeTlsServerTransport<T> {
149    fn peer_addr(&self) -> io::Result<Option<SocketAddr>> {
150        self.0.get_ref().peer_addr()
151    }
152}