Avoid deadlocks on rpc state by switching to an RwLock

Max Brunsfeld created

Change summary

zed/src/rpc.rs      | 15 +++++++--------
zed/src/worktree.rs | 41 ++++++++++++++++++++++++-----------------
2 files changed, 31 insertions(+), 25 deletions(-)

Detailed changes

zed/src/rpc.rs 🔗

@@ -5,7 +5,7 @@ use gpui::executor::Background;
 use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle};
 use lazy_static::lazy_static;
 use postage::prelude::Stream;
-use smol::lock::Mutex;
+use smol::lock::RwLock;
 use std::collections::HashMap;
 use std::time::Duration;
 use std::{convert::TryFrom, future::Future, sync::Arc};
@@ -24,7 +24,7 @@ lazy_static! {
 #[derive(Clone)]
 pub struct Client {
     peer: Arc<Peer>,
-    pub state: Arc<Mutex<ClientState>>,
+    pub state: Arc<RwLock<ClientState>>,
 }
 
 pub struct ClientState {
@@ -35,7 +35,7 @@ pub struct ClientState {
 
 impl ClientState {
     pub fn shared_worktree(
-        &mut self,
+        &self,
         id: u64,
         cx: &mut AsyncAppContext,
     ) -> Result<ModelHandle<Worktree>> {
@@ -43,7 +43,6 @@ impl ClientState {
             if let Some(worktree) = cx.read(|cx| worktree.upgrade(cx)) {
                 Ok(worktree)
             } else {
-                self.shared_worktrees.remove(&id);
                 Err(anyhow!("worktree {} was dropped", id))
             }
         } else {
@@ -56,7 +55,7 @@ impl Client {
     pub fn new(languages: Arc<LanguageRegistry>) -> Self {
         Self {
             peer: Peer::new(),
-            state: Arc::new(Mutex::new(ClientState {
+            state: Arc::new(RwLock::new(ClientState {
                 connection_id: None,
                 shared_worktrees: Default::default(),
                 languages,
@@ -82,7 +81,7 @@ impl Client {
     }
 
     pub async fn log_in_and_connect(&self, cx: &AsyncAppContext) -> surf::Result<()> {
-        if self.state.lock().await.connection_id.is_some() {
+        if self.state.read().await.connection_id.is_some() {
             return Ok(());
         }
 
@@ -139,7 +138,7 @@ impl Client {
             Err(anyhow!("failed to authenticate with RPC server"))?;
         }
 
-        self.state.lock().await.connection_id = Some(connection_id);
+        self.state.write().await.connection_id = Some(connection_id);
         Ok(())
     }
 
@@ -221,7 +220,7 @@ impl Client {
 
     async fn connection_id(&self) -> Result<ConnectionId> {
         self.state
-            .lock()
+            .read()
             .await
             .connection_id
             .ok_or_else(|| anyhow!("not connected"))

zed/src/worktree.rs 🔗

@@ -85,7 +85,11 @@ impl Entity for Worktree {
 
         if let Some((rpc, worktree_id)) = rpc {
             cx.spawn(|_| async move {
-                rpc.state.lock().await.shared_worktrees.remove(&worktree_id);
+                rpc.state
+                    .write()
+                    .await
+                    .shared_worktrees
+                    .remove(&worktree_id);
                 if let Err(err) = rpc.send(proto::CloseWorktree { worktree_id }).await {
                     log::error!("error closing worktree {}: {}", worktree_id, err);
                 }
@@ -191,7 +195,7 @@ impl Worktree {
             })
         });
         rpc.state
-            .lock()
+            .write()
             .await
             .shared_worktrees
             .insert(id, worktree.downgrade());
@@ -343,8 +347,7 @@ impl Worktree {
                 .and_then(|buf| buf.upgrade(&cx))
             {
                 buffer.update(cx, |buffer, cx| {
-                    buffer.did_save(message.version.try_into()?, cx);
-                    Result::<_, anyhow::Error>::Ok(())
+                    buffer.did_save(message.version.try_into()?, cx)
                 })?;
             }
             Ok(())
@@ -798,7 +801,7 @@ impl LocalWorktree {
                 .await?;
 
             rpc.state
-                .lock()
+                .write()
                 .await
                 .shared_worktrees
                 .insert(share_response.worktree_id, handle.downgrade());
@@ -2094,7 +2097,7 @@ mod remote {
         cx: &mut AsyncAppContext,
     ) -> anyhow::Result<()> {
         rpc.state
-            .lock()
+            .read()
             .await
             .shared_worktree(envelope.payload.worktree_id, cx)?
             .update(cx, |worktree, cx| worktree.add_guest(envelope, cx))
@@ -2106,7 +2109,7 @@ mod remote {
         cx: &mut AsyncAppContext,
     ) -> anyhow::Result<()> {
         rpc.state
-            .lock()
+            .read()
             .await
             .shared_worktree(envelope.payload.worktree_id, cx)?
             .update(cx, |worktree, cx| worktree.remove_guest(envelope, cx))
@@ -2120,7 +2123,7 @@ mod remote {
         let receipt = envelope.receipt();
         let worktree = rpc
             .state
-            .lock()
+            .read()
             .await
             .shared_worktree(envelope.payload.worktree_id, cx)?;
 
@@ -2145,7 +2148,7 @@ mod remote {
     ) -> anyhow::Result<()> {
         let worktree = rpc
             .state
-            .lock()
+            .read()
             .await
             .shared_worktree(envelope.payload.worktree_id, cx)?;
 
@@ -2163,9 +2166,11 @@ mod remote {
         cx: &mut AsyncAppContext,
     ) -> anyhow::Result<()> {
         let message = envelope.payload;
-        let mut state = rpc.state.lock().await;
-        let worktree = state.shared_worktree(message.worktree_id, cx)?;
-        worktree.update(cx, |tree, cx| tree.update_buffer(message, cx))?;
+        rpc.state
+            .read()
+            .await
+            .shared_worktree(message.worktree_id, cx)?
+            .update(cx, |tree, cx| tree.update_buffer(message, cx))?;
         Ok(())
     }
 
@@ -2203,11 +2208,13 @@ mod remote {
         rpc: &rpc::Client,
         cx: &mut AsyncAppContext,
     ) -> anyhow::Result<()> {
-        let mut state = rpc.state.lock().await;
-        let worktree = state.shared_worktree(envelope.payload.worktree_id, cx)?;
-        worktree.update(cx, |worktree, cx| {
-            worktree.buffer_saved(envelope.payload, cx)
-        })?;
+        rpc.state
+            .read()
+            .await
+            .shared_worktree(envelope.payload.worktree_id, cx)?
+            .update(cx, |worktree, cx| {
+                worktree.buffer_saved(envelope.payload, cx)
+            })?;
         Ok(())
     }
 }