Store shared buffers on LocalWorktree

Nathan Sobo created

It's okay for our domain objects to model remote state. We should minimize what we need to store in the rpc::ClientState struct.

Change summary

zed-rpc/src/peer.rs |   9 ++
zed/src/rpc.rs      |  14 +--
zed/src/worktree.rs | 163 ++++++++++++++++++++++++++--------------------
3 files changed, 107 insertions(+), 79 deletions(-)

Detailed changes

zed-rpc/src/peer.rs 🔗

@@ -51,11 +51,18 @@ pub struct Receipt<T> {
 
 pub struct TypedEnvelope<T> {
     pub sender_id: ConnectionId,
-    pub original_sender_id: Option<PeerId>,
+    original_sender_id: Option<PeerId>,
     pub message_id: u32,
     pub payload: T,
 }
 
+impl<T> TypedEnvelope<T> {
+    pub fn original_sender_id(&self) -> Result<PeerId> {
+        self.original_sender_id
+            .ok_or_else(|| anyhow!("missing original_sender_id"))
+    }
+}
+
 impl<T: RequestMessage> TypedEnvelope<T> {
     pub fn receipt(&self) -> Receipt<T> {
         Receipt {

zed/src/rpc.rs 🔗

@@ -1,5 +1,5 @@
 use super::util::SurfResultExt as _;
-use crate::{editor::Buffer, language::LanguageRegistry, worktree::Worktree};
+use crate::{language::LanguageRegistry, worktree::Worktree};
 use anyhow::{anyhow, Context, Result};
 use gpui::executor::Background;
 use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle};
@@ -29,22 +29,21 @@ pub struct Client {
 
 pub struct ClientState {
     connection_id: Option<ConnectionId>,
-    pub remote_worktrees: HashMap<u64, WeakModelHandle<Worktree>>,
-    pub shared_buffers: HashMap<PeerId, HashMap<u64, ModelHandle<Buffer>>>,
+    pub shared_worktrees: HashMap<u64, WeakModelHandle<Worktree>>,
     pub languages: Arc<LanguageRegistry>,
 }
 
 impl ClientState {
-    pub fn remote_worktree(
+    pub fn shared_worktree(
         &mut self,
         id: u64,
         cx: &mut AsyncAppContext,
     ) -> Result<ModelHandle<Worktree>> {
-        if let Some(worktree) = self.remote_worktrees.get(&id) {
+        if let Some(worktree) = self.shared_worktrees.get(&id) {
             if let Some(worktree) = cx.read(|cx| worktree.upgrade(cx)) {
                 Ok(worktree)
             } else {
-                self.remote_worktrees.remove(&id);
+                self.shared_worktrees.remove(&id);
                 Err(anyhow!("worktree {} was dropped", id))
             }
         } else {
@@ -59,8 +58,7 @@ impl Client {
             peer: Peer::new(),
             state: Arc::new(Mutex::new(ClientState {
                 connection_id: None,
-                remote_worktrees: Default::default(),
-                shared_buffers: Default::default(),
+                shared_worktrees: Default::default(),
                 languages,
             })),
         }

zed/src/worktree.rs 🔗

@@ -45,6 +45,7 @@ use std::{
     },
     time::{Duration, SystemTime},
 };
+use zed_rpc::{PeerId, TypedEnvelope};
 
 lazy_static! {
     static ref GITIGNORE: &'static OsStr = OsStr::new(".gitignore");
@@ -116,7 +117,7 @@ impl Worktree {
         rpc.state
             .lock()
             .await
-            .remote_worktrees
+            .shared_worktrees
             .insert(id, worktree.downgrade());
         Ok(worktree)
     }
@@ -244,6 +245,7 @@ pub struct LocalWorktree {
     poll_scheduled: bool,
     rpc: Option<(rpc::Client, u64)>,
     open_buffers: HashMap<usize, WeakModelHandle<Buffer>>,
+    shared_buffers: HashMap<PeerId, HashMap<u64, ModelHandle<Buffer>>>,
     languages: Arc<LanguageRegistry>,
 }
 
@@ -280,6 +282,7 @@ impl LocalWorktree {
             _event_stream_handle: event_stream_handle,
             poll_scheduled: false,
             open_buffers: Default::default(),
+            shared_buffers: Default::default(),
             rpc: None,
             languages,
         };
@@ -355,14 +358,64 @@ impl LocalWorktree {
                     Buffer::from_history(0, History::new(contents.into()), Some(file), language, cx)
                 });
                 this.update(&mut cx, |this, _| {
-                    let this = this.as_local_mut().unwrap();
+                    let this = this
+                        .as_local_mut()
+                        .ok_or_else(|| anyhow!("must be a local worktree"))?;
                     this.open_buffers.insert(buffer.id(), buffer.downgrade());
-                });
-                Ok(buffer)
+                    Ok(buffer)
+                })
             }
         })
     }
 
+    pub fn open_remote_buffer(
+        &mut self,
+        envelope: TypedEnvelope<proto::OpenBuffer>,
+        cx: &mut ModelContext<Worktree>,
+    ) -> Task<Result<proto::OpenBufferResponse>> {
+        let peer_id = envelope.original_sender_id();
+        let path = Path::new(&envelope.payload.path);
+
+        let buffer = self.open_buffer(path, cx);
+
+        cx.spawn(|this, mut cx| async move {
+            let buffer = buffer.await?;
+            this.update(&mut cx, |this, cx| {
+                this.as_local_mut()
+                    .unwrap()
+                    .shared_buffers
+                    .entry(peer_id?)
+                    .or_default()
+                    .insert(buffer.id() as u64, buffer.clone());
+
+                Ok(proto::OpenBufferResponse {
+                    buffer: Some(buffer.update(cx.as_mut(), |buffer, cx| buffer.to_proto(cx))),
+                })
+            })
+        })
+    }
+
+    pub fn close_remote_buffer(
+        &mut self,
+        envelope: TypedEnvelope<proto::CloseBuffer>,
+        cx: &mut ModelContext<Worktree>,
+    ) -> Result<()> {
+        if let Some(shared_buffers) = self.shared_buffers.get_mut(&envelope.original_sender_id()?) {
+            shared_buffers.remove(&envelope.payload.buffer_id);
+        }
+
+        Ok(())
+    }
+
+    pub fn remove_guest(
+        &mut self,
+        envelope: TypedEnvelope<proto::RemoveGuest>,
+        cx: &mut ModelContext<Worktree>,
+    ) -> Result<()> {
+        self.shared_buffers.remove(&envelope.original_sender_id()?);
+        Ok(())
+    }
+
     pub fn scan_complete(&self) -> impl Future<Output = ()> {
         let mut scan_state_rx = self.scan_state.1.clone();
         async move {
@@ -588,7 +641,7 @@ impl LocalWorktree {
             rpc.state
                 .lock()
                 .await
-                .remote_worktrees
+                .shared_worktrees
                 .insert(share_response.worktree_id, handle.downgrade());
 
             log::info!("sharing worktree {:?}", share_response);
@@ -1877,37 +1930,23 @@ mod remote {
         rpc: &rpc::Client,
         cx: &mut AsyncAppContext,
     ) -> anyhow::Result<()> {
-        let message = &envelope.payload;
-        let peer_id = envelope
-            .original_sender_id
-            .ok_or_else(|| anyhow!("missing original sender id"))?;
-
-        let mut state = rpc.state.lock().await;
-        let worktree = state
-            .remote_worktrees
-            .get(&message.worktree_id)
-            .and_then(|worktree| cx.read(|cx| worktree.upgrade(cx)))
-            .ok_or_else(|| anyhow!("worktree {} not found", message.worktree_id))?
-            .clone();
+        let receipt = envelope.receipt();
+        let worktree = rpc
+            .state
+            .lock()
+            .await
+            .shared_worktree(envelope.payload.worktree_id, cx)?;
 
-        let buffer = worktree
+        let response = worktree
             .update(cx, |worktree, cx| {
-                worktree.open_buffer(Path::new(&message.path), cx)
+                worktree
+                    .as_local_mut()
+                    .unwrap()
+                    .open_remote_buffer(envelope, cx)
             })
             .await?;
-        state
-            .shared_buffers
-            .entry(peer_id)
-            .or_default()
-            .insert(buffer.id() as u64, buffer.clone());
-
-        rpc.respond(
-            envelope.receipt(),
-            proto::OpenBufferResponse {
-                buffer: Some(buffer.update(cx, |buf, cx| buf.to_proto(cx))),
-            },
-        )
-        .await?;
+
+        rpc.respond(receipt, response).await?;
 
         Ok(())
     }
@@ -1915,17 +1954,20 @@ mod remote {
     pub async fn close_buffer(
         envelope: TypedEnvelope<proto::CloseBuffer>,
         rpc: &rpc::Client,
-        _: &mut AsyncAppContext,
+        cx: &mut AsyncAppContext,
     ) -> anyhow::Result<()> {
-        let peer_id = envelope
-            .original_sender_id
-            .ok_or_else(|| anyhow!("missing original sender id"))?;
-        let message = &envelope.payload;
-        let mut state = rpc.state.lock().await;
-        state.shared_buffers.entry(peer_id).and_modify(|buffers| {
-            buffers.remove(&message.buffer_id);
-        });
-        Ok(())
+        let worktree = rpc
+            .state
+            .lock()
+            .await
+            .shared_worktree(envelope.payload.worktree_id, cx)?;
+
+        worktree.update(cx, |worktree, cx| {
+            worktree
+                .as_local_mut()
+                .unwrap()
+                .close_remote_buffer(envelope, cx)
+        })
     }
 
     pub async fn update_buffer(
@@ -1935,7 +1977,7 @@ mod remote {
     ) -> anyhow::Result<()> {
         let message = envelope.payload;
         let mut state = rpc.state.lock().await;
-        match state.remote_worktree(message.worktree_id, cx) {
+        match state.shared_worktree(message.worktree_id, cx) {
             Ok(worktree) => {
                 if let Some(buffer) =
                     worktree.read_with(cx, |w, cx| w.buffer(message.buffer_id, cx))
@@ -1969,33 +2011,14 @@ mod remote {
         rpc: &rpc::Client,
         cx: &mut AsyncAppContext,
     ) -> anyhow::Result<()> {
-        let peer_id = envelope.original_sender_id.unwrap();
-        let message = envelope.payload;
-        let mut state = rpc.state.lock().await;
-        match state.remote_worktree(message.worktree_id, cx) {
-            Ok(worktree) => {
-                let mut peer_buffers = state.shared_buffers.get_mut(&peer_id);
-                let buffers =
-                    worktree.read_with(cx, |worktree, cx| worktree.buffers(cx).collect::<Vec<_>>());
-                for buffer in buffers {
-                    buffer.update(cx, |buffer, cx| {
-                        buffer.peer_left(message.replica_id as ReplicaId, cx);
-                        if let Some(peer_buffers) = &mut peer_buffers {
-                            peer_buffers.remove(&buffer.remote_id());
-                        }
-                    });
-                }
-
-                if let Some(peer_buffers) = peer_buffers {
-                    if peer_buffers.is_empty() {
-                        state.shared_buffers.remove(&peer_id);
-                    }
-                }
-            }
-            Err(error) => log::error!("{}", error),
-        }
-
-        Ok(())
+        rpc.state
+            .lock()
+            .await
+            .shared_worktree(envelope.payload.worktree_id, cx)?
+            .update(cx, |worktree, cx| match worktree {
+                Worktree::Local(worktree) => worktree.remove_guest(envelope, cx),
+                Worktree::Remote(_) => todo!(),
+            })
     }
 }