Skip to main content

trillium_testing/
server_connector.rs

1use crate::{Runtime, TestTransport};
2use async_channel::Receiver;
3use std::{
4    io,
5    net::{IpAddr, SocketAddr},
6    sync::Arc,
7};
8use trillium::{Handler, Transport};
9use trillium_http::ServerConfig;
10use trillium_server_common::Connector;
11use url::Url;
12
13/// a bridge between trillium servers and clients
14#[derive(Debug, fieldwork::Fieldwork)]
15pub struct ServerConnector<H> {
16    /// the handler
17    #[field(get, deref = false)]
18    handler: Arc<H>,
19
20    /// the runtime
21    #[field(with, set, get, into)]
22    runtime: Runtime,
23
24    /// the server config
25    #[field(with, set, get(deref = false), into)]
26    server_config: Arc<ServerConfig>,
27
28    pub(crate) client_peer_ips_receiver: Option<Receiver<IpAddr>>,
29    pub(crate) server_peer_ips_receiver: Option<Receiver<IpAddr>>,
30}
31
32impl<H> Clone for ServerConnector<H> {
33    fn clone(&self) -> Self {
34        Self {
35            handler: self.handler.clone(),
36            runtime: self.runtime.clone(),
37            server_config: self.server_config.clone(),
38            client_peer_ips_receiver: self.client_peer_ips_receiver.clone(),
39            server_peer_ips_receiver: self.server_peer_ips_receiver.clone(),
40        }
41    }
42}
43
44impl<H: Handler> ServerConnector<H> {
45    /// builds a new ServerConnector
46    pub fn new(handler: H) -> Self {
47        Self {
48            handler: Arc::new(handler),
49            runtime: crate::runtime().into(),
50            server_config: Arc::default(),
51            client_peer_ips_receiver: None,
52            server_peer_ips_receiver: None,
53        }
54    }
55
56    /// opens a new connection to this virtual server, returning the client transport
57    pub async fn connect(&self, secure: bool) -> TestTransport {
58        let (mut client_transport, mut server_transport) = TestTransport::new();
59        if let Some(server_ip) = self
60            .client_peer_ips_receiver
61            .as_ref()
62            .and_then(|channel| channel.try_recv().ok())
63        {
64            client_transport.set_peer_ip(server_ip);
65        }
66
67        if let Some(client_ip) = self
68            .server_peer_ips_receiver
69            .as_ref()
70            .and_then(|channel| channel.try_recv().ok())
71        {
72            server_transport.set_peer_ip(client_ip);
73        }
74
75        let handler = Arc::clone(&self.handler);
76        let server_config = Arc::clone(&self.server_config);
77
78        let peer_ip = server_transport
79            .peer_addr()
80            .ok()
81            .flatten()
82            .map(|addr| addr.ip());
83
84        self.runtime.spawn(async move {
85            server_config
86                .run(server_transport, |mut conn| {
87                    let handler = Arc::clone(&handler);
88                    async move {
89                        conn.set_peer_ip(peer_ip).set_secure(secure);
90                        let conn = handler.run(conn.into()).await;
91                        let conn = handler.before_send(conn).await;
92                        let mut inner = conn.into_inner::<TestTransport>();
93                        let state = std::mem::take(inner.state_mut());
94                        *inner.transport().state().write().unwrap() = state;
95                        inner
96                    }
97                })
98                .await
99                .unwrap();
100        });
101
102        client_transport
103    }
104}
105
106impl<H: Handler> Connector for ServerConnector<H> {
107    type Runtime = Runtime;
108    type Transport = TestTransport;
109    type Udp = ();
110
111    async fn connect(&self, url: &Url) -> io::Result<Self::Transport> {
112        Ok(self.connect(url.scheme() == "https").await)
113    }
114
115    fn runtime(&self) -> Self::Runtime {
116        #[allow(clippy::clone_on_copy)]
117        self.runtime.clone()
118    }
119
120    async fn resolve(&self, _host: &str, _port: u16) -> io::Result<Vec<SocketAddr>> {
121        Ok(vec![SocketAddr::from(([0, 0, 0, 0], 0))])
122    }
123}
124
125/// build a connector from this handler
126pub fn connector(handler: impl Handler) -> impl Connector {
127    ServerConnector::new(handler)
128}