1#![forbid(unsafe_code)]
2#![deny(
3 clippy::dbg_macro,
4 missing_copy_implementations,
5 rustdoc::missing_crate_level_docs,
6 missing_debug_implementations,
7 missing_docs,
8 nonstandard_style,
9 unused_qualifications
10)]
11
12#[cfg(test)]
15#[doc = include_str!("../README.md")]
16mod readme {}
17
18mod body_streamer;
19mod forward_proxy_connect;
20pub mod upstream;
21
22use body_streamer::stream_body;
23pub use forward_proxy_connect::ForwardProxyConnect;
24use full_duplex_async_copy::full_duplex_copy;
25use futures_lite::future::zip;
26use size::{Base, Size};
27use std::{borrow::Cow, fmt::Debug, future::IntoFuture};
28use trillium::{
29 Conn, Handler, KnownHeaderName,
30 Status::{NotFound, SwitchingProtocols},
31 Upgrade,
32};
33pub use trillium_client::{Client, Connector};
34use trillium_forwarding::Forwarded;
35use trillium_http::{HeaderName, Headers, HttpContext, Status, Version};
36use upstream::{IntoUpstreamSelector, UpstreamSelector};
37pub use url::Url;
38
39pub fn proxy<I>(client: impl Into<Client>, upstream: I) -> Proxy<I::UpstreamSelector>
41where
42 I: IntoUpstreamSelector,
43{
44 Proxy::new(client, upstream)
45}
46
47#[derive(Debug)]
49pub struct Proxy<U> {
50 upstream: U,
51 client: Client,
52 pass_through_not_found: bool,
53 halt: bool,
54 via_pseudonym: Option<Cow<'static, str>>,
55 allow_websocket_upgrade: bool,
56}
57
58impl<U: UpstreamSelector> Proxy<U> {
59 pub fn new<I>(client: impl Into<Client>, upstream: I) -> Self
72 where
73 I: IntoUpstreamSelector<UpstreamSelector = U>,
74 {
75 let client = client
76 .into()
77 .without_default_header(KnownHeaderName::UserAgent)
78 .without_default_header(KnownHeaderName::Accept);
79
80 Self {
81 upstream: upstream.into_upstream(),
82 client,
83 pass_through_not_found: true,
84 halt: true,
85 via_pseudonym: None,
86 allow_websocket_upgrade: false,
87 }
88 }
89
90 pub fn proxy_not_found(mut self) -> Self {
104 self.pass_through_not_found = false;
105 self
106 }
107
108 pub fn without_halting(mut self) -> Self {
122 self.halt = false;
123 self
124 }
125
126 pub fn with_via_pseudonym(mut self, via_pseudonym: impl Into<Cow<'static, str>>) -> Self {
131 self.via_pseudonym = Some(via_pseudonym.into());
132 self
133 }
134
135 pub fn with_websocket_upgrades(mut self) -> Self {
140 self.allow_websocket_upgrade = true;
141 self
142 }
143
144 fn set_via_pseudonym(&self, headers: &mut Headers, version: Version) {
145 if self.via_pseudonym.is_none() {
146 return;
147 }
148
149 use std::fmt::Write;
150 let mut via = String::new();
151 let _ = write!(&mut via, "{version}");
152
153 if let Some(pseudonym) = &self.via_pseudonym {
154 let _ = write!(&mut via, " {pseudonym}");
155 }
156
157 if let Some(old_via) = headers.get_values(KnownHeaderName::Via) {
158 for old_via in old_via {
159 let _ = write!(&mut via, ", {old_via}");
160 }
161 }
162
163 headers.insert(KnownHeaderName::Via, via);
164 }
165}
166
167#[derive(Debug)]
168struct UpstreamUpgrade(Upgrade);
169
170impl<U: UpstreamSelector> Handler for Proxy<U> {
171 async fn init(&mut self, info: &mut trillium::Info) {
172 let old_context = self.client.context();
175 let new_context = HttpContext::default()
176 .with_config(*old_context.config())
177 .with_swansong(info.swansong().clone());
178 self.client.set_context(new_context);
179 log::info!("proxying to {:?}", self.upstream);
180 }
181
182 async fn run(&self, mut conn: Conn) -> Conn {
183 let Some(request_url) = self.upstream.determine_upstream(&mut conn) else {
184 return conn;
185 };
186
187 log::debug!("proxying to {}", request_url.as_str());
188
189 let mut forwarded = Forwarded::from_headers(conn.request_headers())
190 .ok()
191 .flatten()
192 .unwrap_or_default()
193 .into_owned();
194
195 if let Some(peer_ip) = conn.peer_ip() {
196 forwarded.add_for(peer_ip.to_string());
197 };
198
199 if let Some(host) = conn.host() {
200 forwarded.set_host(host);
201 }
202
203 let mut request_headers = conn
204 .request_headers()
205 .clone()
206 .without_headers([
207 KnownHeaderName::Connection,
208 KnownHeaderName::KeepAlive,
209 KnownHeaderName::ProxyAuthenticate,
210 KnownHeaderName::ProxyAuthorization,
211 KnownHeaderName::Te,
212 KnownHeaderName::Trailer,
213 KnownHeaderName::TransferEncoding,
214 KnownHeaderName::Upgrade,
215 KnownHeaderName::Host,
216 KnownHeaderName::XforwardedBy,
217 KnownHeaderName::XforwardedFor,
218 KnownHeaderName::XforwardedHost,
219 KnownHeaderName::XforwardedProto,
220 KnownHeaderName::XforwardedSsl,
221 KnownHeaderName::AltUsed,
222 ])
223 .with_inserted_header(KnownHeaderName::Forwarded, forwarded.to_string());
224
225 let mut connection_is_upgrade = false;
226 for header in conn
227 .request_headers()
228 .get_str(KnownHeaderName::Connection)
229 .unwrap_or_default()
230 .split(',')
231 .map(|h| HeaderName::from(h.trim()))
232 {
233 if header == KnownHeaderName::Upgrade {
234 connection_is_upgrade = true;
235 }
236 request_headers.remove(header);
237 }
238
239 if self.allow_websocket_upgrade
240 && connection_is_upgrade
241 && conn
242 .request_headers()
243 .eq_ignore_ascii_case(KnownHeaderName::Upgrade, "websocket")
244 {
245 request_headers.extend([
246 (KnownHeaderName::Upgrade, "WebSocket"),
247 (KnownHeaderName::Connection, "Upgrade"),
248 ]);
249 }
250
251 self.set_via_pseudonym(&mut request_headers, conn.http_version());
252
253 let content_length = !matches!(
254 conn.request_headers()
255 .get_str(KnownHeaderName::ContentLength),
256 Some("0") | None
257 );
258
259 let chunked = conn
260 .request_headers()
261 .eq_ignore_ascii_case(KnownHeaderName::TransferEncoding, "chunked");
262
263 let method = conn.method();
264 let conn_result = if chunked || content_length {
265 let (body_fut, request_body) = stream_body(&mut conn);
266
267 let client_fut = self
268 .client
269 .build_conn(method, request_url)
270 .with_request_headers(request_headers)
271 .with_body(request_body)
272 .into_future();
273
274 zip(body_fut, client_fut).await.1
275 } else {
276 self.client
277 .build_conn(method, request_url)
278 .with_request_headers(request_headers)
279 .await
280 };
281
282 let mut client_conn = match conn_result {
283 Ok(client_conn) => client_conn,
284 Err(e) => {
285 return conn
286 .with_status(Status::ServiceUnavailable)
287 .halt()
288 .with_state(e);
289 }
290 };
291
292 let client_conn_version = client_conn.http_version();
293
294 let mut conn = match client_conn.status() {
295 Some(SwitchingProtocols) => {
296 conn.response_headers_mut()
297 .extend(std::mem::take(client_conn.response_headers_mut()));
298
299 conn.with_state(UpstreamUpgrade(
300 trillium_http::Upgrade::from(client_conn).into(),
301 ))
302 .with_status(SwitchingProtocols)
303 }
304
305 Some(NotFound) if self.pass_through_not_found => {
306 client_conn.recycle().await;
307 return conn;
308 }
309
310 Some(status) => {
311 conn.response_headers_mut().remove(KnownHeaderName::Server);
312 conn.response_headers_mut()
313 .append_all(client_conn.response_headers().clone());
314 conn.with_body(client_conn).with_status(status)
315 }
316
317 None => return conn.with_status(Status::ServiceUnavailable).halt(),
318 };
319
320 if Some(SwitchingProtocols) != conn.status()
321 || !conn
322 .response_headers()
323 .eq_ignore_ascii_case(KnownHeaderName::Connection, "Upgrade")
324 {
325 let connection = conn
326 .response_headers_mut()
327 .remove(KnownHeaderName::Connection);
328
329 conn.response_headers_mut().remove_all(
330 connection
331 .iter()
332 .flatten()
333 .filter_map(|s| s.as_str())
334 .flat_map(|s| s.split(','))
335 .map(|t| HeaderName::from(t.trim()).into_owned()),
336 );
337 }
338
339 conn.response_headers_mut().remove_all([
340 KnownHeaderName::KeepAlive,
341 KnownHeaderName::ProxyAuthenticate,
342 KnownHeaderName::ProxyAuthorization,
343 KnownHeaderName::Te,
344 KnownHeaderName::Trailer,
345 KnownHeaderName::TransferEncoding,
346 ]);
347
348 self.set_via_pseudonym(conn.response_headers_mut(), client_conn_version);
349
350 if self.halt { conn.halt() } else { conn }
351 }
352
353 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
354 upgrade.state().contains::<UpstreamUpgrade>()
355 }
356
357 async fn upgrade(&self, mut upgrade: Upgrade) {
358 let Some(UpstreamUpgrade(upstream)) = upgrade.state_mut().take() else {
359 return;
360 };
361 let downstream = upgrade;
362 match full_duplex_copy(upstream, downstream).await {
363 Err(e) => log::error!("upgrade stream error: {:?}", e),
364 Ok((up, down)) => {
365 log::debug!("streamed upgrade {} up and {} down", bytes(up), bytes(down))
366 }
367 }
368 }
369}
370
371fn bytes(bytes: u64) -> String {
372 Size::from_bytes(bytes)
373 .format()
374 .with_base(Base::Base10)
375 .to_string()
376}