Merge pull request #230 from zed-industries/rpc-write-timeout

Max Brunsfeld created

Avoid server deadlocks

Change summary

Cargo.lock               |  11 ++
crates/rpc/Cargo.toml    |   1 
crates/rpc/src/peer.rs   |  10 ++
crates/server/src/rpc.rs | 167 +++++++++++++++--------------------------
4 files changed, 83 insertions(+), 106 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -4308,6 +4308,7 @@ dependencies = [
  "rsa",
  "serde 1.0.125",
  "smol",
+ "smol-timeout",
  "tempdir",
  "zstd",
 ]
@@ -4867,6 +4868,16 @@ dependencies = [
  "once_cell",
 ]
 
+[[package]]
+name = "smol-timeout"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "847d777e2c6c166bad26264479e80a9820f3d364fcb4a0e23cd57bbfa8e94961"
+dependencies = [
+ "async-io",
+ "pin-project-lite 0.1.12",
+]
+
 [[package]]
 name = "socket2"
 version = "0.3.19"

crates/rpc/Cargo.toml 🔗

@@ -20,6 +20,7 @@ prost = "0.8"
 rand = "0.8"
 rsa = "0.4"
 serde = { version = "1", features = ["derive"] }
+smol-timeout = "0.6"
 zstd = "0.9"
 
 [build-dependencies]

crates/rpc/src/peer.rs 🔗

@@ -7,6 +7,7 @@ use postage::{
     mpsc,
     prelude::{Sink as _, Stream as _},
 };
+use smol_timeout::TimeoutExt as _;
 use std::{
     collections::HashMap,
     fmt,
@@ -16,6 +17,7 @@ use std::{
         atomic::{self, AtomicU32},
         Arc,
     },
+    time::Duration,
 };
 
 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
@@ -90,6 +92,8 @@ struct ConnectionState {
     response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
 }
 
