1use std::sync::Arc;
2
3use super::db::{self, UserId};
4use crate::{AppState, Error, Result};
5use anyhow::{anyhow, Context};
6use axum::{
7 http::{self, Request, StatusCode},
8 middleware::Next,
9 response::IntoResponse,
10};
11use rand::thread_rng;
12use scrypt::{
13 password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
14 Scrypt,
15};
16
17pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
18 let mut auth_header = req
19 .headers()
20 .get(http::header::AUTHORIZATION)
21 .and_then(|header| header.to_str().ok())
22 .ok_or_else(|| {
23 Error::Http(
24 StatusCode::BAD_REQUEST,
25 "missing authorization header".to_string(),
26 )
27 })?
28 .split_whitespace();
29
30 let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
31 Error::Http(
32 StatusCode::BAD_REQUEST,
33 "missing user id in authorization header".to_string(),
34 )
35 })?);
36
37 let access_token = auth_header.next().ok_or_else(|| {
38 Error::Http(
39 StatusCode::BAD_REQUEST,
40 "missing access token in authorization header".to_string(),
41 )
42 })?;
43
44 let state = req.extensions().get::<Arc<AppState>>().unwrap();
45 let mut credentials_valid = false;
46 for password_hash in state.db.get_access_token_hashes(user_id).await? {
47 if verify_access_token(access_token, &password_hash)? {
48 credentials_valid = true;
49 break;
50 }
51 }
52
53 if credentials_valid {
54 let user = state
55 .db
56 .get_user_by_id(user_id)
57 .await?
58 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
59 req.extensions_mut().insert(user);
60 Ok::<_, Error>(next.run(req).await)
61 } else {
62 Err(Error::Http(
63 StatusCode::UNAUTHORIZED,
64 "invalid credentials".to_string(),
65 ))
66 }
67}
68
69const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
70
71pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> Result<String> {
72 let access_token = rpc::auth::random_token();
73 let access_token_hash =
74 hash_access_token(&access_token).context("failed to hash access token")?;
75 db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
76 .await?;
77 Ok(access_token)
78}
79
80fn hash_access_token(token: &str) -> Result<String> {
81 // Avoid slow hashing in debug mode.
82 let params = if cfg!(debug_assertions) {
83 scrypt::Params::new(1, 1, 1).unwrap()
84 } else {
85 scrypt::Params::recommended()
86 };
87
88 Ok(Scrypt
89 .hash_password(
90 token.as_bytes(),
91 None,
92 params,
93 &SaltString::generate(thread_rng()),
94 )
95 .map_err(anyhow::Error::new)?
96 .to_string())
97}
98
99pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
100 let native_app_public_key =
101 rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
102 let encrypted_access_token = native_app_public_key
103 .encrypt_string(access_token)
104 .context("failed to encrypt access token with public key")?;
105 Ok(encrypted_access_token)
106}
107
108pub fn verify_access_token(token: &str, hash: &str) -> Result<bool> {
109 let hash = PasswordHash::new(hash).map_err(anyhow::Error::new)?;
110 Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
111}