Maintain server state consistency when removing a connection

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

server/src/rpc.rs | 107 ++++++++++++++++++++++++++++++++++++++++++++++--
1 file changed, 101 insertions(+), 6 deletions(-)

Detailed changes

server/src/rpc.rs 🔗

@@ -3,6 +3,7 @@ use super::{
     db::{ChannelId, MessageId, UserId},
     AppState,
 };
+use crate::errors::TideResultExt;
 use anyhow::anyhow;
 use async_std::{sync::RwLock, task};
 use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
@@ -49,7 +50,7 @@ pub struct Server {
 struct ServerState {
     connections: HashMap<ConnectionId, ConnectionState>,
     connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
-    pub worktrees: HashMap<u64, Worktree>,
+    worktrees: HashMap<u64, Worktree>,
     visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
     channels: HashMap<ChannelId, Channel>,
     next_worktree_id: u64,
@@ -707,15 +708,19 @@ impl Server {
             {
                 let worktree = &state.worktrees[worktree_id];
 
-                let mut participants = HashSet::new();
+                let mut guests = HashSet::new();
                 if let Ok(share) = worktree.share() {
                     for guest_connection_id in share.guest_connection_ids.keys() {
-                        let user_id = state.user_id_for_connection(*guest_connection_id)?;
-                        participants.insert(user_id.to_proto());
+                        let user_id = state
+                            .user_id_for_connection(*guest_connection_id)
+                            .context("stale worktree guest connection")?;
+                        guests.insert(user_id.to_proto());
                     }
                 }
 
-                let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
+                let host_user_id = state
+                    .user_id_for_connection(worktree.host_connection_id)
+                    .context("stale worktree host connection")?;
                 let host =
                     collaborators
                         .entry(host_user_id)
@@ -726,7 +731,7 @@ impl Server {
                 host.worktrees.push(proto::WorktreeMetadata {
                     root_name: worktree.root_name.clone(),
                     is_shared: worktree.share().is_ok(),
-                    participants: participants.into_iter().collect(),
+                    participants: guests.into_iter().collect(),
                 });
             }
 
@@ -1137,7 +1142,14 @@ impl ServerState {
                 .insert(worktree_id);
         }
         self.next_worktree_id += 1;
+        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
+            connection.worktrees.insert(worktree_id);
+        }
         self.worktrees.insert(worktree_id, worktree);
+
+        #[cfg(test)]
+        self.check_invariants();
+
         worktree_id
     }
 
@@ -1161,6 +1173,89 @@ impl ServerState {
                 visible_worktrees.remove(&worktree_id);
             }
         }
+
+        #[cfg(test)]
+        self.check_invariants();
+    }
+
+    #[cfg(test)]
+    fn check_invariants(&self) {
+        for (connection_id, connection) in &self.connections {
+            for worktree_id in &connection.worktrees {
+                let worktree = &self.worktrees.get(&worktree_id).unwrap();
+                if worktree.host_connection_id != *connection_id {
+                    assert!(worktree
+                        .share()
+                        .unwrap()
+                        .guest_connection_ids
+                        .contains_key(connection_id));
+                }
+            }
+            for channel_id in &connection.channels {
+                let channel = self.channels.get(channel_id).unwrap();
+                assert!(channel.connection_ids.contains(connection_id));
+            }
+            assert!(self
+                .connections_by_user_id
+                .get(&connection.user_id)
+                .unwrap()
+                .contains(connection_id));
+        }
+
+        for (user_id, connection_ids) in &self.connections_by_user_id {
+            for connection_id in connection_ids {
+                assert_eq!(
+                    self.connections.get(connection_id).unwrap().user_id,
+                    *user_id
+                );
+            }
+        }
+
+        for (worktree_id, worktree) in &self.worktrees {
+            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
+            assert!(host_connection.worktrees.contains(worktree_id));
+
+            for collaborator_id in &worktree.collaborator_user_ids {
+                let visible_worktree_ids = self
+                    .visible_worktrees_by_user_id
+                    .get(collaborator_id)
+                    .unwrap();
+                assert!(visible_worktree_ids.contains(worktree_id));
+            }
+
+            if let Some(share) = &worktree.share {
+                for guest_connection_id in share.guest_connection_ids.keys() {
+                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
+                    assert!(guest_connection.worktrees.contains(worktree_id));
+                }
+                assert_eq!(
+                    share.active_replica_ids.len(),
+                    share.guest_connection_ids.len(),
+                );
+                assert_eq!(
+                    share.active_replica_ids,
+                    share
+                        .guest_connection_ids
+                        .values()
+                        .copied()
+                        .collect::<HashSet<_>>(),
+                );
+            }
+        }
+
+        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
+            for worktree_id in visible_worktree_ids {
+                let worktree = self.worktrees.get(worktree_id).unwrap();
+                assert!(worktree.collaborator_user_ids.contains(user_id));
+            }
+        }
+
+        for (channel_id, channel) in &self.channels {
+            for connection_id in &channel.connection_ids {
+                let connection = self.connections.get(connection_id).unwrap();
+                assert!(connection.channels.contains(channel_id));
+            }
+        }
     }
 }