From dacf984596b412db98dbd40dbf9db1a544161fc3 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 6 May 2026 14:15:36 -0400 Subject: [PATCH] collab: Introduce `UserService` (#55449) This PR introduces a `UserService` trait to Collab. This is a step towards moving Collab away from reading user information directly from the database. We currently have two implementations for the trait: - The `DatabaseUserService`, which leverages the existing query methods to talk to the database - The `FakeUserService`, which will be used in tests Once we're ready, we'll be able to replace the `DatabaseUserService` with a `CloudUserService` to fetch the users from Cloud. Release Notes: - N/A --- .../20221109000000_test_schema.sql | 24 +- crates/collab/src/db/queries/channels.rs | 20 ++ crates/collab/src/db/tables/user.rs | 2 + crates/collab/src/entities/user.rs | 2 + crates/collab/src/lib.rs | 5 + crates/collab/src/rpc.rs | 43 +++- crates/collab/src/services.rs | 3 + crates/collab/src/services/user_service.rs | 243 ++++++++++++++++++ .../tests/integration/channel_guest_tests.rs | 2 +- .../collab/tests/integration/test_server.rs | 24 +- 10 files changed, 334 insertions(+), 34 deletions(-) create mode 100644 crates/collab/src/services.rs create mode 100644 crates/collab/src/services/user_service.rs diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 0ef44682a11a60de5d12a8efc56ad90687a87324..9c39dd4c260b954a2cf8e2cf7374ffc478d9e7b3 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -21,8 +21,8 @@ CREATE UNIQUE INDEX "index_users_on_github_user_id" ON "users" ("github_user_id" CREATE TABLE "contacts" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, - "user_id_a" INTEGER REFERENCES users (id) NOT NULL, - "user_id_b" INTEGER REFERENCES users (id) NOT NULL, + "user_id_a" INTEGER NOT NULL, + "user_id_b" INTEGER NOT NULL, "a_to_b" BOOLEAN NOT NULL, "should_notify" BOOLEAN NOT NULL, "accepted" BOOLEAN NOT NULL @@ -44,7 +44,7 @@ CREATE UNIQUE INDEX "index_rooms_on_channel_id" ON "rooms" ("channel_id"); CREATE TABLE "projects" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "room_id" INTEGER REFERENCES rooms (id) ON DELETE CASCADE, - "host_user_id" INTEGER REFERENCES users (id), + "host_user_id" INTEGER, "host_connection_id" INTEGER, "host_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE CASCADE, "unregistered" BOOLEAN NOT NULL DEFAULT FALSE, @@ -208,14 +208,14 @@ CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_connection_id_and CREATE TABLE "room_participants" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "room_id" INTEGER NOT NULL REFERENCES rooms (id), - "user_id" INTEGER NOT NULL REFERENCES users (id), + "user_id" INTEGER NOT NULL, "answering_connection_id" INTEGER, "answering_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE CASCADE, "answering_connection_lost" BOOLEAN NOT NULL, "location_kind" INTEGER, "location_project_id" INTEGER, "initial_project_id" INTEGER, - "calling_user_id" INTEGER NOT NULL REFERENCES users (id), + "calling_user_id" INTEGER NOT NULL, "calling_connection_id" INTEGER NOT NULL, "calling_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE SET NULL, "participant_index" INTEGER, @@ -279,7 +279,7 @@ CREATE INDEX "index_channels_on_parent_path_and_order" ON "channels" ("parent_pa CREATE TABLE IF NOT EXISTS "channel_chat_participants" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, - "user_id" INTEGER NOT NULL REFERENCES users (id), + "user_id" INTEGER NOT NULL, "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, "connection_id" INTEGER NOT NULL, "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE @@ -290,7 +290,7 @@ CREATE INDEX "index_channel_chat_participants_on_channel_id" ON "channel_chat_pa CREATE TABLE "channel_members" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, - "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "user_id" INTEGER NOT NULL, "role" VARCHAR NOT NULL, "accepted" BOOLEAN NOT NULL DEFAULT false, "updated_at" TIMESTAMP NOT NULL DEFAULT now @@ -332,7 +332,7 @@ CREATE TABLE "channel_buffer_collaborators" ( "connection_id" INTEGER NOT NULL, "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE, "connection_lost" BOOLEAN NOT NULL DEFAULT false, - "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "user_id" INTEGER NOT NULL, "replica_id" INTEGER NOT NULL ); @@ -351,7 +351,7 @@ CREATE UNIQUE INDEX "index_channel_buffer_collaborators_on_channel_id_connection ); CREATE TABLE "observed_buffer_edits" ( - "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "user_id" INTEGER NOT NULL, "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE, "epoch" INTEGER NOT NULL, "lamport_timestamp" INTEGER NOT NULL, @@ -371,7 +371,7 @@ CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ( CREATE TABLE "notifications" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "created_at" TIMESTAMP NOT NULL default CURRENT_TIMESTAMP, - "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "recipient_id" INTEGER NOT NULL, "kind" INTEGER NOT NULL REFERENCES notification_kinds (id), "entity_id" INTEGER, "content" TEXT, @@ -382,7 +382,7 @@ CREATE TABLE "notifications" ( CREATE INDEX "index_notifications_on_recipient_id_is_read_kind_entity_id" ON "notifications" ("recipient_id", "is_read", "kind", "entity_id"); CREATE TABLE contributors ( - user_id INTEGER REFERENCES users (id), + user_id INTEGER, signed_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (user_id) ); @@ -437,7 +437,7 @@ CREATE INDEX "index_breakpoints_on_project_id" ON "breakpoints" ("project_id"); CREATE TABLE IF NOT EXISTS "shared_threads" ( "id" TEXT PRIMARY KEY NOT NULL, - "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "user_id" INTEGER NOT NULL, "title" VARCHAR(512) NOT NULL, "data" BLOB NOT NULL, "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index 7b435ba1aa2bffc443261c56df0bf26d24d59f7a..b4ee2caa0d69c44426e38f68e29421c3c5d0faca 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -684,6 +684,26 @@ impl Database { .await } + /// Returns the channel memberships for the users with the specified IDs. + #[cfg(feature = "test-support")] + pub async fn get_channel_memberships_for_user_ids( + &self, + channel: &Channel, + ids: Vec, + ) -> Result> { + self.transaction(|tx| async { + let tx = tx; + let members = channel_member::Entity::find() + .filter(channel_member::Column::ChannelId.eq(channel.id)) + .filter(channel_member::Column::UserId.is_in(ids.iter().copied())) + .all(&*tx) + .await?; + + Ok(members) + }) + .await + } + /// Returns the details for the specified channel member. pub async fn get_channel_participant_details( &self, diff --git a/crates/collab/src/db/tables/user.rs b/crates/collab/src/db/tables/user.rs index 933e78ed42698462c249ee8cd299db116c1fb921..c797fe41509b9c2b061bfdc331f3ad1ef526879f 100644 --- a/crates/collab/src/db/tables/user.rs +++ b/crates/collab/src/db/tables/user.rs @@ -25,6 +25,8 @@ impl From for crate::entities::User { crate::entities::User { id: user.id, github_login: user.github_login, + github_user_id: user.github_user_id, + name: user.name, admin: user.admin, connected_once: user.connected_once, } diff --git a/crates/collab/src/entities/user.rs b/crates/collab/src/entities/user.rs index 0c31d78ac51002df384f8e58e074e5f0e8804b4f..248916ad81dd3a67aac3f1e6d4b9f4ccad60c702 100644 --- a/crates/collab/src/entities/user.rs +++ b/crates/collab/src/entities/user.rs @@ -4,6 +4,8 @@ use crate::db::UserId; pub struct User { pub id: UserId, pub github_login: String, + pub github_user_id: i32, + pub name: Option, pub admin: bool, pub connected_once: bool, } diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 51541242a4474de951af97cefb49496632973d0a..91259b4ce402e57223907feae6190ac6575e10cd 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -6,6 +6,7 @@ pub mod env; pub mod executor; pub mod rpc; pub mod seed; +pub mod services; use anyhow::Context as _; use aws_config::{BehaviorVersion, Region}; @@ -19,6 +20,8 @@ use serde::Deserialize; use std::{path::PathBuf, sync::Arc}; use util::ResultExt; +use crate::services::{DatabaseUserService, UserService}; + pub const VERSION: &str = env!("CARGO_PKG_VERSION"); pub const REVISION: Option<&'static str> = option_env!("GITHUB_SHA"); @@ -216,6 +219,7 @@ pub struct AppState { pub blob_store_client: Option, pub executor: Executor, pub kinesis_client: Option<::aws_sdk_kinesis::Client>, + pub user_service: Arc, pub config: Config, } @@ -259,6 +263,7 @@ impl AppState { } else { None }, + user_service: Arc::new(DatabaseUserService::new(db)), config, }; Ok(Arc::new(this)) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 39f442bcafd9d8816ca886b19370520e5a44e590..3412a40c4a8a7e35add91b76892aefb309d28da4 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2541,10 +2541,11 @@ async fn get_users( .map(UserId::from_proto) .collect(); let users = session - .db() - .await + .app_state + .user_service .get_users_by_ids(user_ids) - .await? + .await?; + let users = users .into_iter() .map(|user| proto::User { id: user.id.to_proto(), @@ -2567,13 +2568,19 @@ async fn fuzzy_search_users( let users = match query.len() { 0 => vec![], 1 | 2 => session - .db() - .await + .app_state + .user_service .get_user_by_github_login(&query) .await? .into_iter() .collect(), - _ => session.db().await.fuzzy_search_users(&query, 10).await?, + _ => { + session + .app_state + .user_service + .fuzzy_search_users(&query, 10) + .await? + } }; let users = users .into_iter() @@ -3163,13 +3170,11 @@ async fn get_channel_members( let channel = db.get_channel(channel_id, session.user_id()).await?; - let (members, users) = db - .get_channel_participant_details(&channel, &request.query, limit) + let (members, users) = session + .app_state + .user_service + .search_channel_members(&channel, &request.query, limit as u32) .await?; - let members = members - .into_iter() - .map(proto::ChannelMember::from) - .collect(); let users = users.into_iter().map(proto::User::from).collect(); response.send(proto::GetChannelMembersResponse { members, users })?; @@ -4081,3 +4086,17 @@ where } } } + +impl From for proto::User { + fn from(user: User) -> Self { + Self { + id: user.id.to_proto(), + avatar_url: format!( + "https://avatars.githubusercontent.com/u/{}?s=128&v=4", + user.github_user_id + ), + github_login: user.github_login, + name: user.name, + } + } +} diff --git a/crates/collab/src/services.rs b/crates/collab/src/services.rs new file mode 100644 index 0000000000000000000000000000000000000000..eb87237236f9bb1fdfe3a1be3e8585ae96a6d2e6 --- /dev/null +++ b/crates/collab/src/services.rs @@ -0,0 +1,3 @@ +mod user_service; + +pub use user_service::*; diff --git a/crates/collab/src/services/user_service.rs b/crates/collab/src/services/user_service.rs new file mode 100644 index 0000000000000000000000000000000000000000..e2696e99ff2610e49d603648a26416029ae50ee2 --- /dev/null +++ b/crates/collab/src/services/user_service.rs @@ -0,0 +1,243 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use rpc::proto; + +use crate::Result; +use crate::db::{Channel, Database, UserId}; +use crate::entities::User; + +#[cfg(feature = "test-support")] +pub use self::fake_user_service::*; + +#[async_trait] +pub trait UserService: Send + Sync + 'static { + async fn get_users_by_ids(&self, ids: Vec) -> Result>; + + async fn get_user_by_github_login(&self, github_login: &str) -> Result>; + + async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result>; + + // NOTE: This method is only tangentially related to users, but we're putting it on the `UserService` to avoid + // introducing a separate service. + // + // We're also using the `proto::ChannelMember` representation in the return type, as we don't yet have a domain + // representation of a channel member (and doesn't seem necessary to introduce one, at this point). + async fn search_channel_members( + &self, + channel: &Channel, + query: &str, + limit: u32, + ) -> Result<(Vec, Vec)>; + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> Arc { + panic!("called as_fake on a real `UserService`"); + } +} + +/// A [`UserService`] implementation backed by the database. +pub struct DatabaseUserService { + database: Arc, +} + +impl DatabaseUserService { + pub fn new(database: Arc) -> Self { + Self { database } + } +} + +#[async_trait] +impl UserService for DatabaseUserService { + async fn get_users_by_ids(&self, ids: Vec) -> Result> { + let users = self.database.get_users_by_ids(ids).await?; + + Ok(users.into_iter().map(User::from).collect()) + } + + async fn get_user_by_github_login(&self, github_login: &str) -> Result> { + let user = self.database.get_user_by_github_login(github_login).await?; + + Ok(user.map(User::from)) + } + + async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result> { + let users = self.database.fuzzy_search_users(query, limit).await?; + + Ok(users.into_iter().map(User::from).collect()) + } + + async fn search_channel_members( + &self, + channel: &Channel, + query: &str, + limit: u32, + ) -> Result<(Vec, Vec)> { + let (members, users) = self + .database + .get_channel_participant_details(channel, query, limit as u64) + .await?; + + Ok(( + members + .into_iter() + .map(proto::ChannelMember::from) + .collect(), + users.into_iter().map(User::from).collect(), + )) + } +} + +#[cfg(feature = "test-support")] +mod fake_user_service { + use std::sync::Weak; + + use collections::HashMap; + use tokio::sync::Mutex; + + use super::*; + + #[derive(Debug)] + pub struct NewUserParams { + pub github_login: String, + pub github_user_id: i32, + } + + pub struct FakeUserService { + this: Weak, + state: Arc>, + database: Arc, + } + + struct FakeUserServiceState { + next_user_id: UserId, + users: HashMap, + } + + impl Default for FakeUserServiceState { + fn default() -> Self { + Self { + next_user_id: UserId(1), + users: HashMap::default(), + } + } + } + + impl FakeUserService { + pub fn new(database: Arc) -> Arc { + Arc::new_cyclic(|this| Self { + this: this.clone(), + state: Arc::new(Mutex::default()), + database, + }) + } + + pub async fn create_user( + &self, + email_address: &str, + name: Option<&str>, + admin: bool, + params: NewUserParams, + ) -> UserId { + let mut state = self.state.lock().await; + + let user_id = state.next_user_id; + let _ = email_address; + state.users.insert( + user_id, + User { + id: user_id, + github_login: params.github_login, + github_user_id: params.github_user_id, + name: name.map(|name| name.to_string()), + admin, + connected_once: false, + }, + ); + + state.next_user_id = UserId(state.next_user_id.0 + 1); + + user_id + } + + pub async fn get_user_by_id(&self, id: UserId) -> Result> { + let state = self.state.lock().await; + + let user = state.users.get(&id).cloned(); + + Ok(user) + } + } + + #[async_trait] + impl UserService for FakeUserService { + async fn get_users_by_ids(&self, ids: Vec) -> Result> { + let state = self.state.lock().await; + + let users = state + .users + .values() + .filter(|user| ids.contains(&user.id)) + .cloned() + .collect(); + + Ok(users) + } + + async fn get_user_by_github_login(&self, github_login: &str) -> Result> { + let state = self.state.lock().await; + + let user = state + .users + .values() + .find(|user| user.github_login == github_login) + .cloned(); + + Ok(user) + } + + async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result> { + let _ = query; + let _ = limit; + unimplemented!("not currently exercised by any tests") + } + + async fn search_channel_members( + &self, + channel: &Channel, + query: &str, + limit: u32, + ) -> Result<(Vec, Vec)> { + let state = self.state.lock().await; + + let users = state + .users + .values() + .filter(|user| user.github_login.contains(query)) + .take(limit as usize) + .cloned() + .collect::>(); + + let members = self + .database + .get_channel_memberships_for_user_ids( + channel, + users.iter().map(|user| user.id).collect(), + ) + .await?; + + Ok(( + members + .into_iter() + .map(proto::ChannelMember::from) + .collect(), + users, + )) + } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> Arc { + self.this.upgrade().unwrap() + } + } +} diff --git a/crates/collab/tests/integration/channel_guest_tests.rs b/crates/collab/tests/integration/channel_guest_tests.rs index 95b1eeca5fc905d1b7db0502ec1f3110d8734746..5065a24c3a40b1ab8506837919f915f69e4d06b0 100644 --- a/crates/collab/tests/integration/channel_guest_tests.rs +++ b/crates/collab/tests/integration/channel_guest_tests.rs @@ -281,7 +281,7 @@ async fn test_channel_requires_zed_cla(cx_a: &mut TestAppContext, cx_b: &mut Tes // User B signs the zed CLA. let user_b = server .app_state - .db + .user_service .get_user_by_github_login("user_b") .await .unwrap() diff --git a/crates/collab/tests/integration/test_server.rs b/crates/collab/tests/integration/test_server.rs index 32f0e29c6dc8ed4d75e2c335ada6ffb1d1fb248c..820bcbd33765874bcc01722a3239c98ba9458f10 100644 --- a/crates/collab/tests/integration/test_server.rs +++ b/crates/collab/tests/integration/test_server.rs @@ -7,9 +7,10 @@ use client::{ proto::PeerId, }; use clock::FakeSystemClock; +use collab::services::{FakeUserService, NewUserParams}; use collab::{ AppState, Config, - db::{NewUserParams, UserId}, + db::UserId, executor::Executor, rpc::{CLEANUP_TIMEOUT, Principal, RECONNECT_TIMEOUT, Server, ZedVersion}, }; @@ -179,14 +180,19 @@ impl TestServer { let clock = Arc::new(FakeSystemClock::new()); - let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await + let user_id = if let Ok(Some(user)) = self + .app_state + .user_service + .get_user_by_github_login(name) + .await { user.id } else { let github_user_id = self.next_github_user_id; self.next_github_user_id += 1; self.app_state - .db + .user_service + .as_fake() .create_user( &format!("{name}@example.com"), None, @@ -197,8 +203,6 @@ impl TestServer { }, ) .await - .expect("creating user failed") - .user_id }; let http = FakeHttpClient::create({ @@ -244,7 +248,7 @@ impl TestServer { let client_name = name.to_string(); let client = cx.update(|cx| Client::new(clock, http.clone(), cx)); let server = self.server.clone(); - let db = self.app_state.db.clone(); + let user_service = self.app_state.user_service.clone(); let connection_killers = self.connection_killers.clone(); let forbid_connections = self.forbid_connections.clone(); @@ -268,7 +272,7 @@ impl TestServer { ); let server = server.clone(); - let db = db.clone(); + let user_service = user_service.clone(); let connection_killers = connection_killers.clone(); let forbid_connections = forbid_connections.clone(); let client_name = client_name.clone(); @@ -281,7 +285,8 @@ impl TestServer { let (client_conn, server_conn, killed) = Connection::in_memory(cx.background_executor().clone()); let (connection_id_tx, connection_id_rx) = oneshot::channel(); - let user = db + let user = user_service + .as_fake() .get_user_by_id(user_id) .await .map_err(|e| { @@ -294,7 +299,7 @@ impl TestServer { cx.background_spawn(server.handle_connection( server_conn, client_name, - Principal::User(user.into()), + Principal::User(user), ZedVersion(semver::Version::new(1, 0, 0)), Some("test".to_string()), None, @@ -576,6 +581,7 @@ impl TestServer { blob_store_client: None, executor, kinesis_client: None, + user_service: FakeUserService::new(test_db.db().clone()), config: Config { http_port: 0, database_url: "".into(),