trillium_compression/
lib.rs1#![forbid(unsafe_code)]
9#![deny(
10 missing_copy_implementations,
11 rustdoc::missing_crate_level_docs,
12 missing_debug_implementations,
13 nonstandard_style,
14 unused_qualifications
15)]
16#![warn(missing_docs)]
17
18#[cfg(test)]
19#[doc = include_str!("../README.md")]
20mod readme {}
21
22use async_compression::futures::bufread::{BrotliEncoder, GzipEncoder, ZstdEncoder};
23use futures_lite::{
24 AsyncReadExt,
25 io::{BufReader, Cursor},
26};
27use std::{
28 collections::BTreeSet,
29 fmt::{self, Display, Formatter},
30 str::FromStr,
31};
32use trillium::{
33 Body, Conn, Handler, HeaderValues,
34 KnownHeaderName::{AcceptEncoding, ContentEncoding, Vary},
35 conn_try, conn_unwrap,
36};
37
38#[derive(PartialEq, Eq, Clone, Copy, Debug, Ord, PartialOrd)]
40#[non_exhaustive]
41pub enum CompressionAlgorithm {
42 Brotli,
44
45 Gzip,
47
48 Zstd,
50}
51
52impl CompressionAlgorithm {
53 fn as_str(&self) -> &'static str {
54 match self {
55 CompressionAlgorithm::Brotli => "br",
56 CompressionAlgorithm::Gzip => "gzip",
57 CompressionAlgorithm::Zstd => "zstd",
58 }
59 }
60
61 fn from_str_exact(s: &str) -> Option<Self> {
62 match s {
63 "br" => Some(CompressionAlgorithm::Brotli),
64 "gzip" => Some(CompressionAlgorithm::Gzip),
65 "x-gzip" => Some(CompressionAlgorithm::Gzip),
66 "zstd" => Some(CompressionAlgorithm::Zstd),
67 _ => None,
68 }
69 }
70}
71
72impl AsRef<str> for CompressionAlgorithm {
73 fn as_ref(&self) -> &str {
74 self.as_str()
75 }
76}
77
78impl Display for CompressionAlgorithm {
79 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
80 f.write_str(self.as_str())
81 }
82}
83
84impl FromStr for CompressionAlgorithm {
85 type Err = String;
86
87 fn from_str(s: &str) -> Result<Self, Self::Err> {
88 Self::from_str_exact(s)
89 .or_else(|| Self::from_str_exact(&s.to_ascii_lowercase()))
90 .ok_or_else(|| format!("unrecognized coding {s}"))
91 }
92}
93
94#[derive(Clone, Debug)]
96pub struct Compression {
97 algorithms: BTreeSet<CompressionAlgorithm>,
98}
99
100impl Default for Compression {
101 fn default() -> Self {
102 use CompressionAlgorithm::*;
103 Self {
104 algorithms: [Zstd, Brotli, Gzip].into_iter().collect(),
105 }
106 }
107}
108
109impl Compression {
110 pub fn new() -> Self {
112 Self::default()
113 }
114
115 fn set_algorithms(&mut self, algos: &[CompressionAlgorithm]) {
116 self.algorithms = algos.iter().copied().collect();
117 }
118
119 pub fn with_algorithms(mut self, algorithms: &[CompressionAlgorithm]) -> Self {
123 self.set_algorithms(algorithms);
124 self
125 }
126
127 fn negotiate(&self, header: &str) -> Option<CompressionAlgorithm> {
128 parse_accept_encoding(header)
129 .into_iter()
130 .find_map(|(algo, _)| {
131 if self.algorithms.contains(&algo) {
132 Some(algo)
133 } else {
134 None
135 }
136 })
137 }
138}
139
140fn parse_accept_encoding(header: &str) -> Vec<(CompressionAlgorithm, u8)> {
141 let mut vec = header
142 .split(',')
143 .filter_map(|s| {
144 let mut iter = s.trim().split(';');
145 let (algo, q) = (iter.next()?, iter.next());
146 let algo = algo.trim().parse().ok()?;
147 let q = q
148 .and_then(|q| {
149 q.trim()
150 .strip_prefix("q=")
151 .and_then(|q| q.parse::<f32>().map(|f| (f * 100.0) as u8).ok())
152 })
153 .unwrap_or(100u8);
154 Some((algo, q))
155 })
156 .collect::<Vec<(CompressionAlgorithm, u8)>>();
157
158 vec.sort_by(|(algo_a, a), (algo_b, b)| match b.cmp(a) {
159 std::cmp::Ordering::Equal => algo_a.cmp(algo_b),
160 other => other,
161 });
162
163 vec
164}
165
166impl Handler for Compression {
167 async fn run(&self, mut conn: Conn) -> Conn {
168 if let Some(header) = conn
169 .request_headers()
170 .get_str(AcceptEncoding)
171 .and_then(|h| self.negotiate(h))
172 {
173 conn.insert_state(header);
174 }
175 conn
176 }
177
178 async fn before_send(&self, mut conn: Conn) -> Conn {
179 if let Some(algo) = conn.state::<CompressionAlgorithm>().copied() {
180 let mut body = conn_unwrap!(conn.take_response_body(), conn);
181 let mut compression_used = false;
182
183 if body.is_static() {
184 match algo {
185 CompressionAlgorithm::Zstd => {
186 let bytes = body.static_bytes().unwrap();
187 let mut data = vec![];
188 let mut encoder = ZstdEncoder::new(Cursor::new(bytes));
189 conn_try!(encoder.read_to_end(&mut data).await, conn);
190 if data.len() < bytes.len() {
191 log::trace!("zstd body from {} to {}", bytes.len(), data.len());
192 compression_used = true;
193 body = Body::new_static(data);
194 }
195 }
196
197 CompressionAlgorithm::Brotli => {
198 let bytes = body.static_bytes().unwrap();
199 let mut data = vec![];
200 let mut encoder = BrotliEncoder::new(Cursor::new(bytes));
201 conn_try!(encoder.read_to_end(&mut data).await, conn);
202 if data.len() < bytes.len() {
203 log::trace!("brotli'd body from {} to {}", bytes.len(), data.len());
204 compression_used = true;
205 body = Body::new_static(data);
206 }
207 }
208
209 CompressionAlgorithm::Gzip => {
210 let bytes = body.static_bytes().unwrap();
211 let mut data = vec![];
212 let mut encoder = GzipEncoder::new(Cursor::new(bytes));
213 conn_try!(encoder.read_to_end(&mut data).await, conn);
214 if data.len() < bytes.len() {
215 log::trace!("gzipped body from {} to {}", bytes.len(), data.len());
216 body = Body::new_static(data);
217 compression_used = true;
218 }
219 }
220 }
221 } else if body.is_streaming() {
222 compression_used = true;
223 match algo {
224 CompressionAlgorithm::Zstd => {
225 body = Body::new_streaming(
226 ZstdEncoder::new(BufReader::new(body.into_reader())),
227 None,
228 );
229 }
230
231 CompressionAlgorithm::Brotli => {
232 body = Body::new_streaming(
233 BrotliEncoder::new(BufReader::new(body.into_reader())),
234 None,
235 );
236 }
237
238 CompressionAlgorithm::Gzip => {
239 body = Body::new_streaming(
240 GzipEncoder::new(BufReader::new(body.into_reader())),
241 None,
242 );
243 }
244 }
245 }
246
247 if compression_used {
248 let vary = conn
249 .response_headers()
250 .get_str(Vary)
251 .map(|vary| HeaderValues::from(format!("{vary}, Accept-Encoding")))
252 .unwrap_or_else(|| HeaderValues::from("Accept-Encoding"));
253
254 conn.response_headers_mut().extend([
255 (ContentEncoding, HeaderValues::from(algo.as_str())),
256 (Vary, vary),
257 ]);
258 }
259
260 conn.with_body(body)
261 } else {
262 conn
263 }
264 }
265}
266
267pub fn compression() -> Compression {
269 Compression::new()
270}