proto.rs

  1use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _};
  2use prost::Message;
  3use std::{convert::TryInto, io};
  4
  5include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
  6
  7pub trait EnvelopedMessage: Sized + Send + 'static {
  8    const NAME: &'static str;
  9    fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
 10    fn matches_envelope(envelope: &Envelope) -> bool;
 11    fn from_envelope(envelope: Envelope) -> Option<Self>;
 12}
 13
 14pub trait RequestMessage: EnvelopedMessage {
 15    type Response: EnvelopedMessage;
 16}
 17
 18macro_rules! message {
 19    ($name:ident) => {
 20        impl EnvelopedMessage for $name {
 21            const NAME: &'static str = std::stringify!($name);
 22
 23            fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope {
 24                Envelope {
 25                    id,
 26                    responding_to,
 27                    payload: Some(envelope::Payload::$name(self)),
 28                }
 29            }
 30
 31            fn matches_envelope(envelope: &Envelope) -> bool {
 32                matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
 33            }
 34
 35            fn from_envelope(envelope: Envelope) -> Option<Self> {
 36                if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 37                    Some(msg)
 38                } else {
 39                    None
 40                }
 41            }
 42        }
 43    };
 44}
 45
 46macro_rules! request_message {
 47    ($req:ident, $resp:ident) => {
 48        message!($req);
 49        message!($resp);
 50        impl RequestMessage for $req {
 51            type Response = $resp;
 52        }
 53    };
 54}
 55
 56request_message!(Auth, AuthResponse);
 57request_message!(ShareWorktree, ShareWorktreeResponse);
 58request_message!(OpenWorktree, OpenWorktreeResponse);
 59request_message!(OpenFile, OpenFileResponse);
 60message!(CloseFile);
 61request_message!(OpenBuffer, OpenBufferResponse);
 62
 63/// A stream of protobuf messages.
 64pub struct MessageStream<T> {
 65    byte_stream: T,
 66    buffer: Vec<u8>,
 67}
 68
 69impl<T> MessageStream<T> {
 70    pub fn new(byte_stream: T) -> Self {
 71        Self {
 72            byte_stream,
 73            buffer: Default::default(),
 74        }
 75    }
 76
 77    pub fn inner_mut(&mut self) -> &mut T {
 78        &mut self.byte_stream
 79    }
 80}
 81
 82impl<T> MessageStream<T>
 83where
 84    T: AsyncWrite + Unpin,
 85{
 86    /// Write a given protobuf message to the stream.
 87    pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
 88        let message_len: u32 = message
 89            .encoded_len()
 90            .try_into()
 91            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
 92        self.buffer.clear();
 93        self.buffer.extend_from_slice(&message_len.to_be_bytes());
 94        message.encode(&mut self.buffer)?;
 95        self.byte_stream.write_all(&self.buffer).await
 96    }
 97}
 98
 99impl<T> MessageStream<T>
100where
101    T: AsyncRead + Unpin,
102{
103    /// Read a protobuf message of the given type from the stream.
104    pub async fn read_message(&mut self) -> io::Result<Envelope> {
105        let mut delimiter_buf = [0; 4];
106        self.byte_stream.read_exact(&mut delimiter_buf).await?;
107        let message_len = u32::from_be_bytes(delimiter_buf) as usize;
108        self.buffer.resize(message_len, 0);
109        self.byte_stream.read_exact(&mut self.buffer).await?;
110        Ok(Envelope::decode(self.buffer.as_slice())?)
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use std::{
118        pin::Pin,
119        task::{Context, Poll},
120    };
121
122    #[test]
123    fn test_round_trip_message() {
124        smol::block_on(async {
125            let byte_stream = ChunkedStream {
126                bytes: Vec::new(),
127                read_offset: 0,
128                chunk_size: 3,
129            };
130
131            let message1 = Auth {
132                user_id: 5,
133                access_token: "the-access-token".into(),
134            }
135            .into_envelope(3, None);
136
137            let message2 = OpenBuffer {
138                worktree_id: 1,
139                path: "path".to_string(),
140            }
141            .into_envelope(5, None);
142
143            let mut message_stream = MessageStream::new(byte_stream);
144            message_stream.write_message(&message1).await.unwrap();
145            message_stream.write_message(&message2).await.unwrap();
146            let decoded_message1 = message_stream.read_message().await.unwrap();
147            let decoded_message2 = message_stream.read_message().await.unwrap();
148            assert_eq!(decoded_message1, message1);
149            assert_eq!(decoded_message2, message2);
150        });
151    }
152
153    struct ChunkedStream {
154        bytes: Vec<u8>,
155        read_offset: usize,
156        chunk_size: usize,
157    }
158
159    impl AsyncWrite for ChunkedStream {
160        fn poll_write(
161            mut self: Pin<&mut Self>,
162            _: &mut Context<'_>,
163            buf: &[u8],
164        ) -> Poll<io::Result<usize>> {
165            let bytes_written = buf.len().min(self.chunk_size);
166            self.bytes.extend_from_slice(&buf[0..bytes_written]);
167            Poll::Ready(Ok(bytes_written))
168        }
169
170        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
171            Poll::Ready(Ok(()))
172        }
173
174        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
175            Poll::Ready(Ok(()))
176        }
177    }
178
179    impl AsyncRead for ChunkedStream {
180        fn poll_read(
181            mut self: Pin<&mut Self>,
182            _: &mut Context<'_>,
183            buf: &mut [u8],
184        ) -> Poll<io::Result<usize>> {
185            let bytes_read = buf
186                .len()
187                .min(self.chunk_size)
188                .min(self.bytes.len() - self.read_offset);
189            let end_offset = self.read_offset + bytes_read;
190            buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
191            self.read_offset = end_offset;
192            Poll::Ready(Ok(bytes_read))
193        }
194    }
195}