WIP: Update contacts based on deltas rather than snapshots

Nathan Sobo created

Change summary

crates/client/src/user.rs      |  61 ++++++++++++--
crates/collab/src/db.rs        |  64 +++++++-------
crates/collab/src/rpc.rs       | 144 ++++++++++++++++++++---------------
crates/collab/src/rpc/store.rs | 146 ++++++++++++++++++++++++++---------
crates/rpc/proto/zed.proto     |  10 +
5 files changed, 277 insertions(+), 148 deletions(-)

Detailed changes

crates/client/src/user.rs 🔗

@@ -36,6 +36,8 @@ pub struct UserStore {
     update_contacts_tx: watch::Sender<Option<proto::UpdateContacts>>,
     current_user: watch::Receiver<Option<Arc<User>>>,
     contacts: Vec<Arc<Contact>>,
+    incoming_contact_requests: Vec<Arc<User>>,
+    outgoing_contact_requests: Vec<Arc<User>>,
     client: Weak<Client>,
     http: Arc<dyn HttpClient>,
     _maintain_contacts: Task<()>,
@@ -63,6 +65,8 @@ impl UserStore {
             users: Default::default(),
             current_user: current_user_rx,
             contacts: Default::default(),
+            incoming_contact_requests: Default::default(),
+            outgoing_contact_requests: Default::default(),
             client: Arc::downgrade(&client),
             update_contacts_tx,
             http,
@@ -121,29 +125,64 @@ impl UserStore {
             user_ids.insert(contact.user_id);
             user_ids.extend(contact.projects.iter().flat_map(|w| &w.guests).copied());
         }
-        user_ids.extend(message.pending_requests_to_user_ids.iter());
-        user_ids.extend(
-            message
-                .pending_requests_from_user_ids
-                .iter()
-                .map(|req| req.user_id),
-        );
+        user_ids.extend(message.incoming_requests.iter().map(|req| req.user_id));
+        user_ids.extend(message.outgoing_requests.iter());
 
         let load_users = self.get_users(user_ids.into_iter().collect(), cx);
         cx.spawn(|this, mut cx| async move {
             load_users.await?;
 
-            let mut contacts = Vec::new();
+            // Users are fetched in parallel above and cached in call to get_users
+            // No need to paralellize here
+            let mut updated_contacts = Vec::new();
             for contact in message.contacts {
-                contacts.push(Arc::new(
+                updated_contacts.push(Arc::new(
                     Contact::from_proto(contact, &this, &mut cx).await?,
                 ));
             }
 
+            let mut incoming_requests = Vec::new();
+            for request in message.incoming_requests {
+                incoming_requests.push(
+                    this.update(&mut cx, |this, cx| this.fetch_user(request.user_id, cx))
+                        .await?,
+                );
+            }
+
+            let mut outgoing_requests = Vec::new();
+            for requested_user_id in message.outgoing_requests {
+                outgoing_requests.push(
+                    this.update(&mut cx, |this, cx| this.fetch_user(requested_user_id, cx))
+                        .await?,
+                );
+            }
+
+            let removed_contacts =
+                HashSet::<u64>::from_iter(message.remove_contacts.iter().copied());
+            let removed_incoming_requests =
+                HashSet::<u64>::from_iter(message.remove_incoming_requests.iter().copied());
+            let removed_outgoing_requests =
+                HashSet::<u64>::from_iter(message.remove_outgoing_requests.iter().copied());
+
             this.update(&mut cx, |this, cx| {
-                contacts.sort_by(|a, b| a.user.github_login.cmp(&b.user.github_login));
-                this.contacts = contacts;
+                this.contacts
+                    .retain(|contact| !removed_contacts.contains(&contact.user.id));
+                this.contacts.extend(updated_contacts);
+                this.contacts
+                    .sort_by(|a, b| a.user.github_login.cmp(&b.user.github_login));
                 cx.notify();
+
+                this.incoming_contact_requests
+                    .retain(|user| !removed_incoming_requests.contains(&user.id));
+                this.incoming_contact_requests.extend(incoming_requests);
+                this.incoming_contact_requests
+                    .sort_by(|a, b| a.github_login.cmp(&b.github_login));
+
+                this.outgoing_contact_requests
+                    .retain(|user| !removed_outgoing_requests.contains(&user.id));
+                this.outgoing_contact_requests.extend(outgoing_requests);
+                this.outgoing_contact_requests
+                    .sort_by(|a, b| a.github_login.cmp(&b.github_login));
             });
 
             Ok(())

crates/collab/src/db.rs 🔗

@@ -199,8 +199,8 @@ impl Db for PostgresDb {
             .fetch(&self.pool);
 
         let mut current = Vec::new();
-        let mut requests_sent = Vec::new();
-        let mut requests_received = Vec::new();
+        let mut outgoing_requests = Vec::new();
+        let mut incoming_requests = Vec::new();
         while let Some(row) = rows.next().await {
             let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 
@@ -208,9 +208,9 @@ impl Db for PostgresDb {
                 if accepted {
                     current.push(user_id_b);
                 } else if a_to_b {
-                    requests_sent.push(user_id_b);
+                    outgoing_requests.push(user_id_b);
                 } else {
-                    requests_received.push(IncomingContactRequest {
+                    incoming_requests.push(IncomingContactRequest {
                         requesting_user_id: user_id_b,
                         should_notify,
                     });
@@ -219,20 +219,20 @@ impl Db for PostgresDb {
                 if accepted {
                     current.push(user_id_a);
                 } else if a_to_b {
-                    requests_received.push(IncomingContactRequest {
+                    incoming_requests.push(IncomingContactRequest {
                         requesting_user_id: user_id_a,
                         should_notify,
                     });
                 } else {
-                    requests_sent.push(user_id_a);
+                    outgoing_requests.push(user_id_a);
                 }
             }
         }
 
         Ok(Contacts {
             current,
-            requests_sent,
-            requests_received,
+            outgoing_requests,
+            incoming_requests,
         })
     }
 
@@ -669,8 +669,8 @@ pub struct ChannelMessage {
 #[derive(Clone, Debug, PartialEq, Eq)]
 pub struct Contacts {
     pub current: Vec<UserId>,
-    pub requests_sent: Vec<UserId>,
-    pub requests_received: Vec<IncomingContactRequest>,
+    pub incoming_requests: Vec<IncomingContactRequest>,
+    pub outgoing_requests: Vec<UserId>,
 }
 
 #[derive(Clone, Debug, PartialEq, Eq)]
@@ -914,8 +914,8 @@ pub mod tests {
                 db.get_contacts(user_1).await.unwrap(),
                 Contacts {
                     current: vec![],
-                    requests_sent: vec![],
-                    requests_received: vec![],
+                    outgoing_requests: vec![],
+                    incoming_requests: vec![],
                 },
             );
 
@@ -925,16 +925,16 @@ pub mod tests {
                 db.get_contacts(user_1).await.unwrap(),
                 Contacts {
                     current: vec![],
-                    requests_sent: vec![user_2],
-                    requests_received: vec![],
+                    outgoing_requests: vec![user_2],
+                    incoming_requests: vec![],
                 },
             );
             assert_eq!(
                 db.get_contacts(user_2).await.unwrap(),
                 Contacts {
                     current: vec![],
-                    requests_sent: vec![],
-                    requests_received: vec![IncomingContactRequest {
+                    outgoing_requests: vec![],
+                    incoming_requests: vec![IncomingContactRequest {
                         requesting_user_id: user_1,
                         should_notify: true
                     }],
@@ -951,8 +951,8 @@ pub mod tests {
                 db.get_contacts(user_2).await.unwrap(),
                 Contacts {
                     current: vec![],
-                    requests_sent: vec![],
-                    requests_received: vec![IncomingContactRequest {
+                    outgoing_requests: vec![],
+                    incoming_requests: vec![IncomingContactRequest {
                         requesting_user_id: user_1,
                         should_notify: false
                     }],
@@ -972,16 +972,16 @@ pub mod tests {
                 db.get_contacts(user_1).await.unwrap(),
                 Contacts {
                     current: vec![user_2],
-                    requests_sent: vec![],
-                    requests_received: vec![],
+                    outgoing_requests: vec![],
+                    incoming_requests: vec![],
                 },
             );
             assert_eq!(
                 db.get_contacts(user_2).await.unwrap(),
                 Contacts {
                     current: vec![user_1],
-                    requests_sent: vec![],
-                    requests_received: vec![],
+                    outgoing_requests: vec![],
+                    incoming_requests: vec![],
                 },
             );
 
@@ -997,16 +997,16 @@ pub mod tests {
                 db.get_contacts(user_1).await.unwrap(),
                 Contacts {
                     current: vec![user_2, user_3],
-                    requests_sent: vec![],
-                    requests_received: vec![],
+                    outgoing_requests: vec![],
+                    incoming_requests: vec![],
                 },
             );
             assert_eq!(
                 db.get_contacts(user_3).await.unwrap(),
                 Contacts {
                     current: vec![user_1],
-                    requests_sent: vec![],
-                    requests_received: vec![],
+                    outgoing_requests: vec![],
+                    incoming_requests: vec![],
                 },
             );
 
@@ -1019,16 +1019,16 @@ pub mod tests {
                 db.get_contacts(user_2).await.unwrap(),
                 Contacts {
                     current: vec![user_1],
-                    requests_sent: vec![],
-                    requests_received: vec![],
+                    outgoing_requests: vec![],
+                    incoming_requests: vec![],
                 },
             );
             assert_eq!(
                 db.get_contacts(user_3).await.unwrap(),
                 Contacts {
                     current: vec![user_1],
-                    requests_sent: vec![],
-                    requests_received: vec![],
+                    outgoing_requests: vec![],
+                    incoming_requests: vec![],
                 },
             );
         }
@@ -1203,8 +1203,8 @@ pub mod tests {
             }
             Ok(Contacts {
                 current,
-                requests_sent,
-                requests_received,
+                outgoing_requests,
+                incoming_requests,
             })
         }
 

crates/collab/src/rpc.rs 🔗

@@ -246,7 +246,7 @@ impl Server {
         user_id: UserId,
         mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
         executor: E,
-    ) -> impl Future<Output = ()> {
+    ) -> impl Future<Output = Result<()>> {
         let mut this = self.clone();
         let span = info_span!("handle connection", %user_id, %address);
         async move {
@@ -269,10 +269,15 @@ impl Server {
                 let _ = send_connection_id.send(connection_id).await;
             }
 
+            let contacts = this.app_state.db.get_contacts(user_id).await?;
+            
             {
-                let mut state = this.state_mut().await;
-                state.add_connection(connection_id, user_id);
-                this.update_contacts_for_users(&*state, &[user_id]);
+                let mut store = this.store_mut().await;
+                store.add_connection(connection_id, user_id);
+                let update_contacts = store.build_initial_contacts_update(contacts);
+                for connection_id in store.connection_ids_for_user(user_id) {
+                    this.peer.send(connection_id, update_contacts.clone());
+                }
             }
 
             let handle_io = handle_io.fuse();
@@ -322,14 +327,15 @@ impl Server {
             if let Err(error) = this.sign_out(connection_id).await {
                 tracing::error!(%error, "error signing out");
             }
+            
+            Ok(())
         }.instrument(span)
     }
 
     async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
         self.peer.disconnect(connection_id);
-        let mut state = self.state_mut().await;
-        let removed_connection = state.remove_connection(connection_id)?;
-
+        let removed_connection = self.store_mut().await.remove_connection(connection_id)?;
+    
         for (project_id, project) in removed_connection.hosted_projects {
             if let Some(share) = project.share {
                 broadcast(
@@ -354,8 +360,22 @@ impl Server {
                 )
             });
         }
-
-        self.update_contacts_for_users(&*state, removed_connection.contact_ids.iter());
+                
+        let contacts_to_update = self.app_state.db.get_contacts(removed_connection.user_id).await?;
+        let mut update = proto::UpdateContacts::default();
+        update.contacts.push(proto::Contact {
+            user_id: removed_connection.user_id.to_proto(),
+            projects: Default::default(),
+            online: false,
+        });
+        
+        let store = self.store().await;
+        for user_id in contacts_to_update.current {
+            for connection_id in store.connection_ids_for_user(user_id) {
+                self.peer.send(connection_id, update.clone());
+            }
+        }        
+        
         Ok(())
     }
 
@@ -374,7 +394,7 @@ impl Server {
         response: Response<proto::RegisterProject>,
     ) -> Result<()> {
         let project_id = {
-            let mut state = self.state_mut().await;
+            let mut state = self.store_mut().await;
             let user_id = state.user_id_for_connection(request.sender_id)?;
             state.register_project(request.sender_id, user_id)
         };
@@ -386,7 +406,7 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::UnregisterProject>,
     ) -> Result<()> {
-        let mut state = self.state_mut().await;
+        let mut state = self.store_mut().await;
         let project = state.unregister_project(request.payload.project_id, request.sender_id)?;
         self.update_contacts_for_users(&*state, &project.authorized_user_ids());
         Ok(())
@@ -397,7 +417,7 @@ impl Server {
         request: TypedEnvelope<proto::ShareProject>,
         response: Response<proto::ShareProject>,
     ) -> Result<()> {
-        let mut state = self.state_mut().await;
+        let mut state = self.store_mut().await;
         let project = state.share_project(request.payload.project_id, request.sender_id)?;
         self.update_contacts_for_users(&mut *state, &project.authorized_user_ids);
         response.send(proto::Ack {})?;
@@ -409,7 +429,7 @@ impl Server {
         request: TypedEnvelope<proto::UnshareProject>,
     ) -> Result<()> {
         let project_id = request.payload.project_id;
-        let mut state = self.state_mut().await;
+        let mut state = self.store_mut().await;
         let project = state.unshare_project(project_id, request.sender_id)?;
         broadcast(request.sender_id, project.connection_ids, |conn_id| {
             self.peer
@@ -426,7 +446,7 @@ impl Server {
     ) -> Result<()> {
         let project_id = request.payload.project_id;
 
-        let state = &mut *self.state_mut().await;
+        let state = &mut *self.store_mut().await;
         let user_id = state.user_id_for_connection(request.sender_id)?;
         let (response_payload, connection_ids, contact_user_ids) = state
             .join_project(request.sender_id, user_id, project_id)
@@ -502,7 +522,7 @@ impl Server {
     ) -> Result<()> {
         let sender_id = request.sender_id;
         let project_id = request.payload.project_id;
-        let mut state = self.state_mut().await;
+        let mut state = self.store_mut().await;
         let worktree = state.leave_project(sender_id, project_id)?;
         broadcast(sender_id, worktree.connection_ids, |conn_id| {
             self.peer.send(
@@ -528,7 +548,7 @@ impl Server {
             contact_user_ids.insert(contact_user_id);
         }
 
-        let mut state = self.state_mut().await;
+        let mut state = self.store_mut().await;
         let host_user_id = state.user_id_for_connection(request.sender_id)?;
         contact_user_ids.insert(host_user_id);
 
@@ -562,7 +582,7 @@ impl Server {
     ) -> Result<()> {
         let project_id = request.payload.project_id;
         let worktree_id = request.payload.worktree_id;
-        let mut state = self.state_mut().await;
+        let mut state = self.store_mut().await;
         let (worktree, guest_connection_ids) =
             state.unregister_worktree(project_id, worktree_id, request.sender_id)?;
         broadcast(request.sender_id, guest_connection_ids, |conn_id| {
@@ -583,7 +603,7 @@ impl Server {
         request: TypedEnvelope<proto::UpdateWorktree>,
         response: Response<proto::UpdateWorktree>,
     ) -> Result<()> {
-        let connection_ids = self.state_mut().await.update_worktree(
+        let connection_ids = self.store_mut().await.update_worktree(
             request.sender_id,
             request.payload.project_id,
             request.payload.worktree_id,
@@ -609,7 +629,7 @@ impl Server {
             .summary
             .clone()
             .ok_or_else(|| anyhow!("invalid summary"))?;
-        let receiver_ids = self.state_mut().await.update_diagnostic_summary(
+        let receiver_ids = self.store_mut().await.update_diagnostic_summary(
             request.payload.project_id,
             request.payload.worktree_id,
             request.sender_id,
@@ -627,7 +647,7 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::StartLanguageServer>,
     ) -> Result<()> {
-        let receiver_ids = self.state_mut().await.start_language_server(
+        let receiver_ids = self.store_mut().await.start_language_server(
             request.payload.project_id,
             request.sender_id,
             request
@@ -648,7 +668,7 @@ impl Server {
         request: TypedEnvelope<proto::UpdateLanguageServer>,
     ) -> Result<()> {
         let receiver_ids = self
-            .state()
+            .store()
             .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
@@ -667,7 +687,7 @@ impl Server {
         T: EntityMessage + RequestMessage,
     {
         let host_connection_id = self
-            .state()
+            .store()
             .await
             .read_project(request.payload.remote_entity_id(), request.sender_id)?
             .host_connection_id;
@@ -686,7 +706,7 @@ impl Server {
         response: Response<proto::SaveBuffer>,
     ) -> Result<()> {
         let host = self
-            .state()
+            .store()
             .await
             .read_project(request.payload.project_id, request.sender_id)?
             .host_connection_id;
@@ -696,7 +716,7 @@ impl Server {
             .await?;
 
         let mut guests = self
-            .state()
+            .store()
             .await
             .read_project(request.payload.project_id, request.sender_id)?
             .connection_ids();
@@ -715,7 +735,7 @@ impl Server {
         response: Response<proto::UpdateBuffer>,
     ) -> Result<()> {
         let receiver_ids = self
-            .state()
+            .store()
             .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
@@ -731,7 +751,7 @@ impl Server {
         request: TypedEnvelope<proto::UpdateBufferFile>,
     ) -> Result<()> {
         let receiver_ids = self
-            .state()
+            .store()
             .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
@@ -746,7 +766,7 @@ impl Server {
         request: TypedEnvelope<proto::BufferReloaded>,
     ) -> Result<()> {
         let receiver_ids = self
-            .state()
+            .store()
             .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
@@ -761,7 +781,7 @@ impl Server {
         request: TypedEnvelope<proto::BufferSaved>,
     ) -> Result<()> {
         let receiver_ids = self
-            .state()
+            .store()
             .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
@@ -779,7 +799,7 @@ impl Server {
         let leader_id = ConnectionId(request.payload.leader_id);
         let follower_id = request.sender_id;
         if !self
-            .state()
+            .store()
             .await
             .project_connection_ids(request.payload.project_id, follower_id)?
             .contains(&leader_id)
@@ -800,7 +820,7 @@ impl Server {
     async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
         let leader_id = ConnectionId(request.payload.leader_id);
         if !self
-            .state()
+            .store()
             .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?
             .contains(&leader_id)
@@ -817,7 +837,7 @@ impl Server {
         request: TypedEnvelope<proto::UpdateFollowers>,
     ) -> Result<()> {
         let connection_ids = self
-            .state()
+            .store()
             .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         let leader_id = request
@@ -845,7 +865,7 @@ impl Server {
         response: Response<proto::GetChannels>,
     ) -> Result<()> {
         let user_id = self
-            .state()
+            .store()
             .await
             .user_id_for_connection(request.sender_id)?;
         let channels = self.app_state.db.get_accessible_channels(user_id).await?;
@@ -958,28 +978,28 @@ impl Server {
         Ok(())
     }
 
-    #[instrument(skip(self, state, user_ids))]
-    fn update_contacts_for_users<'a>(
-        self: &Arc<Self>,
-        state: &Store,
-        user_ids: impl IntoIterator<Item = &'a UserId>,
-    ) {
-        for user_id in user_ids {
-            let contacts = state.contacts_for_user(*user_id);
-            for connection_id in state.connection_ids_for_user(*user_id) {
-                self.peer
-                    .send(
-                        connection_id,
-                        proto::UpdateContacts {
-                            contacts: contacts.clone(),
-                            pending_requests_from_user_ids: Default::default(),
-                            pending_requests_to_user_ids: Default::default(),
-                        },
-                    )
-                    .trace_err();
-            }
-        }
-    }
+    // #[instrument(skip(self, state, user_ids))]
+    // fn update_contacts_for_users<'a>(
+    //     self: &Arc<Self>,
+    //     state: &Store,
+    //     user_ids: impl IntoIterator<Item = &'a UserId>,
+    // ) {
+    //     for user_id in user_ids {
+    //         let contacts = state.contacts_for_user(*user_id);
+    //         for connection_id in state.connection_ids_for_user(*user_id) {
+    //             self.peer
+    //                 .send(
+    //                     connection_id,
+    //                     proto::UpdateContacts {
+    //                         contacts: contacts.clone(),
+    //                         pending_requests_from_user_ids: Default::default(),
+    //                         pending_requests_to_user_ids: Default::default(),
+    //                     },
+    //                 )
+    //                 .trace_err();
+    //         }
+    //     }
+    // }
 
     async fn join_channel(
         self: Arc<Self>,
@@ -987,7 +1007,7 @@ impl Server {
         response: Response<proto::JoinChannel>,
     ) -> Result<()> {
         let user_id = self
-            .state()
+            .store()
             .await
             .user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
@@ -1000,7 +1020,7 @@ impl Server {
             Err(anyhow!("access denied"))?;
         }
 
-        self.state_mut()
+        self.store_mut()
             .await
             .join_channel(request.sender_id, channel_id);
         let messages = self
@@ -1029,7 +1049,7 @@ impl Server {
         request: TypedEnvelope<proto::LeaveChannel>,
     ) -> Result<()> {
         let user_id = self
-            .state()
+            .store()
             .await
             .user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
@@ -1042,7 +1062,7 @@ impl Server {
             Err(anyhow!("access denied"))?;
         }
 
-        self.state_mut()
+        self.store_mut()
             .await
             .leave_channel(request.sender_id, channel_id);
 
@@ -1058,7 +1078,7 @@ impl Server {
         let user_id;
         let connection_ids;
         {
-            let state = self.state().await;
+            let state = self.store().await;
             user_id = state.user_id_for_connection(request.sender_id)?;
             connection_ids = state.channel_connection_ids(channel_id)?;
         }
@@ -1112,7 +1132,7 @@ impl Server {
         response: Response<proto::GetChannelMessages>,
     ) -> Result<()> {
         let user_id = self
-            .state()
+            .store()
             .await
             .user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
@@ -1150,7 +1170,7 @@ impl Server {
         Ok(())
     }
 
-    async fn state<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
+    async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
         #[cfg(test)]
         tokio::task::yield_now().await;
         let guard = self.store.read().await;
@@ -1162,7 +1182,7 @@ impl Server {
         }
     }
 
-    async fn state_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
+    async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
         #[cfg(test)]
         tokio::task::yield_now().await;
         let guard = self.store.write().await;

crates/collab/src/rpc/store.rs 🔗

@@ -1,4 +1,4 @@
-use crate::db::{ChannelId, UserId};
+use crate::db::{self, ChannelId, UserId};
 use anyhow::{anyhow, Result};
 use collections::{BTreeMap, HashMap, HashSet};
 use rpc::{proto, ConnectionId};
@@ -58,6 +58,7 @@ pub type ReplicaId = u16;
 
 #[derive(Default)]
 pub struct RemovedConnectionState {
+    pub user_id: UserId,
     pub hosted_projects: HashMap<u64, Project>,
     pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
     pub contact_ids: HashSet<UserId>,
@@ -151,6 +152,7 @@ impl Store {
         }
 
         let mut result = RemovedConnectionState::default();
+        result.user_id = connection.user_id;
         for project_id in connection.projects.clone() {
             if let Ok(project) = self.unregister_project(project_id, connection_id) {
                 result.contact_ids.extend(project.authorized_user_ids());
@@ -213,51 +215,115 @@ impl Store {
             .copied()
     }
 
-    pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
-        let mut contacts = HashMap::default();
-        for project_id in self
-            .visible_projects_by_user_id
-            .get(&user_id)
-            .unwrap_or(&HashSet::default())
-        {
-            let project = &self.projects[project_id];
+    pub fn build_initial_contacts_update(&self, contacts: db::Contacts) -> proto::UpdateContacts {
+        let mut update = proto::UpdateContacts::default();
+        for user_id in contacts.current {
+            update.contacts.push(self.contact_for_user(user_id));
+        }
 
-            let mut guests = HashSet::default();
-            if let Ok(share) = project.share() {
-                for guest_connection_id in share.guests.keys() {
-                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
-                        guests.insert(user_id.to_proto());
-                    }
-                }
-            }
+        for request in contacts.incoming_requests {
+            update
+                .incoming_requests
+                .push(proto::IncomingContactRequest {
+                    user_id: request.requesting_user_id.to_proto(),
+                    should_notify: request.should_notify,
+                })
+        }
 
-            if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
-                let mut worktree_root_names = project
-                    .worktrees
-                    .values()
-                    .filter(|worktree| worktree.visible)
-                    .map(|worktree| worktree.root_name.clone())
-                    .collect::<Vec<_>>();
-                worktree_root_names.sort_unstable();
-                contacts
-                    .entry(host_user_id)
-                    .or_insert_with(|| proto::Contact {
-                        user_id: host_user_id.to_proto(),
-                        projects: Vec::new(),
-                    })
-                    .projects
-                    .push(proto::ProjectMetadata {
-                        id: *project_id,
-                        worktree_root_names,
-                        is_shared: project.share.is_some(),
-                        guests: guests.into_iter().collect(),
-                    });
-            }
+        for requested_user_id in contacts.outgoing_requests {
+            update.outgoing_requests.push(requested_user_id.to_proto())
         }
 
-        contacts.into_values().collect()
+        update
     }
 
+    pub fn contact_for_user(&self, user_id: UserId) -> proto::Contact {
+        proto::Contact {
+            user_id: user_id.to_proto(),
+            projects: self.project_metadata_for_user(user_id),
+            online: self.connection_ids_for_user(user_id).next().is_some(),
+        }
+    }
+
+    pub fn project_metadata_for_user(&self, user_id: UserId) -> Vec<proto::ProjectMetadata> {
+        let project_ids = self
+            .connections_by_user_id
+            .get(&user_id)
+            .unwrap_or_else(|| &HashSet::default())
+            .iter()
+            .filter_map(|connection_id| self.connections.get(connection_id))
+            .flat_map(|connection| connection.projects.iter().copied());
+
+        let mut metadata = Vec::new();
+        for project_id in project_ids {
+            if let Some(project) = self.projects.get(&project_id) {
+                metadata.push(proto::ProjectMetadata {
+                    id: project_id,
+                    is_shared: project.share.is_some(),
+                    worktree_root_names: project
+                        .worktrees
+                        .values()
+                        .map(|worktree| worktree.root_name)
+                        .collect(),
+                    guests: project
+                        .share
+                        .iter()
+                        .flat_map(|share| {
+                            share.guests.values().map(|(_, user_id)| user_id.to_proto())
+                        })
+                        .collect(),
+                });
+            }
+        }
+
+        metadata
+    }
+
+    // pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
+    //     let mut contacts = HashMap::default();
+    //     for project_id in self
+    //         .visible_projects_by_user_id
+    //         .get(&user_id)
+    //         .unwrap_or(&HashSet::default())
+    //     {
+    //         let project = &self.projects[project_id];
+
+    //         let mut guests = HashSet::default();
+    //         if let Ok(share) = project.share() {
+    //             for guest_connection_id in share.guests.keys() {
+    //                 if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
+    //                     guests.insert(user_id.to_proto());
+    //                 }
+    //             }
+    //         }
+
+    //         if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
+    //             let mut worktree_root_names = project
+    //                 .worktrees
+    //                 .values()
+    //                 .filter(|worktree| worktree.visible)
+    //                 .map(|worktree| worktree.root_name.clone())
+    //                 .collect::<Vec<_>>();
+    //             worktree_root_names.sort_unstable();
+    //             contacts
+    //                 .entry(host_user_id)
+    //                 .or_insert_with(|| proto::Contact {
+    //                     user_id: host_user_id.to_proto(),
+    //                     projects: Vec::new(),
+    //                 })
+    //                 .projects
+    //                 .push(proto::ProjectMetadata {
+    //                     id: *project_id,
+    //                     worktree_root_names,
+    //                     is_shared: project.share.is_some(),
+    //                     guests: guests.into_iter().collect(),
+    //                 });
+    //         }
+    //     }
+
+    //     contacts.into_values().collect()
+    // }
+
     pub fn register_project(
         &mut self,
         host_connection_id: ConnectionId,

crates/rpc/proto/zed.proto 🔗

@@ -591,13 +591,16 @@ message GetChannelMessagesResponse {
 
 message UpdateContacts {
     repeated Contact contacts = 1;
-    repeated IncomingContactRequest pending_requests_from_user_ids = 2;
-    repeated uint64 pending_requests_to_user_ids = 3;
+    repeated uint64 remove_contacts = 2;
+    repeated IncomingContactRequest incoming_requests = 3;
+    repeated uint64 remove_incoming_requests = 4;
+    repeated uint64 outgoing_requests = 5;
+    repeated uint64 remove_outgoing_requests = 6;
 }
 
 message IncomingContactRequest {
     uint64 user_id = 1;
-    bool show_notification = 2;
+    bool should_notify = 2;
 }
 
 message UpdateDiagnostics {
@@ -868,6 +871,7 @@ message ChannelMessage {
 message Contact {
     uint64 user_id = 1;
     repeated ProjectMetadata projects = 2;
+    bool online = 3;
 }
 
 message ProjectMetadata {