trillium_testing/
server_connector.rs1use crate::{Runtime, TestTransport};
2use async_channel::Receiver;
3use std::{
4 io,
5 net::{IpAddr, SocketAddr},
6 sync::Arc,
7};
8use trillium::{Handler, Transport};
9use trillium_http::ServerConfig;
10use trillium_server_common::Connector;
11use url::Url;
12
13#[derive(Debug, fieldwork::Fieldwork)]
15pub struct ServerConnector<H> {
16 #[field(get, deref = false)]
18 handler: Arc<H>,
19
20 #[field(with, set, get, into)]
22 runtime: Runtime,
23
24 #[field(with, set, get(deref = false), into)]
26 server_config: Arc<ServerConfig>,
27
28 pub(crate) client_peer_ips_receiver: Option<Receiver<IpAddr>>,
29 pub(crate) server_peer_ips_receiver: Option<Receiver<IpAddr>>,
30}
31
32impl<H> Clone for ServerConnector<H> {
33 fn clone(&self) -> Self {
34 Self {
35 handler: self.handler.clone(),
36 runtime: self.runtime.clone(),
37 server_config: self.server_config.clone(),
38 client_peer_ips_receiver: self.client_peer_ips_receiver.clone(),
39 server_peer_ips_receiver: self.server_peer_ips_receiver.clone(),
40 }
41 }
42}
43
44impl<H: Handler> ServerConnector<H> {
45 pub fn new(handler: H) -> Self {
47 Self {
48 handler: Arc::new(handler),
49 runtime: crate::runtime().into(),
50 server_config: Arc::default(),
51 client_peer_ips_receiver: None,
52 server_peer_ips_receiver: None,
53 }
54 }
55
56 pub async fn connect(&self, secure: bool) -> TestTransport {
58 let (mut client_transport, mut server_transport) = TestTransport::new();
59 if let Some(server_ip) = self
60 .client_peer_ips_receiver
61 .as_ref()
62 .and_then(|channel| channel.try_recv().ok())
63 {
64 client_transport.set_peer_ip(server_ip);
65 }
66
67 if let Some(client_ip) = self
68 .server_peer_ips_receiver
69 .as_ref()
70 .and_then(|channel| channel.try_recv().ok())
71 {
72 server_transport.set_peer_ip(client_ip);
73 }
74
75 let handler = Arc::clone(&self.handler);
76 let server_config = Arc::clone(&self.server_config);
77
78 let peer_ip = server_transport
79 .peer_addr()
80 .ok()
81 .flatten()
82 .map(|addr| addr.ip());
83
84 self.runtime.spawn(async move {
85 server_config
86 .run(server_transport, |mut conn| {
87 let handler = Arc::clone(&handler);
88 async move {
89 conn.set_peer_ip(peer_ip).set_secure(secure);
90 let conn = handler.run(conn.into()).await;
91 let conn = handler.before_send(conn).await;
92 let mut inner = conn.into_inner::<TestTransport>();
93 let state = std::mem::take(inner.state_mut());
94 *inner.transport().state().write().unwrap() = state;
95 inner
96 }
97 })
98 .await
99 .unwrap();
100 });
101
102 client_transport
103 }
104}
105
106impl<H: Handler> Connector for ServerConnector<H> {
107 type Runtime = Runtime;
108 type Transport = TestTransport;
109 type Udp = ();
110
111 async fn connect(&self, url: &Url) -> io::Result<Self::Transport> {
112 Ok(self.connect(url.scheme() == "https").await)
113 }
114
115 fn runtime(&self) -> Self::Runtime {
116 #[allow(clippy::clone_on_copy)]
117 self.runtime.clone()
118 }
119
120 async fn resolve(&self, _host: &str, _port: u16) -> io::Result<Vec<SocketAddr>> {
121 Ok(vec![SocketAddr::from(([0, 0, 0, 0], 0))])
122 }
123}
124
125pub fn connector(handler: impl Handler) -> impl Connector {
127 ServerConnector::new(handler)
128}