Change `RpcClient` methods to take shared references

Antonio Scandurra and Max Brunsfeld created

This will make it easier to spawn a future on gpui's executors
when calling `RpcClient` methods.

Co-Authored-By: Max Brunsfeld <max@zed.dev>

Change summary

Cargo.lock            |  15 +++
zed/Cargo.toml        |   2 
zed/src/rpc_client.rs | 179 +++++++++++++++++++++++---------------------
zed/src/workspace.rs  |   7 -
zed/src/worktree.rs   |   4 
5 files changed, 113 insertions(+), 94 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -1350,6 +1350,7 @@ checksum = "da9052a1a50244d8d5aa9bf55cbc2fb6f357c86cc52e46c62ed390a7180cf150"
 dependencies = [
  "futures-channel",
  "futures-core",
+ "futures-executor",
  "futures-io",
  "futures-sink",
  "futures-task",
@@ -1372,6 +1373,17 @@ version = "0.3.12"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "79e5145dde8da7d1b3892dad07a9c98fc04bc39892b1ecc9692cf53e2b780a65"
 
+[[package]]
+name = "futures-executor"
+version = "0.3.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e9e59fdc009a4b3096bf94f740a0f2424c082521f20a9b08c5c07c48d90fd9b9"
+dependencies = [
+ "futures-core",
+ "futures-task",
+ "futures-util",
+]
+
 [[package]]
 name = "futures-io"
 version = "0.3.12"
@@ -1423,6 +1435,7 @@ version = "0.3.12"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "632a8cd0f2a4b3fdea1657f08bde063848c3bd00f9bbf6e256b8be78802e624b"
 dependencies = [
+ "futures-channel",
  "futures-core",
  "futures-io",
  "futures-macro",
@@ -4304,7 +4317,7 @@ dependencies = [
  "easy-parallel",
  "env_logger",
  "fsevent",
- "futures-core",
+ "futures",
  "gpui",
  "http-auth-basic",
  "ignore",

zed/Cargo.toml 🔗

@@ -21,7 +21,7 @@ ctor = "0.1.20"
 dirs = "3.0"
 easy-parallel = "3.1.0"
 fsevent = { path = "../fsevent" }
-futures-core = "0.3"
+futures = "0.3"
 gpui = { path = "../gpui" }
 http-auth-basic = "0.1.3"
 ignore = "0.4"

zed/src/rpc_client.rs 🔗

@@ -1,106 +1,112 @@
 use anyhow::{anyhow, Result};
+use futures::FutureExt;
 use gpui::executor::Background;
 use parking_lot::Mutex;
 use postage::{
-    mpsc, oneshot,
+    mpsc,
     prelude::{Sink, Stream},
 };
-use smol::{
-    future::FutureExt,
-    io::WriteHalf,
-    prelude::{AsyncRead, AsyncWrite},
+use smol::prelude::{AsyncRead, AsyncWrite};
+use std::{
+    collections::HashMap,
+    io,
+    sync::{
+        atomic::{self, AtomicI32},
+        Arc,
+    },
 };
-use std::{collections::HashMap, sync::Arc};
 use zed_rpc::proto::{
     self, MessageStream, RequestMessage, SendMessage, ServerMessage, SubscribeMessage,
 };
 
-pub struct RpcClient<Conn> {
-    stream: MessageStream<WriteHalf<Conn>>,
+pub struct RpcClient {
     response_channels: Arc<Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>>,
-    next_message_id: i32,
-    _drop_tx: oneshot::Sender<()>,
+    outgoing_tx: mpsc::Sender<proto::FromClient>,
+    next_message_id: AtomicI32,
 }
 
-impl<Conn> RpcClient<Conn>
-where
-    Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
-{
-    pub fn new(conn: Conn, executor: Arc<Background>) -> Self {
-        let (conn_rx, conn_tx) = smol::io::split(conn);
-        let (drop_tx, mut drop_rx) = oneshot::channel();
+impl RpcClient {
+    pub fn new<Conn>(conn: Conn, executor: Arc<Background>) -> Self
+    where
+        Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
+    {
         let response_channels = Arc::new(Mutex::new(HashMap::new()));
-        let client = Self {
-            next_message_id: 0,
-            stream: MessageStream::new(conn_tx),
-            response_channels: response_channels.clone(),
-            _drop_tx: drop_tx,
-        };
+        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(32);
 
-        executor
-            .spawn::<Result<()>, _>(async move {
-                enum Message {
-                    Message(proto::FromServer),
-                    ClientDropped,
-                }
+        {
+            let response_channels = response_channels.clone();
+            executor
+                .spawn(async move {
+                    let (conn_rx, conn_tx) = smol::io::split(conn);
+                    let mut stream_tx = MessageStream::new(conn_tx);
+                    let mut stream_rx = MessageStream::new(conn_rx);
+                    loop {
+                        futures::select! {
+                            incoming = stream_rx.read_message::<proto::FromServer>().fuse() => {
+                                Self::handle_incoming(incoming, &response_channels).await;
+                            }
+                            outgoing = outgoing_rx.recv().fuse() => {
+                                if let Some(outgoing) = outgoing {
+                                    stream_tx.write_message(&outgoing).await;
+                                } else {
+                                    break;
+                                }
+                            }
+                        }
+                    }
+                })
+                .detach();
+        }
 
-                let mut stream = MessageStream::new(conn_rx);
-                let client_dropped = async move {
-                    assert!(drop_rx.recv().await.is_none());
-                    Ok(Message::ClientDropped) as Result<_>
-                };
-                smol::pin!(client_dropped);
-                loop {
-                    let message = async {
-                        Ok(Message::Message(
-                            stream.read_message::<proto::FromServer>().await?,
-                        ))
-                    };
+        Self {
+            response_channels,
+            outgoing_tx,
+            next_message_id: AtomicI32::new(0),
+        }
+    }
 
-                    match message.race(&mut client_dropped).await? {
-                        Message::Message(message) => {
-                            if let Some(variant) = message.variant {
-                                if let Some(request_id) = message.request_id {
-                                    let channel = response_channels.lock().remove(&request_id);
-                                    if let Some((mut tx, oneshot)) = channel {
-                                        if tx.send(variant).await.is_ok() {
-                                            if !oneshot {
-                                                response_channels
-                                                    .lock()
-                                                    .insert(request_id, (tx, false));
-                                            }
-                                        }
-                                    } else {
-                                        log::warn!(
-                                            "received RPC response to unknown request id {}",
-                                            request_id
-                                        );
-                                    }
+    async fn handle_incoming(
+        incoming: io::Result<proto::FromServer>,
+        response_channels: &Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>,
+    ) {
+        match incoming {
+            Ok(incoming) => {
+                if let Some(variant) = incoming.variant {
+                    if let Some(request_id) = incoming.request_id {
+                        let channel = response_channels.lock().remove(&request_id);
+                        if let Some((mut tx, oneshot)) = channel {
+                            if tx.send(variant).await.is_ok() {
+                                if !oneshot {
+                                    response_channels.lock().insert(request_id, (tx, false));
                                 }
-                            } else {
-                                log::warn!("received RPC message with no content");
                             }
+                        } else {
+                            log::warn!(
+                                "received RPC response to unknown request id {}",
+                                request_id
+                            );
                         }
-                        Message::ClientDropped => break Ok(()),
                     }
+                } else {
+                    log::warn!("received RPC message with no content");
                 }
-            })
-            .detach();
-
-        client
+            }
+            Err(error) => log::warn!("invalid incoming RPC message {:?}", error),
+        }
     }
 
-    pub async fn request<T: RequestMessage>(&mut self, req: T) -> Result<T::Response> {
-        let message_id = self.next_message_id;
-        self.next_message_id += 1;
+    pub async fn request<T: RequestMessage>(&self, req: T) -> Result<T::Response> {
+        let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
         let (tx, mut rx) = mpsc::channel(1);
         self.response_channels.lock().insert(message_id, (tx, true));
-        self.stream
-            .write_message(&proto::FromClient {
+        self.outgoing_tx
+            .clone()
+            .send(proto::FromClient {
                 id: message_id,
                 variant: Some(req.to_variant()),
             })
-            .await?;
+            .await
+            .unwrap();
         let response = rx
             .recv()
             .await
@@ -109,15 +115,16 @@ where
             .ok_or_else(|| anyhow!("received response of the wrong t"))
     }
 
-    pub async fn send<T: SendMessage>(&mut self, message: T) -> Result<()> {
-        let message_id = self.next_message_id;
-        self.next_message_id += 1;
-        self.stream
-            .write_message(&proto::FromClient {
+    pub async fn send<T: SendMessage>(&self, message: T) -> Result<()> {
+        let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
+        self.outgoing_tx
+            .clone()
+            .send(proto::FromClient {
                 id: message_id,
                 variant: Some(message.to_variant()),
             })
-            .await?;
+            .await
+            .unwrap();
         Ok(())
     }
 
@@ -125,19 +132,19 @@ where
         &mut self,
         subscription: T,
     ) -> Result<impl Stream<Item = Result<T::Event>>> {
-        let message_id = self.next_message_id;
-        self.next_message_id += 1;
+        let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
         let (tx, rx) = mpsc::channel(256);
         self.response_channels
             .lock()
             .insert(message_id, (tx, false));
-        self.stream
-            .write_message(&proto::FromClient {
+        self.outgoing_tx
+            .clone()
+            .send(proto::FromClient {
                 id: message_id,
                 variant: Some(subscription.to_variant()),
             })
-            .await?;
-
+            .await
+            .unwrap();
         Ok(rx.map(|event| {
             T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
         }))
@@ -165,7 +172,7 @@ mod tests {
         let (server_conn, _) = listener.accept().await.unwrap();
 
         let mut server_stream = MessageStream::new(server_conn);
-        let mut client = RpcClient::new(client_conn, executor.clone());
+        let client = RpcClient::new(client_conn, executor.clone());
 
         let client_req = client.request(proto::from_client::Auth {
             user_id: 42,

zed/src/workspace.rs 🔗

@@ -8,7 +8,6 @@ use crate::{
     worktree::{FileHandle, Worktree, WorktreeHandle},
     AppState,
 };
-use futures_core::Future;
 use gpui::{
     color::rgbu, elements::*, json::to_string_pretty, keymap::Binding, AnyViewHandle, AppContext,
     ClipboardItem, Entity, ModelHandle, MutableAppContext, PathPromptOptions, PromptLevel, Task,
@@ -19,10 +18,10 @@ pub use pane::*;
 pub use pane_group::*;
 use postage::watch;
 use smol::prelude::*;
-use std::{collections::HashMap, path::PathBuf};
 use std::{
-    collections::{hash_map::Entry, HashSet},
-    path::Path,
+    collections::{hash_map::Entry, HashMap, HashSet},
+    future::Future,
+    path::{Path, PathBuf},
     sync::Arc,
 };
 

zed/src/worktree.rs 🔗

@@ -1207,7 +1207,7 @@ pub trait WorktreeHandle {
     fn flush_fs_events<'a>(
         &self,
         cx: &'a gpui::TestAppContext,
-    ) -> futures_core::future::LocalBoxFuture<'a, ()>;
+    ) -> futures::future::LocalBoxFuture<'a, ()>;
 }
 
 impl WorktreeHandle for ModelHandle<Worktree> {
@@ -1268,7 +1268,7 @@ impl WorktreeHandle for ModelHandle<Worktree> {
     fn flush_fs_events<'a>(
         &self,
         cx: &'a gpui::TestAppContext,
-    ) -> futures_core::future::LocalBoxFuture<'a, ()> {
+    ) -> futures::future::LocalBoxFuture<'a, ()> {
         use smol::future::FutureExt;
 
         let filename = "fs-event-sentinel";