Allow peers to receive individual messages before starting message loop

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

zed-rpc/src/peer.rs  | 118 ++++++++++++++++++++++++++++-----------------
zed/src/workspace.rs |   4 
2 files changed, 75 insertions(+), 47 deletions(-)

Detailed changes

zed-rpc/src/peer.rs 🔗

@@ -21,12 +21,14 @@ use std::{
 };
 
 type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
+type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
 
 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 pub struct ConnectionId(u32);
 
 struct Connection {
     writer: Mutex<MessageStream<BoxedWriter>>,
+    reader: Mutex<MessageStream<BoxedReader>>,
     response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
     next_message_id: AtomicU32,
 }
@@ -52,7 +54,8 @@ impl<T> TypedEnvelope<T> {
 }
 
 pub struct Peer {
-    connections: RwLock<HashMap<ConnectionId, (Arc<Connection>, barrier::Sender)>>,
+    connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
+    connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
     message_handlers: RwLock<Vec<MessageHandler>>,
     handler_types: Mutex<HashSet<TypeId>>,
     next_connection_id: AtomicU32,
@@ -62,6 +65,7 @@ impl Peer {
     pub fn new() -> Arc<Self> {
         Arc::new(Self {
             connections: Default::default(),
+            connection_close_barriers: Default::default(),
             message_handlers: Default::default(),
             handler_types: Default::default(),
             next_connection_id: Default::default(),
@@ -102,10 +106,7 @@ impl Peer {
         rx
     }
 
-    pub async fn add_connection<Conn>(
-        self: &Arc<Self>,
-        conn: Conn,
-    ) -> (ConnectionId, impl Future<Output = Result<()>>)
+    pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
     where
         Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
     {
@@ -113,26 +114,44 @@ impl Peer {
             self.next_connection_id
                 .fetch_add(1, atomic::Ordering::SeqCst),
         );
-        let (close_tx, mut close_rx) = barrier::channel();
-        let connection = Arc::new(Connection {
-            writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
-            response_channels: Default::default(),
-            next_message_id: Default::default(),
-        });
+        self.connections.write().await.insert(
+            connection_id,
+            Arc::new(Connection {
+                reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
+                writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
+                response_channels: Default::default(),
+                next_message_id: Default::default(),
+            }),
+        );
+        connection_id
+    }
 
-        self.connections
+    pub async fn disconnect(&self, connection_id: ConnectionId) {
+        self.connections.write().await.remove(&connection_id);
+        self.connection_close_barriers
             .write()
             .await
-            .insert(connection_id, (connection.clone(), close_tx));
+            .remove(&connection_id);
+    }
 
+    pub fn handle_messages(
+        self: &Arc<Self>,
+        connection_id: ConnectionId,
+    ) -> impl Future<Output = Result<()>> + 'static {
+        let (close_tx, mut close_rx) = barrier::channel();
         let this = self.clone();
-        let handler_future = async move {
+        async move {
+            this.connection_close_barriers
+                .write()
+                .await
+                .insert(connection_id, close_tx);
+            let connection = this.connection(connection_id).await?;
             let closed = close_rx.recv();
             futures::pin_mut!(closed);
 
-            let mut stream = MessageStream::new(conn);
             loop {
-                let read_message = stream.read_message();
+                let mut reader = connection.reader.lock().await;
+                let read_message = reader.read_message();
                 futures::pin_mut!(read_message);
 
                 match futures::future::select(read_message, &mut closed).await {
@@ -181,13 +200,23 @@ impl Peer {
                     Either::Right(_) => return Ok(()),
                 }
             }
-        };
-
-        (connection_id, handler_future)
+        }
     }
 
-    pub async fn disconnect(&self, connection_id: ConnectionId) {
-        self.connections.write().await.remove(&connection_id);
+    pub async fn receive<M: EnvelopedMessage>(
+        self: &Arc<Self>,
+        connection_id: ConnectionId,
+    ) -> Result<TypedEnvelope<M>> {
+        let connection = self.connection(connection_id).await?;
+        let envelope = connection.reader.lock().await.read_message().await?;
+        let id = envelope.id;
+        let payload =
+            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
+        Ok(TypedEnvelope {
+            id,
+            connection_id,
+            payload,
+        })
     }
 
     pub fn request<T: RequestMessage>(
@@ -271,7 +300,6 @@ impl Peer {
             .await
             .get(&id)
             .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
-            .0
             .clone())
     }
 }
@@ -298,22 +326,22 @@ mod tests {
             let server = Peer::new();
             let client1 = Peer::new();
             let client2 = Peer::new();
-            let (client1_conn_id, f1) = client1
+            let client1_conn_id = client1
                 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
                 .await;
-            let (client2_conn_id, f2) = client2
+            let client2_conn_id = client2
                 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
                 .await;
-            let (_, f3) = server
+            let server_conn_id1 = server
                 .add_connection(listener.accept().await.unwrap().0)
                 .await;
-            let (_, f4) = server
+            let server_conn_id2 = server
                 .add_connection(listener.accept().await.unwrap().0)
                 .await;
-            smol::spawn(f1).detach();
-            smol::spawn(f2).detach();
-            smol::spawn(f3).detach();
-            smol::spawn(f4).detach();
+            smol::spawn(client1.handle_messages(client1_conn_id)).detach();
+            smol::spawn(client2.handle_messages(client2_conn_id)).detach();
+            smol::spawn(server.handle_messages(server_conn_id1)).detach();
+            smol::spawn(server.handle_messages(server_conn_id2)).detach();
 
             // define the expected requests and responses
             let request1 = proto::OpenWorktree {
@@ -428,21 +456,21 @@ mod tests {
             let (mut server_conn, _) = listener.accept().await.unwrap();
 
             let client = Peer::new();
-            let (connection_id, handler) = client.add_connection(client_conn).await;
-            smol::spawn(handler).detach();
+            let connection_id = client.add_connection(client_conn).await;
+            let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
+                barrier::channel();
+            let handle_messages = client.handle_messages(connection_id);
+            smol::spawn(async move {
+                handle_messages.await.unwrap();
+                incoming_messages_ended_tx.send(()).await.unwrap();
+            })
+            .detach();
             client.disconnect(connection_id).await;
 
-            // Try sending an empty payload over and over, until the client is dropped and hangs up.
-            loop {
-                match server_conn.write(&[]).await {
-                    Ok(_) => {}
-                    Err(err) => {
-                        if err.kind() == io::ErrorKind::BrokenPipe {
-                            break;
-                        }
-                    }
-                }
-            }
+            incoming_messages_ended_rx.recv().await;
+
+            let err = server_conn.write(&[]).await.unwrap_err();
+            assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
         });
     }
 
@@ -456,8 +484,8 @@ mod tests {
             client_conn.close().await.unwrap();
 
             let client = Peer::new();
-            let (connection_id, handler) = client.add_connection(client_conn).await;
-            smol::spawn(handler).detach();
+            let connection_id = client.add_connection(client_conn).await;
+            smol::spawn(client.handle_messages(connection_id)).detach();
 
             let err = client
                 .request(

zed/src/workspace.rs 🔗

@@ -691,8 +691,8 @@ impl Workspace {
             // a TLS stream using `native-tls`.
             let stream = smol::net::TcpStream::connect(rpc_address).await?;
 
-            let (connection_id, handler) = rpc.add_connection(stream).await;
-            executor.spawn(handler).detach();
+            let connection_id = rpc.add_connection(stream).await;
+            executor.spawn(rpc.handle_messages(connection_id)).detach();
 
             let auth_response = rpc
                 .request(