Broadcast new peer ids for rejoined channel collaborators

Max Brunsfeld created

Change summary

crates/channel/src/channel_buffer.rs            | 21 +++++
crates/collab/src/db.rs                         |  5 +
crates/collab/src/db/queries/buffers.rs         | 69 +++++++++---------
crates/collab/src/rpc.rs                        | 24 +++++
crates/collab/src/tests/channel_buffer_tests.rs | 14 +++
crates/rpc/proto/zed.proto                      | 11 ++
crates/rpc/src/proto.rs                         |  4 
7 files changed, 104 insertions(+), 44 deletions(-)

Detailed changes

crates/channel/src/channel_buffer.rs 🔗

@@ -10,6 +10,7 @@ pub(crate) fn init(client: &Arc<Client>) {
     client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer);
     client.add_model_message_handler(ChannelBuffer::handle_add_channel_buffer_collaborator);
     client.add_model_message_handler(ChannelBuffer::handle_remove_channel_buffer_collaborator);
+    client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer_collaborator);
 }
 
 pub struct ChannelBuffer {
@@ -171,6 +172,26 @@ impl ChannelBuffer {
         Ok(())
     }
 
+    async fn handle_update_channel_buffer_collaborator(
+        this: ModelHandle<Self>,
+        message: TypedEnvelope<proto::UpdateChannelBufferCollaborator>,
+        _: Arc<Client>,
+        mut cx: AsyncAppContext,
+    ) -> Result<()> {
+        this.update(&mut cx, |this, cx| {
+            for collaborator in &mut this.collaborators {
+                if collaborator.peer_id == message.payload.old_peer_id {
+                    collaborator.peer_id = message.payload.new_peer_id;
+                    break;
+                }
+            }
+            cx.emit(Event::CollaboratorsChanged);
+            cx.notify();
+        });
+
+        Ok(())
+    }
+
     fn on_buffer_update(
         &mut self,
         _: ModelHandle<language::Buffer>,

crates/collab/src/db.rs 🔗

@@ -435,6 +435,11 @@ pub struct ChannelsForUser {
     pub channels_with_admin_privileges: HashSet<ChannelId>,
 }
 
+pub struct RejoinedChannelBuffer {
+    pub buffer: proto::RejoinedChannelBuffer,
+    pub old_connection_id: ConnectionId,
+}
+
 #[derive(Clone)]
 pub struct JoinRoom {
     pub room: proto::Room,

crates/collab/src/db/queries/buffers.rs 🔗

@@ -94,9 +94,9 @@ impl Database {
         buffers: &[proto::ChannelBufferVersion],
         user_id: UserId,
         connection_id: ConnectionId,
-    ) -> Result<proto::RejoinChannelBuffersResponse> {
+    ) -> Result<Vec<RejoinedChannelBuffer>> {
         self.transaction(|tx| async move {
-            let mut response = proto::RejoinChannelBuffersResponse::default();
+            let mut results = Vec::new();
             for client_buffer in buffers {
                 let channel_id = ChannelId::from_proto(client_buffer.channel_id);
                 if self
@@ -121,28 +121,24 @@ impl Database {
                     continue;
                 }
 
-                // If there is still a disconnected collaborator for the user,
-                // update the connection associated with that collaborator, and reuse
-                // that replica id.
-                if let Some(ix) = collaborators
-                    .iter()
-                    .position(|c| c.user_id == user_id && c.connection_lost)
-                {
-                    let self_collaborator = &mut collaborators[ix];
-                    *self_collaborator = channel_buffer_collaborator::ActiveModel {
-                        id: ActiveValue::Unchanged(self_collaborator.id),
-                        connection_id: ActiveValue::Set(connection_id.id as i32),
-                        connection_server_id: ActiveValue::Set(ServerId(
-                            connection_id.owner_id as i32,
-                        )),
-                        connection_lost: ActiveValue::Set(false),
-                        ..Default::default()
-                    }
-                    .update(&*tx)
-                    .await?;
-                } else {
+                // Find the collaborator record for this user's previous lost
+                // connection. Update it with the new connection id.
+                let Some(self_collaborator) = collaborators
+                    .iter_mut()
+                    .find(|c| c.user_id == user_id && c.connection_lost)
+                else {
                     continue;
+                };
+                let old_connection_id = self_collaborator.connection();
+                *self_collaborator = channel_buffer_collaborator::ActiveModel {
+                    id: ActiveValue::Unchanged(self_collaborator.id),
+                    connection_id: ActiveValue::Set(connection_id.id as i32),
+                    connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
+                    connection_lost: ActiveValue::Set(false),
+                    ..Default::default()
                 }
+                .update(&*tx)
+                .await?;
 
                 let client_version = version_from_wire(&client_buffer.version);
                 let serialization_version = self
@@ -176,22 +172,25 @@ impl Database {
                     }
                 }
 
-                response.buffers.push(proto::RejoinedChannelBuffer {
-                    channel_id: client_buffer.channel_id,
-                    version: version_to_wire(&server_version),
-                    operations,
-                    collaborators: collaborators
-                        .into_iter()
-                        .map(|collaborator| proto::Collaborator {
-                            peer_id: Some(collaborator.connection().into()),
-                            user_id: collaborator.user_id.to_proto(),
-                            replica_id: collaborator.replica_id.0 as u32,
-                        })
-                        .collect(),
+                results.push(RejoinedChannelBuffer {
+                    old_connection_id,
+                    buffer: proto::RejoinedChannelBuffer {
+                        channel_id: client_buffer.channel_id,
+                        version: version_to_wire(&server_version),
+                        operations,
+                        collaborators: collaborators
+                            .into_iter()
+                            .map(|collaborator| proto::Collaborator {
+                                peer_id: Some(collaborator.connection().into()),
+                                user_id: collaborator.user_id.to_proto(),
+                                replica_id: collaborator.replica_id.0 as u32,
+                            })
+                            .collect(),
+                    },
                 });
             }
 
-            Ok(response)
+            Ok(results)
         })
         .await
     }

crates/collab/src/rpc.rs 🔗

@@ -2553,13 +2553,31 @@ async fn rejoin_channel_buffers(
     session: Session,
 ) -> Result<()> {
     let db = session.db().await;
-    let rejoin_response = db
+    let buffers = db
         .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
         .await?;
 
-    // TODO: inform channel buffer collaborators that this user has rejoined.
+    for buffer in &buffers {
+        let collaborators_to_notify = buffer
+            .buffer
+            .collaborators
+            .iter()
+            .filter_map(|c| Some(c.peer_id?.into()));
+        channel_buffer_updated(
+            session.connection_id,
+            collaborators_to_notify,
+            &proto::UpdateChannelBufferCollaborator {
+                channel_id: buffer.buffer.channel_id,
+                old_peer_id: Some(buffer.old_connection_id.into()),
+                new_peer_id: Some(session.connection_id.into()),
+            },
+            &session.peer,
+        );
+    }
 
-    response.send(rejoin_response)?;
+    response.send(proto::RejoinChannelBuffersResponse {
+        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
+    })?;
 
     Ok(())
 }

crates/collab/src/tests/channel_buffer_tests.rs 🔗

@@ -432,9 +432,8 @@ async fn test_rejoin_channel_buffer(
     // Client A disconnects.
     server.forbid_connections();
     server.disconnect_client(client_a.peer_id().unwrap());
-    // deterministic.advance_clock(RECEIVE_TIMEOUT);
 
-    // Both clients make an edit. Both clients see their own edit.
+    // Both clients make an edit.
     channel_buffer_a.update(cx_a, |buffer, cx| {
         buffer.buffer().update(cx, |buffer, cx| {
             buffer.edit([(1..1, "2")], None, cx);
@@ -445,6 +444,8 @@ async fn test_rejoin_channel_buffer(
             buffer.edit([(0..0, "0")], None, cx);
         })
     });
+
+    // Both clients see their own edit.
     deterministic.run_until_parked();
     channel_buffer_a.read_with(cx_a, |buffer, cx| {
         assert_eq!(buffer.buffer().read(cx).text(), "12");
@@ -453,7 +454,8 @@ async fn test_rejoin_channel_buffer(
         assert_eq!(buffer.buffer().read(cx).text(), "01");
     });
 
-    // Client A reconnects.
+    // Client A reconnects. Both clients see each other's edits, and see
+    // the same collaborators.
     server.allow_connections();
     deterministic.advance_clock(RECEIVE_TIMEOUT);
     channel_buffer_a.read_with(cx_a, |buffer, cx| {
@@ -462,6 +464,12 @@ async fn test_rejoin_channel_buffer(
     channel_buffer_b.read_with(cx_b, |buffer, cx| {
         assert_eq!(buffer.buffer().read(cx).text(), "012");
     });
+
+    channel_buffer_a.read_with(cx_a, |buffer_a, _| {
+        channel_buffer_b.read_with(cx_b, |buffer_b, _| {
+            assert_eq!(buffer_a.collaborators(), buffer_b.collaborators());
+        });
+    });
 }
 
 #[track_caller]

crates/rpc/proto/zed.proto 🔗

@@ -153,8 +153,9 @@ message Envelope {
         LeaveChannelBuffer leave_channel_buffer = 134;
         AddChannelBufferCollaborator add_channel_buffer_collaborator = 135;
         RemoveChannelBufferCollaborator remove_channel_buffer_collaborator = 136;
-        RejoinChannelBuffers rejoin_channel_buffers = 139;
-        RejoinChannelBuffersResponse rejoin_channel_buffers_response = 140; // Current max
+        UpdateChannelBufferCollaborator update_channel_buffer_collaborator = 139;
+        RejoinChannelBuffers rejoin_channel_buffers = 140;
+        RejoinChannelBuffersResponse rejoin_channel_buffers_response = 141; // Current max
     }
 }
 
@@ -434,6 +435,12 @@ message RemoveChannelBufferCollaborator {
     PeerId peer_id = 2;
 }
 
+message UpdateChannelBufferCollaborator {
+    uint64 channel_id = 1;
+    PeerId old_peer_id = 2;
+    PeerId new_peer_id = 3;
+}
+
 message GetDefinition {
      uint64 project_id = 1;
      uint64 buffer_id = 2;

crates/rpc/src/proto.rs 🔗

@@ -259,6 +259,7 @@ messages!(
     (UpdateChannelBuffer, Foreground),
     (RemoveChannelBufferCollaborator, Foreground),
     (AddChannelBufferCollaborator, Foreground),
+    (UpdateChannelBufferCollaborator, Foreground),
 );
 
 request_messages!(
@@ -389,7 +390,8 @@ entity_messages!(
     channel_id,
     UpdateChannelBuffer,
     RemoveChannelBufferCollaborator,
-    AddChannelBufferCollaborator
+    AddChannelBufferCollaborator,
+    UpdateChannelBufferCollaborator
 );
 
 const KIB: usize = 1024;