Refactor zed-rpc to work with websockets

Antonio Scandurra created

Change summary

Cargo.lock           |  68 +++++++++++++++++++++++
zed-rpc/Cargo.toml   |   1 
zed-rpc/src/lib.rs   |   2 
zed-rpc/src/peer.rs  |  86 ++++++++++++++----------------
zed-rpc/src/proto.rs | 131 +++++++++++----------------------------------
zed-rpc/src/test.rs  |  64 ++++++++++++++++++++++
6 files changed, 208 insertions(+), 144 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -286,6 +286,19 @@ dependencies = [
  "syn",
 ]
 
+[[package]]
+name = "async-tungstenite"
+version = "0.14.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8645e929ec7964448a901db9da30cd2ae8c7fecf4d6176af427837531dbbb63b"
+dependencies = [
+ "futures-io",
+ "futures-util",
+ "log",
+ "pin-project-lite",
+ "tungstenite",
+]
+
 [[package]]
 name = "atomic"
 version = "0.5.0"
@@ -1713,6 +1726,12 @@ dependencies = [
  "url",
 ]
 
+[[package]]
+name = "httparse"
+version = "1.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f3a87b616e37e93c22fb19bcd386f02f3af5ea98a25670ad0fce773de23c5e68"
+
 [[package]]
 name = "humantime"
 version = "2.1.0"
@@ -1806,6 +1825,15 @@ dependencies = [
  "adler32",
 ]
 
+[[package]]
+name = "input_buffer"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f97967975f448f1a7ddb12b0bc41069d09ed6a1c161a92687e057325db35d413"
+dependencies = [
+ "bytes 1.0.1",
+]
+
 [[package]]
 name = "instant"
 version = "0.1.9"
@@ -3304,6 +3332,19 @@ dependencies = [
  "pkg-config",
 ]
 
+[[package]]
+name = "sha-1"
+version = "0.9.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8c4cfa741c5832d0ef7fab46cabed29c2aae926db0b11bb2069edd8db5e64e16"
+dependencies = [
+ "block-buffer",
+ "cfg-if 1.0.0",
+ "cpufeatures",
+ "digest",
+ "opaque-debug",
+]
+
 [[package]]
 name = "sha1"
 version = "0.2.0"
@@ -3920,6 +3961,26 @@ version = "0.12.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "85e00391c1f3d171490a3f8bd79999b0002ae38d3da0d6a3a306c754b053d71b"
 
+[[package]]
+name = "tungstenite"
+version = "0.13.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5fe8dada8c1a3aeca77d6b51a4f1314e0f4b8e438b7b1b71e3ddaca8080e4093"
+dependencies = [
+ "base64 0.13.0",
+ "byteorder",
+ "bytes 1.0.1",
+ "http",
+ "httparse",
+ "input_buffer",
+ "log",
+ "rand 0.8.3",
+ "sha-1",
+ "thiserror",
+ "url",
+ "utf-8",
+]
+
 [[package]]
 name = "typenum"
 version = "1.13.0"
@@ -4057,6 +4118,12 @@ dependencies = [
  "xmlwriter",
 ]
 
+[[package]]
+name = "utf-8"
+version = "0.7.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
+
 [[package]]
 name = "uuid"
 version = "0.5.1"
@@ -4356,6 +4423,7 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "async-lock",
+ "async-tungstenite",
  "base64 0.13.0",
  "futures",
  "log",

zed-rpc/Cargo.toml 🔗

@@ -7,6 +7,7 @@ version = "0.1.0"
 [dependencies]
 anyhow = "1.0"
 async-lock = "2.4"
+async-tungstenite = "0.14"
 base64 = "0.13"
 futures = "0.3"
 log = "0.4"

zed-rpc/src/lib.rs 🔗

@@ -2,5 +2,7 @@ pub mod auth;
 mod peer;
 pub mod proto;
 pub mod rest;
+#[cfg(test)]
+mod test;
 
 pub use peer::*;

zed-rpc/src/peer.rs 🔗

@@ -1,7 +1,12 @@
 use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
 use anyhow::{anyhow, Context, Result};
 use async_lock::{Mutex, RwLock};
