Skip to main content

trillium_proxy/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(
3    clippy::dbg_macro,
4    missing_copy_implementations,
5    rustdoc::missing_crate_level_docs,
6    missing_debug_implementations,
7    missing_docs,
8    nonstandard_style,
9    unused_qualifications
10)]
11
12//! http reverse and forward proxy trillium handler
13
14#[cfg(test)]
15#[doc = include_str!("../README.md")]
16mod readme {}
17
18mod body_streamer;
19mod forward_proxy_connect;
20pub mod upstream;
21
22use body_streamer::stream_body;
23pub use forward_proxy_connect::ForwardProxyConnect;
24use full_duplex_async_copy::full_duplex_copy;
25use futures_lite::future::zip;
26use size::{Base, Size};
27use std::{borrow::Cow, fmt::Debug, future::IntoFuture};
28use trillium::{
29    Conn, Handler, KnownHeaderName,
30    Status::{NotFound, SwitchingProtocols},
31    Upgrade,
32};
33pub use trillium_client::{Client, Connector};
34use trillium_forwarding::Forwarded;
35use trillium_http::{HeaderName, Headers, HttpContext, Status, Version};
36use upstream::{IntoUpstreamSelector, UpstreamSelector};
37pub use url::Url;
38
39/// constructs a new [`Proxy`]. alias of [`Proxy::new`]
40pub fn proxy<I>(client: impl Into<Client>, upstream: I) -> Proxy<I::UpstreamSelector>
41where
42    I: IntoUpstreamSelector,
43{
44    Proxy::new(client, upstream)
45}
46
47/// the proxy handler
48#[derive(Debug)]
49pub struct Proxy<U> {
50    upstream: U,
51    client: Client,
52    pass_through_not_found: bool,
53    halt: bool,
54    via_pseudonym: Option<Cow<'static, str>>,
55    allow_websocket_upgrade: bool,
56}
57
58impl<U: UpstreamSelector> Proxy<U> {
59    /// construct a new proxy handler that sends all requests to the upstream
60    /// provided
61    ///
62    /// ```
63    /// use trillium_proxy::Proxy;
64    /// use trillium_smol::ClientConfig;
65    ///
66    /// let proxy = Proxy::new(
67    ///     ClientConfig::default(),
68    ///     "http://docs.trillium.rs/trillium_proxy",
69    /// );
70    /// ```
71    pub fn new<I>(client: impl Into<Client>, upstream: I) -> Self
72    where
73        I: IntoUpstreamSelector<UpstreamSelector = U>,
74    {
75        let client = client
76            .into()
77            .without_default_header(KnownHeaderName::UserAgent)
78            .without_default_header(KnownHeaderName::Accept);
79
80        Self {
81            upstream: upstream.into_upstream(),
82            client,
83            pass_through_not_found: true,
84            halt: true,
85            via_pseudonym: None,
86            allow_websocket_upgrade: false,
87        }
88    }
89
90    /// chainable constructor to set the 404 Not Found handling
91    /// behavior. By default, this proxy will pass through the trillium
92    /// Conn unmodified if the proxy response is a 404 not found, allowing
93    /// it to be chained in a tuple handler. To modify this behavior, call
94    /// proxy_not_found, and the full 404 response will be forwarded. The
95    /// Conn will be halted unless [`Proxy::without_halting`] was
96    /// configured
97    ///
98    /// ```
99    /// # use trillium_smol::ClientConfig;
100    /// # use trillium_proxy::Proxy;
101    /// let proxy = Proxy::new(ClientConfig::default(), "http://trillium.rs").proxy_not_found();
102    /// ```
103    pub fn proxy_not_found(mut self) -> Self {
104        self.pass_through_not_found = false;
105        self
106    }
107
108    /// The default behavior for this handler is to halt the conn on any
109    /// response other than a 404. If [`Proxy::proxy_not_found`] has been
110    /// configured, the default behavior for all response statuses is to
111    /// halt the trillium conn. To change this behavior, call
112    /// without_halting when constructing the proxy, and it will not halt
113    /// the conn. This is useful when passing the proxy reply through
114    /// [`trillium_html_rewriter`](https://docs.trillium.rs/trillium_html_rewriter).
115    ///
116    /// ```
117    /// # use trillium_smol::ClientConfig;
118    /// # use trillium_proxy::Proxy;
119    /// let proxy = Proxy::new(ClientConfig::default(), "http://trillium.rs").without_halting();
120    /// ```
121    pub fn without_halting(mut self) -> Self {
122        self.halt = false;
123        self
124    }
125
126    /// populate the pseudonym for a
127    /// [`Via`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Via)
128    /// header. If no pseudonym is provided, no via header will be
129    /// inserted.
130    pub fn with_via_pseudonym(mut self, via_pseudonym: impl Into<Cow<'static, str>>) -> Self {
131        self.via_pseudonym = Some(via_pseudonym.into());
132        self
133    }
134
135    /// Allow websockets to be proxied
136    ///
137    /// This is not currently the default, but that may change at some (semver-minor) point in the
138    /// future
139    pub fn with_websocket_upgrades(mut self) -> Self {
140        self.allow_websocket_upgrade = true;
141        self
142    }
143
144    fn set_via_pseudonym(&self, headers: &mut Headers, version: Version) {
145        if self.via_pseudonym.is_none() {
146            return;
147        }
148
149        use std::fmt::Write;
150        let mut via = String::new();
151        let _ = write!(&mut via, "{version}");
152
153        if let Some(pseudonym) = &self.via_pseudonym {
154            let _ = write!(&mut via, " {pseudonym}");
155        }
156
157        if let Some(old_via) = headers.get_values(KnownHeaderName::Via) {
158            for old_via in old_via {
159                let _ = write!(&mut via, ", {old_via}");
160            }
161        }
162
163        headers.insert(KnownHeaderName::Via, via);
164    }
165}
166
167#[derive(Debug)]
168struct UpstreamUpgrade(Upgrade);
169
170impl<U: UpstreamSelector> Handler for Proxy<U> {
171    async fn init(&mut self, info: &mut trillium::Info) {
172        // this little dance is necessary to set the swansong on the client currently.
173        // this is only necessary because we're not wiring together the client.
174        let old_context = self.client.context();
175        let new_context = HttpContext::default()
176            .with_config(*old_context.config())
177            .with_swansong(info.swansong().clone());
178        self.client.set_context(new_context);
179        log::info!("proxying to {:?}", self.upstream);
180    }
181
182    async fn run(&self, mut conn: Conn) -> Conn {
183        let Some(request_url) = self.upstream.determine_upstream(&mut conn) else {
184            return conn;
185        };
186
187        log::debug!("proxying to {}", request_url.as_str());
188
189        let mut forwarded = Forwarded::from_headers(conn.request_headers())
190            .ok()
191            .flatten()
192            .unwrap_or_default()
193            .into_owned();
194
195        if let Some(peer_ip) = conn.peer_ip() {
196            forwarded.add_for(peer_ip.to_string());
197        };
198
199        if let Some(host) = conn.host() {
200            forwarded.set_host(host);
201        }
202
203        let mut request_headers = conn
204            .request_headers()
205            .clone()
206            .without_headers([
207                KnownHeaderName::Connection,
208                KnownHeaderName::KeepAlive,
209                KnownHeaderName::ProxyAuthenticate,
210                KnownHeaderName::ProxyAuthorization,
211                KnownHeaderName::Te,
212                KnownHeaderName::Trailer,
213                KnownHeaderName::TransferEncoding,
214                KnownHeaderName::Upgrade,
215                KnownHeaderName::Host,
216                KnownHeaderName::XforwardedBy,
217                KnownHeaderName::XforwardedFor,
218                KnownHeaderName::XforwardedHost,
219                KnownHeaderName::XforwardedProto,
220                KnownHeaderName::XforwardedSsl,
221                KnownHeaderName::AltUsed,
222            ])
223            .with_inserted_header(KnownHeaderName::Forwarded, forwarded.to_string());
224
225        let mut connection_is_upgrade = false;
226        for header in conn
227            .request_headers()
228            .get_str(KnownHeaderName::Connection)
229            .unwrap_or_default()
230            .split(',')
231            .map(|h| HeaderName::from(h.trim()))
232        {
233            if header == KnownHeaderName::Upgrade {
234                connection_is_upgrade = true;
235            }
236            request_headers.remove(header);
237        }
238
239        if self.allow_websocket_upgrade
240            && connection_is_upgrade
241            && conn
242                .request_headers()
243                .eq_ignore_ascii_case(KnownHeaderName::Upgrade, "websocket")
244        {
245            request_headers.extend([
246                (KnownHeaderName::Upgrade, "WebSocket"),
247                (KnownHeaderName::Connection, "Upgrade"),
248            ]);
249        }
250
251        self.set_via_pseudonym(&mut request_headers, conn.http_version());
252
253        let content_length = !matches!(
254            conn.request_headers()
255                .get_str(KnownHeaderName::ContentLength),
256            Some("0") | None
257        );
258
259        let chunked = conn
260            .request_headers()
261            .eq_ignore_ascii_case(KnownHeaderName::TransferEncoding, "chunked");
262
263        let method = conn.method();
264        let conn_result = if chunked || content_length {
265            let (body_fut, request_body) = stream_body(&mut conn);
266
267            let client_fut = self
268                .client
269                .build_conn(method, request_url)
270                .with_request_headers(request_headers)
271                .with_body(request_body)
272                .into_future();
273
274            zip(body_fut, client_fut).await.1
275        } else {
276            self.client
277                .build_conn(method, request_url)
278                .with_request_headers(request_headers)
279                .await
280        };
281
282        let mut client_conn = match conn_result {
283            Ok(client_conn) => client_conn,
284            Err(e) => {
285                return conn
286                    .with_status(Status::ServiceUnavailable)
287                    .halt()
288                    .with_state(e);
289            }
290        };
291
292        let client_conn_version = client_conn.http_version();
293
294        let mut conn = match client_conn.status() {
295            Some(SwitchingProtocols) => {
296                conn.response_headers_mut()
297                    .extend(std::mem::take(client_conn.response_headers_mut()));
298
299                conn.with_state(UpstreamUpgrade(
300                    trillium_http::Upgrade::from(client_conn).into(),
301                ))
302                .with_status(SwitchingProtocols)
303            }
304
305            Some(NotFound) if self.pass_through_not_found => {
306                client_conn.recycle().await;
307                return conn;
308            }
309
310            Some(status) => {
311                conn.response_headers_mut().remove(KnownHeaderName::Server);
312                conn.response_headers_mut()
313                    .append_all(client_conn.response_headers().clone());
314                conn.with_body(client_conn).with_status(status)
315            }
316
317            None => return conn.with_status(Status::ServiceUnavailable).halt(),
318        };
319
320        if Some(SwitchingProtocols) != conn.status()
321            || !conn
322                .response_headers()
323                .eq_ignore_ascii_case(KnownHeaderName::Connection, "Upgrade")
324        {
325            let connection = conn
326                .response_headers_mut()
327                .remove(KnownHeaderName::Connection);
328
329            conn.response_headers_mut().remove_all(
330                connection
331                    .iter()
332                    .flatten()
333                    .filter_map(|s| s.as_str())
334                    .flat_map(|s| s.split(','))
335                    .map(|t| HeaderName::from(t.trim()).into_owned()),
336            );
337        }
338
339        conn.response_headers_mut().remove_all([
340            KnownHeaderName::KeepAlive,
341            KnownHeaderName::ProxyAuthenticate,
342            KnownHeaderName::ProxyAuthorization,
343            KnownHeaderName::Te,
344            KnownHeaderName::Trailer,
345            KnownHeaderName::TransferEncoding,
346        ]);
347
348        self.set_via_pseudonym(conn.response_headers_mut(), client_conn_version);
349
350        if self.halt { conn.halt() } else { conn }
351    }
352
353    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
354        upgrade.state().contains::<UpstreamUpgrade>()
355    }
356
357    async fn upgrade(&self, mut upgrade: Upgrade) {
358        let Some(UpstreamUpgrade(upstream)) = upgrade.state_mut().take() else {
359            return;
360        };
361        let downstream = upgrade;
362        match full_duplex_copy(upstream, downstream).await {
363            Err(e) => log::error!("upgrade stream error: {:?}", e),
364            Ok((up, down)) => {
365                log::debug!("streamed upgrade {} up and {} down", bytes(up), bytes(down))
366            }
367        }
368    }
369}
370
371fn bytes(bytes: u64) -> String {
372    Size::from_bytes(bytes)
373        .format()
374        .with_base(Base::Base10)
375        .to_string()
376}