Use a fixed-length delimiter for encoding/decoding messages in RPC

Antonio Scandurra and Max Brunsfeld created

Co-Authored-By: Max Brunsfeld <max@zed.dev>

Change summary

zed-rpc/src/proto.rs  | 109 +++++---------------------------------------
zed/src/rpc_client.rs |  24 +++++++++
2 files changed, 36 insertions(+), 97 deletions(-)

Detailed changes

zed-rpc/src/proto.rs 🔗

@@ -1,7 +1,7 @@
 use futures_io::{AsyncRead, AsyncWrite};
 use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
 use prost::Message;
-use std::io;
+use std::{convert::TryInto, io};
 
 include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 
@@ -96,9 +96,14 @@ where
     T: AsyncWrite + Unpin,
 {
     /// Write a given protobuf message to the stream.
-    pub async fn write_message(&mut self, message: &impl Message) -> futures_io::Result<()> {
+    pub async fn write_message(&mut self, message: &impl Message) -> io::Result<()> {
+        let message_len: u32 = message
+            .encoded_len()
+            .try_into()
+            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
         self.buffer.clear();
-        message.encode_length_delimited(&mut self.buffer).unwrap();
+        self.buffer.extend_from_slice(&message_len.to_be_bytes());
+        message.encode(&mut self.buffer)?;
         self.byte_stream.write_all(&self.buffer).await
     }
 }
@@ -109,44 +114,12 @@ where
 {
     /// Read a protobuf message of the given type from the stream.
     pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
-        // Ensure the buffer is large enough to hold the maximum delimiter length
-        const MAX_DELIMITER_LEN: usize = 10;
-        self.buffer.resize(MAX_DELIMITER_LEN, 0);
-
-        // Read until a complete length delimiter can be decoded.
-        let mut read_start_offset = 0;
-        let (encoded_len, delimiter_len) = loop {
-            let bytes_read = self
-                .byte_stream
-                .read(&mut self.buffer[read_start_offset..])
-                .await?;
-            read_start_offset += bytes_read;
-
-            let mut buffer = &self.buffer[0..read_start_offset];
-            match prost::decode_length_delimiter(&mut buffer) {
-                Err(_) => {
-                    if read_start_offset >= MAX_DELIMITER_LEN {
-                        return Err(io::Error::new(
-                            io::ErrorKind::InvalidData,
-                            "invalid message length delimiter",
-                        ));
-                    }
-                }
-                Ok(encoded_len) => {
-                    let delimiter_len = read_start_offset - buffer.len();
-                    break (encoded_len, delimiter_len);
-                }
-            }
-        };
-
-        // Read the message itself.
-        self.buffer.resize(delimiter_len + encoded_len, 0);
-        self.byte_stream
-            .read_exact(&mut self.buffer[read_start_offset..])
-            .await?;
-        let message = M::decode(&self.buffer[delimiter_len..])?;
-
-        Ok(message)
+        let mut delimiter_buf = [0; 4];
+        self.byte_stream.read_exact(&mut delimiter_buf).await?;
+        let message_len = u32::from_be_bytes(delimiter_buf) as usize;
+        self.buffer.resize(message_len, 0);
+        self.byte_stream.read_exact(&mut self.buffer).await?;
+        Ok(M::decode(self.buffer.as_slice())?)
     }
 }
 
@@ -196,60 +169,6 @@ mod tests {
         });
     }
 
-    #[test]
-    fn test_read_message_when_length_delimiter_is_not_complete_in_first_read() {
-        smol::block_on(async {
-            let byte_stream = ChunkedStream {
-                bytes: Vec::new(),
-                read_offset: 0,
-                chunk_size: 2,
-            };
-
-            // This message is so long that its length delimiter requires three bytes,
-            // so it won't be delivered in a single read from the chunked byte stream.
-            let message = FromClient {
-                id: 4,
-                variant: Some(from_client::Variant::UploadFile(from_client::UploadFile {
-                    path: Vec::new(),
-                    content: "long ".repeat(256 * 256).into(),
-                })),
-            };
-            assert!(prost::length_delimiter_len(message.encoded_len()) > byte_stream.chunk_size);
-
-            let mut message_stream = MessageStream::new(byte_stream);
-            message_stream.write_message(&message).await.unwrap();
-            let decoded_message = message_stream.read_message::<FromClient>().await.unwrap();
-            assert_eq!(decoded_message, message);
-        });
-    }
-
-    #[test]
-    fn test_protobuf_parse_error() {
-        smol::block_on(async {
-            let mut byte_stream = ChunkedStream {
-                bytes: Vec::new(),
-                read_offset: 0,
-                chunk_size: 2,
-            };
-
-            let message = FromClient {
-                id: 3,
-                variant: Some(from_client::Variant::Auth(from_client::Auth {
-                    user_id: 5,
-                    access_token: "the-access-token".into(),
-                })),
-            };
-
-            byte_stream.write_all(b"omg").await.unwrap();
-            let mut message_stream = MessageStream::new(byte_stream);
-            message_stream.write_message(&message).await.unwrap();
-
-            // Read the wrong type of message from the stream.
-            let result = message_stream.read_message::<FromServer>().await;
-            assert!(result.is_err());
-        });
-    }
-
     struct ChunkedStream {
         bytes: Vec<u8>,
         read_offset: usize,

zed/src/rpc_client.rs 🔗

@@ -103,8 +103,16 @@ where
             .ok_or_else(|| anyhow!("received response of the wrong t"))
     }
 
-    pub async fn send<T: SendMessage>(_: T) -> Result<()> {
-        todo!()
+    pub async fn send<T: SendMessage>(&mut self, message: T) -> Result<()> {
+        let message_id = self.next_message_id;
+        self.next_message_id += 1;
+        self.stream
+            .write_message(&proto::FromClient {
+                id: message_id,
+                variant: Some(message.to_variant()),
+            })
+            .await?;
+        Ok(())
     }
 }
 
@@ -152,6 +160,18 @@ mod tests {
             ))
         );
 
+        // Respond to another request to ensure requests are properly matched up.
+        server_stream
+            .write_message(&proto::FromServer {
+                request_id: Some(999),
+                variant: Some(proto::from_server::Variant::AuthResponse(
+                    proto::from_server::AuthResponse {
+                        credentials_valid: false,
+                    },
+                )),
+            })
+            .await
+            .unwrap();
         server_stream
             .write_message(&proto::FromServer {
                 request_id: Some(server_req.id),