Get server integration tests passing again

Max Brunsfeld and Nathan Sobo created

* Set up UserStore to have the current user, so that
  channel messages can be sent. This is needed now that
  pending messages are represented more similarly to
  regular messages.
* Drop buffer inside of an `AppContext.update` block, so that
  the Buffer's release hook is called in time.

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

server/src/db.rs     |  37 ++++++-
server/src/rpc.rs    | 206 +++++++++++++++++++++++----------------------
zed/src/channel.rs   |   2 
zed/src/user.rs      |   8 +
zed/src/workspace.rs |   4 
5 files changed, 140 insertions(+), 117 deletions(-)

Detailed changes

server/src/db.rs 🔗

@@ -133,14 +133,18 @@ impl Db {
             let query = "
                 SELECT users.*
                 FROM
-                    users, channel_memberships
+                    users LEFT JOIN channel_memberships
+                ON
+                    channel_memberships.user_id = users.id
                 WHERE
-                    users.id = ANY ($1) AND
-                    channel_memberships.user_id = users.id AND
-                    channel_memberships.channel_id IN (
-                        SELECT channel_id
-                        FROM channel_memberships
-                        WHERE channel_memberships.user_id = $2
+                    users.id = $2 OR
+                    (
+                        users.id = ANY ($1) AND
+                        channel_memberships.channel_id IN (
+                            SELECT channel_id
+                            FROM channel_memberships
+                            WHERE channel_memberships.user_id = $2
+                        )
                     )
             ";
 
@@ -455,7 +459,7 @@ macro_rules! id_type {
 }
 
 id_type!(UserId);
-#[derive(Debug, FromRow, Serialize)]
+#[derive(Debug, FromRow, Serialize, PartialEq)]
 pub struct User {
     pub id: UserId,
     pub github_login: String,
@@ -563,6 +567,23 @@ pub mod tests {
         }
     }
 
+    #[gpui::test]
+    async fn test_get_users_by_ids() {
+        let test_db = TestDb::new();
+        let db = test_db.db();
+        let user_id = db.create_user("user", false).await.unwrap();
+        assert_eq!(
+            db.get_users_by_ids(user_id, Some(user_id).iter().copied())
+                .await
+                .unwrap(),
+            vec![User {
+                id: user_id,
+                github_login: "user".to_string(),
+                admin: false,
+            }]
+        )
+    }
+
     #[gpui::test]
     async fn test_recent_channel_messages() {
         let test_db = TestDb::new();

server/src/rpc.rs 🔗

@@ -1039,8 +1039,8 @@ mod tests {
 
         // Connect to a server as 2 clients.
         let mut server = TestServer::start().await;
-        let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
-        let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
+        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
+        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
 
         cx_a.foreground().forbid_parking();
 
@@ -1124,7 +1124,7 @@ mod tests {
             .await;
 
         // Close the buffer as client A, see that the buffer is closed.
-        drop(buffer_a);
+        cx_a.update(move |_| drop(buffer_a));
         worktree_a
             .condition(&cx_a, |tree, cx| !tree.has_open_buffer("b.txt", cx))
             .await;
@@ -1147,9 +1147,9 @@ mod tests {
 
         // Connect to a server as 3 clients.
         let mut server = TestServer::start().await;
-        let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
-        let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
-        let (_, client_c) = server.create_client(&mut cx_c, "user_c").await;
+        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
+        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
+        let (client_c, _) = server.create_client(&mut cx_c, "user_c").await;
 
         let fs = Arc::new(FakeFs::new());
 
@@ -1288,8 +1288,8 @@ mod tests {
 
         // Connect to a server as 2 clients.
         let mut server = TestServer::start().await;
-        let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
-        let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
+        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
+        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
 
         // Share a local worktree as client A
         let fs = Arc::new(FakeFs::new());
@@ -1369,8 +1369,8 @@ mod tests {
 
         // Connect to a server as 2 clients.
         let mut server = TestServer::start().await;
-        let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
-        let (_, client_b) = server.create_client(&mut cx_b, "user_b").await;
+        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
+        let (client_b, _) = server.create_client(&mut cx_b, "user_b").await;
 
         // Share a local worktree as client A
         let fs = Arc::new(FakeFs::new());
@@ -1429,8 +1429,8 @@ mod tests {
 
         // Connect to a server as 2 clients.
         let mut server = TestServer::start().await;
-        let (_, client_a) = server.create_client(&mut cx_a, "user_a").await;
-        let (_, client_b) = server.create_client(&mut cx_a, "user_b").await;
+        let (client_a, _) = server.create_client(&mut cx_a, "user_a").await;
+        let (client_b, _) = server.create_client(&mut cx_a, "user_b").await;
 
         // Share a local worktree as client A
         let fs = Arc::new(FakeFs::new());
@@ -1484,38 +1484,39 @@ mod tests {
     #[gpui::test]
     async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
         cx_a.foreground().forbid_parking();
-        let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
 
         // Connect to a server as 2 clients.
         let mut server = TestServer::start().await;
-        let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
-        let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await;
+        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
+        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
 
         // Create an org that includes these 2 users.
         let db = &server.app_state.db;
         let org_id = db.create_org("Test Org", "test-org").await.unwrap();
-        db.add_org_member(org_id, user_id_a, false).await.unwrap();
-        db.add_org_member(org_id, user_id_b, false).await.unwrap();
+        db.add_org_member(org_id, current_user_id(&user_store_a), false)
+            .await
+            .unwrap();
+        db.add_org_member(org_id, current_user_id(&user_store_b), false)
+            .await
+            .unwrap();
 
         // Create a channel that includes all the users.
         let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
-        db.add_channel_member(channel_id, user_id_a, false)
+        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
             .await
             .unwrap();
-        db.add_channel_member(channel_id, user_id_b, false)
+        db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
             .await
             .unwrap();
         db.create_channel_message(
             channel_id,
-            user_id_b,
+            current_user_id(&user_store_b),
             "hello A, it's B.",
             OffsetDateTime::now_utc(),
         )
         .await
         .unwrap();
 
-        let user_store_a =
-            UserStore::new(client_a.clone(), http.clone(), cx_a.background().as_ref());
         let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
         channels_a
             .condition(&mut cx_a, |list, _| list.available_channels().is_some())
@@ -1536,12 +1537,10 @@ mod tests {
         channel_a
             .condition(&cx_a, |channel, _| {
                 channel_messages(channel)
-                    == [("user_b".to_string(), "hello A, it's B.".to_string())]
+                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
             })
             .await;
 
-        let user_store_b =
-            UserStore::new(client_b.clone(), http.clone(), cx_b.background().as_ref());
         let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
         channels_b
             .condition(&mut cx_b, |list, _| list.available_channels().is_some())
@@ -1563,7 +1562,7 @@ mod tests {
         channel_b
             .condition(&cx_b, |channel, _| {
                 channel_messages(channel)
-                    == [("user_b".to_string(), "hello A, it's B.".to_string())]
+                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
             })
             .await;
 
@@ -1575,28 +1574,25 @@ mod tests {
                     .detach();
                 let task = channel.send_message("sup".to_string(), cx).unwrap();
                 assert_eq!(
-                    channel
-                        .pending_messages()
-                        .iter()
-                        .map(|m| &m.body)
-                        .collect::<Vec<_>>(),
-                    &["oh, hi B.", "sup"]
+                    channel_messages(channel),
+                    &[
+                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
+                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
+                        ("user_a".to_string(), "sup".to_string(), true)
+                    ]
                 );
                 task
             })
             .await
             .unwrap();
 
-        channel_a
-            .condition(&cx_a, |channel, _| channel.pending_messages().is_empty())
-            .await;
         channel_b
             .condition(&cx_b, |channel, _| {
                 channel_messages(channel)
                     == [
-                        ("user_b".to_string(), "hello A, it's B.".to_string()),
-                        ("user_a".to_string(), "oh, hi B.".to_string()),
-                        ("user_a".to_string(), "sup".to_string()),
+                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
+                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
+                        ("user_a".to_string(), "sup".to_string(), false),
                     ]
             })
             .await;
@@ -1616,33 +1612,25 @@ mod tests {
         server
             .condition(|state| !state.channels.contains_key(&channel_id))
             .await;
-
-        fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
-            channel
-                .messages()
-                .cursor::<(), ()>()
-                .map(|m| (m.sender.github_login.clone(), m.body.clone()))
-                .collect()
-        }
     }
 
     #[gpui::test]
     async fn test_chat_message_validation(mut cx_a: TestAppContext) {
         cx_a.foreground().forbid_parking();
-        let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
 
         let mut server = TestServer::start().await;
-        let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
+        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
 
         let db = &server.app_state.db;
         let org_id = db.create_org("Test Org", "test-org").await.unwrap();
         let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
-        db.add_org_member(org_id, user_id_a, false).await.unwrap();
-        db.add_channel_member(channel_id, user_id_a, false)
+        db.add_org_member(org_id, current_user_id(&user_store_a), false)
+            .await
+            .unwrap();
+        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
             .await
             .unwrap();
 
-        let user_store_a = UserStore::new(client_a.clone(), http, cx_a.background().as_ref());
         let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
         channels_a
             .condition(&mut cx_a, |list, _| list.available_channels().is_some())
@@ -1692,27 +1680,31 @@ mod tests {
 
         // Connect to a server as 2 clients.
         let mut server = TestServer::start().await;
-        let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
-        let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await;
+        let (client_a, user_store_a) = server.create_client(&mut cx_a, "user_a").await;
+        let (client_b, user_store_b) = server.create_client(&mut cx_b, "user_b").await;
         let mut status_b = client_b.status();
 
         // Create an org that includes these 2 users.
         let db = &server.app_state.db;
         let org_id = db.create_org("Test Org", "test-org").await.unwrap();
-        db.add_org_member(org_id, user_id_a, false).await.unwrap();
-        db.add_org_member(org_id, user_id_b, false).await.unwrap();
+        db.add_org_member(org_id, current_user_id(&user_store_a), false)
+            .await
+            .unwrap();
+        db.add_org_member(org_id, current_user_id(&user_store_b), false)
+            .await
+            .unwrap();
 
         // Create a channel that includes all the users.
         let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
-        db.add_channel_member(channel_id, user_id_a, false)
+        db.add_channel_member(channel_id, current_user_id(&user_store_a), false)
             .await
             .unwrap();
-        db.add_channel_member(channel_id, user_id_b, false)
+        db.add_channel_member(channel_id, current_user_id(&user_store_b), false)
             .await
             .unwrap();
         db.create_channel_message(
             channel_id,
-            user_id_b,
+            current_user_id(&user_store_b),
             "hello A, it's B.",
             OffsetDateTime::now_utc(),
         )
@@ -1742,13 +1734,11 @@ mod tests {
         channel_a
             .condition(&cx_a, |channel, _| {
                 channel_messages(channel)
-                    == [("user_b".to_string(), "hello A, it's B.".to_string())]
+                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
             })
             .await;
 
-        let user_store_b =
-            UserStore::new(client_b.clone(), http.clone(), cx_b.background().as_ref());
-        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
+        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b.clone(), client_b, cx));
         channels_b
             .condition(&mut cx_b, |list, _| list.available_channels().is_some())
             .await;
@@ -1769,13 +1759,13 @@ mod tests {
         channel_b
             .condition(&cx_b, |channel, _| {
                 channel_messages(channel)
-                    == [("user_b".to_string(), "hello A, it's B.".to_string())]
+                    == [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
             })
             .await;
 
         // Disconnect client B, ensuring we can still access its cached channel data.
         server.forbid_connections();
-        server.disconnect_client(user_id_b);
+        server.disconnect_client(current_user_id(&user_store_b));
         while !matches!(
             status_b.recv().await,
             Some(rpc::Status::ReconnectionError { .. })
@@ -1793,7 +1783,7 @@ mod tests {
         channel_b.read_with(&cx_b, |channel, _| {
             assert_eq!(
                 channel_messages(channel),
-                [("user_b".to_string(), "hello A, it's B.".to_string())]
+                [("user_b".to_string(), "hello A, it's B.".to_string(), false)]
             )
         });
 
@@ -1806,12 +1796,12 @@ mod tests {
                     .detach();
                 let task = channel.send_message("sup".to_string(), cx).unwrap();
                 assert_eq!(
-                    channel
-                        .pending_messages()
-                        .iter()
-                        .map(|m| &m.body)
-                        .collect::<Vec<_>>(),
-                    &["oh, hi B.", "sup"]
+                    channel_messages(channel),
+                    &[
+                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
+                        ("user_a".to_string(), "oh, hi B.".to_string(), true),
+                        ("user_a".to_string(), "sup".to_string(), true)
+                    ]
                 );
                 task
             })
@@ -1827,9 +1817,9 @@ mod tests {
             .condition(&cx_b, |channel, _| {
                 channel_messages(channel)
                     == [
-                        ("user_b".to_string(), "hello A, it's B.".to_string()),
-                        ("user_a".to_string(), "oh, hi B.".to_string()),
-                        ("user_a".to_string(), "sup".to_string()),
+                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
+                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
+                        ("user_a".to_string(), "sup".to_string(), false),
                     ]
             })
             .await;
@@ -1845,10 +1835,10 @@ mod tests {
             .condition(&cx_b, |channel, _| {
                 channel_messages(channel)
                     == [
-                        ("user_b".to_string(), "hello A, it's B.".to_string()),
-                        ("user_a".to_string(), "oh, hi B.".to_string()),
-                        ("user_a".to_string(), "sup".to_string()),
-                        ("user_a".to_string(), "you online?".to_string()),
+                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
+                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
+                        ("user_a".to_string(), "sup".to_string(), false),
+                        ("user_a".to_string(), "you online?".to_string(), false),
                     ]
             })
             .await;
@@ -1863,22 +1853,14 @@ mod tests {
             .condition(&cx_a, |channel, _| {
                 channel_messages(channel)
                     == [
-                        ("user_b".to_string(), "hello A, it's B.".to_string()),
-                        ("user_a".to_string(), "oh, hi B.".to_string()),
-                        ("user_a".to_string(), "sup".to_string()),
-                        ("user_a".to_string(), "you online?".to_string()),
-                        ("user_b".to_string(), "yep".to_string()),
+                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
+                        ("user_a".to_string(), "oh, hi B.".to_string(), false),
+                        ("user_a".to_string(), "sup".to_string(), false),
+                        ("user_a".to_string(), "you online?".to_string(), false),
+                        ("user_b".to_string(), "yep".to_string(), false),
                     ]
             })
             .await;
-
-        fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
-            channel
-                .messages()
-                .cursor::<(), ()>()
-                .map(|m| (m.sender.github_login.clone(), m.body.clone()))
-                .collect()
-        }
     }
 
     struct TestServer {
@@ -1913,8 +1895,8 @@ mod tests {
             &mut self,
             cx: &mut TestAppContext,
             name: &str,
-        ) -> (UserId, Arc<Client>) {
-            let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
+        ) -> (Arc<Client>, Arc<UserStore>) {
+            let user_id = self.app_state.db.create_user(name, false).await.unwrap();
             let client_name = name.to_string();
             let mut client = Client::new();
             let server = self.server.clone();
@@ -1926,13 +1908,13 @@ mod tests {
                     cx.spawn(|_| async move {
                         let access_token = "the-token".to_string();
                         Ok(Credentials {
-                            user_id: client_user_id.0 as u64,
+                            user_id: user_id.0 as u64,
                             access_token,
                         })
                     })
                 })
                 .override_establish_connection(move |credentials, cx| {
-                    assert_eq!(credentials.user_id, client_user_id.0 as u64);
+                    assert_eq!(credentials.user_id, user_id.0 as u64);
                     assert_eq!(credentials.access_token, "the-token");
 
                     let server = server.clone();
@@ -1946,24 +1928,26 @@ mod tests {
                             )))
                         } else {
                             let (client_conn, server_conn, kill_conn) = Connection::in_memory();
-                            connection_killers.lock().insert(client_user_id, kill_conn);
+                            connection_killers.lock().insert(user_id, kill_conn);
                             cx.background()
-                                .spawn(server.handle_connection(
-                                    server_conn,
-                                    client_name,
-                                    client_user_id,
-                                ))
+                                .spawn(server.handle_connection(server_conn, client_name, user_id))
                                 .detach();
                             Ok(client_conn)
                         }
                     })
                 });
 
+            let http = FakeHttpClient::new(|_| async move { Ok(surf::http::Response::new(404)) });
             client
                 .authenticate_and_connect(&cx.to_async())
                 .await
                 .unwrap();
-            (client_user_id, client)
+
+            let user_store = UserStore::new(client.clone(), http, &cx.background());
+            let mut authed_user = user_store.watch_current_user();
+            while authed_user.recv().await.unwrap().is_none() {}
+
+            (client, user_store)
         }
 
         fn disconnect_client(&self, user_id: UserId) {
@@ -2019,6 +2003,24 @@ mod tests {
         }
     }
 
+    fn current_user_id(user_store: &Arc<UserStore>) -> UserId {
+        UserId::from_proto(user_store.current_user().unwrap().id)
+    }
+
+    fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
+        channel
+            .messages()
+            .cursor::<(), ()>()
+            .map(|m| {
+                (
+                    m.sender.github_login.clone(),
+                    m.body.clone(),
+                    m.is_pending(),
+                )
+            })
+            .collect()
+    }
+
     struct EmptyView;
 
     impl gpui::Entity for EmptyView {

zed/src/channel.rs 🔗

@@ -238,8 +238,6 @@ impl Channel {
         let current_user = self
             .user_store
             .current_user()
-            .borrow()
-            .clone()
             .ok_or_else(|| anyhow!("current_user is not present"))?;
 
         let channel_id = self.details.id;

zed/src/user.rs 🔗

@@ -111,8 +111,12 @@ impl UserStore {
             .ok_or_else(|| anyhow!("server responded with no users"))
     }
 
-    pub fn current_user(&self) -> &watch::Receiver<Option<Arc<User>>> {
-        &self.current_user
+    pub fn current_user(&self) -> Option<Arc<User>> {
+        self.current_user.borrow().clone()
+    }
+
+    pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
+        self.current_user.clone()
     }
 }
 

zed/src/workspace.rs 🔗

@@ -389,7 +389,7 @@ impl Workspace {
         );
         right_sidebar.add_item("icons/user-16.svg", cx.add_view(|_| ProjectBrowser).into());
 
-        let mut current_user = app_state.user_store.current_user().clone();
+        let mut current_user = app_state.user_store.watch_current_user().clone();
         let mut connection_status = app_state.rpc.status().clone();
         let _observe_current_user = cx.spawn_weak(|this, mut cx| async move {
             current_user.recv().await;
@@ -990,8 +990,6 @@ impl Workspace {
         let avatar = if let Some(avatar) = self
             .user_store
             .current_user()
-            .borrow()
-            .as_ref()
             .and_then(|user| user.avatar.clone())
         {
             Image::new(avatar)