proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    AddPeer,
124    BufferSaved,
125    ChannelMessageSent,
126    CloseBuffer,
127    CloseWorktree,
128    GetChannels,
129    GetChannelsResponse,
130    GetUsers,
131    GetUsersResponse,
132    JoinChannel,
133    JoinChannelResponse,
134    LeaveChannel,
135    OpenBuffer,
136    OpenBufferResponse,
137    OpenWorktree,
138    OpenWorktreeResponse,
139    Ping,
140    Pong,
141    RemovePeer,
142    SaveBuffer,
143    SendChannelMessage,
144    SendChannelMessageResponse,
145    ShareWorktree,
146    ShareWorktreeResponse,
147    UpdateBuffer,
148    UpdateWorktree,
149);
150
151request_messages!(
152    (GetChannels, GetChannelsResponse),
153    (GetUsers, GetUsersResponse),
154    (JoinChannel, JoinChannelResponse),
155    (OpenBuffer, OpenBufferResponse),
156    (OpenWorktree, OpenWorktreeResponse),
157    (Ping, Pong),
158    (SaveBuffer, BufferSaved),
159    (ShareWorktree, ShareWorktreeResponse),
160    (SendChannelMessage, SendChannelMessageResponse),
161);
162
163entity_messages!(
164    worktree_id,
165    AddPeer,
166    BufferSaved,
167    CloseBuffer,
168    CloseWorktree,
169    OpenBuffer,
170    OpenWorktree,
171    RemovePeer,
172    SaveBuffer,
173    UpdateBuffer,
174    UpdateWorktree,
175);
176
177entity_messages!(channel_id, ChannelMessageSent);
178
179/// A stream of protobuf messages.
180pub struct MessageStream<S> {
181    stream: S,
182}
183
184impl<S> MessageStream<S> {
185    pub fn new(stream: S) -> Self {
186        Self { stream }
187    }
188
189    pub fn inner_mut(&mut self) -> &mut S {
190        &mut self.stream
191    }
192}
193
194impl<S> MessageStream<S>
195where
196    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
197{
198    /// Write a given protobuf message to the stream.
199    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
200        let mut buffer = Vec::with_capacity(message.encoded_len());
201        message
202            .encode(&mut buffer)
203            .map_err(|err| io::Error::from(err))?;
204        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
205        Ok(())
206    }
207}
208
209impl<S> MessageStream<S>
210where
211    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
212{
213    /// Read a protobuf message of the given type from the stream.
214    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
215        while let Some(bytes) = self.stream.next().await {
216            match bytes? {
217                WebSocketMessage::Binary(bytes) => {
218                    let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
219                    return Ok(envelope);
220                }
221                WebSocketMessage::Close(_) => break,
222                _ => {}
223            }
224        }
225        Err(WebSocketError::ConnectionClosed)
226    }
227}
228
229impl Into<SystemTime> for Timestamp {
230    fn into(self) -> SystemTime {
231        UNIX_EPOCH
232            .checked_add(Duration::new(self.seconds, self.nanos))
233            .unwrap()
234    }
235}
236
237impl From<SystemTime> for Timestamp {
238    fn from(time: SystemTime) -> Self {
239        let duration = time.duration_since(UNIX_EPOCH).unwrap();
240        Self {
241            seconds: duration.as_secs(),
242            nanos: duration.subsec_nanos(),
243        }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use crate::test;
251
252    #[test]
253    fn test_round_trip_message() {
254        smol::block_on(async {
255            let stream = test::Channel::new();
256            let message1 = Ping { id: 5 }.into_envelope(3, None, None);
257            let message2 = OpenBuffer {
258                worktree_id: 0,
259                path: "some/path".to_string(),
260            }
261            .into_envelope(5, None, None);
262
263            let mut message_stream = MessageStream::new(stream);
264            message_stream.write_message(&message1).await.unwrap();
265            message_stream.write_message(&message2).await.unwrap();
266            let decoded_message1 = message_stream.read_message().await.unwrap();
267            let decoded_message2 = message_stream.read_message().await.unwrap();
268            assert_eq!(decoded_message1, message1);
269            assert_eq!(decoded_message2, message2);
270        });
271    }
272}