Allow RpcClient to encapsulate arbitrarily many connections

Max Brunsfeld created

Change summary

zed-rpc/src/proto.rs  |   2 
zed/src/rpc_client.rs | 170 +++++++++++++++++++++++++-------------------
zed/src/workspace.rs  |  18 +++-
zed/src/worktree.rs   |  12 ++-
4 files changed, 119 insertions(+), 83 deletions(-)

Detailed changes

zed-rpc/src/proto.rs 🔗

@@ -5,7 +5,7 @@ use std::{convert::TryInto, io};
 
 include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 
-pub trait EnvelopedMessage: Sized {
+pub trait EnvelopedMessage: Sized + Send + 'static {
     fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
     fn from_envelope(envelope: Envelope) -> Option<Self>;
 }

zed/src/rpc_client.rs 🔗

@@ -6,7 +6,7 @@ use postage::{
     prelude::{Sink, Stream},
 };
 use smol::{
-    io::{BoxedWriter, ReadHalf},
+    io::BoxedWriter,
     lock::Mutex,
     prelude::{AsyncRead, AsyncWrite},
 };
@@ -20,89 +20,103 @@ use std::{
 };
 use zed_rpc::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
 
+#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
+pub struct ConnectionId(u32);
+
 pub struct RpcClient {
     response_channels: Arc<Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>>,
-    outgoing: Mutex<MessageStream<BoxedWriter>>,
+    outgoing: Arc<Mutex<HashMap<ConnectionId, MessageStream<BoxedWriter>>>>,
     next_message_id: AtomicU32,
+    next_connection_id: AtomicU32,
     _drop_tx: barrier::Sender,
+    drop_rx: barrier::Receiver,
 }
 
 impl RpcClient {
-    pub fn new<Conn>(conn: Conn, executor: Arc<Background>) -> Arc<Self>
-    where
-        Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
-    {
-        let response_channels = Arc::new(Mutex::new(HashMap::new()));
-        let (conn_rx, conn_tx) = smol::io::split(conn);
+    pub fn new() -> Arc<Self> {
         let (_drop_tx, drop_rx) = barrier::channel();
-
-        executor
-            .spawn(Self::handle_incoming(
-                conn_rx,
-                drop_rx,
-                response_channels.clone(),
-            ))
-            .detach();
-
         Arc::new(Self {
-            response_channels,
-            outgoing: Mutex::new(MessageStream::new(Box::pin(conn_tx))),
+            response_channels: Default::default(),
+            outgoing: Default::default(),
+            next_message_id: Default::default(),
+            next_connection_id: Default::default(),
             _drop_tx,
-            next_message_id: AtomicU32::new(0),
+            drop_rx,
         })
     }
 
-    async fn handle_incoming<Conn>(
-        conn: ReadHalf<Conn>,
-        mut drop_rx: barrier::Receiver,
-        response_channels: Arc<Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>>,
-    ) where
-        Conn: AsyncRead + Unpin,
+    pub async fn connect<Conn>(&self, conn: Conn, executor: Arc<Background>) -> ConnectionId
+    where
+        Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
     {
-        let dropped = drop_rx.recv();
-        smol::pin!(dropped);
+        let connection_id = ConnectionId(
+            self.next_connection_id
+                .fetch_add(1, atomic::Ordering::SeqCst),
+        );
+        let (conn_rx, conn_tx) = smol::io::split(conn);
+        let response_channels = self.response_channels.clone();
+        let mut drop_rx = self.drop_rx.clone();
+        let outgoing = self.outgoing.clone();
 
-        let mut stream = MessageStream::new(conn);
-        loop {
-            let read_message = stream.read_message();
-            smol::pin!(read_message);
+        executor
+            .spawn(async move {
+                let dropped = drop_rx.recv();
+                smol::pin!(dropped);
+
+                let mut stream = MessageStream::new(conn_rx);
+                loop {
+                    let read_message = stream.read_message();
+                    smol::pin!(read_message);
 
-            match futures::future::select(read_message, &mut dropped).await {
-                Either::Left((Ok(incoming), _)) => {
-                    if let Some(responding_to) = incoming.responding_to {
-                        let channel = response_channels.lock().await.remove(&responding_to);
-                        if let Some(mut tx) = channel {
-                            tx.send(incoming).await.ok();
-                        } else {
-                            log::warn!(
-                                "received RPC response to unknown request {}",
-                                responding_to
-                            );
+                    match futures::future::select(read_message, &mut dropped).await {
+                        Either::Left((Ok(incoming), _)) => {
+                            if let Some(responding_to) = incoming.responding_to {
+                                let channel = response_channels.lock().await.remove(&responding_to);
+                                if let Some(mut tx) = channel {
+                                    tx.send(incoming).await.ok();
+                                } else {
+                                    log::warn!(
+                                        "received RPC response to unknown request {}",
+                                        responding_to
+                                    );
+                                }
+                            } else {
+                                // unprompted message from server
+                            }
                         }
-                    } else {
-                        // unprompted message from server
+                        Either::Left((Err(error), _)) => {
+                            log::warn!("received invalid RPC message {:?}", error);
+                        }
+                        Either::Right(_) => break,
                     }
                 }
-                Either::Left((Err(error), _)) => {
-                    log::warn!("received invalid RPC message {:?}", error);
-                }
-                Either::Right(_) => break,
-            }
-        }
+            })
+            .detach();
+
+        outgoing
+            .lock()
+            .await
+            .insert(connection_id, MessageStream::new(Box::pin(conn_tx)));
+
+        connection_id
     }
 
     pub fn request<T: RequestMessage>(
-        self: &Arc<Self>,
+        &self,
+        connection_id: ConnectionId,
         req: T,
     ) -> impl Future<Output = Result<T::Response>> {
-        let this = self.clone();
+        let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
+        let outgoing = self.outgoing.clone();
+        let response_channels = self.response_channels.clone();
+        let (tx, mut rx) = oneshot::channel();
         async move {
-            let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
-            let (tx, mut rx) = oneshot::channel();
-            this.response_channels.lock().await.insert(message_id, tx);
-            this.outgoing
+            response_channels.lock().await.insert(message_id, tx);
+            outgoing
                 .lock()
                 .await
+                .get_mut(&connection_id)
+                .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
                 .write_message(&req.into_envelope(message_id, None))
                 .await?;
             let response = rx
@@ -115,15 +129,18 @@ impl RpcClient {
     }
 
     pub fn send<T: EnvelopedMessage>(
-        self: &Arc<Self>,
+        &self,
+        connection_id: ConnectionId,
         message: T,
     ) -> impl Future<Output = Result<()>> {
-        let this = self.clone();
+        let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
+        let outgoing = self.outgoing.clone();
         async move {
-            let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
-            this.outgoing
+            outgoing
                 .lock()
                 .await
+                .get_mut(&connection_id)
+                .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
                 .write_message(&message.into_envelope(message_id, None))
                 .await?;
             Ok(())
@@ -152,12 +169,16 @@ mod tests {
         let (server_conn, _) = listener.accept().await.unwrap();
 
         let mut server_stream = MessageStream::new(server_conn);
-        let client = RpcClient::new(client_conn, executor.clone());
+        let client = RpcClient::new();
+        let connection_id = client.connect(client_conn, executor.clone()).await;
 
-        let client_req = client.request(proto::Auth {
-            user_id: 42,
-            access_token: "token".to_string(),
-        });
+        let client_req = client.request(
+            connection_id,
+            proto::Auth {
+                user_id: 42,
+                access_token: "token".to_string(),
+            },
+        );
         smol::pin!(client_req);
         let server_req = send_recv(&mut client_req, server_stream.read_message())
             .await
@@ -206,7 +227,8 @@ mod tests {
         let client_conn = UnixStream::connect(&socket_path).await.unwrap();
         let (mut server_conn, _) = listener.accept().await.unwrap();
 
-        let client = RpcClient::new(client_conn, executor.clone());
+        let client = RpcClient::new();
+        client.connect(client_conn, executor.clone()).await;
         drop(client);
 
         // Try sending an empty payload over and over, until the client is dropped and hangs up.
@@ -231,12 +253,16 @@ mod tests {
         let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
         client_conn.close().await.unwrap();
 
-        let client = RpcClient::new(client_conn, executor.clone());
+        let client = RpcClient::new();
+        let conn_id = client.connect(client_conn, executor.clone()).await;
         let err = client
-            .request(proto::Auth {
-                user_id: 42,
-                access_token: "token".to_string(),
-            })
+            .request(
+                conn_id,
+                proto::Auth {
+                    user_id: 42,
+                    access_token: "token".to_string(),
+                },
+            )
             .await
             .unwrap_err();
         assert_eq!(

zed/src/workspace.rs 🔗

@@ -670,13 +670,17 @@ impl Workspace {
             // a TLS stream using `native-tls`.
             let stream = smol::net::TcpStream::connect(rpc_address).await?;
 
-            let rpc_client = RpcClient::new(stream, executor);
+            let rpc_client = RpcClient::new();
+            let connection_id = rpc_client.connect(stream, executor).await;
 
             let auth_response = rpc_client
-                .request(proto::Auth {
-                    user_id: user_id.parse::<u64>()?,
-                    access_token,
-                })
+                .request(
+                    connection_id,
+                    proto::Auth {
+                        user_id: user_id.parse::<u64>()?,
+                        access_token,
+                    },
+                )
                 .await?;
             if !auth_response.credentials_valid {
                 Err(anyhow!("failed to authenticate with RPC server"))?;
@@ -684,7 +688,9 @@ impl Workspace {
 
             let share_task = this.update(&mut cx, |this, cx| {
                 let worktree = this.worktrees.iter().next()?;
-                Some(worktree.update(cx, |worktree, cx| worktree.share(rpc_client, cx)))
+                Some(worktree.update(cx, |worktree, cx| {
+                    worktree.share(rpc_client, connection_id, cx)
+                }))
             });
 
             if let Some(share_task) = share_task {

zed/src/worktree.rs 🔗

@@ -4,7 +4,7 @@ mod ignore;
 
 use crate::{
     editor::{History, Rope},
-    rpc_client::RpcClient,
+    rpc_client::{ConnectionId, RpcClient},
     sum_tree::{self, Cursor, Edit, SumTree},
     util::Bias,
 };
@@ -229,6 +229,7 @@ impl Worktree {
     pub fn share(
         &mut self,
         client: Arc<RpcClient>,
+        connection_id: ConnectionId,
         cx: &mut ModelContext<Self>,
     ) -> Task<anyhow::Result<()>> {
         self.rpc_client = Some(client.clone());
@@ -245,9 +246,12 @@ impl Worktree {
                 .await;
 
             let share_response = client
-                .request(proto::ShareWorktree {
-                    worktree: Some(proto::Worktree { paths }),
-                })
+                .request(
+                    connection_id,
+                    proto::ShareWorktree {
+                        worktree: Some(proto::Worktree { paths }),
+                    },
+                )
                 .await?;
 
             log::info!("sharing worktree {:?}", share_response);