Add an integration test to verify collaborators are kept up to date

Antonio Scandurra created

Change summary

server/src/db.rs  | 101 +++-----------------------------------
server/src/rpc.rs | 127 ++++++++++++++++++++++++++++++++++++++++++++++--
zed/src/user.rs   |  44 +++++++++++-----
3 files changed, 161 insertions(+), 111 deletions(-)

Detailed changes

server/src/db.rs 🔗

@@ -108,53 +108,16 @@ impl Db {
         })
     }
 
-    pub async fn get_users_by_ids(
-        &self,
-        requester_id: UserId,
-        ids: impl Iterator<Item = UserId>,
-    ) -> Result<Vec<User>> {
-        let mut include_requester = false;
-        let ids = ids
-            .map(|id| {
-                if id == requester_id {
-                    include_requester = true;
-                }
-                id.0
-            })
-            .collect::<Vec<_>>();
-
+    pub async fn get_users_by_ids(&self, ids: impl Iterator<Item = UserId>) -> Result<Vec<User>> {
+        let ids = ids.map(|id| id.0).collect::<Vec<_>>();
         test_support!(self, {
-            // Only return users that are in a common channel with the requesting user.
-            // Also allow the requesting user to return their own data, even if they aren't
-            // in any channels.
             let query = "
-                SELECT
-                    users.*
-                FROM
-                    users, channel_memberships
-                WHERE
-                    users.id = ANY ($1) AND
-                    channel_memberships.user_id = users.id AND
-                    channel_memberships.channel_id IN (
-                        SELECT channel_id
-                        FROM channel_memberships
-                        WHERE channel_memberships.user_id = $2
-                    )
-                UNION
-                SELECT
-                    users.*
-                FROM
-                    users
-                WHERE
-                    $3 AND users.id = $2
+                SELECT users.*
+                FROM users
+                WHERE users.id = ANY ($1)
             ";
 
-            sqlx::query_as(query)
-                .bind(&ids)
-                .bind(requester_id)
-                .bind(include_requester)
-                .fetch_all(&self.pool)
-                .await
+            sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
         })
     }
 
@@ -582,45 +545,11 @@ pub mod tests {
         let friend1 = db.create_user("friend-1", false).await.unwrap();
         let friend2 = db.create_user("friend-2", false).await.unwrap();
         let friend3 = db.create_user("friend-3", false).await.unwrap();
-        let stranger = db.create_user("stranger", false).await.unwrap();
 
-        // A user can read their own info, even if they aren't in any channels.
         assert_eq!(
-            db.get_users_by_ids(
-                user,
-                [user, friend1, friend2, friend3, stranger].iter().copied()
-            )
-            .await
-            .unwrap(),
-            vec![User {
-                id: user,
-                github_login: "user".to_string(),
-                admin: false,
-            },],
-        );
-
-        // A user can read the info of any other user who is in a shared channel
-        // with them.
-        let org = db.create_org("test org", "test-org").await.unwrap();
-        let chan1 = db.create_org_channel(org, "channel-1").await.unwrap();
-        let chan2 = db.create_org_channel(org, "channel-2").await.unwrap();
-        let chan3 = db.create_org_channel(org, "channel-3").await.unwrap();
-
-        db.add_channel_member(chan1, user, false).await.unwrap();
-        db.add_channel_member(chan2, user, false).await.unwrap();
-        db.add_channel_member(chan1, friend1, false).await.unwrap();
-        db.add_channel_member(chan1, friend2, false).await.unwrap();
-        db.add_channel_member(chan2, friend2, false).await.unwrap();
-        db.add_channel_member(chan2, friend3, false).await.unwrap();
-        db.add_channel_member(chan3, stranger, false).await.unwrap();
-
-        assert_eq!(
-            db.get_users_by_ids(
-                user,
-                [user, friend1, friend2, friend3, stranger].iter().copied()
-            )
-            .await
-            .unwrap(),
+            db.get_users_by_ids([user, friend1, friend2, friend3].iter().copied())
+                .await
+                .unwrap(),
             vec![
                 User {
                     id: user,
@@ -644,18 +573,6 @@ pub mod tests {
                 }
             ]
         );
-
-        // The user's own info is only returned if they request it.
-        assert_eq!(
-            db.get_users_by_ids(user, [friend1].iter().copied())
-                .await
-                .unwrap(),
-            vec![User {
-                id: friend1,
-                github_login: "friend-1".to_string(),
-                admin: false,
-            },]
-        )
     }
 
     #[gpui::test]

server/src/rpc.rs 🔗

@@ -149,6 +149,9 @@ impl Server {
             let (connection_id, handle_io, mut incoming_rx) =
                 this.peer.add_connection(connection).await;
             this.add_connection(connection_id, user_id).await;
+            if let Err(err) = this.update_collaborators_for_users(&[user_id]).await {
+                log::error!("error updating collaborators for {:?}: {}", user_id, err);
+            }
 
             let handle_io = handle_io.fuse();
             futures::pin_mut!(handle_io);
@@ -668,17 +671,12 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::GetUsers>,
     ) -> tide::Result<()> {
-        let user_id = self
-            .state
-            .read()
-            .await
-            .user_id_for_connection(request.sender_id)?;
         let receipt = request.receipt();
         let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
         let users = self
             .app_state
             .db
-            .get_users_by_ids(user_id, user_ids)
+            .get_users_by_ids(user_ids)
             .await?
             .into_iter()
             .map(|user| proto::User {
@@ -2150,6 +2148,123 @@ mod tests {
             .await;
     }
 
+    #[gpui::test]
+    async fn test_collaborators(
+        mut cx_a: TestAppContext,
+        mut cx_b: TestAppContext,
+        mut cx_c: TestAppContext,
+    ) {
+        cx_a.foreground().forbid_parking();
+        let lang_registry = Arc::new(LanguageRegistry::new());
+
+        // Connect to a server as 3 clients.
+        let mut server = TestServer::start().await;
+        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
+        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
+        let (_client_c, user_store_c) = server.create_client(&mut cx_c, "user_c").await;
+
+        let fs = Arc::new(FakeFs::new());
+
+        // Share a worktree as client A.
+        fs.insert_tree(
+            "/a",
+            json!({
+                ".zed.toml": r#"collaborators = ["user_b", "user_c"]"#,
+            }),
+        )
+        .await;
+
+        let worktree_a = Worktree::open_local(
+            client_a.clone(),
+            "/a".as_ref(),
+            fs.clone(),
+            lang_registry.clone(),
+            &mut cx_a.to_async(),
+        )
+        .await
+        .unwrap();
+
+        user_store_a
+            .condition(&cx_a, |user_store, _| {
+                collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
+            })
+            .await;
+        user_store_b
+            .condition(&cx_b, |user_store, _| {
+                collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
+            })
+            .await;
+        user_store_c
+            .condition(&cx_c, |user_store, _| {
+                collaborators(user_store) == vec![("user_a", vec![("a", vec![])])]
+            })
+            .await;
+
+        let worktree_id = worktree_a
+            .update(&mut cx_a, |tree, cx| tree.as_local_mut().unwrap().share(cx))
+            .await
+            .unwrap();
+
+        let _worktree_b = Worktree::open_remote(
+            client_b.clone(),
+            worktree_id,
+            lang_registry.clone(),
+            &mut cx_b.to_async(),
+        )
+        .await
+        .unwrap();
+
+        user_store_a
+            .condition(&cx_a, |user_store, _| {
+                collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
+            })
+            .await;
+        user_store_b
+            .condition(&cx_b, |user_store, _| {
+                collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
+            })
+            .await;
+        user_store_c
+            .condition(&cx_c, |user_store, _| {
+                collaborators(user_store) == vec![("user_a", vec![("a", vec!["user_b"])])]
+            })
+            .await;
+
+        cx_a.update(move |_| drop(worktree_a));
+        user_store_a
+            .condition(&cx_a, |user_store, _| collaborators(user_store) == vec![])
+            .await;
+        user_store_b
+            .condition(&cx_b, |user_store, _| collaborators(user_store) == vec![])
+            .await;
+        user_store_c
+            .condition(&cx_c, |user_store, _| collaborators(user_store) == vec![])
+            .await;
+
+        fn collaborators(user_store: &UserStore) -> Vec<(&str, Vec<(&str, Vec<&str>)>)> {
+            user_store
+                .collaborators()
+                .iter()
+                .map(|collaborator| {
+                    let worktrees = collaborator
+                        .worktrees
+                        .iter()
+                        .map(|w| {
+                            (
+                                w.root_name.as_str(),
+                                w.participants
+                                    .iter()
+                                    .map(|p| p.github_login.as_str())
+                                    .collect(),
+                            )
+                        })
+                        .collect();
+                    (collaborator.user.github_login.as_str(), worktrees)
+                })
+                .collect()
+        }
+    }
+
     struct TestServer {
         peer: Arc<Peer>,
         app_state: Arc<AppState>,

zed/src/user.rs 🔗

@@ -1,6 +1,6 @@
 use crate::{
     http::{HttpClient, Method, Request, Url},
-    rpc::{self, Client, Status},
+    rpc::{Client, Status},
     util::TryFutureExt,
 };
 use anyhow::{anyhow, Context, Result};
@@ -21,13 +21,13 @@ pub struct User {
 }
 
 #[derive(Debug)]
-struct Collaborator {
+pub struct Collaborator {
     pub user: Arc<User>,
     pub worktrees: Vec<WorktreeMetadata>,
 }
 
 #[derive(Debug)]
-struct WorktreeMetadata {
+pub struct WorktreeMetadata {
     pub root_name: String,
     pub is_shared: bool,
     pub participants: Vec<Arc<User>>,
@@ -39,7 +39,7 @@ pub struct UserStore {
     collaborators: Vec<Collaborator>,
     rpc: Arc<Client>,
     http: Arc<dyn HttpClient>,
-    _maintain_collaborators: rpc::Subscription,
+    _maintain_collaborators: Task<()>,
     _maintain_current_user: Task<()>,
 }
 
@@ -52,13 +52,31 @@ impl Entity for UserStore {
 impl UserStore {
     pub fn new(rpc: Arc<Client>, http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
         let (mut current_user_tx, current_user_rx) = watch::channel();
+        let (mut update_collaborators_tx, mut update_collaborators_rx) =
+            watch::channel::<Option<proto::UpdateCollaborators>>();
+        let update_collaborators_subscription = rpc.subscribe(
+            cx,
+            move |_: &mut Self, msg: TypedEnvelope<proto::UpdateCollaborators>, _, _| {
+                let _ = update_collaborators_tx.blocking_send(Some(msg.payload));
+                Ok(())
+            },
+        );
         Self {
             users: Default::default(),
             current_user: current_user_rx,
             collaborators: Default::default(),
             rpc: rpc.clone(),
             http,
-            _maintain_collaborators: rpc.subscribe(cx, Self::update_collaborators),
+            _maintain_collaborators: cx.spawn_weak(|this, mut cx| async move {
+                let _subscription = update_collaborators_subscription;
+                while let Some(message) = update_collaborators_rx.recv().await {
+                    if let Some((message, this)) = message.zip(this.upgrade(&cx)) {
+                        this.update(&mut cx, |this, cx| this.update_collaborators(message, cx))
+                            .log_err()
+                            .await;
+                    }
+                }
+            }),
             _maintain_current_user: cx.spawn_weak(|this, mut cx| async move {
                 let mut status = rpc.status();
                 while let Some(status) = status.recv().await {
@@ -84,12 +102,11 @@ impl UserStore {
 
     fn update_collaborators(
         &mut self,
-        message: TypedEnvelope<proto::UpdateCollaborators>,
-        _: Arc<Client>,
+        message: proto::UpdateCollaborators,
         cx: &mut ModelContext<Self>,
-    ) -> Result<()> {
+    ) -> Task<Result<()>> {
         let mut user_ids = HashSet::new();
-        for collaborator in &message.payload.collaborators {
+        for collaborator in &message.collaborators {
             user_ids.insert(collaborator.user_id);
             user_ids.extend(
                 collaborator
@@ -105,7 +122,7 @@ impl UserStore {
             load_users.await?;
 
             let mut collaborators = Vec::new();
-            for collaborator in message.payload.collaborators {
+            for collaborator in message.collaborators {
                 collaborators.push(Collaborator::from_proto(collaborator, &this, &mut cx).await?);
             }
 
@@ -114,11 +131,12 @@ impl UserStore {
                 cx.notify();
             });
 
-            Result::<_, anyhow::Error>::Ok(())
+            Ok(())
         })
-        .detach();
+    }
 
-        Ok(())
+    pub fn collaborators(&self) -> &[Collaborator] {
+        &self.collaborators
     }
 
     pub fn load_users(