trillium_proxy/
forward_proxy_connect.rs1use crate::bytes;
2use full_duplex_async_copy::full_duplex_copy;
3use std::fmt::Debug;
4use trillium::{async_trait, Conn, Handler, Upgrade};
5use trillium_client::{Connector, ObjectSafeConnector};
6use trillium_http::{Method, Status};
7use url::Url;
8
9#[derive(Debug)]
10pub struct ForwardProxyConnect(Box<dyn ObjectSafeConnector>);
12
13#[derive(Debug)]
14struct ForwardUpgrade(trillium_http::transport::BoxedTransport);
15
16impl ForwardProxyConnect {
17 pub fn new(connector: impl Connector) -> Self {
19 Self(connector.boxed())
20 }
21}
22#[async_trait]
23impl Handler for ForwardProxyConnect {
24 async fn run(&self, conn: Conn) -> Conn {
25 if conn.method() == Method::Connect {
26 let Ok(url) = Url::parse(&format!("http://{}", conn.path())) else {
27 return conn.with_status(Status::BadRequest).halt();
28 };
29
30 if url.cannot_be_a_base() {
31 return conn.with_status(Status::BadRequest).halt();
32 }
33
34 let Ok(tcp) = Connector::connect(&self.0, &url).await else {
35 return conn.with_status(Status::BadGateway).halt();
36 };
37
38 conn.with_status(Status::Ok)
39 .with_state(ForwardUpgrade(tcp))
40 .halt()
41 } else {
42 conn
43 }
44 }
45
46 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
47 upgrade.state.contains::<ForwardUpgrade>()
48 }
49
50 async fn upgrade(&self, mut upgrade: Upgrade) {
51 let Some(ForwardUpgrade(upstream)) = upgrade.state.take() else {
52 return;
53 };
54 let downstream = upgrade;
55 match full_duplex_copy(upstream, downstream).await {
56 Err(e) => log::error!("upgrade stream error: {:?}", e),
57 Ok((up, down)) => {
58 log::debug!("streamed upgrade {} up and {} down", bytes(up), bytes(down))
59 }
60 }
61 }
62}