collab: Validate access tokens through Cloud (#49535)

Marshall Bowers created

This PR updates Collab to make it validate access tokens through Cloud
instead of doing it in-house.

We're reusing the `GET /client/users/me` endpoint—which is what we also
call on the client—to validate the user's access token.

We only need to do this when establishing a WebSocket connection, so the
increased latency of a network hop shouldn't be a problem.

Closes CLO-308.

Release Notes:

- N/A

Change summary

Cargo.lock                                      |  37 ---
Cargo.toml                                      |   1 
crates/collab/Cargo.toml                        |   3 
crates/collab/src/auth.rs                       | 104 ++---------
crates/collab/src/lib.rs                        |  19 ++
crates/collab/src/main.rs                       |   5 
crates/collab/tests/integration/collab_tests.rs | 172 -------------------
crates/collab/tests/integration/test_server.rs  |   1 
8 files changed, 41 insertions(+), 301 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3202,7 +3202,6 @@ dependencies = [
  "aws-sdk-kinesis",
  "aws-sdk-s3",
  "axum",
- "base64 0.22.1",
  "buffer_diff",
  "call",
  "channel",
@@ -3260,7 +3259,6 @@ dependencies = [
  "remote_server",
  "reqwest 0.11.27",
  "rpc",
- "scrypt",
  "sea-orm",
  "sea-orm-macros",
  "semver",
@@ -3272,7 +3270,6 @@ dependencies = [
  "smol",
  "sqlx",
  "strum 0.27.2",
- "subtle",
  "task",
  "telemetry_events",
  "text",
@@ -11463,17 +11460,6 @@ dependencies = [
  "subtle",
 ]
 
-[[package]]
-name = "password-hash"
-version = "0.5.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
-dependencies = [
- "base64ct",
- "rand_core 0.6.4",
- "subtle",
-]
-
 [[package]]
 name = "paste"
 version = "1.0.15"
@@ -11559,7 +11545,7 @@ checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917"
 dependencies = [
  "digest",
  "hmac",
- "password-hash 0.4.2",
+ "password-hash",
  "sha2",
 ]
 
@@ -14515,15 +14501,6 @@ dependencies = [
  "serde_json",
 ]
 
-[[package]]
-name = "salsa20"
-version = "0.10.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "97a22f5af31f73a954c10289c93e8a50cc23d971e80ee446f1f6f7137a088213"
-dependencies = [
- "cipher",
-]
-
 [[package]]
 name = "same-file"
 version = "1.0.6"
@@ -14666,18 +14643,6 @@ dependencies = [
  "syn 2.0.106",
 ]
 
-[[package]]
-name = "scrypt"
-version = "0.11.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0516a385866c09368f0b5bcd1caff3366aace790fcd46e2bb032697bb172fd1f"
-dependencies = [
- "password-hash 0.5.0",
- "pbkdf2 0.12.2",
- "salsa20",
- "sha2",
-]
-
 [[package]]
 name = "sct"
 version = "0.7.1"

Cargo.toml 🔗

@@ -668,7 +668,6 @@ stacksafe = "0.1"
 streaming-iterator = "0.1"
 strsim = "0.11"
 strum = { version = "0.27.2", features = ["derive"] }
-subtle = "2.5.0"
 syn = { version = "2.0.101", features = ["full", "extra-traits", "visit-mut"] }
 sys-locale = "0.3.1"
 sysinfo = "0.37.0"

crates/collab/Cargo.toml 🔗

@@ -33,7 +33,6 @@ aws-config = { version = "1.1.5" }
 aws-sdk-kinesis = "1.51.0"
 aws-sdk-s3 = { version = "1.15.0" }
 axum = { version = "0.6", features = ["json", "headers", "ws"] }
-base64.workspace = true
 chrono.workspace = true
 clock.workspace = true
 cloud_api_types.workspace = true
@@ -53,7 +52,6 @@ prost.workspace = true
 rand.workspace = true
 reqwest = { version = "0.11", features = ["json"] }
 rpc.workspace = true
-scrypt = "0.11"
 # sea-orm and sea-orm-macros versions must match exactly.
 sea-orm = { version = "=1.1.10", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls", "with-uuid"] }
 sea-orm-macros = "=1.1.10"
@@ -63,7 +61,6 @@ serde_json.workspace = true
 sha2.workspace = true
 sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid", "any"] }
 strum.workspace = true
-subtle.workspace = true
 telemetry_events.workspace = true
 text.workspace = true
 time.workspace = true

crates/collab/src/auth.rs 🔗

@@ -1,26 +1,13 @@
-use crate::{
-    AppState, Error, Result,
-    db::{AccessTokenId, Database, UserId},
-    rpc::Principal,
-};
+use crate::{AppState, Error, db::UserId, rpc::Principal};
 use anyhow::Context as _;
 use axum::{
     http::{self, Request, StatusCode},
     middleware::Next,
     response::IntoResponse,
 };
-use base64::prelude::*;
-use prometheus::{Histogram, exponential_buckets, register_histogram};
+use cloud_api_types::GetAuthenticatedUserResponse;
 pub use rpc::auth::random_token;
-use scrypt::{
-    Scrypt,
-    password_hash::{PasswordHash, PasswordVerifier},
-};
-use serde::{Deserialize, Serialize};
-use sha2::Digest;
-use std::sync::OnceLock;
-use std::{sync::Arc, time::Instant};
-use subtle::ConstantTimeEq;
+use std::sync::Arc;
 
 /// Validates the authorization header and adds an Extension<Principal> to the request.
 /// Authorization: <user-id> <token>
@@ -64,11 +51,23 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
         )
     })?;
 
-    let validate_result = verify_access_token(access_token, user_id, &state.db).await;
+    let http_client = state.http_client.clone().expect("no HTTP client");
+
+    let response = http_client
+        .get(format!("{}/client/users/me", state.config.zed_cloud_url()))
+        .header("Content-Type", "application/json")
+        .header("Authorization", format!("{user_id} {access_token}"))
+        .send()
+        .await
+        .context("failed to validate access token")?;
+    if let Ok(response) = response.error_for_status() {
+        let response_body: GetAuthenticatedUserResponse = response
+            .json()
+            .await
+            .context("failed to parse response body")?;
+
+        let user_id = UserId(response_body.user.id);
 
-    if let Ok(validate_result) = validate_result
-        && validate_result.is_valid
-    {
         let user = state
             .db
             .get_user_by_id(user_id)
@@ -84,68 +83,3 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
         "invalid credentials".to_string(),
     ))
 }
-
-#[derive(Serialize, Deserialize)]
-pub struct AccessTokenJson {
-    pub version: usize,
-    pub id: AccessTokenId,
-    pub token: String,
-}
-
-/// Hashing prevents anyone with access to the database being able to login.
-/// As the token is randomly generated, we don't need to worry about scrypt-style
-/// protection.
-pub fn hash_access_token(token: &str) -> String {
-    let digest = sha2::Sha256::digest(token);
-    format!("$sha256${}", BASE64_URL_SAFE.encode(digest))
-}
-
-pub struct VerifyAccessTokenResult {
-    pub is_valid: bool,
-}
-
-/// Checks that the given access token is valid for the given user.
-pub async fn verify_access_token(
-    token: &str,
-    user_id: UserId,
-    db: &Arc<Database>,
-) -> Result<VerifyAccessTokenResult> {
-    static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
-    let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
-        register_histogram!(
-            "access_token_hashing_time",
-            "time spent hashing access tokens",
-            exponential_buckets(10.0, 2.0, 10).unwrap(),
-        )
-        .unwrap()
-    });
-
-    let token: AccessTokenJson = serde_json::from_str(token)?;
-
-    let db_token = db.get_access_token(token.id).await?;
-    if db_token.user_id != user_id {
-        return Err(anyhow::anyhow!("no such access token"))?;
-    }
-    let t0 = Instant::now();
-
-    let is_valid = if db_token.hash.starts_with("$scrypt$") {
-        let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
-        Scrypt
-            .verify_password(token.token.as_bytes(), &db_hash)
-            .is_ok()
-    } else {
-        let token_hash = hash_access_token(&token.token);
-        db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
-    };
-
-    let duration = t0.elapsed();
-    log::info!("hashed access token in {:?}", duration);
-    metric_access_token_hashing_time.observe(duration.as_millis() as f64);
-
-    if is_valid && db_token.hash.starts_with("$scrypt$") {
-        let new_hash = hash_access_token(&token.token);
-        db.update_access_token_hash(db_token.id, &new_hash).await?;
-    }
-
-    Ok(VerifyAccessTokenResult { is_valid })
-}

