auth.rs

  1use std::sync::Arc;
  2
  3use super::db::{self, UserId};
  4use crate::{AppState, Error, Result};
  5use anyhow::{anyhow, Context};
  6use axum::{
  7    http::{self, Request, StatusCode},
  8    middleware::Next,
  9    response::IntoResponse,
 10};
 11use rand::thread_rng;
 12use scrypt::{
 13    password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
 14    Scrypt,
 15};
 16
 17pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
 18    let mut auth_header = req
 19        .headers()
 20        .get(http::header::AUTHORIZATION)
 21        .and_then(|header| header.to_str().ok())
 22        .ok_or_else(|| {
 23            Error::Http(
 24                StatusCode::BAD_REQUEST,
 25                "missing authorization header".to_string(),
 26            )
 27        })?
 28        .split_whitespace();
 29
 30    let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
 31        Error::Http(
 32            StatusCode::BAD_REQUEST,
 33            "missing user id in authorization header".to_string(),
 34        )
 35    })?);
 36
 37    let access_token = auth_header.next().ok_or_else(|| {
 38        Error::Http(
 39            StatusCode::BAD_REQUEST,
 40            "missing access token in authorization header".to_string(),
 41        )
 42    })?;
 43
 44    let state = req.extensions().get::<Arc<AppState>>().unwrap();
 45    let mut credentials_valid = false;
 46    for password_hash in state.db.get_access_token_hashes(user_id).await? {
 47        if verify_access_token(access_token, &password_hash)? {
 48            credentials_valid = true;
 49            break;
 50        }
 51    }
 52
 53    if credentials_valid {
 54        let user = state
 55            .db
 56            .get_user_by_id(user_id)
 57            .await?
 58            .ok_or_else(|| anyhow!("user {} not found", user_id))?;
 59        req.extensions_mut().insert(user);
 60        Ok::<_, Error>(next.run(req).await)
 61    } else {
 62        Err(Error::Http(
 63            StatusCode::UNAUTHORIZED,
 64            "invalid credentials".to_string(),
 65        ))
 66    }
 67}
 68
 69const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
 70
 71pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> Result<String> {
 72    let access_token = rpc::auth::random_token();
 73    let access_token_hash =
 74        hash_access_token(&access_token).context("failed to hash access token")?;
 75    db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
 76        .await?;
 77    Ok(access_token)
 78}
 79
 80fn hash_access_token(token: &str) -> Result<String> {
 81    // Avoid slow hashing in debug mode.
 82    let params = if cfg!(debug_assertions) {
 83        scrypt::Params::new(1, 1, 1).unwrap()
 84    } else {
 85        scrypt::Params::recommended()
 86    };
 87
 88    Ok(Scrypt
 89        .hash_password(
 90            token.as_bytes(),
 91            None,
 92            params,
 93            &SaltString::generate(thread_rng()),
 94        )
 95        .map_err(anyhow::Error::new)?
 96        .to_string())
 97}
 98
 99pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
100    let native_app_public_key =
101        rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
102    let encrypted_access_token = native_app_public_key
103        .encrypt_string(access_token)
104        .context("failed to encrypt access token with public key")?;
105    Ok(encrypted_access_token)
106}
107
108pub fn verify_access_token(token: &str, hash: &str) -> Result<bool> {
109    let hash = PasswordHash::new(hash).map_err(anyhow::Error::new)?;
110    Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
111}