auth.rs

  1use crate::{
  2    db::{self, AccessTokenId, Database, UserId},
  3    AppState, Error, Result,
  4};
  5use anyhow::{anyhow, Context};
  6use axum::{
  7    http::{self, Request, StatusCode},
  8    middleware::Next,
  9    response::IntoResponse,
 10};
 11use prometheus::{exponential_buckets, register_histogram, Histogram};
 12use rand::thread_rng;
 13use scrypt::{
 14    password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
 15    Scrypt,
 16};
 17use serde::{Deserialize, Serialize};
 18use std::sync::OnceLock;
 19use std::{sync::Arc, time::Instant};
 20
 21#[derive(Clone, Debug, Default, PartialEq, Eq)]
 22pub struct Impersonator(pub Option<db::User>);
 23
 24/// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN
 25/// and one for the access tokens that we issue.
 26pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
 27    let mut auth_header = req
 28        .headers()
 29        .get(http::header::AUTHORIZATION)
 30        .and_then(|header| header.to_str().ok())
 31        .ok_or_else(|| {
 32            Error::Http(
 33                StatusCode::UNAUTHORIZED,
 34                "missing authorization header".to_string(),
 35            )
 36        })?
 37        .split_whitespace();
 38
 39    let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
 40        Error::Http(
 41            StatusCode::BAD_REQUEST,
 42            "missing user id in authorization header".to_string(),
 43        )
 44    })?);
 45
 46    let access_token = auth_header.next().ok_or_else(|| {
 47        Error::Http(
 48            StatusCode::BAD_REQUEST,
 49            "missing access token in authorization header".to_string(),
 50        )
 51    })?;
 52
 53    let state = req.extensions().get::<Arc<AppState>>().unwrap();
 54
 55    // In development, allow impersonation using the admin API token.
 56    // Don't allow this in production because we can't tell who is doing
 57    // the impersonating.
 58    let validate_result = if let (Some(admin_token), true) = (
 59        access_token.strip_prefix("ADMIN_TOKEN:"),
 60        state.config.is_development(),
 61    ) {
 62        Ok(VerifyAccessTokenResult {
 63            is_valid: state.config.api_token == admin_token,
 64            impersonator_id: None,
 65        })
 66    } else {
 67        verify_access_token(&access_token, user_id, &state.db).await
 68    };
 69
 70    if let Ok(validate_result) = validate_result {
 71        if validate_result.is_valid {
 72            let user = state
 73                .db
 74                .get_user_by_id(user_id)
 75                .await?
 76                .ok_or_else(|| anyhow!("user {} not found", user_id))?;
 77
 78            let impersonator = if let Some(impersonator_id) = validate_result.impersonator_id {
 79                let impersonator = state
 80                    .db
 81                    .get_user_by_id(impersonator_id)
 82                    .await?
 83                    .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
 84                Some(impersonator)
 85            } else {
 86                None
 87            };
 88            req.extensions_mut().insert(user);
 89            req.extensions_mut().insert(Impersonator(impersonator));
 90            return Ok::<_, Error>(next.run(req).await);
 91        }
 92    }
 93
 94    Err(Error::Http(
 95        StatusCode::UNAUTHORIZED,
 96        "invalid credentials".to_string(),
 97    ))
 98}
 99
100const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
101
102#[derive(Serialize, Deserialize)]
103struct AccessTokenJson {
104    version: usize,
105    id: AccessTokenId,
106    token: String,
107}
108
109/// Creates a new access token to identify the given user. before returning it, you should
110/// encrypt it with the user's public key.
111pub async fn create_access_token(
112    db: &db::Database,
113    user_id: UserId,
114    impersonated_user_id: Option<UserId>,
115) -> Result<String> {
116    const VERSION: usize = 1;
117    let access_token = rpc::auth::random_token();
118    let access_token_hash =
119        hash_access_token(&access_token).context("failed to hash access token")?;
120    let id = db
121        .create_access_token(
122            user_id,
123            impersonated_user_id,
124            &access_token_hash,
125            MAX_ACCESS_TOKENS_TO_STORE,
126        )
127        .await?;
128    Ok(serde_json::to_string(&AccessTokenJson {
129        version: VERSION,
130        id,
131        token: access_token,
132    })?)
133}
134
135fn hash_access_token(token: &str) -> Result<String> {
136    // Avoid slow hashing in debug mode.
137    let params = if cfg!(debug_assertions) {
138        scrypt::Params::new(1, 1, 1).unwrap()
139    } else {
140        scrypt::Params::new(14, 8, 1).unwrap()
141    };
142
143    Ok(Scrypt
144        .hash_password(
145            token.as_bytes(),
146            None,
147            params,
148            &SaltString::generate(thread_rng()),
149        )
150        .map_err(anyhow::Error::new)?
151        .to_string())
152}
153
154/// Encrypts the given access token with the given public key to avoid leaking it on the way
155/// to the client.
156pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
157    let native_app_public_key =
158        rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
159    let encrypted_access_token = native_app_public_key
160        .encrypt_string(access_token)
161        .context("failed to encrypt access token with public key")?;
162    Ok(encrypted_access_token)
163}
164
165pub struct VerifyAccessTokenResult {
166    pub is_valid: bool,
167    pub impersonator_id: Option<UserId>,
168}
169
170/// Checks that the given access token is valid for the given user.
171pub async fn verify_access_token(
172    token: &str,
173    user_id: UserId,
174    db: &Arc<Database>,
175) -> Result<VerifyAccessTokenResult> {
176    static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
177    let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
178        register_histogram!(
179            "access_token_hashing_time",
180            "time spent hashing access tokens",
181            exponential_buckets(10.0, 2.0, 10).unwrap(),
182        )
183        .unwrap()
184    });
185
186    let token: AccessTokenJson = serde_json::from_str(&token)?;
187
188    let db_token = db.get_access_token(token.id).await?;
189    let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
190    if token_user_id != user_id {
191        return Err(anyhow!("no such access token"))?;
192    }
193
194    let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
195    let t0 = Instant::now();
196    let is_valid = Scrypt
197        .verify_password(token.token.as_bytes(), &db_hash)
198        .is_ok();
199    let duration = t0.elapsed();
200    log::info!("hashed access token in {:?}", duration);
201    metric_access_token_hashing_time.observe(duration.as_millis() as f64);
202    Ok(VerifyAccessTokenResult {
203        is_valid,
204        impersonator_id: if db_token.impersonated_user_id.is_some() {
205            Some(db_token.user_id)
206        } else {
207            None
208        },
209    })
210}