@@ -1,7 +1,7 @@
use futures_io::{AsyncRead, AsyncWrite};
use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
use prost::Message;
-use std::io;
+use std::{convert::TryInto, io};
include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
@@ -96,9 +96,14 @@ where
T: AsyncWrite + Unpin,
{
/// Write a given protobuf message to the stream.
- pub async fn write_message(&mut self, message: &impl Message) -> futures_io::Result<()> {
+ pub async fn write_message(&mut self, message: &impl Message) -> io::Result<()> {
+ let message_len: u32 = message
+ .encoded_len()
+ .try_into()
+ .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
self.buffer.clear();
- message.encode_length_delimited(&mut self.buffer).unwrap();
+ self.buffer.extend_from_slice(&message_len.to_be_bytes());
+ message.encode(&mut self.buffer)?;
self.byte_stream.write_all(&self.buffer).await
}
}
@@ -109,44 +114,12 @@ where
{
/// Read a protobuf message of the given type from the stream.
pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
- // Ensure the buffer is large enough to hold the maximum delimiter length
- const MAX_DELIMITER_LEN: usize = 10;
- self.buffer.resize(MAX_DELIMITER_LEN, 0);
-
- // Read until a complete length delimiter can be decoded.
- let mut read_start_offset = 0;
- let (encoded_len, delimiter_len) = loop {
- let bytes_read = self
- .byte_stream
- .read(&mut self.buffer[read_start_offset..])
- .await?;
- read_start_offset += bytes_read;
-
- let mut buffer = &self.buffer[0..read_start_offset];
- match prost::decode_length_delimiter(&mut buffer) {
- Err(_) => {
- if read_start_offset >= MAX_DELIMITER_LEN {
- return Err(io::Error::new(
- io::ErrorKind::InvalidData,
- "invalid message length delimiter",
- ));
- }
- }
- Ok(encoded_len) => {
- let delimiter_len = read_start_offset - buffer.len();
- break (encoded_len, delimiter_len);
- }
- }
- };
-
- // Read the message itself.
- self.buffer.resize(delimiter_len + encoded_len, 0);
- self.byte_stream
- .read_exact(&mut self.buffer[read_start_offset..])
- .await?;
- let message = M::decode(&self.buffer[delimiter_len..])?;
-
- Ok(message)
+ let mut delimiter_buf = [0; 4];
+ self.byte_stream.read_exact(&mut delimiter_buf).await?;
+ let message_len = u32::from_be_bytes(delimiter_buf) as usize;
+ self.buffer.resize(message_len, 0);
+ self.byte_stream.read_exact(&mut self.buffer).await?;
+ Ok(M::decode(self.buffer.as_slice())?)
}
}
@@ -196,60 +169,6 @@ mod tests {
});
}
- #[test]
- fn test_read_message_when_length_delimiter_is_not_complete_in_first_read() {
- smol::block_on(async {
- let byte_stream = ChunkedStream {
- bytes: Vec::new(),
- read_offset: 0,
- chunk_size: 2,
- };
-
- // This message is so long that its length delimiter requires three bytes,
- // so it won't be delivered in a single read from the chunked byte stream.
- let message = FromClient {
- id: 4,
- variant: Some(from_client::Variant::UploadFile(from_client::UploadFile {
- path: Vec::new(),
- content: "long ".repeat(256 * 256).into(),
- })),
- };
- assert!(prost::length_delimiter_len(message.encoded_len()) > byte_stream.chunk_size);
-
- let mut message_stream = MessageStream::new(byte_stream);
- message_stream.write_message(&message).await.unwrap();
- let decoded_message = message_stream.read_message::<FromClient>().await.unwrap();
- assert_eq!(decoded_message, message);
- });
- }
-
- #[test]
- fn test_protobuf_parse_error() {
- smol::block_on(async {
- let mut byte_stream = ChunkedStream {
- bytes: Vec::new(),
- read_offset: 0,
- chunk_size: 2,
- };
-
- let message = FromClient {
- id: 3,
- variant: Some(from_client::Variant::Auth(from_client::Auth {
- user_id: 5,
- access_token: "the-access-token".into(),
- })),
- };
-
- byte_stream.write_all(b"omg").await.unwrap();
- let mut message_stream = MessageStream::new(byte_stream);
- message_stream.write_message(&message).await.unwrap();
-
- // Read the wrong type of message from the stream.
- let result = message_stream.read_message::<FromServer>().await;
- assert!(result.is_err());
- });
- }
-
struct ChunkedStream {
bytes: Vec<u8>,
read_offset: usize,
@@ -103,8 +103,16 @@ where
.ok_or_else(|| anyhow!("received response of the wrong t"))
}
- pub async fn send<T: SendMessage>(_: T) -> Result<()> {
- todo!()
+ pub async fn send<T: SendMessage>(&mut self, message: T) -> Result<()> {
+ let message_id = self.next_message_id;
+ self.next_message_id += 1;
+ self.stream
+ .write_message(&proto::FromClient {
+ id: message_id,
+ variant: Some(message.to_variant()),
+ })
+ .await?;
+ Ok(())
}
}
@@ -152,6 +160,18 @@ mod tests {
))
);
+ // Respond to another request to ensure requests are properly matched up.
+ server_stream
+ .write_message(&proto::FromServer {
+ request_id: Some(999),
+ variant: Some(proto::from_server::Variant::AuthResponse(
+ proto::from_server::AuthResponse {
+ credentials_valid: false,
+ },
+ )),
+ })
+ .await
+ .unwrap();
server_stream
.write_message(&proto::FromServer {
request_id: Some(server_req.id),