From f07cec59def58dd8199bdced27b1aafbeb5755d8 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 18 Feb 2026 18:20:52 -0500 Subject: [PATCH] collab: Validate access tokens through Cloud (#49535) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR updates Collab to make it validate access tokens through Cloud instead of doing it in-house. We're reusing the `GET /client/users/me` endpoint—which is what we also call on the client—to validate the user's access token. We only need to do this when establishing a WebSocket connection, so the increased latency of a network hop shouldn't be a problem. Closes CLO-308. Release Notes: - N/A --- Cargo.lock | 37 +--- Cargo.toml | 1 - crates/collab/Cargo.toml | 3 - crates/collab/src/auth.rs | 104 ++--------- crates/collab/src/lib.rs | 19 ++ crates/collab/src/main.rs | 5 +- .../collab/tests/integration/collab_tests.rs | 172 ------------------ .../collab/tests/integration/test_server.rs | 1 + 8 files changed, 41 insertions(+), 301 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 001dc23ce5e3373b0c42ca868b892ca29f989a1d..fae2f36a891f7088eb3363305cdd451f619a50ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3202,7 +3202,6 @@ dependencies = [ "aws-sdk-kinesis", "aws-sdk-s3", "axum", - "base64 0.22.1", "buffer_diff", "call", "channel", @@ -3260,7 +3259,6 @@ dependencies = [ "remote_server", "reqwest 0.11.27", "rpc", - "scrypt", "sea-orm", "sea-orm-macros", "semver", @@ -3272,7 +3270,6 @@ dependencies = [ "smol", "sqlx", "strum 0.27.2", - "subtle", "task", "telemetry_events", "text", @@ -11463,17 +11460,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "password-hash" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" -dependencies = [ - "base64ct", - "rand_core 0.6.4", - "subtle", -] - [[package]] name = "paste" version = "1.0.15" @@ -11559,7 +11545,7 @@ checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" dependencies = [ "digest", "hmac", - "password-hash 0.4.2", + "password-hash", "sha2", ] @@ -14515,15 +14501,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "salsa20" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97a22f5af31f73a954c10289c93e8a50cc23d971e80ee446f1f6f7137a088213" -dependencies = [ - "cipher", -] - [[package]] name = "same-file" version = "1.0.6" @@ -14666,18 +14643,6 @@ dependencies = [ "syn 2.0.106", ] -[[package]] -name = "scrypt" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0516a385866c09368f0b5bcd1caff3366aace790fcd46e2bb032697bb172fd1f" -dependencies = [ - "password-hash 0.5.0", - "pbkdf2 0.12.2", - "salsa20", - "sha2", -] - [[package]] name = "sct" version = "0.7.1" diff --git a/Cargo.toml b/Cargo.toml index d2f1fe27aecf989153b9f23d2adf870c0db48e35..fd4962e74928d7d13fbeeab5ce17b0a7b1ba4e59 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -668,7 +668,6 @@ stacksafe = "0.1" streaming-iterator = "0.1" strsim = "0.11" strum = { version = "0.27.2", features = ["derive"] } -subtle = "2.5.0" syn = { version = "2.0.101", features = ["full", "extra-traits", "visit-mut"] } sys-locale = "0.3.1" sysinfo = "0.37.0" diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 56c3268497246ceab90d1a039d6dc8027b596ce6..5db06ef8e73d3cf276f73fbd8aa53e932e6c75b8 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -33,7 +33,6 @@ aws-config = { version = "1.1.5" } aws-sdk-kinesis = "1.51.0" aws-sdk-s3 = { version = "1.15.0" } axum = { version = "0.6", features = ["json", "headers", "ws"] } -base64.workspace = true chrono.workspace = true clock.workspace = true cloud_api_types.workspace = true @@ -53,7 +52,6 @@ prost.workspace = true rand.workspace = true reqwest = { version = "0.11", features = ["json"] } rpc.workspace = true -scrypt = "0.11" # sea-orm and sea-orm-macros versions must match exactly. sea-orm = { version = "=1.1.10", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls", "with-uuid"] } sea-orm-macros = "=1.1.10" @@ -63,7 +61,6 @@ serde_json.workspace = true sha2.workspace = true sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid", "any"] } strum.workspace = true -subtle.workspace = true telemetry_events.workspace = true text.workspace = true time.workspace = true diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index f87f97c453f209de09c77475e73171d2a6863ce7..5cd377d605b1d59742018d1f1fb52a1fc2d70287 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -1,26 +1,13 @@ -use crate::{ - AppState, Error, Result, - db::{AccessTokenId, Database, UserId}, - rpc::Principal, -}; +use crate::{AppState, Error, db::UserId, rpc::Principal}; use anyhow::Context as _; use axum::{ http::{self, Request, StatusCode}, middleware::Next, response::IntoResponse, }; -use base64::prelude::*; -use prometheus::{Histogram, exponential_buckets, register_histogram}; +use cloud_api_types::GetAuthenticatedUserResponse; pub use rpc::auth::random_token; -use scrypt::{ - Scrypt, - password_hash::{PasswordHash, PasswordVerifier}, -}; -use serde::{Deserialize, Serialize}; -use sha2::Digest; -use std::sync::OnceLock; -use std::{sync::Arc, time::Instant}; -use subtle::ConstantTimeEq; +use std::sync::Arc; /// Validates the authorization header and adds an Extension to the request. /// Authorization: @@ -64,11 +51,23 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into ) })?; - let validate_result = verify_access_token(access_token, user_id, &state.db).await; + let http_client = state.http_client.clone().expect("no HTTP client"); + + let response = http_client + .get(format!("{}/client/users/me", state.config.zed_cloud_url())) + .header("Content-Type", "application/json") + .header("Authorization", format!("{user_id} {access_token}")) + .send() + .await + .context("failed to validate access token")?; + if let Ok(response) = response.error_for_status() { + let response_body: GetAuthenticatedUserResponse = response + .json() + .await + .context("failed to parse response body")?; + + let user_id = UserId(response_body.user.id); - if let Ok(validate_result) = validate_result - && validate_result.is_valid - { let user = state .db .get_user_by_id(user_id) @@ -84,68 +83,3 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into "invalid credentials".to_string(), )) } - -#[derive(Serialize, Deserialize)] -pub struct AccessTokenJson { - pub version: usize, - pub id: AccessTokenId, - pub token: String, -} - -/// Hashing prevents anyone with access to the database being able to login. -/// As the token is randomly generated, we don't need to worry about scrypt-style -/// protection. -pub fn hash_access_token(token: &str) -> String { - let digest = sha2::Sha256::digest(token); - format!("$sha256${}", BASE64_URL_SAFE.encode(digest)) -} - -pub struct VerifyAccessTokenResult { - pub is_valid: bool, -} - -/// Checks that the given access token is valid for the given user. -pub async fn verify_access_token( - token: &str, - user_id: UserId, - db: &Arc, -) -> Result { - static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock = OnceLock::new(); - let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| { - register_histogram!( - "access_token_hashing_time", - "time spent hashing access tokens", - exponential_buckets(10.0, 2.0, 10).unwrap(), - ) - .unwrap() - }); - - let token: AccessTokenJson = serde_json::from_str(token)?; - - let db_token = db.get_access_token(token.id).await?; - if db_token.user_id != user_id { - return Err(anyhow::anyhow!("no such access token"))?; - } - let t0 = Instant::now(); - - let is_valid = if db_token.hash.starts_with("$scrypt$") { - let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?; - Scrypt - .verify_password(token.token.as_bytes(), &db_hash) - .is_ok() - } else { - let token_hash = hash_access_token(&token.token); - db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into() - }; - - let duration = t0.elapsed(); - log::info!("hashed access token in {:?}", duration); - metric_access_token_hashing_time.observe(duration.as_millis() as f64); - - if is_valid && db_token.hash.starts_with("$scrypt$") { - let new_hash = hash_access_token(&token.token); - db.update_access_token_hash(db_token.id, &new_hash).await?; - } - - Ok(VerifyAccessTokenResult { is_valid }) -} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 25337fe3fa00ce72234f6ac5f107db0248d03ce7..7af4216ca5ee69a75757d80e6584acfb5c8f8aa2 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -18,6 +18,9 @@ use serde::Deserialize; use std::{path::PathBuf, sync::Arc}; use util::ResultExt; +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); +pub const REVISION: Option<&'static str> = option_env!("GITHUB_SHA"); + pub type Result = std::result::Result; pub enum Error { @@ -150,6 +153,14 @@ impl Config { } } + /// Returns the base Zed Cloud URL. + pub fn zed_cloud_url(&self) -> &str { + match self.zed_environment.as_ref() { + "development" => "http://localhost:8787", + _ => "https://cloud.zed.dev", + } + } + #[cfg(feature = "test-support")] pub fn test() -> Self { Self { @@ -199,6 +210,7 @@ impl ServiceMode { pub struct AppState { pub db: Arc, + pub http_client: Option, pub livekit_client: Option>, pub blob_store_client: Option, pub executor: Executor, @@ -228,9 +240,16 @@ impl AppState { None }; + let user_agent = format!("Collab/{VERSION} ({})", REVISION.unwrap_or("unknown")); + let http_client = reqwest::Client::builder() + .user_agent(user_agent) + .build() + .context("failed to construct HTTP client")?; + let db = Arc::new(db); let this = Self { db: db.clone(), + http_client: Some(http_client), livekit_client, blob_store_client: build_blob_store_client(&config).await.log_err(), executor, diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 3dc170e831be57d0ee19f24640ac13c9f1c90adc..72eebbe39c20f33102208dba6ac9a8607b15f5be 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -7,12 +7,12 @@ use axum::{ routing::get, }; -use collab::ServiceMode; use collab::api::CloudflareIpCountryHeader; use collab::{ AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, }; +use collab::{REVISION, ServiceMode, VERSION}; use db::Database; use std::{ env::args, @@ -28,9 +28,6 @@ use tracing_subscriber::{ }; use util::ResultExt as _; -const VERSION: &str = env!("CARGO_PKG_VERSION"); -const REVISION: Option<&'static str> = option_env!("GITHUB_SHA"); - #[expect(clippy::result_large_err)] #[tokio::main] async fn main() -> Result<()> { diff --git a/crates/collab/tests/integration/collab_tests.rs b/crates/collab/tests/integration/collab_tests.rs index 3376f12b203c532cfb77f561f6300a832eafc6e1..8c568c5c4e1f9b8414b48d5b7175763ded5e89c9 100644 --- a/crates/collab/tests/integration/collab_tests.rs +++ b/crates/collab/tests/integration/collab_tests.rs @@ -50,175 +50,3 @@ fn room_participants(room: &Entity, cx: &mut TestAppContext) -> RoomPartic fn channel_id(room: &Entity, cx: &mut TestAppContext) -> Option { cx.read(|cx| room.read(cx).channel_id()) } - -mod auth_token_tests { - use collab::auth::{ - AccessTokenJson, VerifyAccessTokenResult, hash_access_token, verify_access_token, - }; - use rand::prelude::*; - use scrypt::Scrypt; - use scrypt::password_hash::{PasswordHasher, SaltString}; - use sea_orm::EntityTrait; - - use collab::db::{Database, NewUserParams, UserId, access_token}; - use collab::*; - - const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; - - async fn create_access_token(db: &db::Database, user_id: UserId) -> Result { - const VERSION: usize = 1; - let access_token = ::rpc::auth::random_token(); - let access_token_hash = hash_access_token(&access_token); - let id = db - .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE) - .await?; - Ok(serde_json::to_string(&AccessTokenJson { - version: VERSION, - id, - token: access_token, - })?) - } - - #[gpui::test] - async fn test_verify_access_token(cx: &mut gpui::TestAppContext) { - let test_db = crate::db_tests::TestDb::sqlite(cx.executor()); - let db = test_db.db(); - - let user = db - .create_user( - "example@example.com", - None, - false, - NewUserParams { - github_login: "example".into(), - github_user_id: 1, - }, - ) - .await - .unwrap(); - - let token = create_access_token(db, user.user_id).await.unwrap(); - assert!(matches!( - verify_access_token(&token, user.user_id, db).await.unwrap(), - VerifyAccessTokenResult { is_valid: true } - )); - - let old_token = create_previous_access_token(user.user_id, db) - .await - .unwrap(); - - let old_token_id = serde_json::from_str::(&old_token) - .unwrap() - .id; - - let hash = db - .transaction(|tx| async move { - Ok(access_token::Entity::find_by_id(old_token_id) - .one(&*tx) - .await?) - }) - .await - .unwrap() - .unwrap() - .hash; - assert!(hash.starts_with("$scrypt$")); - - assert!(matches!( - verify_access_token(&old_token, user.user_id, db) - .await - .unwrap(), - VerifyAccessTokenResult { is_valid: true } - )); - - let hash = db - .transaction(|tx| async move { - Ok(access_token::Entity::find_by_id(old_token_id) - .one(&*tx) - .await?) - }) - .await - .unwrap() - .unwrap() - .hash; - assert!(hash.starts_with("$sha256$")); - - assert!(matches!( - verify_access_token(&old_token, user.user_id, db) - .await - .unwrap(), - VerifyAccessTokenResult { is_valid: true } - )); - - assert!(matches!( - verify_access_token(&token, user.user_id, db).await.unwrap(), - VerifyAccessTokenResult { is_valid: true } - )); - } - - async fn create_previous_access_token(user_id: UserId, db: &Database) -> Result { - let access_token = collab::auth::random_token(); - let access_token_hash = previous_hash_access_token(&access_token)?; - let id = db - .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE) - .await?; - Ok(serde_json::to_string(&AccessTokenJson { - version: 1, - id, - token: access_token, - })?) - } - - #[expect(clippy::result_large_err)] - fn previous_hash_access_token(token: &str) -> Result { - // Avoid slow hashing in debug mode. - let params = if cfg!(debug_assertions) { - scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap() - } else { - scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap() - }; - - Ok(Scrypt - .hash_password_customized( - token.as_bytes(), - None, - None, - params, - &SaltString::generate(PasswordHashRngCompat::new()), - ) - .map_err(anyhow::Error::new)? - .to_string()) - } - - // TODO: remove once we password_hash v0.6 is released. - struct PasswordHashRngCompat(rand::rngs::ThreadRng); - - impl PasswordHashRngCompat { - fn new() -> Self { - Self(rand::rng()) - } - } - - impl scrypt::password_hash::rand_core::RngCore for PasswordHashRngCompat { - fn next_u32(&mut self) -> u32 { - self.0.next_u32() - } - - fn next_u64(&mut self) -> u64 { - self.0.next_u64() - } - - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.0.fill_bytes(dest); - } - - fn try_fill_bytes( - &mut self, - dest: &mut [u8], - ) -> Result<(), scrypt::password_hash::rand_core::Error> { - self.fill_bytes(dest); - Ok(()) - } - } - - impl scrypt::password_hash::rand_core::CryptoRng for PasswordHashRngCompat {} -} diff --git a/crates/collab/tests/integration/test_server.rs b/crates/collab/tests/integration/test_server.rs index 6bc02433e2f724d96b7911a3ed3c741377b5e70f..7405c7140b72595a908a4a6ac3226e7a5476050a 100644 --- a/crates/collab/tests/integration/test_server.rs +++ b/crates/collab/tests/integration/test_server.rs @@ -564,6 +564,7 @@ impl TestServer { ) -> Arc { Arc::new(AppState { db: test_db.db().clone(), + http_client: None, livekit_client: Some(Arc::new(livekit_test_server.create_api_client())), blob_store_client: None, executor,