Remove gpui dependency from rpc_client

Max Brunsfeld and Nathan Sobo created

Also, avoid any contention between connections.

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

zed/src/rpc_client.rs | 164 ++++++++++++++++++++++++++------------------
zed/src/workspace.rs  |   3 
2 files changed, 98 insertions(+), 69 deletions(-)

Detailed changes

zed/src/rpc_client.rs 🔗

@@ -1,13 +1,12 @@
 use anyhow::{anyhow, Result};
 use futures::future::Either;
-use gpui::executor::Background;
 use postage::{
     barrier, oneshot,
     prelude::{Sink, Stream},
 };
 use smol::{
     io::BoxedWriter,
-    lock::Mutex,
+    lock::{Mutex, RwLock},
     prelude::{AsyncRead, AsyncWrite},
 };
 use std::{
@@ -23,29 +22,27 @@ use zed_rpc::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 pub struct ConnectionId(u32);
 
-pub struct RpcClient {
-    response_channels: Arc<Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>>,
-    outgoing: Arc<Mutex<HashMap<ConnectionId, MessageStream<BoxedWriter>>>>,
+struct RpcConnection {
+    writer: Mutex<MessageStream<BoxedWriter>>,
+    response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
     next_message_id: AtomicU32,
+    _close_barrier: barrier::Sender,
+}
+
+pub struct RpcClient {
+    connections: Arc<RwLock<HashMap<ConnectionId, Arc<RpcConnection>>>>,
     next_connection_id: AtomicU32,
-    _drop_tx: barrier::Sender,
-    drop_rx: barrier::Receiver,
 }
 
 impl RpcClient {
     pub fn new() -> Arc<Self> {
-        let (_drop_tx, drop_rx) = barrier::channel();
         Arc::new(Self {
-            response_channels: Default::default(),
-            outgoing: Default::default(),
-            next_message_id: Default::default(),
+            connections: Default::default(),
             next_connection_id: Default::default(),
-            _drop_tx,
-            drop_rx,
         })
     }
 
-    pub async fn connect<Conn>(&self, conn: Conn, executor: Arc<Background>) -> ConnectionId
+    pub async fn add_connection<Conn>(&self, conn: Conn) -> (ConnectionId, impl Future<Output = ()>)
     where
         Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
     {
@@ -53,52 +50,63 @@ impl RpcClient {
             self.next_connection_id
                 .fetch_add(1, atomic::Ordering::SeqCst),
         );
+        let (close_tx, mut close_rx) = barrier::channel();
         let (conn_rx, conn_tx) = smol::io::split(conn);
-        let response_channels = self.response_channels.clone();
-        let mut drop_rx = self.drop_rx.clone();
-        let outgoing = self.outgoing.clone();
+        let connections = self.connections.clone();
+        let connection = Arc::new(RpcConnection {
+            writer: Mutex::new(MessageStream::new(Box::pin(conn_tx))),
+            response_channels: Default::default(),
+            next_message_id: Default::default(),
+            _close_barrier: close_tx,
+        });
+
+        connections
+            .write()
+            .await
+            .insert(connection_id, connection.clone());
 
-        executor
-            .spawn(async move {
-                let dropped = drop_rx.recv();
-                smol::pin!(dropped);
+        let handler_future = async move {
+            let closed = close_rx.recv();
+            smol::pin!(closed);
 
-                let mut stream = MessageStream::new(conn_rx);
-                loop {
-                    let read_message = stream.read_message();
-                    smol::pin!(read_message);
+            let mut stream = MessageStream::new(conn_rx);
+            loop {
+                let read_message = stream.read_message();
+                smol::pin!(read_message);
 
-                    match futures::future::select(read_message, &mut dropped).await {
-                        Either::Left((Ok(incoming), _)) => {
-                            if let Some(responding_to) = incoming.responding_to {
-                                let channel = response_channels.lock().await.remove(&responding_to);
-                                if let Some(mut tx) = channel {
-                                    tx.send(incoming).await.ok();
-                                } else {
-                                    log::warn!(
-                                        "received RPC response to unknown request {}",
-                                        responding_to
-                                    );
-                                }
+                match futures::future::select(read_message, &mut closed).await {
+                    Either::Left((Ok(incoming), _)) => {
+                        if let Some(responding_to) = incoming.responding_to {
+                            let channel = connection
+                                .response_channels
+                                .lock()
+                                .await
+                                .remove(&responding_to);
+                            if let Some(mut tx) = channel {
+                                tx.send(incoming).await.ok();
                             } else {
-                                // unprompted message from server
+                                log::warn!(
+                                    "received RPC response to unknown request {}",
+                                    responding_to
+                                );
                             }
+                        } else {
+                            // unprompted message from server
                         }
-                        Either::Left((Err(error), _)) => {
-                            log::warn!("received invalid RPC message {:?}", error);
-                        }
-                        Either::Right(_) => break,
                     }
+                    Either::Left((Err(error), _)) => {
+                        log::warn!("received invalid RPC message {:?}", error);
+                    }
+                    Either::Right(_) => break,
                 }
-            })
-            .detach();
+            }
+        };
 
-        outgoing
-            .lock()
-            .await
-            .insert(connection_id, MessageStream::new(Box::pin(conn_tx)));
+        (connection_id, handler_future)
+    }
 
-        connection_id
+    pub async fn disconnect(&self, connection_id: ConnectionId) {
+        self.connections.write().await.remove(&connection_id);
     }
 
     pub fn request<T: RequestMessage>(
@@ -106,17 +114,27 @@ impl RpcClient {
         connection_id: ConnectionId,
         req: T,
     ) -> impl Future<Output = Result<T::Response>> {
-        let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
-        let outgoing = self.outgoing.clone();
-        let response_channels = self.response_channels.clone();
+        let connections = self.connections.clone();
         let (tx, mut rx) = oneshot::channel();
         async move {
-            response_channels.lock().await.insert(message_id, tx);
-            outgoing
-                .lock()
+            let connection = connections
+                .read()
                 .await
-                .get_mut(&connection_id)
+                .get(&connection_id)
                 .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
+                .clone();
+            let message_id = connection
+                .next_message_id
+                .fetch_add(1, atomic::Ordering::SeqCst);
+            connection
+                .response_channels
+                .lock()
+                .await
+                .insert(message_id, tx);
+            connection
+                .writer
+                .lock()
+                .await
                 .write_message(&req.into_envelope(message_id, None))
                 .await?;
             let response = rx
@@ -124,7 +142,7 @@ impl RpcClient {
                 .await
                 .expect("response channel was unexpectedly dropped");
             T::Response::from_envelope(response)
-                .ok_or_else(|| anyhow!("received response of the wrong t"))
+                .ok_or_else(|| anyhow!("received response of the wrong type"))
         }
     }
 
@@ -133,14 +151,21 @@ impl RpcClient {
         connection_id: ConnectionId,
         message: T,
     ) -> impl Future<Output = Result<()>> {
-        let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
-        let outgoing = self.outgoing.clone();
+        let connections = self.connections.clone();
         async move {
-            outgoing
-                .lock()
+            let connection = connections
+                .read()
                 .await
-                .get_mut(&connection_id)
+                .get(&connection_id)
                 .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
+                .clone();
+            let message_id = connection
+                .next_message_id
+                .fetch_add(1, atomic::Ordering::SeqCst);
+            connection
+                .writer
+                .lock()
+                .await
                 .write_message(&message.into_envelope(message_id, None))
                 .await?;
             Ok(())
@@ -170,7 +195,8 @@ mod tests {
 
         let mut server_stream = MessageStream::new(server_conn);
         let client = RpcClient::new();
-        let connection_id = client.connect(client_conn, executor.clone()).await;
+        let (connection_id, handler) = client.add_connection(client_conn).await;
+        executor.spawn(handler).detach();
 
         let client_req = client.request(
             connection_id,
@@ -219,7 +245,7 @@ mod tests {
     }
 
     #[gpui::test]
-    async fn test_drop_client(cx: gpui::TestAppContext) {
+    async fn test_disconnect(cx: gpui::TestAppContext) {
         let executor = cx.read(|app| app.background_executor().clone());
         let socket_dir_path = TempDir::new("drop-client").unwrap();
         let socket_path = socket_dir_path.path().join(".sock");
@@ -228,8 +254,9 @@ mod tests {
         let (mut server_conn, _) = listener.accept().await.unwrap();
 
         let client = RpcClient::new();
-        client.connect(client_conn, executor.clone()).await;
-        drop(client);
+        let (connection_id, handler) = client.add_connection(client_conn).await;
+        executor.spawn(handler).detach();
+        client.disconnect(connection_id).await;
 
         // Try sending an empty payload over and over, until the client is dropped and hangs up.
         loop {
@@ -254,10 +281,11 @@ mod tests {
         client_conn.close().await.unwrap();
 
         let client = RpcClient::new();
-        let conn_id = client.connect(client_conn, executor.clone()).await;
+        let (connection_id, handler) = client.add_connection(client_conn).await;
+        executor.spawn(handler).detach();
         let err = client
             .request(
-                conn_id,
+                connection_id,
                 proto::Auth {
                     user_id: 42,
                     access_token: "token".to_string(),

zed/src/workspace.rs 🔗

@@ -671,7 +671,8 @@ impl Workspace {
             let stream = smol::net::TcpStream::connect(rpc_address).await?;
 
             let rpc_client = RpcClient::new();
-            let connection_id = rpc_client.connect(stream, executor).await;
+            let (connection_id, handler) = rpc_client.add_connection(stream).await;
+            executor.spawn(handler).detach();
 
             let auth_response = rpc_client
                 .request(