proto.rs

  1use futures_io::{AsyncRead, AsyncWrite};
  2use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
  3use prost::Message;
  4use std::{convert::TryInto, io};
  5
  6include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
  7
  8/// A message that the client can send to the server.
  9pub trait ClientMessage: Sized {
 10    fn to_variant(self) -> from_client::Variant;
 11    fn from_variant(variant: from_client::Variant) -> Option<Self>;
 12}
 13
 14/// A message that the server can send to the client.
 15pub trait ServerMessage: Sized {
 16    fn to_variant(self) -> from_server::Variant;
 17    fn from_variant(variant: from_server::Variant) -> Option<Self>;
 18}
 19
 20/// A message that the client can send to the server, where the server must respond with a single
 21/// message of a certain type.
 22pub trait RequestMessage: ClientMessage {
 23    type Response: ServerMessage;
 24}
 25
 26/// A message that the client can send to the server, where the server must respond with a series of
 27/// messages of a certain type.
 28pub trait SubscribeMessage: ClientMessage {
 29    type Event: ServerMessage;
 30}
 31
 32/// A message that the client can send to the server, where the server will not respond.
 33pub trait SendMessage: ClientMessage {}
 34
 35macro_rules! directed_message {
 36    ($name:ident, $direction_trait:ident, $direction_module:ident) => {
 37        impl $direction_trait for $direction_module::$name {
 38            fn to_variant(self) -> $direction_module::Variant {
 39                $direction_module::Variant::$name(self)
 40            }
 41
 42            fn from_variant(variant: $direction_module::Variant) -> Option<Self> {
 43                if let $direction_module::Variant::$name(msg) = variant {
 44                    Some(msg)
 45                } else {
 46                    None
 47                }
 48            }
 49        }
 50    };
 51}
 52
 53macro_rules! request_message {
 54    ($req:ident, $resp:ident) => {
 55        directed_message!($req, ClientMessage, from_client);
 56        directed_message!($resp, ServerMessage, from_server);
 57        impl RequestMessage for from_client::$req {
 58            type Response = from_server::$resp;
 59        }
 60    };
 61}
 62
 63macro_rules! send_message {
 64    ($msg:ident) => {
 65        directed_message!($msg, ClientMessage, from_client);
 66        impl SendMessage for from_client::$msg {}
 67    };
 68}
 69
 70request_message!(Auth, AuthResponse);
 71request_message!(NewWorktree, NewWorktreeResponse);
 72request_message!(ShareWorktree, ShareWorktreeResponse);
 73send_message!(UploadFile);
 74
 75/// A stream of protobuf messages.
 76pub struct MessageStream<T> {
 77    byte_stream: T,
 78    buffer: Vec<u8>,
 79}
 80
 81impl<T> MessageStream<T> {
 82    pub fn new(byte_stream: T) -> Self {
 83        Self {
 84            byte_stream,
 85            buffer: Default::default(),
 86        }
 87    }
 88
 89    pub fn inner_mut(&mut self) -> &mut T {
 90        &mut self.byte_stream
 91    }
 92}
 93
 94impl<T> MessageStream<T>
 95where
 96    T: AsyncWrite + Unpin,
 97{
 98    /// Write a given protobuf message to the stream.
 99    pub async fn write_message(&mut self, message: &impl Message) -> io::Result<()> {
100        let message_len: u32 = message
101            .encoded_len()
102            .try_into()
103            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
104        self.buffer.clear();
105        self.buffer.extend_from_slice(&message_len.to_be_bytes());
106        message.encode(&mut self.buffer)?;
107        self.byte_stream.write_all(&self.buffer).await
108    }
109}
110
111impl<T> MessageStream<T>
112where
113    T: AsyncRead + Unpin,
114{
115    /// Read a protobuf message of the given type from the stream.
116    pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
117        let mut delimiter_buf = [0; 4];
118        self.byte_stream.read_exact(&mut delimiter_buf).await?;
119        let message_len = u32::from_be_bytes(delimiter_buf) as usize;
120        self.buffer.resize(message_len, 0);
121        self.byte_stream.read_exact(&mut self.buffer).await?;
122        Ok(M::decode(self.buffer.as_slice())?)
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use std::{
130        pin::Pin,
131        task::{Context, Poll},
132    };
133
134    #[test]
135    fn test_round_trip_message() {
136        smol::block_on(async {
137            let byte_stream = ChunkedStream {
138                bytes: Vec::new(),
139                read_offset: 0,
140                chunk_size: 3,
141            };
142
143            let message1 = FromClient {
144                id: 3,
145                variant: Some(from_client::Variant::Auth(from_client::Auth {
146                    user_id: 5,
147                    access_token: "the-access-token".into(),
148                })),
149            };
150            let message2 = FromClient {
151                id: 4,
152                variant: Some(from_client::Variant::UploadFile(from_client::UploadFile {
153                    path: Vec::new(),
154                    content: format!(
155                        "a {}long error message that requires a two-byte length delimiter",
156                        "very ".repeat(60)
157                    )
158                    .into(),
159                })),
160            };
161
162            let mut message_stream = MessageStream::new(byte_stream);
163            message_stream.write_message(&message1).await.unwrap();
164            message_stream.write_message(&message2).await.unwrap();
165            let decoded_message1 = message_stream.read_message::<FromClient>().await.unwrap();
166            let decoded_message2 = message_stream.read_message::<FromClient>().await.unwrap();
167            assert_eq!(decoded_message1, message1);
168            assert_eq!(decoded_message2, message2);
169        });
170    }
171
172    struct ChunkedStream {
173        bytes: Vec<u8>,
174        read_offset: usize,
175        chunk_size: usize,
176    }
177
178    impl AsyncWrite for ChunkedStream {
179        fn poll_write(
180            mut self: Pin<&mut Self>,
181            _: &mut Context<'_>,
182            buf: &[u8],
183        ) -> Poll<io::Result<usize>> {
184            let bytes_written = buf.len().min(self.chunk_size);
185            self.bytes.extend_from_slice(&buf[0..bytes_written]);
186            Poll::Ready(Ok(bytes_written))
187        }
188
189        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
190            Poll::Ready(Ok(()))
191        }
192
193        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
194            Poll::Ready(Ok(()))
195        }
196    }
197
198    impl AsyncRead for ChunkedStream {
199        fn poll_read(
200            mut self: Pin<&mut Self>,
201            _: &mut Context<'_>,
202            buf: &mut [u8],
203        ) -> Poll<io::Result<usize>> {
204            let bytes_read = buf
205                .len()
206                .min(self.chunk_size)
207                .min(self.bytes.len() - self.read_offset);
208            let end_offset = self.read_offset + bytes_read;
209            buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
210            self.read_offset = end_offset;
211            Poll::Ready(Ok(bytes_read))
212        }
213    }
214}