proto.rs

  1use futures_io::{AsyncRead, AsyncWrite};
  2use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
  3use prost::Message;
  4use std::io;
  5
  6include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
  7
  8use from_client as client;
  9use from_server as server;
 10
 11pub trait Request {
 12    type Response;
 13}
 14
 15macro_rules! request_response {
 16    ($req:path, $resp:path) => {
 17        impl Request for $req {
 18            type Response = $resp;
 19        }
 20    };
 21}
 22
 23request_response!(client::Auth, server::AuthResponse);
 24request_response!(client::NewWorktree, server::NewWorktreeResponse);
 25request_response!(client::ShareWorktree, server::ShareWorktreeResponse);
 26
 27/// A stream of protobuf messages.
 28pub struct MessageStream<T> {
 29    byte_stream: T,
 30    buffer: Vec<u8>,
 31}
 32
 33impl<T> MessageStream<T> {
 34    pub fn new(byte_stream: T) -> Self {
 35        Self {
 36            byte_stream,
 37            buffer: Default::default(),
 38        }
 39    }
 40}
 41
 42impl<T> MessageStream<T>
 43where
 44    T: AsyncWrite + Unpin,
 45{
 46    /// Write a given protobuf message to the stream.
 47    pub async fn write_message(&mut self, message: &impl Message) -> futures_io::Result<()> {
 48        self.buffer.clear();
 49        message.encode_length_delimited(&mut self.buffer).unwrap();
 50        self.byte_stream.write_all(&self.buffer).await
 51    }
 52}
 53
 54impl<T> MessageStream<T>
 55where
 56    T: AsyncRead + Unpin,
 57{
 58    /// Read a protobuf message of the given type from the stream.
 59    pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
 60        // Ensure the buffer is large enough to hold the maximum delimiter length
 61        const MAX_DELIMITER_LEN: usize = 10;
 62        self.buffer.clear();
 63        self.buffer.resize(MAX_DELIMITER_LEN, 0);
 64
 65        // Read until a complete length delimiter can be decoded.
 66        let mut read_start_offset = 0;
 67        let (encoded_len, delimiter_len) = loop {
 68            let bytes_read = self
 69                .byte_stream
 70                .read(&mut self.buffer[read_start_offset..])
 71                .await?;
 72            read_start_offset += bytes_read;
 73
 74            let mut buffer = &self.buffer[0..read_start_offset];
 75            match prost::decode_length_delimiter(&mut buffer) {
 76                Err(_) => {
 77                    if read_start_offset >= MAX_DELIMITER_LEN {
 78                        return Err(io::Error::new(
 79                            io::ErrorKind::InvalidData,
 80                            "invalid message length delimiter",
 81                        ));
 82                    }
 83                }
 84                Ok(encoded_len) => {
 85                    let delimiter_len = read_start_offset - buffer.len();
 86                    break (encoded_len, delimiter_len);
 87                }
 88            }
 89        };
 90
 91        // Read the message itself.
 92        self.buffer.resize(delimiter_len + encoded_len, 0);
 93        self.byte_stream
 94            .read_exact(&mut self.buffer[read_start_offset..])
 95            .await?;
 96        let message = M::decode(&self.buffer[delimiter_len..])?;
 97
 98        Ok(message)
 99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use std::{
106        pin::Pin,
107        task::{Context, Poll},
108    };
109
110    #[test]
111    fn test_round_trip_message() {
112        smol::block_on(async {
113            let byte_stream = ChunkedStream {
114                bytes: Vec::new(),
115                read_offset: 0,
116                chunk_size: 3,
117            };
118
119            // In reality there will never be both `FromClient` and `FromServer` messages
120            // sent in the same direction on the same stream.
121            let message1 = FromClient {
122                id: 3,
123                variant: Some(from_client::Variant::Auth(from_client::Auth {
124                    user_id: 5,
125                    access_token: "the-access-token".into(),
126                })),
127            };
128            let message2 = FromServer {
129                request_id: Some(4),
130                variant: Some(from_server::Variant::Ack(from_server::Ack {
131                    error_message: Some(
132                        format!(
133                            "a {}long error message that requires a two-byte length delimiter",
134                            "very ".repeat(60)
135                        )
136                        .into(),
137                    ),
138                })),
139            };
140
141            let mut message_stream = MessageStream::new(byte_stream);
142            message_stream.write_message(&message1).await.unwrap();
143            message_stream.write_message(&message2).await.unwrap();
144            let decoded_message1 = message_stream.read_message::<FromClient>().await.unwrap();
145            let decoded_message2 = message_stream.read_message::<FromServer>().await.unwrap();
146            assert_eq!(decoded_message1, message1);
147            assert_eq!(decoded_message2, message2);
148        });
149    }
150
151    #[test]
152    fn test_read_message_when_length_delimiter_is_not_complete_in_first_read() {
153        smol::block_on(async {
154            let byte_stream = ChunkedStream {
155                bytes: Vec::new(),
156                read_offset: 0,
157                chunk_size: 2,
158            };
159
160            // This message is so long that its length delimiter requires three bytes,
161            // so it won't be delivered in a single read from the chunked byte stream.
162            let message = FromServer {
163                request_id: Some(4),
164                variant: Some(from_server::Variant::Ack(from_server::Ack {
165                    error_message: Some("long ".repeat(256 * 256).into()),
166                })),
167            };
168            assert!(prost::length_delimiter_len(message.encoded_len()) > byte_stream.chunk_size);
169
170            let mut message_stream = MessageStream::new(byte_stream);
171            message_stream.write_message(&message).await.unwrap();
172            let decoded_message = message_stream.read_message::<FromServer>().await.unwrap();
173            assert_eq!(decoded_message, message);
174        });
175    }
176
177    #[test]
178    fn test_protobuf_parse_error() {
179        smol::block_on(async {
180            let byte_stream = ChunkedStream {
181                bytes: Vec::new(),
182                read_offset: 0,
183                chunk_size: 2,
184            };
185
186            let message = FromClient {
187                id: 3,
188                variant: Some(from_client::Variant::Auth(from_client::Auth {
189                    user_id: 5,
190                    access_token: "the-access-token".into(),
191                })),
192            };
193
194            let mut message_stream = MessageStream::new(byte_stream);
195            message_stream.write_message(&message).await.unwrap();
196
197            // Read the wrong type of message from the stream.
198            let result = message_stream.read_message::<FromServer>().await;
199            assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
200        });
201    }
202
203    struct ChunkedStream {
204        bytes: Vec<u8>,
205        read_offset: usize,
206        chunk_size: usize,
207    }
208
209    impl AsyncWrite for ChunkedStream {
210        fn poll_write(
211            mut self: Pin<&mut Self>,
212            _: &mut Context<'_>,
213            buf: &[u8],
214        ) -> Poll<io::Result<usize>> {
215            let bytes_written = buf.len().min(self.chunk_size);
216            self.bytes.extend_from_slice(&buf[0..bytes_written]);
217            Poll::Ready(Ok(bytes_written))
218        }
219
220        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
221            Poll::Ready(Ok(()))
222        }
223
224        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
225            Poll::Ready(Ok(()))
226        }
227    }
228
229    impl AsyncRead for ChunkedStream {
230        fn poll_read(
231            mut self: Pin<&mut Self>,
232            _: &mut Context<'_>,
233            buf: &mut [u8],
234        ) -> Poll<io::Result<usize>> {
235            let bytes_read = buf
236                .len()
237                .min(self.chunk_size)
238                .min(self.bytes.len() - self.read_offset);
239            let end_offset = self.read_offset + bytes_read;
240            buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
241            self.read_offset = end_offset;
242            Poll::Ready(Ok(bytes_read))
243        }
244    }
245}