Update collaborators as worktrees are opened/shared/closed

Antonio Scandurra created

Change summary

server/src/rpc.rs | 163 ++++++++++++++++++++++++++----------------------
1 file changed, 87 insertions(+), 76 deletions(-)

Detailed changes

server/src/rpc.rs 🔗

@@ -99,7 +99,7 @@ impl Server {
         server
             .add_handler(Server::ping)
             .add_handler(Server::open_worktree)
-            .add_handler(Server::close_worktree)
+            .add_handler(Server::handle_close_worktree)
             .add_handler(Server::share_worktree)
             .add_handler(Server::unshare_worktree)
             .add_handler(Server::join_worktree)
@@ -195,22 +195,7 @@ impl Server {
 
     async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
         self.peer.disconnect(connection_id).await;
-        let worktree_ids = self.remove_connection(connection_id).await;
-        for worktree_id in worktree_ids {
-            let state = self.state.read().await;
-            if let Some(worktree) = state.worktrees.get(&worktree_id) {
-                broadcast(connection_id, worktree.connection_ids(), |conn_id| {
-                    self.peer.send(
-                        conn_id,
-                        proto::RemovePeer {
-                            worktree_id,
-                            peer_id: connection_id.0,
-                        },
-                    )
-                })
-                .await?;
-            }
-        }
+        self.remove_connection(connection_id).await?;
         Ok(())
     }
 
@@ -233,29 +218,20 @@ impl Server {
     }
 
     // Remove the given connection and its association with any worktrees.
-    async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<u64> {
+    async fn remove_connection(
+        self: &Arc<Server>,
+        connection_id: ConnectionId,
+    ) -> tide::Result<()> {
         let mut worktree_ids = Vec::new();
         let mut state = self.state.write().await;
         if let Some(connection) = state.connections.remove(&connection_id) {
+            worktree_ids = connection.worktrees.into_iter().collect();
+
             for channel_id in connection.channels {
                 if let Some(channel) = state.channels.get_mut(&channel_id) {
                     channel.connection_ids.remove(&connection_id);
                 }
             }
-            for worktree_id in connection.worktrees {
-                if let Some(worktree) = state.worktrees.get_mut(&worktree_id) {
-                    if worktree.host_connection_id == connection_id {
-                        worktree_ids.push(worktree_id);
-                    } else if let Some(share_state) = worktree.share.as_mut() {
-                        if let Some(replica_id) =
-                            share_state.guest_connection_ids.remove(&connection_id)
-                        {
-                            share_state.active_replica_ids.remove(&replica_id);
-                            worktree_ids.push(worktree_id);
-                        }
-                    }
-                }
-            }
 
             let user_connections = state
                 .connections_by_user_id
@@ -266,7 +242,12 @@ impl Server {
                 state.connections_by_user_id.remove(&connection.user_id);
             }
         }
-        worktree_ids
+
+        for worktree_id in worktree_ids {
+            self.close_worktree(worktree_id, connection_id).await?;
+        }
+
+        Ok(())
     }
 
     async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
@@ -279,7 +260,7 @@ impl Server {
         request: TypedEnvelope<proto::OpenWorktree>,
     ) -> tide::Result<()> {
         let receipt = request.receipt();
-        let user_id = self
+        let host_user_id = self
             .state
             .read()
             .await
@@ -289,7 +270,7 @@ impl Server {
         for github_login in request.payload.collaborator_logins {
             match self.app_state.db.create_user(&github_login, false).await {
                 Ok(collaborator_user_id) => {
-                    if collaborator_user_id != user_id {
+                    if collaborator_user_id != host_user_id {
                         collaborator_user_ids.push(collaborator_user_id);
                     }
                 }
@@ -303,18 +284,24 @@ impl Server {
             }
         }
 
-        let mut state = self.state.write().await;
-        let worktree_id = state.add_worktree(Worktree {
-            host_connection_id: request.sender_id,
-            collaborator_user_ids: collaborator_user_ids.clone(),
-            root_name: request.payload.root_name,
-            share: None,
-        });
+        let worktree_id;
+        let mut user_ids;
+        {
+            let mut state = self.state.write().await;
+            worktree_id = state.add_worktree(Worktree {
+                host_connection_id: request.sender_id,
+                collaborator_user_ids: collaborator_user_ids.clone(),
+                root_name: request.payload.root_name,
+                share: None,
+            });
+            user_ids = collaborator_user_ids;
+            user_ids.push(host_user_id);
+        }
 
         self.peer
             .respond(receipt, proto::OpenWorktreeResponse { worktree_id })
             .await?;
-        self.update_collaborators(&collaborator_user_ids).await?;
+        self.update_collaborators_for_users(&user_ids).await?;
 
         Ok(())
     }
@@ -323,6 +310,11 @@ impl Server {
         self: Arc<Server>,
         mut request: TypedEnvelope<proto::ShareWorktree>,
     ) -> tide::Result<()> {
+        let host_user_id = self
+            .state
+            .read()
+            .await
+            .user_id_for_connection(request.sender_id)?;
         let worktree = request
             .payload
             .worktree
@@ -332,6 +324,7 @@ impl Server {
             .into_iter()
             .map(|entry| (entry.id, entry))
             .collect();
+
         let mut state = self.state.write().await;
         if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
             worktree.share = Some(WorktreeShare {
@@ -339,13 +332,15 @@ impl Server {
                 active_replica_ids: Default::default(),
                 entries,
             });
+
+            let mut user_ids = worktree.collaborator_user_ids.clone();
+            user_ids.push(host_user_id);
+
+            drop(state);
             self.peer
                 .respond(request.receipt(), proto::ShareWorktreeResponse {})
                 .await?;
-
-            let collaborator_user_ids = worktree.collaborator_user_ids.clone();
-            drop(state);
-            self.update_collaborators(&collaborator_user_ids).await?;
+            self.update_collaborators_for_users(&user_ids).await?;
         } else {
             self.peer
                 .respond_with_error(
@@ -364,9 +359,14 @@ impl Server {
         request: TypedEnvelope<proto::UnshareWorktree>,
     ) -> tide::Result<()> {
         let worktree_id = request.payload.worktree_id;
+        let host_user_id = self
+            .state
+            .read()
+            .await
+            .user_id_for_connection(request.sender_id)?;
 
         let connection_ids;
-        let collaborator_user_ids;
+        let mut user_ids;
         {
             let mut state = self.state.write().await;
             let worktree = state.write_worktree(worktree_id, request.sender_id)?;
@@ -375,7 +375,8 @@ impl Server {
             }
 
             connection_ids = worktree.connection_ids();
-            collaborator_user_ids = worktree.collaborator_user_ids.clone();
+            user_ids = worktree.collaborator_user_ids.clone();
+            user_ids.push(host_user_id);
             worktree.share.take();
             for connection_id in &connection_ids {
                 if let Some(connection) = state.connections.get_mut(connection_id) {
@@ -389,7 +390,7 @@ impl Server {
                 .send(conn_id, proto::UnshareWorktree { worktree_id })
         })
         .await?;
-        self.update_collaborators(&collaborator_user_ids).await?;
+        self.update_collaborators_for_users(&user_ids).await?;
 
         Ok(())
     }
@@ -407,7 +408,7 @@ impl Server {
 
         let response;
         let connection_ids;
-        let collaborator_user_ids;
+        let mut user_ids;
         let mut state = self.state.write().await;
         match state.join_worktree(request.sender_id, user_id, worktree_id) {
             Ok((peer_replica_id, worktree)) => {
@@ -426,8 +427,6 @@ impl Server {
                         });
                     }
                 }
-                connection_ids = worktree.connection_ids();
-                collaborator_user_ids = worktree.collaborator_user_ids.clone();
                 response = proto::JoinWorktreeResponse {
                     worktree: Some(proto::Worktree {
                         id: worktree_id,
@@ -437,6 +436,11 @@ impl Server {
                     replica_id: peer_replica_id as u32,
                     peers,
                 };
+
+                let host_connection_id = worktree.host_connection_id;
+                connection_ids = worktree.connection_ids();
+                user_ids = worktree.collaborator_user_ids.clone();
+                user_ids.push(state.user_id_for_connection(host_connection_id)?);
             }
             Err(error) => {
                 self.peer
@@ -465,55 +469,69 @@ impl Server {
         })
         .await?;
         self.peer.respond(request.receipt(), response).await?;
-        self.update_collaborators(&collaborator_user_ids).await?;
+        self.update_collaborators_for_users(&user_ids).await?;
 
         Ok(())
     }
 
-    async fn close_worktree(
+    async fn handle_close_worktree(
         self: Arc<Server>,
         request: TypedEnvelope<proto::CloseWorktree>,
     ) -> tide::Result<()> {
-        let worktree_id = request.payload.worktree_id;
+        self.close_worktree(request.payload.worktree_id, request.sender_id)
+            .await
+    }
+
+    async fn close_worktree(
+        self: &Arc<Server>,
+        worktree_id: u64,
+        conn_id: ConnectionId,
+    ) -> tide::Result<()> {
         let connection_ids;
+        let mut user_ids;
+
         let mut is_host = false;
         let mut is_guest = false;
         {
             let mut state = self.state.write().await;
-            let worktree = state.write_worktree(worktree_id, request.sender_id)?;
+            let worktree = state.write_worktree(worktree_id, conn_id)?;
+            let host_connection_id = worktree.host_connection_id;
             connection_ids = worktree.connection_ids();
+            user_ids = worktree.collaborator_user_ids.clone();
 
-            if worktree.host_connection_id == request.sender_id {
+            if worktree.host_connection_id == conn_id {
                 is_host = true;
                 state.remove_worktree(worktree_id);
             } else {
                 let share = worktree.share_mut()?;
-                if let Some(replica_id) = share.guest_connection_ids.remove(&request.sender_id) {
+                if let Some(replica_id) = share.guest_connection_ids.remove(&conn_id) {
                     is_guest = true;
                     share.active_replica_ids.remove(&replica_id);
                 }
             }
+
+            user_ids.push(state.user_id_for_connection(host_connection_id)?);
         }
 
         if is_host {
-            broadcast(request.sender_id, connection_ids, |conn_id| {
+            broadcast(conn_id, connection_ids, |conn_id| {
                 self.peer
                     .send(conn_id, proto::UnshareWorktree { worktree_id })
             })
             .await?;
         } else if is_guest {
-            broadcast(request.sender_id, connection_ids, |conn_id| {
+            broadcast(conn_id, connection_ids, |conn_id| {
                 self.peer.send(
                     conn_id,
                     proto::RemovePeer {
                         worktree_id,
-                        peer_id: request.sender_id.0,
+                        peer_id: conn_id.0,
                     },
                 )
             })
             .await?
         }
-
+        self.update_collaborators_for_users(&user_ids).await?;
         Ok(())
     }
 
@@ -694,7 +712,10 @@ impl Server {
         Ok(())
     }
 
-    async fn update_collaborators(self: &Arc<Server>, user_ids: &[UserId]) -> tide::Result<()> {
+    async fn update_collaborators_for_users<'a>(
+        self: &Arc<Server>,
+        user_ids: impl IntoIterator<Item = &'a UserId>,
+    ) -> tide::Result<()> {
         let mut send_futures = Vec::new();
 
         let state = self.state.read().await;
@@ -730,15 +751,8 @@ impl Server {
                 });
             }
 
-            let connection_ids = self
-                .state
-                .read()
-                .await
-                .user_connection_ids(*user_id)
-                .collect::<Vec<_>>();
-
             let collaborators = collaborators.into_values().collect::<Vec<_>>();
-            for connection_id in connection_ids {
+            for connection_id in state.user_connection_ids(*user_id) {
                 send_futures.push(self.peer.send(
                     connection_id,
                     proto::UpdateCollaborators {
@@ -748,6 +762,7 @@ impl Server {
             }
         }
 
+        drop(state);
         futures::future::try_join_all(send_futures).await?;
 
         Ok(())
@@ -1052,10 +1067,6 @@ impl ServerState {
             .copied()
     }
 
-    fn is_online(&self, user_id: UserId) -> bool {
-        self.connections_by_user_id.contains_key(&user_id)
-    }
-
     // Add the given connection as a guest of the given worktree
     fn join_worktree(
         &mut self,