trillium_method_override/
lib.rs1#![forbid(unsafe_code)]
20#![deny(
21 missing_copy_implementations,
22 rustdoc::missing_crate_level_docs,
23 missing_debug_implementations,
24 missing_docs,
25 nonstandard_style,
26 unused_qualifications
27)]
28
29use querystrong::QueryStrong;
30use std::{collections::HashSet, fmt::Debug};
31use trillium::{async_trait, conn_unwrap, Conn, Handler, Method};
32
33#[derive(Clone, Debug)]
39pub struct MethodOverride {
40 param: &'static str,
41 allowed_methods: HashSet<Method>,
42}
43
44impl Default for MethodOverride {
45 fn default() -> Self {
46 Self {
47 param: "_method",
48 allowed_methods: HashSet::from_iter([Method::Put, Method::Patch, Method::Delete]),
49 }
50 }
51}
52
53impl MethodOverride {
54 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn with_allowed_methods<M>(mut self, methods: impl IntoIterator<Item = M>) -> Self
70 where
71 M: TryInto<Method>,
72 <M as TryInto<Method>>::Error: Debug,
73 {
74 self.allowed_methods = methods.into_iter().map(|m| m.try_into().unwrap()).collect();
75 self
76 }
77
78 pub fn with_param_name(mut self, param_name: &'static str) -> Self {
89 self.param = param_name;
90 self
91 }
92}
93
94#[async_trait]
95impl Handler for MethodOverride {
96 async fn run(&self, mut conn: Conn) -> Conn {
97 if conn.method() != Method::Post {
98 return conn;
99 }
100 let qs = conn_unwrap!(QueryStrong::parse(conn.querystring()).ok(), conn);
101 let method_str = conn_unwrap!(qs.get_str(self.param), conn);
102 let method: Method = conn_unwrap!(method_str.try_into().ok(), conn);
103 if self.allowed_methods.contains(&method) {
104 conn.inner_mut().set_method(method);
105 }
106 conn
107 }
108}
109
110pub fn method_override() -> MethodOverride {
112 MethodOverride::new()
113}