Add `Server::{state,state_mut}` to catch most deadlocks statically

Antonio Scandurra created

Change summary

server/src/rpc.rs | 150 +++++++++++++++++++++++-------------------------
1 file changed, 71 insertions(+), 79 deletions(-)

Detailed changes

server/src/rpc.rs 🔗

@@ -112,12 +112,11 @@ impl Server {
         addr: String,
         user_id: UserId,
     ) -> impl Future<Output = ()> {
-        let this = self.clone();
+        let mut this = self.clone();
         async move {
             let (connection_id, handle_io, mut incoming_rx) =
                 this.peer.add_connection(connection).await;
-            this.store
-                .write()
+            this.state_mut()
                 .await
                 .add_connection(connection_id, user_id);
             if let Err(err) = this.update_collaborators_for_users(&[user_id]).await {
@@ -167,9 +166,9 @@ impl Server {
         }
     }
 
-    async fn sign_out(self: &Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
+    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
         self.peer.disconnect(connection_id).await;
-        let removed_connection = self.store.write().await.remove_connection(connection_id)?;
+        let removed_connection = self.state_mut().await.remove_connection(connection_id)?;
 
         for (worktree_id, worktree) in removed_connection.hosted_worktrees {
             if let Some(share) = worktree.share {
@@ -210,13 +209,12 @@ impl Server {
     }
 
     async fn open_worktree(
-        self: Arc<Server>,
+        mut self: Arc<Server>,
         request: TypedEnvelope<proto::OpenWorktree>,
     ) -> tide::Result<()> {
         let receipt = request.receipt();
         let host_user_id = self
-            .store
-            .read()
+            .state()
             .await
             .user_id_for_connection(request.sender_id)?;
 
@@ -238,7 +236,7 @@ impl Server {
         }
 
         let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
-        let worktree_id = self.store.write().await.add_worktree(Worktree {
+        let worktree_id = self.state_mut().await.add_worktree(Worktree {
             host_connection_id: request.sender_id,
             collaborator_user_ids: collaborator_user_ids.clone(),
             root_name: request.payload.root_name,
@@ -255,13 +253,12 @@ impl Server {
     }
 
     async fn close_worktree(
-        self: Arc<Server>,
+        mut self: Arc<Server>,
         request: TypedEnvelope<proto::CloseWorktree>,
     ) -> tide::Result<()> {
         let worktree_id = request.payload.worktree_id;
         let worktree = self
-            .store
-            .write()
+            .state_mut()
             .await
             .remove_worktree(worktree_id, request.sender_id)?;
 
@@ -282,7 +279,7 @@ impl Server {
     }
 
     async fn share_worktree(
-        self: Arc<Server>,
+        mut self: Arc<Server>,
         mut request: TypedEnvelope<proto::ShareWorktree>,
     ) -> tide::Result<()> {
         let worktree = request
@@ -296,8 +293,7 @@ impl Server {
             .collect();
 
         let collaborator_user_ids =
-            self.store
-                .write()
+            self.state_mut()
                 .await
                 .share_worktree(worktree.id, request.sender_id, entries);
         if let Some(collaborator_user_ids) = collaborator_user_ids {
@@ -320,13 +316,12 @@ impl Server {
     }
 
     async fn unshare_worktree(
-        self: Arc<Server>,
+        mut self: Arc<Server>,
         request: TypedEnvelope<proto::UnshareWorktree>,
     ) -> tide::Result<()> {
         let worktree_id = request.payload.worktree_id;
         let worktree = self
-            .store
-            .write()
+            .state_mut()
             .await
             .unshare_worktree(worktree_id, request.sender_id)?;
 
@@ -342,20 +337,16 @@ impl Server {
     }
 
     async fn join_worktree(
-        self: Arc<Server>,
+        mut self: Arc<Server>,
         request: TypedEnvelope<proto::JoinWorktree>,
     ) -> tide::Result<()> {
         let worktree_id = request.payload.worktree_id;
         let user_id = self
-            .store
-            .read()
+            .state()
             .await
             .user_id_for_connection(request.sender_id)?;
 
-        let response;
-        let connection_ids;
-        let collaborator_user_ids;
-        let mut state = self.store.write().await;
+        let mut state = self.state_mut().await;
         match state.join_worktree(request.sender_id, user_id, worktree_id) {
             Ok(JoinedWorktree {
                 replica_id,
@@ -376,7 +367,7 @@ impl Server {
                         });
                     }
                 }
-                response = proto::JoinWorktreeResponse {
+                let response = proto::JoinWorktreeResponse {
                     worktree: Some(proto::Worktree {
                         id: worktree_id,
                         root_name: worktree.root_name.clone(),
@@ -385,10 +376,29 @@ impl Server {
                     replica_id: replica_id as u32,
                     peers,
                 };
-                connection_ids = worktree.connection_ids();
-                collaborator_user_ids = worktree.collaborator_user_ids.clone();
+                let connection_ids = worktree.connection_ids();
+                let collaborator_user_ids = worktree.collaborator_user_ids.clone();
+                drop(state);
+
+                broadcast(request.sender_id, connection_ids, |conn_id| {
+                    self.peer.send(
+                        conn_id,
+                        proto::AddPeer {
+                            worktree_id,
+                            peer: Some(proto::Peer {
+                                peer_id: request.sender_id.0,
+                                replica_id: response.replica_id,
+                            }),
+                        },
+                    )
+                })
+                .await?;
+                self.peer.respond(request.receipt(), response).await?;
+                self.update_collaborators_for_users(&collaborator_user_ids)
+                    .await?;
             }
             Err(error) => {
+                drop(state);
                 self.peer
                     .respond_with_error(
                         request.receipt(),
@@ -397,44 +407,23 @@ impl Server {
                         },
                     )
                     .await?;
-                return Ok(());
             }
         }
 
-        drop(state);
-        broadcast(request.sender_id, connection_ids, |conn_id| {
-            self.peer.send(
-                conn_id,
-                proto::AddPeer {
-                    worktree_id,
-                    peer: Some(proto::Peer {
-                        peer_id: request.sender_id.0,
-                        replica_id: response.replica_id,
-                    }),
-                },
-            )
-        })
-        .await?;
-        self.peer.respond(request.receipt(), response).await?;
-        self.update_collaborators_for_users(&collaborator_user_ids)
-            .await?;
-
         Ok(())
     }
 
     async fn leave_worktree(
-        self: Arc<Server>,
+        mut self: Arc<Server>,
         request: TypedEnvelope<proto::LeaveWorktree>,
     ) -> tide::Result<()> {
         let sender_id = request.sender_id;
         let worktree_id = request.payload.worktree_id;
-
-        if let Some(worktree) = self
-            .store
-            .write()
+        let worktree = self
+            .state_mut()
             .await
-            .leave_worktree(sender_id, worktree_id)
-        {
+            .leave_worktree(sender_id, worktree_id);
+        if let Some(worktree) = worktree {
             broadcast(sender_id, worktree.connection_ids, |conn_id| {
                 self.peer.send(
                     conn_id,
@@ -452,10 +441,10 @@ impl Server {
     }
 
     async fn update_worktree(
-        self: Arc<Server>,
+        mut self: Arc<Server>,
         request: TypedEnvelope<proto::UpdateWorktree>,
     ) -> tide::Result<()> {
-        let connection_ids = self.store.write().await.update_worktree(
+        let connection_ids = self.state_mut().await.update_worktree(
             request.sender_id,
             request.payload.worktree_id,
             &request.payload.removed_entries,
@@ -477,8 +466,7 @@ impl Server {
     ) -> tide::Result<()> {
         let receipt = request.receipt();
         let host_connection_id = self
-            .store
-            .read()
+            .state()
             .await
             .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
         let response = self
@@ -494,8 +482,7 @@ impl Server {
         request: TypedEnvelope<proto::CloseBuffer>,
     ) -> tide::Result<()> {
         let host_connection_id = self
-            .store
-            .read()
+            .state()
             .await
             .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
         self.peer
@@ -511,7 +498,7 @@ impl Server {
         let host;
         let guests;
         {
-            let state = self.store.read().await;
+            let state = self.state().await;
             host = state
                 .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
             guests = state
@@ -547,8 +534,7 @@ impl Server {
     ) -> tide::Result<()> {
         broadcast(
             request.sender_id,
-            self.store
-                .read()
+            self.state()
                 .await
                 .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
             |connection_id| {
@@ -585,8 +571,7 @@ impl Server {
         request: TypedEnvelope<proto::GetChannels>,
     ) -> tide::Result<()> {
         let user_id = self
-            .store
-            .read()
+            .state()
             .await
             .user_id_for_connection(request.sender_id)?;
         let channels = self.app_state.db.get_accessible_channels(user_id).await?;
@@ -637,7 +622,7 @@ impl Server {
     ) -> tide::Result<()> {
         let mut send_futures = Vec::new();
 
-        let state = self.store.read().await;
+        let state = self.state().await;
         for user_id in user_ids {
             let collaborators = state.collaborators_for_user(*user_id);
             for connection_id in state.connection_ids_for_user(*user_id) {
@@ -657,12 +642,11 @@ impl Server {
     }
 
     async fn join_channel(
-        self: Arc<Self>,
+        mut self: Arc<Self>,
         request: TypedEnvelope<proto::JoinChannel>,
     ) -> tide::Result<()> {
         let user_id = self
-            .store
-            .read()
+            .state()
             .await
             .user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
@@ -675,8 +659,7 @@ impl Server {
             Err(anyhow!("access denied"))?;
         }
 
-        self.store
-            .write()
+        self.state_mut()
             .await
             .join_channel(request.sender_id, channel_id);
         let messages = self
@@ -706,12 +689,11 @@ impl Server {
     }
 
     async fn leave_channel(
-        self: Arc<Self>,
+        mut self: Arc<Self>,
         request: TypedEnvelope<proto::LeaveChannel>,
     ) -> tide::Result<()> {
         let user_id = self
-            .store
-            .read()
+            .state()
             .await
             .user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
@@ -724,8 +706,7 @@ impl Server {
             Err(anyhow!("access denied"))?;
         }
 
-        self.store
-            .write()
+        self.state_mut()
             .await
             .leave_channel(request.sender_id, channel_id);
 
@@ -741,7 +722,7 @@ impl Server {
         let user_id;
         let connection_ids;
         {
-            let state = self.store.read().await;
+            let state = self.state().await;
             user_id = state.user_id_for_connection(request.sender_id)?;
             if let Some(ids) = state.channel_connection_ids(channel_id) {
                 connection_ids = ids;
@@ -829,8 +810,7 @@ impl Server {
         request: TypedEnvelope<proto::GetChannelMessages>,
     ) -> tide::Result<()> {
         let user_id = self
-            .store
-            .read()
+            .state()
             .await
             .user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
@@ -872,6 +852,18 @@ impl Server {
             .await?;
         Ok(())
     }
+
+    fn state<'a>(
+        self: &'a Arc<Self>,
+    ) -> impl Future<Output = async_std::sync::RwLockReadGuard<'a, Store>> {
+        self.store.read()
+    }
+
+    fn state_mut<'a>(
+        self: &'a mut Arc<Self>,
+    ) -> impl Future<Output = async_std::sync::RwLockWriteGuard<'a, Store>> {
+        self.store.write()
+    }
 }
 
 pub async fn broadcast<F, T>(