crates/collab/src/lib.rs 🔗

@@ -18,6 +18,9 @@ use serde::Deserialize;
 use std::{path::PathBuf, sync::Arc};
 use util::ResultExt;
 
+pub const VERSION: &str = env!("CARGO_PKG_VERSION");
+pub const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
+
 pub type Result<T, E = Error> = std::result::Result<T, E>;
 
 pub enum Error {
@@ -150,6 +153,14 @@ impl Config {
         }
     }
 
+    /// Returns the base Zed Cloud URL.
+    pub fn zed_cloud_url(&self) -> &str {
+        match self.zed_environment.as_ref() {
+            "development" => "http://localhost:8787",
+            _ => "https://cloud.zed.dev",
+        }
+    }
+
     #[cfg(feature = "test-support")]
     pub fn test() -> Self {
         Self {
@@ -199,6 +210,7 @@ impl ServiceMode {
 
 pub struct AppState {
     pub db: Arc<Database>,
+    pub http_client: Option<reqwest::Client>,
     pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
     pub blob_store_client: Option<aws_sdk_s3::Client>,
     pub executor: Executor,
@@ -228,9 +240,16 @@ impl AppState {
             None
         };
 
+        let user_agent = format!("Collab/{VERSION} ({})", REVISION.unwrap_or("unknown"));
+        let http_client = reqwest::Client::builder()
+            .user_agent(user_agent)
+            .build()
+            .context("failed to construct HTTP client")?;
+
         let db = Arc::new(db);
         let this = Self {
             db: db.clone(),
+            http_client: Some(http_client),
             livekit_client,
             blob_store_client: build_blob_store_client(&config).await.log_err(),
             executor,

crates/collab/src/main.rs 🔗

@@ -7,12 +7,12 @@ use axum::{
     routing::get,
 };
 
-use collab::ServiceMode;
 use collab::api::CloudflareIpCountryHeader;
 use collab::{
     AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
     executor::Executor,
 };
+use collab::{REVISION, ServiceMode, VERSION};
 use db::Database;
 use std::{
     env::args,
@@ -28,9 +28,6 @@ use tracing_subscriber::{
 };
 use util::ResultExt as _;
 
-const VERSION: &str = env!("CARGO_PKG_VERSION");
-const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
-
 #[expect(clippy::result_large_err)]
 #[tokio::main]
 async fn main() -> Result<()> {

crates/collab/tests/integration/collab_tests.rs 🔗

@@ -50,175 +50,3 @@ fn room_participants(room: &Entity<Room>, cx: &mut TestAppContext) -> RoomPartic
 fn channel_id(room: &Entity<Room>, cx: &mut TestAppContext) -> Option<ChannelId> {
     cx.read(|cx| room.read(cx).channel_id())
 }
-
-mod auth_token_tests {
-    use collab::auth::{
-        AccessTokenJson, VerifyAccessTokenResult, hash_access_token, verify_access_token,
-    };
-    use rand::prelude::*;
-    use scrypt::Scrypt;
-    use scrypt::password_hash::{PasswordHasher, SaltString};
-    use sea_orm::EntityTrait;
-
-    use collab::db::{Database, NewUserParams, UserId, access_token};
-    use collab::*;
-
-    const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
-
-    async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
-        const VERSION: usize = 1;
-        let access_token = ::rpc::auth::random_token();
-        let access_token_hash = hash_access_token(&access_token);
-        let id = db
-            .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
-            .await?;
-        Ok(serde_json::to_string(&AccessTokenJson {
-            version: VERSION,
-            id,
-            token: access_token,
-        })?)
-    }
-
-    #[gpui::test]
-    async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
-        let test_db = crate::db_tests::TestDb::sqlite(cx.executor());
-        let db = test_db.db();
-
-        let user = db
-            .create_user(
-                "example@example.com",
-                None,
-                false,
-                NewUserParams {
-                    github_login: "example".into(),
-                    github_user_id: 1,
-                },
-            )
-            .await
-            .unwrap();
-
-        let token = create_access_token(db, user.user_id).await.unwrap();
-        assert!(matches!(
-            verify_access_token(&token, user.user_id, db).await.unwrap(),
-            VerifyAccessTokenResult { is_valid: true }
-        ));
-
-        let old_token = create_previous_access_token(user.user_id, db)
-            .await
-            .unwrap();
-
-        let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
-            .unwrap()
-            .id;
-
-        let hash = db
-            .transaction(|tx| async move {
-                Ok(access_token::Entity::find_by_id(old_token_id)
-                    .one(&*tx)
-                    .await?)
-            })
-            .await
-            .unwrap()
-            .unwrap()
-            .hash;
-        assert!(hash.starts_with("$scrypt$"));
-
-        assert!(matches!(
-            verify_access_token(&old_token, user.user_id, db)
-                .await
-                .unwrap(),
-            VerifyAccessTokenResult { is_valid: true }
-        ));
-
-        let hash = db
-            .transaction(|tx| async move {
-                Ok(access_token::Entity::find_by_id(old_token_id)
-                    .one(&*tx)
-                    .await?)
-            })
-            .await
-            .unwrap()
-            .unwrap()
-            .hash;
-        assert!(hash.starts_with("$sha256$"));
-
-        assert!(matches!(
-            verify_access_token(&old_token, user.user_id, db)
-                .await
-                .unwrap(),
-            VerifyAccessTokenResult { is_valid: true }
-        ));
-
-        assert!(matches!(
-            verify_access_token(&token, user.user_id, db).await.unwrap(),
-            VerifyAccessTokenResult { is_valid: true }
-        ));
-    }
-
-    async fn create_previous_access_token(user_id: UserId, db: &Database) -> Result<String> {
-        let access_token = collab::auth::random_token();
-        let access_token_hash = previous_hash_access_token(&access_token)?;
-        let id = db
-            .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
-            .await?;
-        Ok(serde_json::to_string(&AccessTokenJson {
-            version: 1,
-            id,
-            token: access_token,
-        })?)
-    }
-
-    #[expect(clippy::result_large_err)]
-    fn previous_hash_access_token(token: &str) -> Result<String> {
-        // Avoid slow hashing in debug mode.
-        let params = if cfg!(debug_assertions) {
-            scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
-        } else {
-            scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
-        };
-
-        Ok(Scrypt
-            .hash_password_customized(
-                token.as_bytes(),
-                None,
-                None,
-                params,
-                &SaltString::generate(PasswordHashRngCompat::new()),
-            )
-            .map_err(anyhow::Error::new)?
-            .to_string())
-    }
-
-    // TODO: remove once we password_hash v0.6 is released.
-    struct PasswordHashRngCompat(rand::rngs::ThreadRng);
-
-    impl PasswordHashRngCompat {
-        fn new() -> Self {
-            Self(rand::rng())
-        }
-    }
-
-    impl scrypt::password_hash::rand_core::RngCore for PasswordHashRngCompat {
-        fn next_u32(&mut self) -> u32 {
-            self.0.next_u32()
-        }
-
-        fn next_u64(&mut self) -> u64 {
-            self.0.next_u64()
-        }
-
-        fn fill_bytes(&mut self, dest: &mut [u8]) {
-            self.0.fill_bytes(dest);
-        }
-
-        fn try_fill_bytes(
-            &mut self,
-            dest: &mut [u8],
-        ) -> Result<(), scrypt::password_hash::rand_core::Error> {
-            self.fill_bytes(dest);
-            Ok(())
-        }
-    }
-
-    impl scrypt::password_hash::rand_core::CryptoRng for PasswordHashRngCompat {}
-}

crates/collab/tests/integration/test_server.rs 🔗

@@ -564,6 +564,7 @@ impl TestServer {
     ) -> Arc<AppState> {
         Arc::new(AppState {
             db: test_db.db().clone(),
+            http_client: None,
             livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
             blob_store_client: None,
             executor,