Skip to main content

trillium_basic_auth/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(
3    clippy::dbg_macro,
4    missing_copy_implementations,
5    rustdoc::missing_crate_level_docs,
6    missing_debug_implementations,
7    missing_docs,
8    nonstandard_style,
9    unused_qualifications
10)]
11
12//! Basic authentication for trillium.rs
13//!
14//! ```rust,no_run
15//! use trillium_basic_auth::BasicAuth;
16//! trillium_smol::run((
17//!     BasicAuth::new("trillium", "7r1ll1um").with_realm("rust"),
18//!     |conn: trillium::Conn| async move { conn.ok("authenticated") },
19//! ));
20//! ```
21
22#[cfg(test)]
23#[doc = include_str!("../README.md")]
24mod readme {}
25
26use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
27use trillium::{
28    Conn, Handler,
29    KnownHeaderName::{Authorization, WwwAuthenticate},
30    Status,
31};
32
33/// basic auth handler
34#[derive(Clone, Debug)]
35pub struct BasicAuth {
36    credentials: Credentials,
37    realm: Option<String>,
38
39    // precomputed/derived data fields:
40    expected_header: String,
41    www_authenticate: String,
42}
43
44/// basic auth username-password credentials
45#[derive(Clone, Debug, PartialEq, Eq, fieldwork::Fieldwork)]
46#[fieldwork(get)]
47pub struct Credentials {
48    /// username
49    username: String,
50
51    /// password
52    password: String,
53}
54
55impl Credentials {
56    fn new(username: &str, password: &str) -> Self {
57        Self {
58            username: String::from(username),
59            password: String::from(password),
60        }
61    }
62
63    fn expected_header(&self) -> String {
64        format!(
65            "Basic {}",
66            BASE64.encode(format!("{}:{}", self.username, self.password))
67        )
68    }
69
70    // const BASIC: &str = "Basic ";
71    // pub fn for_conn(conn: &Conn) -> Option<Self> {
72    //     conn.request_headers()
73    //         .get_str(KnownHeaderName::Authorization)
74    //         .and_then(|value| {
75    //             if value[..BASIC.len().min(value.len())].eq_ignore_ascii_case(BASIC) {
76    //                 Some(&value[BASIC.len()..])
77    //             } else {
78    //                 None
79    //             }
80    //         })
81    //         .and_then(|base64_credentials| BASE64.decode(base64_credentials).ok())
82    //         .and_then(|credential_bytes| String::from_utf8(credential_bytes).ok())
83    //         .and_then(|mut credential_string| {
84    //             credential_string.find(":").map(|colon| {
85    //                 let password = credential_string.split_off(colon + 1).into();
86    //                 credential_string.pop();
87    //                 Self {
88    //                     username: credential_string.into(),
89    //                     password,
90    //                 }
91    //             })
92    //         })
93    // }
94}
95
96impl BasicAuth {
97    /// build a new basic auth handler with the provided username and password
98    pub fn new(username: &str, password: &str) -> Self {
99        let credentials = Credentials::new(username, password);
100        let expected_header = credentials.expected_header();
101        let realm = None;
102        Self {
103            expected_header,
104            credentials,
105            realm,
106            www_authenticate: String::from("Basic"),
107        }
108    }
109
110    /// provide a realm for the www-authenticate response sent by this handler
111    pub fn with_realm(mut self, realm: &str) -> Self {
112        self.www_authenticate = format!("Basic realm=\"{}\"", realm.replace('\"', "\\\""));
113        self.realm = Some(String::from(realm));
114        self
115    }
116
117    fn is_allowed(&self, conn: &Conn) -> bool {
118        conn.request_headers().get_str(Authorization) == Some(&*self.expected_header)
119    }
120
121    fn deny(&self, conn: Conn) -> Conn {
122        conn.with_status(Status::Unauthorized)
123            .with_response_header(WwwAuthenticate, self.www_authenticate.clone())
124            .halt()
125    }
126}
127
128impl Handler for BasicAuth {
129    async fn run(&self, conn: Conn) -> Conn {
130        if self.is_allowed(&conn) {
131            conn.with_state(self.credentials.clone())
132        } else {
133            self.deny(conn)
134        }
135    }
136}