Dedup channel buffers

Max Brunsfeld created

Change summary

crates/channel/src/channel_buffer.rs            | 68 +++++++--------
crates/channel/src/channel_store.rs             | 81 +++++++++++++++---
crates/collab/src/tests/channel_buffer_tests.rs | 56 +++++++++++++
3 files changed, 155 insertions(+), 50 deletions(-)

Detailed changes

crates/channel/src/channel_buffer.rs 🔗

@@ -1,7 +1,7 @@
 use crate::{Channel, ChannelId, ChannelStore};
 use anyhow::Result;
 use client::Client;
-use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
+use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle};
 use rpc::{proto, TypedEnvelope};
 use std::sync::Arc;
 use util::ResultExt;
@@ -38,46 +38,44 @@ impl Entity for ChannelBuffer {
 }
 
 impl ChannelBuffer {
-    pub(crate) fn new(
+    pub(crate) async fn new(
         channel_store: ModelHandle<ChannelStore>,
         channel_id: ChannelId,
         client: Arc<Client>,
-        cx: &mut AppContext,
-    ) -> Task<Result<ModelHandle<Self>>> {
-        cx.spawn(|mut cx| async move {
-            let response = client
-                .request(proto::JoinChannelBuffer { channel_id })
-                .await?;
-
-            let base_text = response.base_text;
-            let operations = response
-                .operations
-                .into_iter()
-                .map(language::proto::deserialize_operation)
-                .collect::<Result<Vec<_>, _>>()?;
-
-            let collaborators = response.collaborators;
-
-            let buffer = cx.add_model(|_| {
-                language::Buffer::remote(response.buffer_id, response.replica_id as u16, base_text)
-            });
-            buffer.update(&mut cx, |buffer, cx| buffer.apply_ops(operations, cx))?;
+        mut cx: AsyncAppContext,
+    ) -> Result<ModelHandle<Self>> {
+        let response = client
+            .request(proto::JoinChannelBuffer { channel_id })
+            .await?;
 
-            let subscription = client.subscribe_to_entity(channel_id)?;
+        let base_text = response.base_text;
+        let operations = response
+            .operations
+            .into_iter()
+            .map(language::proto::deserialize_operation)
+            .collect::<Result<Vec<_>, _>>()?;
 
-            anyhow::Ok(cx.add_model(|cx| {
-                cx.subscribe(&buffer, Self::on_buffer_update).detach();
+        let collaborators = response.collaborators;
 
-                Self {
-                    buffer,
-                    client,
-                    channel_id,
-                    channel_store,
-                    collaborators,
-                    _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
-                }
-            }))
-        })
+        let buffer = cx.add_model(|_| {
+            language::Buffer::remote(response.buffer_id, response.replica_id as u16, base_text)
+        });
+        buffer.update(&mut cx, |buffer, cx| buffer.apply_ops(operations, cx))?;
+
+        let subscription = client.subscribe_to_entity(channel_id)?;
+
+        anyhow::Ok(cx.add_model(|cx| {
+            cx.subscribe(&buffer, Self::on_buffer_update).detach();
+
+            Self {
+                buffer,
+                client,
+                channel_id,
+                channel_store,
+                collaborators,
+                _subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
+            }
+        }))
     }
 
     async fn handle_update_channel_buffer(

crates/channel/src/channel_store.rs 🔗

@@ -1,20 +1,13 @@
-use anyhow::anyhow;
-use anyhow::Result;
-use client::Status;
-use client::UserId;
-use client::{Client, Subscription, User, UserStore};
-use collections::HashMap;
-use collections::HashSet;
-use futures::channel::mpsc;
-use futures::Future;
-use futures::StreamExt;
-use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
+use crate::channel_buffer::ChannelBuffer;
+use anyhow::{anyhow, Result};
+use client::{Client, Status, Subscription, User, UserId, UserStore};
+use collections::{hash_map, HashMap, HashSet};
+use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt, TryFutureExt};
+use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 use rpc::{proto, TypedEnvelope};
 use std::sync::Arc;
 use util::ResultExt;
 
-use crate::channel_buffer::ChannelBuffer;
-
 pub type ChannelId = u64;
 
 pub struct ChannelStore {
@@ -25,6 +18,7 @@ pub struct ChannelStore {
     channels_with_admin_privileges: HashSet<ChannelId>,
     outgoing_invites: HashSet<(ChannelId, UserId)>,
     update_channels_tx: mpsc::UnboundedSender<proto::UpdateChannels>,
+    opened_buffers: HashMap<ChannelId, OpenedChannelBuffer>,
     client: Arc<Client>,
     user_store: ModelHandle<UserStore>,
     _rpc_subscription: Subscription,
@@ -59,6 +53,11 @@ pub enum ChannelMemberStatus {
     NotMember,
 }
 
+enum OpenedChannelBuffer {
+    Open(WeakModelHandle<ChannelBuffer>),
+    Loading(Shared<Task<Result<ModelHandle<ChannelBuffer>, Arc<anyhow::Error>>>>),
+}
+
 impl ChannelStore {
     pub fn new(
         client: Arc<Client>,
@@ -89,6 +88,7 @@ impl ChannelStore {
                 }
             }
         });
+
         Self {
             channels_by_id: HashMap::default(),
             channel_invitations: Vec::default(),
@@ -96,6 +96,7 @@ impl ChannelStore {
             channel_participants: Default::default(),
             channels_with_admin_privileges: Default::default(),
             outgoing_invites: Default::default(),
+            opened_buffers: Default::default(),
             update_channels_tx,
             client,
             user_store,
@@ -154,11 +155,61 @@ impl ChannelStore {
     }
 
     pub fn open_channel_buffer(
-        &self,
+        &mut self,
         channel_id: ChannelId,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<ModelHandle<ChannelBuffer>>> {
-        ChannelBuffer::new(cx.handle(), channel_id, self.client.clone(), cx)
+        // Make sure that a given channel buffer is only opened once per
+        // app instance, even if this method is called multiple times
+        // with the same channel id while the first task is still running.
+        let task = loop {
+            match self.opened_buffers.entry(channel_id) {
+                hash_map::Entry::Occupied(e) => match e.get() {
+                    OpenedChannelBuffer::Open(buffer) => {
+                        if let Some(buffer) = buffer.upgrade(cx) {
+                            break Task::ready(Ok(buffer)).shared();
+                        } else {
+                            self.opened_buffers.remove(&channel_id);
+                            continue;
+                        }
+                    }
+                    OpenedChannelBuffer::Loading(task) => break task.clone(),
+                },
+                hash_map::Entry::Vacant(e) => {
+                    let task = cx
+                        .spawn(|this, cx| {
+                            ChannelBuffer::new(this, channel_id, self.client.clone(), cx)
+                                .map_err(Arc::new)
+                        })
+                        .shared();
+                    e.insert(OpenedChannelBuffer::Loading(task.clone()));
+                    cx.spawn({
+                        let task = task.clone();
+                        |this, mut cx| async move {
+                            let result = task.await;
+                            this.update(&mut cx, |this, cx| {
+                                if let Ok(buffer) = result {
+                                    cx.observe_release(&buffer, move |this, _, _| {
+                                        this.opened_buffers.remove(&channel_id);
+                                    })
+                                    .detach();
+                                    this.opened_buffers.insert(
+                                        channel_id,
+                                        OpenedChannelBuffer::Open(buffer.downgrade()),
+                                    );
+                                } else {
+                                    this.opened_buffers.remove(&channel_id);
+                                }
+                            });
+                        }
+                    })
+                    .detach();
+                    break task;
+                }
+            }
+        };
+        cx.foreground()
+            .spawn(async move { task.await.map_err(|error| anyhow!("{}", error)) })
     }
 
     pub fn is_user_admin(&self, channel_id: ChannelId) -> bool {

crates/collab/src/tests/channel_buffer_tests.rs 🔗

@@ -3,6 +3,7 @@ use call::ActiveCall;
 use client::UserId;
 use collab_ui::channel_view::ChannelView;
 use collections::HashMap;
+use futures::future;
 use gpui::{executor::Deterministic, ModelHandle, TestAppContext};
 use rpc::{proto, RECEIVE_TIMEOUT};
 use serde_json::json;
@@ -283,6 +284,61 @@ async fn test_channel_buffer_replica_ids(
     });
 }
 
+#[gpui::test]
+async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mut TestAppContext) {
+    deterministic.forbid_parking();
+    let mut server = TestServer::start(&deterministic).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+
+    let zed_id = server.make_channel("zed", (&client_a, cx_a), &mut []).await;
+
+    let channel_buffer_1 = client_a
+        .channel_store()
+        .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
+    let channel_buffer_2 = client_a
+        .channel_store()
+        .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
+    let channel_buffer_3 = client_a
+        .channel_store()
+        .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
+
+    // All concurrent tasks for opening a channel buffer return the same model handle.
+    let (channel_buffer_1, channel_buffer_2, channel_buffer_3) =
+        future::try_join3(channel_buffer_1, channel_buffer_2, channel_buffer_3)
+            .await
+            .unwrap();
+    let model_id = channel_buffer_1.id();
+    assert_eq!(channel_buffer_1, channel_buffer_2);
+    assert_eq!(channel_buffer_1, channel_buffer_3);
+
+    channel_buffer_1.update(cx_a, |buffer, cx| {
+        buffer.buffer().update(cx, |buffer, cx| {
+            buffer.edit([(0..0, "hello")], None, cx);
+        })
+    });
+    deterministic.run_until_parked();
+
+    cx_a.update(|_| {
+        drop(channel_buffer_1);
+        drop(channel_buffer_2);
+        drop(channel_buffer_3);
+    });
+    deterministic.run_until_parked();
+
+    // The channel buffer can be reopened after dropping it.
+    let channel_buffer = client_a
+        .channel_store()
+        .update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx))
+        .await
+        .unwrap();
+    assert_ne!(channel_buffer.id(), model_id);
+    channel_buffer.update(cx_a, |buffer, cx| {
+        buffer.buffer().update(cx, |buffer, _| {
+            assert_eq!(buffer.text(), "hello");
+        })
+    });
+}
+
 #[track_caller]
 fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option<UserId>]) {
     assert_eq!(