auth.rs

  1use crate::{
  2    AppState, Error, Result,
  3    db::{AccessTokenId, Database, UserId},
  4    rpc::Principal,
  5};
  6use anyhow::Context as _;
  7use axum::{
  8    http::{self, Request, StatusCode},
  9    middleware::Next,
 10    response::IntoResponse,
 11};
 12use base64::prelude::*;
 13use prometheus::{Histogram, exponential_buckets, register_histogram};
 14pub use rpc::auth::random_token;
 15use scrypt::{
 16    Scrypt,
 17    password_hash::{PasswordHash, PasswordVerifier},
 18};
 19use serde::{Deserialize, Serialize};
 20use sha2::Digest;
 21use std::sync::OnceLock;
 22use std::{sync::Arc, time::Instant};
 23use subtle::ConstantTimeEq;
 24
 25/// Validates the authorization header and adds an Extension<Principal> to the request.
 26/// Authorization: <user-id> <token>
 27///   <token> can be an access_token attached to that user, or an access token of an admin
 28///   or (in development) the string ADMIN:<config.api_token>.
 29/// Authorization: "dev-server-token" <token>
 30pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
 31    let mut auth_header = req
 32        .headers()
 33        .get(http::header::AUTHORIZATION)
 34        .and_then(|header| header.to_str().ok())
 35        .ok_or_else(|| {
 36            Error::http(
 37                StatusCode::UNAUTHORIZED,
 38                "missing authorization header".to_string(),
 39            )
 40        })?
 41        .split_whitespace();
 42
 43    let state = req.extensions().get::<Arc<AppState>>().unwrap();
 44
 45    let first = auth_header.next().unwrap_or("");
 46    if first == "dev-server-token" {
 47        Err(Error::http(
 48            StatusCode::UNAUTHORIZED,
 49            "Dev servers were removed in Zed 0.157 please upgrade to SSH remoting".to_string(),
 50        ))?;
 51    }
 52
 53    let user_id = UserId(first.parse().map_err(|_| {
 54        Error::http(
 55            StatusCode::BAD_REQUEST,
 56            "missing user id in authorization header".to_string(),
 57        )
 58    })?);
 59
 60    let access_token = auth_header.next().ok_or_else(|| {
 61        Error::http(
 62            StatusCode::BAD_REQUEST,
 63            "missing access token in authorization header".to_string(),
 64        )
 65    })?;
 66
 67    let validate_result = verify_access_token(access_token, user_id, &state.db).await;
 68
 69    if let Ok(validate_result) = validate_result
 70        && validate_result.is_valid
 71    {
 72        let user = state
 73            .db
 74            .get_user_by_id(user_id)
 75            .await?
 76            .with_context(|| format!("user {user_id} not found"))?;
 77
 78        req.extensions_mut().insert(Principal::User(user));
 79        return Ok::<_, Error>(next.run(req).await);
 80    }
 81
 82    Err(Error::http(
 83        StatusCode::UNAUTHORIZED,
 84        "invalid credentials".to_string(),
 85    ))
 86}
 87
 88#[derive(Serialize, Deserialize)]
 89pub struct AccessTokenJson {
 90    pub version: usize,
 91    pub id: AccessTokenId,
 92    pub token: String,
 93}
 94
 95/// Hashing prevents anyone with access to the database being able to login.
 96/// As the token is randomly generated, we don't need to worry about scrypt-style
 97/// protection.
 98pub fn hash_access_token(token: &str) -> String {
 99    let digest = sha2::Sha256::digest(token);
100    format!("$sha256${}", BASE64_URL_SAFE.encode(digest))
101}
102
103pub struct VerifyAccessTokenResult {
104    pub is_valid: bool,
105}
106
107/// Checks that the given access token is valid for the given user.
108pub async fn verify_access_token(
109    token: &str,
110    user_id: UserId,
111    db: &Arc<Database>,
112) -> Result<VerifyAccessTokenResult> {
113    static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
114    let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
115        register_histogram!(
116            "access_token_hashing_time",
117            "time spent hashing access tokens",
118            exponential_buckets(10.0, 2.0, 10).unwrap(),
119        )
120        .unwrap()
121    });
122
123    let token: AccessTokenJson = serde_json::from_str(token)?;
124
125    let db_token = db.get_access_token(token.id).await?;
126    if db_token.user_id != user_id {
127        return Err(anyhow::anyhow!("no such access token"))?;
128    }
129    let t0 = Instant::now();
130
131    let is_valid = if db_token.hash.starts_with("$scrypt$") {
132        let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
133        Scrypt
134            .verify_password(token.token.as_bytes(), &db_hash)
135            .is_ok()
136    } else {
137        let token_hash = hash_access_token(&token.token);
138        db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
139    };
140
141    let duration = t0.elapsed();
142    log::info!("hashed access token in {:?}", duration);
143    metric_access_token_hashing_time.observe(duration.as_millis() as f64);
144
145    if is_valid && db_token.hash.starts_with("$scrypt$") {
146        let new_hash = hash_access_token(&token.token);
147        db.update_access_token_hash(db_token.id, &new_hash).await?;
148    }
149
150    Ok(VerifyAccessTokenResult { is_valid })
151}