auth.rs

  1use crate::{
  2    db::{self, AccessTokenId, Database, UserId},
  3    AppState, Error, Result,
  4};
  5use anyhow::{anyhow, Context};
  6use axum::{
  7    http::{self, Request, StatusCode},
  8    middleware::Next,
  9    response::IntoResponse,
 10};
 11use lazy_static::lazy_static;
 12use prometheus::{exponential_buckets, register_histogram, Histogram};
 13use rand::thread_rng;
 14use scrypt::{
 15    password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
 16    Scrypt,
 17};
 18use serde::{Deserialize, Serialize};
 19use std::{sync::Arc, time::Instant};
 20
 21lazy_static! {
 22    static ref METRIC_ACCESS_TOKEN_HASHING_TIME: Histogram = register_histogram!(
 23        "access_token_hashing_time",
 24        "time spent hashing access tokens",
 25        exponential_buckets(10.0, 2.0, 10).unwrap(),
 26    )
 27    .unwrap();
 28}
 29
 30#[derive(Clone, Debug, Default, PartialEq, Eq)]
 31pub struct Impersonator(pub Option<db::User>);
 32
 33/// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN
 34/// and one for the access tokens that we issue.
 35pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
 36    let mut auth_header = req
 37        .headers()
 38        .get(http::header::AUTHORIZATION)
 39        .and_then(|header| header.to_str().ok())
 40        .ok_or_else(|| {
 41            Error::Http(
 42                StatusCode::UNAUTHORIZED,
 43                "missing authorization header".to_string(),
 44            )
 45        })?
 46        .split_whitespace();
 47
 48    let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
 49        Error::Http(
 50            StatusCode::BAD_REQUEST,
 51            "missing user id in authorization header".to_string(),
 52        )
 53    })?);
 54
 55    let access_token = auth_header.next().ok_or_else(|| {
 56        Error::Http(
 57            StatusCode::BAD_REQUEST,
 58            "missing access token in authorization header".to_string(),
 59        )
 60    })?;
 61
 62    let state = req.extensions().get::<Arc<AppState>>().unwrap();
 63
 64    // In development, allow impersonation using the admin API token.
 65    // Don't allow this in production because we can't tell who is doing
 66    // the impersonating.
 67    let validate_result = if let (Some(admin_token), true) = (
 68        access_token.strip_prefix("ADMIN_TOKEN:"),
 69        state.config.is_development(),
 70    ) {
 71        Ok(VerifyAccessTokenResult {
 72            is_valid: state.config.api_token == admin_token,
 73            impersonator_id: None,
 74        })
 75    } else {
 76        verify_access_token(&access_token, user_id, &state.db).await
 77    };
 78
 79    if let Ok(validate_result) = validate_result {
 80        if validate_result.is_valid {
 81            let user = state
 82                .db
 83                .get_user_by_id(user_id)
 84                .await?
 85                .ok_or_else(|| anyhow!("user {} not found", user_id))?;
 86
 87            let impersonator = if let Some(impersonator_id) = validate_result.impersonator_id {
 88                let impersonator = state
 89                    .db
 90                    .get_user_by_id(impersonator_id)
 91                    .await?
 92                    .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
 93                Some(impersonator)
 94            } else {
 95                None
 96            };
 97            req.extensions_mut().insert(user);
 98            req.extensions_mut().insert(Impersonator(impersonator));
 99            return Ok::<_, Error>(next.run(req).await);
100        }
101    }
102
103    Err(Error::Http(
104        StatusCode::UNAUTHORIZED,
105        "invalid credentials".to_string(),
106    ))
107}
108
109const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
110
111#[derive(Serialize, Deserialize)]
112struct AccessTokenJson {
113    version: usize,
114    id: AccessTokenId,
115    token: String,
116}
117
118/// Creates a new access token to identify the given user. before returning it, you should
119/// encrypt it with the user's public key.
120pub async fn create_access_token(
121    db: &db::Database,
122    user_id: UserId,
123    impersonated_user_id: Option<UserId>,
124) -> Result<String> {
125    const VERSION: usize = 1;
126    let access_token = rpc::auth::random_token();
127    let access_token_hash =
128        hash_access_token(&access_token).context("failed to hash access token")?;
129    let id = db
130        .create_access_token(
131            user_id,
132            impersonated_user_id,
133            &access_token_hash,
134            MAX_ACCESS_TOKENS_TO_STORE,
135        )
136        .await?;
137    Ok(serde_json::to_string(&AccessTokenJson {
138        version: VERSION,
139        id,
140        token: access_token,
141    })?)
142}
143
144fn hash_access_token(token: &str) -> Result<String> {
145    // Avoid slow hashing in debug mode.
146    let params = if cfg!(debug_assertions) {
147        scrypt::Params::new(1, 1, 1).unwrap()
148    } else {
149        scrypt::Params::new(14, 8, 1).unwrap()
150    };
151
152    Ok(Scrypt
153        .hash_password(
154            token.as_bytes(),
155            None,
156            params,
157            &SaltString::generate(thread_rng()),
158        )
159        .map_err(anyhow::Error::new)?
160        .to_string())
161}
162
163/// Encrypts the given access token with the given public key to avoid leaking it on the way
164/// to the client.
165pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
166    let native_app_public_key =
167        rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
168    let encrypted_access_token = native_app_public_key
169        .encrypt_string(access_token)
170        .context("failed to encrypt access token with public key")?;
171    Ok(encrypted_access_token)
172}
173
174pub struct VerifyAccessTokenResult {
175    pub is_valid: bool,
176    pub impersonator_id: Option<UserId>,
177}
178
179/// Checks that the given access token is valid for the given user.
180pub async fn verify_access_token(
181    token: &str,
182    user_id: UserId,
183    db: &Arc<Database>,
184) -> Result<VerifyAccessTokenResult> {
185    let token: AccessTokenJson = serde_json::from_str(&token)?;
186
187    let db_token = db.get_access_token(token.id).await?;
188    let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
189    if token_user_id != user_id {
190        return Err(anyhow!("no such access token"))?;
191    }
192
193    let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
194    let t0 = Instant::now();
195    let is_valid = Scrypt
196        .verify_password(token.token.as_bytes(), &db_hash)
197        .is_ok();
198    let duration = t0.elapsed();
199    log::info!("hashed access token in {:?}", duration);
200    METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64);
201    Ok(VerifyAccessTokenResult {
202        is_valid,
203        impersonator_id: if db_token.impersonated_user_id.is_some() {
204            Some(db_token.user_id)
205        } else {
206            None
207        },
208    })
209}