proto.rs

  1use futures_io::{AsyncRead, AsyncWrite};
  2use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
  3use prost::Message;
  4use std::{convert::TryInto, io};
  5
  6include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
  7
  8/// A message that the client can send to the server.
  9pub trait ClientMessage: Sized {
 10    fn to_variant(self) -> from_client::Variant;
 11    fn from_variant(variant: from_client::Variant) -> Option<Self>;
 12}
 13
 14/// A message that the server can send to the client.
 15pub trait ServerMessage: Sized {
 16    fn to_variant(self) -> from_server::Variant;
 17    fn from_variant(variant: from_server::Variant) -> Option<Self>;
 18}
 19
 20/// A message that the client can send to the server, where the server must respond with a single
 21/// message of a certain type.
 22pub trait RequestMessage: ClientMessage {
 23    type Response: ServerMessage;
 24}
 25
 26/// A message that the client can send to the server, where the server must respond with a series of
 27/// messages of a certain type.
 28pub trait SubscribeMessage: ClientMessage {
 29    type Event: ServerMessage;
 30}
 31
 32/// A message that the client can send to the server, where the server will not respond.
 33pub trait SendMessage: ClientMessage {}
 34
 35macro_rules! directed_message {
 36    ($name:ident, $direction_trait:ident, $direction_module:ident) => {
 37        impl $direction_trait for $direction_module::$name {
 38            fn to_variant(self) -> $direction_module::Variant {
 39                $direction_module::Variant::$name(self)
 40            }
 41
 42            fn from_variant(variant: $direction_module::Variant) -> Option<Self> {
 43                if let $direction_module::Variant::$name(msg) = variant {
 44                    Some(msg)
 45                } else {
 46                    None
 47                }
 48            }
 49        }
 50    };
 51}
 52
 53macro_rules! request_message {
 54    ($req:ident, $resp:ident) => {
 55        directed_message!($req, ClientMessage, from_client);
 56        directed_message!($resp, ServerMessage, from_server);
 57        impl RequestMessage for from_client::$req {
 58            type Response = from_server::$resp;
 59        }
 60    };
 61}
 62
 63macro_rules! send_message {
 64    ($msg:ident) => {
 65        directed_message!($msg, ClientMessage, from_client);
 66        impl SendMessage for from_client::$msg {}
 67    };
 68}
 69
 70macro_rules! subscribe_message {
 71    ($subscription:ident, $event:ident) => {
 72        directed_message!($subscription, ClientMessage, from_client);
 73        directed_message!($event, ServerMessage, from_server);
 74        impl SubscribeMessage for from_client::$subscription {
 75            type Event = from_server::$event;
 76        }
 77    };
 78}
 79
 80request_message!(Auth, AuthResponse);
 81request_message!(NewWorktree, NewWorktreeResponse);
 82request_message!(ShareWorktree, ShareWorktreeResponse);
 83send_message!(UploadFile);
 84subscribe_message!(SubscribeToPathRequests, PathRequest);
 85
 86/// A stream of protobuf messages.
 87pub struct MessageStream<T> {
 88    byte_stream: T,
 89    buffer: Vec<u8>,
 90}
 91
 92impl<T> MessageStream<T> {
 93    pub fn new(byte_stream: T) -> Self {
 94        Self {
 95            byte_stream,
 96            buffer: Default::default(),
 97        }
 98    }
 99
100    pub fn inner_mut(&mut self) -> &mut T {
101        &mut self.byte_stream
102    }
103}
104
105impl<T> MessageStream<T>
106where
107    T: AsyncWrite + Unpin,
108{
109    /// Write a given protobuf message to the stream.
110    pub async fn write_message(&mut self, message: &impl Message) -> io::Result<()> {
111        let message_len: u32 = message
112            .encoded_len()
113            .try_into()
114            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
115        self.buffer.clear();
116        self.buffer.extend_from_slice(&message_len.to_be_bytes());
117        message.encode(&mut self.buffer)?;
118        self.byte_stream.write_all(&self.buffer).await
119    }
120}
121
122impl<T> MessageStream<T>
123where
124    T: AsyncRead + Unpin,
125{
126    /// Read a protobuf message of the given type from the stream.
127    pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
128        let mut delimiter_buf = [0; 4];
129        self.byte_stream.read_exact(&mut delimiter_buf).await?;
130        let message_len = u32::from_be_bytes(delimiter_buf) as usize;
131        self.buffer.resize(message_len, 0);
132        self.byte_stream.read_exact(&mut self.buffer).await?;
133        Ok(M::decode(self.buffer.as_slice())?)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use std::{
141        pin::Pin,
142        task::{Context, Poll},
143    };
144
145    #[test]
146    fn test_round_trip_message() {
147        smol::block_on(async {
148            let byte_stream = ChunkedStream {
149                bytes: Vec::new(),
150                read_offset: 0,
151                chunk_size: 3,
152            };
153
154            let message1 = FromClient {
155                id: 3,
156                variant: Some(from_client::Variant::Auth(from_client::Auth {
157                    user_id: 5,
158                    access_token: "the-access-token".into(),
159                })),
160            };
161            let message2 = FromClient {
162                id: 4,
163                variant: Some(from_client::Variant::UploadFile(from_client::UploadFile {
164                    path: Vec::new(),
165                    content: format!(
166                        "a {}long error message that requires a two-byte length delimiter",
167                        "very ".repeat(60)
168                    )
169                    .into(),
170                })),
171            };
172
173            let mut message_stream = MessageStream::new(byte_stream);
174            message_stream.write_message(&message1).await.unwrap();
175            message_stream.write_message(&message2).await.unwrap();
176            let decoded_message1 = message_stream.read_message::<FromClient>().await.unwrap();
177            let decoded_message2 = message_stream.read_message::<FromClient>().await.unwrap();
178            assert_eq!(decoded_message1, message1);
179            assert_eq!(decoded_message2, message2);
180        });
181    }
182
183    struct ChunkedStream {
184        bytes: Vec<u8>,
185        read_offset: usize,
186        chunk_size: usize,
187    }
188
189    impl AsyncWrite for ChunkedStream {
190        fn poll_write(
191            mut self: Pin<&mut Self>,
192            _: &mut Context<'_>,
193            buf: &[u8],
194        ) -> Poll<io::Result<usize>> {
195            let bytes_written = buf.len().min(self.chunk_size);
196            self.bytes.extend_from_slice(&buf[0..bytes_written]);
197            Poll::Ready(Ok(bytes_written))
198        }
199
200        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
201            Poll::Ready(Ok(()))
202        }
203
204        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
205            Poll::Ready(Ok(()))
206        }
207    }
208
209    impl AsyncRead for ChunkedStream {
210        fn poll_read(
211            mut self: Pin<&mut Self>,
212            _: &mut Context<'_>,
213            buf: &mut [u8],
214        ) -> Poll<io::Result<usize>> {
215            let bytes_read = buf
216                .len()
217                .min(self.chunk_size)
218                .min(self.bytes.len() - self.read_offset);
219            let end_offset = self.read_offset + bytes_read;
220            buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
221            self.read_offset = end_offset;
222            Poll::Ready(Ok(bytes_read))
223        }
224    }
225}