Skip to main content

trillium_router/
router.rs

1use crate::{CapturesNewType, RouteSpecNewType, RouterRef};
2use routefinder::{Match, RouteSpec, Router as Routefinder};
3use std::{
4    collections::BTreeSet,
5    fmt::{self, Debug, Display, Formatter},
6    mem,
7};
8use trillium::{async_trait, Conn, Handler, Info, KnownHeaderName, Method, Upgrade};
9
10const ALL_METHODS: [Method; 5] = [
11    Method::Delete,
12    Method::Get,
13    Method::Patch,
14    Method::Post,
15    Method::Put,
16];
17
18#[derive(Debug)]
19enum MethodSelection {
20    Just(Method),
21    All,
22    Any(Vec<Method>),
23}
24
25impl Display for MethodSelection {
26    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
27        match self {
28            MethodSelection::Just(m) => Display::fmt(m, f),
29            MethodSelection::All => f.write_str("*"),
30            MethodSelection::Any(v) => {
31                f.write_str(&v.iter().map(|m| m.as_ref()).collect::<Vec<_>>().join(", "))
32            }
33        }
34    }
35}
36
37impl PartialEq<Method> for MethodSelection {
38    fn eq(&self, other: &Method) -> bool {
39        match self {
40            MethodSelection::Just(m) => m == other,
41            MethodSelection::All => true,
42            MethodSelection::Any(v) => v.contains(other),
43        }
44    }
45}
46
47impl From<()> for MethodSelection {
48    fn from(_: ()) -> MethodSelection {
49        Self::All
50    }
51}
52
53impl From<Method> for MethodSelection {
54    fn from(method: Method) -> Self {
55        Self::Just(method)
56    }
57}
58
59impl From<&[Method]> for MethodSelection {
60    fn from(methods: &[Method]) -> Self {
61        Self::Any(methods.to_vec())
62    }
63}
64impl From<Vec<Method>> for MethodSelection {
65    fn from(methods: Vec<Method>) -> Self {
66        Self::Any(methods)
67    }
68}
69
70#[derive(Debug, Default)]
71struct MethodRoutefinder(Routefinder<(MethodSelection, Box<dyn Handler>)>);
72impl MethodRoutefinder {
73    fn add<R>(
74        &mut self,
75        method_selection: impl Into<MethodSelection>,
76        path: R,
77        handler: impl Handler,
78    ) where
79        R: TryInto<RouteSpec>,
80        R::Error: Debug,
81    {
82        self.0
83            .add(path, (method_selection.into(), Box::new(handler)))
84            .expect("could not add route")
85    }
86
87    fn methods_matching(&self, path: &str) -> BTreeSet<Method> {
88        let mut set = BTreeSet::new();
89
90        fn extend(ms: &MethodSelection, set: &mut BTreeSet<Method>) {
91            match ms {
92                MethodSelection::All => {
93                    set.extend(ALL_METHODS);
94                }
95                MethodSelection::Just(method) => {
96                    set.insert(*method);
97                }
98                MethodSelection::Any(methods) => {
99                    set.extend(methods);
100                }
101            }
102        }
103
104        if path == "*" {
105            for ms in self.0.iter().map(|(_, (m, _))| m) {
106                extend(ms, &mut set);
107            }
108        } else {
109            for m in self.0.match_iter(path) {
110                extend(&m.0, &mut set);
111            }
112        };
113
114        set.remove(&Method::Options);
115        set
116    }
117
118    fn best_match<'a, 'b>(
119        &'a self,
120        method: Method,
121        path: &'b str,
122    ) -> Option<Match<'a, 'b, (MethodSelection, Box<dyn Handler>)>> {
123        self.0.match_iter(path).find(|m| m.0 == method)
124    }
125}
126
127/**
128# The Router handler
129
130See crate level docs for more, as this is the primary type in this crate.
131
132*/
133pub struct Router {
134    routefinder: MethodRoutefinder,
135    handle_options: bool,
136}
137
138impl Default for Router {
139    fn default() -> Self {
140        Self {
141            routefinder: MethodRoutefinder::default(),
142            handle_options: true,
143        }
144    }
145}
146
147macro_rules! method {
148    ($fn_name:ident, $method:ident) => {
149        method!(
150            $fn_name,
151            $method,
152            concat!(
153                // yep, macro-generated doctests
154                "Registers a handler for the ",
155                stringify!($fn_name),
156                " http method.
157
158```
159# use trillium::Conn;
160# use trillium_router::Router;
161let router = Router::new().",
162                stringify!($fn_name),
163                "(\"/some/route\", |conn: Conn| async move {
164  conn.ok(\"success\")
165});
166
167use trillium_testing::{methods::",
168                stringify!($fn_name),
169                ", assert_ok, assert_not_handled};
170assert_ok!(",
171                stringify!($fn_name),
172                "(\"/some/route\").on(&router), \"success\");
173assert_not_handled!(",
174                stringify!($fn_name),
175                "(\"/other/route\").on(&router));
176```
177"
178            )
179        );
180    };
181
182    ($fn_name:ident, $method:ident, $doc_comment:expr) => {
183        #[doc = $doc_comment]
184        pub fn $fn_name<R>(mut self, path: R, handler: impl Handler) -> Self
185        where
186            R: TryInto<RouteSpec>,
187            R::Error: Debug,
188        {
189            self.add(path, Method::$method, handler);
190            self
191        }
192    };
193}
194
195impl Router {
196    /**
197    Constructs a new Router. This is often used with [`Router::get`],
198    [`Router::post`], [`Router::put`], [`Router::delete`], and
199    [`Router::patch`] chainable methods to build up an application.
200
201    For an alternative way of constructing a Router, see [`Router::build`]
202
203    ```
204    # use trillium::Conn;
205    # use trillium_router::Router;
206
207    let router = Router::new()
208        .get("/", |conn: Conn| async move { conn.ok("you have reached the index") })
209        .get("/some/:param", |conn: Conn| async move { conn.ok("you have reached /some/:param") })
210        .post("/", |conn: Conn| async move { conn.ok("post!") });
211
212    use trillium_testing::prelude::*;
213    assert_ok!(get("/").on(&router), "you have reached the index");
214    assert_ok!(get("/some/route").on(&router), "you have reached /some/:param");
215    assert_ok!(post("/").on(&router), "post!");
216    ```
217     */
218    pub fn new() -> Self {
219        Self::default()
220    }
221
222    /**
223    Disable the default behavior of responding to OPTIONS requests
224    with the supported methods at a given path
225    */
226    pub fn without_options_handling(mut self) -> Self {
227        self.set_options_handling(false);
228        self
229    }
230
231    /**
232    enable or disable the router's behavior of responding to OPTIONS requests with the supported methods at given path.
233
234    default: enabled
235     */
236    pub(crate) fn set_options_handling(&mut self, options_enabled: bool) {
237        self.handle_options = options_enabled;
238    }
239
240    /**
241    Another way to build a router, if you don't like the chainable
242    interface described in [`Router::new`]. Note that the argument to
243    the closure is a [`RouterRef`].
244
245    ```
246    # use trillium::Conn;
247    # use trillium_router::Router;
248    let router = Router::build(|mut router| {
249        router.get("/", |conn: Conn| async move {
250            conn.ok("you have reached the index")
251        });
252
253        router.get("/some/:paramroute", |conn: Conn| async move {
254            conn.ok("you have reached /some/:param")
255        });
256
257        router.post("/", |conn: Conn| async move {
258            conn.ok("post!")
259        });
260    });
261
262
263    use trillium_testing::prelude::*;
264    assert_ok!(get("/").on(&router), "you have reached the index");
265    assert_ok!(get("/some/route").on(&router), "you have reached /some/:param");
266    assert_ok!(post("/").on(&router), "post!");
267    ```
268    */
269    pub fn build(builder: impl Fn(RouterRef)) -> Router {
270        let mut router = Router::new();
271        builder(RouterRef::new(&mut router));
272        router
273    }
274
275    fn best_match<'a, 'b>(
276        &'a self,
277        method: Method,
278        path: &'b str,
279    ) -> Option<Match<'a, 'b, (MethodSelection, Box<dyn Handler>)>> {
280        self.routefinder.best_match(method, path)
281    }
282
283    /**
284    Registers a handler for a method other than get, put, post, patch, or delete.
285
286    ```
287    # use trillium::{Conn, Method};
288    # use trillium_router::Router;
289    let router = Router::new()
290        .with_route("OPTIONS", "/some/route", |conn: Conn| async move { conn.ok("directly handling options") })
291        .with_route(Method::Checkin, "/some/route", |conn: Conn| async move { conn.ok("checkin??") });
292
293    use trillium_testing::{prelude::*, TestConn};
294    assert_ok!(TestConn::build(Method::Options, "/some/route", ()).on(&router), "directly handling options");
295    assert_ok!(TestConn::build("checkin", "/some/route", ()).on(&router), "checkin??");
296    ```
297    */
298    pub fn with_route<M, R>(mut self, method: M, path: R, handler: impl Handler) -> Self
299    where
300        M: TryInto<Method>,
301        <M as TryInto<Method>>::Error: Debug,
302        R: TryInto<RouteSpec>,
303        R::Error: Debug,
304    {
305        self.add(path, method.try_into().unwrap(), handler);
306        self
307    }
308
309    pub(crate) fn add<R>(&mut self, path: R, method: Method, handler: impl Handler)
310    where
311        R: TryInto<RouteSpec>,
312        R::Error: Debug,
313    {
314        self.routefinder.add(method, path, handler);
315    }
316
317    pub(crate) fn add_any<R>(&mut self, methods: &[Method], path: R, handler: impl Handler)
318    where
319        R: TryInto<RouteSpec>,
320        R::Error: Debug,
321    {
322        self.routefinder.add(methods, path, handler)
323    }
324
325    pub(crate) fn add_all<R>(&mut self, path: R, handler: impl Handler)
326    where
327        R: TryInto<RouteSpec>,
328        R::Error: Debug,
329    {
330        self.routefinder.add((), path, handler);
331    }
332
333    /**
334    Appends the handler to all (get, post, put, delete, and patch) methods.
335    ```
336    # use trillium::Conn;
337    # use trillium_router::Router;
338    let router = Router::new().all("/any", |conn: Conn| async move {
339        let response = format!("you made a {} request to /any", conn.method());
340        conn.ok(response)
341    });
342
343    use trillium_testing::prelude::*;
344    assert_ok!(get("/any").on(&router), "you made a GET request to /any");
345    assert_ok!(post("/any").on(&router), "you made a POST request to /any");
346    assert_ok!(delete("/any").on(&router), "you made a DELETE request to /any");
347    assert_ok!(patch("/any").on(&router), "you made a PATCH request to /any");
348    assert_ok!(put("/any").on(&router), "you made a PUT request to /any");
349
350    assert_not_handled!(get("/").on(&router));
351    ```
352    */
353    pub fn all<R>(mut self, path: R, handler: impl Handler) -> Self
354    where
355        R: TryInto<RouteSpec>,
356        R::Error: Debug,
357    {
358        self.add_all(path, handler);
359        self
360    }
361
362    /**
363    Appends the handler to each of the provided http methods.
364    ```
365    # use trillium::Conn;
366    # use trillium_router::Router;
367    let router = Router::new().any(&["get", "post"], "/get_or_post", |conn: Conn| async move {
368        let response = format!("you made a {} request to /get_or_post", conn.method());
369        conn.ok(response)
370    });
371
372    use trillium_testing::prelude::*;
373    assert_ok!(get("/get_or_post").on(&router), "you made a GET request to /get_or_post");
374    assert_ok!(post("/get_or_post").on(&router), "you made a POST request to /get_or_post");
375    assert_not_handled!(delete("/any").on(&router));
376    assert_not_handled!(patch("/any").on(&router));
377    assert_not_handled!(put("/any").on(&router));
378    assert_not_handled!(get("/").on(&router));
379    ```
380    */
381    pub fn any<IntoMethod, R>(
382        mut self,
383        methods: &[IntoMethod],
384        path: R,
385        handler: impl Handler,
386    ) -> Self
387    where
388        IntoMethod: TryInto<Method> + Clone,
389        <IntoMethod as TryInto<Method>>::Error: Debug,
390        R: TryInto<RouteSpec>,
391        R::Error: Debug,
392    {
393        let methods = methods
394            .iter()
395            .cloned()
396            .map(|m| m.try_into().unwrap())
397            .collect::<Vec<_>>();
398        self.add_any(&methods, path, handler);
399        self
400    }
401
402    method!(get, Get);
403    method!(post, Post);
404    method!(put, Put);
405    method!(delete, Delete);
406    method!(patch, Patch);
407}
408
409#[async_trait]
410impl Handler for Router {
411    async fn run(&self, mut conn: Conn) -> Conn {
412        let method = conn.method();
413        let original_captures = conn.take_state();
414        let path = conn.path();
415        let mut has_path = false;
416
417        if let Some(m) = self.best_match(conn.method(), path) {
418            let mut captures = m.captures().into_owned();
419
420            let route = m.route().clone();
421
422            if let Some(CapturesNewType(mut original_captures)) = original_captures {
423                original_captures.append(captures);
424                captures = original_captures;
425            }
426
427            log::debug!("running {}: {}", m.route(), m.1.name());
428            let mut new_conn = m
429                .handler()
430                .1
431                .run({
432                    if let Some(wildcard) = captures.wildcard() {
433                        conn.push_path(String::from(wildcard));
434                        has_path = true;
435                    }
436
437                    conn.with_state(CapturesNewType(captures))
438                        .with_state(RouteSpecNewType(route))
439                })
440                .await;
441
442            if has_path {
443                new_conn.pop_path();
444            }
445            new_conn
446        } else if method == Method::Options && self.handle_options {
447            let allow = self
448                .routefinder
449                .methods_matching(path)
450                .iter()
451                .map(|m| m.as_ref())
452                .collect::<Vec<_>>()
453                .join(", ");
454
455            return conn
456                .with_response_header(KnownHeaderName::Allow, allow)
457                .with_status(200)
458                .halt();
459        } else {
460            log::debug!("{} did not match any route", conn.path());
461            conn
462        }
463    }
464
465    async fn before_send(&self, conn: Conn) -> Conn {
466        let path = conn.path();
467        if let Some(m) = self.best_match(conn.method(), path) {
468            m.handler().1.before_send(conn).await
469        } else {
470            conn
471        }
472    }
473
474    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
475        if let Some(m) = self.best_match(*upgrade.method(), upgrade.path()) {
476            m.1.has_upgrade(upgrade)
477        } else {
478            false
479        }
480    }
481
482    async fn upgrade(&self, upgrade: Upgrade) {
483        self.best_match(*upgrade.method(), upgrade.path())
484            .unwrap()
485            .handler()
486            .1
487            .upgrade(upgrade)
488            .await
489    }
490
491    fn name(&self) -> std::borrow::Cow<'static, str> {
492        "Router".into()
493    }
494
495    async fn init(&mut self, info: &mut Info) {
496        // This code is not what a reader would expect, so here's a
497        // brief explanation:
498        //
499        // Currently, the init trait interface must return a Send
500        // future because that's the default for async-trait. We don't
501        // actually need it to be Send, but changing that would be a
502        // semver-minor trillium release.
503        //
504        // Mutable map iterators are not Send, and because we need to
505        // hold that data across await boundaries, we cannot mutate in
506        // place.
507        //
508        // However, because this is only called once at app boot, and
509        // because we have &mut self, it is safe to move the router
510        // contents into this future and then replace it, and the
511        // performance impacts of doing so are unimportant as it is
512        // part of app boot.
513        let routefinder = mem::take(&mut self.routefinder);
514        for (route, (methods, mut handler)) in routefinder.0 {
515            handler.init(info).await;
516            self.routefinder.add(methods, route, handler);
517        }
518    }
519}
520
521impl Debug for Router {
522    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
523        f.write_str("Router ")?;
524        let mut set = f.debug_set();
525
526        for (route, (methods, handler)) in &self.routefinder.0 {
527            set.entry(&format_args!("{} {} -> {}", methods, route, handler.name()));
528        }
529        set.finish()
530    }
531}