Merge pull request #2305 from zed-industries/faster-access-token-validation

Max Brunsfeld created

Faster access token validation

Change summary

crates/collab/src/auth.rs     | 73 ++++++++++++++++++++++++++----------
crates/collab/src/db.rs       | 30 ++++++---------
crates/collab/src/db/tests.rs | 61 +++++++++++++++++++++++-------
3 files changed, 112 insertions(+), 52 deletions(-)

Detailed changes

crates/collab/src/auth.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{
-    db::{self, UserId},
+    db::{self, AccessTokenId, Database, UserId},
     AppState, Error, Result,
 };
 use anyhow::{anyhow, Context};
@@ -8,12 +8,24 @@ use axum::{
     middleware::Next,
     response::IntoResponse,
 };
+use lazy_static::lazy_static;
+use prometheus::{exponential_buckets, register_histogram, Histogram};
 use rand::thread_rng;
 use scrypt::{
     password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
     Scrypt,
 };
-use std::sync::Arc;
+use serde::{Deserialize, Serialize};
+use std::{sync::Arc, time::Instant};
+
+lazy_static! {
+    static ref METRIC_ACCESS_TOKEN_HASHING_TIME: Histogram = register_histogram!(
+        "access_token_hashing_time",
+        "time spent hashing access tokens",
+        exponential_buckets(10.0, 2.0, 10).unwrap(),
+    )
+    .unwrap();
+}
 
 pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
     let mut auth_header = req
@@ -42,20 +54,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
         )
     })?;
 
-    let mut credentials_valid = false;
     let state = req.extensions().get::<Arc<AppState>>().unwrap();
