Make channel notes read-only when disconnected

Max Brunsfeld and Mikayla created

Co-authored-by: Mikayla <mikayla@zed.dev>

Change summary

crates/channel/src/channel_buffer.rs            |  58 +++++---
crates/channel/src/channel_store.rs             | 115 ++++++++++++------
crates/collab/src/rpc.rs                        |   5 
crates/collab/src/tests/channel_buffer_tests.rs |  76 ++++++++++++
crates/collab/src/tests/channel_tests.rs        |   7 
crates/collab_ui/src/channel_view.rs            |  36 +++--
6 files changed, 217 insertions(+), 80 deletions(-)

Detailed changes

crates/channel/src/channel_buffer.rs 🔗

@@ -1,4 +1,4 @@
-use crate::{Channel, ChannelId, ChannelStore};
+use crate::Channel;
 use anyhow::Result;
 use client::Client;
 use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle};
@@ -13,39 +13,43 @@ pub(crate) fn init(client: &Arc<Client>) {
 }
 
 pub struct ChannelBuffer {
-    channel_id: ChannelId,
+    pub(crate) channel: Arc<Channel>,
+    connected: bool,
     collaborators: Vec<proto::Collaborator>,
     buffer: ModelHandle<language::Buffer>,
-    channel_store: ModelHandle<ChannelStore>,
     client: Arc<Client>,
-    _subscription: client::Subscription,
+    subscription: Option<client::Subscription>,
 }
 
 pub enum Event {
     CollaboratorsChanged,
+    Disconnected,
 }
 
 impl Entity for ChannelBuffer {
     type Event = Event;
 
     fn release(&mut self, _: &mut AppContext) {
-        self.client
-            .send(proto::LeaveChannelBuffer {
-                channel_id: self.channel_id,
-            })
-            .log_err();
+        if self.connected {
+            self.client
+                .send(proto::LeaveChannelBuffer {
+                    channel_id: self.channel.id,
+                })
+                .log_err();
+        }
     }
 }
 
 impl ChannelBuffer {
     pub(crate) async fn new(
-        channel_store: ModelHandle<ChannelStore>,
-        channel_id: ChannelId,
+        channel: Arc<Channel>,
         client: Arc<Client>,
         mut cx: AsyncAppContext,
     ) -> Result<ModelHandle<Self>> {
         let response = client
-            .request(proto::JoinChannelBuffer { channel_id })
+            .request(proto::JoinChannelBuffer {
+                channel_id: channel.id,
+            })
             .await?;
 
         let base_text = response.base_text;
@@ -62,7 +66,7 @@ impl ChannelBuffer {
         });
         buffer.update(&mut cx, |buffer, cx| buffer.apply_ops(operations, cx))?;
 
-        let subscription = client.subscribe_to_entity(channel_id)?;
+        let subscription = client.subscribe_to_entity(channel.id)?;
 
         anyhow::Ok(cx.add_model(|cx| {
             cx.subscribe(&buffer, Self::on_buffer_update).detach();
@@ -70,10 +74,10 @@ impl ChannelBuffer {
             Self {
                 buffer,
                 client,
-                channel_id,
-                channel_store,
+                connected: true,
                 collaborators,
-                _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
+                channel,
+                subscription: Some(subscription.set_model(&cx.handle(), &mut cx.to_async())),
             }
         }))
     }
@@ -155,7 +159,7 @@ impl ChannelBuffer {
             let operation = language::proto::serialize_operation(operation);
             self.client
                 .send(proto::UpdateChannelBuffer {
-                    channel_id: self.channel_id,
+                    channel_id: self.channel.id,
                     operations: vec![operation],
                 })
                 .log_err();
@@ -170,11 +174,21 @@ impl ChannelBuffer {
         &self.collaborators
     }
 
-    pub fn channel(&self, cx: &AppContext) -> Option<Arc<Channel>> {
-        self.channel_store
-            .read(cx)
-            .channel_for_id(self.channel_id)
-            .cloned()
+    pub fn channel(&self) -> Arc<Channel> {
+        self.channel.clone()
+    }
+
+    pub(crate) fn disconnect(&mut self, cx: &mut ModelContext<Self>) {
+        if self.connected {
+            self.connected = false;
+            self.subscription.take();
+            cx.emit(Event::Disconnected);
+            cx.notify()
+        }
+    }
+
+    pub fn is_connected(&self) -> bool {
+        self.connected
     }
 
     pub fn replica_id(&self, cx: &AppContext) -> u16 {

crates/channel/src/channel_store.rs 🔗

@@ -2,7 +2,7 @@ use crate::channel_buffer::ChannelBuffer;
 use anyhow::{anyhow, Result};
 use client::{Client, Status, Subscription, User, UserId, UserStore};
 use collections::{hash_map, HashMap, HashSet};
-use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt, TryFutureExt};
+use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt};
 use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 use rpc::{proto, TypedEnvelope};
 use std::sync::Arc;
@@ -71,16 +71,14 @@ impl ChannelStore {
         let mut connection_status = client.status();
         let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
             while let Some(status) = connection_status.next().await {
-                if matches!(status, Status::ConnectionLost | Status::SignedOut) {
+                if !status.is_connected() {
                     if let Some(this) = this.upgrade(&cx) {
                         this.update(&mut cx, |this, cx| {
-                            this.channels_by_id.clear();
-                            this.channel_invitations.clear();
-                            this.channel_participants.clear();
-                            this.channels_with_admin_privileges.clear();
-                            this.channel_paths.clear();
-                            this.outgoing_invites.clear();
-                            cx.notify();
+                            if matches!(status, Status::ConnectionLost | Status::SignedOut) {
+                                this.handle_disconnect(cx);
+                            } else {
+                                this.disconnect_buffers(cx);
+                            }
                         });
                     } else {
                         break;
@@ -176,9 +174,17 @@ impl ChannelStore {
                     OpenedChannelBuffer::Loading(task) => break task.clone(),
                 },
                 hash_map::Entry::Vacant(e) => {
+                    let client = self.client.clone();
                     let task = cx
-                        .spawn(|this, cx| {
-                            ChannelBuffer::new(this, channel_id, self.client.clone(), cx)
+                        .spawn(|this, cx| async move {
+                            let channel = this.read_with(&cx, |this, _| {
+                                this.channel_for_id(channel_id).cloned().ok_or_else(|| {
+                                    Arc::new(anyhow!("no channel for id: {}", channel_id))
+                                })
+                            })?;
+
+                            ChannelBuffer::new(channel, client, cx)
+                                .await
                                 .map_err(Arc::new)
                         })
                         .shared();
@@ -187,8 +193,8 @@ impl ChannelStore {
                         let task = task.clone();
                         |this, mut cx| async move {
                             let result = task.await;
-                            this.update(&mut cx, |this, cx| {
-                                if let Ok(buffer) = result {
+                            this.update(&mut cx, |this, cx| match result {
+                                Ok(buffer) => {
                                     cx.observe_release(&buffer, move |this, _, _| {
                                         this.opened_buffers.remove(&channel_id);
                                     })
@@ -197,7 +203,9 @@ impl ChannelStore {
                                         channel_id,
                                         OpenedChannelBuffer::Open(buffer.downgrade()),
                                     );
-                                } else {
+                                }
+                                Err(error) => {
+                                    log::error!("failed to open channel buffer {error:?}");
                                     this.opened_buffers.remove(&channel_id);
                                 }
                             });
@@ -474,6 +482,27 @@ impl ChannelStore {
         Ok(())
     }
 
+    fn handle_disconnect(&mut self, cx: &mut ModelContext<'_, ChannelStore>) {
+        self.disconnect_buffers(cx);
+        self.channels_by_id.clear();
+        self.channel_invitations.clear();
+        self.channel_participants.clear();
+        self.channels_with_admin_privileges.clear();
+        self.channel_paths.clear();
+        self.outgoing_invites.clear();
+        cx.notify();
+    }
+
+    fn disconnect_buffers(&mut self, cx: &mut ModelContext<ChannelStore>) {
+        for (_, buffer) in self.opened_buffers.drain() {
+            if let OpenedChannelBuffer::Open(buffer) = buffer {
+                if let Some(buffer) = buffer.upgrade(cx) {
+                    buffer.update(cx, |buffer, cx| buffer.disconnect(cx));
+                }
+            }
+        }
+    }
+
     pub(crate) fn update_channels(
         &mut self,
         payload: proto::UpdateChannels,
@@ -508,38 +537,44 @@ impl ChannelStore {
                     .retain(|channel_id, _| !payload.remove_channels.contains(channel_id));
                 self.channels_with_admin_privileges
                     .retain(|channel_id| !payload.remove_channels.contains(channel_id));
-            }
 
-            for channel in payload.channels {
-                if let Some(existing_channel) = self.channels_by_id.get_mut(&channel.id) {
-                    // FIXME: We may be missing a path for this existing channel in certain cases
-                    let existing_channel = Arc::make_mut(existing_channel);
-                    existing_channel.name = channel.name;
-                    continue;
+                for channel_id in &payload.remove_channels {
+                    let channel_id = *channel_id;
+                    if let Some(OpenedChannelBuffer::Open(buffer)) =
+                        self.opened_buffers.remove(&channel_id)
+                    {
+                        if let Some(buffer) = buffer.upgrade(cx) {
+                            buffer.update(cx, ChannelBuffer::disconnect);
+                        }
+                    }
                 }
+            }
 
-                self.channels_by_id.insert(
-                    channel.id,
-                    Arc::new(Channel {
-                        id: channel.id,
-                        name: channel.name,
-                    }),
-                );
-
-                if let Some(parent_id) = channel.parent_id {
-                    let mut ix = 0;
-                    while ix < self.channel_paths.len() {
-                        let path = &self.channel_paths[ix];
-                        if path.ends_with(&[parent_id]) {
-                            let mut new_path = path.clone();
-                            new_path.push(channel.id);
-                            self.channel_paths.insert(ix + 1, new_path);
+            for channel_proto in payload.channels {
+                if let Some(existing_channel) = self.channels_by_id.get_mut(&channel_proto.id) {
+                    Arc::make_mut(existing_channel).name = channel_proto.name;
+                } else {
+                    let channel = Arc::new(Channel {
+                        id: channel_proto.id,
+                        name: channel_proto.name,
+                    });
+                    self.channels_by_id.insert(channel.id, channel.clone());
+
+                    if let Some(parent_id) = channel_proto.parent_id {
+                        let mut ix = 0;
+                        while ix < self.channel_paths.len() {
+                            let path = &self.channel_paths[ix];
+                            if path.ends_with(&[parent_id]) {
+                                let mut new_path = path.clone();
+                                new_path.push(channel.id);
+                                self.channel_paths.insert(ix + 1, new_path);
+                                ix += 1;
+                            }
                             ix += 1;
                         }
-                        ix += 1;
+                    } else {
+                        self.channel_paths.push(vec![channel.id]);
                     }
-                } else {
-                    self.channel_paths.push(vec![channel.id]);
                 }
             }
 

crates/collab/src/rpc.rs 🔗

@@ -854,10 +854,13 @@ async fn connection_lost(
         .await
         .trace_err();
 
+    leave_channel_buffers_for_session(&session)
+        .await
+        .trace_err();
+
     futures::select_biased! {
         _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
             leave_room_for_session(&session).await.trace_err();
-            leave_channel_buffers_for_session(&session).await.trace_err();
 
             if !session
                 .connection_pool()

crates/collab/src/tests/channel_buffer_tests.rs 🔗

@@ -1,5 +1,6 @@
 use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
 use call::ActiveCall;
+use channel::Channel;
 use client::UserId;
 use collab_ui::channel_view::ChannelView;
 use collections::HashMap;
@@ -334,6 +335,81 @@ async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mu
     });
 }
 
+#[gpui::test]
+async fn test_channel_buffer_disconnect(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+    let mut server = TestServer::start(&deterministic).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+
+    let channel_id = server
+        .make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)])
+        .await;
+
+    let channel_buffer_a = client_a
+        .channel_store()
+        .update(cx_a, |channel, cx| {
+            channel.open_channel_buffer(channel_id, cx)
+        })
+        .await
+        .unwrap();
+
+    let channel_buffer_b = client_b
+        .channel_store()
+        .update(cx_b, |channel, cx| {
+            channel.open_channel_buffer(channel_id, cx)
+        })
+        .await
+        .unwrap();
+
+    server.forbid_connections();
+    server.disconnect_client(client_a.peer_id().unwrap());
+    deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
+
+    channel_buffer_a.update(cx_a, |buffer, _| {
+        assert_eq!(
+            buffer.channel().as_ref(),
+            &Channel {
+                id: channel_id,
+                name: "zed".to_string()
+            }
+        );
+        assert!(!buffer.is_connected());
+    });
+
+    deterministic.run_until_parked();
+
+    server.allow_connections();
+    deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
+
+    deterministic.run_until_parked();
+
+    client_a
+        .channel_store()
+        .update(cx_a, |channel_store, _| {
+            channel_store.remove_channel(channel_id)
+        })
+        .await
+        .unwrap();
+    deterministic.run_until_parked();
+
+    // Channel buffer observed the deletion
+    channel_buffer_b.update(cx_b, |buffer, _| {
+        assert_eq!(
+            buffer.channel().as_ref(),
+            &Channel {
+                id: channel_id,
+                name: "zed".to_string()
+            }
+        );
+        assert!(!buffer.is_connected());
+    });
+}
+
 #[track_caller]
 fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option<UserId>]) {
     assert_eq!(

crates/collab/src/tests/channel_tests.rs 🔗

@@ -799,7 +799,7 @@ async fn test_lost_channel_creation(
 
     deterministic.run_until_parked();
 
-    // Sanity check
+    // Sanity check, B has the invitation
     assert_channel_invitations(
         client_b.channel_store(),
         cx_b,
@@ -811,6 +811,7 @@ async fn test_lost_channel_creation(
         }],
     );
 
+    // A creates a subchannel while the invite is still pending.
     let subchannel_id = client_a
         .channel_store()
         .update(cx_a, |channel_store, cx| {
@@ -841,7 +842,7 @@ async fn test_lost_channel_creation(
         ],
     );
 
-    // Accept the invite
+    // Client B accepts the invite
     client_b
         .channel_store()
         .update(cx_b, |channel_store, _| {
@@ -852,7 +853,7 @@ async fn test_lost_channel_creation(
 
     deterministic.run_until_parked();
 
-    // B should now see the channel
+    // Client B should now see the channel
     assert_channels(
         client_b.channel_store(),
         cx_b,

crates/collab_ui/src/channel_view.rs 🔗

@@ -114,10 +114,18 @@ impl ChannelView {
     fn handle_channel_buffer_event(
         &mut self,
         _: ModelHandle<ChannelBuffer>,
-        _: &channel_buffer::Event,
+        event: &channel_buffer::Event,
         cx: &mut ViewContext<Self>,
     ) {
-        self.refresh_replica_id_map(cx);
+        match event {
+            channel_buffer::Event::CollaboratorsChanged => {
+                self.refresh_replica_id_map(cx);
+            }
+            channel_buffer::Event::Disconnected => self.editor.update(cx, |editor, cx| {
+                editor.set_read_only(true);
+                cx.notify();
+            }),
+        }
     }
 
     /// Build a mapping of channel buffer replica ids to the corresponding
@@ -183,14 +191,13 @@ impl Item for ChannelView {
         style: &theme::Tab,
         cx: &gpui::AppContext,
     ) -> AnyElement<V> {
-        let channel_name = self
-            .channel_buffer
-            .read(cx)
-            .channel(cx)
-            .map_or("[Deleted channel]".to_string(), |channel| {
-                format!("#{}", channel.name)
-            });
-        Label::new(channel_name, style.label.to_owned()).into_any()
+        let channel_name = &self.channel_buffer.read(cx).channel().name;
+        let label = if self.channel_buffer.read(cx).is_connected() {
+            format!("#{}", channel_name)
+        } else {
+            format!("#{} (disconnected)", channel_name)
+        };
+        Label::new(label, style.label.to_owned()).into_any()
     }
 
     fn clone_on_split(&self, _: WorkspaceId, cx: &mut ViewContext<Self>) -> Option<Self> {
@@ -208,8 +215,9 @@ impl FollowableItem for ChannelView {
     }
 
     fn to_state_proto(&self, cx: &AppContext) -> Option<proto::view::Variant> {
-        self.channel_buffer.read(cx).channel(cx).map(|channel| {
-            proto::view::Variant::ChannelView(proto::view::ChannelView {
+        let channel = self.channel_buffer.read(cx).channel();
+        Some(proto::view::Variant::ChannelView(
+            proto::view::ChannelView {
                 channel_id: channel.id,
                 editor: if let Some(proto::view::Variant::Editor(proto)) =
                     self.editor.read(cx).to_state_proto(cx)
@@ -218,8 +226,8 @@ impl FollowableItem for ChannelView {
                 } else {
                     None
                 },
-            })
-        })
+            },
+        ))
     }
 
     fn from_state_proto(