Skip to main content

trillium_native_tls/
server.rs

1use std::{
2    io::{self, IoSlice, IoSliceMut},
3    net::SocketAddr,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use crate::Identity;
9use async_native_tls::{Error, TlsAcceptor, TlsStream};
10use trillium_server_common::{async_trait, Acceptor, AsyncRead, AsyncWrite, Transport};
11
12/**
13trillium [`Acceptor`] for native-tls
14*/
15
16#[derive(Clone, Debug)]
17pub struct NativeTlsAcceptor(TlsAcceptor);
18
19impl NativeTlsAcceptor {
20    /**
21    constructs a NativeTlsAcceptor from a [`native_tls::TlsAcceptor`],
22    an [`async_native_tls::TlsAcceptor`], or an [`Identity`]
23    */
24    pub fn new(t: impl Into<Self>) -> Self {
25        t.into()
26    }
27
28    /**
29    constructs a NativeTlsAcceptor from a pkcs12 key and password. See
30    See [`Identity::from_pkcs8`]
31    */
32    pub fn from_pkcs12(der: &[u8], password: &str) -> Self {
33        Identity::from_pkcs12(der, password)
34            .expect("could not build Identity from provided pkcs12 key and password")
35            .into()
36    }
37
38    /**
39    constructs a NativeTlsAcceptor from a pkcs8 pem and private
40    key. See [`Identity::from_pkcs8`]
41    */
42    pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Self {
43        Identity::from_pkcs8(pem, key)
44            .expect("could not build Identity from provided pem and key")
45            .into()
46    }
47}
48
49impl From<Identity> for NativeTlsAcceptor {
50    fn from(i: Identity) -> Self {
51        native_tls::TlsAcceptor::new(i).unwrap().into()
52    }
53}
54
55impl From<native_tls::TlsAcceptor> for NativeTlsAcceptor {
56    fn from(i: native_tls::TlsAcceptor) -> Self {
57        Self(i.into())
58    }
59}
60
61impl From<TlsAcceptor> for NativeTlsAcceptor {
62    fn from(i: TlsAcceptor) -> Self {
63        Self(i)
64    }
65}
66
67impl From<(&[u8], &str)> for NativeTlsAcceptor {
68    fn from(i: (&[u8], &str)) -> Self {
69        Self::from_pkcs12(i.0, i.1)
70    }
71}
72
73#[async_trait]
74impl<Input> Acceptor<Input> for NativeTlsAcceptor
75where
76    Input: Transport,
77{
78    type Output = NativeTlsServerTransport<Input>;
79    type Error = Error;
80    async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
81        self.0.accept(input).await.map(NativeTlsServerTransport)
82    }
83}
84
85/// Server Tls Transport
86///
87/// A wrapper type around [`TlsStream`] that also implements [`Transport`]
88#[derive(Debug)]
89pub struct NativeTlsServerTransport<T>(TlsStream<T>);
90
91impl<T: AsyncWrite + AsyncRead + Unpin> AsRef<T> for NativeTlsServerTransport<T> {
92    fn as_ref(&self) -> &T {
93        self.0.get_ref()
94    }
95}
96impl<T: AsyncWrite + AsyncRead + Unpin> AsMut<T> for NativeTlsServerTransport<T> {
97    fn as_mut(&mut self) -> &mut T {
98        self.0.get_mut()
99    }
100}
101
102impl<T> AsRef<TlsStream<T>> for NativeTlsServerTransport<T> {
103    fn as_ref(&self) -> &TlsStream<T> {
104        &self.0
105    }
106}
107impl<T> AsMut<TlsStream<T>> for NativeTlsServerTransport<T> {
108    fn as_mut(&mut self) -> &mut TlsStream<T> {
109        &mut self.0
110    }
111}
112
113impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsServerTransport<T> {
114    fn poll_read(
115        mut self: Pin<&mut Self>,
116        cx: &mut Context<'_>,
117        buf: &mut [u8],
118    ) -> Poll<io::Result<usize>> {
119        Pin::new(&mut self.0).poll_read(cx, buf)
120    }
121
122    fn poll_read_vectored(
123        mut self: Pin<&mut Self>,
124        cx: &mut Context<'_>,
125        bufs: &mut [IoSliceMut<'_>],
126    ) -> Poll<io::Result<usize>> {
127        Pin::new(&mut self.0).poll_read_vectored(cx, bufs)
128    }
129}
130
131impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for NativeTlsServerTransport<T> {
132    fn poll_write(
133        mut self: Pin<&mut Self>,
134        cx: &mut Context<'_>,
135        buf: &[u8],
136    ) -> Poll<io::Result<usize>> {
137        Pin::new(&mut self.0).poll_write(cx, buf)
138    }
139
140    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
141        Pin::new(&mut self.0).poll_flush(cx)
142    }
143
144    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
145        Pin::new(&mut self.0).poll_close(cx)
146    }
147
148    fn poll_write_vectored(
149        mut self: Pin<&mut Self>,
150        cx: &mut Context<'_>,
151        bufs: &[IoSlice<'_>],
152    ) -> Poll<io::Result<usize>> {
153        Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
154    }
155}
156
157impl<T: Transport> Transport for NativeTlsServerTransport<T> {
158    fn peer_addr(&self) -> io::Result<Option<SocketAddr>> {
159        self.0.get_ref().peer_addr()
160    }
161}