1use crate::parse_utils::{parse_quoted_string, parse_token};
2use std::{borrow::Cow, fmt::Write, net::IpAddr};
3use trillium::{
4 Headers,
5 KnownHeaderName::{
6 Forwarded as ForwardedHeader, XforwardedBy, XforwardedFor, XforwardedHost, XforwardedProto,
7 XforwardedSsl,
8 },
9};
10
11#[derive(Debug, Clone, Default, PartialEq, Eq)]
14pub struct Forwarded<'a> {
15 by: Option<Cow<'a, str>>,
16 forwarded_for: Vec<Cow<'a, str>>,
17 host: Option<Cow<'a, str>>,
18 proto: Option<Cow<'a, str>>,
19}
20
21impl<'a> Forwarded<'a> {
22 pub fn from_headers(headers: &'a Headers) -> Result<Option<Self>, ParseError> {
79 if let Some(forwarded) = Self::from_forwarded_header(headers)? {
80 Ok(Some(forwarded))
81 } else {
82 Self::from_x_headers(headers)
83 }
84 }
85
86 pub fn from_forwarded_header(headers: &'a Headers) -> Result<Option<Self>, ParseError> {
113 if let Some(headers) = headers.get_str(ForwardedHeader) {
114 Ok(Some(Self::parse(headers)?))
115 } else {
116 Ok(None)
117 }
118 }
119
120 pub fn from_x_headers(headers: &'a Headers) -> Result<Option<Self>, ParseError> {
148 let forwarded_for: Vec<Cow<'a, str>> = headers
149 .get_str(XforwardedFor)
150 .map(|hv| {
151 hv.split(',')
152 .map(|v| {
153 let v = v.trim();
154 match v.parse::<IpAddr>().ok() {
155 Some(IpAddr::V6(v6)) => Cow::Owned(format!(r#"[{v6}]"#)),
156 _ => Cow::Borrowed(v),
157 }
158 })
159 .collect()
160 })
161 .unwrap_or_default();
162
163 let by = headers.get_str(XforwardedBy).map(Cow::Borrowed);
164
165 let proto = headers
166 .get_str(XforwardedProto)
167 .map(Cow::Borrowed)
168 .or_else(|| {
169 if headers.eq_ignore_ascii_case(XforwardedSsl, "on") {
170 Some(Cow::Borrowed("https"))
171 } else {
172 None
173 }
174 });
175
176 let host = headers.get_str(XforwardedHost).map(Cow::Borrowed);
177
178 if !forwarded_for.is_empty() || by.is_some() || proto.is_some() || host.is_some() {
179 Ok(Some(Self {
180 forwarded_for,
181 by,
182 proto,
183 host,
184 }))
185 } else {
186 Ok(None)
187 }
188 }
189
190 pub fn parse(input: &'a str) -> Result<Self, ParseError> {
210 let mut input = input;
211 let mut forwarded = Forwarded::new();
212
213 while !input.is_empty() {
214 input = if starts_with_ignore_case("for=", input) {
215 forwarded.parse_for(input)?
216 } else {
217 forwarded.parse_forwarded_pair(input)?
218 }
219 }
220
221 Ok(forwarded)
222 }
223
224 fn parse_forwarded_pair(&mut self, input: &'a str) -> Result<&'a str, ParseError> {
225 let (key, value, rest) = match parse_token(input) {
226 (Some(key), rest) if rest.starts_with('=') => match parse_value(&rest[1..]) {
227 (Some(value), rest) => Some((key, value, rest)),
228 (None, _) => None,
229 },
230 _ => None,
231 }
232 .ok_or_else(|| ParseError::new("parse error in forwarded-pair"))?;
233
234 match key {
235 "by" => {
236 if self.by.is_some() {
237 return Err(ParseError::new("parse error, duplicate `by` key"));
238 }
239 self.by = Some(value);
240 }
241
242 "host" => {
243 if self.host.is_some() {
244 return Err(ParseError::new("parse error, duplicate `host` key"));
245 }
246 self.host = Some(value);
247 }
248
249 "proto" => {
250 if self.proto.is_some() {
251 return Err(ParseError::new("parse error, duplicate `proto` key"));
252 }
253 self.proto = Some(value);
254 }
255
256 _ => { }
257 }
258
259 match rest.strip_prefix(';') {
260 Some(rest) => Ok(rest),
261 None => Ok(rest),
262 }
263 }
264
265 fn parse_for(&mut self, input: &'a str) -> Result<&'a str, ParseError> {
266 let mut rest = input;
267
268 loop {
269 rest = match match_ignore_case("for=", rest) {
270 (true, rest) => rest,
271 (false, _) => return Err(ParseError::new("http list must start with for=")),
272 };
273
274 let (value, rest_) = parse_value(rest);
275 rest = rest_;
276
277 if let Some(value) = value {
278 self.forwarded_for.push(value);
280 } else {
281 return Err(ParseError::new("for= without valid value"));
282 }
283
284 match rest.chars().next() {
285 Some(',') => {
287 rest = rest[1..].trim_start();
288 }
289
290 Some(';') => return Ok(&rest[1..]),
292
293 None => return Ok(rest),
295
296 _ => return Err(ParseError::new("unexpected character after for= section")),
298 }
299 }
300 }
301
302 pub fn into_owned(self) -> Forwarded<'static> {
305 Forwarded {
306 by: self.by.map(|by| Cow::Owned(by.into_owned())),
307 forwarded_for: self
308 .forwarded_for
309 .into_iter()
310 .map(|ff| Cow::Owned(ff.into_owned()))
311 .collect(),
312 host: self.host.map(|h| Cow::Owned(h.into_owned())),
313 proto: self.proto.map(|p| Cow::Owned(p.into_owned())),
314 }
315 }
316
317 pub fn new() -> Self {
319 Self::default()
320 }
321
322 pub fn add_for(&mut self, forwarded_for: impl Into<Cow<'a, str>>) {
324 self.forwarded_for.push(forwarded_for.into());
325 }
326
327 pub fn forwarded_for(&self) -> Vec<&str> {
329 self.forwarded_for.iter().map(|x| x.as_ref()).collect()
330 }
331
332 pub fn set_host(&mut self, host: impl Into<Cow<'a, str>>) {
334 self.host = Some(host.into());
335 }
336
337 pub fn host(&self) -> Option<&str> {
339 self.host.as_deref()
340 }
341
342 pub fn set_proto(&mut self, proto: impl Into<Cow<'a, str>>) {
344 self.proto = Some(proto.into())
345 }
346
347 pub fn proto(&self) -> Option<&str> {
349 self.proto.as_deref()
350 }
351
352 pub fn set_by(&mut self, by: impl Into<Cow<'a, str>>) {
354 self.by = Some(by.into());
355 }
356
357 pub fn by(&self) -> Option<&str> {
359 self.by.as_deref()
360 }
361}
362
363impl std::fmt::Display for Forwarded<'_> {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 let mut needs_semi = false;
366 if let Some(by) = self.by() {
367 needs_semi = true;
368 write!(f, "by={}", format_value(by))?;
369 }
370
371 if !self.forwarded_for.is_empty() {
372 if needs_semi {
373 f.write_char(';')?;
374 }
375 needs_semi = true;
376 f.write_str(
377 &self
378 .forwarded_for
379 .iter()
380 .map(|f| format!("for={}", format_value(f)))
381 .collect::<Vec<_>>()
382 .join(", "),
383 )?;
384 }
385
386 if let Some(host) = self.host() {
387 if needs_semi {
388 f.write_char(';')?;
389 }
390 needs_semi = true;
391 write!(f, "host={}", format_value(host))?
392 }
393
394 if let Some(proto) = self.proto() {
395 if needs_semi {
396 f.write_char(';')?;
397 }
398 write!(f, "proto={}", format_value(proto))?
399 }
400
401 Ok(())
402 }
403}
404
405fn parse_value(input: &str) -> (Option<Cow<'_, str>>, &str) {
406 match parse_token(input) {
407 (Some(token), rest) => (Some(Cow::Borrowed(token)), rest),
408 (None, rest) => parse_quoted_string(rest),
409 }
410}
411
412fn format_value(input: &str) -> Cow<'_, str> {
413 match parse_token(input) {
414 (_, "") => input.into(),
415 _ => {
416 let mut string = String::from("\"");
417 for ch in input.chars() {
418 if let '\\' | '"' = ch {
419 string.push('\\');
420 }
421 string.push(ch);
422 }
423 string.push('"');
424 string.into()
425 }
426 }
427}
428
429fn match_ignore_case<'a>(start: &'static str, input: &'a str) -> (bool, &'a str) {
430 let len = start.len();
431 if input[..len].eq_ignore_ascii_case(start) {
432 (true, &input[len..])
433 } else {
434 (false, input)
435 }
436}
437
438fn starts_with_ignore_case(start: &'static str, input: &str) -> bool {
439 if start.len() <= input.len() {
440 let len = start.len();
441 input[..len].eq_ignore_ascii_case(start)
442 } else {
443 false
444 }
445}
446
447#[derive(Debug, Clone, Copy)]
448pub struct ParseError(&'static str);
449impl ParseError {
450 pub fn new(msg: &'static str) -> Self {
451 Self(msg)
452 }
453}
454
455impl std::error::Error for ParseError {}
456impl std::fmt::Display for ParseError {
457 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
458 write!(f, "unable to parse forwarded header: {}", self.0)
459 }
460}
461
462impl<'a> TryFrom<&'a str> for Forwarded<'a> {
463 type Error = ParseError;
464 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
465 Self::parse(value)
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 type Result = std::result::Result<(), ParseError>;
473
474 #[test]
475 fn starts_with_ignore_case_can_handle_short_inputs() {
476 assert!(!starts_with_ignore_case("helloooooo", "h"));
477 }
478
479 #[test]
480 fn parsing_for() -> Result {
481 assert_eq!(
482 Forwarded::parse(r#"for="_gazonk""#)?.forwarded_for(),
483 vec!["_gazonk"]
484 );
485 assert_eq!(
486 Forwarded::parse(r#"For="[2001:db8:cafe::17]:4711""#)?.forwarded_for(),
487 vec!["[2001:db8:cafe::17]:4711"]
488 );
489
490 assert_eq!(
491 Forwarded::parse("for=192.0.2.60;proto=http;by=203.0.113.43")?.forwarded_for(),
492 vec!["192.0.2.60"]
493 );
494
495 assert_eq!(
496 Forwarded::parse("for=192.0.2.43, for=198.51.100.17")?.forwarded_for(),
497 vec!["192.0.2.43", "198.51.100.17"]
498 );
499
500 assert_eq!(
501 Forwarded::parse(r#"for=192.0.2.43,for="[2001:db8:cafe::17]",for=unknown"#)?
502 .forwarded_for(),
503 Forwarded::parse(r#"for=192.0.2.43, for="[2001:db8:cafe::17]", for=unknown"#)?
504 .forwarded_for()
505 );
506
507 assert_eq!(
508 Forwarded::parse(
509 r#"for=192.0.2.43,for="this is a valid quoted-string, \" \\",for=unknown"#
510 )?
511 .forwarded_for(),
512 vec![
513 "192.0.2.43",
514 r#"this is a valid quoted-string, " \"#,
515 "unknown"
516 ]
517 );
518
519 Ok(())
520 }
521
522 #[test]
523 fn basic_parse() -> Result {
524 let forwarded = Forwarded::parse("for=client.com;by=proxy.com;host=host.com;proto=https")?;
525
526 assert_eq!(forwarded.by(), Some("proxy.com"));
527 assert_eq!(forwarded.forwarded_for(), vec!["client.com"]);
528 assert_eq!(forwarded.host(), Some("host.com"));
529 assert_eq!(forwarded.proto(), Some("https"));
530 assert!(matches!(forwarded, Forwarded { .. }));
531 Ok(())
532 }
533
534 #[test]
535 fn bad_parse() {
536 let err = Forwarded::parse("by=proxy.com;for=client;host=example.com;host").unwrap_err();
537 assert_eq!(
538 err.to_string(),
539 "unable to parse forwarded header: parse error in forwarded-pair"
540 );
541
542 let err = Forwarded::parse("by;for;host;proto").unwrap_err();
543 assert_eq!(
544 err.to_string(),
545 "unable to parse forwarded header: parse error in forwarded-pair"
546 );
547
548 let err = Forwarded::parse("for=for, key=value").unwrap_err();
549 assert_eq!(
550 err.to_string(),
551 "unable to parse forwarded header: http list must start with for="
552 );
553
554 let err = Forwarded::parse(r#"for="unterminated string"#).unwrap_err();
555 assert_eq!(
556 err.to_string(),
557 "unable to parse forwarded header: for= without valid value"
558 );
559
560 let err = Forwarded::parse(r#"for=, for=;"#).unwrap_err();
561 assert_eq!(
562 err.to_string(),
563 "unable to parse forwarded header: for= without valid value"
564 );
565 }
566
567 #[test]
568 fn bad_parse_from_headers() -> Result {
569 let mut headers = Headers::new();
570 headers.append("forwarded", "uh oh");
571 assert_eq!(
572 Forwarded::from_headers(&headers).unwrap_err().to_string(),
573 "unable to parse forwarded header: parse error in forwarded-pair"
574 );
575
576 let headers = Headers::new();
577 assert!(Forwarded::from_headers(&headers)?.is_none());
578 Ok(())
579 }
580
581 #[test]
582 fn from_x_headers() -> Result {
583 let mut headers = Headers::new();
584 headers.append(XforwardedFor, "192.0.2.43, 2001:db8:cafe::17");
585 headers.append(XforwardedProto, "gopher");
586 headers.append(XforwardedHost, "example.com");
587 let forwarded = Forwarded::from_headers(&headers)?.unwrap();
588 assert_eq!(
589 forwarded.to_string(),
590 r#"for=192.0.2.43, for="[2001:db8:cafe::17]";host=example.com;proto=gopher"#
591 );
592 Ok(())
593 }
594
595 #[test]
596 fn from_x_headers_with_ssl_on() -> Result {
597 let mut headers = Headers::new();
598 headers.append(XforwardedFor, "192.0.2.43, 2001:db8:cafe::17");
599 headers.append(XforwardedHost, "example.com");
600 headers.append(XforwardedSsl, "on");
601 let forwarded = Forwarded::from_headers(&headers)?.unwrap();
602 assert_eq!(
603 forwarded.to_string(),
604 r#"for=192.0.2.43, for="[2001:db8:cafe::17]";host=example.com;proto=https"#
605 );
606 Ok(())
607 }
608
609 #[test]
610 fn formatting_edge_cases() {
611 let mut forwarded = Forwarded::new();
612 forwarded.add_for(r#"quote: " backslash: \"#);
613 forwarded.add_for(";proto=https");
614 assert_eq!(
615 forwarded.to_string(),
616 r#"for="quote: \" backslash: \\", for=";proto=https""#
617 );
618
619 let mut forwarded = Forwarded::new();
620 forwarded.set_host("localhost:8080");
621 forwarded.set_proto("not:normal"); forwarded.set_by("localhost:8081");
623 assert_eq!(
624 forwarded.to_string(),
625 r#"by="localhost:8081";host="localhost:8080";proto="not:normal""#
626 );
627 }
628
629 #[test]
630 fn parse_edge_cases() -> Result {
631 let forwarded =
632 Forwarded::parse(r#"for=";", for=",", for="\"", for=unquoted;by=";proto=https""#)?;
633 assert_eq!(forwarded.forwarded_for(), vec![";", ",", "\"", "unquoted"]);
634 assert_eq!(forwarded.by(), Some(";proto=https"));
635 assert!(forwarded.proto().is_none());
636
637 let forwarded = Forwarded::parse("proto=https")?;
638 assert_eq!(forwarded.proto(), Some("https"));
639 Ok(())
640 }
641
642 #[test]
643 fn owned_parse() -> Result {
644 let forwarded =
645 Forwarded::parse("for=client;by=proxy.com;host=example.com;proto=https")?.into_owned();
646
647 assert_eq!(forwarded.by(), Some("proxy.com"));
648 assert_eq!(forwarded.forwarded_for(), vec!["client"]);
649 assert_eq!(forwarded.host(), Some("example.com"));
650 assert_eq!(forwarded.proto(), Some("https"));
651 assert!(matches!(forwarded, Forwarded { .. }));
652 Ok(())
653 }
654
655 #[test]
656 fn from_headers() -> Result {
657 let mut headers = Headers::new();
658 headers.append("Forwarded", "for=for");
659
660 let forwarded = Forwarded::from_headers(&headers)?.unwrap();
661 assert_eq!(forwarded.forwarded_for(), vec!["for"]);
662
663 Ok(())
664 }
665
666 #[test]
667 fn owned_can_outlive_headers() -> Result {
668 let forwarded = {
669 let mut headers = Headers::new();
670 headers.append("Forwarded", "for=for;by=by;host=host;proto=proto");
671 Forwarded::from_headers(&headers)?.unwrap().into_owned()
672 };
673 assert_eq!(forwarded.by(), Some("by"));
674 Ok(())
675 }
676
677 #[test]
678 fn round_trip() -> Result {
679 let inputs = [
680 "for=client,for=b,for=c;by=proxy.com;host=example.com;proto=https",
681 "by=proxy.com;proto=https;host=example.com;for=a,for=b",
682 "by=proxy.com",
683 "proto=https",
684 "host=example.com",
685 "for=a,for=b",
686 r#"by="localhost:8081";host="localhost:8080";proto="not:normal""#,
687 ];
688 for input in inputs {
689 let forwarded = Forwarded::parse(input)?;
690 let header = forwarded.to_string();
691 let parsed = Forwarded::parse(header.as_str())?;
692 assert_eq!(forwarded, parsed);
693 }
694 Ok(())
695 }
696}