proto.rs

  1use futures_io::{AsyncRead, AsyncWrite};
  2use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
  3use prost::Message;
  4use std::{convert::TryInto, io};
  5
  6include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
  7
  8pub trait EnvelopedMessage: Sized + Send + 'static {
  9    fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
 10    fn from_envelope(envelope: Envelope) -> Option<Self>;
 11}
 12
 13pub trait RequestMessage: EnvelopedMessage {
 14    type Response: EnvelopedMessage;
 15}
 16
 17macro_rules! message {
 18    ($name:ident) => {
 19        impl EnvelopedMessage for $name {
 20            fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope {
 21                Envelope {
 22                    id,
 23                    responding_to,
 24                    payload: Some(envelope::Payload::$name(self)),
 25                }
 26            }
 27
 28            fn from_envelope(envelope: Envelope) -> Option<Self> {
 29                if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 30                    Some(msg)
 31                } else {
 32                    None
 33                }
 34            }
 35        }
 36    };
 37}
 38
 39macro_rules! request_message {
 40    ($req:ident, $resp:ident) => {
 41        message!($req);
 42        message!($resp);
 43        impl RequestMessage for $req {
 44            type Response = $resp;
 45        }
 46    };
 47}
 48
 49request_message!(Auth, AuthResponse);
 50request_message!(ShareWorktree, ShareWorktreeResponse);
 51request_message!(OpenWorktree, OpenWorktreeResponse);
 52request_message!(OpenBuffer, OpenBufferResponse);
 53
 54/// A stream of protobuf messages.
 55pub struct MessageStream<T> {
 56    byte_stream: T,
 57    buffer: Vec<u8>,
 58}
 59
 60impl<T> MessageStream<T> {
 61    pub fn new(byte_stream: T) -> Self {
 62        Self {
 63            byte_stream,
 64            buffer: Default::default(),
 65        }
 66    }
 67
 68    pub fn inner_mut(&mut self) -> &mut T {
 69        &mut self.byte_stream
 70    }
 71}
 72
 73impl<T> MessageStream<T>
 74where
 75    T: AsyncWrite + Unpin,
 76{
 77    /// Write a given protobuf message to the stream.
 78    pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
 79        let message_len: u32 = message
 80            .encoded_len()
 81            .try_into()
 82            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
 83        self.buffer.clear();
 84        self.buffer.extend_from_slice(&message_len.to_be_bytes());
 85        message.encode(&mut self.buffer)?;
 86        self.byte_stream.write_all(&self.buffer).await
 87    }
 88}
 89
 90impl<T> MessageStream<T>
 91where
 92    T: AsyncRead + Unpin,
 93{
 94    /// Read a protobuf message of the given type from the stream.
 95    pub async fn read_message(&mut self) -> futures_io::Result<Envelope> {
 96        let mut delimiter_buf = [0; 4];
 97        self.byte_stream.read_exact(&mut delimiter_buf).await?;
 98        let message_len = u32::from_be_bytes(delimiter_buf) as usize;
 99        self.buffer.resize(message_len, 0);
100        self.byte_stream.read_exact(&mut self.buffer).await?;
101        Ok(Envelope::decode(self.buffer.as_slice())?)
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use std::{
109        pin::Pin,
110        task::{Context, Poll},
111    };
112
113    #[test]
114    fn test_round_trip_message() {
115        smol::block_on(async {
116            let byte_stream = ChunkedStream {
117                bytes: Vec::new(),
118                read_offset: 0,
119                chunk_size: 3,
120            };
121
122            let message1 = Auth {
123                user_id: 5,
124                access_token: "the-access-token".into(),
125            }
126            .into_envelope(3, None);
127
128            let message2 = ShareWorktree {
129                worktree: Some(Worktree {
130                    paths: vec![b"ok".to_vec()],
131                }),
132            }
133            .into_envelope(5, None);
134
135            let mut message_stream = MessageStream::new(byte_stream);
136            message_stream.write_message(&message1).await.unwrap();
137            message_stream.write_message(&message2).await.unwrap();
138            let decoded_message1 = message_stream.read_message().await.unwrap();
139            let decoded_message2 = message_stream.read_message().await.unwrap();
140            assert_eq!(decoded_message1, message1);
141            assert_eq!(decoded_message2, message2);
142        });
143    }
144
145    struct ChunkedStream {
146        bytes: Vec<u8>,
147        read_offset: usize,
148        chunk_size: usize,
149    }
150
151    impl AsyncWrite for ChunkedStream {
152        fn poll_write(
153            mut self: Pin<&mut Self>,
154            _: &mut Context<'_>,
155            buf: &[u8],
156        ) -> Poll<io::Result<usize>> {
157            let bytes_written = buf.len().min(self.chunk_size);
158            self.bytes.extend_from_slice(&buf[0..bytes_written]);
159            Poll::Ready(Ok(bytes_written))
160        }
161
162        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
163            Poll::Ready(Ok(()))
164        }
165
166        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
167            Poll::Ready(Ok(()))
168        }
169    }
170
171    impl AsyncRead for ChunkedStream {
172        fn poll_read(
173            mut self: Pin<&mut Self>,
174            _: &mut Context<'_>,
175            buf: &mut [u8],
176        ) -> Poll<io::Result<usize>> {
177            let bytes_read = buf
178                .len()
179                .min(self.chunk_size)
180                .min(self.bytes.len() - self.read_offset);
181            let end_offset = self.read_offset + bytes_read;
182            buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
183            self.read_offset = end_offset;
184            Poll::Ready(Ok(bytes_read))
185        }
186    }
187}