proto.rs

  1use futures_io::{AsyncRead, AsyncWrite};
  2use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
  3use prost::Message;
  4use std::{convert::TryInto, io};
  5
  6include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
  7
  8pub trait EnvelopedMessage: Sized + Send + 'static {
  9    fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
 10    fn matches_envelope(envelope: &Envelope) -> bool;
 11    fn from_envelope(envelope: Envelope) -> Option<Self>;
 12}
 13
 14pub trait RequestMessage: EnvelopedMessage {
 15    type Response: EnvelopedMessage;
 16}
 17
 18macro_rules! message {
 19    ($name:ident) => {
 20        impl EnvelopedMessage for $name {
 21            fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope {
 22                Envelope {
 23                    id,
 24                    responding_to,
 25                    payload: Some(envelope::Payload::$name(self)),
 26                }
 27            }
 28
 29            fn matches_envelope(envelope: &Envelope) -> bool {
 30                matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
 31            }
 32
 33            fn from_envelope(envelope: Envelope) -> Option<Self> {
 34                if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 35                    Some(msg)
 36                } else {
 37                    None
 38                }
 39            }
 40        }
 41    };
 42}
 43
 44macro_rules! request_message {
 45    ($req:ident, $resp:ident) => {
 46        message!($req);
 47        message!($resp);
 48        impl RequestMessage for $req {
 49            type Response = $resp;
 50        }
 51    };
 52}
 53
 54request_message!(Auth, AuthResponse);
 55request_message!(ShareWorktree, ShareWorktreeResponse);
 56request_message!(OpenWorktree, OpenWorktreeResponse);
 57request_message!(OpenBuffer, OpenBufferResponse);
 58
 59/// A stream of protobuf messages.
 60pub struct MessageStream<T> {
 61    byte_stream: T,
 62    buffer: Vec<u8>,
 63}
 64
 65impl<T> MessageStream<T> {
 66    pub fn new(byte_stream: T) -> Self {
 67        Self {
 68            byte_stream,
 69            buffer: Default::default(),
 70        }
 71    }
 72
 73    pub fn inner_mut(&mut self) -> &mut T {
 74        &mut self.byte_stream
 75    }
 76}
 77
 78impl<T> MessageStream<T>
 79where
 80    T: AsyncWrite + Unpin,
 81{
 82    /// Write a given protobuf message to the stream.
 83    pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
 84        let message_len: u32 = message
 85            .encoded_len()
 86            .try_into()
 87            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
 88        self.buffer.clear();
 89        self.buffer.extend_from_slice(&message_len.to_be_bytes());
 90        message.encode(&mut self.buffer)?;
 91        self.byte_stream.write_all(&self.buffer).await
 92    }
 93}
 94
 95impl<T> MessageStream<T>
 96where
 97    T: AsyncRead + Unpin,
 98{
 99    /// Read a protobuf message of the given type from the stream.
100    pub async fn read_message(&mut self) -> futures_io::Result<Envelope> {
101        let mut delimiter_buf = [0; 4];
102        self.byte_stream.read_exact(&mut delimiter_buf).await?;
103        let message_len = u32::from_be_bytes(delimiter_buf) as usize;
104        self.buffer.resize(message_len, 0);
105        self.byte_stream.read_exact(&mut self.buffer).await?;
106        Ok(Envelope::decode(self.buffer.as_slice())?)
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use std::{
114        pin::Pin,
115        task::{Context, Poll},
116    };
117
118    #[test]
119    fn test_round_trip_message() {
120        smol::block_on(async {
121            let byte_stream = ChunkedStream {
122                bytes: Vec::new(),
123                read_offset: 0,
124                chunk_size: 3,
125            };
126
127            let message1 = Auth {
128                user_id: 5,
129                access_token: "the-access-token".into(),
130            }
131            .into_envelope(3, None);
132
133            let message2 = ShareWorktree {
134                worktree: Some(Worktree {
135                    paths: vec![b"ok".to_vec()],
136                }),
137            }
138            .into_envelope(5, None);
139
140            let mut message_stream = MessageStream::new(byte_stream);
141            message_stream.write_message(&message1).await.unwrap();
142            message_stream.write_message(&message2).await.unwrap();
143            let decoded_message1 = message_stream.read_message().await.unwrap();
144            let decoded_message2 = message_stream.read_message().await.unwrap();
145            assert_eq!(decoded_message1, message1);
146            assert_eq!(decoded_message2, message2);
147        });
148    }
149
150    struct ChunkedStream {
151        bytes: Vec<u8>,
152        read_offset: usize,
153        chunk_size: usize,
154    }
155
156    impl AsyncWrite for ChunkedStream {
157        fn poll_write(
158            mut self: Pin<&mut Self>,
159            _: &mut Context<'_>,
160            buf: &[u8],
161        ) -> Poll<io::Result<usize>> {
162            let bytes_written = buf.len().min(self.chunk_size);
163            self.bytes.extend_from_slice(&buf[0..bytes_written]);
164            Poll::Ready(Ok(bytes_written))
165        }
166
167        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
168            Poll::Ready(Ok(()))
169        }
170
171        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
172            Poll::Ready(Ok(()))
173        }
174    }
175
176    impl AsyncRead for ChunkedStream {
177        fn poll_read(
178            mut self: Pin<&mut Self>,
179            _: &mut Context<'_>,
180            buf: &mut [u8],
181        ) -> Poll<io::Result<usize>> {
182            let bytes_read = buf
183                .len()
184                .min(self.chunk_size)
185                .min(self.bytes.len() - self.read_offset);
186            let end_offset = self.read_offset + bytes_read;
187            buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
188            self.read_offset = end_offset;
189            Poll::Ready(Ok(bytes_read))
190        }
191    }
192}