proto.rs

  1use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _};
  2use prost::Message;
  3use std::{convert::TryInto, io};
  4
  5include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
  6
  7pub trait EnvelopedMessage: Sized + Send + 'static {
  8    fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
  9    fn matches_envelope(envelope: &Envelope) -> bool;
 10    fn from_envelope(envelope: Envelope) -> Option<Self>;
 11}
 12
 13pub trait RequestMessage: EnvelopedMessage {
 14    type Response: EnvelopedMessage;
 15}
 16
 17macro_rules! message {
 18    ($name:ident) => {
 19        impl EnvelopedMessage for $name {
 20            fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope {
 21                Envelope {
 22                    id,
 23                    responding_to,
 24                    payload: Some(envelope::Payload::$name(self)),
 25                }
 26            }
 27
 28            fn matches_envelope(envelope: &Envelope) -> bool {
 29                matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
 30            }
 31
 32            fn from_envelope(envelope: Envelope) -> Option<Self> {
 33                if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 34                    Some(msg)
 35                } else {
 36                    None
 37                }
 38            }
 39        }
 40    };
 41}
 42
 43macro_rules! request_message {
 44    ($req:ident, $resp:ident) => {
 45        message!($req);
 46        message!($resp);
 47        impl RequestMessage for $req {
 48            type Response = $resp;
 49        }
 50    };
 51}
 52
 53request_message!(Auth, AuthResponse);
 54request_message!(ShareWorktree, ShareWorktreeResponse);
 55request_message!(OpenWorktree, OpenWorktreeResponse);
 56request_message!(OpenBuffer, OpenBufferResponse);
 57
 58/// A stream of protobuf messages.
 59pub struct MessageStream<T> {
 60    byte_stream: T,
 61    buffer: Vec<u8>,
 62}
 63
 64impl<T> MessageStream<T> {
 65    pub fn new(byte_stream: T) -> Self {
 66        Self {
 67            byte_stream,
 68            buffer: Default::default(),
 69        }
 70    }
 71
 72    pub fn inner_mut(&mut self) -> &mut T {
 73        &mut self.byte_stream
 74    }
 75}
 76
 77impl<T> MessageStream<T>
 78where
 79    T: AsyncWrite + Unpin,
 80{
 81    /// Write a given protobuf message to the stream.
 82    pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
 83        let message_len: u32 = message
 84            .encoded_len()
 85            .try_into()
 86            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
 87        self.buffer.clear();
 88        self.buffer.extend_from_slice(&message_len.to_be_bytes());
 89        message.encode(&mut self.buffer)?;
 90        self.byte_stream.write_all(&self.buffer).await
 91    }
 92}
 93
 94impl<T> MessageStream<T>
 95where
 96    T: AsyncRead + Unpin,
 97{
 98    /// Read a protobuf message of the given type from the stream.
 99    pub async fn read_message(&mut self) -> io::Result<Envelope> {
100        let mut delimiter_buf = [0; 4];
101        self.byte_stream.read_exact(&mut delimiter_buf).await?;
102        let message_len = u32::from_be_bytes(delimiter_buf) as usize;
103        self.buffer.resize(message_len, 0);
104        self.byte_stream.read_exact(&mut self.buffer).await?;
105        Ok(Envelope::decode(self.buffer.as_slice())?)
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use std::{
113        pin::Pin,
114        task::{Context, Poll},
115    };
116
117    #[test]
118    fn test_round_trip_message() {
119        smol::block_on(async {
120            let byte_stream = ChunkedStream {
121                bytes: Vec::new(),
122                read_offset: 0,
123                chunk_size: 3,
124            };
125
126            let message1 = Auth {
127                user_id: 5,
128                access_token: "the-access-token".into(),
129            }
130            .into_envelope(3, None);
131
132            let message2 = ShareWorktree {
133                worktree: Some(Worktree {
134                    paths: vec![b"ok".to_vec()],
135                }),
136            }
137            .into_envelope(5, None);
138
139            let mut message_stream = MessageStream::new(byte_stream);
140            message_stream.write_message(&message1).await.unwrap();
141            message_stream.write_message(&message2).await.unwrap();
142            let decoded_message1 = message_stream.read_message().await.unwrap();
143            let decoded_message2 = message_stream.read_message().await.unwrap();
144            assert_eq!(decoded_message1, message1);
145            assert_eq!(decoded_message2, message2);
146        });
147    }
148
149    struct ChunkedStream {
150        bytes: Vec<u8>,
151        read_offset: usize,
152        chunk_size: usize,
153    }
154
155    impl AsyncWrite for ChunkedStream {
156        fn poll_write(
157            mut self: Pin<&mut Self>,
158            _: &mut Context<'_>,
159            buf: &[u8],
160        ) -> Poll<io::Result<usize>> {
161            let bytes_written = buf.len().min(self.chunk_size);
162            self.bytes.extend_from_slice(&buf[0..bytes_written]);
163            Poll::Ready(Ok(bytes_written))
164        }
165
166        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
167            Poll::Ready(Ok(()))
168        }
169
170        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
171            Poll::Ready(Ok(()))
172        }
173    }
174
175    impl AsyncRead for ChunkedStream {
176        fn poll_read(
177            mut self: Pin<&mut Self>,
178            _: &mut Context<'_>,
179            buf: &mut [u8],
180        ) -> Poll<io::Result<usize>> {
181            let bytes_read = buf
182                .len()
183                .min(self.chunk_size)
184                .min(self.bytes.len() - self.read_offset);
185            let end_offset = self.read_offset + bytes_read;
186            buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
187            self.read_offset = end_offset;
188            Poll::Ready(Ok(bytes_read))
189        }
190    }
191}