proto.rs

  1use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _};
  2use prost::Message;
  3use std::{
  4    convert::TryInto,
  5    io,
  6    time::{Duration, SystemTime, UNIX_EPOCH},
  7};
  8
  9include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 10
 11pub trait EnvelopedMessage: Clone + Sized + Send + 'static {
 12    const NAME: &'static str;
 13    fn into_envelope(
 14        self,
 15        id: u32,
 16        responding_to: Option<u32>,
 17        original_sender_id: Option<u32>,
 18    ) -> Envelope;
 19    fn matches_envelope(envelope: &Envelope) -> bool;
 20    fn from_envelope(envelope: Envelope) -> Option<Self>;
 21}
 22
 23pub trait RequestMessage: EnvelopedMessage {
 24    type Response: EnvelopedMessage;
 25}
 26
 27macro_rules! message {
 28    ($name:ident) => {
 29        impl EnvelopedMessage for $name {
 30            const NAME: &'static str = std::stringify!($name);
 31
 32            fn into_envelope(
 33                self,
 34                id: u32,
 35                responding_to: Option<u32>,
 36                original_sender_id: Option<u32>,
 37            ) -> Envelope {
 38                Envelope {
 39                    id,
 40                    responding_to,
 41                    original_sender_id,
 42                    payload: Some(envelope::Payload::$name(self)),
 43                }
 44            }
 45
 46            fn matches_envelope(envelope: &Envelope) -> bool {
 47                matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
 48            }
 49
 50            fn from_envelope(envelope: Envelope) -> Option<Self> {
 51                if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 52                    Some(msg)
 53                } else {
 54                    None
 55                }
 56            }
 57        }
 58    };
 59}
 60
 61macro_rules! request_message {
 62    ($req:ident, $resp:ident) => {
 63        message!($req);
 64        message!($resp);
 65        impl RequestMessage for $req {
 66            type Response = $resp;
 67        }
 68    };
 69}
 70
 71request_message!(Auth, AuthResponse);
 72request_message!(ShareWorktree, ShareWorktreeResponse);
 73request_message!(OpenWorktree, OpenWorktreeResponse);
 74message!(UpdateWorktree);
 75message!(CloseWorktree);
 76request_message!(OpenBuffer, OpenBufferResponse);
 77message!(CloseBuffer);
 78message!(UpdateBuffer);
 79request_message!(SaveBuffer, BufferSaved);
 80message!(AddPeer);
 81message!(RemovePeer);
 82
 83/// A stream of protobuf messages.
 84pub struct MessageStream<T> {
 85    byte_stream: T,
 86    buffer: Vec<u8>,
 87    upcoming_message_len: Option<usize>,
 88}
 89
 90impl<T> MessageStream<T> {
 91    pub fn new(byte_stream: T) -> Self {
 92        Self {
 93            byte_stream,
 94            buffer: Default::default(),
 95            upcoming_message_len: None,
 96        }
 97    }
 98
 99    pub fn inner_mut(&mut self) -> &mut T {
100        &mut self.byte_stream
101    }
102}
103
104impl<T> MessageStream<T>
105where
106    T: AsyncWrite + Unpin,
107{
108    /// Write a given protobuf message to the stream.
109    pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
110        let message_len: u32 = message
111            .encoded_len()
112            .try_into()
113            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
114        self.buffer.clear();
115        self.buffer.extend_from_slice(&message_len.to_be_bytes());
116        message.encode(&mut self.buffer)?;
117        self.byte_stream.write_all(&self.buffer).await
118    }
119}
120
121impl<T> MessageStream<T>
122where
123    T: AsyncRead + Unpin,
124{
125    /// Read a protobuf message of the given type from the stream.
126    pub async fn read_message(&mut self) -> io::Result<Envelope> {
127        loop {
128            if let Some(upcoming_message_len) = self.upcoming_message_len {
129                self.buffer.resize(upcoming_message_len, 0);
130                self.byte_stream.read_exact(&mut self.buffer).await?;
131                self.upcoming_message_len = None;
132                return Ok(Envelope::decode(self.buffer.as_slice())?);
133            } else {
134                self.buffer.resize(4, 0);
135                self.byte_stream.read_exact(&mut self.buffer).await?;
136                self.upcoming_message_len = Some(u32::from_be_bytes([
137                    self.buffer[0],
138                    self.buffer[1],
139                    self.buffer[2],
140                    self.buffer[3],
141                ]) as usize);
142            }
143        }
144    }
145}
146
147impl Into<SystemTime> for Timestamp {
148    fn into(self) -> SystemTime {
149        UNIX_EPOCH
150            .checked_add(Duration::new(self.seconds, self.nanos))
151            .unwrap()
152    }
153}
154
155impl From<SystemTime> for Timestamp {
156    fn from(time: SystemTime) -> Self {
157        let duration = time.duration_since(UNIX_EPOCH).unwrap();
158        Self {
159            seconds: duration.as_secs(),
160            nanos: duration.subsec_nanos(),
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use std::{
169        pin::Pin,
170        task::{Context, Poll},
171    };
172
173    #[test]
174    fn test_round_trip_message() {
175        smol::block_on(async {
176            let byte_stream = ChunkedStream {
177                bytes: Vec::new(),
178                read_offset: 0,
179                chunk_size: 3,
180            };
181
182            let message1 = Auth {
183                user_id: 5,
184                access_token: "the-access-token".into(),
185            }
186            .into_envelope(3, None, None);
187
188            let message2 = OpenBuffer {
189                worktree_id: 0,
190                path: "some/path".to_string(),
191            }
192            .into_envelope(5, None, None);
193
194            let mut message_stream = MessageStream::new(byte_stream);
195            message_stream.write_message(&message1).await.unwrap();
196            message_stream.write_message(&message2).await.unwrap();
197            let decoded_message1 = message_stream.read_message().await.unwrap();
198            let decoded_message2 = message_stream.read_message().await.unwrap();
199            assert_eq!(decoded_message1, message1);
200            assert_eq!(decoded_message2, message2);
201        });
202    }
203
204    struct ChunkedStream {
205        bytes: Vec<u8>,
206        read_offset: usize,
207        chunk_size: usize,
208    }
209
210    impl AsyncWrite for ChunkedStream {
211        fn poll_write(
212            mut self: Pin<&mut Self>,
213            _: &mut Context<'_>,
214            buf: &[u8],
215        ) -> Poll<io::Result<usize>> {
216            let bytes_written = buf.len().min(self.chunk_size);
217            self.bytes.extend_from_slice(&buf[0..bytes_written]);
218            Poll::Ready(Ok(bytes_written))
219        }
220
221        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
222            Poll::Ready(Ok(()))
223        }
224
225        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
226            Poll::Ready(Ok(()))
227        }
228    }
229
230    impl AsyncRead for ChunkedStream {
231        fn poll_read(
232            mut self: Pin<&mut Self>,
233            _: &mut Context<'_>,
234            buf: &mut [u8],
235        ) -> Poll<io::Result<usize>> {
236            let bytes_read = buf
237                .len()
238                .min(self.chunk_size)
239                .min(self.bytes.len() - self.read_offset);
240            let end_offset = self.read_offset + bytes_read;
241            buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
242            self.read_offset = end_offset;
243            Poll::Ready(Ok(bytes_read))
244        }
245    }
246}