1use crate::{CapturesNewType, RouteSpecNewType, RouterRef};
2use routefinder::{Match, RouteSpec, Router as Routefinder};
3use std::{
4 collections::BTreeSet,
5 fmt::{self, Debug, Display, Formatter},
6 mem,
7};
8use trillium::{async_trait, Conn, Handler, Info, KnownHeaderName, Method, Upgrade};
9
10const ALL_METHODS: [Method; 5] = [
11 Method::Delete,
12 Method::Get,
13 Method::Patch,
14 Method::Post,
15 Method::Put,
16];
17
18#[derive(Debug)]
19enum MethodSelection {
20 Just(Method),
21 All,
22 Any(Vec<Method>),
23}
24
25impl Display for MethodSelection {
26 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
27 match self {
28 MethodSelection::Just(m) => Display::fmt(m, f),
29 MethodSelection::All => f.write_str("*"),
30 MethodSelection::Any(v) => {
31 f.write_str(&v.iter().map(|m| m.as_ref()).collect::<Vec<_>>().join(", "))
32 }
33 }
34 }
35}
36
37impl PartialEq<Method> for MethodSelection {
38 fn eq(&self, other: &Method) -> bool {
39 match self {
40 MethodSelection::Just(m) => m == other,
41 MethodSelection::All => true,
42 MethodSelection::Any(v) => v.contains(other),
43 }
44 }
45}
46
47impl From<()> for MethodSelection {
48 fn from(_: ()) -> MethodSelection {
49 Self::All
50 }
51}
52
53impl From<Method> for MethodSelection {
54 fn from(method: Method) -> Self {
55 Self::Just(method)
56 }
57}
58
59impl From<&[Method]> for MethodSelection {
60 fn from(methods: &[Method]) -> Self {
61 Self::Any(methods.to_vec())
62 }
63}
64impl From<Vec<Method>> for MethodSelection {
65 fn from(methods: Vec<Method>) -> Self {
66 Self::Any(methods)
67 }
68}
69
70#[derive(Debug, Default)]
71struct MethodRoutefinder(Routefinder<(MethodSelection, Box<dyn Handler>)>);
72impl MethodRoutefinder {
73 fn add<R>(
74 &mut self,
75 method_selection: impl Into<MethodSelection>,
76 path: R,
77 handler: impl Handler,
78 ) where
79 R: TryInto<RouteSpec>,
80 R::Error: Debug,
81 {
82 self.0
83 .add(path, (method_selection.into(), Box::new(handler)))
84 .expect("could not add route")
85 }
86
87 fn methods_matching(&self, path: &str) -> BTreeSet<Method> {
88 let mut set = BTreeSet::new();
89
90 fn extend(ms: &MethodSelection, set: &mut BTreeSet<Method>) {
91 match ms {
92 MethodSelection::All => {
93 set.extend(ALL_METHODS);
94 }
95 MethodSelection::Just(method) => {
96 set.insert(*method);
97 }
98 MethodSelection::Any(methods) => {
99 set.extend(methods);
100 }
101 }
102 }
103
104 if path == "*" {
105 for ms in self.0.iter().map(|(_, (m, _))| m) {
106 extend(ms, &mut set);
107 }
108 } else {
109 for m in self.0.match_iter(path) {
110 extend(&m.0, &mut set);
111 }
112 };
113
114 set.remove(&Method::Options);
115 set
116 }
117
118 fn best_match<'a, 'b>(
119 &'a self,
120 method: Method,
121 path: &'b str,
122 ) -> Option<Match<'a, 'b, (MethodSelection, Box<dyn Handler>)>> {
123 self.0.match_iter(path).find(|m| m.0 == method)
124 }
125}
126
127pub struct Router {
134 routefinder: MethodRoutefinder,
135 handle_options: bool,
136}
137
138impl Default for Router {
139 fn default() -> Self {
140 Self {
141 routefinder: MethodRoutefinder::default(),
142 handle_options: true,
143 }
144 }
145}
146
147macro_rules! method {
148 ($fn_name:ident, $method:ident) => {
149 method!(
150 $fn_name,
151 $method,
152 concat!(
153 "Registers a handler for the ",
155 stringify!($fn_name),
156 " http method.
157
158```
159# use trillium::Conn;
160# use trillium_router::Router;
161let router = Router::new().",
162 stringify!($fn_name),
163 "(\"/some/route\", |conn: Conn| async move {
164 conn.ok(\"success\")
165});
166
167use trillium_testing::{methods::",
168 stringify!($fn_name),
169 ", assert_ok, assert_not_handled};
170assert_ok!(",
171 stringify!($fn_name),
172 "(\"/some/route\").on(&router), \"success\");
173assert_not_handled!(",
174 stringify!($fn_name),
175 "(\"/other/route\").on(&router));
176```
177"
178 )
179 );
180 };
181
182 ($fn_name:ident, $method:ident, $doc_comment:expr) => {
183 #[doc = $doc_comment]
184 pub fn $fn_name<R>(mut self, path: R, handler: impl Handler) -> Self
185 where
186 R: TryInto<RouteSpec>,
187 R::Error: Debug,
188 {
189 self.add(path, Method::$method, handler);
190 self
191 }
192 };
193}
194
195impl Router {
196 pub fn new() -> Self {
219 Self::default()
220 }
221
222 pub fn without_options_handling(mut self) -> Self {
227 self.set_options_handling(false);
228 self
229 }
230
231 pub(crate) fn set_options_handling(&mut self, options_enabled: bool) {
237 self.handle_options = options_enabled;
238 }
239
240 pub fn build(builder: impl Fn(RouterRef)) -> Router {
270 let mut router = Router::new();
271 builder(RouterRef::new(&mut router));
272 router
273 }
274
275 fn best_match<'a, 'b>(
276 &'a self,
277 method: Method,
278 path: &'b str,
279 ) -> Option<Match<'a, 'b, (MethodSelection, Box<dyn Handler>)>> {
280 self.routefinder.best_match(method, path)
281 }
282
283 pub fn with_route<M, R>(mut self, method: M, path: R, handler: impl Handler) -> Self
299 where
300 M: TryInto<Method>,
301 <M as TryInto<Method>>::Error: Debug,
302 R: TryInto<RouteSpec>,
303 R::Error: Debug,
304 {
305 self.add(path, method.try_into().unwrap(), handler);
306 self
307 }
308
309 pub(crate) fn add<R>(&mut self, path: R, method: Method, handler: impl Handler)
310 where
311 R: TryInto<RouteSpec>,
312 R::Error: Debug,
313 {
314 self.routefinder.add(method, path, handler);
315 }
316
317 pub(crate) fn add_any<R>(&mut self, methods: &[Method], path: R, handler: impl Handler)
318 where
319 R: TryInto<RouteSpec>,
320 R::Error: Debug,
321 {
322 self.routefinder.add(methods, path, handler)
323 }
324
325 pub(crate) fn add_all<R>(&mut self, path: R, handler: impl Handler)
326 where
327 R: TryInto<RouteSpec>,
328 R::Error: Debug,
329 {
330 self.routefinder.add((), path, handler);
331 }
332
333 pub fn all<R>(mut self, path: R, handler: impl Handler) -> Self
354 where
355 R: TryInto<RouteSpec>,
356 R::Error: Debug,
357 {
358 self.add_all(path, handler);
359 self
360 }
361
362 pub fn any<IntoMethod, R>(
382 mut self,
383 methods: &[IntoMethod],
384 path: R,
385 handler: impl Handler,
386 ) -> Self
387 where
388 IntoMethod: TryInto<Method> + Clone,
389 <IntoMethod as TryInto<Method>>::Error: Debug,
390 R: TryInto<RouteSpec>,
391 R::Error: Debug,
392 {
393 let methods = methods
394 .iter()
395 .cloned()
396 .map(|m| m.try_into().unwrap())
397 .collect::<Vec<_>>();
398 self.add_any(&methods, path, handler);
399 self
400 }
401
402 method!(get, Get);
403 method!(post, Post);
404 method!(put, Put);
405 method!(delete, Delete);
406 method!(patch, Patch);
407}
408
409#[async_trait]
410impl Handler for Router {
411 async fn run(&self, mut conn: Conn) -> Conn {
412 let method = conn.method();
413 let original_captures = conn.take_state();
414 let path = conn.path();
415 let mut has_path = false;
416
417 if let Some(m) = self.best_match(conn.method(), path) {
418 let mut captures = m.captures().into_owned();
419
420 let route = m.route().clone();
421
422 if let Some(CapturesNewType(mut original_captures)) = original_captures {
423 original_captures.append(captures);
424 captures = original_captures;
425 }
426
427 log::debug!("running {}: {}", m.route(), m.1.name());
428 let mut new_conn = m
429 .handler()
430 .1
431 .run({
432 if let Some(wildcard) = captures.wildcard() {
433 conn.push_path(String::from(wildcard));
434 has_path = true;
435 }
436
437 conn.with_state(CapturesNewType(captures))
438 .with_state(RouteSpecNewType(route))
439 })
440 .await;
441
442 if has_path {
443 new_conn.pop_path();
444 }
445 new_conn
446 } else if method == Method::Options && self.handle_options {
447 let allow = self
448 .routefinder
449 .methods_matching(path)
450 .iter()
451 .map(|m| m.as_ref())
452 .collect::<Vec<_>>()
453 .join(", ");
454
455 return conn
456 .with_response_header(KnownHeaderName::Allow, allow)
457 .with_status(200)
458 .halt();
459 } else {
460 log::debug!("{} did not match any route", conn.path());
461 conn
462 }
463 }
464
465 async fn before_send(&self, conn: Conn) -> Conn {
466 let path = conn.path();
467 if let Some(m) = self.best_match(conn.method(), path) {
468 m.handler().1.before_send(conn).await
469 } else {
470 conn
471 }
472 }
473
474 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
475 if let Some(m) = self.best_match(*upgrade.method(), upgrade.path()) {
476 m.1.has_upgrade(upgrade)
477 } else {
478 false
479 }
480 }
481
482 async fn upgrade(&self, upgrade: Upgrade) {
483 self.best_match(*upgrade.method(), upgrade.path())
484 .unwrap()
485 .handler()
486 .1
487 .upgrade(upgrade)
488 .await
489 }
490
491 fn name(&self) -> std::borrow::Cow<'static, str> {
492 "Router".into()
493 }
494
495 async fn init(&mut self, info: &mut Info) {
496 let routefinder = mem::take(&mut self.routefinder);
514 for (route, (methods, mut handler)) in routefinder.0 {
515 handler.init(info).await;
516 self.routefinder.add(methods, route, handler);
517 }
518 }
519}
520
521impl Debug for Router {
522 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
523 f.write_str("Router ")?;
524 let mut set = f.debug_set();
525
526 for (route, (methods, handler)) in &self.routefinder.0 {
527 set.entry(&format_args!("{} {} -> {}", methods, route, handler.name()));
528 }
529 set.finish()
530 }
531}