-use futures::{future::BoxFuture, AsyncRead, AsyncWrite, FutureExt};
+use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
+use futures::{
+    future::BoxFuture,
+    stream::{SplitSink, SplitStream},
+    FutureExt, StreamExt,
+};
 use postage::{
     mpsc,
     prelude::{Sink, Stream},
@@ -72,13 +77,13 @@ struct Connection {
     response_channels: ResponseChannels,
 }
 
-pub struct ConnectionHandler<Conn> {
+pub struct ConnectionHandler<W, R> {
     peer: Arc<Peer>,
     connection_id: ConnectionId,
     response_channels: ResponseChannels,
     outgoing_rx: mpsc::Receiver<proto::Envelope>,
-    reader: MessageStream<Conn>,
-    writer: MessageStream<Conn>,
+    writer: MessageStream<W>,
+    reader: MessageStream<R>,
 }
 
 type ResponseChannels = Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>;
@@ -131,10 +136,16 @@ impl Peer {
     pub async fn add_connection<Conn>(
         self: &Arc<Self>,
         conn: Conn,
-    ) -> (ConnectionId, ConnectionHandler<Conn>)
+    ) -> (
+        ConnectionId,
+        ConnectionHandler<SplitSink<Conn, WebSocketMessage>, SplitStream<Conn>>,
+    )
     where
-        Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
+        Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
+            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+            + Unpin,
     {
+        let (tx, rx) = conn.split();
         let connection_id = ConnectionId(
             self.next_connection_id
                 .fetch_add(1, atomic::Ordering::SeqCst),
@@ -150,8 +161,8 @@ impl Peer {
             connection_id,
             response_channels: connection.response_channels.clone(),
             outgoing_rx,
-            reader: MessageStream::new(conn.clone()),
-            writer: MessageStream::new(conn),
+            writer: MessageStream::new(tx),
+            reader: MessageStream::new(rx),
         };
         self.connections
             .write()
@@ -291,9 +302,10 @@ impl Peer {
     }
 }
 
-impl<Conn> ConnectionHandler<Conn>
+impl<W, R> ConnectionHandler<W, R>
 where
-    Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
+    W: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
+    R: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
 {
     pub async fn run(mut self) -> Result<()> {
         loop {
@@ -402,38 +414,25 @@ impl fmt::Display for PeerId {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::test;
     use postage::oneshot;
-    use smol::{
-        io::AsyncWriteExt,
-        net::unix::{UnixListener, UnixStream},
-    };
-    use std::io;
-    use tempdir::TempDir;
 
     #[test]
     fn test_request_response() {
         smol::block_on(async move {
-            // create socket
-            let socket_dir_path = TempDir::new("test-request-response").unwrap();
-            let socket_path = socket_dir_path.path().join("test.sock");
-            let listener = UnixListener::bind(&socket_path).unwrap();
-
             // create 2 clients connected to 1 server
             let server = Peer::new();
             let client1 = Peer::new();
             let client2 = Peer::new();
-            let (client1_conn_id, task1) = client1
-                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
-                .await;
-            let (client2_conn_id, task2) = client2
-                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
-                .await;
-            let (_, task3) = server
-                .add_connection(listener.accept().await.unwrap().0)
-                .await;
-            let (_, task4) = server
-                .add_connection(listener.accept().await.unwrap().0)
-                .await;
+
+            let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
+            let (client1_conn_id, task1) = client1.add_connection(client1_to_server_conn).await;
+            let (_, task2) = server.add_connection(server_to_client_1_conn).await;
+
+            let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
+            let (client2_conn_id, task3) = client2.add_connection(client2_to_server_conn).await;
+            let (_, task4) = server.add_connection(server_to_client_2_conn).await;
+
             smol::spawn(task1.run()).detach();
             smol::spawn(task2.run()).detach();
             smol::spawn(task3.run()).detach();
@@ -553,11 +552,7 @@ mod tests {
     #[test]
     fn test_disconnect() {
         smol::block_on(async move {
-            let socket_dir_path = TempDir::new("drop-client").unwrap();
-            let socket_path = socket_dir_path.path().join(".sock");
-            let listener = UnixListener::bind(&socket_path).unwrap();
-            let client_conn = UnixStream::connect(&socket_path).await.unwrap();
-            let (mut server_conn, _) = listener.accept().await.unwrap();
+            let (client_conn, mut server_conn) = test::Channel::bidirectional();
 
             let client = Peer::new();
             let (connection_id, handler) = client.add_connection(client_conn).await;
@@ -571,20 +566,19 @@ mod tests {
             client.disconnect(connection_id).await;
 
             incoming_messages_ended_rx.recv().await;
-
-            let err = server_conn.write(&[]).await.unwrap_err();
-            assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
+            assert!(
+                futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
+                    .await
+                    .is_err()
+            );
         });
     }
 
     #[test]
     fn test_io_error() {
         smol::block_on(async move {
-            let socket_dir_path = TempDir::new("io-error").unwrap();
-            let socket_path = socket_dir_path.path().join(".sock");
-            let _listener = UnixListener::bind(&socket_path).unwrap();
-            let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
-            client_conn.close().await.unwrap();
+            let (client_conn, server_conn) = test::Channel::bidirectional();
+            drop(server_conn);
 
             let client = Peer::new();
             let (connection_id, handler) = client.add_connection(client_conn).await;

zed-rpc/src/proto.rs 🔗

@@ -1,7 +1,7 @@
-use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _};
+use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
+use futures::{SinkExt as _, StreamExt as _};
 use prost::Message;
 use std::{
-    convert::TryInto,
     io,
     time::{Duration, SystemTime, UNIX_EPOCH},
 };
@@ -81,66 +81,52 @@ message!(AddPeer);
 message!(RemovePeer);
 
 /// A stream of protobuf messages.
-pub struct MessageStream<T> {
-    byte_stream: T,
-    buffer: Vec<u8>,
-    upcoming_message_len: Option<usize>,
+pub struct MessageStream<S> {
+    stream: S,
 }
 
-impl<T> MessageStream<T> {
-    pub fn new(byte_stream: T) -> Self {
-        Self {
-            byte_stream,
-            buffer: Default::default(),
-            upcoming_message_len: None,
-        }
+impl<S> MessageStream<S> {
+    pub fn new(stream: S) -> Self {
+        Self { stream }
     }
 
-    pub fn inner_mut(&mut self) -> &mut T {
-        &mut self.byte_stream
+    pub fn inner_mut(&mut self) -> &mut S {
+        &mut self.stream
     }
 }
 
-impl<T> MessageStream<T>
+impl<S> MessageStream<S>
 where
-    T: AsyncWrite + Unpin,
+    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
 {
     /// Write a given protobuf message to the stream.
-    pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
-        let message_len: u32 = message
-            .encoded_len()
-            .try_into()
-            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
-        self.buffer.clear();
-        self.buffer.extend_from_slice(&message_len.to_be_bytes());
-        message.encode(&mut self.buffer)?;
-        self.byte_stream.write_all(&self.buffer).await
+    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
+        let mut buffer = Vec::with_capacity(message.encoded_len());
+        message
+            .encode(&mut buffer)
+            .map_err(|err| io::Error::from(err))?;
+        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
+        Ok(())
     }
 }
 
-impl<T> MessageStream<T>
+impl<S> MessageStream<S>
 where
-    T: AsyncRead + Unpin,
+    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
 {
     /// Read a protobuf message of the given type from the stream.
-    pub async fn read_message(&mut self) -> io::Result<Envelope> {
-        loop {
-            if let Some(upcoming_message_len) = self.upcoming_message_len {
-                self.buffer.resize(upcoming_message_len, 0);
-                self.byte_stream.read_exact(&mut self.buffer).await?;
-                self.upcoming_message_len = None;
-                return Ok(Envelope::decode(self.buffer.as_slice())?);
-            } else {
-                self.buffer.resize(4, 0);
-                self.byte_stream.read_exact(&mut self.buffer).await?;
-                self.upcoming_message_len = Some(u32::from_be_bytes([
-                    self.buffer[0],
-                    self.buffer[1],
-                    self.buffer[2],
-                    self.buffer[3],
-                ]) as usize);
+    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
+        while let Some(bytes) = self.stream.next().await {
+            match bytes? {
+                WebSocketMessage::Binary(bytes) => {
+                    let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
+                    return Ok(envelope);
+                }
+                WebSocketMessage::Close(_) => break,
+                _ => {}
             }
         }
+        Err(WebSocketError::ConnectionClosed)
     }
 }
 
@@ -165,20 +151,12 @@ impl From<SystemTime> for Timestamp {
 #[cfg(test)]
 mod tests {
     use super::*;
-    use std::{
-        pin::Pin,
-        task::{Context, Poll},
-    };
+    use crate::test;
 
     #[test]
     fn test_round_trip_message() {
         smol::block_on(async {
-            let byte_stream = ChunkedStream {
-                bytes: Vec::new(),
-                read_offset: 0,
-                chunk_size: 3,
-            };
-
+            let stream = test::Channel::new();
             let message1 = Auth {
                 user_id: 5,
                 access_token: "the-access-token".into(),
@@ -191,7 +169,7 @@ mod tests {
             }
             .into_envelope(5, None, None);
 
-            let mut message_stream = MessageStream::new(byte_stream);
+            let mut message_stream = MessageStream::new(stream);
             message_stream.write_message(&message1).await.unwrap();
             message_stream.write_message(&message2).await.unwrap();
             let decoded_message1 = message_stream.read_message().await.unwrap();
@@ -200,47 +178,4 @@ mod tests {
             assert_eq!(decoded_message2, message2);
         });
     }
-
-    struct ChunkedStream {
-        bytes: Vec<u8>,
-        read_offset: usize,
-        chunk_size: usize,
-    }
-
-    impl AsyncWrite for ChunkedStream {
-        fn poll_write(
-            mut self: Pin<&mut Self>,
-            _: &mut Context<'_>,
-            buf: &[u8],
-        ) -> Poll<io::Result<usize>> {
-            let bytes_written = buf.len().min(self.chunk_size);
-            self.bytes.extend_from_slice(&buf[0..bytes_written]);
-            Poll::Ready(Ok(bytes_written))
-        }
-
-        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
-            Poll::Ready(Ok(()))
-        }
-
-        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
-            Poll::Ready(Ok(()))
-        }
-    }
-
-    impl AsyncRead for ChunkedStream {
-        fn poll_read(
-            mut self: Pin<&mut Self>,
-            _: &mut Context<'_>,
-            buf: &mut [u8],
-        ) -> Poll<io::Result<usize>> {
-            let bytes_read = buf
-                .len()
-                .min(self.chunk_size)
-                .min(self.bytes.len() - self.read_offset);
-            let end_offset = self.read_offset + bytes_read;
-            buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
-            self.read_offset = end_offset;
-            Poll::Ready(Ok(bytes_read))
-        }
-    }
 }

zed-rpc/src/test.rs 🔗

@@ -0,0 +1,64 @@
+use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
+use std::{
+    io,
+    pin::Pin,
+    task::{Context, Poll},
+};
+
+pub struct Channel {
+    tx: futures::channel::mpsc::UnboundedSender<WebSocketMessage>,
+    rx: futures::channel::mpsc::UnboundedReceiver<WebSocketMessage>,
+}
+
+impl Channel {
+    pub fn new() -> Self {
+        let (tx, rx) = futures::channel::mpsc::unbounded();
+        Self { tx, rx }
+    }
+
+    pub fn bidirectional() -> (Self, Self) {
+        let (a_tx, a_rx) = futures::channel::mpsc::unbounded();
+        let (b_tx, b_rx) = futures::channel::mpsc::unbounded();
+        let a = Self { tx: a_tx, rx: b_rx };
+        let b = Self { tx: b_tx, rx: a_rx };
+        (a, b)
+    }
+}
+
+impl futures::Sink<WebSocketMessage> for Channel {
+    type Error = WebSocketError;
+
+    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+        Pin::new(&mut self.tx)
+            .poll_ready(cx)
+            .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
+    }
+
+    fn start_send(mut self: Pin<&mut Self>, item: WebSocketMessage) -> Result<(), Self::Error> {
+        Pin::new(&mut self.tx)
+            .start_send(item)
+            .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
+    }
+
+    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+        Pin::new(&mut self.tx)
+            .poll_flush(cx)
+            .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
+    }
+
+    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+        Pin::new(&mut self.tx)
+            .poll_close(cx)
+            .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
+    }
+}
+
+impl futures::Stream for Channel {
+    type Item = Result<WebSocketMessage, WebSocketError>;
+
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+        Pin::new(&mut self.rx)
+            .poll_next(cx)
+            .map(|i| i.map(|i| Ok(i)))
+    }
+}