Start work on rejoining channel buffers

Max Brunsfeld created

Change summary

crates/channel/src/channel_buffer.rs            |  27 +
crates/channel/src/channel_store.rs             | 153 +++++++++-
crates/collab/src/db/queries/buffers.rs         | 265 ++++++++++++++----
crates/collab/src/rpc.rs                        |  25 +
crates/collab/src/tests/channel_buffer_tests.rs | 138 ++++++---
crates/language/src/proto.rs                    |  25 +
crates/project/src/project.rs                   |  27 -
crates/rpc/proto/zed.proto                      |  26 +
crates/rpc/src/proto.rs                         |   3 
9 files changed, 526 insertions(+), 163 deletions(-)

Detailed changes

crates/channel/src/channel_buffer.rs 🔗

@@ -17,6 +17,7 @@ pub struct ChannelBuffer {
     connected: bool,
     collaborators: Vec<proto::Collaborator>,
     buffer: ModelHandle<language::Buffer>,
+    buffer_epoch: u64,
     client: Arc<Client>,
     subscription: Option<client::Subscription>,
 }
@@ -73,6 +74,7 @@ impl ChannelBuffer {
 
             Self {
                 buffer,
+                buffer_epoch: response.epoch,
                 client,
                 connected: true,
                 collaborators,
@@ -82,6 +84,26 @@ impl ChannelBuffer {
         }))
     }
 
+    pub(crate) fn replace_collaborators(
+        &mut self,
+        collaborators: Vec<proto::Collaborator>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        for old_collaborator in &self.collaborators {
+            if collaborators
+                .iter()
+                .any(|c| c.replica_id == old_collaborator.replica_id)
+            {
+                self.buffer.update(cx, |buffer, cx| {
+                    buffer.remove_peer(old_collaborator.replica_id as u16, cx)
+                });
+            }
+        }
+        self.collaborators = collaborators;
+        cx.emit(Event::CollaboratorsChanged);
+        cx.notify();
+    }
+
     async fn handle_update_channel_buffer(
         this: ModelHandle<Self>,
         update_channel_buffer: TypedEnvelope<proto::UpdateChannelBuffer>,
@@ -166,6 +188,10 @@ impl ChannelBuffer {
         }
     }
 
+    pub fn epoch(&self) -> u64 {
+        self.buffer_epoch
+    }
+
     pub fn buffer(&self) -> ModelHandle<language::Buffer> {
         self.buffer.clone()
     }
