Skip to main content

trillium_cache/
client.rs

1//! Client-side cache handler.
2//!
3//! [`Cache`] wires [`CacheStorage`] + [`CachePolicy`] onto a `trillium-client` request
4//! lifecycle. Feature-gated behind `client`.
5//!
6//! ## Position in the handler chain
7//!
8//! Add `Cache` *last* in the handler tuple:
9//!
10//! ```ignore
11//! client.with_handler((Logger::new(), Cache::new(storage)));
12//! ```
13//!
14//! Reasons:
15//! - `run` runs in declared order; the cache should be the last `run` so it can short-circuit the
16//!   network with a fresh hit.
17//! - `after_response` runs in reverse declared order; the cache should be the first
18//!   `after_response` so it can read the response body and replace it with a streaming tee before
19//!   any other handler reads the (one-shot) network body.
20//!
21//! ## Streaming
22//!
23//! On miss, the cache installs a streaming tee between the origin response body and the
24//! user — bytes flow to storage and the user concurrently. Trailers propagate to both.
25//! The cap on stored body size is enforced mid-stream; if exceeded, the cache write is
26//! aborted and the remainder of the body passes through unmodified.
27
28use crate::{
29    CacheKey, CacheOptions, CachePolicy, CacheStorage, PutHandle, StoredEntry,
30    tee::TeeingReader,
31    validation::{AfterResponse, BeforeRequest},
32};
33use futures_lite::{AsyncReadExt, AsyncWriteExt};
34use std::{sync::Arc, time::SystemTime};
35use trillium_client::{
36    Body, Client, ClientHandler, Conn, ConnExt, Headers, KnownHeaderName, Method, ResponseBody,
37    Result, Url,
38};
39
40const DEFAULT_MAX_CACHEABLE_SIZE: u64 = 16 * 1024 * 1024;
41
42/// Cache handler. Mount on a [`trillium_client::Client`] together with
43/// a [`CacheStorage`] backend.
44///
45/// `Cache` is cheap to `Clone`: storage is held in an `Arc`, so clones
46/// share the same backend.
47#[derive(Debug)]
48pub struct Cache<S: CacheStorage> {
49    storage: Arc<S>,
50    options: CacheOptions,
51    max_cacheable_size: u64,
52}
53
54impl<S: CacheStorage> Clone for Cache<S> {
55    fn clone(&self) -> Self {
56        Self {
57            storage: Arc::clone(&self.storage),
58            options: self.options,
59            max_cacheable_size: self.max_cacheable_size,
60        }
61    }
62}
63
64impl<S: CacheStorage> Cache<S> {
65    /// Construct a cache handler with default options
66    /// ([`CacheOptions::default`]) and a 16 MiB body-size cap.
67    pub fn new(storage: S) -> Self {
68        Self {
69            storage: Arc::new(storage),
70            options: CacheOptions::default(),
71            max_cacheable_size: DEFAULT_MAX_CACHEABLE_SIZE,
72        }
73    }
74
75    /// Replace the cache options.
76    pub fn with_options(mut self, options: CacheOptions) -> Self {
77        self.options = options;
78        self
79    }
80
81    /// Mark this cache as a *shared cache* (proxy/CDN). Equivalent to
82    /// `with_options` with `shared: true`.
83    pub fn shared(mut self) -> Self {
84        self.options.shared = true;
85        self
86    }
87
88    /// Set the cap on response body bytes the cache will store.
89    /// Responses larger than this pass through but are not stored. If
90    /// the cap is exceeded mid-stream, the cache write is aborted and
91    /// the remainder of the body passes through unmodified.
92    pub fn with_max_cacheable_size(mut self, max: u64) -> Self {
93        self.max_cacheable_size = max;
94        self
95    }
96
97    /// Borrow the storage backend.
98    pub fn storage(&self) -> &S {
99        &self.storage
100    }
101}
102
103// State stashed in the conn's typeset by `run` for `after_response` to
104// pick up.
105enum CacheCtx<E: StoredEntry> {
106    /// Cache hit — `run` already populated a synthetic response and
107    /// halted. `after_response` is a no-op.
108    Hit,
109    /// Stored entry was stale and a conditional revalidation request
110    /// has been spliced onto the conn. `after_response` reconciles the
111    /// origin's reply (304 vs 200) with the stored entry.
112    Revalidation { stored: E, key: CacheKey },
113    /// Cache miss — no stored entry matched. If the response is
114    /// storable, `after_response` will install a streaming tee.
115    Miss { key: CacheKey },
116    /// Unsafe method (POST/PUT/DELETE/...). On a non-error response,
117    /// `after_response` invalidates the target URI per RFC 9111 §4.4.
118    Unsafe { url: Url },
119}
120
121impl<E: StoredEntry> std::fmt::Debug for CacheCtx<E> {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        match self {
124            Self::Hit => f.write_str("Hit"),
125            Self::Revalidation { key, .. } => f
126                .debug_struct("Revalidation")
127                .field("key", key)
128                .finish_non_exhaustive(),
129            Self::Miss { key } => f.debug_struct("Miss").field("key", key).finish(),
130            Self::Unsafe { url } => f.debug_struct("Unsafe").field("url", url).finish(),
131        }
132    }
133}
134
135impl<S: CacheStorage> ClientHandler for Cache<S> {
136    async fn run(&self, conn: &mut Conn) -> Result<()> {
137        let method = conn.method();
138        let key = CacheKey::new(method, conn.url().clone());
139        log::trace!("cache: run {method} {}", conn.url());
140
141        // RFC 9111 §4.4: don't read from cache for unsafe methods;
142        // possibly invalidate after the round-trip.
143        if !method.is_safe() {
144            log::trace!("cache: unsafe method {method}, bypassing cache read");
145            conn.insert_state(CacheCtx::<S::StoredEntry>::Unsafe {
146                url: conn.url().clone(),
147            });
148            return Ok(());
149        }
150
151        let now = SystemTime::now();
152        let entries = self.storage.get(&key).await;
153        log::trace!("cache: {} stored candidate(s) for {key}", entries.len());
154
155        for entry in entries {
156            match entry.policy().before_request(conn.request_headers(), now) {
157                BeforeRequest::Fresh(cached) => {
158                    log::trace!("cache: hit (fresh) for {key}, serving cached response");
159                    *conn.response_headers_mut() = cached.headers;
160                    let body = match entry.open().await {
161                        Ok(b) => b,
162                        Err(e) => {
163                            log::warn!(
164                                "cache: open for hit failed for {key}: {e}, passing through"
165                            );
166                            // Reset the override; let the network round-trip happen.
167                            return Ok(());
168                        }
169                    };
170                    conn.set_status(cached.status)
171                        .set_response_body(body)
172                        .halt()
173                        .insert_state(CacheCtx::<S::StoredEntry>::Hit);
174                    return Ok(());
175                }
176
177                BeforeRequest::NotModified(cached) => {
178                    log::trace!("cache: hit (fresh, conditional matches) for {key}, serving 304");
179                    *conn.response_headers_mut() = cached.headers;
180                    conn.set_status(cached.status)
181                        .set_response_body(b"" as &[u8])
182                        .halt()
183                        .insert_state(CacheCtx::<S::StoredEntry>::Hit);
184                    return Ok(());
185                }
186
187                BeforeRequest::Stale {
188                    request_headers,
189                    matches: true,
190                } => {
191                    // RFC 9111 §4.2.4 stale-while-revalidate: if the
192                    // entry is within its SWR window, serve it
193                    // immediately and revalidate in the background.
194                    if entry.policy().is_swr_eligible(now) {
195                        log::trace!(
196                            "cache: stale-while-revalidate for {key}, serving stale + spawning \
197                             background revalidation"
198                        );
199                        let entry_for_bg = entry.clone();
200                        self.spawn_background_revalidation(
201                            conn,
202                            entry_for_bg,
203                            key.clone(),
204                            request_headers,
205                        );
206                        match self.serve_stale(conn, entry, now).await {
207                            Ok(()) => {
208                                conn.halt();
209                                conn.insert_state(CacheCtx::<S::StoredEntry>::Hit);
210                            }
211                            Err(e) => {
212                                log::warn!(
213                                    "cache: open for stale serve failed for {key}: {e}, passing \
214                                     through"
215                                );
216                            }
217                        }
218                        return Ok(());
219                    }
220                    // Otherwise fall through to synchronous revalidation.
221                    log::trace!("cache: stale for {key}, sending conditional revalidation request");
222                    *conn.request_headers_mut() = request_headers;
223                    conn.insert_state(CacheCtx::Revalidation { stored: entry, key });
224                    return Ok(());
225                }
226
227                BeforeRequest::Stale { matches: false, .. } => {
228                    log::trace!("cache: candidate vary-mismatch for {key}, trying next");
229                    continue;
230                }
231            }
232        }
233
234        log::trace!("cache: miss for {key}, forwarding to origin");
235        conn.insert_state(CacheCtx::<S::StoredEntry>::Miss { key });
236        Ok(())
237    }
238
239    async fn after_response(&self, conn: &mut Conn) -> Result<()> {
240        let Some(ctx) = conn.take_state::<CacheCtx<S::StoredEntry>>() else {
241            log::trace!("cache: after_response with no CacheCtx, nothing to do");
242            return Ok(());
243        };
244
245        // RFC 9111 §4.2.4 / RFC 5861 stale-if-error: if revalidation hit a transport-level
246        // failure or a 5xx, and the stored entry is SIE-eligible, serve it instead.
247        if let CacheCtx::Revalidation { ref stored, .. } = ctx {
248            let now = SystemTime::now();
249            let origin_failed =
250                conn.error().is_some() || conn.status().is_some_and(|s| s.is_server_error());
251            if origin_failed && stored.policy().is_sie_eligible(now) {
252                log::trace!(
253                    "cache: stale-if-error recovery for {} (origin error/{:?}), serving stale",
254                    conn.url(),
255                    conn.status()
256                );
257                if let Err(e) = self.serve_stale(conn, stored.clone(), now).await {
258                    log::warn!(
259                        "cache: open for stale serve failed for {}: {e}, propagating error",
260                        conn.url()
261                    );
262                    return Ok(());
263                }
264                conn.take_error();
265                return Ok(());
266            }
267        }
268
269        if conn.status().is_none() {
270            log::trace!(
271                "cache: transport error with no SIE recovery for {}, propagating",
272                conn.url()
273            );
274            return Ok(());
275        }
276
277        match ctx {
278            CacheCtx::Hit => {
279                log::trace!("cache: hit confirmed in after_response for {}", conn.url());
280                Ok(())
281            }
282            CacheCtx::Revalidation { stored, key } => {
283                self.handle_revalidation(conn, stored, key).await
284            }
285            CacheCtx::Miss { key } => self.handle_miss(conn, key).await,
286            CacheCtx::Unsafe { url } => {
287                let status = conn.status().expect("checked above");
288                if status.is_success() || status.is_redirection() {
289                    log::trace!(
290                        "cache: unsafe method {} → {}, invalidating GET and HEAD entries for {url}",
291                        conn.method(),
292                        status
293                    );
294                    self.invalidate_url(&url).await;
295
296                    for header in [KnownHeaderName::Location, KnownHeaderName::ContentLocation] {
297                        let Some(value) = conn.response_headers().get_str(header) else {
298                            continue;
299                        };
300                        let Ok(target) = url.join(value) else {
301                            log::trace!(
302                                "cache: unsafe method secondary invalidation: {header} value \
303                                 {value:?} did not resolve, skipping"
304                            );
305                            continue;
306                        };
307                        if target.host_str() != url.host_str() {
308                            log::trace!(
309                                "cache: unsafe method secondary invalidation: {header} target \
310                                 {target} differs in host from request URL, skipping (§4.4 DoS \
311                                 guard)"
312                            );
313                            continue;
314                        }
315                        log::trace!(
316                            "cache: unsafe method secondary invalidation via {header}: {target}"
317                        );
318                        self.invalidate_url(&target).await;
319                    }
320                } else {
321                    log::trace!(
322                        "cache: unsafe method {} → {} for {url}, no invalidation",
323                        conn.method(),
324                        status
325                    );
326                }
327                Ok(())
328            }
329        }
330    }
331}
332
333impl<S: CacheStorage> Cache<S> {
334    // §4.4: invalidate any stored entries for this URI under the methods
335    // we'd ever cache (GET and HEAD).
336    async fn invalidate_url(&self, url: &Url) {
337        self.storage
338            .invalidate(&CacheKey::new(Method::Get, url.clone()))
339            .await;
340        self.storage
341            .invalidate(&CacheKey::new(Method::Head, url.clone()))
342            .await;
343    }
344
345    // RFC 9111 §4.2.4 / RFC 5861: apply a stored stale entry to the
346    // conn as the served response. Used by both stale-while-revalidate
347    // and stale-if-error paths.
348    async fn serve_stale(
349        &self,
350        conn: &mut Conn,
351        stored: S::StoredEntry,
352        now: SystemTime,
353    ) -> std::io::Result<()> {
354        let cached = stored.policy().cached_response(now);
355        let body = stored.open().await?;
356        conn.set_status(cached.status);
357        *conn.response_headers_mut() = cached.headers;
358        conn.set_response_body(body);
359        Ok(())
360    }
361
362    // RFC 9111 §4.2.4: spawn a background revalidation so the user gets
363    // an immediate stale response while the cache refreshes.
364    //
365    // We share the runtime + connector + pool with the user's client
366    // (cloning `conn.client()` is cheap — the underlying pools are
367    // Arc-shared). The bypass client has its handler stack replaced
368    // with `()` so the cache handler doesn't recurse on itself.
369    fn spawn_background_revalidation(
370        &self,
371        conn: &Conn,
372        stored: S::StoredEntry,
373        key: CacheKey,
374        request_headers: Headers,
375    ) {
376        let runtime = conn.client().connector().runtime();
377        let bypass_client = conn.client().clone().with_handler(());
378        let cache = self.clone();
379        let method = conn.method();
380        let url = conn.url().clone();
381        log::trace!("cache: spawning background revalidation for {key}");
382
383        let _detached = runtime.spawn(async move {
384            cache
385                .background_revalidation(bypass_client, method, url, request_headers, stored, key)
386                .await;
387        });
388    }
389
390    async fn background_revalidation(
391        self,
392        client: Client,
393        method: Method,
394        url: Url,
395        request_headers: Headers,
396        mut stored: S::StoredEntry,
397        key: CacheKey,
398    ) {
399        let mut new_conn = client.build_conn(method, url);
400        *new_conn.request_headers_mut() = request_headers;
401
402        if let Err(e) = (&mut new_conn).await {
403            log::trace!(
404                "cache: background revalidation transport error for {key} ({e}), leaving stored \
405                 entry"
406            );
407            return;
408        }
409
410        let now = SystemTime::now();
411        let new_status = new_conn
412            .status()
413            .expect("background revalidation: response not yet received");
414        match stored.policy().after_response(
415            new_conn.request_headers(),
416            new_status,
417            new_conn.response_headers(),
418            now,
419        ) {
420            AfterResponse::NotModified(new_policy, _) => {
421                log::trace!("cache: background revalidation 304 for {key}, refreshing entry");
422                if let Err(e) = stored.refresh_policy(new_policy).await {
423                    log::warn!("cache: background refresh_policy failed for {key}: {e}");
424                }
425            }
426            AfterResponse::Modified => {
427                let new_request_method = new_conn.method();
428                let new_request_headers = new_conn.request_headers().clone();
429                let new_response_headers = new_conn.response_headers().clone();
430                if !CachePolicy::is_storable(
431                    new_request_method,
432                    &new_request_headers,
433                    new_status,
434                    &new_response_headers,
435                    &self.options,
436                ) {
437                    log::trace!(
438                        "cache: background revalidation 200 for {key}, response not storable, \
439                         dropping"
440                    );
441                    return;
442                }
443                let new_policy = CachePolicy::new(
444                    new_request_method,
445                    &new_request_headers,
446                    new_status,
447                    new_response_headers,
448                    now,
449                    self.options,
450                );
451                let put_handle = match self.storage.put(key.clone(), new_policy).await {
452                    Ok(h) => h,
453                    Err(e) => {
454                        log::warn!(
455                            "cache: background put({key}) failed: {e}, leaving stored entry"
456                        );
457                        return;
458                    }
459                };
460                let Some(body) = new_conn.take_response_body() else {
461                    log::trace!(
462                        "cache: background revalidation 200 for {key}, no body, leaving stored \
463                         entry"
464                    );
465                    return;
466                };
467                if let Err(e) = copy_into_storage(body, put_handle, self.max_cacheable_size).await {
468                    log::warn!(
469                        "cache: background copy into storage failed for {key}: {e}, leaving \
470                         stored entry"
471                    );
472                }
473            }
474        }
475    }
476
477    async fn handle_revalidation(
478        &self,
479        conn: &mut Conn,
480        mut stored: S::StoredEntry,
481        key: CacheKey,
482    ) -> Result<()> {
483        let now = SystemTime::now();
484        let new_status = conn.status().expect("checked above");
485        match stored.policy().after_response(
486            conn.request_headers(),
487            new_status,
488            conn.response_headers(),
489            now,
490        ) {
491            AfterResponse::NotModified(new_policy, cached_response) => {
492                log::trace!(
493                    "cache: revalidation 304 for {key}, reusing stored body and refreshing entry"
494                );
495                if let Err(e) = stored.refresh_policy(new_policy).await {
496                    log::warn!("cache: refresh_policy failed for {key}: {e}");
497                }
498                let body = match stored.open().await {
499                    Ok(b) => b,
500                    Err(e) => {
501                        log::warn!("cache: open after 304 failed for {key}: {e}, passing through");
502                        return Ok(());
503                    }
504                };
505                conn.set_status(cached_response.status);
506                *conn.response_headers_mut() = cached_response.headers;
507                conn.set_response_body(body);
508                Ok(())
509            }
510            AfterResponse::Modified => {
511                // Drop the stored entry; treat as a fresh miss against the same key. The new
512                // entry replaces any stored variant with the same Vary signature.
513                drop(stored);
514                self.handle_miss(conn, key).await
515            }
516        }
517    }
518
519    async fn handle_miss(&self, conn: &mut Conn, key: CacheKey) -> Result<()> {
520        let status = conn.status().expect("checked above");
521        if !CachePolicy::is_storable(
522            conn.method(),
523            conn.request_headers(),
524            status,
525            conn.response_headers(),
526            &self.options,
527        ) {
528            log::trace!("cache: miss for {key}, response not storable, passing through");
529            return Ok(());
530        }
531
532        // Skip the put entirely when content-length is known and already over cap.
533        if let Some(len) = conn
534            .response_headers()
535            .get_str(KnownHeaderName::ContentLength)
536            .and_then(|s| s.parse::<u64>().ok())
537            && len > self.max_cacheable_size
538        {
539            log::trace!(
540                "cache: miss for {key}, body {len} > max {}, not caching",
541                self.max_cacheable_size
542            );
543            return Ok(());
544        }
545
546        let policy = CachePolicy::new(
547            conn.method(),
548            conn.request_headers(),
549            status,
550            conn.response_headers().clone(),
551            SystemTime::now(),
552            self.options,
553        );
554        let put_handle = match self.storage.put(key.clone(), policy).await {
555            Ok(h) => h,
556            Err(e) => {
557                log::warn!("cache: put({key}) failed: {e}, passing through");
558                return Ok(());
559            }
560        };
561
562        let Some(response_body) = conn.take_response_body() else {
563            log::trace!("cache: miss for {key}, no body, passing through");
564            return Ok(());
565        };
566        let len = response_body.content_length();
567        // Strip wire-format chunked framing so the tee stores the decoded body. The outer
568        // body re-frames for the downstream when `len` is None.
569        let upstream = Body::new_with_trailers(response_body, len).without_chunked_framing();
570        log::trace!("cache: miss for {key}, streaming through tee");
571        let tee = TeeingReader::new(upstream, put_handle, self.max_cacheable_size);
572        conn.set_response_body(Body::new_with_trailers(tee, len));
573        Ok(())
574    }
575}
576
577// Copy a response body into a put handle, finalizing on EOF with whatever trailers the body
578// exposes. Used by background revalidation, where there's no concurrent user reader; the cap
579// is enforced by aborting when exceeded.
580async fn copy_into_storage<P: PutHandle>(
581    body: ResponseBody<'static>,
582    mut put: P,
583    cap: u64,
584) -> std::io::Result<()> {
585    let len = body.content_length();
586    // Strip wire-format chunked framing so storage gets the decoded body, not chunk bytes.
587    let mut body = Body::new_with_trailers(body, len).without_chunked_framing();
588    let mut buf = [0u8; 8192];
589    let mut total: u64 = 0;
590    loop {
591        let n = body.read(&mut buf).await?;
592        if n == 0 {
593            break;
594        }
595        total = total.saturating_add(n as u64);
596        if total > cap {
597            // Drop put_handle without finalizing — storage gets nothing.
598            drop(put);
599            log::trace!("cache: background copy exceeded cap {cap}, aborting cache write");
600            return Ok(());
601        }
602        put.write_all(&buf[..n]).await?;
603    }
604    let trailers = body.trailers();
605    put.finalize(trailers).await
606}
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611    use crate::InMemoryStorage;
612    use std::sync::{
613        Arc,
614        atomic::{AtomicUsize, Ordering},
615    };
616    use trillium::{Conn as ServerConn, Handler as ServerHandler, KnownHeaderName, Status};
617    use trillium_client::Client;
618    use trillium_testing::{ServerConnector, TestResult, harness, test};
619
620    #[derive(Debug, Clone)]
621    struct CountingServer {
622        counter: Arc<AtomicUsize>,
623        cache_control: &'static str,
624        etag: Option<&'static str>,
625    }
626
627    impl CountingServer {
628        fn new(cache_control: &'static str) -> Self {
629            Self {
630                counter: Arc::new(AtomicUsize::new(0)),
631                cache_control,
632                etag: None,
633            }
634        }
635
636        fn with_etag(mut self, etag: &'static str) -> Self {
637            self.etag = Some(etag);
638            self
639        }
640    }
641
642    impl ServerHandler for CountingServer {
643        async fn run(&self, conn: ServerConn) -> ServerConn {
644            let n = self.counter.fetch_add(1, Ordering::SeqCst);
645
646            if let Some(etag) = self.etag {
647                if conn.request_headers().get_str(KnownHeaderName::IfNoneMatch) == Some(etag) {
648                    return conn
649                        .with_status(Status::NotModified)
650                        .with_response_header(KnownHeaderName::Etag, etag)
651                        .halt();
652                }
653            }
654
655            let mut conn = conn
656                .with_response_header(KnownHeaderName::CacheControl, self.cache_control)
657                .ok(format!("body-{n}"));
658            if let Some(etag) = self.etag {
659                conn.response_headers_mut()
660                    .insert(KnownHeaderName::Etag, etag);
661            }
662            conn
663        }
664    }
665
666    fn cache_client(server: CountingServer) -> (Client, Arc<AtomicUsize>) {
667        let counter = server.counter.clone();
668        let client = Client::new(ServerConnector::new(server))
669            .with_handler(Cache::new(InMemoryStorage::new()));
670        (client, counter)
671    }
672
673    #[test(harness)]
674    async fn first_request_misses_subsequent_request_hits() -> TestResult {
675        let (client, counter) = cache_client(CountingServer::new("max-age=600"));
676
677        let mut r1 = client.get("http://example.com/x").await?;
678        assert_eq!(r1.status(), Some(Status::Ok));
679        assert_eq!(r1.response_body().read_string().await?, "body-0");
680
681        let mut r2 = client.get("http://example.com/x").await?;
682        assert_eq!(r2.status(), Some(Status::Ok));
683        assert_eq!(r2.response_body().read_string().await?, "body-0");
684        assert_eq!(counter.load(Ordering::SeqCst), 1, "server only hit once");
685        Ok(())
686    }
687
688    #[test(harness)]
689    async fn different_urls_dont_collide() -> TestResult {
690        let (client, counter) = cache_client(CountingServer::new("max-age=600"));
691
692        let mut r1 = client.get("http://example.com/a").await?;
693        let mut r2 = client.get("http://example.com/b").await?;
694        assert_eq!(r1.response_body().read_string().await?, "body-0");
695        assert_eq!(r2.response_body().read_string().await?, "body-1");
696        assert_eq!(counter.load(Ordering::SeqCst), 2);
697        Ok(())
698    }
699
700    #[test(harness)]
701    async fn no_store_response_is_not_cached() -> TestResult {
702        let (client, counter) = cache_client(CountingServer::new("no-store"));
703
704        let mut r1 = client.get("http://example.com/x").await?;
705        assert_eq!(r1.response_body().read_string().await?, "body-0");
706
707        let mut r2 = client.get("http://example.com/x").await?;
708        assert_eq!(r2.response_body().read_string().await?, "body-1");
709        assert_eq!(counter.load(Ordering::SeqCst), 2);
710        Ok(())
711    }
712
713    #[test(harness)]
714    async fn post_invalidates_existing_entry() -> TestResult {
715        let (client, counter) = cache_client(CountingServer::new("max-age=600"));
716
717        let mut r1 = client.get("http://example.com/x").await?;
718        assert_eq!(r1.response_body().read_string().await?, "body-0");
719
720        let _ = client.post("http://example.com/x").await?;
721
722        let mut r3 = client.get("http://example.com/x").await?;
723        assert_eq!(r3.response_body().read_string().await?, "body-2");
724        assert_eq!(counter.load(Ordering::SeqCst), 3);
725        Ok(())
726    }
727
728    #[test(harness)]
729    async fn post_invalidates_location_and_content_location_targets() -> TestResult {
730        #[derive(Debug, Clone, Default)]
731        struct LclServer(Arc<AtomicUsize>);
732        impl ServerHandler for LclServer {
733            async fn run(&self, conn: ServerConn) -> ServerConn {
734                let n = self.0.fetch_add(1, Ordering::SeqCst);
735                if conn.method() == Method::Post {
736                    conn.with_response_header(KnownHeaderName::Location, "/loc")
737                        .with_response_header(KnownHeaderName::ContentLocation, "/cl")
738                        .ok(format!("post-body-{n}"))
739                } else {
740                    conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
741                        .ok(format!("get-body-{n}"))
742                }
743            }
744        }
745
746        let server = LclServer::default();
747        let counter = Arc::clone(&server.0);
748        let client = Client::new(ServerConnector::new(server))
749            .with_handler(Cache::new(InMemoryStorage::new()));
750
751        // Read each body so the streaming tee actually commits to storage.
752        let mut loc = client.get("http://example.com/loc").await?;
753        let _ = loc.response_body().read_string().await?;
754        let mut cl = client.get("http://example.com/cl").await?;
755        let _ = cl.response_body().read_string().await?;
756        assert_eq!(counter.load(Ordering::SeqCst), 2);
757
758        let _ = client.post("http://example.com/anything").await?;
759
760        let _ = client.get("http://example.com/loc").await?;
761        let _ = client.get("http://example.com/cl").await?;
762        assert_eq!(
763            counter.load(Ordering::SeqCst),
764            5,
765            "POST + 2 re-fetches should hit the origin again"
766        );
767        Ok(())
768    }
769
770    #[test(harness)]
771    async fn cross_host_location_does_not_invalidate() -> TestResult {
772        #[derive(Debug, Clone, Default)]
773        struct CrossHostServer(Arc<AtomicUsize>);
774        impl ServerHandler for CrossHostServer {
775            async fn run(&self, conn: ServerConn) -> ServerConn {
776                let n = self.0.fetch_add(1, Ordering::SeqCst);
777                if conn.method() == Method::Post {
778                    conn.with_response_header(KnownHeaderName::Location, "http://other.example/loc")
779                        .ok(format!("post-{n}"))
780                } else {
781                    conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
782                        .ok(format!("get-{n}"))
783                }
784            }
785        }
786
787        let server = CrossHostServer::default();
788        let counter = Arc::clone(&server.0);
789        let client = Client::new(ServerConnector::new(server))
790            .with_handler(Cache::new(InMemoryStorage::new()));
791
792        // Read the body to drive the tee into storage. Streaming-cache contract: nothing is
793        // cached unless the body is read.
794        let mut populating = client.get("http://other.example/loc").await?;
795        let _ = populating.response_body().read_string().await?;
796        assert_eq!(counter.load(Ordering::SeqCst), 1);
797
798        let _ = client.post("http://example.com/anything").await?;
799
800        let mut r = client.get("http://other.example/loc").await?;
801        assert_eq!(r.response_body().read_string().await?, "get-0");
802        assert_eq!(
803            counter.load(Ordering::SeqCst),
804            2,
805            "no extra GET to other.example"
806        );
807        Ok(())
808    }
809
810    #[test(harness)]
811    async fn stale_with_etag_revalidates_to_304() -> TestResult {
812        let (client, counter) = cache_client(CountingServer::new("max-age=0").with_etag(r#""v1""#));
813
814        let mut r1 = client.get("http://example.com/x").await?;
815        assert_eq!(r1.response_body().read_string().await?, "body-0");
816        assert_eq!(counter.load(Ordering::SeqCst), 1);
817
818        let mut r2 = client.get("http://example.com/x").await?;
819        assert_eq!(r2.status(), Some(Status::Ok));
820        assert_eq!(r2.response_body().read_string().await?, "body-0");
821        assert_eq!(counter.load(Ordering::SeqCst), 2);
822        Ok(())
823    }
824
825    #[test(harness)]
826    async fn stale_with_mismatching_etag_replaces_body() -> TestResult {
827        #[derive(Debug, Clone)]
828        struct AlwaysFresh {
829            counter: Arc<AtomicUsize>,
830        }
831        impl ServerHandler for AlwaysFresh {
832            async fn run(&self, conn: ServerConn) -> ServerConn {
833                let n = self.counter.fetch_add(1, Ordering::SeqCst);
834                conn.with_response_header(KnownHeaderName::CacheControl, "max-age=0")
835                    .with_response_header(KnownHeaderName::Etag, r#""rolling""#)
836                    .ok(format!("body-{n}"))
837            }
838        }
839        let counter = Arc::new(AtomicUsize::new(0));
840        let server = AlwaysFresh {
841            counter: counter.clone(),
842        };
843        let client = Client::new(ServerConnector::new(server))
844            .with_handler(Cache::new(InMemoryStorage::new()));
845
846        let mut r1 = client.get("http://example.com/x").await?;
847        assert_eq!(r1.response_body().read_string().await?, "body-0");
848
849        let mut r2 = client.get("http://example.com/x").await?;
850        assert_eq!(r2.response_body().read_string().await?, "body-1");
851        assert_eq!(counter.load(Ordering::SeqCst), 2);
852        Ok(())
853    }
854
855    #[test(harness)]
856    async fn vary_isolates_entries_by_request_header() -> TestResult {
857        #[derive(Debug, Clone)]
858        struct VaryServer {
859            counter: Arc<AtomicUsize>,
860        }
861        impl ServerHandler for VaryServer {
862            async fn run(&self, conn: ServerConn) -> ServerConn {
863                self.counter.fetch_add(1, Ordering::SeqCst);
864                let ae = conn
865                    .request_headers()
866                    .get_str(KnownHeaderName::AcceptEncoding)
867                    .unwrap_or("none")
868                    .to_string();
869                conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
870                    .with_response_header(KnownHeaderName::Vary, "Accept-Encoding")
871                    .ok(format!("body-for-{ae}"))
872            }
873        }
874        let counter = Arc::new(AtomicUsize::new(0));
875        let server = VaryServer {
876            counter: counter.clone(),
877        };
878        let client = Client::new(ServerConnector::new(server))
879            .with_handler(Cache::new(InMemoryStorage::new()));
880
881        let mut r1 = client
882            .get("http://example.com/x")
883            .with_request_header(KnownHeaderName::AcceptEncoding, "gzip")
884            .await?;
885        assert_eq!(r1.response_body().read_string().await?, "body-for-gzip");
886
887        let mut r2 = client
888            .get("http://example.com/x")
889            .with_request_header(KnownHeaderName::AcceptEncoding, "br")
890            .await?;
891        assert_eq!(r2.response_body().read_string().await?, "body-for-br");
892
893        let mut r3 = client
894            .get("http://example.com/x")
895            .with_request_header(KnownHeaderName::AcceptEncoding, "gzip")
896            .await?;
897        assert_eq!(r3.response_body().read_string().await?, "body-for-gzip");
898
899        assert_eq!(counter.load(Ordering::SeqCst), 2);
900        Ok(())
901    }
902
903    #[test(harness)]
904    async fn oversized_body_is_served_but_not_cached() -> TestResult {
905        let server = CountingServer::new("max-age=600");
906        let counter = server.counter.clone();
907        let client = Client::new(ServerConnector::new(server))
908            .with_handler(Cache::new(InMemoryStorage::new()).with_max_cacheable_size(3));
909
910        let mut r1 = client.get("http://example.com/x").await?;
911        assert_eq!(r1.response_body().read_string().await?, "body-0");
912
913        let mut r2 = client.get("http://example.com/x").await?;
914        assert_eq!(r2.response_body().read_string().await?, "body-1");
915        assert_eq!(counter.load(Ordering::SeqCst), 2);
916        Ok(())
917    }
918
919    // A chunked (unknown-length) upstream body must be stored and replayed *decoded* — not as
920    // raw chunk framing. Every other test here uses fixed-length bodies, which read raw and so
921    // never exercised the framing path.
922    #[test(harness)]
923    async fn chunked_upstream_is_stored_and_replayed_decoded() -> TestResult {
924        #[derive(Debug, Clone)]
925        struct ChunkedServer {
926            counter: Arc<AtomicUsize>,
927        }
928        impl ServerHandler for ChunkedServer {
929            async fn run(&self, conn: ServerConn) -> ServerConn {
930                self.counter.fetch_add(1, Ordering::SeqCst);
931                // No known length -> the server frames this as Transfer-Encoding: chunked.
932                let body = Body::new_streaming(
933                    futures_lite::io::Cursor::new(b"chunked-body-content".to_vec()),
934                    None,
935                );
936                conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
937                    .with_body(body)
938                    .with_status(Status::Ok)
939                    .halt()
940            }
941        }
942
943        let counter = Arc::new(AtomicUsize::new(0));
944        let server = ChunkedServer {
945            counter: counter.clone(),
946        };
947        let client = Client::new(ServerConnector::new(server))
948            .with_handler(Cache::new(InMemoryStorage::new()));
949
950        // MISS: the pass-through must deliver the decoded body, not chunk framing.
951        let mut r1 = client.get("http://example.com/x").await?;
952        assert_eq!(
953            r1.response_body().read_string().await?,
954            "chunked-body-content"
955        );
956
957        // HIT: the stored copy must replay decoded, with a known content-length.
958        let mut r2 = client.get("http://example.com/x").await?;
959        assert_eq!(r2.status(), Some(Status::Ok));
960        assert_eq!(
961            r2.response_body().read_string().await?,
962            "chunked-body-content"
963        );
964        assert_eq!(
965            counter.load(Ordering::SeqCst),
966            1,
967            "second request served from cache"
968        );
969        Ok(())
970    }
971
972    // ===== §4.2.4 / RFC 5861 stale-if-error =====
973
974    use crate::test_helpers::exchange;
975    use std::{io, net::SocketAddr};
976    use trillium_client::{Connector, Url};
977
978    /// Connector that always fails to connect. Used to drive the
979    /// transport-error code path in `Conn::exec` for SIE tests.
980    #[derive(Debug)]
981    struct FailingConnector {
982        inner: ServerConnector<Status>,
983    }
984
985    impl FailingConnector {
986        fn new() -> Self {
987            Self {
988                inner: ServerConnector::new(Status::Ok),
989            }
990        }
991    }
992
993    impl Connector for FailingConnector {
994        type Runtime = <ServerConnector<Status> as Connector>::Runtime;
995        type Transport = <ServerConnector<Status> as Connector>::Transport;
996        type Udp = <ServerConnector<Status> as Connector>::Udp;
997
998        async fn connect(&self, _url: &Url) -> io::Result<Self::Transport> {
999            Err(io::Error::new(
1000                io::ErrorKind::ConnectionRefused,
1001                "test failure",
1002            ))
1003        }
1004
1005        fn runtime(&self) -> Self::Runtime {
1006            self.inner.runtime().clone()
1007        }
1008
1009        async fn resolve(&self, host: &str, port: u16) -> io::Result<Vec<SocketAddr>> {
1010            self.inner.resolve(host, port).await
1011        }
1012    }
1013
1014    /// Build a stale, SIE-eligible cache entry by hand and pre-populate
1015    /// `storage`. Returns the URL/key under which the entry was stored.
1016    async fn populate_stale_entry(
1017        storage: &InMemoryStorage,
1018        cache_control: &'static str,
1019        body: &'static [u8],
1020    ) -> CacheKey {
1021        let conn = exchange(
1022            Method::Get,
1023            &[],
1024            Status::Ok,
1025            &[(KnownHeaderName::CacheControl, cache_control)],
1026        );
1027        let policy =
1028            crate::test_helpers::policy_from(&conn, SystemTime::now(), CacheOptions::default());
1029        let key = CacheKey::new(Method::Get, "http://example.com/x".parse().unwrap());
1030        let mut handle = storage.put(key.clone(), policy).await.unwrap();
1031        use futures_lite::AsyncWriteExt;
1032        handle.write_all(body).await.unwrap();
1033        handle.finalize(None).await.unwrap();
1034        key
1035    }
1036
1037    #[test(harness)]
1038    async fn sie_serves_stale_on_transport_error() -> TestResult {
1039        let storage = InMemoryStorage::new();
1040        let _ =
1041            populate_stale_entry(&storage, "max-age=0, stale-if-error=3600", b"stale body").await;
1042        let client = Client::new(FailingConnector::new()).with_handler(Cache::new(storage));
1043
1044        let mut conn = client.get("http://example.com/x").await?;
1045        assert_eq!(conn.status(), Some(Status::Ok));
1046        assert_eq!(conn.response_body().read_string().await?, "stale body");
1047        Ok(())
1048    }
1049
1050    #[test(harness)]
1051    async fn no_sie_propagates_transport_error() -> TestResult {
1052        let storage = InMemoryStorage::new();
1053        let _ = populate_stale_entry(&storage, "max-age=0", b"stale body").await;
1054        let client = Client::new(FailingConnector::new()).with_handler(Cache::new(storage));
1055
1056        let result = client.get("http://example.com/x").await;
1057        assert!(
1058            result.is_err(),
1059            "expected transport error to propagate, got {result:?}"
1060        );
1061        Ok(())
1062    }
1063
1064    #[test(harness)]
1065    async fn sie_serves_stale_on_5xx() -> TestResult {
1066        let storage = InMemoryStorage::new();
1067        let _ =
1068            populate_stale_entry(&storage, "max-age=0, stale-if-error=3600", b"stale body").await;
1069        let server = ServerConnector::new(Status::ServiceUnavailable);
1070        let client = Client::new(server).with_handler(Cache::new(storage));
1071
1072        let mut conn = client.get("http://example.com/x").await?;
1073        assert_eq!(conn.status(), Some(Status::Ok));
1074        assert_eq!(conn.response_body().read_string().await?, "stale body");
1075        Ok(())
1076    }
1077
1078    #[test(harness)]
1079    async fn no_sie_serves_5xx_as_received() -> TestResult {
1080        let storage = InMemoryStorage::new();
1081        let _ = populate_stale_entry(&storage, "max-age=0", b"stale body").await;
1082        let server = ServerConnector::new(Status::ServiceUnavailable);
1083        let client = Client::new(server).with_handler(Cache::new(storage));
1084
1085        let conn = client.get("http://example.com/x").await?;
1086        assert_eq!(conn.status(), Some(Status::ServiceUnavailable));
1087        Ok(())
1088    }
1089
1090    // ===== §4.2.4 / RFC 5861 stale-while-revalidate =====
1091
1092    use std::time::Duration;
1093
1094    #[test(harness)]
1095    async fn swr_serves_stale_immediately_and_revalidates_in_background() -> TestResult {
1096        let storage = InMemoryStorage::new();
1097        let _ = populate_stale_entry(
1098            &storage,
1099            "max-age=0, stale-while-revalidate=3600",
1100            b"stale-body",
1101        )
1102        .await;
1103
1104        let server = CountingServer::new("max-age=600");
1105        let counter = server.counter.clone();
1106        let client = Client::new(ServerConnector::new(server)).with_handler(Cache::new(storage));
1107
1108        let mut conn = client.get("http://example.com/x").await?;
1109        assert_eq!(conn.status(), Some(Status::Ok));
1110        assert_eq!(conn.response_body().read_string().await?, "stale-body");
1111
1112        let runtime = client.connector().runtime();
1113        for _ in 0..100 {
1114            if counter.load(Ordering::SeqCst) > 0 {
1115                break;
1116            }
1117            runtime.delay(Duration::from_millis(10)).await;
1118        }
1119        assert_eq!(
1120            counter.load(Ordering::SeqCst),
1121            1,
1122            "background revalidation should hit the origin"
1123        );
1124
1125        let cache = client
1126            .downcast_handler::<Cache<InMemoryStorage>>()
1127            .expect("cache handler installed");
1128        let key = CacheKey::new(Method::Get, "http://example.com/x".parse().unwrap());
1129        // Wait briefly for the background put to land in storage.
1130        for _ in 0..100 {
1131            if !cache.storage().get(&key).await.is_empty() {
1132                break;
1133            }
1134            runtime.delay(Duration::from_millis(10)).await;
1135        }
1136        let entries = cache.storage().get(&key).await;
1137        assert_eq!(entries.len(), 1);
1138        let body = entries[0].clone().open().await.unwrap();
1139        use futures_lite::AsyncReadExt;
1140        let mut buf = Vec::new();
1141        let mut body = body;
1142        body.read_to_end(&mut buf).await.unwrap();
1143        assert_eq!(&buf, b"body-0");
1144        Ok(())
1145    }
1146
1147    #[test(harness)]
1148    async fn no_swr_falls_back_to_synchronous_revalidation() -> TestResult {
1149        let storage = InMemoryStorage::new();
1150        let _ = populate_stale_entry(&storage, "max-age=0", b"stale-body").await;
1151
1152        let server = CountingServer::new("max-age=600");
1153        let counter = server.counter.clone();
1154        let client = Client::new(ServerConnector::new(server)).with_handler(Cache::new(storage));
1155
1156        let mut conn = client.get("http://example.com/x").await?;
1157        assert_eq!(conn.response_body().read_string().await?, "body-0");
1158        assert_eq!(counter.load(Ordering::SeqCst), 1);
1159        Ok(())
1160    }
1161
1162    #[test(harness)]
1163    async fn must_revalidate_disables_swr() -> TestResult {
1164        let storage = InMemoryStorage::new();
1165        let _ = populate_stale_entry(
1166            &storage,
1167            "max-age=0, must-revalidate, stale-while-revalidate=3600",
1168            b"stale-body",
1169        )
1170        .await;
1171
1172        let server = CountingServer::new("max-age=600");
1173        let client = Client::new(ServerConnector::new(server)).with_handler(Cache::new(storage));
1174
1175        let mut conn = client.get("http://example.com/x").await?;
1176        assert_eq!(conn.response_body().read_string().await?, "body-0");
1177        Ok(())
1178    }
1179}