Fully test contact request acceptance

Nathan Sobo created

* Be sure we send updates to multiple clients for the same user
* Be sure we send a full contacts update on initial connection

As part of this commit, I fixed an issue where we couldn't disconnect and reconnect in tests. The first disconnect would cause the I/O future to terminate asynchronously, which caused us to sign out even though the active connection didn't belong to that future. I added a guard to ensure that we only sign out if the I/O future is associated with the current connection.

Change summary

crates/client/src/client.rs |  14 +++-
crates/client/src/user.rs   |  57 ++++++++++++++---
crates/collab/src/db.rs     |  12 +++
crates/collab/src/rpc.rs    | 124 +++++++++++++++++++++++---------------
crates/rpc/src/peer.rs      |  10 ++
5 files changed, 149 insertions(+), 68 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -117,7 +117,7 @@ impl EstablishConnectionError {
     }
 }
 
-#[derive(Copy, Clone, Debug)]
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum Status {
     SignedOut,
     UpgradeRequired,
@@ -293,6 +293,7 @@ impl Client {
     }
 
     fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
+        log::info!("set status on client {}: {:?}", self.id, status);
         let mut state = self.state.write();
         *state.status.0.borrow_mut() = status;
 
@@ -629,10 +630,13 @@ impl Client {
 
     async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
         let executor = cx.background();
+        log::info!("add connection to peer");
         let (connection_id, handle_io, mut incoming) = self
             .peer
             .add_connection(conn, move |duration| executor.timer(duration))
             .await;
+        log::info!("set status to connected {}", connection_id);
+        self.set_status(Status::Connected { connection_id }, cx);
         cx.foreground()
             .spawn({
                 let cx = cx.clone();
@@ -730,15 +734,17 @@ impl Client {
             })
             .detach();
 
-        self.set_status(Status::Connected { connection_id }, cx);
-
         let handle_io = cx.background().spawn(handle_io);
         let this = self.clone();
         let cx = cx.clone();
         cx.foreground()
             .spawn(async move {
                 match handle_io.await {
-                    Ok(()) => this.set_status(Status::SignedOut, &cx),
+                    Ok(()) => {
+                        if *this.status().borrow() == (Status::Connected { connection_id }) {
+                            this.set_status(Status::SignedOut, &cx);
+                        }
+                    }
                     Err(err) => {
                         log::error!("connection error: {:?}", err);
                         this.set_status(Status::ConnectionLost, &cx);

crates/client/src/user.rs 🔗

@@ -1,5 +1,5 @@
 use super::{http::HttpClient, proto, Client, Status, TypedEnvelope};
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, Context, Result};
 use futures::{future, AsyncReadExt, Future};
 use gpui::{AsyncAppContext, Entity, ImageData, ModelContext, ModelHandle, Task};
 use postage::{prelude::Stream, sink::Sink, watch};
@@ -120,6 +120,7 @@ impl UserStore {
         message: proto::UpdateContacts,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<()>> {
+        log::info!("update contacts on client {:?}", message);
         let mut user_ids = HashSet::new();
         for contact in &message.contacts {
             user_ids.insert(contact.user_id);
@@ -167,24 +168,51 @@ impl UserStore {
                 HashSet::<u64>::from_iter(message.remove_outgoing_requests.iter().copied());
 
             this.update(&mut cx, |this, cx| {
+                // Remove contacts
                 this.contacts
                     .retain(|contact| !removed_contacts.contains(&contact.user.id));
-                this.contacts.extend(updated_contacts);
-                this.contacts
-                    .sort_by(|a, b| a.user.github_login.cmp(&b.user.github_login));
+                // Update existing contacts and insert new ones
+                for updated_contact in updated_contacts {
+                    match this
+                        .contacts
+                        .binary_search_by_key(&&updated_contact.user.github_login, |contact| {
+                            &contact.user.github_login
+                        }) {
+                        Ok(ix) => this.contacts[ix] = updated_contact,
+                        Err(ix) => this.contacts.insert(ix, updated_contact),
+                    }
+                }
                 cx.notify();
 
+                // Remove incoming contact requests
                 this.incoming_contact_requests
                     .retain(|user| !removed_incoming_requests.contains(&user.id));
-                this.incoming_contact_requests.extend(incoming_requests);
-                this.incoming_contact_requests
-                    .sort_by(|a, b| a.github_login.cmp(&b.github_login));
+                // Update existing incoming requests and insert new ones
+                for request in incoming_requests {
+                    match this
+                        .incoming_contact_requests
+                        .binary_search_by_key(&&request.github_login, |contact| {
+                            &contact.github_login
+                        }) {
+                        Ok(ix) => this.incoming_contact_requests[ix] = request,
+                        Err(ix) => this.incoming_contact_requests.insert(ix, request),
+                    }
+                }
 
+                // Remove outgoing contact requests
                 this.outgoing_contact_requests
                     .retain(|user| !removed_outgoing_requests.contains(&user.id));
-                this.outgoing_contact_requests.extend(outgoing_requests);
-                this.outgoing_contact_requests
-                    .sort_by(|a, b| a.github_login.cmp(&b.github_login));
+                // Update existing incoming requests and insert new ones
+                for request in outgoing_requests {
+                    match this
+                        .outgoing_contact_requests
+                        .binary_search_by_key(&&request.github_login, |contact| {
+                            &contact.github_login
+                        }) {
+                        Ok(ix) => this.outgoing_contact_requests[ix] = request,
+                        Err(ix) => this.outgoing_contact_requests.insert(ix, request),
+                    }
+                }
             });
 
             Ok(())
@@ -242,6 +270,13 @@ impl UserStore {
         }
     }
 
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn clear_contacts(&mut self) {
+        self.contacts.clear();
+        self.incoming_contact_requests.clear();
+        self.outgoing_contact_requests.clear();
+    }
+
     pub fn get_users(
         &mut self,
         mut user_ids: Vec<u64>,
@@ -297,7 +332,7 @@ impl UserStore {
         let http = self.http.clone();
         cx.spawn_weak(|this, mut cx| async move {
             if let Some(rpc) = client.upgrade() {
-                let response = rpc.request(request).await?;
+                let response = rpc.request(request).await.context("error loading users")?;
                 let users = future::join_all(
                     response
                         .users

crates/collab/src/db.rs 🔗

@@ -1097,6 +1097,7 @@ pub mod tests {
         contacts: Mutex<Vec<FakeContact>>,
     }
 
+    #[derive(Debug)]
     struct FakeContact {
         requester_id: UserId,
         responder_id: UserId,
@@ -1166,8 +1167,13 @@ pub mod tests {
             Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
         }
 
-        async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
-            unimplemented!()
+        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
+            Ok(self
+                .users
+                .lock()
+                .values()
+                .find(|user| user.github_login == github_login)
+                .cloned())
         }
 
         async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
@@ -1183,6 +1189,7 @@ pub mod tests {
             let mut current = Vec::new();
             let mut outgoing_requests = Vec::new();
             let mut incoming_requests = Vec::new();
+
             for contact in self.contacts.lock().iter() {
                 if contact.requester_id == id {
                     if contact.accepted {
@@ -1201,6 +1208,7 @@ pub mod tests {
                     }
                 }
             }
+
             Ok(Contacts {
                 current,
                 outgoing_requests,

crates/collab/src/rpc.rs 🔗

@@ -274,10 +274,7 @@ impl Server {
             {
                 let mut store = this.store_mut().await;
                 store.add_connection(connection_id, user_id);
-                let update_contacts = store.build_initial_contacts_update(contacts);
-                for connection_id in store.connection_ids_for_user(user_id) {
-                    this.peer.send(connection_id, update_contacts.clone())?;
-                }
+                this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
             }
 
             let handle_io = handle_io.fuse();
@@ -959,7 +956,6 @@ impl Server {
             .send_contact_request(requester_id, responder_id)
             .await?;
         
-
         // Update outgoing contact requests of requester
         let mut update = proto::UpdateContacts::default();
         update.outgoing_requests.push(responder_id.to_proto());
@@ -5035,18 +5031,21 @@ mod tests {
                 .collect()
         }
     }
-
-    #[gpui::test(iterations = 1)] // TODO: More iterations
-    async fn test_contacts_requests(executor: Arc<Deterministic>, cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
+    
+    #[gpui::test(iterations = 10)]
+    async fn test_contact_requests(executor: Arc<Deterministic>, cx_a: &mut TestAppContext, cx_a2: &mut TestAppContext, cx_b: &mut TestAppContext, cx_b2: &mut TestAppContext) {
         cx_a.foreground().forbid_parking();
 
         // Connect to a server as 3 clients.
         let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(cx_a, "user_a").await;
+        let client_a2 = server.create_client(cx_a2, "user_a").await;
         let client_b = server.create_client(cx_b, "user_b").await;
+        let client_b2 = server.create_client(cx_b2, "user_b").await;
+        
+        assert_eq!(client_a.user_id().unwrap(), client_a2.user_id().unwrap());
 
         // User A requests that user B become their contact
-
         client_a
             .user_store
             .read_with(cx_a, |store, _| {
@@ -5054,55 +5053,56 @@ mod tests {
             })
             .await
             .unwrap();
-        
         executor.run_until_parked();
         
-        // Both parties see the pending request appear. User B accepts the request.
+        // Both users see the pending request appear in all their clients.
+        assert_eq!(client_a.summarize_contacts(&cx_a).outgoing_requests, &["user_b"]);
+        assert_eq!(client_a2.summarize_contacts(&cx_a2).outgoing_requests, &["user_b"]);
+        assert_eq!(client_b.summarize_contacts(&cx_b).incoming_requests, &["user_a"]);
+        assert_eq!(client_b2.summarize_contacts(&cx_b2).incoming_requests, &["user_a"]);
         
-        client_a.user_store.read_with(cx_a, |store, _| {
-            let contacts = store
-                .outgoing_contact_requests()
-                .iter()
-                .map(|contact| contact.github_login.clone())
-                .collect::<Vec<_>>();
-            assert_eq!(contacts, &["user_b"]);
-        });
-                
+        // Contact requests are present upon connecting (tested here via disconnect/reconnect)
+        disconnect_and_reconnect(&client_a, cx_a).await;
+        disconnect_and_reconnect(&client_b, cx_b).await;
+        executor.run_until_parked();
+        assert_eq!(client_a.summarize_contacts(&cx_a).outgoing_requests, &["user_b"]);
+        assert_eq!(client_b.summarize_contacts(&cx_b).incoming_requests, &["user_a"]);
+        
+        // User B accepts the request.
         client_b.user_store.read_with(cx_b, |store, _| {
-            let contacts = store
-                .incoming_contact_requests()
-                .iter()
-                .map(|contact| contact.github_login.clone())
-                .collect::<Vec<_>>();
-            assert_eq!(contacts, &["user_a"]);
-            
             store.respond_to_contact_request(client_a.user_id().unwrap(), true)
         }).await.unwrap();
 
         executor.run_until_parked();
 
-        // User B sees user A as their contact now, and the incoming request from them is removed
-        client_b.user_store.read_with(cx_b, |store, _| {
-            let contacts = store
-                .contacts()
-                .iter()
-                .map(|contact| contact.user.github_login.clone())
-                .collect::<Vec<_>>();
-            assert_eq!(contacts, &["user_a"]);
-            assert!(store.incoming_contact_requests().is_empty());
-        });
-
-        // User A sees user B as their contact now, and the outgoing request to them is removed
-        client_a.user_store.read_with(cx_a, |store, _| {
-            let contacts = store
-                .contacts()
-                .iter()
-                .map(|contact| contact.user.github_login.clone())
-                .collect::<Vec<_>>();
-            assert_eq!(contacts, &["user_b"]);
-            assert!(store.outgoing_contact_requests().is_empty());
-        });
+        // User B sees user A as their contact now in all client, and the incoming request from them is removed.
+        let contacts_b = client_b.summarize_contacts(&cx_b);
+        assert_eq!(contacts_b.current, &["user_a"]);
+        assert!(contacts_b.incoming_requests.is_empty());
+        let contacts_b2 = client_b2.summarize_contacts(&cx_b2);
+        assert_eq!(contacts_b2.current, &["user_a"]);
+        assert!(contacts_b2.incoming_requests.is_empty());
         
+        // User A sees user B as their contact now in all clients, and the outgoing request to them is removed.
+        let contacts_a = client_a.summarize_contacts(&cx_a);
+        assert_eq!(contacts_a.current, &["user_b"]);
+        assert!(contacts_a.outgoing_requests.is_empty());
+        let contacts_a2 = client_a2.summarize_contacts(&cx_a2);
+        assert_eq!(contacts_a2.current, &["user_b"]);
+        assert!(contacts_a2.outgoing_requests.is_empty());
+
+        // Contacts are present upon connecting (tested here via disconnect/reconnect)
+        disconnect_and_reconnect(&client_a, cx_a).await;
+        disconnect_and_reconnect(&client_b, cx_b).await;
+        executor.run_until_parked();
+        assert_eq!(client_a.summarize_contacts(&cx_a).current, &["user_b"]);
+        // assert_eq!(client_b.summarize_contacts(&cx_b).current, &["user_a"]);
+        
+        async fn disconnect_and_reconnect(client: &TestClient, cx: &mut TestAppContext) {
+            client.disconnect(&cx.to_async()).unwrap();
+            client.clear_contacts(cx);
+            client.authenticate_and_connect(false, &cx.to_async()).await.unwrap();
+        }
     }
 
     #[gpui::test(iterations = 10)]
@@ -6143,7 +6143,11 @@ mod tests {
             });
 
             let http = FakeHttpClient::with_404_response();
-            let user_id = self.app_state.db.create_user(name, false).await.unwrap();
+            let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await {
+                user.id
+            } else {
+                self.app_state.db.create_user(name, false).await.unwrap()
+            };
             let client_name = name.to_string();
             let mut client = Client::new(http.clone());
             let server = self.server.clone();
@@ -6295,6 +6299,12 @@ mod tests {
             &self.client
         }
     }
+    
+    struct ContactsSummary {
+        pub current: Vec<String>,
+        pub outgoing_requests: Vec<String>,
+        pub incoming_requests: Vec<String>,
+    }
 
     impl TestClient {
         pub fn current_user_id(&self, cx: &TestAppContext) -> UserId {
@@ -6310,6 +6320,22 @@ mod tests {
                 .read_with(cx, |user_store, _| user_store.watch_current_user());
             while authed_user.next().await.unwrap().is_none() {}
         }
+        
+        fn clear_contacts(&self, cx: &mut TestAppContext) {
+            self.user_store.update(cx, |store, _| {
+                store.clear_contacts();
+            });
+        }
+        
+        fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary {
+            self.user_store.read_with(cx, |store, cx| {
+                ContactsSummary {
+                    current: store.contacts().iter().map(|contact| contact.user.github_login.clone()).collect(),
+                    outgoing_requests: store.outgoing_contact_requests().iter().map(|user| user.github_login.clone()).collect(),
+                    incoming_requests: store.incoming_contact_requests().iter().map(|user| user.github_login.clone()).collect(),
+                }            
+            })
+        }
 
         async fn build_local_project(
             &mut self,

crates/rpc/src/peer.rs 🔗

@@ -173,7 +173,10 @@ impl Peer {
                                     Err(anyhow!("timed out writing message"))?;
                                 }
                             }
-                            None => return Ok(()),
+                            None => {
+                                log::info!("outgoing channel closed");
+                                return Ok(())
+                            },
                         },
                         incoming = read_message => {
                             let incoming = incoming.context("received invalid RPC message")?;
@@ -181,7 +184,10 @@ impl Peer {
                             if let proto::Message::Envelope(incoming) = incoming {
                                 match incoming_tx.send(incoming).timeout(RECEIVE_TIMEOUT).await {
                                     Some(Ok(_)) => {},
-                                    Some(Err(_)) => return Ok(()),
+                                    Some(Err(_)) => {
+                                        log::info!("incoming channel closed");
+                                        return Ok(())
+                                    },
                                     None => Err(anyhow!("timed out processing incoming message"))?,
                                 }
                             }