proto.rs

  1use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _};
  2use prost::Message;
  3use std::{convert::TryInto, io};
  4
  5include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
  6
  7pub trait EnvelopedMessage: Sized + Send + 'static {
  8    const NAME: &'static str;
  9    fn into_envelope(
 10        self,
 11        id: u32,
 12        responding_to: Option<u32>,
 13        original_sender_id: Option<u32>,
 14    ) -> Envelope;
 15    fn matches_envelope(envelope: &Envelope) -> bool;
 16    fn from_envelope(envelope: Envelope) -> Option<Self>;
 17}
 18
 19pub trait RequestMessage: EnvelopedMessage {
 20    type Response: EnvelopedMessage;
 21}
 22
 23macro_rules! message {
 24    ($name:ident) => {
 25        impl EnvelopedMessage for $name {
 26            const NAME: &'static str = std::stringify!($name);
 27
 28            fn into_envelope(
 29                self,
 30                id: u32,
 31                responding_to: Option<u32>,
 32                original_sender_id: Option<u32>,
 33            ) -> Envelope {
 34                Envelope {
 35                    id,
 36                    responding_to,
 37                    original_sender_id,
 38                    payload: Some(envelope::Payload::$name(self)),
 39                }
 40            }
 41
 42            fn matches_envelope(envelope: &Envelope) -> bool {
 43                matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
 44            }
 45
 46            fn from_envelope(envelope: Envelope) -> Option<Self> {
 47                if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 48                    Some(msg)
 49                } else {
 50                    None
 51                }
 52            }
 53        }
 54    };
 55}
 56
 57macro_rules! request_message {
 58    ($req:ident, $resp:ident) => {
 59        message!($req);
 60        message!($resp);
 61        impl RequestMessage for $req {
 62            type Response = $resp;
 63        }
 64    };
 65}
 66
 67request_message!(Auth, AuthResponse);
 68request_message!(ShareWorktree, ShareWorktreeResponse);
 69request_message!(OpenWorktree, OpenWorktreeResponse);
 70request_message!(OpenFile, OpenFileResponse);
 71message!(CloseFile);
 72request_message!(OpenBuffer, OpenBufferResponse);
 73
 74/// A stream of protobuf messages.
 75pub struct MessageStream<T> {
 76    byte_stream: T,
 77    buffer: Vec<u8>,
 78}
 79
 80impl<T> MessageStream<T> {
 81    pub fn new(byte_stream: T) -> Self {
 82        Self {
 83            byte_stream,
 84            buffer: Default::default(),
 85        }
 86    }
 87
 88    pub fn inner_mut(&mut self) -> &mut T {
 89        &mut self.byte_stream
 90    }
 91}
 92
 93impl<T> MessageStream<T>
 94where
 95    T: AsyncWrite + Unpin,
 96{
 97    /// Write a given protobuf message to the stream.
 98    pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
 99        let message_len: u32 = message
100            .encoded_len()
101            .try_into()
102            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
103        self.buffer.clear();
104        self.buffer.extend_from_slice(&message_len.to_be_bytes());
105        message.encode(&mut self.buffer)?;
106        self.byte_stream.write_all(&self.buffer).await
107    }
108}
109
110impl<T> MessageStream<T>
111where
112    T: AsyncRead + Unpin,
113{
114    /// Read a protobuf message of the given type from the stream.
115    pub async fn read_message(&mut self) -> io::Result<Envelope> {
116        let mut delimiter_buf = [0; 4];
117        self.byte_stream.read_exact(&mut delimiter_buf).await?;
118        let message_len = u32::from_be_bytes(delimiter_buf) as usize;
119        self.buffer.resize(message_len, 0);
120        self.byte_stream.read_exact(&mut self.buffer).await?;
121        Ok(Envelope::decode(self.buffer.as_slice())?)
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use std::{
129        pin::Pin,
130        task::{Context, Poll},
131    };
132
133    #[test]
134    fn test_round_trip_message() {
135        smol::block_on(async {
136            let byte_stream = ChunkedStream {
137                bytes: Vec::new(),
138                read_offset: 0,
139                chunk_size: 3,
140            };
141
142            let message1 = Auth {
143                user_id: 5,
144                access_token: "the-access-token".into(),
145            }
146            .into_envelope(3, None, None);
147
148            let message2 = OpenBuffer {
149                worktree_id: 0,
150                id: 1,
151            }
152            .into_envelope(5, None, None);
153
154            let mut message_stream = MessageStream::new(byte_stream);
155            message_stream.write_message(&message1).await.unwrap();
156            message_stream.write_message(&message2).await.unwrap();
157            let decoded_message1 = message_stream.read_message().await.unwrap();
158            let decoded_message2 = message_stream.read_message().await.unwrap();
159            assert_eq!(decoded_message1, message1);
160            assert_eq!(decoded_message2, message2);
161        });
162    }
163
164    struct ChunkedStream {
165        bytes: Vec<u8>,
166        read_offset: usize,
167        chunk_size: usize,
168    }
169
170    impl AsyncWrite for ChunkedStream {
171        fn poll_write(
172            mut self: Pin<&mut Self>,
173            _: &mut Context<'_>,
174            buf: &[u8],
175        ) -> Poll<io::Result<usize>> {
176            let bytes_written = buf.len().min(self.chunk_size);
177            self.bytes.extend_from_slice(&buf[0..bytes_written]);
178            Poll::Ready(Ok(bytes_written))
179        }
180
181        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
182            Poll::Ready(Ok(()))
183        }
184
185        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
186            Poll::Ready(Ok(()))
187        }
188    }
189
190    impl AsyncRead for ChunkedStream {
191        fn poll_read(
192            mut self: Pin<&mut Self>,
193            _: &mut Context<'_>,
194            buf: &mut [u8],
195        ) -> Poll<io::Result<usize>> {
196            let bytes_read = buf
197                .len()
198                .min(self.chunk_size)
199                .min(self.bytes.len() - self.read_offset);
200            let end_offset = self.read_offset + bytes_read;
201            buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
202            self.read_offset = end_offset;
203            Poll::Ready(Ok(bytes_read))
204        }
205    }
206}