+const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
+
 impl Peer {
     pub fn new() -> Arc<Self> {
         Arc::new(Self {
@@ -155,8 +159,10 @@ impl Peer {
                         },
                         outgoing = outgoing_rx.recv().fuse() => match outgoing {
                             Some(outgoing) => {
-                                if let Err(result) = writer.write_message(&outgoing).await {
-                                    break 'outer Err(result).context("failed to write RPC message")
+                                match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
+                                    None => break 'outer Err(anyhow!("timed out writing RPC message")),
+                                    Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
+                                    _ => {}
                                 }
                             }
                             None => break 'outer Ok(()),

crates/server/src/rpc.rs 🔗

@@ -6,9 +6,10 @@ use super::{
     AppState,
 };
 use anyhow::anyhow;
-use async_std::{sync::RwLock, task};
+use async_std::task;
 use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
 use futures::{future::BoxFuture, FutureExt};
+use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
 use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
 use rpc::{
     proto::{self, AnyTypedEnvelope, EnvelopedMessage},
@@ -23,7 +24,7 @@ use std::{
     sync::Arc,
     time::Instant,
 };
-use store::{JoinedWorktree, Store, Worktree};
+use store::{Store, Worktree};
 use surf::StatusCode;
 use tide::log;
 use tide::{
@@ -116,9 +117,7 @@ impl Server {
         async move {
             let (connection_id, handle_io, mut incoming_rx) =
                 this.peer.add_connection(connection).await;
-            this.state_mut()
-                .await
-                .add_connection(connection_id, user_id);
+            this.state_mut().add_connection(connection_id, user_id);
             if let Err(err) = this.update_collaborators_for_users(&[user_id]).await {
                 log::error!("error updating collaborators for {:?}: {}", user_id, err);
             }
@@ -168,7 +167,7 @@ impl Server {
 
     async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
         self.peer.disconnect(connection_id).await;
-        let removed_connection = self.state_mut().await.remove_connection(connection_id)?;
+        let removed_connection = self.state_mut().remove_connection(connection_id)?;
 
         for (worktree_id, worktree) in removed_connection.hosted_worktrees {
             if let Some(share) = worktree.share {
@@ -213,10 +212,7 @@ impl Server {
         request: TypedEnvelope<proto::OpenWorktree>,
     ) -> tide::Result<()> {
         let receipt = request.receipt();
-        let host_user_id = self
-            .state()
-            .await
-            .user_id_for_connection(request.sender_id)?;
+        let host_user_id = self.state().user_id_for_connection(request.sender_id)?;
 
         let mut collaborator_user_ids = HashSet::new();
         collaborator_user_ids.insert(host_user_id);
@@ -236,7 +232,7 @@ impl Server {
         }
 
         let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
-        let worktree_id = self.state_mut().await.add_worktree(Worktree {
+        let worktree_id = self.state_mut().add_worktree(Worktree {
             host_connection_id: request.sender_id,
             collaborator_user_ids: collaborator_user_ids.clone(),
             root_name: request.payload.root_name,
@@ -259,7 +255,6 @@ impl Server {
         let worktree_id = request.payload.worktree_id;
         let worktree = self
             .state_mut()
-            .await
             .remove_worktree(worktree_id, request.sender_id)?;
 
         if let Some(share) = worktree.share {
@@ -294,7 +289,6 @@ impl Server {
 
         let collaborator_user_ids =
             self.state_mut()
-                .await
                 .share_worktree(worktree.id, request.sender_id, entries);
         if let Some(collaborator_user_ids) = collaborator_user_ids {
             self.peer
@@ -322,7 +316,6 @@ impl Server {
         let worktree_id = request.payload.worktree_id;
         let worktree = self
             .state_mut()
-            .await
             .unshare_worktree(worktree_id, request.sender_id)?;
 
         broadcast(request.sender_id, worktree.connection_ids, |conn_id| {
@@ -341,22 +334,17 @@ impl Server {
         request: TypedEnvelope<proto::JoinWorktree>,
     ) -> tide::Result<()> {
         let worktree_id = request.payload.worktree_id;
-        let user_id = self
-            .state()
-            .await
-            .user_id_for_connection(request.sender_id)?;
-
-        let mut state = self.state_mut().await;
-        match state.join_worktree(request.sender_id, user_id, worktree_id) {
-            Ok(JoinedWorktree {
-                replica_id,
-                worktree,
-            }) => {
-                let share = worktree.share()?;
+
+        let user_id = self.state().user_id_for_connection(request.sender_id)?;
+        let response_data = self
+            .state_mut()
+            .join_worktree(request.sender_id, user_id, worktree_id)
+            .and_then(|joined| {
+                let share = joined.worktree.share()?;
                 let peer_count = share.guest_connection_ids.len();
                 let mut peers = Vec::with_capacity(peer_count);
                 peers.push(proto::Peer {
-                    peer_id: worktree.host_connection_id.0,
+                    peer_id: joined.worktree.host_connection_id.0,
                     replica_id: 0,
                 });
                 for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids {
@@ -370,16 +358,19 @@ impl Server {
                 let response = proto::JoinWorktreeResponse {
                     worktree: Some(proto::Worktree {
                         id: worktree_id,
-                        root_name: worktree.root_name.clone(),
+                        root_name: joined.worktree.root_name.clone(),
                         entries: share.entries.values().cloned().collect(),
                     }),
-                    replica_id: replica_id as u32,
+                    replica_id: joined.replica_id as u32,
                     peers,
                 };
-                let connection_ids = worktree.connection_ids();
-                let collaborator_user_ids = worktree.collaborator_user_ids.clone();
-                drop(state);
+                let connection_ids = joined.worktree.connection_ids();
+                let collaborator_user_ids = joined.worktree.collaborator_user_ids.clone();
+                Ok((response, connection_ids, collaborator_user_ids))
+            });
 
+        match response_data {
+            Ok((response, connection_ids, collaborator_user_ids)) => {
                 broadcast(request.sender_id, connection_ids, |conn_id| {
                     self.peer.send(
                         conn_id,
@@ -398,7 +389,6 @@ impl Server {
                     .await?;
             }
             Err(error) => {
-                drop(state);
                 self.peer
                     .respond_with_error(
                         request.receipt(),
@@ -419,10 +409,7 @@ impl Server {
     ) -> tide::Result<()> {
         let sender_id = request.sender_id;
         let worktree_id = request.payload.worktree_id;
-        let worktree = self
-            .state_mut()
-            .await
-            .leave_worktree(sender_id, worktree_id);
+        let worktree = self.state_mut().leave_worktree(sender_id, worktree_id);
         if let Some(worktree) = worktree {
             broadcast(sender_id, worktree.connection_ids, |conn_id| {
                 self.peer.send(
@@ -444,7 +431,7 @@ impl Server {
         mut self: Arc<Server>,
         request: TypedEnvelope<proto::UpdateWorktree>,
     ) -> tide::Result<()> {
-        let connection_ids = self.state_mut().await.update_worktree(
+        let connection_ids = self.state_mut().update_worktree(
             request.sender_id,
             request.payload.worktree_id,
             &request.payload.removed_entries,
@@ -467,7 +454,6 @@ impl Server {
         let receipt = request.receipt();
         let host_connection_id = self
             .state()
-            .await
             .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
         let response = self
             .peer
@@ -483,7 +469,6 @@ impl Server {
     ) -> tide::Result<()> {
         let host_connection_id = self
             .state()
-            .await
             .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
         self.peer
             .forward_send(request.sender_id, host_connection_id, request.payload)
@@ -498,7 +483,7 @@ impl Server {
         let host;
         let guests;
         {
-            let state = self.state().await;
+            let state = self.state();
             host = state
                 .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
             guests = state
@@ -532,16 +517,13 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::UpdateBuffer>,
     ) -> tide::Result<()> {
-        broadcast(
-            request.sender_id,
-            self.state()
-                .await
-                .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
-            |connection_id| {
-                self.peer
-                    .forward_send(request.sender_id, connection_id, request.payload.clone())
-            },
-        )
+        let receiver_ids = self
+            .state()
+            .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?;
+        broadcast(request.sender_id, receiver_ids, |connection_id| {
+            self.peer
+                .forward_send(request.sender_id, connection_id, request.payload.clone())
+        })
         .await?;
         self.peer.respond(request.receipt(), proto::Ack {}).await?;
         Ok(())
@@ -551,17 +533,13 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::BufferSaved>,
     ) -> tide::Result<()> {
-        broadcast(
-            request.sender_id,
-            self.store
-                .read()
-                .await
-                .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
-            |connection_id| {
-                self.peer
-                    .forward_send(request.sender_id, connection_id, request.payload.clone())
-            },
-        )
+        let receiver_ids = self
+            .state()
+            .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?;
+        broadcast(request.sender_id, receiver_ids, |connection_id| {
+            self.peer
+                .forward_send(request.sender_id, connection_id, request.payload.clone())
+        })
         .await?;
         Ok(())
     }
@@ -570,10 +548,7 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::GetChannels>,
     ) -> tide::Result<()> {
-        let user_id = self
-            .state()
-            .await
-            .user_id_for_connection(request.sender_id)?;
+        let user_id = self.state().user_id_for_connection(request.sender_id)?;
         let channels = self.app_state.db.get_accessible_channels(user_id).await?;
         self.peer
             .respond(
@@ -622,20 +597,20 @@ impl Server {
     ) -> tide::Result<()> {
         let mut send_futures = Vec::new();
 
-        let state = self.state().await;
-        for user_id in user_ids {
-            let collaborators = state.collaborators_for_user(*user_id);
-            for connection_id in state.connection_ids_for_user(*user_id) {
-                send_futures.push(self.peer.send(
-                    connection_id,
-                    proto::UpdateCollaborators {
-                        collaborators: collaborators.clone(),
-                    },
-                ));
+        {
+            let state = self.state();
+            for user_id in user_ids {
+                let collaborators = state.collaborators_for_user(*user_id);
+                for connection_id in state.connection_ids_for_user(*user_id) {
+                    send_futures.push(self.peer.send(
+                        connection_id,
+                        proto::UpdateCollaborators {
+                            collaborators: collaborators.clone(),
+                        },
+                    ));
+                }
             }
         }
-
-        drop(state);
         futures::future::try_join_all(send_futures).await?;
 
         Ok(())
@@ -645,10 +620,7 @@ impl Server {
         mut self: Arc<Self>,
         request: TypedEnvelope<proto::JoinChannel>,
     ) -> tide::Result<()> {
-        let user_id = self
-            .state()
-            .await
-            .user_id_for_connection(request.sender_id)?;
+        let user_id = self.state().user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
         if !self
             .app_state
@@ -659,9 +631,7 @@ impl Server {
             Err(anyhow!("access denied"))?;
         }
 
-        self.state_mut()
-            .await
-            .join_channel(request.sender_id, channel_id);
+        self.state_mut().join_channel(request.sender_id, channel_id);
         let messages = self
             .app_state
             .db
@@ -692,10 +662,7 @@ impl Server {
         mut self: Arc<Self>,
         request: TypedEnvelope<proto::LeaveChannel>,
     ) -> tide::Result<()> {
-        let user_id = self
-            .state()
-            .await
-            .user_id_for_connection(request.sender_id)?;
+        let user_id = self.state().user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
         if !self
             .app_state
@@ -707,7 +674,6 @@ impl Server {
         }
 
         self.state_mut()
-            .await
             .leave_channel(request.sender_id, channel_id);
 
         Ok(())
@@ -722,7 +688,7 @@ impl Server {
         let user_id;
         let connection_ids;
         {
-            let state = self.state().await;
+            let state = self.state();
             user_id = state.user_id_for_connection(request.sender_id)?;
             if let Some(ids) = state.channel_connection_ids(channel_id) {
                 connection_ids = ids;
@@ -809,10 +775,7 @@ impl Server {
         self: Arc<Self>,
         request: TypedEnvelope<proto::GetChannelMessages>,
     ) -> tide::Result<()> {
-        let user_id = self
-            .state()
-            .await
-            .user_id_for_connection(request.sender_id)?;
+        let user_id = self.state().user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
         if !self
             .app_state
@@ -853,15 +816,11 @@ impl Server {
         Ok(())
     }
 
-    fn state<'a>(
-        self: &'a Arc<Self>,
-    ) -> impl Future<Output = async_std::sync::RwLockReadGuard<'a, Store>> {
+    fn state<'a>(self: &'a Arc<Self>) -> RwLockReadGuard<'a, Store> {
         self.store.read()
     }
 
-    fn state_mut<'a>(
-        self: &'a mut Arc<Self>,
-    ) -> impl Future<Output = async_std::sync::RwLockWriteGuard<'a, Store>> {
+    fn state_mut<'a>(self: &'a mut Arc<Self>) -> RwLockWriteGuard<'a, Store> {
         self.store.write()
     }
 }
@@ -961,7 +920,7 @@ mod tests {
         github, AppState, Config,
     };
     use ::rpc::Peer;
-    use async_std::{sync::RwLockReadGuard, task};
+    use async_std::task;
     use gpui::{ModelHandle, TestAppContext};
     use parking_lot::Mutex;
     use postage::{mpsc, watch};
@@ -2372,7 +2331,7 @@ mod tests {
         }
 
         async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> {
-            self.server.store.read().await
+            self.server.store.read()
         }
 
         async fn condition<F>(&mut self, mut predicate: F)
@@ -2380,7 +2339,7 @@ mod tests {
             F: FnMut(&Store) -> bool,
         {
             async_std::future::timeout(Duration::from_millis(500), async {
-                while !(predicate)(&*self.server.store.read().await) {
+                while !(predicate)(&*self.server.store.read()) {
                     self.notifications.recv().await;
                 }
             })