proto.rs

  1use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _};
  2use prost::Message;
  3use std::{convert::TryInto, io};
  4
  5include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
  6
  7pub trait EnvelopedMessage: Sized + Send + 'static {
  8    const NAME: &'static str;
  9    fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
 10    fn matches_envelope(envelope: &Envelope) -> bool;
 11    fn from_envelope(envelope: Envelope) -> Option<Self>;
 12}
 13
 14pub trait RequestMessage: EnvelopedMessage {
 15    type Response: EnvelopedMessage;
 16}
 17
 18macro_rules! message {
 19    ($name:ident) => {
 20        impl EnvelopedMessage for $name {
 21            const NAME: &'static str = std::stringify!($name);
 22
 23            fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope {
 24                Envelope {
 25                    id,
 26                    responding_to,
 27                    payload: Some(envelope::Payload::$name(self)),
 28                }
 29            }
 30
 31            fn matches_envelope(envelope: &Envelope) -> bool {
 32                matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
 33            }
 34
 35            fn from_envelope(envelope: Envelope) -> Option<Self> {
 36                if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 37                    Some(msg)
 38                } else {
 39                    None
 40                }
 41            }
 42        }
 43    };
 44}
 45
 46macro_rules! request_message {
 47    ($req:ident, $resp:ident) => {
 48        message!($req);
 49        message!($resp);
 50        impl RequestMessage for $req {
 51            type Response = $resp;
 52        }
 53    };
 54}
 55
 56request_message!(Auth, AuthResponse);
 57request_message!(ShareWorktree, ShareWorktreeResponse);
 58request_message!(OpenWorktree, OpenWorktreeResponse);
 59request_message!(OpenBuffer, OpenBufferResponse);
 60
 61/// A stream of protobuf messages.
 62pub struct MessageStream<T> {
 63    byte_stream: T,
 64    buffer: Vec<u8>,
 65}
 66
 67impl<T> MessageStream<T> {
 68    pub fn new(byte_stream: T) -> Self {
 69        Self {
 70            byte_stream,
 71            buffer: Default::default(),
 72        }
 73    }
 74
 75    pub fn inner_mut(&mut self) -> &mut T {
 76        &mut self.byte_stream
 77    }
 78}
 79
 80impl<T> MessageStream<T>
 81where
 82    T: AsyncWrite + Unpin,
 83{
 84    /// Write a given protobuf message to the stream.
 85    pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
 86        let message_len: u32 = message
 87            .encoded_len()
 88            .try_into()
 89            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
 90        self.buffer.clear();
 91        self.buffer.extend_from_slice(&message_len.to_be_bytes());
 92        message.encode(&mut self.buffer)?;
 93        self.byte_stream.write_all(&self.buffer).await
 94    }
 95}
 96
 97impl<T> MessageStream<T>
 98where
 99    T: AsyncRead + Unpin,
100{
101    /// Read a protobuf message of the given type from the stream.
102    pub async fn read_message(&mut self) -> io::Result<Envelope> {
103        let mut delimiter_buf = [0; 4];
104        self.byte_stream.read_exact(&mut delimiter_buf).await?;
105        let message_len = u32::from_be_bytes(delimiter_buf) as usize;
106        self.buffer.resize(message_len, 0);
107        self.byte_stream.read_exact(&mut self.buffer).await?;
108        Ok(Envelope::decode(self.buffer.as_slice())?)
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use std::{
116        pin::Pin,
117        task::{Context, Poll},
118    };
119
120    #[test]
121    fn test_round_trip_message() {
122        smol::block_on(async {
123            let byte_stream = ChunkedStream {
124                bytes: Vec::new(),
125                read_offset: 0,
126                chunk_size: 3,
127            };
128
129            let message1 = Auth {
130                user_id: 5,
131                access_token: "the-access-token".into(),
132            }
133            .into_envelope(3, None);
134
135            let message2 = OpenBuffer {
136                worktree_id: 1,
137                path: "path".to_string(),
138            }
139            .into_envelope(5, None);
140
141            let mut message_stream = MessageStream::new(byte_stream);
142            message_stream.write_message(&message1).await.unwrap();
143            message_stream.write_message(&message2).await.unwrap();
144            let decoded_message1 = message_stream.read_message().await.unwrap();
145            let decoded_message2 = message_stream.read_message().await.unwrap();
146            assert_eq!(decoded_message1, message1);
147            assert_eq!(decoded_message2, message2);
148        });
149    }
150
151    struct ChunkedStream {
152        bytes: Vec<u8>,
153        read_offset: usize,
154        chunk_size: usize,
155    }
156
157    impl AsyncWrite for ChunkedStream {
158        fn poll_write(
159            mut self: Pin<&mut Self>,
160            _: &mut Context<'_>,
161            buf: &[u8],
162        ) -> Poll<io::Result<usize>> {
163            let bytes_written = buf.len().min(self.chunk_size);
164            self.bytes.extend_from_slice(&buf[0..bytes_written]);
165            Poll::Ready(Ok(bytes_written))
166        }
167
168        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
169            Poll::Ready(Ok(()))
170        }
171
172        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
173            Poll::Ready(Ok(()))
174        }
175    }
176
177    impl AsyncRead for ChunkedStream {
178        fn poll_read(
179            mut self: Pin<&mut Self>,
180            _: &mut Context<'_>,
181            buf: &mut [u8],
182        ) -> Poll<io::Result<usize>> {
183            let bytes_read = buf
184                .len()
185                .min(self.chunk_size)
186                .min(self.bytes.len() - self.read_offset);
187            let end_offset = self.read_offset + bytes_read;
188            buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
189            self.read_offset = end_offset;
190            Poll::Ready(Ok(bytes_read))
191        }
192    }
193}