diff --git a/crates/channel/src/channel_buffer.rs b/crates/channel/src/channel_buffer.rs index 29f4d3493c6d0fe9e2fc041695f40fe48225c76f..98ecbc5dcf5215ad98c9c591d4c17dce9f19d61a 100644 --- a/crates/channel/src/channel_buffer.rs +++ b/crates/channel/src/channel_buffer.rs @@ -17,6 +17,7 @@ pub struct ChannelBuffer { connected: bool, collaborators: Vec, buffer: ModelHandle, + buffer_epoch: u64, client: Arc, subscription: Option, } @@ -73,6 +74,7 @@ impl ChannelBuffer { Self { buffer, + buffer_epoch: response.epoch, client, connected: true, collaborators, @@ -82,6 +84,26 @@ impl ChannelBuffer { })) } + pub(crate) fn replace_collaborators( + &mut self, + collaborators: Vec, + cx: &mut ModelContext, + ) { + for old_collaborator in &self.collaborators { + if collaborators + .iter() + .any(|c| c.replica_id == old_collaborator.replica_id) + { + self.buffer.update(cx, |buffer, cx| { + buffer.remove_peer(old_collaborator.replica_id as u16, cx) + }); + } + } + self.collaborators = collaborators; + cx.emit(Event::CollaboratorsChanged); + cx.notify(); + } + async fn handle_update_channel_buffer( this: ModelHandle, update_channel_buffer: TypedEnvelope, @@ -166,6 +188,10 @@ impl ChannelBuffer { } } + pub fn epoch(&self) -> u64 { + self.buffer_epoch + } + pub fn buffer(&self) -> ModelHandle { self.buffer.clone() } @@ -179,6 +205,7 @@ impl ChannelBuffer { } pub(crate) fn disconnect(&mut self, cx: &mut ModelContext) { + log::info!("channel buffer {} disconnected", self.channel.id); if self.connected { self.connected = false; self.subscription.take(); diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index 861f731331ca6337ed7d798162ceb3321ad170fe..ec1652581d603e3be510489c3fbc159735a10849 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -1,13 +1,15 @@ use crate::channel_buffer::ChannelBuffer; use anyhow::{anyhow, Result}; -use client::{Client, Status, Subscription, User, UserId, UserStore}; +use client::{Client, Subscription, User, UserId, UserStore}; use collections::{hash_map, HashMap, HashSet}; use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt}; use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use rpc::{proto, TypedEnvelope}; -use std::sync::Arc; +use std::{mem, sync::Arc, time::Duration}; use util::ResultExt; +pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); + pub type ChannelId = u64; pub struct ChannelStore { @@ -22,7 +24,8 @@ pub struct ChannelStore { client: Arc, user_store: ModelHandle, _rpc_subscription: Subscription, - _watch_connection_status: Task<()>, + _watch_connection_status: Task>, + disconnect_channel_buffers_task: Option>, _update_channels: Task<()>, } @@ -67,24 +70,20 @@ impl ChannelStore { let rpc_subscription = client.add_message_handler(cx.handle(), Self::handle_update_channels); - let (update_channels_tx, mut update_channels_rx) = mpsc::unbounded(); let mut connection_status = client.status(); + let (update_channels_tx, mut update_channels_rx) = mpsc::unbounded(); let watch_connection_status = cx.spawn_weak(|this, mut cx| async move { while let Some(status) = connection_status.next().await { - if !status.is_connected() { - if let Some(this) = this.upgrade(&cx) { - this.update(&mut cx, |this, cx| { - if matches!(status, Status::ConnectionLost | Status::SignedOut) { - this.handle_disconnect(cx); - } else { - this.disconnect_buffers(cx); - } - }); - } else { - break; - } + let this = this.upgrade(&cx)?; + if status.is_connected() { + this.update(&mut cx, |this, cx| this.handle_connect(cx)) + .await + .log_err()?; + } else { + this.update(&mut cx, |this, cx| this.handle_disconnect(cx)); } } + Some(()) }); Self { @@ -100,6 +99,7 @@ impl ChannelStore { user_store, _rpc_subscription: rpc_subscription, _watch_connection_status: watch_connection_status, + disconnect_channel_buffers_task: None, _update_channels: cx.spawn_weak(|this, mut cx| async move { while let Some(update_channels) = update_channels_rx.next().await { if let Some(this) = this.upgrade(&cx) { @@ -482,8 +482,102 @@ impl ChannelStore { Ok(()) } - fn handle_disconnect(&mut self, cx: &mut ModelContext<'_, ChannelStore>) { - self.disconnect_buffers(cx); + fn handle_connect(&mut self, cx: &mut ModelContext) -> Task> { + self.disconnect_channel_buffers_task.take(); + + let mut buffer_versions = Vec::new(); + for buffer in self.opened_buffers.values() { + if let OpenedChannelBuffer::Open(buffer) = buffer { + if let Some(buffer) = buffer.upgrade(cx) { + let channel_buffer = buffer.read(cx); + let buffer = channel_buffer.buffer().read(cx); + buffer_versions.push(proto::ChannelBufferVersion { + channel_id: channel_buffer.channel().id, + epoch: channel_buffer.epoch(), + version: language::proto::serialize_version(&buffer.version()), + }); + } + } + } + + let response = self.client.request(proto::RejoinChannelBuffers { + buffers: buffer_versions, + }); + + cx.spawn(|this, mut cx| async move { + let mut response = response.await?; + + this.update(&mut cx, |this, cx| { + this.opened_buffers.retain(|_, buffer| match buffer { + OpenedChannelBuffer::Open(channel_buffer) => { + let Some(channel_buffer) = channel_buffer.upgrade(cx) else { + return false; + }; + + channel_buffer.update(cx, |channel_buffer, cx| { + let channel_id = channel_buffer.channel().id; + if let Some(remote_buffer) = response + .buffers + .iter_mut() + .find(|buffer| buffer.channel_id == channel_id) + { + let channel_id = channel_buffer.channel().id; + let remote_version = + language::proto::deserialize_version(&remote_buffer.version); + + channel_buffer.replace_collaborators( + mem::take(&mut remote_buffer.collaborators), + cx, + ); + + let operations = channel_buffer + .buffer() + .update(cx, |buffer, cx| { + let outgoing_operations = + buffer.serialize_ops(Some(remote_version), cx); + let incoming_operations = + mem::take(&mut remote_buffer.operations) + .into_iter() + .map(language::proto::deserialize_operation) + .collect::>>()?; + buffer.apply_ops(incoming_operations, cx)?; + anyhow::Ok(outgoing_operations) + }) + .log_err(); + + if let Some(operations) = operations { + let client = this.client.clone(); + cx.background() + .spawn(async move { + let operations = operations.await; + for chunk in + language::proto::split_operations(operations) + { + client + .send(proto::UpdateChannelBuffer { + channel_id, + operations: chunk, + }) + .ok(); + } + }) + .detach(); + return true; + } + } + + channel_buffer.disconnect(cx); + false + }) + } + OpenedChannelBuffer::Loading(_) => true, + }); + }); + anyhow::Ok(()) + }) + } + + fn handle_disconnect(&mut self, cx: &mut ModelContext) { self.channels_by_id.clear(); self.channel_invitations.clear(); self.channel_participants.clear(); @@ -491,16 +585,23 @@ impl ChannelStore { self.channel_paths.clear(); self.outgoing_invites.clear(); cx.notify(); - } - fn disconnect_buffers(&mut self, cx: &mut ModelContext) { - for (_, buffer) in self.opened_buffers.drain() { - if let OpenedChannelBuffer::Open(buffer) = buffer { - if let Some(buffer) = buffer.upgrade(cx) { - buffer.update(cx, |buffer, cx| buffer.disconnect(cx)); + self.disconnect_channel_buffers_task.get_or_insert_with(|| { + cx.spawn_weak(|this, mut cx| async move { + cx.background().timer(RECONNECT_TIMEOUT).await; + if let Some(this) = this.upgrade(&cx) { + this.update(&mut cx, |this, cx| { + for (_, buffer) in this.opened_buffers.drain() { + if let OpenedChannelBuffer::Open(buffer) = buffer { + if let Some(buffer) = buffer.upgrade(cx) { + buffer.update(cx, |buffer, cx| buffer.disconnect(cx)); + } + } + } + }); } - } - } + }) + }); } pub(crate) fn update_channels( diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index f120aea1c58a1745484d9e2eed674b20c8b091f3..587ed058ff88cfdacdede328e08694681e628109 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -10,8 +10,6 @@ impl Database { connection: ConnectionId, ) -> Result { self.transaction(|tx| async move { - let tx = tx; - self.check_user_is_channel_member(channel_id, user_id, &tx) .await?; @@ -70,7 +68,6 @@ impl Database { .await?; collaborators.push(collaborator); - // Assemble the buffer state let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?; Ok(proto::JoinChannelBufferResponse { @@ -78,6 +75,7 @@ impl Database { replica_id: replica_id.to_proto() as u32, base_text, operations, + epoch: buffer.epoch as u64, collaborators: collaborators .into_iter() .map(|collaborator| proto::Collaborator { @@ -91,6 +89,113 @@ impl Database { .await } + pub async fn rejoin_channel_buffers( + &self, + buffers: &[proto::ChannelBufferVersion], + user_id: UserId, + connection_id: ConnectionId, + ) -> Result { + self.transaction(|tx| async move { + let mut response = proto::RejoinChannelBuffersResponse::default(); + for client_buffer in buffers { + let channel_id = ChannelId::from_proto(client_buffer.channel_id); + if self + .check_user_is_channel_member(channel_id, user_id, &*tx) + .await + .is_err() + { + log::info!("user is not a member of channel"); + continue; + } + + let buffer = self.get_channel_buffer(channel_id, &*tx).await?; + let mut collaborators = channel_buffer_collaborator::Entity::find() + .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)) + .all(&*tx) + .await?; + + // If the buffer epoch hasn't changed since the client lost + // connection, then the client's buffer can be syncronized with + // the server's buffer. + if buffer.epoch as u64 != client_buffer.epoch { + continue; + } + + // If there is still a disconnected collaborator for the user, + // update the connection associated with that collaborator, and reuse + // that replica id. + if let Some(ix) = collaborators + .iter() + .position(|c| c.user_id == user_id && c.connection_lost) + { + let self_collaborator = &mut collaborators[ix]; + *self_collaborator = channel_buffer_collaborator::ActiveModel { + id: ActiveValue::Unchanged(self_collaborator.id), + connection_id: ActiveValue::Set(connection_id.id as i32), + connection_server_id: ActiveValue::Set(ServerId( + connection_id.owner_id as i32, + )), + connection_lost: ActiveValue::Set(false), + ..Default::default() + } + .update(&*tx) + .await?; + } else { + continue; + } + + let client_version = version_from_wire(&client_buffer.version); + let serialization_version = self + .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx) + .await?; + + let mut rows = buffer_operation::Entity::find() + .filter( + buffer_operation::Column::BufferId + .eq(buffer.id) + .and(buffer_operation::Column::Epoch.eq(buffer.epoch)), + ) + .stream(&*tx) + .await?; + + // Find the server's version vector and any operations + // that the client has not seen. + let mut server_version = clock::Global::new(); + let mut operations = Vec::new(); + while let Some(row) = rows.next().await { + let row = row?; + let timestamp = clock::Lamport { + replica_id: row.replica_id as u16, + value: row.lamport_timestamp as u32, + }; + server_version.observe(timestamp); + if !client_version.observed(timestamp) { + operations.push(proto::Operation { + variant: Some(operation_from_storage(row, serialization_version)?), + }) + } + } + + response.buffers.push(proto::RejoinedChannelBuffer { + channel_id: client_buffer.channel_id, + version: version_to_wire(&server_version), + operations, + collaborators: collaborators + .into_iter() + .map(|collaborator| proto::Collaborator { + peer_id: Some(collaborator.connection().into()), + user_id: collaborator.user_id.to_proto(), + replica_id: collaborator.replica_id.0 as u32, + }) + .collect(), + }); + } + + Ok(response) + }) + .await + } + pub async fn leave_channel_buffer( &self, channel_id: ChannelId, @@ -103,6 +208,39 @@ impl Database { .await } + pub async fn leave_channel_buffers( + &self, + connection: ConnectionId, + ) -> Result)>> { + self.transaction(|tx| async move { + #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] + enum QueryChannelIds { + ChannelId, + } + + let channel_ids: Vec = channel_buffer_collaborator::Entity::find() + .select_only() + .column(channel_buffer_collaborator::Column::ChannelId) + .filter(Condition::all().add( + channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32), + )) + .into_values::<_, QueryChannelIds>() + .all(&*tx) + .await?; + + let mut result = Vec::new(); + for channel_id in channel_ids { + let collaborators = self + .leave_channel_buffer_internal(channel_id, connection, &*tx) + .await?; + result.push((channel_id, collaborators)); + } + + Ok(result) + }) + .await + } + pub async fn leave_channel_buffer_internal( &self, channel_id: ChannelId, @@ -143,45 +281,12 @@ impl Database { drop(rows); if connections.is_empty() { - self.snapshot_buffer(channel_id, &tx).await?; + self.snapshot_channel_buffer(channel_id, &tx).await?; } Ok(connections) } - pub async fn leave_channel_buffers( - &self, - connection: ConnectionId, - ) -> Result)>> { - self.transaction(|tx| async move { - #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] - enum QueryChannelIds { - ChannelId, - } - - let channel_ids: Vec = channel_buffer_collaborator::Entity::find() - .select_only() - .column(channel_buffer_collaborator::Column::ChannelId) - .filter(Condition::all().add( - channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32), - )) - .into_values::<_, QueryChannelIds>() - .all(&*tx) - .await?; - - let mut result = Vec::new(); - for channel_id in channel_ids { - let collaborators = self - .leave_channel_buffer_internal(channel_id, connection, &*tx) - .await?; - result.push((channel_id, collaborators)); - } - - Ok(result) - }) - .await - } - pub async fn get_channel_buffer_collaborators( &self, channel_id: ChannelId, @@ -224,20 +329,9 @@ impl Database { .await? .ok_or_else(|| anyhow!("no such buffer"))?; - #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] - enum QueryVersion { - OperationSerializationVersion, - } - - let serialization_version: i32 = buffer - .find_related(buffer_snapshot::Entity) - .select_only() - .column(buffer_snapshot::Column::OperationSerializationVersion) - .filter(buffer_snapshot::Column::Epoch.eq(buffer.epoch)) - .into_values::<_, QueryVersion>() - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("missing buffer snapshot"))?; + let serialization_version = self + .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx) + .await?; let operations = operations .iter() @@ -270,6 +364,38 @@ impl Database { .await } + async fn get_buffer_operation_serialization_version( + &self, + buffer_id: BufferId, + epoch: i32, + tx: &DatabaseTransaction, + ) -> Result { + Ok(buffer_snapshot::Entity::find() + .filter(buffer_snapshot::Column::BufferId.eq(buffer_id)) + .filter(buffer_snapshot::Column::Epoch.eq(epoch)) + .select_only() + .column(buffer_snapshot::Column::OperationSerializationVersion) + .into_values::<_, QueryOperationSerializationVersion>() + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("missing buffer snapshot"))?) + } + + async fn get_channel_buffer( + &self, + channel_id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result { + Ok(channel::Model { + id: channel_id, + ..Default::default() + } + .find_related(buffer::Entity) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such buffer"))?) + } + async fn get_buffer_state( &self, buffer: &buffer::Model, @@ -303,27 +429,20 @@ impl Database { .await?; let mut operations = Vec::new(); while let Some(row) = rows.next().await { - let row = row?; - - let operation = operation_from_storage(row, version)?; operations.push(proto::Operation { - variant: Some(operation), + variant: Some(operation_from_storage(row?, version)?), }) } Ok((base_text, operations)) } - async fn snapshot_buffer(&self, channel_id: ChannelId, tx: &DatabaseTransaction) -> Result<()> { - let buffer = channel::Model { - id: channel_id, - ..Default::default() - } - .find_related(buffer::Entity) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("no such buffer"))?; - + async fn snapshot_channel_buffer( + &self, + channel_id: ChannelId, + tx: &DatabaseTransaction, + ) -> Result<()> { + let buffer = self.get_channel_buffer(channel_id, tx).await?; let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?; if operations.is_empty() { return Ok(()); @@ -527,6 +646,22 @@ fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global { version } +fn version_to_wire(version: &clock::Global) -> Vec { + let mut message = Vec::new(); + for entry in version.iter() { + message.push(proto::VectorClockEntry { + replica_id: entry.replica_id as u32, + timestamp: entry.value, + }); + } + message +} + +#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] +enum QueryOperationSerializationVersion { + OperationSerializationVersion, +} + mod storage { #![allow(non_snake_case)] use prost::Message; diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 6b44711c42f4a37eea15c437879650a7c269aad5..06aa00c9b8f58495c25fd878720083debd6c4398 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -251,6 +251,7 @@ impl Server { .add_request_handler(join_channel_buffer) .add_request_handler(leave_channel_buffer) .add_message_handler(update_channel_buffer) + .add_request_handler(rejoin_channel_buffers) .add_request_handler(get_channel_members) .add_request_handler(respond_to_channel_invite) .add_request_handler(join_channel) @@ -854,13 +855,12 @@ async fn connection_lost( .await .trace_err(); - leave_channel_buffers_for_session(&session) - .await - .trace_err(); - futures::select_biased! { _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => { leave_room_for_session(&session).await.trace_err(); + leave_channel_buffers_for_session(&session) + .await + .trace_err(); if !session .connection_pool() @@ -2547,6 +2547,23 @@ async fn update_channel_buffer( Ok(()) } +async fn rejoin_channel_buffers( + request: proto::RejoinChannelBuffers, + response: Response, + session: Session, +) -> Result<()> { + let db = session.db().await; + let rejoin_response = db + .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id) + .await?; + + // TODO: inform channel buffer collaborators that this user has rejoined. + + response.send(rejoin_response)?; + + Ok(()) +} + async fn leave_channel_buffer( request: proto::LeaveChannelBuffer, response: Response, diff --git a/crates/collab/src/tests/channel_buffer_tests.rs b/crates/collab/src/tests/channel_buffer_tests.rs index 8ac4dbbd3f1c606b52fb445a1c08ca4f1e8c6883..5ba5b50429063a2ffaceacf7c6a87c66bb9845c9 100644 --- a/crates/collab/src/tests/channel_buffer_tests.rs +++ b/crates/collab/src/tests/channel_buffer_tests.rs @@ -21,20 +21,19 @@ async fn test_core_channel_buffers( let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; - let zed_id = server + let channel_id = server .make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)]) .await; // Client A joins the channel buffer let channel_buffer_a = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); // Client A edits the buffer let buffer_a = channel_buffer_a.read_with(cx_a, |buffer, _| buffer.buffer()); - buffer_a.update(cx_a, |buffer, cx| { buffer.edit([(0..0, "hello world")], None, cx) }); @@ -45,17 +44,15 @@ async fn test_core_channel_buffers( buffer.edit([(0..5, "goodbye")], None, cx) }); buffer_a.update(cx_a, |buffer, cx| buffer.undo(cx)); - deterministic.run_until_parked(); - assert_eq!(buffer_text(&buffer_a, cx_a), "hello, cruel world"); + deterministic.run_until_parked(); // Client B joins the channel buffer let channel_buffer_b = client_b .channel_store() - .update(cx_b, |channel, cx| channel.open_channel_buffer(zed_id, cx)) + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); - channel_buffer_b.read_with(cx_b, |buffer, _| { assert_collaborators( buffer.collaborators(), @@ -91,9 +88,7 @@ async fn test_core_channel_buffers( // Client A rejoins the channel buffer let _channel_buffer_a = client_a .channel_store() - .update(cx_a, |channels, cx| { - channels.open_channel_buffer(zed_id, cx) - }) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); deterministic.run_until_parked(); @@ -136,7 +131,7 @@ async fn test_channel_buffer_replica_ids( let channel_id = server .make_channel( - "zed", + "the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b), (&client_c, cx_c)], ) @@ -160,23 +155,17 @@ async fn test_channel_buffer_replica_ids( // C first so that the replica IDs in the project and the channel buffer are different let channel_buffer_c = client_c .channel_store() - .update(cx_c, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); let channel_buffer_b = client_b .channel_store() - .update(cx_b, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); let channel_buffer_a = client_a .channel_store() - .update(cx_a, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); @@ -286,28 +275,30 @@ async fn test_reopen_channel_buffer(deterministic: Arc, cx_a: &mu let mut server = TestServer::start(&deterministic).await; let client_a = server.create_client(cx_a, "user_a").await; - let zed_id = server.make_channel("zed", (&client_a, cx_a), &mut []).await; + let channel_id = server + .make_channel("the-channel", (&client_a, cx_a), &mut []) + .await; let channel_buffer_1 = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)); + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)); let channel_buffer_2 = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)); + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)); let channel_buffer_3 = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)); + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)); // All concurrent tasks for opening a channel buffer return the same model handle. - let (channel_buffer_1, channel_buffer_2, channel_buffer_3) = + let (channel_buffer, channel_buffer_2, channel_buffer_3) = future::try_join3(channel_buffer_1, channel_buffer_2, channel_buffer_3) .await .unwrap(); - let model_id = channel_buffer_1.id(); - assert_eq!(channel_buffer_1, channel_buffer_2); - assert_eq!(channel_buffer_1, channel_buffer_3); + let channel_buffer_model_id = channel_buffer.id(); + assert_eq!(channel_buffer, channel_buffer_2); + assert_eq!(channel_buffer, channel_buffer_3); - channel_buffer_1.update(cx_a, |buffer, cx| { + channel_buffer.update(cx_a, |buffer, cx| { buffer.buffer().update(cx, |buffer, cx| { buffer.edit([(0..0, "hello")], None, cx); }) @@ -315,7 +306,7 @@ async fn test_reopen_channel_buffer(deterministic: Arc, cx_a: &mu deterministic.run_until_parked(); cx_a.update(|_| { - drop(channel_buffer_1); + drop(channel_buffer); drop(channel_buffer_2); drop(channel_buffer_3); }); @@ -324,10 +315,10 @@ async fn test_reopen_channel_buffer(deterministic: Arc, cx_a: &mu // The channel buffer can be reopened after dropping it. let channel_buffer = client_a .channel_store() - .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx)) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); - assert_ne!(channel_buffer.id(), model_id); + assert_ne!(channel_buffer.id(), channel_buffer_model_id); channel_buffer.update(cx_a, |buffer, cx| { buffer.buffer().update(cx, |buffer, _| { assert_eq!(buffer.text(), "hello"); @@ -347,22 +338,17 @@ async fn test_channel_buffer_disconnect( let client_b = server.create_client(cx_b, "user_b").await; let channel_id = server - .make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)]) + .make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)]) .await; let channel_buffer_a = client_a .channel_store() - .update(cx_a, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); - let channel_buffer_b = client_b .channel_store() - .update(cx_b, |channel, cx| { - channel.open_channel_buffer(channel_id, cx) - }) + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) .await .unwrap(); @@ -375,7 +361,7 @@ async fn test_channel_buffer_disconnect( buffer.channel().as_ref(), &Channel { id: channel_id, - name: "zed".to_string() + name: "the-channel".to_string() } ); assert!(!buffer.is_connected()); @@ -403,13 +389,81 @@ async fn test_channel_buffer_disconnect( buffer.channel().as_ref(), &Channel { id: channel_id, - name: "zed".to_string() + name: "the-channel".to_string() } ); assert!(!buffer.is_connected()); }); } +#[gpui::test] +async fn test_rejoin_channel_buffer( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channel_id = server + .make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)]) + .await; + + let channel_buffer_a = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + let channel_buffer_b = client_b + .channel_store() + .update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx)) + .await + .unwrap(); + + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "1")], None, cx); + }) + }); + deterministic.run_until_parked(); + + // Client A disconnects. + server.forbid_connections(); + server.disconnect_client(client_a.peer_id().unwrap()); + // deterministic.advance_clock(RECEIVE_TIMEOUT); + + // Both clients make an edit. Both clients see their own edit. + channel_buffer_a.update(cx_a, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(1..1, "2")], None, cx); + }) + }); + channel_buffer_b.update(cx_b, |buffer, cx| { + buffer.buffer().update(cx, |buffer, cx| { + buffer.edit([(0..0, "0")], None, cx); + }) + }); + deterministic.run_until_parked(); + channel_buffer_a.read_with(cx_a, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "12"); + }); + channel_buffer_b.read_with(cx_b, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "01"); + }); + + // Client A reconnects. + server.allow_connections(); + deterministic.advance_clock(RECEIVE_TIMEOUT); + channel_buffer_a.read_with(cx_a, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); + channel_buffer_b.read_with(cx_b, |buffer, cx| { + assert_eq!(buffer.buffer().read(cx).text(), "012"); + }); +} + #[track_caller] fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option]) { assert_eq!( diff --git a/crates/language/src/proto.rs b/crates/language/src/proto.rs index 80eb972f421197ea5dc80c39d06f203c530b707b..c4abe39d4782aafbe90594e3a0bc5de70787fa03 100644 --- a/crates/language/src/proto.rs +++ b/crates/language/src/proto.rs @@ -127,6 +127,31 @@ pub fn serialize_undo_map_entry( } } +pub fn split_operations( + mut operations: Vec, +) -> impl Iterator> { + #[cfg(any(test, feature = "test-support"))] + const CHUNK_SIZE: usize = 5; + + #[cfg(not(any(test, feature = "test-support")))] + const CHUNK_SIZE: usize = 100; + + let mut done = false; + std::iter::from_fn(move || { + if done { + return None; + } + + let operations = operations + .drain(..std::cmp::min(CHUNK_SIZE, operations.len())) + .collect::>(); + if operations.is_empty() { + done = true; + } + Some(operations) + }) +} + pub fn serialize_selections(selections: &Arc<[Selection]>) -> Vec { selections.iter().map(serialize_selection).collect() } diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 5cd13b8be8aca39d73741849d10376e08baf21c9..0690cc9188c129121da3aad95e9081a58c32f54b 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -35,7 +35,7 @@ use language::{ point_to_lsp, proto::{ deserialize_anchor, deserialize_fingerprint, deserialize_line_ending, deserialize_version, - serialize_anchor, serialize_version, + serialize_anchor, serialize_version, split_operations, }, range_from_lsp, range_to_lsp, Bias, Buffer, BufferSnapshot, CachedLspAdapter, CodeAction, CodeLabel, Completion, Diagnostic, DiagnosticEntry, DiagnosticSet, Diff, Event as BufferEvent, @@ -8200,31 +8200,6 @@ impl LspAdapterDelegate for ProjectLspAdapterDelegate { } } -fn split_operations( - mut operations: Vec, -) -> impl Iterator> { - #[cfg(any(test, feature = "test-support"))] - const CHUNK_SIZE: usize = 5; - - #[cfg(not(any(test, feature = "test-support")))] - const CHUNK_SIZE: usize = 100; - - let mut done = false; - std::iter::from_fn(move || { - if done { - return None; - } - - let operations = operations - .drain(..cmp::min(CHUNK_SIZE, operations.len())) - .collect::>(); - if operations.is_empty() { - done = true; - } - Some(operations) - }) -} - fn serialize_symbol(symbol: &Symbol) -> proto::Symbol { proto::Symbol { language_server_name: symbol.language_server_name.0.to_string(), diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 92c85677f6919aff06daef30b57e6482949d021f..fe9093245e83a24741452e0f168bb5c831824be6 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -1,6 +1,8 @@ syntax = "proto3"; package zed.messages; +// Looking for a number? Search "// Current max" + message PeerId { uint32 owner_id = 1; uint32 id = 2; @@ -151,6 +153,8 @@ message Envelope { LeaveChannelBuffer leave_channel_buffer = 134; AddChannelBufferCollaborator add_channel_buffer_collaborator = 135; RemoveChannelBufferCollaborator remove_channel_buffer_collaborator = 136; + RejoinChannelBuffers rejoin_channel_buffers = 139; + RejoinChannelBuffersResponse rejoin_channel_buffers_response = 140; // Current max } } @@ -616,6 +620,12 @@ message BufferVersion { repeated VectorClockEntry version = 2; } +message ChannelBufferVersion { + uint64 channel_id = 1; + repeated VectorClockEntry version = 2; + uint64 epoch = 3; +} + enum FormatTrigger { Save = 0; Manual = 1; @@ -1008,12 +1018,28 @@ message JoinChannelBuffer { uint64 channel_id = 1; } +message RejoinChannelBuffers { + repeated ChannelBufferVersion buffers = 1; +} + +message RejoinChannelBuffersResponse { + repeated RejoinedChannelBuffer buffers = 1; +} + message JoinChannelBufferResponse { uint64 buffer_id = 1; uint32 replica_id = 2; string base_text = 3; repeated Operation operations = 4; repeated Collaborator collaborators = 5; + uint64 epoch = 6; +} + +message RejoinedChannelBuffer { + uint64 channel_id = 1; + repeated VectorClockEntry version = 2; + repeated Operation operations = 3; + repeated Collaborator collaborators = 4; } message LeaveChannelBuffer { diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 2e4dce01e1a3bf5789206c80b3a4574f6e198c0d..a600bc4970dc7e8d5d199ee6f455728620aa4070 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -229,6 +229,8 @@ messages!( (StartLanguageServer, Foreground), (SynchronizeBuffers, Foreground), (SynchronizeBuffersResponse, Foreground), + (RejoinChannelBuffers, Foreground), + (RejoinChannelBuffersResponse, Foreground), (Test, Foreground), (Unfollow, Foreground), (UnshareProject, Foreground), @@ -319,6 +321,7 @@ request_messages!( (SearchProject, SearchProjectResponse), (ShareProject, ShareProjectResponse), (SynchronizeBuffers, SynchronizeBuffersResponse), + (RejoinChannelBuffers, RejoinChannelBuffersResponse), (Test, Test), (UpdateBuffer, Ack), (UpdateParticipantLocation, Ack),