Serialize RPC sends and responses using a channel

Max Brunsfeld created

Change summary

zed-rpc/src/peer.rs  | 344 ++++++++++++++++++++++-----------------------
zed-rpc/src/proto.rs |  25 ++
zed/src/rpc.rs       |   6 
3 files changed, 190 insertions(+), 185 deletions(-)

Detailed changes

zed-rpc/src/peer.rs 🔗

@@ -1,12 +1,9 @@
 use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, Context, Result};
 use async_lock::{Mutex, RwLock};
-use futures::{
-    future::{BoxFuture, Either},
-    AsyncRead, AsyncWrite, FutureExt,
-};
+use futures::{future::BoxFuture, AsyncRead, AsyncWrite, FutureExt};
 use postage::{
-    barrier, mpsc, oneshot,
+    mpsc,
     prelude::{Sink, Stream},
 };
 use std::{
@@ -15,29 +12,18 @@ use std::{
     fmt,
     future::Future,
     marker::PhantomData,
-    pin::Pin,
     sync::{
         atomic::{self, AtomicU32},
         Arc,
     },
 };
 
-type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
-type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
-
 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 pub struct ConnectionId(pub u32);
 
 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 pub struct PeerId(pub u32);
 
-struct Connection {
-    writer: Mutex<MessageStream<BoxedWriter>>,
-    reader: Mutex<MessageStream<BoxedReader>>,
-    response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
-    next_message_id: AtomicU32,
-}
-
 type MessageHandler = Box<
     dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
 >;
@@ -74,18 +60,34 @@ impl<T: RequestMessage> TypedEnvelope<T> {
 }
 
 pub struct Peer {
-    connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
-    connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
+    connections: RwLock<HashMap<ConnectionId, Connection>>,
     message_handlers: RwLock<Vec<MessageHandler>>,
     handler_types: Mutex<HashSet<TypeId>>,
     next_connection_id: AtomicU32,
 }
 
+#[derive(Clone)]
+struct Connection {
+    outgoing_tx: mpsc::Sender<proto::Envelope>,
+    next_message_id: Arc<AtomicU32>,
+    response_channels: ResponseChannels,
+}
+
+pub struct ConnectionHandler<Conn> {
+    peer: Arc<Peer>,
+    connection_id: ConnectionId,
+    response_channels: ResponseChannels,
+    outgoing_rx: mpsc::Receiver<proto::Envelope>,
+    reader: MessageStream<Conn>,
+    writer: MessageStream<Conn>,
+}
+
+type ResponseChannels = Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>;
+
 impl Peer {
     pub fn new() -> Arc<Self> {
         Arc::new(Self {
             connections: Default::default(),
-            connection_close_barriers: Default::default(),
             message_handlers: Default::default(),
             handler_types: Default::default(),
             next_connection_id: Default::default(),
@@ -127,7 +129,10 @@ impl Peer {
         rx
     }
 
-    pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
+    pub async fn add_connection<Conn>(
+        self: &Arc<Self>,
+        conn: Conn,
+    ) -> (ConnectionId, ConnectionHandler<Conn>)
     where
         Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
     {
@@ -135,120 +140,37 @@ impl Peer {
             self.next_connection_id
                 .fetch_add(1, atomic::Ordering::SeqCst),
         );
-        self.connections.write().await.insert(
+        let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
+        let connection = Connection {
+            outgoing_tx,
+            next_message_id: Default::default(),
+            response_channels: Default::default(),
+        };
+        let handler = ConnectionHandler {
+            peer: self.clone(),
             connection_id,
-            Arc::new(Connection {
-                reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
-                writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
-                response_channels: Default::default(),
-                next_message_id: Default::default(),
-            }),
-        );
-        connection_id
+            response_channels: connection.response_channels.clone(),
+            outgoing_rx,
+            reader: MessageStream::new(conn.clone()),
+            writer: MessageStream::new(conn),
+        };
+        self.connections
+            .write()
+            .await
+            .insert(connection_id, connection);
+        (connection_id, handler)
     }
 
     pub async fn disconnect(&self, connection_id: ConnectionId) {
         self.connections.write().await.remove(&connection_id);
-        self.connection_close_barriers
-            .write()
-            .await
-            .remove(&connection_id);
     }
 
     pub async fn reset(&self) {
         self.connections.write().await.clear();
-        self.connection_close_barriers.write().await.clear();
         self.handler_types.lock().await.clear();
         self.message_handlers.write().await.clear();
     }
 
-    pub fn handle_messages(
-        self: &Arc<Self>,
-        connection_id: ConnectionId,
-    ) -> impl Future<Output = Result<()>> + 'static {
-        let (close_tx, mut close_rx) = barrier::channel();
-        let this = self.clone();
-        async move {
-            this.connection_close_barriers
-                .write()
-                .await
-                .insert(connection_id, close_tx);
-            let connection = this.connection(connection_id).await?;
-            let closed = close_rx.recv();
-            futures::pin_mut!(closed);
-
-            loop {
-                let mut reader = connection.reader.lock().await;
-                let read_message = reader.read_message();
-                futures::pin_mut!(read_message);
-
-                match futures::future::select(read_message, &mut closed).await {
-                    Either::Left((Ok(incoming), _)) => {
-                        if let Some(responding_to) = incoming.responding_to {
-                            let channel = connection
-                                .response_channels
-                                .lock()
-                                .await
-                                .remove(&responding_to);
-                            if let Some(mut tx) = channel {
-                                tx.send(incoming).await.ok();
-                            } else {
-                                log::warn!(
-                                    "received RPC response to unknown request {}",
-                                    responding_to
-                                );
-                            }
-                        } else {
-                            let mut envelope = Some(incoming);
-                            let mut handler_index = None;
-                            let mut handler_was_dropped = false;
-                            for (i, handler) in
-                                this.message_handlers.read().await.iter().enumerate()
-                            {
-                                if let Some(future) = handler(&mut envelope, connection_id) {
-                                    handler_was_dropped = future.await;
-                                    handler_index = Some(i);
-                                    break;
-                                }
-                            }
-
-                            if let Some(handler_index) = handler_index {
-                                if handler_was_dropped {
-                                    drop(this.message_handlers.write().await.remove(handler_index));
-                                }
-                            } else {
-                                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
-                            }
-                        }
-                    }
-                    Either::Left((Err(error), _)) => {
-                        log::warn!("received invalid RPC message: {}", error);
-                        Err(error)?;
-                    }
-                    Either::Right(_) => return Ok(()),
-                }
-            }
-        }
-    }
-
-    pub async fn receive<M: EnvelopedMessage>(
-        self: &Arc<Self>,
-        connection_id: ConnectionId,
-    ) -> Result<TypedEnvelope<M>> {
-        let connection = self.connection(connection_id).await?;
-        let envelope = connection.reader.lock().await.read_message().await?;
-        let original_sender_id = envelope.original_sender_id;
-        let message_id = envelope.id;
-        let payload =
-            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
-        Ok(TypedEnvelope {
-            sender_id: connection_id,
-            original_sender_id: original_sender_id.map(PeerId),
-            message_id,
-            payload,
-        })
-    }
-
     pub fn request<T: RequestMessage>(
         self: &Arc<Self>,
         receiver_id: ConnectionId,
@@ -273,9 +195,9 @@ impl Peer {
         request: T,
     ) -> impl Future<Output = Result<T::Response>> {
         let this = self.clone();
-        let (tx, mut rx) = oneshot::channel();
+        let (tx, mut rx) = mpsc::channel(1);
         async move {
-            let connection = this.connection(receiver_id).await?;
+            let mut connection = this.connection(receiver_id).await?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
@@ -285,19 +207,13 @@ impl Peer {
                 .await
                 .insert(message_id, tx);
             connection
-                .writer
-                .lock()
-                .await
-                .write_message(&request.into_envelope(
-                    message_id,
-                    None,
-                    original_sender_id.map(|id| id.0),
-                ))
+                .outgoing_tx
+                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
                 .await?;
             let response = rx
                 .recv()
                 .await
-                .expect("response channel was unexpectedly dropped");
+                .ok_or_else(|| anyhow!("connection was closed"))?;
             T::Response::from_envelope(response)
                 .ok_or_else(|| anyhow!("received response of the wrong type"))
         }
@@ -305,20 +221,18 @@ impl Peer {
 
     pub fn send<T: EnvelopedMessage>(
         self: &Arc<Self>,
-        connection_id: ConnectionId,
+        receiver_id: ConnectionId,
         message: T,
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
-            let connection = this.connection(connection_id).await?;
+            let mut connection = this.connection(receiver_id).await?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
             connection
-                .writer
-                .lock()
-                .await
-                .write_message(&message.into_envelope(message_id, None, None))
+                .outgoing_tx
+                .send(message.into_envelope(message_id, None, None))
                 .await?;
             Ok(())
         }
@@ -332,15 +246,13 @@ impl Peer {
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
-            let connection = this.connection(receiver_id).await?;
+            let mut connection = this.connection(receiver_id).await?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
             connection
-                .writer
-                .lock()
-                .await
-                .write_message(&message.into_envelope(message_id, None, Some(sender_id.0)))
+                .outgoing_tx
+                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
                 .await?;
             Ok(())
         }
@@ -353,28 +265,114 @@ impl Peer {
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
-            let connection = this.connection(receipt.sender_id).await?;
+            let mut connection = this.connection(receipt.sender_id).await?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
             connection
-                .writer
-                .lock()
-                .await
-                .write_message(&response.into_envelope(message_id, Some(receipt.message_id), None))
+                .outgoing_tx
+                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
                 .await?;
             Ok(())
         }
     }
 
-    async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
-        Ok(self
-            .connections
-            .read()
-            .await
-            .get(&id)
-            .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
-            .clone())
+    fn connection(
+        self: &Arc<Self>,
+        connection_id: ConnectionId,
+    ) -> impl Future<Output = Result<Connection>> {
+        let this = self.clone();
+        async move {
+            let connections = this.connections.read().await;
+            let connection = connections
+                .get(&connection_id)
+                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
+            Ok(connection.clone())
+        }
+    }
+}
+
+impl<Conn> ConnectionHandler<Conn>
+where
+    Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
+{
+    pub async fn run(mut self) -> Result<()> {
+        loop {
+            let read_message = self.reader.read_message().fuse();
+            futures::pin_mut!(read_message);
+            loop {
+                futures::select! {
+                    incoming = read_message => match incoming {
+                        Ok(incoming) => {
+                            Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await;
+                            break;
+                        }
+                        Err(error) => {
+                            self.response_channels.lock().await.clear();
+                            Err(error).context("received invalid RPC message")?;
+                        }
+                    },
+                    outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
+                        Some(outgoing) => {
+                            if let Err(result) = self.writer.write_message(&outgoing).await {
+                                self.response_channels.lock().await.clear();
+                                Err(result).context("failed to write RPC message")?;
+                            }
+                        }
+                        None => return Ok(()),
+                    }
+                }
+            }
+        }
+    }
+
+    pub async fn receive<M: EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
+        let envelope = self.reader.read_message().await?;
+        let original_sender_id = envelope.original_sender_id;
+        let message_id = envelope.id;
+        let payload =
+            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
+        Ok(TypedEnvelope {
+            sender_id: self.connection_id,
+            original_sender_id: original_sender_id.map(PeerId),
+            message_id,
+            payload,
+        })
+    }
+
+    async fn handle_incoming_message(
+        message: proto::Envelope,
+        peer: &Arc<Peer>,
+        connection_id: ConnectionId,
+        response_channels: &ResponseChannels,
+    ) {
+        if let Some(responding_to) = message.responding_to {
+            let channel = response_channels.lock().await.remove(&responding_to);
+            if let Some(mut tx) = channel {
+                tx.send(message).await.ok();
+            } else {
+                log::warn!("received RPC response to unknown request {}", responding_to);
+            }
+        } else {
+            let mut envelope = Some(message);
+            let mut handler_index = None;
+            let mut handler_was_dropped = false;
+            for (i, handler) in peer.message_handlers.read().await.iter().enumerate() {
+                if let Some(future) = handler(&mut envelope, connection_id) {
+                    handler_was_dropped = future.await;
+                    handler_index = Some(i);
+                    break;
+                }
+            }
+
+            if let Some(handler_index) = handler_index {
+                if handler_was_dropped {
+                    drop(peer.message_handlers.write().await.remove(handler_index));
+                }
+            } else {
+                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
+            }
+        }
     }
 }
 
@@ -412,22 +410,22 @@ mod tests {
             let server = Peer::new();
             let client1 = Peer::new();
             let client2 = Peer::new();
-            let client1_conn_id = client1
+            let (client1_conn_id, task1) = client1
                 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
                 .await;
-            let client2_conn_id = client2
+            let (client2_conn_id, task2) = client2
                 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
                 .await;
-            let server_conn_id1 = server
+            let (_, task3) = server
                 .add_connection(listener.accept().await.unwrap().0)
                 .await;
-            let server_conn_id2 = server
+            let (_, task4) = server
                 .add_connection(listener.accept().await.unwrap().0)
                 .await;
-            smol::spawn(client1.handle_messages(client1_conn_id)).detach();
-            smol::spawn(client2.handle_messages(client2_conn_id)).detach();
-            smol::spawn(server.handle_messages(server_conn_id1)).detach();
-            smol::spawn(server.handle_messages(server_conn_id2)).detach();
+            smol::spawn(task1.run()).detach();
+            smol::spawn(task2.run()).detach();
+            smol::spawn(task3.run()).detach();
+            smol::spawn(task4.run()).detach();
 
             // define the expected requests and responses
             let request1 = proto::Auth {
@@ -548,12 +546,11 @@ mod tests {
             let (mut server_conn, _) = listener.accept().await.unwrap();
 
             let client = Peer::new();
-            let connection_id = client.add_connection(client_conn).await;
+            let (connection_id, handler) = client.add_connection(client_conn).await;
             let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
-                barrier::channel();
-            let handle_messages = client.handle_messages(connection_id);
+                postage::barrier::channel();
             smol::spawn(async move {
-                handle_messages.await.ok();
+                handler.run().await.ok();
                 incoming_messages_ended_tx.send(()).await.unwrap();
             })
             .detach();
@@ -576,8 +573,8 @@ mod tests {
             client_conn.close().await.unwrap();
 
             let client = Peer::new();
-            let connection_id = client.add_connection(client_conn).await;
-            smol::spawn(client.handle_messages(connection_id)).detach();
+            let (connection_id, handler) = client.add_connection(client_conn).await;
+            smol::spawn(handler.run()).detach();
 
             let err = client
                 .request(
@@ -589,10 +586,7 @@ mod tests {
                 )
                 .await
                 .unwrap_err();
-            assert_eq!(
-                err.downcast_ref::<io::Error>().unwrap().kind(),
-                io::ErrorKind::BrokenPipe
-            );
+            assert_eq!(err.to_string(), "connection was closed");
         });
     }
 }

zed-rpc/src/proto.rs 🔗

@@ -82,6 +82,7 @@ message!(RemoveGuest);
 pub struct MessageStream<T> {
     byte_stream: T,
     buffer: Vec<u8>,
+    upcoming_message_len: Option<usize>,
 }
 
 impl<T> MessageStream<T> {
@@ -89,6 +90,7 @@ impl<T> MessageStream<T> {
         Self {
             byte_stream,
             buffer: Default::default(),
+            upcoming_message_len: None,
         }
     }
 
@@ -120,12 +122,23 @@ where
 {
     /// Read a protobuf message of the given type from the stream.
     pub async fn read_message(&mut self) -> io::Result<Envelope> {
-        let mut delimiter_buf = [0; 4];
-        self.byte_stream.read_exact(&mut delimiter_buf).await?;
-        let message_len = u32::from_be_bytes(delimiter_buf) as usize;
-        self.buffer.resize(message_len, 0);
-        self.byte_stream.read_exact(&mut self.buffer).await?;
-        Ok(Envelope::decode(self.buffer.as_slice())?)
+        loop {
+            if let Some(upcoming_message_len) = self.upcoming_message_len {
+                self.buffer.resize(upcoming_message_len, 0);
+                self.byte_stream.read_exact(&mut self.buffer).await?;
+                self.upcoming_message_len = None;
+                return Ok(Envelope::decode(self.buffer.as_slice())?);
+            } else {
+                self.buffer.resize(4, 0);
+                self.byte_stream.read_exact(&mut self.buffer).await?;
+                self.upcoming_message_len = Some(u32::from_be_bytes([
+                    self.buffer[0],
+                    self.buffer[1],
+                    self.buffer[2],
+                    self.buffer[3],
+                ]) as usize);
+            }
+        }
     }
 }
 

zed/src/rpc.rs 🔗

@@ -121,10 +121,8 @@ impl Client {
         let stream = smol::net::TcpStream::connect(&address).await?;
         log::info!("connected to rpc address {}", address);
 
-        let connection_id = self.peer.add_connection(stream).await;
-        executor
-            .spawn(self.peer.handle_messages(connection_id))
-            .detach();
+        let (connection_id, handler) = self.peer.add_connection(stream).await;
+        executor.spawn(handler.run()).detach();
 
         let auth_response = self
             .peer