auth.rs

  1use crate::{
  2    db::{self, AccessTokenId, Database, UserId},
  3    AppState, Error, Result,
  4};
  5use anyhow::{anyhow, Context};
  6use axum::{
  7    http::{self, Request, StatusCode},
  8    middleware::Next,
  9    response::IntoResponse,
 10};
 11use lazy_static::lazy_static;
 12use prometheus::{exponential_buckets, register_histogram, Histogram};
 13use rand::thread_rng;
 14use scrypt::{
 15    password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
 16    Scrypt,
 17};
 18use serde::{Deserialize, Serialize};
 19use std::{sync::Arc, time::Instant};
 20
 21lazy_static! {
 22    static ref METRIC_ACCESS_TOKEN_HASHING_TIME: Histogram = register_histogram!(
 23        "access_token_hashing_time",
 24        "time spent hashing access tokens",
 25        exponential_buckets(10.0, 2.0, 10).unwrap(),
 26    )
 27    .unwrap();
 28}
 29
 30/// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN
 31/// and one for the access tokens that we issue.
 32pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
 33    let mut auth_header = req
 34        .headers()
 35        .get(http::header::AUTHORIZATION)
 36        .and_then(|header| header.to_str().ok())
 37        .ok_or_else(|| {
 38            Error::Http(
 39                StatusCode::UNAUTHORIZED,
 40                "missing authorization header".to_string(),
 41            )
 42        })?
 43        .split_whitespace();
 44
 45    let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
 46        Error::Http(
 47            StatusCode::BAD_REQUEST,
 48            "missing user id in authorization header".to_string(),
 49        )
 50    })?);
 51
 52    let access_token = auth_header.next().ok_or_else(|| {
 53        Error::Http(
 54            StatusCode::BAD_REQUEST,
 55            "missing access token in authorization header".to_string(),
 56        )
 57    })?;
 58
 59    let state = req.extensions().get::<Arc<AppState>>().unwrap();
 60    let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
 61        state.config.api_token == admin_token
 62    } else {
 63        verify_access_token(&access_token, user_id, &state.db)
 64            .await
 65            .unwrap_or(false)
 66    };
 67
 68    if credentials_valid {
 69        let user = state
 70            .db
 71            .get_user_by_id(user_id)
 72            .await?
 73            .ok_or_else(|| anyhow!("user {} not found", user_id))?;
 74        req.extensions_mut().insert(user);
 75        Ok::<_, Error>(next.run(req).await)
 76    } else {
 77        Err(Error::Http(
 78            StatusCode::UNAUTHORIZED,
 79            "invalid credentials".to_string(),
 80        ))
 81    }
 82}
 83
 84const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
 85
 86#[derive(Serialize, Deserialize)]
 87struct AccessTokenJson {
 88    version: usize,
 89    id: AccessTokenId,
 90    token: String,
 91}
 92
 93/// Creates a new access token to identify the given user. before returning it, you should
 94/// encrypt it with the user's public key.
 95pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
 96    const VERSION: usize = 1;
 97    let access_token = rpc::auth::random_token();
 98    let access_token_hash =
 99        hash_access_token(&access_token).context("failed to hash access token")?;
100    let id = db
101        .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
102        .await?;
103    Ok(serde_json::to_string(&AccessTokenJson {
104        version: VERSION,
105        id,
106        token: access_token,
107    })?)
108}
109
110fn hash_access_token(token: &str) -> Result<String> {
111    // Avoid slow hashing in debug mode.
112    let params = if cfg!(debug_assertions) {
113        scrypt::Params::new(1, 1, 1).unwrap()
114    } else {
115        scrypt::Params::new(14, 8, 1).unwrap()
116    };
117
118    Ok(Scrypt
119        .hash_password(
120            token.as_bytes(),
121            None,
122            params,
123            &SaltString::generate(thread_rng()),
124        )
125        .map_err(anyhow::Error::new)?
126        .to_string())
127}
128
129/// Encrypts the given access token with the given public key to avoid leaking it on the way
130/// to the client.
131pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
132    let native_app_public_key =
133        rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
134    let encrypted_access_token = native_app_public_key
135        .encrypt_string(access_token)
136        .context("failed to encrypt access token with public key")?;
137    Ok(encrypted_access_token)
138}
139
140/// verify access token returns true if the given token is valid for the given user.
141pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc<Database>) -> Result<bool> {
142    let token: AccessTokenJson = serde_json::from_str(&token)?;
143
144    let db_token = db.get_access_token(token.id).await?;
145    if db_token.user_id != user_id {
146        return Err(anyhow!("no such access token"))?;
147    }
148
149    let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
150    let t0 = Instant::now();
151    let is_valid = Scrypt
152        .verify_password(token.token.as_bytes(), &db_hash)
153        .is_ok();
154    let duration = t0.elapsed();
155    log::info!("hashed access token in {:?}", duration);
156    METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64);
157    Ok(is_valid)
158}