proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    Ack,
124    AddPeer,
125    BufferSaved,
126    ChannelMessageSent,
127    CloseBuffer,
128    CloseWorktree,
129    Error,
130    GetChannelMessages,
131    GetChannelMessagesResponse,
132    GetChannels,
133    GetChannelsResponse,
134    GetUsers,
135    GetUsersResponse,
136    JoinChannel,
137    JoinChannelResponse,
138    LeaveChannel,
139    OpenBuffer,
140    OpenBufferResponse,
141    OpenWorktree,
142    OpenWorktreeResponse,
143    Ping,
144    RemovePeer,
145    SaveBuffer,
146    SendChannelMessage,
147    SendChannelMessageResponse,
148    ShareWorktree,
149    ShareWorktreeResponse,
150    UpdateBuffer,
151    UpdateWorktree,
152);
153
154request_messages!(
155    (GetChannels, GetChannelsResponse),
156    (GetUsers, GetUsersResponse),
157    (JoinChannel, JoinChannelResponse),
158    (OpenBuffer, OpenBufferResponse),
159    (OpenWorktree, OpenWorktreeResponse),
160    (Ping, Ack),
161    (SaveBuffer, BufferSaved),
162    (UpdateBuffer, Ack),
163    (ShareWorktree, ShareWorktreeResponse),
164    (SendChannelMessage, SendChannelMessageResponse),
165    (GetChannelMessages, GetChannelMessagesResponse),
166);
167
168entity_messages!(
169    worktree_id,
170    AddPeer,
171    BufferSaved,
172    CloseBuffer,
173    CloseWorktree,
174    OpenBuffer,
175    OpenWorktree,
176    RemovePeer,
177    SaveBuffer,
178    UpdateBuffer,
179    UpdateWorktree,
180);
181
182entity_messages!(channel_id, ChannelMessageSent);
183
184/// A stream of protobuf messages.
185pub struct MessageStream<S> {
186    stream: S,
187}
188
189impl<S> MessageStream<S> {
190    pub fn new(stream: S) -> Self {
191        Self { stream }
192    }
193
194    pub fn inner_mut(&mut self) -> &mut S {
195        &mut self.stream
196    }
197}
198
199impl<S> MessageStream<S>
200where
201    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
202{
203    /// Write a given protobuf message to the stream.
204    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
205        let mut buffer = Vec::with_capacity(message.encoded_len());
206        message
207            .encode(&mut buffer)
208            .map_err(|err| io::Error::from(err))?;
209        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
210        Ok(())
211    }
212}
213
214impl<S> MessageStream<S>
215where
216    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
217{
218    /// Read a protobuf message of the given type from the stream.
219    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
220        while let Some(bytes) = self.stream.next().await {
221            match bytes? {
222                WebSocketMessage::Binary(bytes) => {
223                    let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
224                    return Ok(envelope);
225                }
226                WebSocketMessage::Close(_) => break,
227                _ => {}
228            }
229        }
230        Err(WebSocketError::ConnectionClosed)
231    }
232}
233
234impl Into<SystemTime> for Timestamp {
235    fn into(self) -> SystemTime {
236        UNIX_EPOCH
237            .checked_add(Duration::new(self.seconds, self.nanos))
238            .unwrap()
239    }
240}
241
242impl From<SystemTime> for Timestamp {
243    fn from(time: SystemTime) -> Self {
244        let duration = time.duration_since(UNIX_EPOCH).unwrap();
245        Self {
246            seconds: duration.as_secs(),
247            nanos: duration.subsec_nanos(),
248        }
249    }
250}