@@ -179,6 +205,7 @@ impl ChannelBuffer {
     }
 
     pub(crate) fn disconnect(&mut self, cx: &mut ModelContext<Self>) {
+        log::info!("channel buffer {} disconnected", self.channel.id);
         if self.connected {
             self.connected = false;
             self.subscription.take();

crates/channel/src/channel_store.rs 🔗

@@ -1,13 +1,15 @@
 use crate::channel_buffer::ChannelBuffer;
 use anyhow::{anyhow, Result};
-use client::{Client, Status, Subscription, User, UserId, UserStore};
+use client::{Client, Subscription, User, UserId, UserStore};
 use collections::{hash_map, HashMap, HashSet};
 use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt};
 use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 use rpc::{proto, TypedEnvelope};
-use std::sync::Arc;
+use std::{mem, sync::Arc, time::Duration};
 use util::ResultExt;
 
+pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
+
 pub type ChannelId = u64;
 
 pub struct ChannelStore {
@@ -22,7 +24,8 @@ pub struct ChannelStore {
     client: Arc<Client>,
     user_store: ModelHandle<UserStore>,
     _rpc_subscription: Subscription,
-    _watch_connection_status: Task<()>,
+    _watch_connection_status: Task<Option<()>>,
+    disconnect_channel_buffers_task: Option<Task<()>>,
     _update_channels: Task<()>,
 }
 
@@ -67,24 +70,20 @@ impl ChannelStore {
         let rpc_subscription =
             client.add_message_handler(cx.handle(), Self::handle_update_channels);
 
-        let (update_channels_tx, mut update_channels_rx) = mpsc::unbounded();
         let mut connection_status = client.status();
+        let (update_channels_tx, mut update_channels_rx) = mpsc::unbounded();
         let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
             while let Some(status) = connection_status.next().await {
-                if !status.is_connected() {
-                    if let Some(this) = this.upgrade(&cx) {
-                        this.update(&mut cx, |this, cx| {
-                            if matches!(status, Status::ConnectionLost | Status::SignedOut) {
-                                this.handle_disconnect(cx);
-                            } else {
-                                this.disconnect_buffers(cx);
-                            }
-                        });
-                    } else {
-                        break;
-                    }
+                let this = this.upgrade(&cx)?;
+                if status.is_connected() {
+                    this.update(&mut cx, |this, cx| this.handle_connect(cx))
+                        .await
+                        .log_err()?;
+                } else {
+                    this.update(&mut cx, |this, cx| this.handle_disconnect(cx));
                 }
             }
+            Some(())
         });
 
         Self {
@@ -100,6 +99,7 @@ impl ChannelStore {
             user_store,
             _rpc_subscription: rpc_subscription,
             _watch_connection_status: watch_connection_status,
+            disconnect_channel_buffers_task: None,
             _update_channels: cx.spawn_weak(|this, mut cx| async move {
                 while let Some(update_channels) = update_channels_rx.next().await {
                     if let Some(this) = this.upgrade(&cx) {
@@ -482,8 +482,102 @@ impl ChannelStore {
         Ok(())
     }
 
-    fn handle_disconnect(&mut self, cx: &mut ModelContext<'_, ChannelStore>) {
-        self.disconnect_buffers(cx);
+    fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+        self.disconnect_channel_buffers_task.take();
+
+        let mut buffer_versions = Vec::new();
+        for buffer in self.opened_buffers.values() {
+            if let OpenedChannelBuffer::Open(buffer) = buffer {
+                if let Some(buffer) = buffer.upgrade(cx) {
+                    let channel_buffer = buffer.read(cx);
+                    let buffer = channel_buffer.buffer().read(cx);
+                    buffer_versions.push(proto::ChannelBufferVersion {
+                        channel_id: channel_buffer.channel().id,
+                        epoch: channel_buffer.epoch(),
+                        version: language::proto::serialize_version(&buffer.version()),
+                    });
+                }
+            }
+        }
+
+        let response = self.client.request(proto::RejoinChannelBuffers {
+            buffers: buffer_versions,
+        });
+
+        cx.spawn(|this, mut cx| async move {
+            let mut response = response.await?;
+
+            this.update(&mut cx, |this, cx| {
+                this.opened_buffers.retain(|_, buffer| match buffer {
+                    OpenedChannelBuffer::Open(channel_buffer) => {
+                        let Some(channel_buffer) = channel_buffer.upgrade(cx) else {
+                            return false;
+                        };
+
+                        channel_buffer.update(cx, |channel_buffer, cx| {
+                            let channel_id = channel_buffer.channel().id;
+                            if let Some(remote_buffer) = response
+                                .buffers
+                                .iter_mut()
+                                .find(|buffer| buffer.channel_id == channel_id)
+                            {
+                                let channel_id = channel_buffer.channel().id;
+                                let remote_version =
+                                    language::proto::deserialize_version(&remote_buffer.version);
+
+                                channel_buffer.replace_collaborators(
+                                    mem::take(&mut remote_buffer.collaborators),
+                                    cx,
+                                );
+
+                                let operations = channel_buffer
+                                    .buffer()
+                                    .update(cx, |buffer, cx| {
+                                        let outgoing_operations =
+                                            buffer.serialize_ops(Some(remote_version), cx);
+                                        let incoming_operations =
+                                            mem::take(&mut remote_buffer.operations)
+                                                .into_iter()
+                                                .map(language::proto::deserialize_operation)
+                                                .collect::<Result<Vec<_>>>()?;
+                                        buffer.apply_ops(incoming_operations, cx)?;
+                                        anyhow::Ok(outgoing_operations)
+                                    })
+                                    .log_err();
+
+                                if let Some(operations) = operations {
+                                    let client = this.client.clone();
+                                    cx.background()
+                                        .spawn(async move {
+                                            let operations = operations.await;
+                                            for chunk in
+                                                language::proto::split_operations(operations)
+                                            {
+                                                client
+                                                    .send(proto::UpdateChannelBuffer {
+                                                        channel_id,
+                                                        operations: chunk,
+                                                    })
+                                                    .ok();
+                                            }
+                                        })
+                                        .detach();
+                                    return true;
+                                }
+                            }
+
+                            channel_buffer.disconnect(cx);
+                            false
+                        })
+                    }
+                    OpenedChannelBuffer::Loading(_) => true,
+                });
+            });
+            anyhow::Ok(())
+        })
+    }
+
+    fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
         self.channels_by_id.clear();
         self.channel_invitations.clear();
         self.channel_participants.clear();
@@ -491,16 +585,23 @@ impl ChannelStore {
         self.channel_paths.clear();
         self.outgoing_invites.clear();
         cx.notify();
-    }
 
-    fn disconnect_buffers(&mut self, cx: &mut ModelContext<ChannelStore>) {
-        for (_, buffer) in self.opened_buffers.drain() {
-            if let OpenedChannelBuffer::Open(buffer) = buffer {
-                if let Some(buffer) = buffer.upgrade(cx) {
-                    buffer.update(cx, |buffer, cx| buffer.disconnect(cx));
+        self.disconnect_channel_buffers_task.get_or_insert_with(|| {
+            cx.spawn_weak(|this, mut cx| async move {
+                cx.background().timer(RECONNECT_TIMEOUT).await;
+                if let Some(this) = this.upgrade(&cx) {
+                    this.update(&mut cx, |this, cx| {
+                        for (_, buffer) in this.opened_buffers.drain() {
+                            if let OpenedChannelBuffer::Open(buffer) = buffer {
+                                if let Some(buffer) = buffer.upgrade(cx) {
+                                    buffer.update(cx, |buffer, cx| buffer.disconnect(cx));
+                                }
+                            }
+                        }
+                    });
                 }
-            }
-        }
+            })
+        });
     }
 
     pub(crate) fn update_channels(

crates/collab/src/db/queries/buffers.rs 🔗

@@ -10,8 +10,6 @@ impl Database {
         connection: ConnectionId,
     ) -> Result<proto::JoinChannelBufferResponse> {
         self.transaction(|tx| async move {
-            let tx = tx;
-
             self.check_user_is_channel_member(channel_id, user_id, &tx)
                 .await?;
 
@@ -70,7 +68,6 @@ impl Database {
             .await?;
             collaborators.push(collaborator);
 
-            // Assemble the buffer state
             let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
 
             Ok(proto::JoinChannelBufferResponse {
@@ -78,6 +75,7 @@ impl Database {
                 replica_id: replica_id.to_proto() as u32,
                 base_text,
                 operations,
+                epoch: buffer.epoch as u64,
                 collaborators: collaborators
                     .into_iter()
                     .map(|collaborator| proto::Collaborator {
@@ -91,6 +89,113 @@ impl Database {
         .await
     }
 
+    pub async fn rejoin_channel_buffers(
+        &self,
+        buffers: &[proto::ChannelBufferVersion],
+        user_id: UserId,
+        connection_id: ConnectionId,
+    ) -> Result<proto::RejoinChannelBuffersResponse> {
+        self.transaction(|tx| async move {
+            let mut response = proto::RejoinChannelBuffersResponse::default();
+            for client_buffer in buffers {
+                let channel_id = ChannelId::from_proto(client_buffer.channel_id);
+                if self
+                    .check_user_is_channel_member(channel_id, user_id, &*tx)
+                    .await
+                    .is_err()
+                {
+                    log::info!("user is not a member of channel");
+                    continue;
+                }
+
+                let buffer = self.get_channel_buffer(channel_id, &*tx).await?;
+                let mut collaborators = channel_buffer_collaborator::Entity::find()
+                    .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
+                    .all(&*tx)
+                    .await?;
+
+                // If the buffer epoch hasn't changed since the client lost
+                // connection, then the client's buffer can be syncronized with
+                // the server's buffer.
+                if buffer.epoch as u64 != client_buffer.epoch {
+                    continue;
+                }
+
+                // If there is still a disconnected collaborator for the user,
+                // update the connection associated with that collaborator, and reuse
+                // that replica id.
+                if let Some(ix) = collaborators
+                    .iter()
+                    .position(|c| c.user_id == user_id && c.connection_lost)
+                {
+                    let self_collaborator = &mut collaborators[ix];
+                    *self_collaborator = channel_buffer_collaborator::ActiveModel {
+                        id: ActiveValue::Unchanged(self_collaborator.id),
+                        connection_id: ActiveValue::Set(connection_id.id as i32),
+                        connection_server_id: ActiveValue::Set(ServerId(
+                            connection_id.owner_id as i32,
+                        )),
+                        connection_lost: ActiveValue::Set(false),
+                        ..Default::default()
+                    }
+                    .update(&*tx)
+                    .await?;
+                } else {
+                    continue;
+                }
+
+                let client_version = version_from_wire(&client_buffer.version);
+                let serialization_version = self
+                    .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
+                    .await?;
+
+                let mut rows = buffer_operation::Entity::find()
+                    .filter(
+                        buffer_operation::Column::BufferId
+                            .eq(buffer.id)
+                            .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
+                    )
+                    .stream(&*tx)
+                    .await?;
+
+                // Find the server's version vector and any operations
+                // that the client has not seen.
+                let mut server_version = clock::Global::new();
+                let mut operations = Vec::new();
+                while let Some(row) = rows.next().await {
+                    let row = row?;
+                    let timestamp = clock::Lamport {
+                        replica_id: row.replica_id as u16,
+                        value: row.lamport_timestamp as u32,
+                    };
+                    server_version.observe(timestamp);
+                    if !client_version.observed(timestamp) {
+                        operations.push(proto::Operation {
+                            variant: Some(operation_from_storage(row, serialization_version)?),
+                        })
+                    }
+                }
+
+                response.buffers.push(proto::RejoinedChannelBuffer {
+                    channel_id: client_buffer.channel_id,
+                    version: version_to_wire(&server_version),
+                    operations,
+                    collaborators: collaborators
+                        .into_iter()
+                        .map(|collaborator| proto::Collaborator {
+                            peer_id: Some(collaborator.connection().into()),
+                            user_id: collaborator.user_id.to_proto(),
+                            replica_id: collaborator.replica_id.0 as u32,
+                        })
+                        .collect(),
+                });
+            }
+
+            Ok(response)
+        })
+        .await
+    }
+
     pub async fn leave_channel_buffer(
         &self,
         channel_id: ChannelId,
@@ -103,6 +208,39 @@ impl Database {
         .await
     }
 
+    pub async fn leave_channel_buffers(
+        &self,
+        connection: ConnectionId,
+    ) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
+        self.transaction(|tx| async move {
+            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
+            enum QueryChannelIds {
+                ChannelId,
+            }
+
+            let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
+                .select_only()
+                .column(channel_buffer_collaborator::Column::ChannelId)
+                .filter(Condition::all().add(
+                    channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
+                ))
+                .into_values::<_, QueryChannelIds>()
+                .all(&*tx)
+                .await?;
+
+            let mut result = Vec::new();
+            for channel_id in channel_ids {
+                let collaborators = self
+                    .leave_channel_buffer_internal(channel_id, connection, &*tx)
+                    .await?;
+                result.push((channel_id, collaborators));
+            }
+
+            Ok(result)
+        })
+        .await
+    }
+
     pub async fn leave_channel_buffer_internal(
         &self,
         channel_id: ChannelId,
@@ -143,45 +281,12 @@ impl Database {
         drop(rows);
 
         if connections.is_empty() {
-            self.snapshot_buffer(channel_id, &tx).await?;
+            self.snapshot_channel_buffer(channel_id, &tx).await?;
         }
 
         Ok(connections)
     }
 
-    pub async fn leave_channel_buffers(
-        &self,
-        connection: ConnectionId,
-    ) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
-        self.transaction(|tx| async move {
-            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
-            enum QueryChannelIds {
-                ChannelId,
-            }
-
-            let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
-                .select_only()
-                .column(channel_buffer_collaborator::Column::ChannelId)
-                .filter(Condition::all().add(
-                    channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
-                ))
-                .into_values::<_, QueryChannelIds>()
-                .all(&*tx)
-                .await?;
-
-            let mut result = Vec::new();
-            for channel_id in channel_ids {
-                let collaborators = self
-                    .leave_channel_buffer_internal(channel_id, connection, &*tx)
-                    .await?;
-                result.push((channel_id, collaborators));
-            }
-
-            Ok(result)
-        })
-        .await
-    }
-
     pub async fn get_channel_buffer_collaborators(
         &self,
         channel_id: ChannelId,
@@ -224,20 +329,9 @@ impl Database {
                 .await?
                 .ok_or_else(|| anyhow!("no such buffer"))?;
 
-            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
-            enum QueryVersion {
-                OperationSerializationVersion,
-            }
-
-            let serialization_version: i32 = buffer
-                .find_related(buffer_snapshot::Entity)
-                .select_only()
-                .column(buffer_snapshot::Column::OperationSerializationVersion)
-                .filter(buffer_snapshot::Column::Epoch.eq(buffer.epoch))
-                .into_values::<_, QueryVersion>()
-                .one(&*tx)
-                .await?
-                .ok_or_else(|| anyhow!("missing buffer snapshot"))?;
+            let serialization_version = self
+                .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
+                .await?;
 
             let operations = operations
                 .iter()
@@ -270,6 +364,38 @@ impl Database {
         .await
     }
 
+    async fn get_buffer_operation_serialization_version(
+        &self,
+        buffer_id: BufferId,
+        epoch: i32,
+        tx: &DatabaseTransaction,
+    ) -> Result<i32> {
+        Ok(buffer_snapshot::Entity::find()
+            .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
+            .filter(buffer_snapshot::Column::Epoch.eq(epoch))
+            .select_only()
+            .column(buffer_snapshot::Column::OperationSerializationVersion)
+            .into_values::<_, QueryOperationSerializationVersion>()
+            .one(&*tx)
+            .await?
+            .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
+    }
+
+    async fn get_channel_buffer(
+        &self,
+        channel_id: ChannelId,
+        tx: &DatabaseTransaction,
+    ) -> Result<buffer::Model> {
+        Ok(channel::Model {
+            id: channel_id,
+            ..Default::default()
+        }
+        .find_related(buffer::Entity)
+        .one(&*tx)
+        .await?
+        .ok_or_else(|| anyhow!("no such buffer"))?)
+    }
+
     async fn get_buffer_state(
         &self,
         buffer: &buffer::Model,
@@ -303,27 +429,20 @@ impl Database {
             .await?;
         let mut operations = Vec::new();
         while let Some(row) = rows.next().await {
-            let row = row?;
-
-            let operation = operation_from_storage(row, version)?;
             operations.push(proto::Operation {
-                variant: Some(operation),
+                variant: Some(operation_from_storage(row?, version)?),
             })
         }
 
         Ok((base_text, operations))
     }
 
-    async fn snapshot_buffer(&self, channel_id: ChannelId, tx: &DatabaseTransaction) -> Result<()> {
-        let buffer = channel::Model {
-            id: channel_id,
-            ..Default::default()
-        }
-        .find_related(buffer::Entity)
-        .one(&*tx)
-        .await?
-        .ok_or_else(|| anyhow!("no such buffer"))?;
-
+    async fn snapshot_channel_buffer(
+        &self,
+        channel_id: ChannelId,
+        tx: &DatabaseTransaction,
+    ) -> Result<()> {
+        let buffer = self.get_channel_buffer(channel_id, tx).await?;
         let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
         if operations.is_empty() {
             return Ok(());
@@ -527,6 +646,22 @@ fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
     version
 }
 
+fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
+    let mut message = Vec::new();
+    for entry in version.iter() {
+        message.push(proto::VectorClockEntry {
+            replica_id: entry.replica_id as u32,
+            timestamp: entry.value,
+        });
+    }
+    message
+}
+
+#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
+enum QueryOperationSerializationVersion {
+    OperationSerializationVersion,
+}
+
 mod storage {
     #![allow(non_snake_case)]
     use prost::Message;

crates/collab/src/rpc.rs 🔗

@@ -251,6 +251,7 @@ impl Server {
             .add_request_handler(join_channel_buffer)
             .add_request_handler(leave_channel_buffer)
             .add_message_handler(update_channel_buffer)
+            .add_request_handler(rejoin_channel_buffers)
             .add_request_handler(get_channel_members)
             .add_request_handler(respond_to_channel_invite)
             .add_request_handler(join_channel)
@@ -854,13 +855,12 @@ async fn connection_lost(
         .await
         .trace_err();
 
-    leave_channel_buffers_for_session(&session)
-        .await
-        .trace_err();
-
     futures::select_biased! {
         _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
             leave_room_for_session(&session).await.trace_err();
+            leave_channel_buffers_for_session(&session)
+                .await
+                .trace_err();
 
             if !session
                 .connection_pool()
@@ -2547,6 +2547,23 @@ async fn update_channel_buffer(
     Ok(())
 }
 
+async fn rejoin_channel_buffers(
+    request: proto::RejoinChannelBuffers,
+    response: Response<proto::RejoinChannelBuffers>,
+    session: Session,
+) -> Result<()> {
+    let db = session.db().await;
+    let rejoin_response = db
+        .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
+        .await?;
+
+    // TODO: inform channel buffer collaborators that this user has rejoined.
+
+    response.send(rejoin_response)?;
+
+    Ok(())
+}
+
 async fn leave_channel_buffer(
     request: proto::LeaveChannelBuffer,
     response: Response<proto::LeaveChannelBuffer>,

crates/collab/src/tests/channel_buffer_tests.rs 🔗

@@ -21,20 +21,19 @@ async fn test_core_channel_buffers(
     let client_a = server.create_client(cx_a, "user_a").await;
     let client_b = server.create_client(cx_b, "user_b").await;
 
-    let zed_id = server
+    let channel_id = server
         .make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)])
         .await;
 
     // Client A joins the channel buffer
     let channel_buffer_a = client_a
         .channel_store()
-        .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx))
+        .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
         .await
         .unwrap();
 
     // Client A edits the buffer
     let buffer_a = channel_buffer_a.read_with(cx_a, |buffer, _| buffer.buffer());
-
     buffer_a.update(cx_a, |buffer, cx| {
         buffer.edit([(0..0, "hello world")], None, cx)
     });
@@ -45,17 +44,15 @@ async fn test_core_channel_buffers(
         buffer.edit([(0..5, "goodbye")], None, cx)
     });
     buffer_a.update(cx_a, |buffer, cx| buffer.undo(cx));
-    deterministic.run_until_parked();
-
     assert_eq!(buffer_text(&buffer_a, cx_a), "hello, cruel world");
+    deterministic.run_until_parked();
 
     // Client B joins the channel buffer
     let channel_buffer_b = client_b
         .channel_store()
-        .update(cx_b, |channel, cx| channel.open_channel_buffer(zed_id, cx))
+        .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
         .await
         .unwrap();
-
     channel_buffer_b.read_with(cx_b, |buffer, _| {
         assert_collaborators(
             buffer.collaborators(),
@@ -91,9 +88,7 @@ async fn test_core_channel_buffers(
     // Client A rejoins the channel buffer
     let _channel_buffer_a = client_a
         .channel_store()
-        .update(cx_a, |channels, cx| {
-            channels.open_channel_buffer(zed_id, cx)
-        })
+        .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
         .await
         .unwrap();
     deterministic.run_until_parked();
@@ -136,7 +131,7 @@ async fn test_channel_buffer_replica_ids(
 
     let channel_id = server
         .make_channel(
-            "zed",
+            "the-channel",
             (&client_a, cx_a),
             &mut [(&client_b, cx_b), (&client_c, cx_c)],
         )
@@ -160,23 +155,17 @@ async fn test_channel_buffer_replica_ids(
     // C first so that the replica IDs in the project and the channel buffer are different
     let channel_buffer_c = client_c
         .channel_store()
-        .update(cx_c, |channel, cx| {
-            channel.open_channel_buffer(channel_id, cx)
-        })
+        .update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx))
         .await
         .unwrap();
     let channel_buffer_b = client_b
         .channel_store()
-        .update(cx_b, |channel, cx| {
-            channel.open_channel_buffer(channel_id, cx)
-        })
+        .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
         .await
         .unwrap();
     let channel_buffer_a = client_a
         .channel_store()
-        .update(cx_a, |channel, cx| {
-            channel.open_channel_buffer(channel_id, cx)
-        })
+        .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
         .await
         .unwrap();
 
@@ -286,28 +275,30 @@ async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mu
     let mut server = TestServer::start(&deterministic).await;
     let client_a = server.create_client(cx_a, "user_a").await;
 
-    let zed_id = server.make_channel("zed", (&client_a, cx_a), &mut []).await;
+    let channel_id = server
+        .make_channel("the-channel", (&client_a, cx_a), &mut [])
+        .await;
 
     let channel_buffer_1 = client_a
         .channel_store()
-        .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
+        .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx));
     let channel_buffer_2 = client_a
         .channel_store()
-        .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
+        .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx));
     let channel_buffer_3 = client_a
         .channel_store()
-        .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
+        .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx));
 
     // All concurrent tasks for opening a channel buffer return the same model handle.
-    let (channel_buffer_1, channel_buffer_2, channel_buffer_3) =
+    let (channel_buffer, channel_buffer_2, channel_buffer_3) =
         future::try_join3(channel_buffer_1, channel_buffer_2, channel_buffer_3)
             .await
             .unwrap();
-    let model_id = channel_buffer_1.id();
-    assert_eq!(channel_buffer_1, channel_buffer_2);
-    assert_eq!(channel_buffer_1, channel_buffer_3);
+    let channel_buffer_model_id = channel_buffer.id();
+    assert_eq!(channel_buffer, channel_buffer_2);
+    assert_eq!(channel_buffer, channel_buffer_3);
 
-    channel_buffer_1.update(cx_a, |buffer, cx| {
+    channel_buffer.update(cx_a, |buffer, cx| {
         buffer.buffer().update(cx, |buffer, cx| {
             buffer.edit([(0..0, "hello")], None, cx);
         })
@@ -315,7 +306,7 @@ async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mu
     deterministic.run_until_parked();
 
     cx_a.update(|_| {
-        drop(channel_buffer_1);
+        drop(channel_buffer);
         drop(channel_buffer_2);
         drop(channel_buffer_3);
     });
@@ -324,10 +315,10 @@ async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mu
     // The channel buffer can be reopened after dropping it.
     let channel_buffer = client_a
         .channel_store()
-        .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx))
+        .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
         .await
         .unwrap();
-    assert_ne!(channel_buffer.id(), model_id);
+    assert_ne!(channel_buffer.id(), channel_buffer_model_id);
     channel_buffer.update(cx_a, |buffer, cx| {
         buffer.buffer().update(cx, |buffer, _| {
             assert_eq!(buffer.text(), "hello");
@@ -347,22 +338,17 @@ async fn test_channel_buffer_disconnect(
     let client_b = server.create_client(cx_b, "user_b").await;
 
     let channel_id = server
-        .make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)])
+        .make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)])
         .await;
 
     let channel_buffer_a = client_a
         .channel_store()
-        .update(cx_a, |channel, cx| {
-            channel.open_channel_buffer(channel_id, cx)
-        })
+        .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
         .await
         .unwrap();
-
     let channel_buffer_b = client_b
         .channel_store()
-        .update(cx_b, |channel, cx| {
-            channel.open_channel_buffer(channel_id, cx)
-        })
+        .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
         .await
         .unwrap();
 
@@ -375,7 +361,7 @@ async fn test_channel_buffer_disconnect(
             buffer.channel().as_ref(),
             &Channel {
                 id: channel_id,
-                name: "zed".to_string()
+                name: "the-channel".to_string()
             }
         );
         assert!(!buffer.is_connected());
@@ -403,13 +389,81 @@ async fn test_channel_buffer_disconnect(
             buffer.channel().as_ref(),
             &Channel {
                 id: channel_id,
-                name: "zed".to_string()
+                name: "the-channel".to_string()
             }
         );
         assert!(!buffer.is_connected());
     });
 }
 
+#[gpui::test]
+async fn test_rejoin_channel_buffer(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+    let mut server = TestServer::start(&deterministic).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+
+    let channel_id = server
+        .make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)])
+        .await;
+
+    let channel_buffer_a = client_a
+        .channel_store()
+        .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
+        .await
+        .unwrap();
+    let channel_buffer_b = client_b
+        .channel_store()
+        .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
+        .await
+        .unwrap();
+
+    channel_buffer_a.update(cx_a, |buffer, cx| {
+        buffer.buffer().update(cx, |buffer, cx| {
+            buffer.edit([(0..0, "1")], None, cx);
+        })
+    });
+    deterministic.run_until_parked();
+
+    // Client A disconnects.
+    server.forbid_connections();
+    server.disconnect_client(client_a.peer_id().unwrap());
+    // deterministic.advance_clock(RECEIVE_TIMEOUT);
+
+    // Both clients make an edit. Both clients see their own edit.
+    channel_buffer_a.update(cx_a, |buffer, cx| {
+        buffer.buffer().update(cx, |buffer, cx| {
+            buffer.edit([(1..1, "2")], None, cx);
+        })
+    });
+    channel_buffer_b.update(cx_b, |buffer, cx| {
+        buffer.buffer().update(cx, |buffer, cx| {
+            buffer.edit([(0..0, "0")], None, cx);
+        })
+    });
+    deterministic.run_until_parked();
+    channel_buffer_a.read_with(cx_a, |buffer, cx| {
+        assert_eq!(buffer.buffer().read(cx).text(), "12");
+    });
+    channel_buffer_b.read_with(cx_b, |buffer, cx| {
+        assert_eq!(buffer.buffer().read(cx).text(), "01");
+    });
+
+    // Client A reconnects.
+    server.allow_connections();
+    deterministic.advance_clock(RECEIVE_TIMEOUT);
+    channel_buffer_a.read_with(cx_a, |buffer, cx| {
+        assert_eq!(buffer.buffer().read(cx).text(), "012");
+    });
+    channel_buffer_b.read_with(cx_b, |buffer, cx| {
+        assert_eq!(buffer.buffer().read(cx).text(), "012");
+    });
+}
+
 #[track_caller]
 fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option<UserId>]) {
     assert_eq!(

crates/language/src/proto.rs 🔗

@@ -127,6 +127,31 @@ pub fn serialize_undo_map_entry(
     }
 }
 
+pub fn split_operations(
+    mut operations: Vec<proto::Operation>,
+) -> impl Iterator<Item = Vec<proto::Operation>> {
+    #[cfg(any(test, feature = "test-support"))]
+    const CHUNK_SIZE: usize = 5;
+
+    #[cfg(not(any(test, feature = "test-support")))]
+    const CHUNK_SIZE: usize = 100;
+
+    let mut done = false;
+    std::iter::from_fn(move || {
+        if done {
+            return None;
+        }
+
+        let operations = operations
+            .drain(..std::cmp::min(CHUNK_SIZE, operations.len()))
+            .collect::<Vec<_>>();
+        if operations.is_empty() {
+            done = true;
+        }
+        Some(operations)
+    })
+}
+
 pub fn serialize_selections(selections: &Arc<[Selection<Anchor>]>) -> Vec<proto::Selection> {
     selections.iter().map(serialize_selection).collect()
 }

crates/project/src/project.rs 🔗

@@ -35,7 +35,7 @@ use language::{
     point_to_lsp,
     proto::{
         deserialize_anchor, deserialize_fingerprint, deserialize_line_ending, deserialize_version,
-        serialize_anchor, serialize_version,
+        serialize_anchor, serialize_version, split_operations,
     },
     range_from_lsp, range_to_lsp, Bias, Buffer, BufferSnapshot, CachedLspAdapter, CodeAction,
     CodeLabel, Completion, Diagnostic, DiagnosticEntry, DiagnosticSet, Diff, Event as BufferEvent,
@@ -8200,31 +8200,6 @@ impl LspAdapterDelegate for ProjectLspAdapterDelegate {
     }
 }
 
-fn split_operations(
-    mut operations: Vec<proto::Operation>,
-) -> impl Iterator<Item = Vec<proto::Operation>> {
-    #[cfg(any(test, feature = "test-support"))]
-    const CHUNK_SIZE: usize = 5;
-
-    #[cfg(not(any(test, feature = "test-support")))]
-    const CHUNK_SIZE: usize = 100;
-
-    let mut done = false;
-    std::iter::from_fn(move || {
-        if done {
-            return None;
-        }
-
-        let operations = operations
-            .drain(..cmp::min(CHUNK_SIZE, operations.len()))
-            .collect::<Vec<_>>();
-        if operations.is_empty() {
-            done = true;
-        }
-        Some(operations)
-    })
-}
-
 fn serialize_symbol(symbol: &Symbol) -> proto::Symbol {
     proto::Symbol {
         language_server_name: symbol.language_server_name.0.to_string(),

crates/rpc/proto/zed.proto 🔗

@@ -1,6 +1,8 @@
 syntax = "proto3";
 package zed.messages;
 
+// Looking for a number? Search "// Current max"
+
 message PeerId {
     uint32 owner_id = 1;
     uint32 id = 2;
@@ -151,6 +153,8 @@ message Envelope {
         LeaveChannelBuffer leave_channel_buffer = 134;
         AddChannelBufferCollaborator add_channel_buffer_collaborator = 135;
         RemoveChannelBufferCollaborator remove_channel_buffer_collaborator = 136;
+        RejoinChannelBuffers rejoin_channel_buffers = 139;
+        RejoinChannelBuffersResponse rejoin_channel_buffers_response = 140; // Current max
     }
 }
 
@@ -616,6 +620,12 @@ message BufferVersion {
     repeated VectorClockEntry version = 2;
 }
 
+message ChannelBufferVersion {
+    uint64 channel_id = 1;
+    repeated VectorClockEntry version = 2;
+    uint64 epoch = 3;
+}
+
 enum FormatTrigger {
     Save = 0;
     Manual = 1;
@@ -1008,12 +1018,28 @@ message JoinChannelBuffer {
     uint64 channel_id = 1;
 }
 
+message RejoinChannelBuffers {
+    repeated ChannelBufferVersion buffers = 1;
+}
+
+message RejoinChannelBuffersResponse {
+    repeated RejoinedChannelBuffer buffers = 1;
+}
+
 message JoinChannelBufferResponse {
     uint64 buffer_id = 1;
     uint32 replica_id = 2;
     string base_text = 3;
     repeated Operation operations = 4;
     repeated Collaborator collaborators = 5;
+    uint64 epoch = 6;
+}
+
+message RejoinedChannelBuffer {
+    uint64 channel_id = 1;
+    repeated VectorClockEntry version = 2;
+    repeated Operation operations = 3;
+    repeated Collaborator collaborators = 4;
 }
 
 message LeaveChannelBuffer {

crates/rpc/src/proto.rs 🔗

@@ -229,6 +229,8 @@ messages!(
     (StartLanguageServer, Foreground),
     (SynchronizeBuffers, Foreground),
     (SynchronizeBuffersResponse, Foreground),
+    (RejoinChannelBuffers, Foreground),
+    (RejoinChannelBuffersResponse, Foreground),
     (Test, Foreground),
     (Unfollow, Foreground),
     (UnshareProject, Foreground),
@@ -319,6 +321,7 @@ request_messages!(
     (SearchProject, SearchProjectResponse),
     (ShareProject, ShareProjectResponse),
     (SynchronizeBuffers, SynchronizeBuffersResponse),
+    (RejoinChannelBuffers, RejoinChannelBuffersResponse),
     (Test, Test),
     (UpdateBuffer, Ack),
     (UpdateParticipantLocation, Ack),