proto.rs

  1use futures_io::{AsyncRead, AsyncWrite};
  2use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
  3use prost::Message;
  4use std::io;
  5
  6include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
  7
  8pub trait Request {
  9    type Response;
 10}
 11
 12impl Request for from_client::Auth {
 13    type Response = from_server::Ack;
 14}
 15
 16/// A stream of protobuf messages.
 17pub struct MessageStream<T> {
 18    byte_stream: T,
 19    buffer: Vec<u8>,
 20}
 21
 22impl<T> MessageStream<T> {
 23    pub fn new(byte_stream: T) -> Self {
 24        Self {
 25            byte_stream,
 26            buffer: Default::default(),
 27        }
 28    }
 29}
 30
 31impl<T> MessageStream<T>
 32where
 33    T: AsyncWrite + Unpin,
 34{
 35    /// Write a given protobuf message to the stream.
 36    pub async fn write_message(&mut self, message: &impl Message) -> futures_io::Result<()> {
 37        self.buffer.clear();
 38        message.encode_length_delimited(&mut self.buffer).unwrap();
 39        self.byte_stream.write_all(&self.buffer).await
 40    }
 41}
 42
 43impl<T> MessageStream<T>
 44where
 45    T: AsyncRead + Unpin,
 46{
 47    /// Read a protobuf message of the given type from the stream.
 48    pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
 49        // Ensure the buffer is large enough to hold the maximum delimiter length
 50        const MAX_DELIMITER_LEN: usize = 10;
 51        self.buffer.clear();
 52        self.buffer.resize(MAX_DELIMITER_LEN, 0);
 53
 54        // Read until a complete length delimiter can be decoded.
 55        let mut read_start_offset = 0;
 56        let (encoded_len, delimiter_len) = loop {
 57            let bytes_read = self
 58                .byte_stream
 59                .read(&mut self.buffer[read_start_offset..])
 60                .await?;
 61            read_start_offset += bytes_read;
 62
 63            let mut buffer = &self.buffer[0..read_start_offset];
 64            match prost::decode_length_delimiter(&mut buffer) {
 65                Err(_) => {
 66                    if read_start_offset >= MAX_DELIMITER_LEN {
 67                        return Err(io::Error::new(
 68                            io::ErrorKind::InvalidData,
 69                            "invalid message length delimiter",
 70                        ));
 71                    }
 72                }
 73                Ok(encoded_len) => {
 74                    let delimiter_len = read_start_offset - buffer.len();
 75                    break (encoded_len, delimiter_len);
 76                }
 77            }
 78        };
 79
 80        // Read the message itself.
 81        self.buffer.resize(delimiter_len + encoded_len, 0);
 82        self.byte_stream
 83            .read_exact(&mut self.buffer[read_start_offset..])
 84            .await?;
 85        let message = M::decode(&self.buffer[delimiter_len..])?;
 86
 87        Ok(message)
 88    }
 89}
 90
 91#[cfg(test)]
 92mod tests {
 93    use super::*;
 94    use std::{
 95        pin::Pin,
 96        task::{Context, Poll},
 97    };
 98
 99    #[test]
100    fn test_round_trip_message() {
101        smol::block_on(async {
102            let byte_stream = ChunkedStream {
103                bytes: Vec::new(),
104                read_offset: 0,
105                chunk_size: 3,
106            };
107
108            // In reality there will never be both `FromClient` and `FromServer` messages
109            // sent in the same direction on the same stream.
110            let message1 = FromClient {
111                id: 3,
112                variant: Some(from_client::Variant::Auth(from_client::Auth {
113                    user_id: 5,
114                    access_token: "the-access-token".into(),
115                })),
116            };
117            let message2 = FromServer {
118                request_id: Some(4),
119                variant: Some(from_server::Variant::Ack(from_server::Ack {
120                    error_message: Some(
121                        format!(
122                            "a {}long error message that requires a two-byte length delimiter",
123                            "very ".repeat(60)
124                        )
125                        .into(),
126                    ),
127                })),
128            };
129
130            let mut message_stream = MessageStream::new(byte_stream);
131            message_stream.write_message(&message1).await.unwrap();
132            message_stream.write_message(&message2).await.unwrap();
133            let decoded_message1 = message_stream.read_message::<FromClient>().await.unwrap();
134            let decoded_message2 = message_stream.read_message::<FromServer>().await.unwrap();
135            assert_eq!(decoded_message1, message1);
136            assert_eq!(decoded_message2, message2);
137        });
138    }
139
140    #[test]
141    fn test_read_message_when_length_delimiter_is_not_complete_in_first_read() {
142        smol::block_on(async {
143            let byte_stream = ChunkedStream {
144                bytes: Vec::new(),
145                read_offset: 0,
146                chunk_size: 2,
147            };
148
149            // This message is so long that its length delimiter requires three bytes,
150            // so it won't be delivered in a single read from the chunked byte stream.
151            let message = FromServer {
152                request_id: Some(4),
153                variant: Some(from_server::Variant::Ack(from_server::Ack {
154                    error_message: Some("long ".repeat(256 * 256).into()),
155                })),
156            };
157            assert!(prost::length_delimiter_len(message.encoded_len()) > byte_stream.chunk_size);
158
159            let mut message_stream = MessageStream::new(byte_stream);
160            message_stream.write_message(&message).await.unwrap();
161            let decoded_message = message_stream.read_message::<FromServer>().await.unwrap();
162            assert_eq!(decoded_message, message);
163        });
164    }
165
166    #[test]
167    fn test_protobuf_parse_error() {
168        smol::block_on(async {
169            let byte_stream = ChunkedStream {
170                bytes: Vec::new(),
171                read_offset: 0,
172                chunk_size: 2,
173            };
174
175            let message = FromClient {
176                id: 3,
177                variant: Some(from_client::Variant::Auth(from_client::Auth {
178                    user_id: 5,
179                    access_token: "the-access-token".into(),
180                })),
181            };
182
183            let mut message_stream = MessageStream::new(byte_stream);
184            message_stream.write_message(&message).await.unwrap();
185
186            // Read the wrong type of message from the stream.
187            let result = message_stream.read_message::<FromServer>().await;
188            assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
189        });
190    }
191
192    struct ChunkedStream {
193        bytes: Vec<u8>,
194        read_offset: usize,
195        chunk_size: usize,
196    }
197
198    impl AsyncWrite for ChunkedStream {
199        fn poll_write(
200            mut self: Pin<&mut Self>,
201            _: &mut Context<'_>,
202            buf: &[u8],
203        ) -> Poll<io::Result<usize>> {
204            let bytes_written = buf.len().min(self.chunk_size);
205            self.bytes.extend_from_slice(&buf[0..bytes_written]);
206            Poll::Ready(Ok(bytes_written))
207        }
208
209        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
210            Poll::Ready(Ok(()))
211        }
212
213        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
214            Poll::Ready(Ok(()))
215        }
216    }
217
218    impl AsyncRead for ChunkedStream {
219        fn poll_read(
220            mut self: Pin<&mut Self>,
221            _: &mut Context<'_>,
222            buf: &mut [u8],
223        ) -> Poll<io::Result<usize>> {
224            let bytes_read = buf
225                .len()
226                .min(self.chunk_size)
227                .min(self.bytes.len() - self.read_offset);
228            let end_offset = self.read_offset + bytes_read;
229            buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
230            self.read_offset = end_offset;
231            Poll::Ready(Ok(bytes_read))
232        }
233    }
234}