1use crate::{
29 CacheKey, CacheOptions, CachePolicy, CacheStorage, PutHandle, StoredEntry,
30 tee::TeeingReader,
31 validation::{AfterResponse, BeforeRequest},
32};
33use futures_lite::{AsyncReadExt, AsyncWriteExt};
34use std::{sync::Arc, time::SystemTime};
35use trillium_client::{
36 Body, Client, ClientHandler, Conn, ConnExt, Headers, KnownHeaderName, Method, ResponseBody,
37 Result, Url,
38};
39
40const DEFAULT_MAX_CACHEABLE_SIZE: u64 = 16 * 1024 * 1024;
41
42#[derive(Debug)]
48pub struct Cache<S: CacheStorage> {
49 storage: Arc<S>,
50 options: CacheOptions,
51 max_cacheable_size: u64,
52}
53
54impl<S: CacheStorage> Clone for Cache<S> {
55 fn clone(&self) -> Self {
56 Self {
57 storage: Arc::clone(&self.storage),
58 options: self.options,
59 max_cacheable_size: self.max_cacheable_size,
60 }
61 }
62}
63
64impl<S: CacheStorage> Cache<S> {
65 pub fn new(storage: S) -> Self {
68 Self {
69 storage: Arc::new(storage),
70 options: CacheOptions::default(),
71 max_cacheable_size: DEFAULT_MAX_CACHEABLE_SIZE,
72 }
73 }
74
75 pub fn with_options(mut self, options: CacheOptions) -> Self {
77 self.options = options;
78 self
79 }
80
81 pub fn shared(mut self) -> Self {
84 self.options.shared = true;
85 self
86 }
87
88 pub fn with_max_cacheable_size(mut self, max: u64) -> Self {
93 self.max_cacheable_size = max;
94 self
95 }
96
97 pub fn storage(&self) -> &S {
99 &self.storage
100 }
101}
102
103enum CacheCtx<E: StoredEntry> {
106 Hit,
109 Revalidation { stored: E, key: CacheKey },
113 Miss { key: CacheKey },
116 Unsafe { url: Url },
119}
120
121impl<E: StoredEntry> std::fmt::Debug for CacheCtx<E> {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 match self {
124 Self::Hit => f.write_str("Hit"),
125 Self::Revalidation { key, .. } => f
126 .debug_struct("Revalidation")
127 .field("key", key)
128 .finish_non_exhaustive(),
129 Self::Miss { key } => f.debug_struct("Miss").field("key", key).finish(),
130 Self::Unsafe { url } => f.debug_struct("Unsafe").field("url", url).finish(),
131 }
132 }
133}
134
135impl<S: CacheStorage> ClientHandler for Cache<S> {
136 async fn run(&self, conn: &mut Conn) -> Result<()> {
137 let method = conn.method();
138 let key = CacheKey::new(method, conn.url().clone());
139 log::trace!("cache: run {method} {}", conn.url());
140
141 if !method.is_safe() {
144 log::trace!("cache: unsafe method {method}, bypassing cache read");
145 conn.insert_state(CacheCtx::<S::StoredEntry>::Unsafe {
146 url: conn.url().clone(),
147 });
148 return Ok(());
149 }
150
151 let now = SystemTime::now();
152 let entries = self.storage.get(&key).await;
153 log::trace!("cache: {} stored candidate(s) for {key}", entries.len());
154
155 for entry in entries {
156 match entry.policy().before_request(conn.request_headers(), now) {
157 BeforeRequest::Fresh(cached) => {
158 log::trace!("cache: hit (fresh) for {key}, serving cached response");
159 *conn.response_headers_mut() = cached.headers;
160 let body = match entry.open().await {
161 Ok(b) => b,
162 Err(e) => {
163 log::warn!(
164 "cache: open for hit failed for {key}: {e}, passing through"
165 );
166 return Ok(());
168 }
169 };
170 conn.set_status(cached.status)
171 .set_response_body(body)
172 .halt()
173 .insert_state(CacheCtx::<S::StoredEntry>::Hit);
174 return Ok(());
175 }
176
177 BeforeRequest::NotModified(cached) => {
178 log::trace!("cache: hit (fresh, conditional matches) for {key}, serving 304");
179 *conn.response_headers_mut() = cached.headers;
180 conn.set_status(cached.status)
181 .set_response_body(b"" as &[u8])
182 .halt()
183 .insert_state(CacheCtx::<S::StoredEntry>::Hit);
184 return Ok(());
185 }
186
187 BeforeRequest::Stale {
188 request_headers,
189 matches: true,
190 } => {
191 if entry.policy().is_swr_eligible(now) {
195 log::trace!(
196 "cache: stale-while-revalidate for {key}, serving stale + spawning \
197 background revalidation"
198 );
199 let entry_for_bg = entry.clone();
200 self.spawn_background_revalidation(
201 conn,
202 entry_for_bg,
203 key.clone(),
204 request_headers,
205 );
206 match self.serve_stale(conn, entry, now).await {
207 Ok(()) => {
208 conn.halt();
209 conn.insert_state(CacheCtx::<S::StoredEntry>::Hit);
210 }
211 Err(e) => {
212 log::warn!(
213 "cache: open for stale serve failed for {key}: {e}, passing \
214 through"
215 );
216 }
217 }
218 return Ok(());
219 }
220 log::trace!("cache: stale for {key}, sending conditional revalidation request");
222 *conn.request_headers_mut() = request_headers;
223 conn.insert_state(CacheCtx::Revalidation { stored: entry, key });
224 return Ok(());
225 }
226
227 BeforeRequest::Stale { matches: false, .. } => {
228 log::trace!("cache: candidate vary-mismatch for {key}, trying next");
229 continue;
230 }
231 }
232 }
233
234 log::trace!("cache: miss for {key}, forwarding to origin");
235 conn.insert_state(CacheCtx::<S::StoredEntry>::Miss { key });
236 Ok(())
237 }
238
239 async fn after_response(&self, conn: &mut Conn) -> Result<()> {
240 let Some(ctx) = conn.take_state::<CacheCtx<S::StoredEntry>>() else {
241 log::trace!("cache: after_response with no CacheCtx, nothing to do");
242 return Ok(());
243 };
244
245 if let CacheCtx::Revalidation { ref stored, .. } = ctx {
248 let now = SystemTime::now();
249 let origin_failed =
250 conn.error().is_some() || conn.status().is_some_and(|s| s.is_server_error());
251 if origin_failed && stored.policy().is_sie_eligible(now) {
252 log::trace!(
253 "cache: stale-if-error recovery for {} (origin error/{:?}), serving stale",
254 conn.url(),
255 conn.status()
256 );
257 if let Err(e) = self.serve_stale(conn, stored.clone(), now).await {
258 log::warn!(
259 "cache: open for stale serve failed for {}: {e}, propagating error",
260 conn.url()
261 );
262 return Ok(());
263 }
264 conn.take_error();
265 return Ok(());
266 }
267 }
268
269 if conn.status().is_none() {
270 log::trace!(
271 "cache: transport error with no SIE recovery for {}, propagating",
272 conn.url()
273 );
274 return Ok(());
275 }
276
277 match ctx {
278 CacheCtx::Hit => {
279 log::trace!("cache: hit confirmed in after_response for {}", conn.url());
280 Ok(())
281 }
282 CacheCtx::Revalidation { stored, key } => {
283 self.handle_revalidation(conn, stored, key).await
284 }
285 CacheCtx::Miss { key } => self.handle_miss(conn, key).await,
286 CacheCtx::Unsafe { url } => {
287 let status = conn.status().expect("checked above");
288 if status.is_success() || status.is_redirection() {
289 log::trace!(
290 "cache: unsafe method {} → {}, invalidating GET and HEAD entries for {url}",
291 conn.method(),
292 status
293 );
294 self.invalidate_url(&url).await;
295
296 for header in [KnownHeaderName::Location, KnownHeaderName::ContentLocation] {
297 let Some(value) = conn.response_headers().get_str(header) else {
298 continue;
299 };
300 let Ok(target) = url.join(value) else {
301 log::trace!(
302 "cache: unsafe method secondary invalidation: {header} value \
303 {value:?} did not resolve, skipping"
304 );
305 continue;
306 };
307 if target.host_str() != url.host_str() {
308 log::trace!(
309 "cache: unsafe method secondary invalidation: {header} target \
310 {target} differs in host from request URL, skipping (§4.4 DoS \
311 guard)"
312 );
313 continue;
314 }
315 log::trace!(
316 "cache: unsafe method secondary invalidation via {header}: {target}"
317 );
318 self.invalidate_url(&target).await;
319 }
320 } else {
321 log::trace!(
322 "cache: unsafe method {} → {} for {url}, no invalidation",
323 conn.method(),
324 status
325 );
326 }
327 Ok(())
328 }
329 }
330 }
331}
332
333impl<S: CacheStorage> Cache<S> {
334 async fn invalidate_url(&self, url: &Url) {
337 self.storage
338 .invalidate(&CacheKey::new(Method::Get, url.clone()))
339 .await;
340 self.storage
341 .invalidate(&CacheKey::new(Method::Head, url.clone()))
342 .await;
343 }
344
345 async fn serve_stale(
349 &self,
350 conn: &mut Conn,
351 stored: S::StoredEntry,
352 now: SystemTime,
353 ) -> std::io::Result<()> {
354 let cached = stored.policy().cached_response(now);
355 let body = stored.open().await?;
356 conn.set_status(cached.status);
357 *conn.response_headers_mut() = cached.headers;
358 conn.set_response_body(body);
359 Ok(())
360 }
361
362 fn spawn_background_revalidation(
370 &self,
371 conn: &Conn,
372 stored: S::StoredEntry,
373 key: CacheKey,
374 request_headers: Headers,
375 ) {
376 let runtime = conn.client().connector().runtime();
377 let bypass_client = conn.client().clone().with_handler(());
378 let cache = self.clone();
379 let method = conn.method();
380 let url = conn.url().clone();
381 log::trace!("cache: spawning background revalidation for {key}");
382
383 let _detached = runtime.spawn(async move {
384 cache
385 .background_revalidation(bypass_client, method, url, request_headers, stored, key)
386 .await;
387 });
388 }
389
390 async fn background_revalidation(
391 self,
392 client: Client,
393 method: Method,
394 url: Url,
395 request_headers: Headers,
396 mut stored: S::StoredEntry,
397 key: CacheKey,
398 ) {
399 let mut new_conn = client.build_conn(method, url);
400 *new_conn.request_headers_mut() = request_headers;
401
402 if let Err(e) = (&mut new_conn).await {
403 log::trace!(
404 "cache: background revalidation transport error for {key} ({e}), leaving stored \
405 entry"
406 );
407 return;
408 }
409
410 let now = SystemTime::now();
411 let new_status = new_conn
412 .status()
413 .expect("background revalidation: response not yet received");
414 match stored.policy().after_response(
415 new_conn.request_headers(),
416 new_status,
417 new_conn.response_headers(),
418 now,
419 ) {
420 AfterResponse::NotModified(new_policy, _) => {
421 log::trace!("cache: background revalidation 304 for {key}, refreshing entry");
422 if let Err(e) = stored.refresh_policy(new_policy).await {
423 log::warn!("cache: background refresh_policy failed for {key}: {e}");
424 }
425 }
426 AfterResponse::Modified => {
427 let new_request_method = new_conn.method();
428 let new_request_headers = new_conn.request_headers().clone();
429 let new_response_headers = new_conn.response_headers().clone();
430 if !CachePolicy::is_storable(
431 new_request_method,
432 &new_request_headers,
433 new_status,
434 &new_response_headers,
435 &self.options,
436 ) {
437 log::trace!(
438 "cache: background revalidation 200 for {key}, response not storable, \
439 dropping"
440 );
441 return;
442 }
443 let new_policy = CachePolicy::new(
444 new_request_method,
445 &new_request_headers,
446 new_status,
447 new_response_headers,
448 now,
449 self.options,
450 );
451 let put_handle = match self.storage.put(key.clone(), new_policy).await {
452 Ok(h) => h,
453 Err(e) => {
454 log::warn!(
455 "cache: background put({key}) failed: {e}, leaving stored entry"
456 );
457 return;
458 }
459 };
460 let Some(body) = new_conn.take_response_body() else {
461 log::trace!(
462 "cache: background revalidation 200 for {key}, no body, leaving stored \
463 entry"
464 );
465 return;
466 };
467 if let Err(e) = copy_into_storage(body, put_handle, self.max_cacheable_size).await {
468 log::warn!(
469 "cache: background copy into storage failed for {key}: {e}, leaving \
470 stored entry"
471 );
472 }
473 }
474 }
475 }
476
477 async fn handle_revalidation(
478 &self,
479 conn: &mut Conn,
480 mut stored: S::StoredEntry,
481 key: CacheKey,
482 ) -> Result<()> {
483 let now = SystemTime::now();
484 let new_status = conn.status().expect("checked above");
485 match stored.policy().after_response(
486 conn.request_headers(),
487 new_status,
488 conn.response_headers(),
489 now,
490 ) {
491 AfterResponse::NotModified(new_policy, cached_response) => {
492 log::trace!(
493 "cache: revalidation 304 for {key}, reusing stored body and refreshing entry"
494 );
495 if let Err(e) = stored.refresh_policy(new_policy).await {
496 log::warn!("cache: refresh_policy failed for {key}: {e}");
497 }
498 let body = match stored.open().await {
499 Ok(b) => b,
500 Err(e) => {
501 log::warn!("cache: open after 304 failed for {key}: {e}, passing through");
502 return Ok(());
503 }
504 };
505 conn.set_status(cached_response.status);
506 *conn.response_headers_mut() = cached_response.headers;
507 conn.set_response_body(body);
508 Ok(())
509 }
510 AfterResponse::Modified => {
511 drop(stored);
514 self.handle_miss(conn, key).await
515 }
516 }
517 }
518
519 async fn handle_miss(&self, conn: &mut Conn, key: CacheKey) -> Result<()> {
520 let status = conn.status().expect("checked above");
521 if !CachePolicy::is_storable(
522 conn.method(),
523 conn.request_headers(),
524 status,
525 conn.response_headers(),
526 &self.options,
527 ) {
528 log::trace!("cache: miss for {key}, response not storable, passing through");
529 return Ok(());
530 }
531
532 if let Some(len) = conn
534 .response_headers()
535 .get_str(KnownHeaderName::ContentLength)
536 .and_then(|s| s.parse::<u64>().ok())
537 && len > self.max_cacheable_size
538 {
539 log::trace!(
540 "cache: miss for {key}, body {len} > max {}, not caching",
541 self.max_cacheable_size
542 );
543 return Ok(());
544 }
545
546 let policy = CachePolicy::new(
547 conn.method(),
548 conn.request_headers(),
549 status,
550 conn.response_headers().clone(),
551 SystemTime::now(),
552 self.options,
553 );
554 let put_handle = match self.storage.put(key.clone(), policy).await {
555 Ok(h) => h,
556 Err(e) => {
557 log::warn!("cache: put({key}) failed: {e}, passing through");
558 return Ok(());
559 }
560 };
561
562 let Some(response_body) = conn.take_response_body() else {
563 log::trace!("cache: miss for {key}, no body, passing through");
564 return Ok(());
565 };
566 let len = response_body.content_length();
567 let upstream = Body::new_with_trailers(response_body, len).without_chunked_framing();
570 log::trace!("cache: miss for {key}, streaming through tee");
571 let tee = TeeingReader::new(upstream, put_handle, self.max_cacheable_size);
572 conn.set_response_body(Body::new_with_trailers(tee, len));
573 Ok(())
574 }
575}
576
577async fn copy_into_storage<P: PutHandle>(
581 body: ResponseBody<'static>,
582 mut put: P,
583 cap: u64,
584) -> std::io::Result<()> {
585 let len = body.content_length();
586 let mut body = Body::new_with_trailers(body, len).without_chunked_framing();
588 let mut buf = [0u8; 8192];
589 let mut total: u64 = 0;
590 loop {
591 let n = body.read(&mut buf).await?;
592 if n == 0 {
593 break;
594 }
595 total = total.saturating_add(n as u64);
596 if total > cap {
597 drop(put);
599 log::trace!("cache: background copy exceeded cap {cap}, aborting cache write");
600 return Ok(());
601 }
602 put.write_all(&buf[..n]).await?;
603 }
604 let trailers = body.trailers();
605 put.finalize(trailers).await
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611 use crate::InMemoryStorage;
612 use std::sync::{
613 Arc,
614 atomic::{AtomicUsize, Ordering},
615 };
616 use trillium::{Conn as ServerConn, Handler as ServerHandler, KnownHeaderName, Status};
617 use trillium_client::Client;
618 use trillium_testing::{ServerConnector, TestResult, harness, test};
619
620 #[derive(Debug, Clone)]
621 struct CountingServer {
622 counter: Arc<AtomicUsize>,
623 cache_control: &'static str,
624 etag: Option<&'static str>,
625 }
626
627 impl CountingServer {
628 fn new(cache_control: &'static str) -> Self {
629 Self {
630 counter: Arc::new(AtomicUsize::new(0)),
631 cache_control,
632 etag: None,
633 }
634 }
635
636 fn with_etag(mut self, etag: &'static str) -> Self {
637 self.etag = Some(etag);
638 self
639 }
640 }
641
642 impl ServerHandler for CountingServer {
643 async fn run(&self, conn: ServerConn) -> ServerConn {
644 let n = self.counter.fetch_add(1, Ordering::SeqCst);
645
646 if let Some(etag) = self.etag {
647 if conn.request_headers().get_str(KnownHeaderName::IfNoneMatch) == Some(etag) {
648 return conn
649 .with_status(Status::NotModified)
650 .with_response_header(KnownHeaderName::Etag, etag)
651 .halt();
652 }
653 }
654
655 let mut conn = conn
656 .with_response_header(KnownHeaderName::CacheControl, self.cache_control)
657 .ok(format!("body-{n}"));
658 if let Some(etag) = self.etag {
659 conn.response_headers_mut()
660 .insert(KnownHeaderName::Etag, etag);
661 }
662 conn
663 }
664 }
665
666 fn cache_client(server: CountingServer) -> (Client, Arc<AtomicUsize>) {
667 let counter = server.counter.clone();
668 let client = Client::new(ServerConnector::new(server))
669 .with_handler(Cache::new(InMemoryStorage::new()));
670 (client, counter)
671 }
672
673 #[test(harness)]
674 async fn first_request_misses_subsequent_request_hits() -> TestResult {
675 let (client, counter) = cache_client(CountingServer::new("max-age=600"));
676
677 let mut r1 = client.get("http://example.com/x").await?;
678 assert_eq!(r1.status(), Some(Status::Ok));
679 assert_eq!(r1.response_body().read_string().await?, "body-0");
680
681 let mut r2 = client.get("http://example.com/x").await?;
682 assert_eq!(r2.status(), Some(Status::Ok));
683 assert_eq!(r2.response_body().read_string().await?, "body-0");
684 assert_eq!(counter.load(Ordering::SeqCst), 1, "server only hit once");
685 Ok(())
686 }
687
688 #[test(harness)]
689 async fn different_urls_dont_collide() -> TestResult {
690 let (client, counter) = cache_client(CountingServer::new("max-age=600"));
691
692 let mut r1 = client.get("http://example.com/a").await?;
693 let mut r2 = client.get("http://example.com/b").await?;
694 assert_eq!(r1.response_body().read_string().await?, "body-0");
695 assert_eq!(r2.response_body().read_string().await?, "body-1");
696 assert_eq!(counter.load(Ordering::SeqCst), 2);
697 Ok(())
698 }
699
700 #[test(harness)]
701 async fn no_store_response_is_not_cached() -> TestResult {
702 let (client, counter) = cache_client(CountingServer::new("no-store"));
703
704 let mut r1 = client.get("http://example.com/x").await?;
705 assert_eq!(r1.response_body().read_string().await?, "body-0");
706
707 let mut r2 = client.get("http://example.com/x").await?;
708 assert_eq!(r2.response_body().read_string().await?, "body-1");
709 assert_eq!(counter.load(Ordering::SeqCst), 2);
710 Ok(())
711 }
712
713 #[test(harness)]
714 async fn post_invalidates_existing_entry() -> TestResult {
715 let (client, counter) = cache_client(CountingServer::new("max-age=600"));
716
717 let mut r1 = client.get("http://example.com/x").await?;
718 assert_eq!(r1.response_body().read_string().await?, "body-0");
719
720 let _ = client.post("http://example.com/x").await?;
721
722 let mut r3 = client.get("http://example.com/x").await?;
723 assert_eq!(r3.response_body().read_string().await?, "body-2");
724 assert_eq!(counter.load(Ordering::SeqCst), 3);
725 Ok(())
726 }
727
728 #[test(harness)]
729 async fn post_invalidates_location_and_content_location_targets() -> TestResult {
730 #[derive(Debug, Clone, Default)]
731 struct LclServer(Arc<AtomicUsize>);
732 impl ServerHandler for LclServer {
733 async fn run(&self, conn: ServerConn) -> ServerConn {
734 let n = self.0.fetch_add(1, Ordering::SeqCst);
735 if conn.method() == Method::Post {
736 conn.with_response_header(KnownHeaderName::Location, "/loc")
737 .with_response_header(KnownHeaderName::ContentLocation, "/cl")
738 .ok(format!("post-body-{n}"))
739 } else {
740 conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
741 .ok(format!("get-body-{n}"))
742 }
743 }
744 }
745
746 let server = LclServer::default();
747 let counter = Arc::clone(&server.0);
748 let client = Client::new(ServerConnector::new(server))
749 .with_handler(Cache::new(InMemoryStorage::new()));
750
751 let mut loc = client.get("http://example.com/loc").await?;
753 let _ = loc.response_body().read_string().await?;
754 let mut cl = client.get("http://example.com/cl").await?;
755 let _ = cl.response_body().read_string().await?;
756 assert_eq!(counter.load(Ordering::SeqCst), 2);
757
758 let _ = client.post("http://example.com/anything").await?;
759
760 let _ = client.get("http://example.com/loc").await?;
761 let _ = client.get("http://example.com/cl").await?;
762 assert_eq!(
763 counter.load(Ordering::SeqCst),
764 5,
765 "POST + 2 re-fetches should hit the origin again"
766 );
767 Ok(())
768 }
769
770 #[test(harness)]
771 async fn cross_host_location_does_not_invalidate() -> TestResult {
772 #[derive(Debug, Clone, Default)]
773 struct CrossHostServer(Arc<AtomicUsize>);
774 impl ServerHandler for CrossHostServer {
775 async fn run(&self, conn: ServerConn) -> ServerConn {
776 let n = self.0.fetch_add(1, Ordering::SeqCst);
777 if conn.method() == Method::Post {
778 conn.with_response_header(KnownHeaderName::Location, "http://other.example/loc")
779 .ok(format!("post-{n}"))
780 } else {
781 conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
782 .ok(format!("get-{n}"))
783 }
784 }
785 }
786
787 let server = CrossHostServer::default();
788 let counter = Arc::clone(&server.0);
789 let client = Client::new(ServerConnector::new(server))
790 .with_handler(Cache::new(InMemoryStorage::new()));
791
792 let mut populating = client.get("http://other.example/loc").await?;
795 let _ = populating.response_body().read_string().await?;
796 assert_eq!(counter.load(Ordering::SeqCst), 1);
797
798 let _ = client.post("http://example.com/anything").await?;
799
800 let mut r = client.get("http://other.example/loc").await?;
801 assert_eq!(r.response_body().read_string().await?, "get-0");
802 assert_eq!(
803 counter.load(Ordering::SeqCst),
804 2,
805 "no extra GET to other.example"
806 );
807 Ok(())
808 }
809
810 #[test(harness)]
811 async fn stale_with_etag_revalidates_to_304() -> TestResult {
812 let (client, counter) = cache_client(CountingServer::new("max-age=0").with_etag(r#""v1""#));
813
814 let mut r1 = client.get("http://example.com/x").await?;
815 assert_eq!(r1.response_body().read_string().await?, "body-0");
816 assert_eq!(counter.load(Ordering::SeqCst), 1);
817
818 let mut r2 = client.get("http://example.com/x").await?;
819 assert_eq!(r2.status(), Some(Status::Ok));
820 assert_eq!(r2.response_body().read_string().await?, "body-0");
821 assert_eq!(counter.load(Ordering::SeqCst), 2);
822 Ok(())
823 }
824
825 #[test(harness)]
826 async fn stale_with_mismatching_etag_replaces_body() -> TestResult {
827 #[derive(Debug, Clone)]
828 struct AlwaysFresh {
829 counter: Arc<AtomicUsize>,
830 }
831 impl ServerHandler for AlwaysFresh {
832 async fn run(&self, conn: ServerConn) -> ServerConn {
833 let n = self.counter.fetch_add(1, Ordering::SeqCst);
834 conn.with_response_header(KnownHeaderName::CacheControl, "max-age=0")
835 .with_response_header(KnownHeaderName::Etag, r#""rolling""#)
836 .ok(format!("body-{n}"))
837 }
838 }
839 let counter = Arc::new(AtomicUsize::new(0));
840 let server = AlwaysFresh {
841 counter: counter.clone(),
842 };
843 let client = Client::new(ServerConnector::new(server))
844 .with_handler(Cache::new(InMemoryStorage::new()));
845
846 let mut r1 = client.get("http://example.com/x").await?;
847 assert_eq!(r1.response_body().read_string().await?, "body-0");
848
849 let mut r2 = client.get("http://example.com/x").await?;
850 assert_eq!(r2.response_body().read_string().await?, "body-1");
851 assert_eq!(counter.load(Ordering::SeqCst), 2);
852 Ok(())
853 }
854
855 #[test(harness)]
856 async fn vary_isolates_entries_by_request_header() -> TestResult {
857 #[derive(Debug, Clone)]
858 struct VaryServer {
859 counter: Arc<AtomicUsize>,
860 }
861 impl ServerHandler for VaryServer {
862 async fn run(&self, conn: ServerConn) -> ServerConn {
863 self.counter.fetch_add(1, Ordering::SeqCst);
864 let ae = conn
865 .request_headers()
866 .get_str(KnownHeaderName::AcceptEncoding)
867 .unwrap_or("none")
868 .to_string();
869 conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
870 .with_response_header(KnownHeaderName::Vary, "Accept-Encoding")
871 .ok(format!("body-for-{ae}"))
872 }
873 }
874 let counter = Arc::new(AtomicUsize::new(0));
875 let server = VaryServer {
876 counter: counter.clone(),
877 };
878 let client = Client::new(ServerConnector::new(server))
879 .with_handler(Cache::new(InMemoryStorage::new()));
880
881 let mut r1 = client
882 .get("http://example.com/x")
883 .with_request_header(KnownHeaderName::AcceptEncoding, "gzip")
884 .await?;
885 assert_eq!(r1.response_body().read_string().await?, "body-for-gzip");
886
887 let mut r2 = client
888 .get("http://example.com/x")
889 .with_request_header(KnownHeaderName::AcceptEncoding, "br")
890 .await?;
891 assert_eq!(r2.response_body().read_string().await?, "body-for-br");
892
893 let mut r3 = client
894 .get("http://example.com/x")
895 .with_request_header(KnownHeaderName::AcceptEncoding, "gzip")
896 .await?;
897 assert_eq!(r3.response_body().read_string().await?, "body-for-gzip");
898
899 assert_eq!(counter.load(Ordering::SeqCst), 2);
900 Ok(())
901 }
902
903 #[test(harness)]
904 async fn oversized_body_is_served_but_not_cached() -> TestResult {
905 let server = CountingServer::new("max-age=600");
906 let counter = server.counter.clone();
907 let client = Client::new(ServerConnector::new(server))
908 .with_handler(Cache::new(InMemoryStorage::new()).with_max_cacheable_size(3));
909
910 let mut r1 = client.get("http://example.com/x").await?;
911 assert_eq!(r1.response_body().read_string().await?, "body-0");
912
913 let mut r2 = client.get("http://example.com/x").await?;
914 assert_eq!(r2.response_body().read_string().await?, "body-1");
915 assert_eq!(counter.load(Ordering::SeqCst), 2);
916 Ok(())
917 }
918
919 #[test(harness)]
923 async fn chunked_upstream_is_stored_and_replayed_decoded() -> TestResult {
924 #[derive(Debug, Clone)]
925 struct ChunkedServer {
926 counter: Arc<AtomicUsize>,
927 }
928 impl ServerHandler for ChunkedServer {
929 async fn run(&self, conn: ServerConn) -> ServerConn {
930 self.counter.fetch_add(1, Ordering::SeqCst);
931 let body = Body::new_streaming(
933 futures_lite::io::Cursor::new(b"chunked-body-content".to_vec()),
934 None,
935 );
936 conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
937 .with_body(body)
938 .with_status(Status::Ok)
939 .halt()
940 }
941 }
942
943 let counter = Arc::new(AtomicUsize::new(0));
944 let server = ChunkedServer {
945 counter: counter.clone(),
946 };
947 let client = Client::new(ServerConnector::new(server))
948 .with_handler(Cache::new(InMemoryStorage::new()));
949
950 let mut r1 = client.get("http://example.com/x").await?;
952 assert_eq!(
953 r1.response_body().read_string().await?,
954 "chunked-body-content"
955 );
956
957 let mut r2 = client.get("http://example.com/x").await?;
959 assert_eq!(r2.status(), Some(Status::Ok));
960 assert_eq!(
961 r2.response_body().read_string().await?,
962 "chunked-body-content"
963 );
964 assert_eq!(
965 counter.load(Ordering::SeqCst),
966 1,
967 "second request served from cache"
968 );
969 Ok(())
970 }
971
972 use crate::test_helpers::exchange;
975 use std::{io, net::SocketAddr};
976 use trillium_client::{Connector, Url};
977
978 #[derive(Debug)]
981 struct FailingConnector {
982 inner: ServerConnector<Status>,
983 }
984
985 impl FailingConnector {
986 fn new() -> Self {
987 Self {
988 inner: ServerConnector::new(Status::Ok),
989 }
990 }
991 }
992
993 impl Connector for FailingConnector {
994 type Runtime = <ServerConnector<Status> as Connector>::Runtime;
995 type Transport = <ServerConnector<Status> as Connector>::Transport;
996 type Udp = <ServerConnector<Status> as Connector>::Udp;
997
998 async fn connect(&self, _url: &Url) -> io::Result<Self::Transport> {
999 Err(io::Error::new(
1000 io::ErrorKind::ConnectionRefused,
1001 "test failure",
1002 ))
1003 }
1004
1005 fn runtime(&self) -> Self::Runtime {
1006 self.inner.runtime().clone()
1007 }
1008
1009 async fn resolve(&self, host: &str, port: u16) -> io::Result<Vec<SocketAddr>> {
1010 self.inner.resolve(host, port).await
1011 }
1012 }
1013
1014 async fn populate_stale_entry(
1017 storage: &InMemoryStorage,
1018 cache_control: &'static str,
1019 body: &'static [u8],
1020 ) -> CacheKey {
1021 let conn = exchange(
1022 Method::Get,
1023 &[],
1024 Status::Ok,
1025 &[(KnownHeaderName::CacheControl, cache_control)],
1026 );
1027 let policy =
1028 crate::test_helpers::policy_from(&conn, SystemTime::now(), CacheOptions::default());
1029 let key = CacheKey::new(Method::Get, "http://example.com/x".parse().unwrap());
1030 let mut handle = storage.put(key.clone(), policy).await.unwrap();
1031 use futures_lite::AsyncWriteExt;
1032 handle.write_all(body).await.unwrap();
1033 handle.finalize(None).await.unwrap();
1034 key
1035 }
1036
1037 #[test(harness)]
1038 async fn sie_serves_stale_on_transport_error() -> TestResult {
1039 let storage = InMemoryStorage::new();
1040 let _ =
1041 populate_stale_entry(&storage, "max-age=0, stale-if-error=3600", b"stale body").await;
1042 let client = Client::new(FailingConnector::new()).with_handler(Cache::new(storage));
1043
1044 let mut conn = client.get("http://example.com/x").await?;
1045 assert_eq!(conn.status(), Some(Status::Ok));
1046 assert_eq!(conn.response_body().read_string().await?, "stale body");
1047 Ok(())
1048 }
1049
1050 #[test(harness)]
1051 async fn no_sie_propagates_transport_error() -> TestResult {
1052 let storage = InMemoryStorage::new();
1053 let _ = populate_stale_entry(&storage, "max-age=0", b"stale body").await;
1054 let client = Client::new(FailingConnector::new()).with_handler(Cache::new(storage));
1055
1056 let result = client.get("http://example.com/x").await;
1057 assert!(
1058 result.is_err(),
1059 "expected transport error to propagate, got {result:?}"
1060 );
1061 Ok(())
1062 }
1063
1064 #[test(harness)]
1065 async fn sie_serves_stale_on_5xx() -> TestResult {
1066 let storage = InMemoryStorage::new();
1067 let _ =
1068 populate_stale_entry(&storage, "max-age=0, stale-if-error=3600", b"stale body").await;
1069 let server = ServerConnector::new(Status::ServiceUnavailable);
1070 let client = Client::new(server).with_handler(Cache::new(storage));
1071
1072 let mut conn = client.get("http://example.com/x").await?;
1073 assert_eq!(conn.status(), Some(Status::Ok));
1074 assert_eq!(conn.response_body().read_string().await?, "stale body");
1075 Ok(())
1076 }
1077
1078 #[test(harness)]
1079 async fn no_sie_serves_5xx_as_received() -> TestResult {
1080 let storage = InMemoryStorage::new();
1081 let _ = populate_stale_entry(&storage, "max-age=0", b"stale body").await;
1082 let server = ServerConnector::new(Status::ServiceUnavailable);
1083 let client = Client::new(server).with_handler(Cache::new(storage));
1084
1085 let conn = client.get("http://example.com/x").await?;
1086 assert_eq!(conn.status(), Some(Status::ServiceUnavailable));
1087 Ok(())
1088 }
1089
1090 use std::time::Duration;
1093
1094 #[test(harness)]
1095 async fn swr_serves_stale_immediately_and_revalidates_in_background() -> TestResult {
1096 let storage = InMemoryStorage::new();
1097 let _ = populate_stale_entry(
1098 &storage,
1099 "max-age=0, stale-while-revalidate=3600",
1100 b"stale-body",
1101 )
1102 .await;
1103
1104 let server = CountingServer::new("max-age=600");
1105 let counter = server.counter.clone();
1106 let client = Client::new(ServerConnector::new(server)).with_handler(Cache::new(storage));
1107
1108 let mut conn = client.get("http://example.com/x").await?;
1109 assert_eq!(conn.status(), Some(Status::Ok));
1110 assert_eq!(conn.response_body().read_string().await?, "stale-body");
1111
1112 let runtime = client.connector().runtime();
1113 for _ in 0..100 {
1114 if counter.load(Ordering::SeqCst) > 0 {
1115 break;
1116 }
1117 runtime.delay(Duration::from_millis(10)).await;
1118 }
1119 assert_eq!(
1120 counter.load(Ordering::SeqCst),
1121 1,
1122 "background revalidation should hit the origin"
1123 );
1124
1125 let cache = client
1126 .downcast_handler::<Cache<InMemoryStorage>>()
1127 .expect("cache handler installed");
1128 let key = CacheKey::new(Method::Get, "http://example.com/x".parse().unwrap());
1129 for _ in 0..100 {
1131 if !cache.storage().get(&key).await.is_empty() {
1132 break;
1133 }
1134 runtime.delay(Duration::from_millis(10)).await;
1135 }
1136 let entries = cache.storage().get(&key).await;
1137 assert_eq!(entries.len(), 1);
1138 let body = entries[0].clone().open().await.unwrap();
1139 use futures_lite::AsyncReadExt;
1140 let mut buf = Vec::new();
1141 let mut body = body;
1142 body.read_to_end(&mut buf).await.unwrap();
1143 assert_eq!(&buf, b"body-0");
1144 Ok(())
1145 }
1146
1147 #[test(harness)]
1148 async fn no_swr_falls_back_to_synchronous_revalidation() -> TestResult {
1149 let storage = InMemoryStorage::new();
1150 let _ = populate_stale_entry(&storage, "max-age=0", b"stale-body").await;
1151
1152 let server = CountingServer::new("max-age=600");
1153 let counter = server.counter.clone();
1154 let client = Client::new(ServerConnector::new(server)).with_handler(Cache::new(storage));
1155
1156 let mut conn = client.get("http://example.com/x").await?;
1157 assert_eq!(conn.response_body().read_string().await?, "body-0");
1158 assert_eq!(counter.load(Ordering::SeqCst), 1);
1159 Ok(())
1160 }
1161
1162 #[test(harness)]
1163 async fn must_revalidate_disables_swr() -> TestResult {
1164 let storage = InMemoryStorage::new();
1165 let _ = populate_stale_entry(
1166 &storage,
1167 "max-age=0, must-revalidate, stale-while-revalidate=3600",
1168 b"stale-body",
1169 )
1170 .await;
1171
1172 let server = CountingServer::new("max-age=600");
1173 let client = Client::new(ServerConnector::new(server)).with_handler(Cache::new(storage));
1174
1175 let mut conn = client.get("http://example.com/x").await?;
1176 assert_eq!(conn.response_body().read_string().await?, "body-0");
1177 Ok(())
1178 }
1179}