proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    Ack,
124    AddPeer,
125    BufferSaved,
126    ChannelMessageSent,
127    CloseBuffer,
128    CloseWorktree,
129    Error,
130    GetChannelMessages,
131    GetChannelMessagesResponse,
132    GetChannels,
133    GetChannelsResponse,
134    GetUsers,
135    GetUsersResponse,
136    JoinChannel,
137    JoinChannelResponse,
138    LeaveChannel,
139    OpenBuffer,
140    OpenBufferResponse,
141    OpenWorktree,
142    OpenWorktreeResponse,
143    Ping,
144    RemovePeer,
145    SaveBuffer,
146    SendChannelMessage,
147    SendChannelMessageResponse,
148    ShareWorktree,
149    ShareWorktreeResponse,
150    UpdateBuffer,
151    UpdateWorktree,
152);
153
154request_messages!(
155    (GetChannels, GetChannelsResponse),
156    (GetUsers, GetUsersResponse),
157    (JoinChannel, JoinChannelResponse),
158    (OpenBuffer, OpenBufferResponse),
159    (OpenWorktree, OpenWorktreeResponse),
160    (Ping, Ack),
161    (SaveBuffer, BufferSaved),
162    (UpdateBuffer, Ack),
163    (ShareWorktree, ShareWorktreeResponse),
164    (SendChannelMessage, SendChannelMessageResponse),
165    (GetChannelMessages, GetChannelMessagesResponse),
166);
167
168entity_messages!(
169    worktree_id,
170    AddPeer,
171    BufferSaved,
172    CloseBuffer,
173    CloseWorktree,
174    OpenBuffer,
175    OpenWorktree,
176    RemovePeer,
177    SaveBuffer,
178    UpdateBuffer,
179    UpdateWorktree,
180);
181
182entity_messages!(channel_id, ChannelMessageSent);
183
184/// A stream of protobuf messages.
185pub struct MessageStream<S> {
186    stream: S,
187}
188
189impl<S> MessageStream<S> {
190    pub fn new(stream: S) -> Self {
191        Self { stream }
192    }
193
194    pub fn inner_mut(&mut self) -> &mut S {
195        &mut self.stream
196    }
197}
198
199impl<S> MessageStream<S>
200where
201    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
202{
203    /// Write a given protobuf message to the stream.
204    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
205        let mut buffer = Vec::with_capacity(message.encoded_len());
206        message
207            .encode(&mut buffer)
208            .map_err(|err| io::Error::from(err))?;
209        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
210        Ok(())
211    }
212}
213
214impl<S> MessageStream<S>
215where
216    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
217{
218    /// Read a protobuf message of the given type from the stream.
219    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
220        while let Some(bytes) = self.stream.next().await {
221            match bytes? {
222                WebSocketMessage::Binary(bytes) => {
223                    let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
224                    return Ok(envelope);
225                }
226                WebSocketMessage::Close(_) => break,
227                _ => {}
228            }
229        }
230        Err(WebSocketError::ConnectionClosed)
231    }
232}
233
234impl Into<SystemTime> for Timestamp {
235    fn into(self) -> SystemTime {
236        UNIX_EPOCH
237            .checked_add(Duration::new(self.seconds, self.nanos))
238            .unwrap()
239    }
240}
241
242impl From<SystemTime> for Timestamp {
243    fn from(time: SystemTime) -> Self {
244        let duration = time.duration_since(UNIX_EPOCH).unwrap();
245        Self {
246            seconds: duration.as_secs(),
247            nanos: duration.subsec_nanos(),
248        }
249    }
250}
251
252impl From<u128> for Nonce {
253    fn from(nonce: u128) -> Self {
254        let upper_half = (nonce >> 64) as u64;
255        let lower_half = nonce as u64;
256        Self {
257            upper_half,
258            lower_half,
259        }
260    }
261}
262
263impl From<Nonce> for u128 {
264    fn from(nonce: Nonce) -> Self {
265        let upper_half = (nonce.upper_half as u128) << 64;
266        let lower_half = nonce.lower_half as u128;
267        upper_half | lower_half
268    }
269}