proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    AddPeer,
124    BufferSaved,
125    ChannelMessageSent,
126    CloseBuffer,
127    CloseWorktree,
128    Error,
129    GetChannelMessages,
130    GetChannelMessagesResponse,
131    GetChannels,
132    GetChannelsResponse,
133    GetUsers,
134    GetUsersResponse,
135    JoinChannel,
136    JoinChannelResponse,
137    LeaveChannel,
138    OpenBuffer,
139    OpenBufferResponse,
140    OpenWorktree,
141    OpenWorktreeResponse,
142    Ping,
143    Pong,
144    RemovePeer,
145    SaveBuffer,
146    SendChannelMessage,
147    SendChannelMessageResponse,
148    ShareWorktree,
149    ShareWorktreeResponse,
150    UpdateBuffer,
151    UpdateWorktree,
152);
153
154request_messages!(
155    (GetChannels, GetChannelsResponse),
156    (GetUsers, GetUsersResponse),
157    (JoinChannel, JoinChannelResponse),
158    (OpenBuffer, OpenBufferResponse),
159    (OpenWorktree, OpenWorktreeResponse),
160    (Ping, Pong),
161    (SaveBuffer, BufferSaved),
162    (ShareWorktree, ShareWorktreeResponse),
163    (SendChannelMessage, SendChannelMessageResponse),
164    (GetChannelMessages, GetChannelMessagesResponse),
165);
166
167entity_messages!(
168    worktree_id,
169    AddPeer,
170    BufferSaved,
171    CloseBuffer,
172    CloseWorktree,
173    OpenBuffer,
174    OpenWorktree,
175    RemovePeer,
176    SaveBuffer,
177    UpdateBuffer,
178    UpdateWorktree,
179);
180
181entity_messages!(channel_id, ChannelMessageSent);
182
183/// A stream of protobuf messages.
184pub struct MessageStream<S> {
185    stream: S,
186}
187
188impl<S> MessageStream<S> {
189    pub fn new(stream: S) -> Self {
190        Self { stream }
191    }
192
193    pub fn inner_mut(&mut self) -> &mut S {
194        &mut self.stream
195    }
196}
197
198impl<S> MessageStream<S>
199where
200    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
201{
202    /// Write a given protobuf message to the stream.
203    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
204        let mut buffer = Vec::with_capacity(message.encoded_len());
205        message
206            .encode(&mut buffer)
207            .map_err(|err| io::Error::from(err))?;
208        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
209        Ok(())
210    }
211}
212
213impl<S> MessageStream<S>
214where
215    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
216{
217    /// Read a protobuf message of the given type from the stream.
218    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
219        while let Some(bytes) = self.stream.next().await {
220            match bytes? {
221                WebSocketMessage::Binary(bytes) => {
222                    let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
223                    return Ok(envelope);
224                }
225                WebSocketMessage::Close(_) => break,
226                _ => {}
227            }
228        }
229        Err(WebSocketError::ConnectionClosed)
230    }
231}
232
233impl Into<SystemTime> for Timestamp {
234    fn into(self) -> SystemTime {
235        UNIX_EPOCH
236            .checked_add(Duration::new(self.seconds, self.nanos))
237            .unwrap()
238    }
239}
240
241impl From<SystemTime> for Timestamp {
242    fn from(time: SystemTime) -> Self {
243        let duration = time.duration_since(UNIX_EPOCH).unwrap();
244        Self {
245            seconds: duration.as_secs(),
246            nanos: duration.subsec_nanos(),
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::test;
255
256    #[test]
257    fn test_round_trip_message() {
258        smol::block_on(async {
259            let stream = test::Channel::new();
260            let message1 = Ping { id: 5 }.into_envelope(3, None, None);
261            let message2 = OpenBuffer {
262                worktree_id: 0,
263                path: "some/path".to_string(),
264            }
265            .into_envelope(5, None, None);
266
267            let mut message_stream = MessageStream::new(stream);
268            message_stream.write_message(&message1).await.unwrap();
269            message_stream.write_message(&message2).await.unwrap();
270            let decoded_message1 = message_stream.read_message().await.unwrap();
271            let decoded_message2 = message_stream.read_message().await.unwrap();
272            assert_eq!(decoded_message1, message1);
273            assert_eq!(decoded_message2, message2);
274        });
275    }
276}