proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn matches_envelope(envelope: &Envelope) -> bool;
 23    fn from_envelope(envelope: Envelope) -> Option<Self>;
 24}
 25
 26pub trait EntityMessage: EnvelopedMessage {
 27    fn remote_entity_id(&self) -> u64;
 28}
 29
 30pub trait RequestMessage: EnvelopedMessage {
 31    type Response: EnvelopedMessage;
 32}
 33
 34pub trait AnyTypedEnvelope: 'static + Send + Sync {
 35    fn payload_type_id(&self) -> TypeId;
 36    fn payload_type_name(&self) -> &'static str;
 37    fn as_any(&self) -> &dyn Any;
 38    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 39}
 40
 41impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 42    fn payload_type_id(&self) -> TypeId {
 43        TypeId::of::<T>()
 44    }
 45
 46    fn payload_type_name(&self) -> &'static str {
 47        T::NAME
 48    }
 49
 50    fn as_any(&self) -> &dyn Any {
 51        self
 52    }
 53
 54    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 55        self
 56    }
 57}
 58
 59macro_rules! messages {
 60    ($($name:ident),* $(,)?) => {
 61        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 62            match envelope.payload {
 63                $(Some(envelope::Payload::$name(payload)) => {
 64                    Some(Box::new(TypedEnvelope {
 65                        sender_id,
 66                        original_sender_id: envelope.original_sender_id.map(PeerId),
 67                        message_id: envelope.id,
 68                        payload,
 69                    }))
 70                }, )*
 71                _ => None
 72            }
 73        }
 74
 75        $(
 76            impl EnvelopedMessage for $name {
 77                const NAME: &'static str = std::stringify!($name);
 78
 79                fn into_envelope(
 80                    self,
 81                    id: u32,
 82                    responding_to: Option<u32>,
 83                    original_sender_id: Option<u32>,
 84                ) -> Envelope {
 85                    Envelope {
 86                        id,
 87                        responding_to,
 88                        original_sender_id,
 89                        payload: Some(envelope::Payload::$name(self)),
 90                    }
 91                }
 92
 93                fn matches_envelope(envelope: &Envelope) -> bool {
 94                    matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
 95                }
 96
 97                fn from_envelope(envelope: Envelope) -> Option<Self> {
 98                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 99                        Some(msg)
100                    } else {
101                        None
102                    }
103                }
104            }
105        )*
106    };
107}
108
109macro_rules! request_messages {
110    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
111        $(impl RequestMessage for $request_name {
112            type Response = $response_name;
113        })*
114    };
115}
116
117macro_rules! entity_messages {
118    ($id_field:ident, $($name:ident),* $(,)?) => {
119        $(impl EntityMessage for $name {
120            fn remote_entity_id(&self) -> u64 {
121                self.$id_field
122            }
123        })*
124    };
125}
126
127messages!(
128    AddPeer,
129    BufferSaved,
130    ChannelMessageSent,
131    CloseBuffer,
132    CloseWorktree,
133    GetChannels,
134    GetChannelsResponse,
135    GetUsers,
136    GetUsersResponse,
137    JoinChannel,
138    JoinChannelResponse,
139    LeaveChannel,
140    OpenBuffer,
141    OpenBufferResponse,
142    OpenWorktree,
143    OpenWorktreeResponse,
144    Ping,
145    Pong,
146    RemovePeer,
147    SaveBuffer,
148    SendChannelMessage,
149    SendChannelMessageResponse,
150    ShareWorktree,
151    ShareWorktreeResponse,
152    UpdateBuffer,
153    UpdateWorktree,
154);
155
156request_messages!(
157    (GetChannels, GetChannelsResponse),
158    (GetUsers, GetUsersResponse),
159    (JoinChannel, JoinChannelResponse),
160    (OpenBuffer, OpenBufferResponse),
161    (OpenWorktree, OpenWorktreeResponse),
162    (Ping, Pong),
163    (SaveBuffer, BufferSaved),
164    (ShareWorktree, ShareWorktreeResponse),
165    (SendChannelMessage, SendChannelMessageResponse),
166);
167
168entity_messages!(
169    worktree_id,
170    AddPeer,
171    BufferSaved,
172    CloseBuffer,
173    CloseWorktree,
174    OpenBuffer,
175    OpenWorktree,
176    RemovePeer,
177    SaveBuffer,
178    UpdateBuffer,
179    UpdateWorktree,
180);
181
182entity_messages!(channel_id, ChannelMessageSent);
183
184/// A stream of protobuf messages.
185pub struct MessageStream<S> {
186    stream: S,
187}
188
189impl<S> MessageStream<S> {
190    pub fn new(stream: S) -> Self {
191        Self { stream }
192    }
193
194    pub fn inner_mut(&mut self) -> &mut S {
195        &mut self.stream
196    }
197}
198
199impl<S> MessageStream<S>
200where
201    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
202{
203    /// Write a given protobuf message to the stream.
204    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
205        let mut buffer = Vec::with_capacity(message.encoded_len());
206        message
207            .encode(&mut buffer)
208            .map_err(|err| io::Error::from(err))?;
209        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
210        Ok(())
211    }
212}
213
214impl<S> MessageStream<S>
215where
216    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
217{
218    /// Read a protobuf message of the given type from the stream.
219    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
220        while let Some(bytes) = self.stream.next().await {
221            match bytes? {
222                WebSocketMessage::Binary(bytes) => {
223                    let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
224                    return Ok(envelope);
225                }
226                WebSocketMessage::Close(_) => break,
227                _ => {}
228            }
229        }
230        Err(WebSocketError::ConnectionClosed)
231    }
232}
233
234impl Into<SystemTime> for Timestamp {
235    fn into(self) -> SystemTime {
236        UNIX_EPOCH
237            .checked_add(Duration::new(self.seconds, self.nanos))
238            .unwrap()
239    }
240}
241
242impl From<SystemTime> for Timestamp {
243    fn from(time: SystemTime) -> Self {
244        let duration = time.duration_since(UNIX_EPOCH).unwrap();
245        Self {
246            seconds: duration.as_secs(),
247            nanos: duration.subsec_nanos(),
248        }
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use crate::test;
256
257    #[test]
258    fn test_round_trip_message() {
259        smol::block_on(async {
260            let stream = test::Channel::new();
261            let message1 = Ping { id: 5 }.into_envelope(3, None, None);
262            let message2 = OpenBuffer {
263                worktree_id: 0,
264                path: "some/path".to_string(),
265            }
266            .into_envelope(5, None, None);
267
268            let mut message_stream = MessageStream::new(stream);
269            message_stream.write_message(&message1).await.unwrap();
270            message_stream.write_message(&message2).await.unwrap();
271            let decoded_message1 = message_stream.read_message().await.unwrap();
272            let decoded_message2 = message_stream.read_message().await.unwrap();
273            assert_eq!(decoded_message1, message1);
274            assert_eq!(decoded_message2, message2);
275        });
276    }
277}