1use crate::{
2 connection::QuinnConnection,
3 runtime::{SocketTransport, TrilliumRuntime},
4};
5use std::{io, net::SocketAddr, sync::Arc};
6use trillium_server_common::{Info, QuicConfig as QuicConfigTrait, QuicEndpoint, Server};
7
8pub struct QuicConfig(quinn::ServerConfig);
20
21impl QuicConfig {
22 pub fn from_single_cert(cert_pem: &[u8], key_pem: &[u8]) -> Self {
27 let certs: Vec<_> = rustls_pemfile::certs(&mut io::BufReader::new(cert_pem))
28 .collect::<Result<_, _>>()
29 .expect("parsing certificate PEM");
30
31 let key = rustls_pemfile::private_key(&mut io::BufReader::new(key_pem))
32 .expect("parsing private key PEM")
33 .expect("no private key found in PEM");
34
35 let mut tls_config =
36 rustls::ServerConfig::builder_with_provider(crate::crypto_provider::crypto_provider())
37 .with_safe_default_protocol_versions()
38 .expect("building TLS config with protocol versions")
39 .with_no_client_auth()
40 .with_single_cert(certs, key)
41 .expect("building TLS config");
42
43 tls_config.alpn_protocols = vec![b"h3".to_vec()];
44
45 let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(Arc::new(tls_config))
46 .expect("building QUIC TLS config");
47
48 Self(quinn::ServerConfig::with_crypto(Arc::new(quic_tls)))
49 }
50
51 pub fn from_rustls_server_config(tls_config: rustls::ServerConfig) -> Self {
56 let mut tls_config = tls_config;
57 if !tls_config.alpn_protocols.contains(&b"h3".to_vec()) {
58 tls_config.alpn_protocols.push(b"h3".to_vec());
59 }
60 let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(Arc::new(tls_config))
61 .expect("building QUIC TLS config");
62 Self(quinn::ServerConfig::with_crypto(Arc::new(quic_tls)))
63 }
64
65 pub fn from_quinn_server_config(config: quinn::ServerConfig) -> Self {
71 Self(config)
72 }
73}
74
75impl<S> QuicConfigTrait<S> for QuicConfig
76where
77 S: Server,
78 S::Runtime: Unpin,
79 S::UdpTransport: SocketTransport,
80{
81 type Endpoint = QuinnEndpoint;
82
83 fn bind(
84 self,
85 addr: SocketAddr,
86 runtime: S::Runtime,
87 _info: &mut Info,
88 ) -> Option<io::Result<Self::Endpoint>> {
89 let quinn_runtime = TrilliumRuntime::<S::Runtime, S::UdpTransport>::new(runtime);
90 let socket = match std::net::UdpSocket::bind(addr) {
91 Ok(s) => s,
92 Err(e) => return Some(Err(e)),
93 };
94
95 Some(
96 quinn::Endpoint::new(
97 quinn::EndpointConfig::default(),
98 Some(self.0),
99 socket,
100 quinn_runtime,
101 )
102 .map(QuinnEndpoint::new),
103 )
104 }
105}
106
107pub struct QuinnEndpoint(quinn::Endpoint);
109
110impl QuinnEndpoint {
111 pub(crate) fn new(endpoint: quinn::Endpoint) -> Self {
113 Self(endpoint)
114 }
115}
116
117impl QuicEndpoint for QuinnEndpoint {
118 type Connection = QuinnConnection;
119
120 async fn accept(&self) -> Option<Self::Connection> {
121 loop {
122 let incoming = self.0.accept().await?;
123 match incoming.await {
124 Ok(connection) => return Some(QuinnConnection::new(connection)),
125 Err(e) => log::error!("QUIC accept failed: {e}"),
126 }
127 }
128 }
129
130 async fn connect(&self, addr: SocketAddr, server_name: &str) -> io::Result<Self::Connection> {
131 let connection = self
132 .0
133 .connect(addr, server_name)
134 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
135 .await
136 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?;
137 Ok(QuinnConnection::new(connection))
138 }
139}