Skip to main content

trillium_proxy/
forward_proxy_connect.rs

1use crate::bytes;
2use full_duplex_async_copy::full_duplex_copy;
3use std::fmt::Debug;
4use trillium::{async_trait, Conn, Handler, Upgrade};
5use trillium_client::{Connector, ObjectSafeConnector};
6use trillium_http::{Method, Status};
7use url::Url;
8
9#[derive(Debug)]
10/// trillium handler to implement Connect proxying
11pub struct ForwardProxyConnect(Box<dyn ObjectSafeConnector>);
12
13#[derive(Debug)]
14struct ForwardUpgrade(trillium_http::transport::BoxedTransport);
15
16impl ForwardProxyConnect {
17    /// construct a new ForwardProxyConnect
18    pub fn new(connector: impl Connector) -> Self {
19        Self(connector.boxed())
20    }
21}
22#[async_trait]
23impl Handler for ForwardProxyConnect {
24    async fn run(&self, conn: Conn) -> Conn {
25        if conn.method() == Method::Connect {
26            let Ok(url) = Url::parse(&format!("http://{}", conn.path())) else {
27                return conn.with_status(Status::BadRequest).halt();
28            };
29
30            if url.cannot_be_a_base() {
31                return conn.with_status(Status::BadRequest).halt();
32            }
33
34            let Ok(tcp) = Connector::connect(&self.0, &url).await else {
35                return conn.with_status(Status::BadGateway).halt();
36            };
37
38            conn.with_status(Status::Ok)
39                .with_state(ForwardUpgrade(tcp))
40                .halt()
41        } else {
42            conn
43        }
44    }
45
46    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
47        upgrade.state.contains::<ForwardUpgrade>()
48    }
49
50    async fn upgrade(&self, mut upgrade: Upgrade) {
51        let Some(ForwardUpgrade(upstream)) = upgrade.state.take() else {
52            return;
53        };
54        let downstream = upgrade;
55        match full_duplex_copy(upstream, downstream).await {
56            Err(e) => log::error!("upgrade stream error: {:?}", e),
57            Ok((up, down)) => {
58                log::debug!("streamed upgrade {} up and {} down", bytes(up), bytes(down))
59            }
60        }
61    }
62}