Restructure RPC state to also keep track of remote worktrees on guests

Antonio Scandurra created

Change summary

zed-rpc/proto/zed.proto  |   6 +
zed-rpc/src/peer.rs      |   2 
zed-rpc/src/proto.rs     |   1 
zed/src/editor/buffer.rs |   6 +
zed/src/rpc.rs           |  25 ++++++
zed/src/worktree.rs      | 133 ++++++++++++++++++++++++++++++-----------
6 files changed, 133 insertions(+), 40 deletions(-)

Detailed changes

zed-rpc/proto/zed.proto 🔗

@@ -16,6 +16,7 @@ message Envelope {
         OpenBufferResponse open_buffer_response = 11;
         CloseBuffer close_buffer = 12;
         UpdateBuffer update_buffer = 13;
+        RemoveGuest remove_guest = 14;
     }
 }
 
@@ -49,11 +50,14 @@ message OpenWorktreeResponse {
 
 message AddGuest {
     uint64 worktree_id = 1;
-    User user = 2;
+    uint32 replica_id = 2;
+    User user = 3;
 }
 
 message RemoveGuest {
     uint64 worktree_id = 1;
+    uint32 peer_id = 2;
+    uint32 replica_id = 3;
 }
 
 message OpenBuffer {

zed-rpc/src/peer.rs 🔗

@@ -29,7 +29,7 @@ type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
 pub struct ConnectionId(u32);
 
 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
-pub struct PeerId(u32);
+pub struct PeerId(pub u32);
 
 struct Connection {
     writer: Mutex<MessageStream<BoxedWriter>>,

zed-rpc/src/proto.rs 🔗

@@ -74,6 +74,7 @@ request_message!(OpenWorktree, OpenWorktreeResponse);
 request_message!(OpenBuffer, OpenBufferResponse);
 message!(CloseBuffer);
 message!(UpdateBuffer);
+message!(RemoveGuest);
 
 /// A stream of protobuf messages.
 pub struct MessageStream<T> {

zed/src/editor/buffer.rs 🔗

@@ -1398,6 +1398,12 @@ impl Buffer {
         self.operations.push(operation);
     }
 
+    pub fn peer_left(&mut self, replica_id: ReplicaId, cx: &mut ModelContext<Self>) {
+        self.selections
+            .retain(|set_id, _| set_id.replica_id != replica_id);
+        cx.notify();
+    }
+
     pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
         let was_dirty = self.is_dirty(cx.as_ref());
         let old_version = self.version.clone();

zed/src/rpc.rs 🔗

@@ -2,7 +2,7 @@ use super::util::SurfResultExt as _;
 use crate::{editor::Buffer, language::LanguageRegistry, worktree::Worktree};
 use anyhow::{anyhow, Context, Result};
 use gpui::executor::Background;
-use gpui::{AsyncAppContext, ModelHandle, Task};
+use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle};
 use lazy_static::lazy_static;
 use postage::prelude::Stream;
 use smol::lock::Mutex;
@@ -29,18 +29,37 @@ pub struct Client {
 
 pub struct ClientState {
     connection_id: Option<ConnectionId>,
-    pub shared_worktrees: HashMap<u64, ModelHandle<Worktree>>,
+    pub remote_worktrees: HashMap<u64, WeakModelHandle<Worktree>>,
     pub shared_buffers: HashMap<PeerId, HashMap<u64, ModelHandle<Buffer>>>,
     pub language_registry: Arc<LanguageRegistry>,
 }
 
+impl ClientState {
+    pub fn remote_worktree(
+        &mut self,
+        id: u64,
+        cx: &mut AsyncAppContext,
+    ) -> Result<ModelHandle<Worktree>> {
+        if let Some(worktree) = self.remote_worktrees.get(&id) {
+            if let Some(worktree) = cx.read(|cx| worktree.upgrade(cx)) {
+                Ok(worktree)
+            } else {
+                self.remote_worktrees.remove(&id);
+                Err(anyhow!("worktree {} was dropped", id))
+            }
+        } else {
+            Err(anyhow!("worktree {} does not exist", id))
+        }
+    }
+}
+
 impl Client {
     pub fn new(language_registry: Arc<LanguageRegistry>) -> Self {
         Self {
             peer: Peer::new(),
             state: Arc::new(Mutex::new(ClientState {
                 connection_id: None,
-                shared_worktrees: Default::default(),
+                remote_worktrees: Default::default(),
                 shared_buffers: Default::default(),
                 language_registry,
             })),

zed/src/worktree.rs 🔗

@@ -47,6 +47,7 @@ pub fn init(cx: &mut MutableAppContext, rpc: rpc::Client) {
     rpc.on_message(remote::open_buffer, cx);
     rpc.on_message(remote::close_buffer, cx);
     rpc.on_message(remote::update_buffer, cx);
+    rpc.on_message(remote::remove_guest, cx);
 }
 
 #[derive(Clone, Debug)]
@@ -88,17 +89,23 @@ impl Worktree {
         let replica_id = open_worktree_response
             .replica_id
             .ok_or_else(|| anyhow!("empty replica id"))?;
-        Ok(cx.update(|cx| {
+        let worktree = cx.update(|cx| {
             cx.add_model(|cx| {
                 Worktree::Remote(RemoteWorktree::new(
                     id,
                     worktree_message,
-                    rpc,
+                    rpc.clone(),
                     replica_id as ReplicaId,
                     cx,
                 ))
             })
-        }))
+        });
+        rpc.state
+            .lock()
+            .await
+            .remote_worktrees
+            .insert(id, worktree.downgrade());
+        Ok(worktree)
     }
 
     pub fn as_local(&self) -> Option<&LocalWorktree> {
@@ -165,6 +172,30 @@ impl Worktree {
             .is_some()
     }
 
+    pub fn buffer(&self, id: u64, cx: &AppContext) -> Option<ModelHandle<Buffer>> {
+        let open_buffers = match self {
+            Worktree::Local(worktree) => &worktree.open_buffers,
+            Worktree::Remote(worktree) => &worktree.open_buffers,
+        };
+
+        open_buffers
+            .get(&(id as usize))
+            .and_then(|buffer| buffer.upgrade(cx))
+    }
+
+    pub fn buffers<'a>(
+        &'a self,
+        cx: &'a AppContext,
+    ) -> impl 'a + Iterator<Item = ModelHandle<Buffer>> {
+        let open_buffers = match self {
+            Worktree::Local(worktree) => &worktree.open_buffers,
+            Worktree::Remote(worktree) => &worktree.open_buffers,
+        };
+        open_buffers
+            .values()
+            .filter_map(move |buffer| buffer.upgrade(cx))
+    }
+
     pub fn save(
         &self,
         path: &Path,
@@ -520,8 +551,8 @@ impl LocalWorktree {
             rpc.state
                 .lock()
                 .await
-                .shared_worktrees
-                .insert(share_response.worktree_id, handle);
+                .remote_worktrees
+                .insert(share_response.worktree_id, handle.downgrade());
 
             log::info!("sharing worktree {:?}", share_response);
 
@@ -1784,10 +1815,9 @@ impl<'a> Iterator for ChildEntriesIter<'a> {
 }
 
 mod remote {
-    use std::convert::TryInto;
-
     use super::*;
     use crate::rpc::TypedEnvelope;
+    use std::convert::TryInto;
 
     pub async fn open_buffer(
         envelope: TypedEnvelope<proto::OpenBuffer>,
@@ -1801,8 +1831,9 @@ mod remote {
 
         let mut state = rpc.state.lock().await;
         let worktree = state
-            .shared_worktrees
+            .remote_worktrees
             .get(&message.worktree_id)
+            .and_then(|worktree| cx.read(|cx| worktree.upgrade(cx)))
             .ok_or_else(|| anyhow!("worktree {} not found", message.worktree_id))?
             .clone();
 
@@ -1853,36 +1884,68 @@ mod remote {
         rpc: &rpc::Client,
         cx: &mut AsyncAppContext,
     ) -> anyhow::Result<()> {
-        let peer_id = envelope
-            .original_sender_id
-            .ok_or_else(|| anyhow!("missing original sender id"))?;
         let message = envelope.payload;
-        if let Some(buffer) = rpc
-            .state
-            .lock()
-            .await
-            .shared_buffers
-            .get(&peer_id)
-            .and_then(|buffers| buffers.get(&message.buffer_id))
-            .cloned()
-        {
-            if let Err(error) = buffer.update(cx, |buffer, cx| {
-                let ops = message
-                    .operations
-                    .into_iter()
-                    .map(|op| op.try_into())
-                    .collect::<anyhow::Result<Vec<_>>>()?;
-                buffer.apply_ops(ops, cx)?;
-                Ok::<(), anyhow::Error>(())
-            }) {
-                log::error!("error applying buffer operations {}", error);
+        let mut state = rpc.state.lock().await;
+        match state.remote_worktree(message.worktree_id, cx) {
+            Ok(worktree) => {
+                if let Some(buffer) =
+                    worktree.read_with(cx, |w, cx| w.buffer(message.buffer_id, cx))
+                {
+                    if let Err(error) = buffer.update(cx, |buffer, cx| {
+                        let ops = message
+                            .operations
+                            .into_iter()
+                            .map(|op| op.try_into())
+                            .collect::<anyhow::Result<Vec<_>>>()?;
+                        buffer.apply_ops(ops, cx)?;
+                        Ok::<(), anyhow::Error>(())
+                    }) {
+                        log::error!("error applying buffer operations {}", error);
+                    }
+                } else {
+                    log::error!(
+                        "invalid buffer {} in update buffer message",
+                        message.buffer_id
+                    );
+                }
             }
-        } else {
-            log::error!(
-                "invalid buffer {} in update buffer message",
-                message.buffer_id
-            );
+            Err(error) => log::error!("{}", error),
         }
+
+        Ok(())
+    }
+
+    pub async fn remove_guest(
+        envelope: TypedEnvelope<proto::RemoveGuest>,
+        rpc: &rpc::Client,
+        cx: &mut AsyncAppContext,
+    ) -> anyhow::Result<()> {
+        let peer_id = envelope.original_sender_id.unwrap();
+        let message = envelope.payload;
+        let mut state = rpc.state.lock().await;
+        match state.remote_worktree(message.worktree_id, cx) {
+            Ok(worktree) => {
+                let mut peer_buffers = state.shared_buffers.get_mut(&peer_id);
+                let buffers =
+                    worktree.read_with(cx, |worktree, cx| worktree.buffers(cx).collect::<Vec<_>>());
+                for buffer in buffers {
+                    buffer.update(cx, |buffer, cx| {
+                        buffer.peer_left(message.replica_id as ReplicaId, cx);
+                        if let Some(peer_buffers) = &mut peer_buffers {
+                            peer_buffers.remove(&buffer.remote_id());
+                        }
+                    });
+                }
+
+                if let Some(peer_buffers) = peer_buffers {
+                    if peer_buffers.is_empty() {
+                        state.shared_buffers.remove(&peer_id);
+                    }
+                }
+            }
+            Err(error) => log::error!("{}", error),
+        }
+
         Ok(())
     }
 }