From 26dae3c04ecc9a41cb3983fc5caa829f243cc4f7 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 17 Mar 2023 11:13:50 -0700 Subject: [PATCH 1/4] Lookup access tokens by id when authenticating a connection This avoids the cost of hashing an access token multiple times, to compare it to all known access tokens for a given user. Co-authored-by: Antonio Scandurra --- crates/collab/src/auth.rs | 41 +++++++++++++++-------- crates/collab/src/db.rs | 30 +++++++---------- crates/collab/src/db/tests.rs | 61 +++++++++++++++++++++++++++-------- 3 files changed, 86 insertions(+), 46 deletions(-) diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 0c9cf33a6b94b369a9ea47e92e254ffd87e151ab..796bf1c40fac54ee1cddfa6cf37e508eee63d4f6 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -1,5 +1,5 @@ use crate::{ - db::{self, UserId}, + db::{self, AccessTokenId, UserId}, AppState, Error, Result, }; use anyhow::{anyhow, Context}; @@ -13,6 +13,7 @@ use scrypt::{ password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, Scrypt, }; +use serde::{Deserialize, Serialize}; use std::sync::Arc; pub async fn validate_header(mut req: Request, next: Next) -> impl IntoResponse { @@ -42,20 +43,19 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into ) })?; - let mut credentials_valid = false; let state = req.extensions().get::>().unwrap(); - if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") { - if state.config.api_token == admin_token { - credentials_valid = true; - } + let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") { + state.config.api_token == admin_token } else { - for password_hash in state.db.get_access_token_hashes(user_id).await? { - if verify_access_token(access_token, &password_hash)? { - credentials_valid = true; - break; - } + let access_token: AccessTokenJson = serde_json::from_str(&access_token)?; + + let token = state.db.get_access_token(access_token.id).await?; + if token.user_id != user_id { + return Err(anyhow!("no such access token"))?; } - } + + verify_access_token(&access_token.token, &token.hash)? + }; if credentials_valid { let user = state @@ -75,13 +75,26 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; +#[derive(Serialize, Deserialize)] +struct AccessTokenJson { + version: usize, + id: AccessTokenId, + token: String, +} + pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result { + const VERSION: usize = 1; let access_token = rpc::auth::random_token(); let access_token_hash = hash_access_token(&access_token).context("failed to hash access token")?; - db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE) + let id = db + .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE) .await?; - Ok(access_token) + Ok(serde_json::to_string(&AccessTokenJson { + version: VERSION, + id, + token: access_token, + })?) } fn hash_access_token(token: &str) -> Result { diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index e8374f4acd504092b04d52be55b1f41d1a0d8a1f..72f8d9c70379344346b0a4f32917b1bbe498eaf3 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -2746,16 +2746,16 @@ impl Database { // access tokens - pub async fn create_access_token_hash( + pub async fn create_access_token( &self, user_id: UserId, access_token_hash: &str, max_access_token_count: usize, - ) -> Result<()> { + ) -> Result { self.transaction(|tx| async { let tx = tx; - access_token::ActiveModel { + let token = access_token::ActiveModel { user_id: ActiveValue::set(user_id), hash: ActiveValue::set(access_token_hash.into()), ..Default::default() @@ -2778,26 +2778,20 @@ impl Database { ) .exec(&*tx) .await?; - Ok(()) + Ok(token.id) }) .await } - pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryAs { - Hash, - } - + pub async fn get_access_token( + &self, + access_token_id: AccessTokenId, + ) -> Result { self.transaction(|tx| async move { - Ok(access_token::Entity::find() - .select_only() - .column(access_token::Column::Hash) - .filter(access_token::Column::UserId.eq(user_id)) - .order_by_desc(access_token::Column::Id) - .into_values::<_, QueryAs>() - .all(&*tx) - .await?) + Ok(access_token::Entity::find_by_id(access_token_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such access token"))?) }) .await } diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 9cd79ea6d15942959135a3db899c91070d8e7b2b..855dfec91feaf627085019dbb62a5cd247b9803a 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -177,30 +177,63 @@ test_both_dbs!( .unwrap() .user_id; - db.create_access_token_hash(user, "h1", 3).await.unwrap(); - db.create_access_token_hash(user, "h2", 3).await.unwrap(); + let token_1 = db.create_access_token(user, "h1", 2).await.unwrap(); + let token_2 = db.create_access_token(user, "h2", 2).await.unwrap(); assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h2".to_string(), "h1".to_string()] + db.get_access_token(token_1).await.unwrap(), + access_token::Model { + id: token_1, + user_id: user, + hash: "h1".into(), + } ); - - db.create_access_token_hash(user, "h3", 3).await.unwrap(); assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h3".to_string(), "h2".to_string(), "h1".to_string(),] + db.get_access_token(token_2).await.unwrap(), + access_token::Model { + id: token_2, + user_id: user, + hash: "h2".into() + } ); - db.create_access_token_hash(user, "h4", 3).await.unwrap(); + let token_3 = db.create_access_token(user, "h3", 2).await.unwrap(); assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h4".to_string(), "h3".to_string(), "h2".to_string(),] + db.get_access_token(token_3).await.unwrap(), + access_token::Model { + id: token_3, + user_id: user, + hash: "h3".into() + } ); + assert_eq!( + db.get_access_token(token_2).await.unwrap(), + access_token::Model { + id: token_2, + user_id: user, + hash: "h2".into() + } + ); + assert!(db.get_access_token(token_1).await.is_err()); - db.create_access_token_hash(user, "h5", 3).await.unwrap(); + let token_4 = db.create_access_token(user, "h4", 2).await.unwrap(); assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h5".to_string(), "h4".to_string(), "h3".to_string()] + db.get_access_token(token_4).await.unwrap(), + access_token::Model { + id: token_4, + user_id: user, + hash: "h4".into() + } + ); + assert_eq!( + db.get_access_token(token_3).await.unwrap(), + access_token::Model { + id: token_3, + user_id: user, + hash: "h3".into() + } ); + assert!(db.get_access_token(token_2).await.is_err()); + assert!(db.get_access_token(token_1).await.is_err()); } ); From 9633a4b527d4f9c73a81e3868e4d2876a3972e8f Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 17 Mar 2023 13:56:12 -0700 Subject: [PATCH 2/4] Return a 400, not a 500 when token validation fails Co-authored-by: Antonio Scandurra --- crates/collab/src/auth.rs | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 796bf1c40fac54ee1cddfa6cf37e508eee63d4f6..3156212cdfe0d913c4e8fda89d59e9fdf073bcd2 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -1,5 +1,5 @@ use crate::{ - db::{self, AccessTokenId, UserId}, + db::{self, AccessTokenId, Database, UserId}, AppState, Error, Result, }; use anyhow::{anyhow, Context}; @@ -47,14 +47,9 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") { state.config.api_token == admin_token } else { - let access_token: AccessTokenJson = serde_json::from_str(&access_token)?; - - let token = state.db.get_access_token(access_token.id).await?; - if token.user_id != user_id { - return Err(anyhow!("no such access token"))?; - } - - verify_access_token(&access_token.token, &token.hash)? + verify_access_token(&access_token, user_id, &state.db) + .await + .unwrap_or(false) }; if credentials_valid { @@ -125,7 +120,16 @@ pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result Result { - let hash = PasswordHash::new(hash).map_err(anyhow::Error::new)?; - Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok()) +pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc) -> Result { + let token: AccessTokenJson = serde_json::from_str(&token)?; + + let db_token = db.get_access_token(token.id).await?; + if db_token.user_id != user_id { + return Err(anyhow!("no such access token"))?; + } + + let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?; + Ok(Scrypt + .verify_password(token.token.as_bytes(), &db_hash) + .is_ok()) } From 623133ffa0b8634722638d42a0fe718f6a5db533 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 17 Mar 2023 14:31:39 -0700 Subject: [PATCH 3/4] Reduce scrypt work factor to speed up websocket authentication Co-authored-by: Mikayla Maki --- crates/collab/src/auth.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 3156212cdfe0d913c4e8fda89d59e9fdf073bcd2..9819e600c3c94da7a9ed6d5037838fe38983cc69 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -97,7 +97,7 @@ fn hash_access_token(token: &str) -> Result { let params = if cfg!(debug_assertions) { scrypt::Params::new(1, 1, 1).unwrap() } else { - scrypt::Params::recommended() + scrypt::Params::new(14, 8, 1).unwrap() }; Ok(Scrypt From b8e8363a729aa8aecd331e4c484f691e2072528f Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 17 Mar 2023 14:32:13 -0700 Subject: [PATCH 4/4] Add logging and metric for time spent hashing auth tokens Co-authored-by: Mikayla Maki --- crates/collab/src/auth.rs | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 9819e600c3c94da7a9ed6d5037838fe38983cc69..9ce602c5778c25efe89017b58b3498108de38038 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -8,13 +8,24 @@ use axum::{ middleware::Next, response::IntoResponse, }; +use lazy_static::lazy_static; +use prometheus::{exponential_buckets, register_histogram, Histogram}; use rand::thread_rng; use scrypt::{ password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, Scrypt, }; use serde::{Deserialize, Serialize}; -use std::sync::Arc; +use std::{sync::Arc, time::Instant}; + +lazy_static! { + static ref METRIC_ACCESS_TOKEN_HASHING_TIME: Histogram = register_histogram!( + "access_token_hashing_time", + "time spent hashing access tokens", + exponential_buckets(10.0, 2.0, 10).unwrap(), + ) + .unwrap(); +} pub async fn validate_header(mut req: Request, next: Next) -> impl IntoResponse { let mut auth_header = req @@ -129,7 +140,12 @@ pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc