Skip to main content

trillium_method_override/
lib.rs

1/*!
2# Trillium method override handler
3
4This allows http clients that are unable to generate http methods
5other than `GET` and `POST` to use `POST` requests that are
6interpreted as other methods such as `PUT`, `PATCH`, or `DELETE`.
7
8This is currently supported with a querystring parameter of
9`_method`. To change the querystring parameter's name, use
10[`MethodOverride::with_param_name`]
11
12By default, the only methods allowed are `PUT`, `PATCH`, and
13`DELETE`. To override this, use
14[`MethodOverride::with_allowed_methods`]
15
16Subsequent handlers see the requested method on the conn instead of
17POST.
18*/
19#![forbid(unsafe_code)]
20#![deny(
21    missing_copy_implementations,
22    rustdoc::missing_crate_level_docs,
23    missing_debug_implementations,
24    missing_docs,
25    nonstandard_style,
26    unused_qualifications
27)]
28
29use querystrong::QueryStrong;
30use std::{collections::HashSet, fmt::Debug};
31use trillium::{async_trait, conn_unwrap, Conn, Handler, Method};
32
33/**
34Trillium method override handler
35
36See crate-level docs for an explanation
37*/
38#[derive(Clone, Debug)]
39pub struct MethodOverride {
40    param: &'static str,
41    allowed_methods: HashSet<Method>,
42}
43
44impl Default for MethodOverride {
45    fn default() -> Self {
46        Self {
47            param: "_method",
48            allowed_methods: HashSet::from_iter([Method::Put, Method::Patch, Method::Delete]),
49        }
50    }
51}
52
53impl MethodOverride {
54    /// constructs a new MethodOverride handler with default allowed methods and param name
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    /**
60    replace the default allowed methods with the provided list of methods
61
62    default: `put`, `patch`, `delete`
63
64    ```
65    # use trillium_method_override::MethodOverride;
66    let handler = MethodOverride::new().with_allowed_methods(["put", "patch"]);
67    ```
68    */
69    pub fn with_allowed_methods<M>(mut self, methods: impl IntoIterator<Item = M>) -> Self
70    where
71        M: TryInto<Method>,
72        <M as TryInto<Method>>::Error: Debug,
73    {
74        self.allowed_methods = methods.into_iter().map(|m| m.try_into().unwrap()).collect();
75        self
76    }
77
78    /**
79    replace the default param name with the provided param name
80
81    default: `_method`
82    ```
83    # use trillium_method_override::MethodOverride;
84    let handler = MethodOverride::new().with_param_name("_http_method");
85    ```
86    */
87
88    pub fn with_param_name(mut self, param_name: &'static str) -> Self {
89        self.param = param_name;
90        self
91    }
92}
93
94#[async_trait]
95impl Handler for MethodOverride {
96    async fn run(&self, mut conn: Conn) -> Conn {
97        if conn.method() != Method::Post {
98            return conn;
99        }
100        let qs = conn_unwrap!(QueryStrong::parse(conn.querystring()).ok(), conn);
101        let method_str = conn_unwrap!(qs.get_str(self.param), conn);
102        let method: Method = conn_unwrap!(method_str.try_into().ok(), conn);
103        if self.allowed_methods.contains(&method) {
104            conn.inner_mut().set_method(method);
105        }
106        conn
107    }
108}
109
110/// Alias for MethodOverride::new()
111pub fn method_override() -> MethodOverride {
112    MethodOverride::new()
113}