Use an async `Mutex` to guard access to write stream in `RpcClient`

Antonio Scandurra created

Change summary

zed/src/rpc_client.rs | 115 ++++++++++++++++----------------------------
1 file changed, 43 insertions(+), 72 deletions(-)

Detailed changes

zed/src/rpc_client.rs 🔗

@@ -1,18 +1,17 @@
 use anyhow::{anyhow, Result};
 use futures::future::Either;
 use gpui::executor::Background;
-use parking_lot::Mutex;
 use postage::{
-    barrier, mpsc, oneshot,
+    barrier, mpsc,
     prelude::{Sink, Stream},
 };
 use smol::{
     io::{ReadHalf, WriteHalf},
+    lock::Mutex,
     prelude::{AsyncRead, AsyncWrite},
 };
 use std::{
     collections::HashMap,
-    io,
     sync::{
         atomic::{self, AtomicI32},
         Arc,
@@ -22,21 +21,20 @@ use zed_rpc::proto::{
     self, MessageStream, RequestMessage, SendMessage, ServerMessage, SubscribeMessage,
 };
 
-pub struct RpcClient {
+pub struct RpcClient<Conn> {
     response_channels: Arc<Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>>,
-    outgoing_tx: mpsc::Sender<(proto::FromClient, oneshot::Sender<io::Result<()>>)>,
+    outgoing: Mutex<MessageStream<WriteHalf<Conn>>>,
     next_message_id: AtomicI32,
     _drop_tx: barrier::Sender,
 }
 
-impl RpcClient {
-    pub fn new<Conn>(conn: Conn, executor: Arc<Background>) -> Self
-    where
-        Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
-    {
+impl<Conn> RpcClient<Conn>
+where
+    Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
+{
+    pub fn new(conn: Conn, executor: Arc<Background>) -> Self {
         let response_channels = Arc::new(Mutex::new(HashMap::new()));
         let (conn_rx, conn_tx) = smol::io::split(conn);
-        let (outgoing_tx, outgoing_rx) = mpsc::channel(32);
         let (_drop_tx, drop_rx) = barrier::channel();
 
         executor
@@ -47,27 +45,21 @@ impl RpcClient {
             ))
             .detach();
 
-        executor
-            .spawn(Self::handle_outgoing(conn_tx, outgoing_rx))
-            .detach();
-
         Self {
             response_channels,
-            outgoing_tx,
+            outgoing: Mutex::new(MessageStream::new(conn_tx)),
             _drop_tx,
             next_message_id: AtomicI32::new(0),
         }
     }
 
-    async fn handle_incoming<Conn>(
+    async fn handle_incoming(
         conn: ReadHalf<Conn>,
         mut drop_rx: barrier::Receiver,
         response_channels: Arc<
             Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>,
         >,
-    ) where
-        Conn: AsyncRead + Unpin,
-    {
+    ) {
         let mut stream = MessageStream::new(conn);
         loop {
             let read_message = stream.read_message::<proto::FromServer>();
@@ -78,11 +70,14 @@ impl RpcClient {
                 Either::Left((Ok(incoming), _)) => {
                     if let Some(variant) = incoming.variant {
                         if let Some(request_id) = incoming.request_id {
-                            let channel = response_channels.lock().remove(&request_id);
+                            let channel = response_channels.lock().await.remove(&request_id);
                             if let Some((mut tx, oneshot)) = channel {
                                 if tx.send(variant).await.is_ok() {
                                     if !oneshot {
-                                        response_channels.lock().insert(request_id, (tx, false));
+                                        response_channels
+                                            .lock()
+                                            .await
+                                            .insert(request_id, (tx, false));
                                     }
                                 }
                             } else {
@@ -104,36 +99,21 @@ impl RpcClient {
         }
     }
 
-    async fn handle_outgoing<Conn>(
-        conn: WriteHalf<Conn>,
-        mut outgoing_rx: mpsc::Receiver<(proto::FromClient, oneshot::Sender<io::Result<()>>)>,
-    ) where
-        Conn: AsyncWrite + Unpin,
-    {
-        let mut stream = MessageStream::new(conn);
-        while let Some((message, mut result_tx)) = outgoing_rx.recv().await {
-            let result = stream.write_message(&message).await;
-            result_tx.send(result).await.unwrap();
-        }
-    }
-
     pub async fn request<T: RequestMessage>(&self, req: T) -> Result<T::Response> {
         let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
-        let (result_tx, mut result_rx) = oneshot::channel();
         let (tx, mut rx) = mpsc::channel(1);
-        self.response_channels.lock().insert(message_id, (tx, true));
-        self.outgoing_tx
-            .clone()
-            .send((
-                proto::FromClient {
-                    id: message_id,
-                    variant: Some(req.to_variant()),
-                },
-                result_tx,
-            ))
+        self.response_channels
+            .lock()
+            .await
+            .insert(message_id, (tx, true));
+        self.outgoing
+            .lock()
             .await
-            .ok();
-        result_rx.recv().await.unwrap()?;
+            .write_message(&proto::FromClient {
+                id: message_id,
+                variant: Some(req.to_variant()),
+            })
+            .await?;
         let response = rx
             .recv()
             .await
@@ -144,19 +124,14 @@ impl RpcClient {
 
     pub async fn send<T: SendMessage>(&self, message: T) -> Result<()> {
         let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
-        let (result_tx, mut result_rx) = oneshot::channel();
-        self.outgoing_tx
-            .clone()
-            .send((
-                proto::FromClient {
-                    id: message_id,
-                    variant: Some(message.to_variant()),
-                },
-                result_tx,
-            ))
+        self.outgoing
+            .lock()
             .await
-            .ok();
-        result_rx.recv().await.unwrap()?;
+            .write_message(&proto::FromClient {
+                id: message_id,
+                variant: Some(message.to_variant()),
+            })
+            .await?;
         Ok(())
     }
 
@@ -165,23 +140,19 @@ impl RpcClient {
         subscription: T,
     ) -> Result<impl Stream<Item = Result<T::Event>>> {
         let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
-        let (result_tx, mut result_rx) = oneshot::channel();
         let (tx, rx) = mpsc::channel(256);
         self.response_channels
             .lock()
+            .await
             .insert(message_id, (tx, false));
-        self.outgoing_tx
-            .clone()
-            .send((
-                proto::FromClient {
-                    id: message_id,
-                    variant: Some(subscription.to_variant()),
-                },
-                result_tx,
-            ))
+        self.outgoing
+            .lock()
             .await
-            .ok();
-        result_rx.recv().await.unwrap()?;
+            .write_message(&proto::FromClient {
+                id: message_id,
+                variant: Some(subscription.to_variant()),
+            })
+            .await?;
         Ok(rx.map(|event| {
             T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
         }))