1use crate::{
37 CacheKey, CacheOptions, CachePolicy, CacheStorage, StoredEntry,
38 tee::TeeingReader,
39 validation::{AfterResponse, BeforeRequest},
40};
41use std::{sync::Arc, time::SystemTime};
42use trillium::{Body, Conn, Handler, KnownHeaderName, Method};
43use url::Url;
44
45const DEFAULT_MAX_CACHEABLE_SIZE: u64 = 16 * 1024 * 1024;
46
47#[derive(Debug)]
50pub struct Cache<S: CacheStorage> {
51 storage: Arc<S>,
52 options: CacheOptions,
53 max_cacheable_size: u64,
54}
55
56impl<S: CacheStorage> Clone for Cache<S> {
57 fn clone(&self) -> Self {
58 Self {
59 storage: Arc::clone(&self.storage),
60 options: self.options,
61 max_cacheable_size: self.max_cacheable_size,
62 }
63 }
64}
65
66impl<S: CacheStorage> Cache<S> {
67 pub fn new(storage: S) -> Self {
70 Self {
71 storage: Arc::new(storage),
72 options: CacheOptions::default(),
73 max_cacheable_size: DEFAULT_MAX_CACHEABLE_SIZE,
74 }
75 }
76
77 pub fn with_options(mut self, options: CacheOptions) -> Self {
79 self.options = options;
80 self
81 }
82
83 pub fn shared(mut self) -> Self {
86 self.options.shared = true;
87 self
88 }
89
90 pub fn with_max_cacheable_size(mut self, max: u64) -> Self {
95 self.max_cacheable_size = max;
96 self
97 }
98
99 pub fn storage(&self) -> &S {
101 &self.storage
102 }
103}
104
105enum CacheCtx<E: StoredEntry> {
107 Hit,
109 Revalidation { stored: E, key: CacheKey },
113 Miss { key: CacheKey },
116 Unsafe { url: Url },
119}
120
121impl<E: StoredEntry> std::fmt::Debug for CacheCtx<E> {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 match self {
124 Self::Hit => f.write_str("Hit"),
125 Self::Revalidation { key, .. } => f
126 .debug_struct("Revalidation")
127 .field("key", key)
128 .finish_non_exhaustive(),
129 Self::Miss { key } => f.debug_struct("Miss").field("key", key).finish(),
130 Self::Unsafe { url } => f.debug_struct("Unsafe").field("url", url).finish(),
131 }
132 }
133}
134
135fn url_from_conn(conn: &Conn) -> Option<Url> {
139 let scheme = if conn.is_secure() { "https" } else { "http" };
140 let host = conn.host()?;
141 let path_and_query = conn.path_and_query();
142 Url::parse(&format!("{scheme}://{host}{path_and_query}")).ok()
143}
144
145impl<S: CacheStorage> Handler for Cache<S> {
146 async fn run(&self, mut conn: Conn) -> Conn {
147 let method = conn.method();
148 let Some(url) = url_from_conn(&conn) else {
149 log::trace!("cache: no host on request, passing through without caching");
150 return conn;
151 };
152 let key = CacheKey::new(method, url.clone());
153 log::trace!("cache: run {method} {url}");
154
155 if !method.is_safe() {
158 log::trace!("cache: unsafe method {method}, bypassing cache read");
159 return conn.with_state(CacheCtx::<S::StoredEntry>::Unsafe { url });
160 }
161
162 let now = SystemTime::now();
163 let entries = self.storage.get(&key).await;
164 log::trace!("cache: {} stored candidate(s) for {key}", entries.len());
165
166 for entry in entries {
167 match entry.policy().before_request(conn.request_headers(), now) {
168 BeforeRequest::Fresh(cached) => {
169 log::trace!("cache: hit (fresh) for {key}, serving cached response");
170 *conn.response_headers_mut() = cached.headers;
171 let body = match entry.open().await {
172 Ok(b) => b,
173 Err(e) => {
174 log::warn!(
175 "cache: open for hit failed for {key}: {e}, passing through"
176 );
177 return conn;
178 }
179 };
180 return conn
181 .with_state(CacheCtx::<S::StoredEntry>::Hit)
182 .with_status(cached.status)
183 .with_body(body)
184 .halt();
185 }
186
187 BeforeRequest::NotModified(cached) => {
188 log::trace!("cache: hit (fresh, conditional matches) for {key}, serving 304");
191 *conn.response_headers_mut() = cached.headers;
192 return conn
193 .with_state(CacheCtx::<S::StoredEntry>::Hit)
194 .with_status(cached.status)
195 .with_body(Body::default())
196 .halt();
197 }
198
199 BeforeRequest::Stale {
200 request_headers,
201 matches: true,
202 } => {
203 log::trace!("cache: stale for {key}, sending conditional revalidation request");
209 *conn.request_headers_mut() = request_headers;
210 return conn.with_state(CacheCtx::Revalidation { stored: entry, key });
211 }
212
213 BeforeRequest::Stale { matches: false, .. } => {
214 log::trace!("cache: candidate vary-mismatch for {key}, trying next");
215 continue;
216 }
217 }
218 }
219
220 log::trace!("cache: miss for {key}, forwarding to downstream handler");
221 conn.with_state(CacheCtx::<S::StoredEntry>::Miss { key })
222 }
223
224 async fn before_send(&self, mut conn: Conn) -> Conn {
225 let Some(ctx) = conn.take_state::<CacheCtx<S::StoredEntry>>() else {
226 return conn;
227 };
228
229 match ctx {
230 CacheCtx::Hit => conn,
231 CacheCtx::Revalidation { stored, key } => {
232 let now = SystemTime::now();
233 let origin_failed = conn.status().is_some_and(|s| s.is_server_error());
234 if origin_failed && stored.policy().is_sie_eligible(now) {
235 log::trace!(
236 "cache: stale-if-error recovery for {} (downstream {:?}), serving stale",
237 conn.method(),
238 conn.status()
239 );
240 return apply_stale(conn, stored, now).await;
241 }
242 if conn.status().is_none() {
243 log::trace!("cache: downstream produced no status, passing through");
244 return conn;
245 }
246 self.handle_revalidation(conn, stored, key).await
247 }
248 CacheCtx::Miss { key } => {
249 if conn.status().is_none() {
250 log::trace!("cache: downstream produced no status, passing through");
251 return conn;
252 }
253 self.handle_miss(conn, key).await
254 }
255 CacheCtx::Unsafe { url } => {
256 let Some(status) = conn.status() else {
257 return conn;
258 };
259 if status.is_success() || status.is_redirection() {
260 log::trace!(
261 "cache: unsafe method {} → {}, invalidating GET and HEAD entries for {url}",
262 conn.method(),
263 status
264 );
265 self.invalidate_url(&url).await;
266
267 for header in [KnownHeaderName::Location, KnownHeaderName::ContentLocation] {
270 let Some(value) = conn.response_headers().get_str(header) else {
271 continue;
272 };
273 let Ok(target) = url.join(value) else {
274 continue;
275 };
276 if target.host_str() != url.host_str() {
277 continue;
278 }
279 log::trace!(
280 "cache: unsafe method secondary invalidation via {header}: {target}"
281 );
282 self.invalidate_url(&target).await;
283 }
284 }
285 conn
286 }
287 }
288 }
289}
290
291impl<S: CacheStorage> Cache<S> {
292 async fn invalidate_url(&self, url: &Url) {
293 self.storage
294 .invalidate(&CacheKey::new(Method::Get, url.clone()))
295 .await;
296 self.storage
297 .invalidate(&CacheKey::new(Method::Head, url.clone()))
298 .await;
299 }
300
301 async fn handle_revalidation(
302 &self,
303 mut conn: Conn,
304 mut stored: S::StoredEntry,
305 key: CacheKey,
306 ) -> Conn {
307 let now = SystemTime::now();
308 let status = conn.status().expect("checked above");
309 match stored.policy().after_response(
310 conn.request_headers(),
311 status,
312 conn.response_headers(),
313 now,
314 ) {
315 AfterResponse::NotModified(new_policy, cached_response) => {
316 log::trace!(
317 "cache: revalidation 304 for {key}, reusing stored body and refreshing entry"
318 );
319 if let Err(e) = stored.refresh_policy(new_policy).await {
320 log::warn!("cache: refresh_policy failed for {key}: {e}");
321 }
322 let body = match stored.open().await {
323 Ok(b) => b,
324 Err(e) => {
325 log::warn!("cache: open after 304 failed for {key}: {e}, passing through");
326 return conn;
327 }
328 };
329 *conn.response_headers_mut() = cached_response.headers;
330 conn.set_status(cached_response.status);
331 conn.set_body(body);
332 conn
333 }
334 AfterResponse::Modified => {
335 drop(stored);
338 self.handle_miss(conn, key).await
339 }
340 }
341 }
342
343 async fn handle_miss(&self, mut conn: Conn, key: CacheKey) -> Conn {
344 let status = conn.status().expect("checked above");
345 if !CachePolicy::is_storable(
346 conn.method(),
347 conn.request_headers(),
348 status,
349 conn.response_headers(),
350 &self.options,
351 ) {
352 log::trace!("cache: miss for {key}, response not storable, passing through");
353 return conn;
354 }
355
356 if let Some(body_ref) = conn.response_body()
358 && let Some(len) = body_ref.len()
359 && len > self.max_cacheable_size
360 {
361 log::trace!(
362 "cache: miss for {key}, body {len} > max {}, not caching",
363 self.max_cacheable_size
364 );
365 return conn;
366 }
367
368 let policy = CachePolicy::new(
369 conn.method(),
370 conn.request_headers(),
371 status,
372 conn.response_headers().clone(),
373 SystemTime::now(),
374 self.options,
375 );
376 let put_handle = match self.storage.put(key.clone(), policy).await {
377 Ok(h) => h,
378 Err(e) => {
379 log::warn!("cache: put({key}) failed: {e}, passing through");
380 return conn;
381 }
382 };
383
384 let Some(body) = conn.take_response_body() else {
385 log::trace!("cache: miss for {key}, no body, passing through");
386 return conn;
387 };
388 let len = body.len();
389 log::trace!("cache: miss for {key}, streaming through tee");
390 let body = body.without_chunked_framing();
391 let tee = TeeingReader::new(body, put_handle, self.max_cacheable_size);
392 conn.set_body(Body::new_with_trailers(tee, len));
393 conn
394 }
395}
396
397async fn apply_stale<E: StoredEntry>(mut conn: Conn, stored: E, now: SystemTime) -> Conn {
399 let cached = stored.policy().cached_response(now);
400 let body = match stored.open().await {
401 Ok(b) => b,
402 Err(e) => {
403 log::warn!("cache: open for stale serve failed: {e}, passing through");
404 return conn;
405 }
406 };
407 *conn.response_headers_mut() = cached.headers;
408 conn.set_status(cached.status);
409 conn.set_body(body);
410 conn
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use crate::InMemoryStorage;
417 use std::sync::atomic::{AtomicUsize, Ordering};
418 use trillium_testing::{TestResult, TestServer, harness, test};
419
420 #[derive(Debug, Clone)]
421 struct CountingHandler {
422 counter: Arc<AtomicUsize>,
423 cache_control: &'static str,
424 etag: Option<&'static str>,
425 }
426
427 impl CountingHandler {
428 fn new(cache_control: &'static str) -> Self {
429 Self {
430 counter: Arc::new(AtomicUsize::new(0)),
431 cache_control,
432 etag: None,
433 }
434 }
435
436 fn with_etag(mut self, etag: &'static str) -> Self {
437 self.etag = Some(etag);
438 self
439 }
440 }
441
442 impl Handler for CountingHandler {
443 async fn run(&self, conn: Conn) -> Conn {
444 let n = self.counter.fetch_add(1, Ordering::SeqCst);
445 if let Some(etag) = self.etag
446 && conn.request_headers().get_str(KnownHeaderName::IfNoneMatch) == Some(etag)
447 {
448 return conn
449 .with_response_header(KnownHeaderName::Etag, etag)
450 .with_status(304)
451 .halt();
452 }
453 let mut conn = conn
454 .with_response_header(KnownHeaderName::CacheControl, self.cache_control)
455 .ok(format!("body-{n}"));
456 if let Some(etag) = self.etag {
457 conn.response_headers_mut()
458 .insert(KnownHeaderName::Etag, etag);
459 }
460 conn
461 }
462 }
463
464 fn cache_app(inner: CountingHandler) -> impl Handler {
465 (Cache::new(InMemoryStorage::new()), inner)
466 }
467
468 #[test(harness)]
469 async fn first_request_misses_subsequent_request_hits() -> TestResult {
470 let inner = CountingHandler::new("max-age=600");
471 let counter = inner.counter.clone();
472 let app = TestServer::new(cache_app(inner)).await;
473
474 let r1 = app.get("/x").await;
475 r1.assert_ok().assert_body("body-0");
476
477 let r2 = app.get("/x").await;
478 r2.assert_ok().assert_body("body-0");
479 assert_eq!(
480 counter.load(Ordering::SeqCst),
481 1,
482 "inner handler only hit once"
483 );
484 Ok(())
485 }
486
487 #[test(harness)]
488 async fn different_urls_dont_collide() -> TestResult {
489 let inner = CountingHandler::new("max-age=600");
490 let counter = inner.counter.clone();
491 let app = TestServer::new(cache_app(inner)).await;
492
493 app.get("/a").await.assert_body("body-0");
494 app.get("/b").await.assert_body("body-1");
495 assert_eq!(counter.load(Ordering::SeqCst), 2);
496 Ok(())
497 }
498
499 #[test(harness)]
500 async fn no_store_response_is_not_cached() -> TestResult {
501 let inner = CountingHandler::new("no-store");
502 let counter = inner.counter.clone();
503 let app = TestServer::new(cache_app(inner)).await;
504
505 app.get("/x").await.assert_body("body-0");
506 app.get("/x").await.assert_body("body-1");
507 assert_eq!(counter.load(Ordering::SeqCst), 2);
508 Ok(())
509 }
510
511 #[test(harness)]
512 async fn post_invalidates_existing_entry() -> TestResult {
513 let inner = CountingHandler::new("max-age=600");
514 let counter = inner.counter.clone();
515 let app = TestServer::new(cache_app(inner)).await;
516
517 app.get("/x").await.assert_body("body-0");
518 let _ = app.post("/x").await;
519 app.get("/x").await.assert_body("body-2");
520 assert_eq!(counter.load(Ordering::SeqCst), 3);
521 Ok(())
522 }
523
524 #[test(harness)]
526 async fn stale_with_etag_revalidates_to_304() -> TestResult {
527 let inner = CountingHandler::new("max-age=0").with_etag(r#""v1""#);
528 let counter = inner.counter.clone();
529 let app = TestServer::new(cache_app(inner)).await;
530
531 app.get("/x").await.assert_body("body-0");
532 assert_eq!(counter.load(Ordering::SeqCst), 1);
533
534 let r2 = app.get("/x").await;
537 r2.assert_ok().assert_body("body-0");
538 assert_eq!(counter.load(Ordering::SeqCst), 2);
539 Ok(())
540 }
541
542 #[test(harness)]
543 async fn vary_isolates_entries_by_request_header() -> TestResult {
544 #[derive(Debug, Clone, Default)]
545 struct VaryHandler(Arc<AtomicUsize>);
546 impl Handler for VaryHandler {
547 async fn run(&self, conn: Conn) -> Conn {
548 self.0.fetch_add(1, Ordering::SeqCst);
549 let ae = conn
550 .request_headers()
551 .get_str(KnownHeaderName::AcceptEncoding)
552 .unwrap_or("none")
553 .to_string();
554 conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
555 .with_response_header(KnownHeaderName::Vary, "Accept-Encoding")
556 .ok(format!("body-for-{ae}"))
557 }
558 }
559
560 let inner = VaryHandler::default();
561 let counter = inner.0.clone();
562 let app = TestServer::new((Cache::new(InMemoryStorage::new()), inner)).await;
563
564 app.get("/x")
565 .with_request_header(KnownHeaderName::AcceptEncoding, "gzip")
566 .await
567 .assert_body("body-for-gzip");
568 app.get("/x")
569 .with_request_header(KnownHeaderName::AcceptEncoding, "br")
570 .await
571 .assert_body("body-for-br");
572 app.get("/x")
573 .with_request_header(KnownHeaderName::AcceptEncoding, "gzip")
574 .await
575 .assert_body("body-for-gzip");
576
577 assert_eq!(counter.load(Ordering::SeqCst), 2);
578 Ok(())
579 }
580
581 #[test(harness)]
582 async fn oversized_body_is_served_but_not_cached() -> TestResult {
583 let inner = CountingHandler::new("max-age=600");
584 let counter = inner.counter.clone();
585 let app = TestServer::new((
587 Cache::new(InMemoryStorage::new()).with_max_cacheable_size(3),
588 inner,
589 ))
590 .await;
591
592 app.get("/x").await.assert_body("body-0");
593 app.get("/x").await.assert_body("body-1");
594 assert_eq!(counter.load(Ordering::SeqCst), 2);
595 Ok(())
596 }
597
598 #[test(harness)]
600 async fn sie_serves_stale_on_5xx() -> TestResult {
601 #[derive(Debug, Clone)]
604 struct FlakyHandler(Arc<AtomicUsize>);
605 impl Handler for FlakyHandler {
606 async fn run(&self, conn: Conn) -> Conn {
607 let n = self.0.fetch_add(1, Ordering::SeqCst);
608 if n == 0 {
609 conn.with_response_header(
610 KnownHeaderName::CacheControl,
611 "max-age=0, stale-if-error=3600",
612 )
613 .ok("stable")
614 } else {
615 conn.with_status(500).halt()
616 }
617 }
618 }
619
620 let inner = FlakyHandler(Arc::new(AtomicUsize::new(0)));
621 let counter = inner.0.clone();
622 let app = TestServer::new((Cache::new(InMemoryStorage::new()), inner)).await;
623
624 app.get("/x").await.assert_ok().assert_body("stable");
625 assert_eq!(counter.load(Ordering::SeqCst), 1);
626
627 let r2 = app.get("/x").await;
628 r2.assert_ok().assert_body("stable");
629 assert_eq!(counter.load(Ordering::SeqCst), 2);
630 Ok(())
631 }
632}