-    if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
-        if state.config.api_token == admin_token {
-            credentials_valid = true;
-        }
+    let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
+        state.config.api_token == admin_token
     } else {
-        for password_hash in state.db.get_access_token_hashes(user_id).await? {
-            if verify_access_token(access_token, &password_hash)? {
-                credentials_valid = true;
-                break;
-            }
-        }
-    }
+        verify_access_token(&access_token, user_id, &state.db)
+            .await
+            .unwrap_or(false)
+    };
 
     if credentials_valid {
         let user = state
@@ -75,13 +81,26 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
 
 const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
 
+#[derive(Serialize, Deserialize)]
+struct AccessTokenJson {
+    version: usize,
+    id: AccessTokenId,
+    token: String,
+}
+
 pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
+    const VERSION: usize = 1;
     let access_token = rpc::auth::random_token();
     let access_token_hash =
         hash_access_token(&access_token).context("failed to hash access token")?;
-    db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
+    let id = db
+        .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
         .await?;
-    Ok(access_token)
+    Ok(serde_json::to_string(&AccessTokenJson {
+        version: VERSION,
+        id,
+        token: access_token,
+    })?)
 }
 
 fn hash_access_token(token: &str) -> Result<String> {
@@ -89,7 +108,7 @@ fn hash_access_token(token: &str) -> Result<String> {
     let params = if cfg!(debug_assertions) {
         scrypt::Params::new(1, 1, 1).unwrap()
     } else {
-        scrypt::Params::recommended()
+        scrypt::Params::new(14, 8, 1).unwrap()
     };
 
     Ok(Scrypt
@@ -112,7 +131,21 @@ pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<St
     Ok(encrypted_access_token)
 }
 
-pub fn verify_access_token(token: &str, hash: &str) -> Result<bool> {
-    let hash = PasswordHash::new(hash).map_err(anyhow::Error::new)?;
-    Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
+pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc<Database>) -> Result<bool> {
+    let token: AccessTokenJson = serde_json::from_str(&token)?;
+
+    let db_token = db.get_access_token(token.id).await?;
+    if db_token.user_id != user_id {
+        return Err(anyhow!("no such access token"))?;
+    }
+
+    let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
+    let t0 = Instant::now();
+    let is_valid = Scrypt
+        .verify_password(token.token.as_bytes(), &db_hash)
+        .is_ok();
+    let duration = t0.elapsed();
+    log::info!("hashed access token in {:?}", duration);
+    METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64);
+    Ok(is_valid)
 }

crates/collab/src/db.rs 🔗

@@ -2746,16 +2746,16 @@ impl Database {
 
     // access tokens
 
-    pub async fn create_access_token_hash(
+    pub async fn create_access_token(
         &self,
         user_id: UserId,
         access_token_hash: &str,
         max_access_token_count: usize,
-    ) -> Result<()> {
+    ) -> Result<AccessTokenId> {
         self.transaction(|tx| async {
             let tx = tx;
 
-            access_token::ActiveModel {
+            let token = access_token::ActiveModel {
                 user_id: ActiveValue::set(user_id),
                 hash: ActiveValue::set(access_token_hash.into()),
                 ..Default::default()
@@ -2778,26 +2778,20 @@ impl Database {
                 )
                 .exec(&*tx)
                 .await?;
-            Ok(())
+            Ok(token.id)
         })
         .await
     }
 
-    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
-        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
-        enum QueryAs {
-            Hash,
-        }
-
+    pub async fn get_access_token(
+        &self,
+        access_token_id: AccessTokenId,
+    ) -> Result<access_token::Model> {
         self.transaction(|tx| async move {
-            Ok(access_token::Entity::find()
-                .select_only()
-                .column(access_token::Column::Hash)
-                .filter(access_token::Column::UserId.eq(user_id))
-                .order_by_desc(access_token::Column::Id)
-                .into_values::<_, QueryAs>()
-                .all(&*tx)
-                .await?)
+            Ok(access_token::Entity::find_by_id(access_token_id)
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("no such access token"))?)
         })
         .await
     }

crates/collab/src/db/tests.rs 🔗

@@ -177,30 +177,63 @@ test_both_dbs!(
             .unwrap()
             .user_id;
 
-        db.create_access_token_hash(user, "h1", 3).await.unwrap();
-        db.create_access_token_hash(user, "h2", 3).await.unwrap();
+        let token_1 = db.create_access_token(user, "h1", 2).await.unwrap();
+        let token_2 = db.create_access_token(user, "h2", 2).await.unwrap();
         assert_eq!(
-            db.get_access_token_hashes(user).await.unwrap(),
-            &["h2".to_string(), "h1".to_string()]
+            db.get_access_token(token_1).await.unwrap(),
+            access_token::Model {
+                id: token_1,
+                user_id: user,
+                hash: "h1".into(),
+            }
         );
-
-        db.create_access_token_hash(user, "h3", 3).await.unwrap();
         assert_eq!(
-            db.get_access_token_hashes(user).await.unwrap(),
-            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
+            db.get_access_token(token_2).await.unwrap(),
+            access_token::Model {
+                id: token_2,
+                user_id: user,
+                hash: "h2".into()
+            }
         );
 
-        db.create_access_token_hash(user, "h4", 3).await.unwrap();
+        let token_3 = db.create_access_token(user, "h3", 2).await.unwrap();
         assert_eq!(
-            db.get_access_token_hashes(user).await.unwrap(),
-            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
+            db.get_access_token(token_3).await.unwrap(),
+            access_token::Model {
+                id: token_3,
+                user_id: user,
+                hash: "h3".into()
+            }
         );
+        assert_eq!(
+            db.get_access_token(token_2).await.unwrap(),
+            access_token::Model {
+                id: token_2,
+                user_id: user,
+                hash: "h2".into()
+            }
+        );
+        assert!(db.get_access_token(token_1).await.is_err());
 
-        db.create_access_token_hash(user, "h5", 3).await.unwrap();
+        let token_4 = db.create_access_token(user, "h4", 2).await.unwrap();
         assert_eq!(
-            db.get_access_token_hashes(user).await.unwrap(),
-            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
+            db.get_access_token(token_4).await.unwrap(),
+            access_token::Model {
+                id: token_4,
+                user_id: user,
+                hash: "h4".into()
+            }
+        );
+        assert_eq!(
+            db.get_access_token(token_3).await.unwrap(),
+            access_token::Model {
+                id: token_3,
+                user_id: user,
+                hash: "h3".into()
+            }
         );
+        assert!(db.get_access_token(token_2).await.is_err());
+        assert!(db.get_access_token(token_1).await.is_err());
     }
 );