Notify host when guests close buffers

Nathan Sobo and Max Brunsfeld created

Co-Authored-By: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

zed-rpc/proto/zed.proto  |  17 +++---
zed-rpc/src/peer.rs      |   4 
zed/src/editor/buffer.rs |  57 ++++++++++++++++++---
zed/src/rpc.rs           |  50 +++++++++---------
zed/src/workspace.rs     |  16 +----
zed/src/worktree.rs      | 112 +++++++++++++++++++----------------------
6 files changed, 142 insertions(+), 114 deletions(-)

Detailed changes

zed-rpc/proto/zed.proto 🔗

@@ -61,18 +61,18 @@ message OpenBuffer {
 }
 
 message OpenBufferResponse {
-    uint64 buffer_id = 1;
-    Buffer buffer = 2;
+    Buffer buffer = 1;
 }
 
 message CloseBuffer {
-    uint64 id = 1;
+    uint64 worktree_id = 1;
+    uint64 buffer_id = 2;
 }
 
 message User {
-    string github_login = 1;
-    string avatar_url = 2;
-    uint64 id = 3;
+    uint64 id = 1;
+    string github_login = 2;
+    string avatar_url = 3;
 }
 
 message Worktree {
@@ -90,8 +90,9 @@ message Entry {
 }
 
 message Buffer {
-    string content = 1;
-    repeated Operation history = 2;
+    uint64 id = 1;
+    string content = 2;
+    repeated Operation history = 3;
 }
 
 message Operation {

zed-rpc/src/peer.rs 🔗

@@ -442,8 +442,8 @@ mod tests {
                 path: "path/two".to_string(),
             };
             let response3 = proto::OpenBufferResponse {
-                buffer_id: 2,
                 buffer: Some(proto::Buffer {
+                    id: 2,
                     content: "path/two content".to_string(),
                     history: vec![],
                 }),
@@ -453,8 +453,8 @@ mod tests {
                 path: "path/one".to_string(),
             };
             let response4 = proto::OpenBufferResponse {
-                buffer_id: 1,
                 buffer: Some(proto::Buffer {
+                    id: 1,
                     content: "path/one content".to_string(),
                     history: vec![],
                 }),

zed/src/editor/buffer.rs 🔗

@@ -16,6 +16,7 @@ use zed_rpc::proto;
 use crate::{
     language::{Language, Tree},
     operation_queue::{self, OperationQueue},
+    rpc,
     settings::{StyleId, ThemeMap},
     sum_tree::{self, FilterCursor, SumTree},
     time::{self, ReplicaId},
@@ -117,6 +118,7 @@ pub struct Buffer {
     undo_map: UndoMap,
     history: History,
     file: Option<File>,
+    rpc: Option<rpc::Client>,
     language: Option<Arc<Language>>,
     syntax_tree: Mutex<Option<SyntaxTree>>,
     is_parsing: bool,
@@ -125,6 +127,7 @@ pub struct Buffer {
     deferred_ops: OperationQueue<Operation>,
     deferred_replicas: HashSet<ReplicaId>,
     replica_id: ReplicaId,
+    remote_id: Option<u64>,
     local_clock: time::Local,
     lamport_clock: time::Lamport,
 }
@@ -416,7 +419,15 @@ impl Buffer {
         base_text: T,
         cx: &mut ModelContext<Self>,
     ) -> Self {
-        Self::build(replica_id, History::new(base_text.into()), None, None, cx)
+        Self::build(
+            replica_id,
+            History::new(base_text.into()),
+            None,
+            None,
+            None,
+            None,
+            cx,
+        )
     }
 
     pub fn from_history(
@@ -426,13 +437,15 @@ impl Buffer {
         language: Option<Arc<Language>>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
-        Self::build(replica_id, history, file, language, cx)
+        Self::build(replica_id, history, file, None, None, language, cx)
     }
 
     fn build(
         replica_id: ReplicaId,
         history: History,
         file: Option<File>,
+        rpc: Option<rpc::Client>,
+        remote_id: Option<u64>,
         language: Option<Arc<Language>>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
@@ -469,6 +482,7 @@ impl Buffer {
             undo_map: Default::default(),
             history,
             file,
+            rpc,
             syntax_tree: Mutex::new(None),
             is_parsing: false,
             language,
@@ -478,6 +492,7 @@ impl Buffer {
             deferred_ops: OperationQueue::new(),
             deferred_replicas: HashSet::default(),
             replica_id,
+            remote_id,
             local_clock: time::Local::new(replica_id),
             lamport_clock: time::Lamport::new(replica_id),
         };
@@ -496,19 +511,22 @@ impl Buffer {
 
     pub fn from_proto(
         replica_id: ReplicaId,
-        remote_buffer: proto::Buffer,
+        message: proto::Buffer,
         file: Option<File>,
+        rpc: rpc::Client,
         language: Option<Arc<Language>>,
         cx: &mut ModelContext<Self>,
     ) -> Result<Self> {
         let mut buffer = Buffer::build(
             replica_id,
-            History::new(remote_buffer.content.into()),
+            History::new(message.content.into()),
             file,
+            Some(rpc),
+            Some(message.id),
             language,
             cx,
         );
-        let ops = remote_buffer
+        let ops = message
             .history
             .into_iter()
             .filter_map(|op| op.variant)
@@ -542,7 +560,7 @@ impl Buffer {
         Ok(buffer)
     }
 
-    pub fn to_proto(&self) -> proto::Buffer {
+    pub fn to_proto(&self, cx: &mut ModelContext<Self>) -> proto::Buffer {
         let ops = self
             .history
             .ops
@@ -577,6 +595,7 @@ impl Buffer {
             })
             .collect();
         proto::Buffer {
+            id: cx.model_id() as u64,
             content: self.history.base_text.to_string(),
             history: ops,
         }
@@ -730,7 +749,7 @@ impl Buffer {
 
                     // Parse the current text in a background thread.
                     let new_tree = cx
-                        .background_executor()
+                        .background()
                         .spawn({
                             let language = language.clone();
                             async move { Self::parse_text(&new_text, new_tree, &language) }
@@ -818,7 +837,7 @@ impl Buffer {
         // TODO: it would be nice to not allocate here.
         let old_text = self.text();
         let base_version = self.version();
-        cx.background_executor().spawn(async move {
+        cx.background().spawn(async move {
             let changes = TextDiff::from_lines(old_text.as_str(), new_text.as_ref())
                 .iter_all_changes()
                 .map(|c| (c.tag(), c.value().len()))
@@ -1778,11 +1797,13 @@ impl Clone for Buffer {
             selections_last_update: self.selections_last_update.clone(),
             deferred_ops: self.deferred_ops.clone(),
             file: self.file.clone(),
+            rpc: self.rpc.clone(),
             language: self.language.clone(),
             syntax_tree: Mutex::new(self.syntax_tree.lock().clone()),
             is_parsing: false,
             deferred_replicas: self.deferred_replicas.clone(),
             replica_id: self.replica_id,
+            remote_id: self.remote_id.clone(),
             local_clock: self.local_clock.clone(),
             lamport_clock: self.lamport_clock.clone(),
         }
@@ -1919,6 +1940,26 @@ pub enum Event {
 
 impl Entity for Buffer {
     type Event = Event;
+
+    fn release(&mut self, cx: &mut gpui::MutableAppContext) {
+        if let (Some(buffer_id), Some(file)) = (self.remote_id, self.file.as_ref()) {
+            let rpc = self.rpc.clone().unwrap();
+            let worktree_id = file.worktree_id() as u64;
+            cx.background()
+                .spawn(async move {
+                    if let Err(error) = rpc
+                        .send(proto::CloseBuffer {
+                            worktree_id,
+                            buffer_id,
+                        })
+                        .await
+                    {
+                        log::error!("error closing remote buffer: {}", error);
+                    };
+                })
+                .detach();
+        }
+    }
 }
 
 impl<'a, F: Fn(&FragmentSummary) -> bool> Iterator for Edits<'a, F> {

zed/src/rpc.rs 🔗

@@ -30,7 +30,7 @@ pub struct Client {
 pub struct ClientState {
     connection_id: Option<ConnectionId>,
     pub shared_worktrees: HashMap<u64, ModelHandle<Worktree>>,
-    pub shared_buffers: HashMap<PeerId, HashMap<usize, ModelHandle<Buffer>>>,
+    pub shared_buffers: HashMap<PeerId, HashMap<u64, ModelHandle<Buffer>>>,
     pub language_registry: Arc<LanguageRegistry>,
 }
 
@@ -64,12 +64,12 @@ impl Client {
         .detach();
     }
 
-    pub async fn log_in_and_connect(&self, cx: &AsyncAppContext) -> surf::Result<ConnectionId> {
-        if let Some(connection_id) = self.state.lock().await.connection_id {
-            return Ok(connection_id);
+    pub async fn log_in_and_connect(&self, cx: &AsyncAppContext) -> surf::Result<()> {
+        if self.state.lock().await.connection_id.is_some() {
+            return Ok(());
         }
 
-        let (user_id, access_token) = Self::login(cx.platform(), &cx.background_executor()).await?;
+        let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
 
         let mut response = surf::get(format!(
             "{}{}",
@@ -88,13 +88,8 @@ impl Client {
             .await
             .context("failed to parse rpc address response")?;
 
-        self.connect(
-            &address,
-            user_id.parse()?,
-            access_token,
-            &cx.background_executor(),
-        )
-        .await
+        self.connect(&address, user_id.parse()?, access_token, &cx.background())
+            .await
     }
 
     pub async fn connect(
@@ -103,7 +98,7 @@ impl Client {
         user_id: i32,
         access_token: String,
         executor: &Arc<Background>,
-    ) -> surf::Result<ConnectionId> {
+    ) -> surf::Result<()> {
         // TODO - If the `ZED_SERVER_URL` uses https, then wrap this stream in
         // a TLS stream using `native-tls`.
         let stream = smol::net::TcpStream::connect(&address).await?;
@@ -129,7 +124,8 @@ impl Client {
             Err(anyhow!("failed to authenticate with RPC server"))?;
         }
 
-        Ok(connection_id)
+        self.state.lock().await.connection_id = Some(connection_id);
+        Ok(())
     }
 
     pub fn login(
@@ -208,20 +204,22 @@ impl Client {
         })
     }
 
-    pub fn send<T: EnvelopedMessage>(
-        &self,
-        connection_id: ConnectionId,
-        message: T,
-    ) -> impl Future<Output = Result<()>> {
-        self.peer.send(connection_id, message)
+    async fn connection_id(&self) -> Result<ConnectionId> {
+        self.state
+            .lock()
+            .await
+            .connection_id
+            .ok_or_else(|| anyhow!("not connected"))
     }
 
-    pub fn request<T: RequestMessage>(
-        &self,
-        connection_id: ConnectionId,
-        request: T,
-    ) -> impl Future<Output = Result<T::Response>> {
-        self.peer.request(connection_id, request)
+    pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
+        self.peer.send(self.connection_id().await?, message).await
+    }
+
+    pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
+        self.peer
+            .request(self.connection_id().await?, request)
+            .await
     }
 
     pub fn respond<T: RequestMessage>(

zed/src/workspace.rs 🔗

@@ -630,13 +630,13 @@ impl Workspace {
         let platform = cx.platform();
 
         let task = cx.spawn(|this, mut cx| async move {
-            let connection_id = rpc.log_in_and_connect(&cx).await?;
+            rpc.log_in_and_connect(&cx).await?;
 
             let share_task = this.update(&mut cx, |this, cx| {
                 let worktree = this.worktrees.iter().next()?;
                 worktree.update(cx, |worktree, cx| {
                     let worktree = worktree.as_local_mut()?;
-                    Some(worktree.share(rpc, connection_id, cx))
+                    Some(worktree.share(rpc, cx))
                 })
             });
 
@@ -661,7 +661,7 @@ impl Workspace {
         let rpc = self.rpc.clone();
 
         let task = cx.spawn(|this, mut cx| async move {
-            let connection_id = rpc.log_in_and_connect(&cx).await?;
+            rpc.log_in_and_connect(&cx).await?;
 
             let worktree_url = cx
                 .platform()
@@ -671,14 +671,8 @@ impl Workspace {
                 .ok_or_else(|| anyhow!("failed to decode worktree url"))?;
             log::info!("read worktree url from clipboard: {}", worktree_url.text());
 
-            let worktree = Worktree::remote(
-                rpc.clone(),
-                connection_id,
-                worktree_id,
-                access_token,
-                &mut cx,
-            )
-            .await?;
+            let worktree =
+                Worktree::remote(rpc.clone(), worktree_id, access_token, &mut cx).await?;
             this.update(&mut cx, |workspace, cx| {
                 cx.observe_model(&worktree, |_, _, cx| cx.notify());
                 workspace.worktrees.insert(worktree);

zed/src/worktree.rs 🔗

@@ -6,7 +6,7 @@ use self::{char_bag::CharBag, ignore::IgnoreStack};
 use crate::{
     editor::{Buffer, History, Rope},
     language::LanguageRegistry,
-    rpc::{self, proto, ConnectionId},
+    rpc::{self, proto},
     sum_tree::{self, Cursor, Edit, SumTree},
     time::ReplicaId,
     util::Bias,
@@ -71,19 +71,15 @@ impl Worktree {
 
     pub async fn remote(
         rpc: rpc::Client,
-        connection_id: ConnectionId,
         id: u64,
         access_token: String,
         cx: &mut AsyncAppContext,
     ) -> Result<ModelHandle<Self>> {
         let open_worktree_response = rpc
-            .request(
-                connection_id,
-                proto::OpenWorktree {
-                    worktree_id: id,
-                    access_token,
-                },
-            )
+            .request(proto::OpenWorktree {
+                worktree_id: id,
+                access_token,
+            })
             .await?;
         let worktree_message = open_worktree_response
             .worktree
@@ -97,7 +93,6 @@ impl Worktree {
                     id,
                     worktree_message,
                     rpc,
-                    connection_id,
                     replica_id as ReplicaId,
                     cx,
                 ))
@@ -150,6 +145,25 @@ impl Worktree {
         }
     }
 
+    pub fn has_open_buffer(&self, path: impl AsRef<Path>, cx: &AppContext) -> bool {
+        let open_buffers = match self {
+            Worktree::Local(worktree) => &worktree.open_buffers,
+            Worktree::Remote(worktree) => &worktree.open_buffers,
+        };
+
+        let path = path.as_ref();
+        open_buffers
+            .values()
+            .find(|buffer| {
+                if let Some(file) = buffer.upgrade(cx).and_then(|buffer| buffer.read(cx).file()) {
+                    file.path.as_ref() == path
+                } else {
+                    false
+                }
+            })
+            .is_some()
+    }
+
     pub fn save(
         &self,
         path: &Path,
@@ -415,7 +429,7 @@ impl LocalWorktree {
 
     fn load(&self, path: &Path, cx: &AppContext) -> Task<Result<String>> {
         let abs_path = self.absolutize(path);
-        cx.background_executor().spawn(async move {
+        cx.background().spawn(async move {
             let mut file = fs::File::open(&abs_path)?;
             let mut contents = String::new();
             file.read_to_string(&mut contents)?;
@@ -434,7 +448,7 @@ impl LocalWorktree {
         let background_snapshot = self.background_snapshot.clone();
         let save = {
             let path = path.clone();
-            cx.background_executor().spawn(async move {
+            cx.background().spawn(async move {
                 let buffer_size = content.summary().bytes.min(10 * 1024);
                 let file = fs::File::create(&abs_path)?;
                 let mut writer = io::BufWriter::with_capacity(buffer_size, &file);
@@ -472,7 +486,6 @@ impl LocalWorktree {
     pub fn share(
         &mut self,
         client: rpc::Client,
-        connection_id: ConnectionId,
         cx: &mut ModelContext<Worktree>,
     ) -> Task<anyhow::Result<(u64, String)>> {
         self.rpc = Some(client.clone());
@@ -481,7 +494,7 @@ impl LocalWorktree {
         let handle = cx.handle();
         cx.spawn(|_this, cx| async move {
             let entries = cx
-                .background_executor()
+                .background()
                 .spawn(async move {
                     snapshot
                         .entries
@@ -499,12 +512,9 @@ impl LocalWorktree {
                 .await;
 
             let share_response = client
-                .request(
-                    connection_id,
-                    proto::ShareWorktree {
-                        worktree: Some(proto::Worktree { root_name, entries }),
-                    },
-                )
+                .request(proto::ShareWorktree {
+                    worktree: Some(proto::Worktree { root_name, entries }),
+                })
                 .await?;
 
             client
@@ -538,9 +548,8 @@ pub struct RemoteWorktree {
     remote_id: u64,
     snapshot: Snapshot,
     rpc: rpc::Client,
-    connection_id: ConnectionId,
     replica_id: ReplicaId,
-    open_buffers: HashMap<u64, WeakModelHandle<Buffer>>,
+    open_buffers: HashMap<usize, WeakModelHandle<Buffer>>,
 }
 
 impl RemoteWorktree {
@@ -548,7 +557,6 @@ impl RemoteWorktree {
         remote_id: u64,
         worktree: proto::Worktree,
         rpc: rpc::Client,
-        connection_id: ConnectionId,
         replica_id: ReplicaId,
         cx: &mut ModelContext<Worktree>,
     ) -> Self {
@@ -596,7 +604,6 @@ impl RemoteWorktree {
             remote_id,
             snapshot,
             rpc,
-            connection_id,
             replica_id,
             open_buffers: Default::default(),
         }
@@ -625,7 +632,6 @@ impl RemoteWorktree {
 
         let rpc = self.rpc.clone();
         let replica_id = self.replica_id;
-        let connection_id = self.connection_id;
         let remote_worktree_id = self.remote_id;
         let path = path.to_string_lossy().to_string();
         cx.spawn(|this, mut cx| async move {
@@ -635,22 +641,21 @@ impl RemoteWorktree {
                 let file = File::new(handle, Path::new(&path).into());
                 let language = language_registry.select_language(&path).cloned();
                 let response = rpc
-                    .request(
-                        connection_id,
-                        proto::OpenBuffer {
-                            worktree_id: remote_worktree_id as u64,
-                            path,
-                        },
-                    )
+                    .request(proto::OpenBuffer {
+                        worktree_id: remote_worktree_id as u64,
+                        path,
+                    })
                     .await?;
-                let buffer_id = response.buffer_id;
                 let remote_buffer = response.buffer.ok_or_else(|| anyhow!("empty buffer"))?;
+                let buffer_id = remote_buffer.id;
                 let buffer = cx.add_model(|cx| {
-                    Buffer::from_proto(replica_id, remote_buffer, Some(file), language, cx).unwrap()
+                    Buffer::from_proto(replica_id, remote_buffer, Some(file), rpc, language, cx)
+                        .unwrap()
                 });
                 this.update(&mut cx, |this, _| {
                     let this = this.as_remote_mut().unwrap();
-                    this.open_buffers.insert(buffer_id, buffer.downgrade());
+                    this.open_buffers
+                        .insert(buffer_id as usize, buffer.downgrade());
                 });
                 Ok(buffer)
             }
@@ -1764,13 +1769,12 @@ mod remote {
             .shared_buffers
             .entry(peer_id)
             .or_default()
-            .insert(buffer.id(), buffer.clone());
+            .insert(buffer.id() as u64, buffer.clone());
 
         rpc.respond(
             request.receipt(),
             proto::OpenBufferResponse {
-                buffer_id: buffer.id() as u64,
-                buffer: Some(buffer.read_with(cx, |buf, _| buf.to_proto())),
+                buffer: Some(buffer.update(cx, |buf, cx| buf.to_proto(cx))),
             },
         )
         .await?;
@@ -1779,28 +1783,18 @@ mod remote {
     }
 
     pub async fn close_buffer(
-        _request: TypedEnvelope<proto::CloseBuffer>,
-        _rpc: &rpc::Client,
-        _cx: &mut AsyncAppContext,
+        message: TypedEnvelope<proto::CloseBuffer>,
+        rpc: &rpc::Client,
+        _: &mut AsyncAppContext,
     ) -> anyhow::Result<()> {
-        // let message = &request.payload;
-        // let peer_id = request
-        //     .original_sender_id
-        //     .ok_or_else(|| anyhow!("missing original sender id"))?;
-        // let mut state = rpc.state.lock().await;
-        // if let Some((_, ref_counts)) = state
-        //     .shared_files
-        //     .iter_mut()
-        //     .find(|(file, _)| file.id() == message.id)
-        // {
-        //     if let Some(count) = ref_counts.get_mut(&peer_id) {
-        //         *count -= 1;
-        //         if *count == 0 {
-        //             ref_counts.remove(&peer_id);
-        //         }
-        //     }
-        // }
-
+        let peer_id = message
+            .original_sender_id
+            .ok_or_else(|| anyhow!("missing original sender id"))?;
+        let message = &message.payload;
+        let mut state = rpc.state.lock().await;
+        state.shared_buffers.entry(peer_id).and_modify(|buffers| {
+            buffers.remove(&message.buffer_id);
+        });
         Ok(())
     }
 }