1#![forbid(unsafe_code)]
11#![deny(
12 missing_copy_implementations,
13 rustdoc::missing_crate_level_docs,
14 missing_debug_implementations,
15 nonstandard_style,
16 unused_qualifications
17)]
18#![warn(missing_docs)]
19
20use async_compression::futures::bufread::{BrotliEncoder, GzipEncoder, ZstdEncoder};
21use futures_lite::{
22 io::{BufReader, Cursor},
23 AsyncReadExt,
24};
25use std::{
26 collections::BTreeSet,
27 fmt::{self, Display, Formatter},
28 str::FromStr,
29};
30use trillium::{
31 async_trait, conn_try, conn_unwrap, Body, Conn, Handler, HeaderValues,
32 KnownHeaderName::{AcceptEncoding, ContentEncoding, Vary},
33};
34
35#[derive(PartialEq, Eq, Clone, Copy, Debug, Ord, PartialOrd)]
37#[non_exhaustive]
38pub enum CompressionAlgorithm {
39 Brotli,
41
42 Gzip,
44
45 Zstd,
47}
48
49impl CompressionAlgorithm {
50 fn as_str(&self) -> &'static str {
51 match self {
52 CompressionAlgorithm::Brotli => "br",
53 CompressionAlgorithm::Gzip => "gzip",
54 CompressionAlgorithm::Zstd => "zstd",
55 }
56 }
57
58 fn from_str_exact(s: &str) -> Option<Self> {
59 match s {
60 "br" => Some(CompressionAlgorithm::Brotli),
61 "gzip" => Some(CompressionAlgorithm::Gzip),
62 "x-gzip" => Some(CompressionAlgorithm::Gzip),
63 "zstd" => Some(CompressionAlgorithm::Zstd),
64 _ => None,
65 }
66 }
67}
68
69impl AsRef<str> for CompressionAlgorithm {
70 fn as_ref(&self) -> &str {
71 self.as_str()
72 }
73}
74
75impl Display for CompressionAlgorithm {
76 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
77 f.write_str(self.as_str())
78 }
79}
80
81impl FromStr for CompressionAlgorithm {
82 type Err = String;
83
84 fn from_str(s: &str) -> Result<Self, Self::Err> {
85 Self::from_str_exact(s)
86 .or_else(|| Self::from_str_exact(&s.to_ascii_lowercase()))
87 .ok_or_else(|| format!("unrecognized coding {s}"))
88 }
89}
90
91#[derive(Clone, Debug)]
95pub struct Compression {
96 algorithms: BTreeSet<CompressionAlgorithm>,
97}
98
99impl Default for Compression {
100 fn default() -> Self {
101 use CompressionAlgorithm::*;
102 Self {
103 algorithms: [Zstd, Brotli, Gzip].into_iter().collect(),
104 }
105 }
106}
107
108impl Compression {
109 pub fn new() -> Self {
111 Self::default()
112 }
113
114 fn set_algorithms(&mut self, algos: &[CompressionAlgorithm]) {
115 self.algorithms = algos.iter().copied().collect();
116 }
117
118 pub fn with_algorithms(mut self, algorithms: &[CompressionAlgorithm]) -> Self {
124 self.set_algorithms(algorithms);
125 self
126 }
127
128 fn negotiate(&self, header: &str) -> Option<CompressionAlgorithm> {
129 parse_accept_encoding(header)
130 .into_iter()
131 .find_map(|(algo, _)| {
132 if self.algorithms.contains(&algo) {
133 Some(algo)
134 } else {
135 None
136 }
137 })
138 }
139}
140
141fn parse_accept_encoding(header: &str) -> Vec<(CompressionAlgorithm, u8)> {
142 let mut vec = header
143 .split(',')
144 .filter_map(|s| {
145 let mut iter = s.trim().split(';');
146 let (algo, q) = (iter.next()?, iter.next());
147 let algo = algo.trim().parse().ok()?;
148 let q = q
149 .and_then(|q| {
150 q.trim()
151 .strip_prefix("q=")
152 .and_then(|q| q.parse::<f32>().map(|f| (f * 100.0) as u8).ok())
153 })
154 .unwrap_or(100u8);
155 Some((algo, q))
156 })
157 .collect::<Vec<(CompressionAlgorithm, u8)>>();
158
159 vec.sort_by(|(algo_a, a), (algo_b, b)| match b.cmp(a) {
160 std::cmp::Ordering::Equal => algo_a.cmp(algo_b),
161 other => other,
162 });
163
164 vec
165}
166
167#[async_trait]
168impl Handler for Compression {
169 async fn run(&self, mut conn: Conn) -> Conn {
170 if let Some(header) = conn
171 .request_headers()
172 .get_str(AcceptEncoding)
173 .and_then(|h| self.negotiate(h))
174 {
175 conn.insert_state(header);
176 }
177 conn
178 }
179
180 async fn before_send(&self, mut conn: Conn) -> Conn {
181 if let Some(algo) = conn.state::<CompressionAlgorithm>().copied() {
182 let mut body = conn_unwrap!(conn.inner_mut().take_response_body(), conn);
183 let mut compression_used = false;
184
185 if body.is_static() {
186 match algo {
187 CompressionAlgorithm::Zstd => {
188 let bytes = body.static_bytes().unwrap();
189 let mut data = vec![];
190 let mut encoder = ZstdEncoder::new(Cursor::new(bytes));
191 conn_try!(encoder.read_to_end(&mut data).await, conn);
192 if data.len() < bytes.len() {
193 log::trace!("zstd body from {} to {}", bytes.len(), data.len());
194 compression_used = true;
195 body = Body::new_static(data);
196 }
197 }
198
199 CompressionAlgorithm::Brotli => {
200 let bytes = body.static_bytes().unwrap();
201 let mut data = vec![];
202 let mut encoder = BrotliEncoder::new(Cursor::new(bytes));
203 conn_try!(encoder.read_to_end(&mut data).await, conn);
204 if data.len() < bytes.len() {
205 log::trace!("brotli'd body from {} to {}", bytes.len(), data.len());
206 compression_used = true;
207 body = Body::new_static(data);
208 }
209 }
210
211 CompressionAlgorithm::Gzip => {
212 let bytes = body.static_bytes().unwrap();
213 let mut data = vec![];
214 let mut encoder = GzipEncoder::new(Cursor::new(bytes));
215 conn_try!(encoder.read_to_end(&mut data).await, conn);
216 if data.len() < bytes.len() {
217 log::trace!("gzipped body from {} to {}", bytes.len(), data.len());
218 body = Body::new_static(data);
219 compression_used = true;
220 }
221 }
222 }
223 } else if body.is_streaming() {
224 compression_used = true;
225 match algo {
226 CompressionAlgorithm::Zstd => {
227 body = Body::new_streaming(
228 ZstdEncoder::new(BufReader::new(body.into_reader())),
229 None,
230 );
231 }
232
233 CompressionAlgorithm::Brotli => {
234 body = Body::new_streaming(
235 BrotliEncoder::new(BufReader::new(body.into_reader())),
236 None,
237 );
238 }
239
240 CompressionAlgorithm::Gzip => {
241 body = Body::new_streaming(
242 GzipEncoder::new(BufReader::new(body.into_reader())),
243 None,
244 );
245 }
246 }
247 }
248
249 if compression_used {
250 let vary = conn
251 .response_headers()
252 .get_str(Vary)
253 .map(|vary| HeaderValues::from(format!("{vary}, Accept-Encoding")))
254 .unwrap_or_else(|| HeaderValues::from("Accept-Encoding"));
255
256 conn.response_headers_mut().extend([
257 (ContentEncoding, HeaderValues::from(algo.as_str())),
258 (Vary, vary),
259 ]);
260 }
261
262 conn.with_body(body)
263 } else {
264 conn
265 }
266 }
267}
268
269pub fn compression() -> Compression {
271 Compression::new()
272}