Implement clearing stale channel buffer participants on server restart

Max Brunsfeld and Mikayla created

Co-authored-by: Mikayla <mikayla@zed.dev>

Change summary

crates/channel/src/channel_store.rs             |  4 
crates/collab/src/db.rs                         |  1 
crates/collab/src/db/queries/buffers.rs         | 26 ++++
crates/collab/src/db/queries/rooms.rs           |  2 
crates/collab/src/db/queries/servers.rs         |  1 
crates/collab/src/rpc.rs                        |  9 +
crates/collab/src/tests/channel_buffer_tests.rs | 96 ++++++++++++++++++
7 files changed, 133 insertions(+), 6 deletions(-)

Detailed changes

crates/channel/src/channel_store.rs 🔗

@@ -500,6 +500,10 @@ impl ChannelStore {
             }
         }
 
+        if buffer_versions.is_empty() {
+            return Task::ready(Ok(()));
+        }
+
         let response = self.client.request(proto::RejoinChannelBuffers {
             buffers: buffer_versions,
         });

crates/collab/src/db.rs 🔗

@@ -435,6 +435,7 @@ pub struct ChannelsForUser {
     pub channels_with_admin_privileges: HashSet<ChannelId>,
 }
 
+#[derive(Debug)]
 pub struct RejoinedChannelBuffer {
     pub buffer: proto::RejoinedChannelBuffer,
     pub old_connection_id: ConnectionId,

crates/collab/src/db/queries/buffers.rs 🔗

@@ -118,6 +118,7 @@ impl Database {
                 // connection, then the client's buffer can be syncronized with
                 // the server's buffer.
                 if buffer.epoch as u64 != client_buffer.epoch {
+                    log::info!("can't rejoin buffer, epoch has changed");
                     continue;
                 }
 
@@ -128,6 +129,7 @@ impl Database {
                     c.user_id == user_id
                         && (c.connection_lost || c.connection_server_id != server_id)
                 }) else {
+                    log::info!("can't rejoin buffer, no previous collaborator found");
                     continue;
                 };
                 let old_connection_id = self_collaborator.connection();
@@ -196,16 +198,36 @@ impl Database {
         .await
     }
 
-    pub async fn refresh_channel_buffer(
+    pub async fn clear_stale_channel_buffer_collaborators(
         &self,
         channel_id: ChannelId,
         server_id: ServerId,
     ) -> Result<RefreshedChannelBuffer> {
         self.transaction(|tx| async move {
+            let collaborators = channel_buffer_collaborator::Entity::find()
+                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
+                .all(&*tx)
+                .await?;
+
             let mut connection_ids = Vec::new();
             let mut removed_collaborators = Vec::new();
+            let mut collaborator_ids_to_remove = Vec::new();
+            for collaborator in &collaborators {
+                if !collaborator.connection_lost && collaborator.connection_server_id == server_id {
+                    connection_ids.push(collaborator.connection());
+                } else {
+                    removed_collaborators.push(proto::RemoveChannelBufferCollaborator {
+                        channel_id: channel_id.to_proto(),
+                        peer_id: Some(collaborator.connection().into()),
+                    });
+                    collaborator_ids_to_remove.push(collaborator.id);
+                }
+            }
 
-            // TODO
+            channel_buffer_collaborator::Entity::delete_many()
+                .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove))
+                .exec(&*tx)
+                .await?;
 
             Ok(RefreshedChannelBuffer {
                 connection_ids,

crates/collab/src/db/queries/rooms.rs 🔗

@@ -1,7 +1,7 @@
 use super::*;
 
 impl Database {
-    pub async fn refresh_room(
+    pub async fn clear_stale_room_participants(
         &self,
         room_id: RoomId,
         new_server_id: ServerId,

crates/collab/src/rpc.rs 🔗

@@ -285,11 +285,15 @@ impl Server {
                     .trace_err()
                 {
                     tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
+                    tracing::info!(
+                        stale_channel_buffer_count = channel_ids.len(),
+                        "retrieved stale channel buffers"
+                    );
 
                     for channel_id in channel_ids {
                         if let Some(refreshed_channel_buffer) = app_state
                             .db
-                            .refresh_channel_buffer(channel_id, server_id)
+                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
                             .await
                             .trace_err()
                         {
@@ -309,7 +313,7 @@ impl Server {
 
                         if let Some(mut refreshed_room) = app_state
                             .db
-                            .refresh_room(room_id, server_id)
+                            .clear_stale_room_participants(room_id, server_id)
                             .await
                             .trace_err()
                         {
@@ -873,6 +877,7 @@ async fn connection_lost(
 
     futures::select_biased! {
         _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
+            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
             leave_room_for_session(&session).await.trace_err();
             leave_channel_buffers_for_session(&session)
                 .await

crates/collab/src/tests/channel_buffer_tests.rs 🔗

@@ -1,4 +1,7 @@
-use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
+use crate::{
+    rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
+    tests::TestServer,
+};
 use call::ActiveCall;
 use channel::Channel;
 use client::UserId;
@@ -472,6 +475,97 @@ async fn test_rejoin_channel_buffer(
     });
 }
 
+#[gpui::test]
+async fn test_channel_buffers_and_server_restarts(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+    cx_c: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+    let mut server = TestServer::start(&deterministic).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+    let client_c = server.create_client(cx_c, "user_c").await;
+
+    let channel_id = server
+        .make_channel(
+            "the-channel",
+            (&client_a, cx_a),
+            &mut [(&client_b, cx_b), (&client_c, cx_c)],
+        )
+        .await;
+
+    let channel_buffer_a = client_a
+        .channel_store()
+        .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
+        .await
+        .unwrap();
+    let channel_buffer_b = client_b
+        .channel_store()
+        .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
+        .await
+        .unwrap();
+    let _channel_buffer_c = client_c
+        .channel_store()
+        .update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx))
+        .await
+        .unwrap();
+
+    channel_buffer_a.update(cx_a, |buffer, cx| {
+        buffer.buffer().update(cx, |buffer, cx| {
+            buffer.edit([(0..0, "1")], None, cx);
+        })
+    });
+    deterministic.run_until_parked();
+
+    // Client C can't reconnect.
+    client_c.override_establish_connection(|_, cx| cx.spawn(|_| future::pending()));
+
+    // Server stops.
+    server.reset().await;
+    deterministic.advance_clock(RECEIVE_TIMEOUT);
+
+    // While the server is down, both clients make an edit.
+    channel_buffer_a.update(cx_a, |buffer, cx| {
+        buffer.buffer().update(cx, |buffer, cx| {
+            buffer.edit([(1..1, "2")], None, cx);
+        })
+    });
+    channel_buffer_b.update(cx_b, |buffer, cx| {
+        buffer.buffer().update(cx, |buffer, cx| {
+            buffer.edit([(0..0, "0")], None, cx);
+        })
+    });
+
+    // Server restarts.
+    server.start().await.unwrap();
+    deterministic.advance_clock(CLEANUP_TIMEOUT);
+
+    // Clients reconnects. Clients A and B see each other's edits, and see
+    // that client C has disconnected.
+    channel_buffer_a.read_with(cx_a, |buffer, cx| {
+        assert_eq!(buffer.buffer().read(cx).text(), "012");
+    });
+    channel_buffer_b.read_with(cx_b, |buffer, cx| {
+        assert_eq!(buffer.buffer().read(cx).text(), "012");
+    });
+
+    channel_buffer_a.read_with(cx_a, |buffer_a, _| {
+        channel_buffer_b.read_with(cx_b, |buffer_b, _| {
+            assert_eq!(
+                buffer_a
+                    .collaborators()
+                    .iter()
+                    .map(|c| c.user_id)
+                    .collect::<Vec<_>>(),
+                vec![client_a.user_id().unwrap(), client_b.user_id().unwrap()]
+            );
+            assert_eq!(buffer_a.collaborators(), buffer_b.collaborators());
+        });
+    });
+}
+
 #[track_caller]
 fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option<UserId>]) {
     assert_eq!(