Fix race condition in channel notes rejoin (#50034)

Conrad Irwin created

Closes #49998

Before you mark this PR as ready for review, make sure that you have:
- [ ] Added a solid test coverage and/or screenshots from doing manual
testing
- [ ] Done a self-review taking into account security and performance
aspects
- [ ] Aligned any UI changes with the [UI
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)

Release Notes:

- Fixed a (very rare) crash that could happen due to lost edits in
channel buffers

Change summary

crates/channel/src/channel_buffer.rs                    |  11 
crates/channel/src/channel_store.rs                     |  55 +++
crates/collab/tests/integration/channel_buffer_tests.rs | 160 +++++++++++
3 files changed, 219 insertions(+), 7 deletions(-)

Detailed changes

crates/channel/src/channel_buffer.rs 🔗

@@ -22,6 +22,7 @@ pub(crate) fn init(client: &AnyProtoClient) {
 pub struct ChannelBuffer {
     pub channel_id: ChannelId,
     connected: bool,
+    rejoining: bool,
     collaborators: HashMap<PeerId, Collaborator>,
     user_store: Entity<UserStore>,
     channel_store: Entity<ChannelStore>,
@@ -84,6 +85,7 @@ impl ChannelBuffer {
                 buffer_epoch: response.epoch,
                 client,
                 connected: true,
+                rejoining: false,
                 collaborators: Default::default(),
                 acknowledge_task: None,
                 channel_id: channel.id,
@@ -111,6 +113,7 @@ impl ChannelBuffer {
 
     pub fn connected(&mut self, cx: &mut Context<Self>) {
         self.connected = true;
+        self.rejoining = false;
         if self.subscription.is_none() {
             let Ok(subscription) = self.client.subscribe_to_entity(self.channel_id.0) else {
                 return;
@@ -120,6 +123,10 @@ impl ChannelBuffer {
         }
     }
 
+    pub(crate) fn set_rejoining(&mut self, rejoining: bool) {
+        self.rejoining = rejoining;
+    }
+
     pub fn remote_id(&self, cx: &App) -> BufferId {
         self.buffer.read(cx).remote_id()
     }
@@ -204,6 +211,9 @@ impl ChannelBuffer {
                     return;
                 }
                 let operation = language::proto::serialize_operation(operation);
+                if self.rejoining {
+                    return;
+                }
                 self.client
                     .send(proto::UpdateChannelBuffer {
                         channel_id: self.channel_id.0,
@@ -263,6 +273,7 @@ impl ChannelBuffer {
         log::info!("channel buffer {} disconnected", self.channel_id);
         if self.connected {
             self.connected = false;
+            self.rejoining = false;
             self.subscription.take();
             cx.emit(ChannelBufferEvent::Disconnected);
             cx.notify()

crates/channel/src/channel_store.rs 🔗

@@ -855,12 +855,18 @@ impl ChannelStore {
             if let OpenEntityHandle::Open(buffer) = buffer
                 && let Some(buffer) = buffer.upgrade()
             {
-                let channel_buffer = buffer.read(cx);
-                let buffer = channel_buffer.buffer().read(cx);
-                buffer_versions.push(proto::ChannelBufferVersion {
-                    channel_id: channel_buffer.channel_id.0,
-                    epoch: channel_buffer.epoch(),
-                    version: language::proto::serialize_version(&buffer.version()),
+                buffer.update(cx, |channel_buffer, cx| {
+                    // Block on_buffer_update from sending UpdateChannelBuffer messages
+                    // until the rejoin completes. This prevents a race condition where
+                    // edits made during the rejoin async gap could inflate the server
+                    // version, causing offline edits to be filtered out by serialize_ops.
+                    channel_buffer.set_rejoining(true);
+                    let inner_buffer = channel_buffer.buffer().read(cx);
+                    buffer_versions.push(proto::ChannelBufferVersion {
+                        channel_id: channel_buffer.channel_id.0,
+                        epoch: channel_buffer.epoch(),
+                        version: language::proto::serialize_version(&inner_buffer.version()),
+                    });
                 });
             }
         }
@@ -874,7 +880,26 @@ impl ChannelStore {
         });
 
         cx.spawn(async move |this, cx| {
-            let mut response = response.await?;
+            let response = match response.await {
+                Ok(response) => response,
+                Err(err) => {
+                    // Clear rejoining flag on all buffers since the rejoin failed
+                    this.update(cx, |this, cx| {
+                        for buffer in this.opened_buffers.values() {
+                            if let OpenEntityHandle::Open(buffer) = buffer {
+                                if let Some(buffer) = buffer.upgrade() {
+                                    buffer.update(cx, |channel_buffer, _| {
+                                        channel_buffer.set_rejoining(false);
+                                    });
+                                }
+                            }
+                        }
+                    })
+                    .ok();
+                    return Err(err);
+                }
+            };
+            let mut response = response;
 
             this.update(cx, |this, cx| {
                 this.opened_buffers.retain(|_, buffer| match buffer {
@@ -948,6 +973,22 @@ impl ChannelStore {
     fn handle_disconnect(&mut self, wait_for_reconnect: bool, cx: &mut Context<Self>) {
         cx.notify();
         self.did_subscribe = false;
+
+        // If we're waiting for reconnect, set rejoining=true on all buffers immediately.
+        // This prevents operations from being sent during the reconnection window,
+        // before handle_connect has a chance to run and capture the version.
+        if wait_for_reconnect {
+            for buffer in self.opened_buffers.values() {
+                if let OpenEntityHandle::Open(buffer) = buffer {
+                    if let Some(buffer) = buffer.upgrade() {
+                        buffer.update(cx, |channel_buffer, _| {
+                            channel_buffer.set_rejoining(true);
+                        });
+                    }
+                }
+            }
+        }
+
         self.disconnect_channel_buffers_task.get_or_insert_with(|| {
             cx.spawn(async move |this, cx| {
                 if wait_for_reconnect {

crates/collab/tests/integration/channel_buffer_tests.rs 🔗

@@ -3,6 +3,7 @@ use call::ActiveCall;
 use channel::ACKNOWLEDGE_DEBOUNCE_INTERVAL;
 use client::{Collaborator, ParticipantIndex, UserId};
 use collab::rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT};
+
 use collab_ui::channel_view::ChannelView;
 use collections::HashMap;
 use editor::{Anchor, Editor, MultiBufferOffset, ToOffset};
@@ -698,6 +699,165 @@ async fn test_channel_buffer_changes_persist(
     });
 }
 
+#[gpui::test]
+async fn test_channel_buffer_operations_lost_on_reconnect(
+    executor: BackgroundExecutor,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+) {
+    let mut server = TestServer::start(executor.clone()).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+
+    let channel_id = server
+        .make_channel(
+            "the-channel",
+            None,
+            (&client_a, cx_a),
+            &mut [(&client_b, cx_b)],
+        )
+        .await;
+
+    // Both clients open the channel buffer.
+    let channel_buffer_a = client_a
+        .channel_store()
+        .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
+        .await
+        .unwrap();
+    let channel_buffer_b = client_b
+        .channel_store()
+        .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
+        .await
+        .unwrap();
+
+    // Step 1: Client A makes an initial edit that syncs to B.
+    channel_buffer_a.update(cx_a, |buffer, cx| {
+        buffer.buffer().update(cx, |buffer, cx| {
+            buffer.edit([(0..0, "a")], None, cx);
+        })
+    });
+    executor.run_until_parked();
+
+    // Verify both clients see "a".
+    channel_buffer_a.read_with(cx_a, |buffer, cx| {
+        assert_eq!(buffer.buffer().read(cx).text(), "a");
+    });
+    channel_buffer_b.read_with(cx_b, |buffer, cx| {
+        assert_eq!(buffer.buffer().read(cx).text(), "a");
+    });
+
+    // Step 2: Disconnect client A. Do NOT advance past RECONNECT_TIMEOUT
+    // so that the buffer stays in `opened_buffers` for rejoin.
+    server.forbid_connections();
+    server.disconnect_client(client_a.peer_id().unwrap());
+    executor.run_until_parked();
+
+    // Step 3: While disconnected, client A makes an offline edit ("b").
+    // on_buffer_update fires but client.send() fails because transport is down.
+    channel_buffer_a.update(cx_a, |buffer, cx| {
+        buffer.buffer().update(cx, |buffer, cx| {
+            buffer.edit([(1..1, "b")], None, cx);
+        })
+    });
+    executor.run_until_parked();
+
+    // Client A sees "ab" locally; B still sees "a".
+    channel_buffer_a.read_with(cx_a, |buffer, cx| {
+        assert_eq!(buffer.buffer().read(cx).text(), "ab");
+    });
+    channel_buffer_b.read_with(cx_b, |buffer, cx| {
+        assert_eq!(buffer.buffer().read(cx).text(), "a");
+    });
+
+    // Step 4: Reconnect and make a racing edit in parallel.
+    //
+    // The race condition occurs when:
+    // 1. Transport reconnects, handle_connect captures version V (with "b") and sends RejoinChannelBuffers
+    // 2. DURING the async gap (awaiting response), user makes edit "c"
+    // 3. on_buffer_update sends UpdateChannelBuffer (succeeds because transport is up)
+    // 4. Server receives BOTH messages concurrently (FuturesUnordered)
+    // 5. If UpdateChannelBuffer commits first, server version is inflated to include "c"
+    // 6. RejoinChannelBuffers reads inflated version and sends it back
+    // 7. Client's serialize_ops(inflated_version) filters out "b" (offline edit)
+    //    because the inflated version's timestamp covers "b"'s timestamp
+
+    // Get the buffer handle for spawning
+    let buffer_for_edit = channel_buffer_a.read_with(cx_a, |buffer, _| buffer.buffer());
+
+    // Spawn the edit task - it will wait for executor to run it
+    let edit_task = cx_a.spawn({
+        let buffer = buffer_for_edit;
+        async move |mut cx| {
+            let _ = buffer.update(&mut cx, |buffer, cx| {
+                buffer.edit([(2..2, "c")], None, cx);
+            });
+        }
+    });
+
+    // Allow connections so reconnect can succeed
+    server.allow_connections();
+
+    // Advance clock to trigger reconnection attempt
+    executor.advance_clock(RECEIVE_TIMEOUT);
+
+    // Run the edit task - this races with handle_connect
+    edit_task.detach();
+
+    // Let everything settle.
+    executor.run_until_parked();
+
+    // Step 7: Read final buffer text from both clients.
+    let text_a = channel_buffer_a.read_with(cx_a, |buffer, cx| buffer.buffer().read(cx).text());
+    let text_b = channel_buffer_b.read_with(cx_b, |buffer, cx| buffer.buffer().read(cx).text());
+
+    // Both clients must see the same text containing all three edits.
+    assert_eq!(
+        text_a, text_b,
+        "Client A and B diverged! A sees {:?}, B sees {:?}. \
+         Operations were lost during reconnection.",
+        text_a, text_b
+    );
+    assert!(
+        text_a.contains('a'),
+        "Initial edit 'a' missing from final text {:?}",
+        text_a
+    );
+    assert!(
+        text_a.contains('b'),
+        "Offline edit 'b' missing from final text {:?}. \
+         This is the reconnection race bug: the offline operation was \
+         filtered out by serialize_ops because the server_version was \
+         inflated by a racing UpdateChannelBuffer.",
+        text_a
+    );
+    assert!(
+        text_a.contains('c'),
+        "Racing edit 'c' missing from final text {:?}",
+        text_a
+    );
+
+    // Step 8: Verify the invariant directly — every operation known to
+    // client A must be observed by client B's version. If any operation
+    // in A's history is not covered by B's version, it was lost.
+    channel_buffer_a.read_with(cx_a, |buf_a, cx_a_inner| {
+        let buffer_a = buf_a.buffer().read(cx_a_inner);
+        let ops_a = buffer_a.operations();
+        channel_buffer_b.read_with(cx_b, |buf_b, cx_b_inner| {
+            let buffer_b = buf_b.buffer().read(cx_b_inner);
+            let version_b = buffer_b.version();
+            for (lamport, _op) in ops_a.iter() {
+                assert!(
+                    version_b.observed(*lamport),
+                    "Operation with lamport timestamp {:?} from client A \
+                     is NOT observed by client B's version. This operation \
+                     was lost during reconnection.",
+                    lamport
+                );
+            }
+        });
+    });
+}
+
 #[track_caller]
 fn assert_collaborators(collaborators: &HashMap<PeerId, Collaborator>, ids: &[Option<UserId>]) {
     let mut user_ids = collaborators