Skip to main content

trillium_html_rewriter/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(
3    clippy::dbg_macro,
4    missing_copy_implementations,
5    rustdoc::missing_crate_level_docs,
6    missing_debug_implementations,
7    missing_docs,
8    nonstandard_style,
9    unused_qualifications
10)]
11#![doc = include_str!("../README.md")]
12
13use lol_async::rewrite;
14pub use lol_async::{Settings, html};
15use mime::Mime;
16use std::{
17    fmt::{self, Debug, Formatter},
18    str::FromStr,
19    sync::Arc,
20};
21use trillium::{
22    Body, Conn, Handler,
23    KnownHeaderName::{ContentLength, ContentType},
24};
25
26/**
27trillium handler for html rewriting
28*/
29pub struct HtmlRewriter {
30    settings: Arc<dyn Fn() -> Settings<'static, 'static> + Send + Sync + 'static>,
31}
32
33impl Debug for HtmlRewriter {
34    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
35        f.debug_struct("HtmlRewriter").finish()
36    }
37}
38
39impl Handler for HtmlRewriter {
40    async fn run(&self, mut conn: Conn) -> Conn {
41        let html = conn
42            .response_headers()
43            .get_str(ContentType)
44            .and_then(|c| Mime::from_str(c).ok())
45            .map(|m| m.subtype() == "html")
46            .unwrap_or_default();
47
48        if html && let Some(body) = conn.take_response_body() {
49            let reader = rewrite(body, (self.settings)());
50            conn.response_headers_mut().remove(ContentLength); // we no longer know the content length, if we ever did
51            conn.with_body(Body::new_streaming(reader, None))
52        } else {
53            conn
54        }
55    }
56}
57
58impl HtmlRewriter {
59    /**
60    construct a new html rewriter from the provided `fn() -> Settings`. See
61    [`lol_async::html::Settings`] for more information.
62     */
63    pub fn new(f: impl Fn() -> Settings<'static, 'static> + Send + Sync + 'static) -> Self {
64        Self {
65            settings: Arc::new(f)
66                as Arc<dyn Fn() -> Settings<'static, 'static> + Send + Sync + 'static>,
67        }
68    }
69}