Detailed changes
@@ -3202,7 +3202,6 @@ dependencies = [
"aws-sdk-kinesis",
"aws-sdk-s3",
"axum",
- "base64 0.22.1",
"buffer_diff",
"call",
"channel",
@@ -3260,7 +3259,6 @@ dependencies = [
"remote_server",
"reqwest 0.11.27",
"rpc",
- "scrypt",
"sea-orm",
"sea-orm-macros",
"semver",
@@ -3272,7 +3270,6 @@ dependencies = [
"smol",
"sqlx",
"strum 0.27.2",
- "subtle",
"task",
"telemetry_events",
"text",
@@ -11463,17 +11460,6 @@ dependencies = [
"subtle",
]
-[[package]]
-name = "password-hash"
-version = "0.5.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
-dependencies = [
- "base64ct",
- "rand_core 0.6.4",
- "subtle",
-]
-
[[package]]
name = "paste"
version = "1.0.15"
@@ -11559,7 +11545,7 @@ checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917"
dependencies = [
"digest",
"hmac",
- "password-hash 0.4.2",
+ "password-hash",
"sha2",
]
@@ -14515,15 +14501,6 @@ dependencies = [
"serde_json",
]
-[[package]]
-name = "salsa20"
-version = "0.10.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "97a22f5af31f73a954c10289c93e8a50cc23d971e80ee446f1f6f7137a088213"
-dependencies = [
- "cipher",
-]
-
[[package]]
name = "same-file"
version = "1.0.6"
@@ -14666,18 +14643,6 @@ dependencies = [
"syn 2.0.106",
]
-[[package]]
-name = "scrypt"
-version = "0.11.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0516a385866c09368f0b5bcd1caff3366aace790fcd46e2bb032697bb172fd1f"
-dependencies = [
- "password-hash 0.5.0",
- "pbkdf2 0.12.2",
- "salsa20",
- "sha2",
-]
-
[[package]]
name = "sct"
version = "0.7.1"
@@ -668,7 +668,6 @@ stacksafe = "0.1"
streaming-iterator = "0.1"
strsim = "0.11"
strum = { version = "0.27.2", features = ["derive"] }
-subtle = "2.5.0"
syn = { version = "2.0.101", features = ["full", "extra-traits", "visit-mut"] }
sys-locale = "0.3.1"
sysinfo = "0.37.0"
@@ -33,7 +33,6 @@ aws-config = { version = "1.1.5" }
aws-sdk-kinesis = "1.51.0"
aws-sdk-s3 = { version = "1.15.0" }
axum = { version = "0.6", features = ["json", "headers", "ws"] }
-base64.workspace = true
chrono.workspace = true
clock.workspace = true
cloud_api_types.workspace = true
@@ -53,7 +52,6 @@ prost.workspace = true
rand.workspace = true
reqwest = { version = "0.11", features = ["json"] }
rpc.workspace = true
-scrypt = "0.11"
# sea-orm and sea-orm-macros versions must match exactly.
sea-orm = { version = "=1.1.10", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls", "with-uuid"] }
sea-orm-macros = "=1.1.10"
@@ -63,7 +61,6 @@ serde_json.workspace = true
sha2.workspace = true
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid", "any"] }
strum.workspace = true
-subtle.workspace = true
telemetry_events.workspace = true
text.workspace = true
time.workspace = true
@@ -1,26 +1,13 @@
-use crate::{
- AppState, Error, Result,
- db::{AccessTokenId, Database, UserId},
- rpc::Principal,
-};
+use crate::{AppState, Error, db::UserId, rpc::Principal};
use anyhow::Context as _;
use axum::{
http::{self, Request, StatusCode},
middleware::Next,
response::IntoResponse,
};
-use base64::prelude::*;
-use prometheus::{Histogram, exponential_buckets, register_histogram};
+use cloud_api_types::GetAuthenticatedUserResponse;
pub use rpc::auth::random_token;
-use scrypt::{
- Scrypt,
- password_hash::{PasswordHash, PasswordVerifier},
-};
-use serde::{Deserialize, Serialize};
-use sha2::Digest;
-use std::sync::OnceLock;
-use std::{sync::Arc, time::Instant};
-use subtle::ConstantTimeEq;
+use std::sync::Arc;
/// Validates the authorization header and adds an Extension<Principal> to the request.
/// Authorization: <user-id> <token>
@@ -64,11 +51,23 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
)
})?;
- let validate_result = verify_access_token(access_token, user_id, &state.db).await;
+ let http_client = state.http_client.clone().expect("no HTTP client");
+
+ let response = http_client
+ .get(format!("{}/client/users/me", state.config.zed_cloud_url()))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("{user_id} {access_token}"))
+ .send()
+ .await
+ .context("failed to validate access token")?;
+ if let Ok(response) = response.error_for_status() {
+ let response_body: GetAuthenticatedUserResponse = response
+ .json()
+ .await
+ .context("failed to parse response body")?;
+
+ let user_id = UserId(response_body.user.id);
- if let Ok(validate_result) = validate_result
- && validate_result.is_valid
- {
let user = state
.db
.get_user_by_id(user_id)
@@ -84,68 +83,3 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
"invalid credentials".to_string(),
))
}
-
-#[derive(Serialize, Deserialize)]
-pub struct AccessTokenJson {
- pub version: usize,
- pub id: AccessTokenId,
- pub token: String,
-}
-
-/// Hashing prevents anyone with access to the database being able to login.
-/// As the token is randomly generated, we don't need to worry about scrypt-style
-/// protection.
-pub fn hash_access_token(token: &str) -> String {
- let digest = sha2::Sha256::digest(token);
- format!("$sha256${}", BASE64_URL_SAFE.encode(digest))
-}
-
-pub struct VerifyAccessTokenResult {
- pub is_valid: bool,
-}
-
-/// Checks that the given access token is valid for the given user.
-pub async fn verify_access_token(
- token: &str,
- user_id: UserId,
- db: &Arc<Database>,
-) -> Result<VerifyAccessTokenResult> {
- static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
- let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
- register_histogram!(
- "access_token_hashing_time",
- "time spent hashing access tokens",
- exponential_buckets(10.0, 2.0, 10).unwrap(),
- )
- .unwrap()
- });
-
- let token: AccessTokenJson = serde_json::from_str(token)?;
-
- let db_token = db.get_access_token(token.id).await?;
- if db_token.user_id != user_id {
- return Err(anyhow::anyhow!("no such access token"))?;
- }
- let t0 = Instant::now();
-
- let is_valid = if db_token.hash.starts_with("$scrypt$") {
- let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
- Scrypt
- .verify_password(token.token.as_bytes(), &db_hash)
- .is_ok()
- } else {
- let token_hash = hash_access_token(&token.token);
- db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
- };
-
- let duration = t0.elapsed();
- log::info!("hashed access token in {:?}", duration);
- metric_access_token_hashing_time.observe(duration.as_millis() as f64);
-
- if is_valid && db_token.hash.starts_with("$scrypt$") {
- let new_hash = hash_access_token(&token.token);
- db.update_access_token_hash(db_token.id, &new_hash).await?;
- }
-
- Ok(VerifyAccessTokenResult { is_valid })
-}
@@ -18,6 +18,9 @@ use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
+pub const VERSION: &str = env!("CARGO_PKG_VERSION");
+pub const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
+
pub type Result<T, E = Error> = std::result::Result<T, E>;
pub enum Error {
@@ -150,6 +153,14 @@ impl Config {
}
}
+ /// Returns the base Zed Cloud URL.
+ pub fn zed_cloud_url(&self) -> &str {
+ match self.zed_environment.as_ref() {
+ "development" => "http://localhost:8787",
+ _ => "https://cloud.zed.dev",
+ }
+ }
+
#[cfg(feature = "test-support")]
pub fn test() -> Self {
Self {
@@ -199,6 +210,7 @@ impl ServiceMode {
pub struct AppState {
pub db: Arc<Database>,
+ pub http_client: Option<reqwest::Client>,
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
pub executor: Executor,
@@ -228,9 +240,16 @@ impl AppState {
None
};
+ let user_agent = format!("Collab/{VERSION} ({})", REVISION.unwrap_or("unknown"));
+ let http_client = reqwest::Client::builder()
+ .user_agent(user_agent)
+ .build()
+ .context("failed to construct HTTP client")?;
+
let db = Arc::new(db);
let this = Self {
db: db.clone(),
+ http_client: Some(http_client),
livekit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
executor,
@@ -7,12 +7,12 @@ use axum::{
routing::get,
};
-use collab::ServiceMode;
use collab::api::CloudflareIpCountryHeader;
use collab::{
AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
executor::Executor,
};
+use collab::{REVISION, ServiceMode, VERSION};
use db::Database;
use std::{
env::args,
@@ -28,9 +28,6 @@ use tracing_subscriber::{
};
use util::ResultExt as _;
-const VERSION: &str = env!("CARGO_PKG_VERSION");
-const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
-
#[expect(clippy::result_large_err)]
#[tokio::main]
async fn main() -> Result<()> {
@@ -50,175 +50,3 @@ fn room_participants(room: &Entity<Room>, cx: &mut TestAppContext) -> RoomPartic
fn channel_id(room: &Entity<Room>, cx: &mut TestAppContext) -> Option<ChannelId> {
cx.read(|cx| room.read(cx).channel_id())
}
-
-mod auth_token_tests {
- use collab::auth::{
- AccessTokenJson, VerifyAccessTokenResult, hash_access_token, verify_access_token,
- };
- use rand::prelude::*;
- use scrypt::Scrypt;
- use scrypt::password_hash::{PasswordHasher, SaltString};
- use sea_orm::EntityTrait;
-
- use collab::db::{Database, NewUserParams, UserId, access_token};
- use collab::*;
-
- const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
-
- async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
- const VERSION: usize = 1;
- let access_token = ::rpc::auth::random_token();
- let access_token_hash = hash_access_token(&access_token);
- let id = db
- .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
- .await?;
- Ok(serde_json::to_string(&AccessTokenJson {
- version: VERSION,
- id,
- token: access_token,
- })?)
- }
-
- #[gpui::test]
- async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
- let test_db = crate::db_tests::TestDb::sqlite(cx.executor());
- let db = test_db.db();
-
- let user = db
- .create_user(
- "example@example.com",
- None,
- false,
- NewUserParams {
- github_login: "example".into(),
- github_user_id: 1,
- },
- )
- .await
- .unwrap();
-
- let token = create_access_token(db, user.user_id).await.unwrap();
- assert!(matches!(
- verify_access_token(&token, user.user_id, db).await.unwrap(),
- VerifyAccessTokenResult { is_valid: true }
- ));
-
- let old_token = create_previous_access_token(user.user_id, db)
- .await
- .unwrap();
-
- let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
- .unwrap()
- .id;
-
- let hash = db
- .transaction(|tx| async move {
- Ok(access_token::Entity::find_by_id(old_token_id)
- .one(&*tx)
- .await?)
- })
- .await
- .unwrap()
- .unwrap()
- .hash;
- assert!(hash.starts_with("$scrypt$"));
-
- assert!(matches!(
- verify_access_token(&old_token, user.user_id, db)
- .await
- .unwrap(),
- VerifyAccessTokenResult { is_valid: true }
- ));
-
- let hash = db
- .transaction(|tx| async move {
- Ok(access_token::Entity::find_by_id(old_token_id)
- .one(&*tx)
- .await?)
- })
- .await
- .unwrap()
- .unwrap()
- .hash;
- assert!(hash.starts_with("$sha256$"));
-
- assert!(matches!(
- verify_access_token(&old_token, user.user_id, db)
- .await
- .unwrap(),
- VerifyAccessTokenResult { is_valid: true }
- ));
-
- assert!(matches!(
- verify_access_token(&token, user.user_id, db).await.unwrap(),
- VerifyAccessTokenResult { is_valid: true }
- ));
- }
-
- async fn create_previous_access_token(user_id: UserId, db: &Database) -> Result<String> {
- let access_token = collab::auth::random_token();
- let access_token_hash = previous_hash_access_token(&access_token)?;
- let id = db
- .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
- .await?;
- Ok(serde_json::to_string(&AccessTokenJson {
- version: 1,
- id,
- token: access_token,
- })?)
- }
-
- #[expect(clippy::result_large_err)]
- fn previous_hash_access_token(token: &str) -> Result<String> {
- // Avoid slow hashing in debug mode.
- let params = if cfg!(debug_assertions) {
- scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
- } else {
- scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
- };
-
- Ok(Scrypt
- .hash_password_customized(
- token.as_bytes(),
- None,
- None,
- params,
- &SaltString::generate(PasswordHashRngCompat::new()),
- )
- .map_err(anyhow::Error::new)?
- .to_string())
- }
-
- // TODO: remove once we password_hash v0.6 is released.
- struct PasswordHashRngCompat(rand::rngs::ThreadRng);
-
- impl PasswordHashRngCompat {
- fn new() -> Self {
- Self(rand::rng())
- }
- }
-
- impl scrypt::password_hash::rand_core::RngCore for PasswordHashRngCompat {
- fn next_u32(&mut self) -> u32 {
- self.0.next_u32()
- }
-
- fn next_u64(&mut self) -> u64 {
- self.0.next_u64()
- }
-
- fn fill_bytes(&mut self, dest: &mut [u8]) {
- self.0.fill_bytes(dest);
- }
-
- fn try_fill_bytes(
- &mut self,
- dest: &mut [u8],
- ) -> Result<(), scrypt::password_hash::rand_core::Error> {
- self.fill_bytes(dest);
- Ok(())
- }
- }
-
- impl scrypt::password_hash::rand_core::CryptoRng for PasswordHashRngCompat {}
-}
@@ -564,6 +564,7 @@ impl TestServer {
) -> Arc<AppState> {
Arc::new(AppState {
db: test_db.db().clone(),
+ http_client: None,
livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
blob_store_client: None,
executor,