Use synchronous locks for `Peer` state

Antonio Scandurra and Nathan Sobo created

We hold these locks for a short amount of time anyway, and using an
async lock could cause parallel sends to happen in an order different
than the order in which `send`/`request` was called.

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/client/src/client.rs | 10 +++---
crates/client/src/test.rs   |  4 +-
crates/rpc/src/peer.rs      | 60 ++++++++++++++++----------------------
crates/server/src/rpc.rs    |  6 +-
4 files changed, 36 insertions(+), 44 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -661,9 +661,9 @@ impl Client {
         })
     }
 
-    pub async fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
+    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
         let conn_id = self.connection_id()?;
-        self.peer.disconnect(conn_id).await;
+        self.peer.disconnect(conn_id);
         self.set_status(Status::SignedOut, cx);
         Ok(())
     }
@@ -764,7 +764,7 @@ mod tests {
         let ping = server.receive::<proto::Ping>().await.unwrap();
         server.respond(ping.receipt(), proto::Ack {}).await;
 
-        client.disconnect(&cx.to_async()).await.unwrap();
+        client.disconnect(&cx.to_async()).unwrap();
         assert!(server.receive::<proto::Ping>().await.is_err());
     }
 
@@ -783,7 +783,7 @@ mod tests {
         assert_eq!(server.auth_count(), 1);
 
         server.forbid_connections();
-        server.disconnect().await;
+        server.disconnect();
         while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
 
         server.allow_connections();
@@ -792,7 +792,7 @@ mod tests {
         assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
 
         server.forbid_connections();
-        server.disconnect().await;
+        server.disconnect();
         while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
 
         // Clear cached credentials after authentication fails

crates/client/src/test.rs 🔗

@@ -72,8 +72,8 @@ impl FakeServer {
         server
     }
 
-    pub async fn disconnect(&self) {
-        self.peer.disconnect(self.connection_id()).await;
+    pub fn disconnect(&self) {
+        self.peer.disconnect(self.connection_id());
         self.connection_id.lock().take();
         self.incoming.lock().take();
     }

crates/rpc/src/peer.rs 🔗

@@ -1,8 +1,8 @@
 use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
 use super::Connection;
 use anyhow::{anyhow, Context, Result};
-use async_lock::{Mutex, RwLock};
 use futures::FutureExt as _;
+use parking_lot::{Mutex, RwLock};
 use postage::{
     mpsc,
     prelude::{Sink as _, Stream as _},
@@ -133,7 +133,7 @@ impl Peer {
                         incoming = read_message => match incoming {
                             Ok(incoming) => {
                                 if let Some(responding_to) = incoming.responding_to {
-                                    let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to);
+                                    let channel = response_channels.lock().as_mut().unwrap().remove(&responding_to);
                                     if let Some(mut tx) = channel {
                                         tx.send(incoming).await.ok();
                                     } else {
@@ -169,25 +169,24 @@ impl Peer {
                 }
             };
 
-            response_channels.lock().await.take();
-            this.connections.write().await.remove(&connection_id);
+            response_channels.lock().take();
+            this.connections.write().remove(&connection_id);
             result
         };
 
         self.connections
             .write()
-            .await
             .insert(connection_id, connection_state);
 
         (connection_id, handle_io, incoming_rx)
     }
 
-    pub async fn disconnect(&self, connection_id: ConnectionId) {
-        self.connections.write().await.remove(&connection_id);
+    pub fn disconnect(&self, connection_id: ConnectionId) {
+        self.connections.write().remove(&connection_id);
     }
 
-    pub async fn reset(&self) {
-        self.connections.write().await.clear();
+    pub fn reset(&self) {
+        self.connections.write().clear();
     }
 
     pub fn request<T: RequestMessage>(
@@ -216,12 +215,11 @@ impl Peer {
         let this = self.clone();
         let (tx, mut rx) = mpsc::channel(1);
         async move {
-            let mut connection = this.connection_state(receiver_id).await?;
+            let mut connection = this.connection_state(receiver_id)?;
             let message_id = connection.next_message_id.fetch_add(1, SeqCst);
             connection
                 .response_channels
                 .lock()
-                .await
                 .as_mut()
                 .ok_or_else(|| anyhow!("connection was closed"))?
                 .insert(message_id, tx);
@@ -250,7 +248,7 @@ impl Peer {
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
-            let mut connection = this.connection_state(receiver_id).await?;
+            let mut connection = this.connection_state(receiver_id)?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
@@ -270,7 +268,7 @@ impl Peer {
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
-            let mut connection = this.connection_state(receiver_id).await?;
+            let mut connection = this.connection_state(receiver_id)?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
@@ -289,7 +287,7 @@ impl Peer {
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
-            let mut connection = this.connection_state(receipt.sender_id).await?;
+            let mut connection = this.connection_state(receipt.sender_id)?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
@@ -308,7 +306,7 @@ impl Peer {
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
-            let mut connection = this.connection_state(receipt.sender_id).await?;
+            let mut connection = this.connection_state(receipt.sender_id)?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
@@ -320,18 +318,12 @@ impl Peer {
         }
     }
 
-    fn connection_state(
-        self: &Arc<Self>,
-        connection_id: ConnectionId,
-    ) -> impl Future<Output = Result<ConnectionState>> {
-        let this = self.clone();
-        async move {
-            let connections = this.connections.read().await;
-            let connection = connections
-                .get(&connection_id)
-                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
-            Ok(connection.clone())
-        }
+    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
+        let connections = self.connections.read();
+        let connection = connections
+            .get(&connection_id)
+            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
+        Ok(connection.clone())
     }
 }
 
@@ -398,7 +390,7 @@ mod tests {
                 proto::OpenBufferResponse {
                     buffer: Some(proto::Buffer {
                         id: 101,
-                        content: "path/one content".to_string(),
+                        visible_text: "path/one content".to_string(),
                         ..Default::default()
                     }),
                 }
@@ -419,14 +411,14 @@ mod tests {
                 proto::OpenBufferResponse {
                     buffer: Some(proto::Buffer {
                         id: 102,
-                        content: "path/two content".to_string(),
+                        visible_text: "path/two content".to_string(),
                         ..Default::default()
                     }),
                 }
             );
 
-            client1.disconnect(client1_conn_id).await;
-            client2.disconnect(client1_conn_id).await;
+            client1.disconnect(client1_conn_id);
+            client2.disconnect(client1_conn_id);
 
             async fn handle_messages(
                 mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
@@ -448,7 +440,7 @@ mod tests {
                                 proto::OpenBufferResponse {
                                     buffer: Some(proto::Buffer {
                                         id: 101,
-                                        content: "path/one content".to_string(),
+                                        visible_text: "path/one content".to_string(),
                                         ..Default::default()
                                     }),
                                 }
@@ -458,7 +450,7 @@ mod tests {
                                 proto::OpenBufferResponse {
                                     buffer: Some(proto::Buffer {
                                         id: 102,
-                                        content: "path/two content".to_string(),
+                                        visible_text: "path/two content".to_string(),
                                         ..Default::default()
                                     }),
                                 }
@@ -502,7 +494,7 @@ mod tests {
             })
             .detach();
 
-            client.disconnect(connection_id).await;
+            client.disconnect(connection_id);
 
             io_ended_rx.recv().await;
             messages_ended_rx.recv().await;

crates/server/src/rpc.rs 🔗

@@ -174,7 +174,7 @@ impl Server {
     }
 
     async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
-        self.peer.disconnect(connection_id).await;
+        self.peer.disconnect(connection_id);
         let removed_connection = self.state_mut().remove_connection(connection_id)?;
 
         for (project_id, project) in removed_connection.hosted_projects {
@@ -1801,7 +1801,7 @@ mod tests {
             .await;
 
         // Drop client B's connection and ensure client A observes client B leaving the worktree.
-        client_b.disconnect(&cx_b.to_async()).await.unwrap();
+        client_b.disconnect(&cx_b.to_async()).unwrap();
         project_a
             .condition(&cx_a, |p, _| p.collaborators().len() == 0)
             .await;
@@ -2833,7 +2833,7 @@ mod tests {
 
     impl Drop for TestServer {
         fn drop(&mut self) {
-            task::block_on(self.peer.reset());
+            self.peer.reset();
         }
     }