collab: Introduce `UserService` (#55449)

Marshall Bowers created

This PR introduces a `UserService` trait to Collab.

This is a step towards moving Collab away from reading user information
directly from the database.

We currently have two implementations for the trait:

- The `DatabaseUserService`, which leverages the existing query methods
to talk to the database
- The `FakeUserService`, which will be used in tests

Once we're ready, we'll be able to replace the `DatabaseUserService`
with a `CloudUserService` to fetch the users from Cloud.

Release Notes:

- N/A

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql |  24 
crates/collab/src/db/queries/channels.rs                       |  20 
crates/collab/src/db/tables/user.rs                            |   2 
crates/collab/src/entities/user.rs                             |   2 
crates/collab/src/lib.rs                                       |   5 
crates/collab/src/rpc.rs                                       |  43 
crates/collab/src/services.rs                                  |   3 
crates/collab/src/services/user_service.rs                     | 243 ++++
crates/collab/tests/integration/channel_guest_tests.rs         |   2 
crates/collab/tests/integration/test_server.rs                 |  24 
10 files changed, 334 insertions(+), 34 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -21,8 +21,8 @@ CREATE UNIQUE INDEX "index_users_on_github_user_id" ON "users" ("github_user_id"
 
 CREATE TABLE "contacts" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
-    "user_id_a" INTEGER REFERENCES users (id) NOT NULL,
-    "user_id_b" INTEGER REFERENCES users (id) NOT NULL,
+    "user_id_a" INTEGER NOT NULL,
+    "user_id_b" INTEGER NOT NULL,
     "a_to_b" BOOLEAN NOT NULL,
     "should_notify" BOOLEAN NOT NULL,
     "accepted" BOOLEAN NOT NULL
@@ -44,7 +44,7 @@ CREATE UNIQUE INDEX "index_rooms_on_channel_id" ON "rooms" ("channel_id");
 CREATE TABLE "projects" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "room_id" INTEGER REFERENCES rooms (id) ON DELETE CASCADE,
-    "host_user_id" INTEGER REFERENCES users (id),
+    "host_user_id" INTEGER,
     "host_connection_id" INTEGER,
     "host_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE CASCADE,
     "unregistered" BOOLEAN NOT NULL DEFAULT FALSE,
@@ -208,14 +208,14 @@ CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_connection_id_and
 CREATE TABLE "room_participants" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "room_id" INTEGER NOT NULL REFERENCES rooms (id),
-    "user_id" INTEGER NOT NULL REFERENCES users (id),
+    "user_id" INTEGER NOT NULL,
     "answering_connection_id" INTEGER,
     "answering_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE CASCADE,
     "answering_connection_lost" BOOLEAN NOT NULL,
     "location_kind" INTEGER,
     "location_project_id" INTEGER,
     "initial_project_id" INTEGER,
-    "calling_user_id" INTEGER NOT NULL REFERENCES users (id),
+    "calling_user_id" INTEGER NOT NULL,
     "calling_connection_id" INTEGER NOT NULL,
     "calling_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE SET NULL,
     "participant_index" INTEGER,
@@ -279,7 +279,7 @@ CREATE INDEX "index_channels_on_parent_path_and_order" ON "channels" ("parent_pa
 
 CREATE TABLE IF NOT EXISTS "channel_chat_participants" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
-    "user_id" INTEGER NOT NULL REFERENCES users (id),
+    "user_id" INTEGER NOT NULL,
     "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
     "connection_id" INTEGER NOT NULL,
     "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE
@@ -290,7 +290,7 @@ CREATE INDEX "index_channel_chat_participants_on_channel_id" ON "channel_chat_pa
 CREATE TABLE "channel_members" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
-    "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+    "user_id" INTEGER NOT NULL,
     "role" VARCHAR NOT NULL,
     "accepted" BOOLEAN NOT NULL DEFAULT false,
     "updated_at" TIMESTAMP NOT NULL DEFAULT now
@@ -332,7 +332,7 @@ CREATE TABLE "channel_buffer_collaborators" (
     "connection_id" INTEGER NOT NULL,
     "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE,
     "connection_lost" BOOLEAN NOT NULL DEFAULT false,
-    "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+    "user_id" INTEGER NOT NULL,
     "replica_id" INTEGER NOT NULL
 );
 
@@ -351,7 +351,7 @@ CREATE UNIQUE INDEX "index_channel_buffer_collaborators_on_channel_id_connection
 );
 
 CREATE TABLE "observed_buffer_edits" (
-    "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+    "user_id" INTEGER NOT NULL,
     "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE,
     "epoch" INTEGER NOT NULL,
     "lamport_timestamp" INTEGER NOT NULL,
@@ -371,7 +371,7 @@ CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" (
 CREATE TABLE "notifications" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "created_at" TIMESTAMP NOT NULL default CURRENT_TIMESTAMP,
-    "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+    "recipient_id" INTEGER NOT NULL,
     "kind" INTEGER NOT NULL REFERENCES notification_kinds (id),
     "entity_id" INTEGER,
     "content" TEXT,
@@ -382,7 +382,7 @@ CREATE TABLE "notifications" (
 CREATE INDEX "index_notifications_on_recipient_id_is_read_kind_entity_id" ON "notifications" ("recipient_id", "is_read", "kind", "entity_id");
 
 CREATE TABLE contributors (
-    user_id INTEGER REFERENCES users (id),
+    user_id INTEGER,
     signed_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
     PRIMARY KEY (user_id)
 );
@@ -437,7 +437,7 @@ CREATE INDEX "index_breakpoints_on_project_id" ON "breakpoints" ("project_id");
 
 CREATE TABLE IF NOT EXISTS "shared_threads" (
     "id" TEXT PRIMARY KEY NOT NULL,
-    "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+    "user_id" INTEGER NOT NULL,
     "title" VARCHAR(512) NOT NULL,
     "data" BLOB NOT NULL,
     "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,

crates/collab/src/db/queries/channels.rs 🔗

@@ -684,6 +684,26 @@ impl Database {
         .await
     }
 
+    /// Returns the channel memberships for the users with the specified IDs.
+    #[cfg(feature = "test-support")]
+    pub async fn get_channel_memberships_for_user_ids(
+        &self,
+        channel: &Channel,
+        ids: Vec<UserId>,
+    ) -> Result<Vec<channel_member::Model>> {
+        self.transaction(|tx| async {
+            let tx = tx;
+            let members = channel_member::Entity::find()
+                .filter(channel_member::Column::ChannelId.eq(channel.id))
+                .filter(channel_member::Column::UserId.is_in(ids.iter().copied()))
+                .all(&*tx)
+                .await?;
+
+            Ok(members)
+        })
+        .await
+    }
+
     /// Returns the details for the specified channel member.
     pub async fn get_channel_participant_details(
         &self,

crates/collab/src/db/tables/user.rs 🔗

@@ -25,6 +25,8 @@ impl From<Model> for crate::entities::User {
         crate::entities::User {
             id: user.id,
             github_login: user.github_login,
+            github_user_id: user.github_user_id,
+            name: user.name,
             admin: user.admin,
             connected_once: user.connected_once,
         }

crates/collab/src/entities/user.rs 🔗

@@ -4,6 +4,8 @@ use crate::db::UserId;
 pub struct User {
     pub id: UserId,
     pub github_login: String,
+    pub github_user_id: i32,
+    pub name: Option<String>,
     pub admin: bool,
     pub connected_once: bool,
 }

crates/collab/src/lib.rs 🔗

@@ -6,6 +6,7 @@ pub mod env;
 pub mod executor;
 pub mod rpc;
 pub mod seed;
+pub mod services;
 
 use anyhow::Context as _;
 use aws_config::{BehaviorVersion, Region};
@@ -19,6 +20,8 @@ use serde::Deserialize;
 use std::{path::PathBuf, sync::Arc};
 use util::ResultExt;
 
+use crate::services::{DatabaseUserService, UserService};
+
 pub const VERSION: &str = env!("CARGO_PKG_VERSION");
 pub const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
 
@@ -216,6 +219,7 @@ pub struct AppState {
     pub blob_store_client: Option<aws_sdk_s3::Client>,
     pub executor: Executor,
     pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
+    pub user_service: Arc<dyn UserService>,
     pub config: Config,
 }
 
@@ -259,6 +263,7 @@ impl AppState {
             } else {
                 None
             },
+            user_service: Arc::new(DatabaseUserService::new(db)),
             config,
         };
         Ok(Arc::new(this))

crates/collab/src/rpc.rs 🔗

@@ -2541,10 +2541,11 @@ async fn get_users(
         .map(UserId::from_proto)
         .collect();
     let users = session
-        .db()
-        .await
+        .app_state
+        .user_service
         .get_users_by_ids(user_ids)
-        .await?
+        .await?;
+    let users = users
         .into_iter()
         .map(|user| proto::User {
             id: user.id.to_proto(),
@@ -2567,13 +2568,19 @@ async fn fuzzy_search_users(
     let users = match query.len() {
         0 => vec![],
         1 | 2 => session
-            .db()
-            .await
+            .app_state
+            .user_service
             .get_user_by_github_login(&query)
             .await?
             .into_iter()
             .collect(),
-        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
+        _ => {
+            session
+                .app_state
+                .user_service
+                .fuzzy_search_users(&query, 10)
+                .await?
+        }
     };
     let users = users
         .into_iter()
@@ -3163,13 +3170,11 @@ async fn get_channel_members(
 
     let channel = db.get_channel(channel_id, session.user_id()).await?;
 
-    let (members, users) = db
-        .get_channel_participant_details(&channel, &request.query, limit)
+    let (members, users) = session
+        .app_state
+        .user_service
+        .search_channel_members(&channel, &request.query, limit as u32)
         .await?;
-    let members = members
-        .into_iter()
-        .map(proto::ChannelMember::from)
-        .collect();
     let users = users.into_iter().map(proto::User::from).collect();
 
     response.send(proto::GetChannelMembersResponse { members, users })?;
@@ -4081,3 +4086,17 @@ where
         }
     }
 }
+
+impl From<User> for proto::User {
+    fn from(user: User) -> Self {
+        Self {
+            id: user.id.to_proto(),
+            avatar_url: format!(
+                "https://avatars.githubusercontent.com/u/{}?s=128&v=4",
+                user.github_user_id
+            ),
+            github_login: user.github_login,
+            name: user.name,
+        }
+    }
+}

crates/collab/src/services/user_service.rs 🔗

@@ -0,0 +1,243 @@
+use std::sync::Arc;
+
+use async_trait::async_trait;
+use rpc::proto;
+
+use crate::Result;
+use crate::db::{Channel, Database, UserId};
+use crate::entities::User;
+
+#[cfg(feature = "test-support")]
+pub use self::fake_user_service::*;
+
+#[async_trait]
+pub trait UserService: Send + Sync + 'static {
+    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
+
+    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
+
+    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
+
+    // NOTE: This method is only tangentially related to users, but we're putting it on the `UserService` to avoid
+    // introducing a separate service.
+    //
+    // We're also using the `proto::ChannelMember` representation in the return type, as we don't yet have a domain
+    // representation of a channel member (and doesn't seem necessary to introduce one, at this point).
+    async fn search_channel_members(
+        &self,
+        channel: &Channel,
+        query: &str,
+        limit: u32,
+    ) -> Result<(Vec<proto::ChannelMember>, Vec<User>)>;
+
+    #[cfg(feature = "test-support")]
+    fn as_fake(&self) -> Arc<FakeUserService> {
+        panic!("called as_fake on a real `UserService`");
+    }
+}
+
+/// A [`UserService`] implementation backed by the database.
+pub struct DatabaseUserService {
+    database: Arc<Database>,
+}
+
+impl DatabaseUserService {
+    pub fn new(database: Arc<Database>) -> Self {
+        Self { database }
+    }
+}
+
+#[async_trait]
+impl UserService for DatabaseUserService {
+    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
+        let users = self.database.get_users_by_ids(ids).await?;
+
+        Ok(users.into_iter().map(User::from).collect())
+    }
+
+    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
+        let user = self.database.get_user_by_github_login(github_login).await?;
+
+        Ok(user.map(User::from))
+    }
+
+    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>> {
+        let users = self.database.fuzzy_search_users(query, limit).await?;
+
+        Ok(users.into_iter().map(User::from).collect())
+    }
+
+    async fn search_channel_members(
+        &self,
+        channel: &Channel,
+        query: &str,
+        limit: u32,
+    ) -> Result<(Vec<proto::ChannelMember>, Vec<User>)> {
+        let (members, users) = self
+            .database
+            .get_channel_participant_details(channel, query, limit as u64)
+            .await?;
+
+        Ok((
+            members
+                .into_iter()
+                .map(proto::ChannelMember::from)
+                .collect(),
+            users.into_iter().map(User::from).collect(),
+        ))
+    }
+}
+
+#[cfg(feature = "test-support")]
+mod fake_user_service {
+    use std::sync::Weak;
+
+    use collections::HashMap;
+    use tokio::sync::Mutex;
+
+    use super::*;
+
+    #[derive(Debug)]
+    pub struct NewUserParams {
+        pub github_login: String,
+        pub github_user_id: i32,
+    }
+
+    pub struct FakeUserService {
+        this: Weak<Self>,
+        state: Arc<Mutex<FakeUserServiceState>>,
+        database: Arc<Database>,
+    }
+
+    struct FakeUserServiceState {
+        next_user_id: UserId,
+        users: HashMap<UserId, User>,
+    }
+
+    impl Default for FakeUserServiceState {
+        fn default() -> Self {
+            Self {
+                next_user_id: UserId(1),
+                users: HashMap::default(),
+            }
+        }
+    }
+
+    impl FakeUserService {
+        pub fn new(database: Arc<Database>) -> Arc<Self> {
+            Arc::new_cyclic(|this| Self {
+                this: this.clone(),
+                state: Arc::new(Mutex::default()),
+                database,
+            })
+        }
+
+        pub async fn create_user(
+            &self,
+            email_address: &str,
+            name: Option<&str>,
+            admin: bool,
+            params: NewUserParams,
+        ) -> UserId {
+            let mut state = self.state.lock().await;
+
+            let user_id = state.next_user_id;
+            let _ = email_address;
+            state.users.insert(
+                user_id,
+                User {
+                    id: user_id,
+                    github_login: params.github_login,
+                    github_user_id: params.github_user_id,
+                    name: name.map(|name| name.to_string()),
+                    admin,
+                    connected_once: false,
+                },
+            );
+
+            state.next_user_id = UserId(state.next_user_id.0 + 1);
+
+            user_id
+        }
+
+        pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
+            let state = self.state.lock().await;
+
+            let user = state.users.get(&id).cloned();
+
+            Ok(user)
+        }
+    }
+
+    #[async_trait]
+    impl UserService for FakeUserService {
+        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
+            let state = self.state.lock().await;
+
+            let users = state
+                .users
+                .values()
+                .filter(|user| ids.contains(&user.id))
+                .cloned()
+                .collect();
+
+            Ok(users)
+        }
+
+        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
+            let state = self.state.lock().await;
+
+            let user = state
+                .users
+                .values()
+                .find(|user| user.github_login == github_login)
+                .cloned();
+
+            Ok(user)
+        }
+
+        async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>> {
+            let _ = query;
+            let _ = limit;
+            unimplemented!("not currently exercised by any tests")
+        }
+
+        async fn search_channel_members(
+            &self,
+            channel: &Channel,
+            query: &str,
+            limit: u32,
+        ) -> Result<(Vec<proto::ChannelMember>, Vec<User>)> {
+            let state = self.state.lock().await;
+
+            let users = state
+                .users
+                .values()
+                .filter(|user| user.github_login.contains(query))
+                .take(limit as usize)
+                .cloned()
+                .collect::<Vec<_>>();
+
+            let members = self
+                .database
+                .get_channel_memberships_for_user_ids(
+                    channel,
+                    users.iter().map(|user| user.id).collect(),
+                )
+                .await?;
+
+            Ok((
+                members
+                    .into_iter()
+                    .map(proto::ChannelMember::from)
+                    .collect(),
+                users,
+            ))
+        }
+
+        #[cfg(feature = "test-support")]
+        fn as_fake(&self) -> Arc<FakeUserService> {
+            self.this.upgrade().unwrap()
+        }
+    }
+}

crates/collab/tests/integration/channel_guest_tests.rs 🔗

@@ -281,7 +281,7 @@ async fn test_channel_requires_zed_cla(cx_a: &mut TestAppContext, cx_b: &mut Tes
     // User B signs the zed CLA.
     let user_b = server
         .app_state
-        .db
+        .user_service
         .get_user_by_github_login("user_b")
         .await
         .unwrap()

crates/collab/tests/integration/test_server.rs 🔗

@@ -7,9 +7,10 @@ use client::{
     proto::PeerId,
 };
 use clock::FakeSystemClock;
+use collab::services::{FakeUserService, NewUserParams};
 use collab::{
     AppState, Config,
-    db::{NewUserParams, UserId},
+    db::UserId,
     executor::Executor,
     rpc::{CLEANUP_TIMEOUT, Principal, RECONNECT_TIMEOUT, Server, ZedVersion},
 };
@@ -179,14 +180,19 @@ impl TestServer {
 
         let clock = Arc::new(FakeSystemClock::new());
 
-        let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
+        let user_id = if let Ok(Some(user)) = self
+            .app_state
+            .user_service
+            .get_user_by_github_login(name)
+            .await
         {
             user.id
         } else {
             let github_user_id = self.next_github_user_id;
             self.next_github_user_id += 1;
             self.app_state
-                .db
+                .user_service
+                .as_fake()
                 .create_user(
                     &format!("{name}@example.com"),
                     None,
@@ -197,8 +203,6 @@ impl TestServer {
                     },
                 )
                 .await
-                .expect("creating user failed")
-                .user_id
         };
 
         let http = FakeHttpClient::create({
@@ -244,7 +248,7 @@ impl TestServer {
         let client_name = name.to_string();
         let client = cx.update(|cx| Client::new(clock, http.clone(), cx));
         let server = self.server.clone();
-        let db = self.app_state.db.clone();
+        let user_service = self.app_state.user_service.clone();
         let connection_killers = self.connection_killers.clone();
         let forbid_connections = self.forbid_connections.clone();
 
@@ -268,7 +272,7 @@ impl TestServer {
                 );
 
                 let server = server.clone();
-                let db = db.clone();
+                let user_service = user_service.clone();
                 let connection_killers = connection_killers.clone();
                 let forbid_connections = forbid_connections.clone();
                 let client_name = client_name.clone();
@@ -281,7 +285,8 @@ impl TestServer {
                         let (client_conn, server_conn, killed) =
                             Connection::in_memory(cx.background_executor().clone());
                         let (connection_id_tx, connection_id_rx) = oneshot::channel();
-                        let user = db
+                        let user = user_service
+                            .as_fake()
                             .get_user_by_id(user_id)
                             .await
                             .map_err(|e| {
@@ -294,7 +299,7 @@ impl TestServer {
                         cx.background_spawn(server.handle_connection(
                             server_conn,
                             client_name,
-                            Principal::User(user.into()),
+                            Principal::User(user),
                             ZedVersion(semver::Version::new(1, 0, 0)),
                             Some("test".to_string()),
                             None,
@@ -576,6 +581,7 @@ impl TestServer {
             blob_store_client: None,
             executor,
             kinesis_client: None,
+            user_service: FakeUserService::new(test_db.db().clone()),
             config: Config {
                 http_port: 0,
                 database_url: "".into(),