Skip to main content

trillium_compression/
lib.rs

1//! Body compression for trillium.rs
2//!
3//! Currently, this crate only supports compressing outbound bodies with
4//! the zstd, brotli, and gzip algorithms (in order of preference),
5//! although more algorithms may be added in the future. The correct
6//! algorithm will be selected based on the Accept-Encoding header sent by
7//! the client, if one exists.
8#![forbid(unsafe_code)]
9#![deny(
10    missing_copy_implementations,
11    rustdoc::missing_crate_level_docs,
12    missing_debug_implementations,
13    nonstandard_style,
14    unused_qualifications
15)]
16#![warn(missing_docs)]
17
18#[cfg(test)]
19#[doc = include_str!("../README.md")]
20mod readme {}
21
22use async_compression::futures::bufread::{BrotliEncoder, GzipEncoder, ZstdEncoder};
23use futures_lite::{
24    AsyncReadExt,
25    io::{BufReader, Cursor},
26};
27use std::{
28    collections::BTreeSet,
29    fmt::{self, Display, Formatter},
30    str::FromStr,
31};
32use trillium::{
33    Body, Conn, Handler, HeaderValues,
34    KnownHeaderName::{AcceptEncoding, ContentEncoding, Vary},
35    conn_try, conn_unwrap,
36};
37
38/// Algorithms supported by this crate
39#[derive(PartialEq, Eq, Clone, Copy, Debug, Ord, PartialOrd)]
40#[non_exhaustive]
41pub enum CompressionAlgorithm {
42    /// Brotli algorithm
43    Brotli,
44
45    /// Gzip algorithm
46    Gzip,
47
48    /// Zstd algorithm
49    Zstd,
50}
51
52impl CompressionAlgorithm {
53    fn as_str(&self) -> &'static str {
54        match self {
55            CompressionAlgorithm::Brotli => "br",
56            CompressionAlgorithm::Gzip => "gzip",
57            CompressionAlgorithm::Zstd => "zstd",
58        }
59    }
60
61    fn from_str_exact(s: &str) -> Option<Self> {
62        match s {
63            "br" => Some(CompressionAlgorithm::Brotli),
64            "gzip" => Some(CompressionAlgorithm::Gzip),
65            "x-gzip" => Some(CompressionAlgorithm::Gzip),
66            "zstd" => Some(CompressionAlgorithm::Zstd),
67            _ => None,
68        }
69    }
70}
71
72impl AsRef<str> for CompressionAlgorithm {
73    fn as_ref(&self) -> &str {
74        self.as_str()
75    }
76}
77
78impl Display for CompressionAlgorithm {
79    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
80        f.write_str(self.as_str())
81    }
82}
83
84impl FromStr for CompressionAlgorithm {
85    type Err = String;
86
87    fn from_str(s: &str) -> Result<Self, Self::Err> {
88        Self::from_str_exact(s)
89            .or_else(|| Self::from_str_exact(&s.to_ascii_lowercase()))
90            .ok_or_else(|| format!("unrecognized coding {s}"))
91    }
92}
93
94/// Trillium handler for compression
95#[derive(Clone, Debug)]
96pub struct Compression {
97    algorithms: BTreeSet<CompressionAlgorithm>,
98}
99
100impl Default for Compression {
101    fn default() -> Self {
102        use CompressionAlgorithm::*;
103        Self {
104            algorithms: [Zstd, Brotli, Gzip].into_iter().collect(),
105        }
106    }
107}
108
109impl Compression {
110    /// constructs a new compression handler
111    pub fn new() -> Self {
112        Self::default()
113    }
114
115    fn set_algorithms(&mut self, algos: &[CompressionAlgorithm]) {
116        self.algorithms = algos.iter().copied().collect();
117    }
118
119    /// sets the compression algorithms that this handler will
120    /// use. the default of Zstd, Brotli, Gzip is recommended. Note that the
121    /// order is ignored.
122    pub fn with_algorithms(mut self, algorithms: &[CompressionAlgorithm]) -> Self {
123        self.set_algorithms(algorithms);
124        self
125    }
126
127    fn negotiate(&self, header: &str) -> Option<CompressionAlgorithm> {
128        parse_accept_encoding(header)
129            .into_iter()
130            .find_map(|(algo, _)| {
131                if self.algorithms.contains(&algo) {
132                    Some(algo)
133                } else {
134                    None
135                }
136            })
137    }
138}
139
140fn parse_accept_encoding(header: &str) -> Vec<(CompressionAlgorithm, u8)> {
141    let mut vec = header
142        .split(',')
143        .filter_map(|s| {
144            let mut iter = s.trim().split(';');
145            let (algo, q) = (iter.next()?, iter.next());
146            let algo = algo.trim().parse().ok()?;
147            let q = q
148                .and_then(|q| {
149                    q.trim()
150                        .strip_prefix("q=")
151                        .and_then(|q| q.parse::<f32>().map(|f| (f * 100.0) as u8).ok())
152                })
153                .unwrap_or(100u8);
154            Some((algo, q))
155        })
156        .collect::<Vec<(CompressionAlgorithm, u8)>>();
157
158    vec.sort_by(|(algo_a, a), (algo_b, b)| match b.cmp(a) {
159        std::cmp::Ordering::Equal => algo_a.cmp(algo_b),
160        other => other,
161    });
162
163    vec
164}
165
166impl Handler for Compression {
167    async fn run(&self, mut conn: Conn) -> Conn {
168        if let Some(header) = conn
169            .request_headers()
170            .get_str(AcceptEncoding)
171            .and_then(|h| self.negotiate(h))
172        {
173            conn.insert_state(header);
174        }
175        conn
176    }
177
178    async fn before_send(&self, mut conn: Conn) -> Conn {
179        if let Some(algo) = conn.state::<CompressionAlgorithm>().copied() {
180            let mut body = conn_unwrap!(conn.take_response_body(), conn);
181            let mut compression_used = false;
182
183            if body.is_static() {
184                match algo {
185                    CompressionAlgorithm::Zstd => {
186                        let bytes = body.static_bytes().unwrap();
187                        let mut data = vec![];
188                        let mut encoder = ZstdEncoder::new(Cursor::new(bytes));
189                        conn_try!(encoder.read_to_end(&mut data).await, conn);
190                        if data.len() < bytes.len() {
191                            log::trace!("zstd body from {} to {}", bytes.len(), data.len());
192                            compression_used = true;
193                            body = Body::new_static(data);
194                        }
195                    }
196
197                    CompressionAlgorithm::Brotli => {
198                        let bytes = body.static_bytes().unwrap();
199                        let mut data = vec![];
200                        let mut encoder = BrotliEncoder::new(Cursor::new(bytes));
201                        conn_try!(encoder.read_to_end(&mut data).await, conn);
202                        if data.len() < bytes.len() {
203                            log::trace!("brotli'd body from {} to {}", bytes.len(), data.len());
204                            compression_used = true;
205                            body = Body::new_static(data);
206                        }
207                    }
208
209                    CompressionAlgorithm::Gzip => {
210                        let bytes = body.static_bytes().unwrap();
211                        let mut data = vec![];
212                        let mut encoder = GzipEncoder::new(Cursor::new(bytes));
213                        conn_try!(encoder.read_to_end(&mut data).await, conn);
214                        if data.len() < bytes.len() {
215                            log::trace!("gzipped body from {} to {}", bytes.len(), data.len());
216                            body = Body::new_static(data);
217                            compression_used = true;
218                        }
219                    }
220                }
221            } else if body.is_streaming() {
222                compression_used = true;
223                match algo {
224                    CompressionAlgorithm::Zstd => {
225                        body = Body::new_streaming(
226                            ZstdEncoder::new(BufReader::new(body.into_reader())),
227                            None,
228                        );
229                    }
230
231                    CompressionAlgorithm::Brotli => {
232                        body = Body::new_streaming(
233                            BrotliEncoder::new(BufReader::new(body.into_reader())),
234                            None,
235                        );
236                    }
237
238                    CompressionAlgorithm::Gzip => {
239                        body = Body::new_streaming(
240                            GzipEncoder::new(BufReader::new(body.into_reader())),
241                            None,
242                        );
243                    }
244                }
245            }
246
247            if compression_used {
248                let vary = conn
249                    .response_headers()
250                    .get_str(Vary)
251                    .map(|vary| HeaderValues::from(format!("{vary}, Accept-Encoding")))
252                    .unwrap_or_else(|| HeaderValues::from("Accept-Encoding"));
253
254                conn.response_headers_mut().extend([
255                    (ContentEncoding, HeaderValues::from(algo.as_str())),
256                    (Vary, vary),
257                ]);
258            }
259
260            conn.with_body(body)
261        } else {
262            conn
263        }
264    }
265}
266
267/// Alias for [`Compression::new`](crate::Compression::new)
268pub fn compression() -> Compression {
269    Compression::new()
270}