Skip to main content

trillium_compression/
lib.rs

1/*!
2Body compression for trillium.rs
3
4Currently, this crate only supports compressing outbound bodies with
5the zstd, brotli, and gzip algorithms (in order of preference),
6although more algorithms may be added in the future. The correct
7algorithm will be selected based on the Accept-Encoding header sent by
8the client, if one exists.
9*/
10#![forbid(unsafe_code)]
11#![deny(
12    missing_copy_implementations,
13    rustdoc::missing_crate_level_docs,
14    missing_debug_implementations,
15    nonstandard_style,
16    unused_qualifications
17)]
18#![warn(missing_docs)]
19
20use async_compression::futures::bufread::{BrotliEncoder, GzipEncoder, ZstdEncoder};
21use futures_lite::{
22    io::{BufReader, Cursor},
23    AsyncReadExt,
24};
25use std::{
26    collections::BTreeSet,
27    fmt::{self, Display, Formatter},
28    str::FromStr,
29};
30use trillium::{
31    async_trait, conn_try, conn_unwrap, Body, Conn, Handler, HeaderValues,
32    KnownHeaderName::{AcceptEncoding, ContentEncoding, Vary},
33};
34
35/// Algorithms supported by this crate
36#[derive(PartialEq, Eq, Clone, Copy, Debug, Ord, PartialOrd)]
37#[non_exhaustive]
38pub enum CompressionAlgorithm {
39    /// Brotli algorithm
40    Brotli,
41
42    /// Gzip algorithm
43    Gzip,
44
45    /// Zstd algorithm
46    Zstd,
47}
48
49impl CompressionAlgorithm {
50    fn as_str(&self) -> &'static str {
51        match self {
52            CompressionAlgorithm::Brotli => "br",
53            CompressionAlgorithm::Gzip => "gzip",
54            CompressionAlgorithm::Zstd => "zstd",
55        }
56    }
57
58    fn from_str_exact(s: &str) -> Option<Self> {
59        match s {
60            "br" => Some(CompressionAlgorithm::Brotli),
61            "gzip" => Some(CompressionAlgorithm::Gzip),
62            "x-gzip" => Some(CompressionAlgorithm::Gzip),
63            "zstd" => Some(CompressionAlgorithm::Zstd),
64            _ => None,
65        }
66    }
67}
68
69impl AsRef<str> for CompressionAlgorithm {
70    fn as_ref(&self) -> &str {
71        self.as_str()
72    }
73}
74
75impl Display for CompressionAlgorithm {
76    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
77        f.write_str(self.as_str())
78    }
79}
80
81impl FromStr for CompressionAlgorithm {
82    type Err = String;
83
84    fn from_str(s: &str) -> Result<Self, Self::Err> {
85        Self::from_str_exact(s)
86            .or_else(|| Self::from_str_exact(&s.to_ascii_lowercase()))
87            .ok_or_else(|| format!("unrecognized coding {s}"))
88    }
89}
90
91/**
92Trillium handler for compression
93*/
94#[derive(Clone, Debug)]
95pub struct Compression {
96    algorithms: BTreeSet<CompressionAlgorithm>,
97}
98
99impl Default for Compression {
100    fn default() -> Self {
101        use CompressionAlgorithm::*;
102        Self {
103            algorithms: [Zstd, Brotli, Gzip].into_iter().collect(),
104        }
105    }
106}
107
108impl Compression {
109    /// constructs a new compression handler
110    pub fn new() -> Self {
111        Self::default()
112    }
113
114    fn set_algorithms(&mut self, algos: &[CompressionAlgorithm]) {
115        self.algorithms = algos.iter().copied().collect();
116    }
117
118    /**
119    sets the compression algorithms that this handler will
120    use. the default of Zstd, Brotli, Gzip is recommended. Note that the
121    order is ignored.
122    */
123    pub fn with_algorithms(mut self, algorithms: &[CompressionAlgorithm]) -> Self {
124        self.set_algorithms(algorithms);
125        self
126    }
127
128    fn negotiate(&self, header: &str) -> Option<CompressionAlgorithm> {
129        parse_accept_encoding(header)
130            .into_iter()
131            .find_map(|(algo, _)| {
132                if self.algorithms.contains(&algo) {
133                    Some(algo)
134                } else {
135                    None
136                }
137            })
138    }
139}
140
141fn parse_accept_encoding(header: &str) -> Vec<(CompressionAlgorithm, u8)> {
142    let mut vec = header
143        .split(',')
144        .filter_map(|s| {
145            let mut iter = s.trim().split(';');
146            let (algo, q) = (iter.next()?, iter.next());
147            let algo = algo.trim().parse().ok()?;
148            let q = q
149                .and_then(|q| {
150                    q.trim()
151                        .strip_prefix("q=")
152                        .and_then(|q| q.parse::<f32>().map(|f| (f * 100.0) as u8).ok())
153                })
154                .unwrap_or(100u8);
155            Some((algo, q))
156        })
157        .collect::<Vec<(CompressionAlgorithm, u8)>>();
158
159    vec.sort_by(|(algo_a, a), (algo_b, b)| match b.cmp(a) {
160        std::cmp::Ordering::Equal => algo_a.cmp(algo_b),
161        other => other,
162    });
163
164    vec
165}
166
167#[async_trait]
168impl Handler for Compression {
169    async fn run(&self, mut conn: Conn) -> Conn {
170        if let Some(header) = conn
171            .request_headers()
172            .get_str(AcceptEncoding)
173            .and_then(|h| self.negotiate(h))
174        {
175            conn.insert_state(header);
176        }
177        conn
178    }
179
180    async fn before_send(&self, mut conn: Conn) -> Conn {
181        if let Some(algo) = conn.state::<CompressionAlgorithm>().copied() {
182            let mut body = conn_unwrap!(conn.inner_mut().take_response_body(), conn);
183            let mut compression_used = false;
184
185            if body.is_static() {
186                match algo {
187                    CompressionAlgorithm::Zstd => {
188                        let bytes = body.static_bytes().unwrap();
189                        let mut data = vec![];
190                        let mut encoder = ZstdEncoder::new(Cursor::new(bytes));
191                        conn_try!(encoder.read_to_end(&mut data).await, conn);
192                        if data.len() < bytes.len() {
193                            log::trace!("zstd body from {} to {}", bytes.len(), data.len());
194                            compression_used = true;
195                            body = Body::new_static(data);
196                        }
197                    }
198
199                    CompressionAlgorithm::Brotli => {
200                        let bytes = body.static_bytes().unwrap();
201                        let mut data = vec![];
202                        let mut encoder = BrotliEncoder::new(Cursor::new(bytes));
203                        conn_try!(encoder.read_to_end(&mut data).await, conn);
204                        if data.len() < bytes.len() {
205                            log::trace!("brotli'd body from {} to {}", bytes.len(), data.len());
206                            compression_used = true;
207                            body = Body::new_static(data);
208                        }
209                    }
210
211                    CompressionAlgorithm::Gzip => {
212                        let bytes = body.static_bytes().unwrap();
213                        let mut data = vec![];
214                        let mut encoder = GzipEncoder::new(Cursor::new(bytes));
215                        conn_try!(encoder.read_to_end(&mut data).await, conn);
216                        if data.len() < bytes.len() {
217                            log::trace!("gzipped body from {} to {}", bytes.len(), data.len());
218                            body = Body::new_static(data);
219                            compression_used = true;
220                        }
221                    }
222                }
223            } else if body.is_streaming() {
224                compression_used = true;
225                match algo {
226                    CompressionAlgorithm::Zstd => {
227                        body = Body::new_streaming(
228                            ZstdEncoder::new(BufReader::new(body.into_reader())),
229                            None,
230                        );
231                    }
232
233                    CompressionAlgorithm::Brotli => {
234                        body = Body::new_streaming(
235                            BrotliEncoder::new(BufReader::new(body.into_reader())),
236                            None,
237                        );
238                    }
239
240                    CompressionAlgorithm::Gzip => {
241                        body = Body::new_streaming(
242                            GzipEncoder::new(BufReader::new(body.into_reader())),
243                            None,
244                        );
245                    }
246                }
247            }
248
249            if compression_used {
250                let vary = conn
251                    .response_headers()
252                    .get_str(Vary)
253                    .map(|vary| HeaderValues::from(format!("{vary}, Accept-Encoding")))
254                    .unwrap_or_else(|| HeaderValues::from("Accept-Encoding"));
255
256                conn.response_headers_mut().extend([
257                    (ContentEncoding, HeaderValues::from(algo.as_str())),
258                    (Vary, vary),
259                ]);
260            }
261
262            conn.with_body(body)
263        } else {
264            conn
265        }
266    }
267}
268
269/// Alias for [`Compression::new`](crate::Compression::new)
270pub fn compression() -> Compression {
271    Compression::new()
272}