proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    Ack,
124    AddPeer,
125    BufferSaved,
126    ChannelMessageSent,
127    CloseBuffer,
128    CloseWorktree,
129    Error,
130    GetChannelMessages,
131    GetChannelMessagesResponse,
132    GetChannels,
133    GetChannelsResponse,
134    UpdateCollaborators,
135    GetUsers,
136    GetUsersResponse,
137    JoinChannel,
138    JoinChannelResponse,
139    JoinWorktree,
140    JoinWorktreeResponse,
141    LeaveChannel,
142    LeaveWorktree,
143    OpenBuffer,
144    OpenBufferResponse,
145    OpenWorktree,
146    OpenWorktreeResponse,
147    Ping,
148    RemovePeer,
149    SaveBuffer,
150    SendChannelMessage,
151    SendChannelMessageResponse,
152    ShareWorktree,
153    ShareWorktreeResponse,
154    UnshareWorktree,
155    UpdateBuffer,
156    UpdateWorktree,
157);
158
159request_messages!(
160    (GetChannels, GetChannelsResponse),
161    (GetUsers, GetUsersResponse),
162    (JoinChannel, JoinChannelResponse),
163    (OpenBuffer, OpenBufferResponse),
164    (JoinWorktree, JoinWorktreeResponse),
165    (OpenWorktree, OpenWorktreeResponse),
166    (Ping, Ack),
167    (SaveBuffer, BufferSaved),
168    (UpdateBuffer, Ack),
169    (ShareWorktree, ShareWorktreeResponse),
170    (UnshareWorktree, Ack),
171    (SendChannelMessage, SendChannelMessageResponse),
172    (GetChannelMessages, GetChannelMessagesResponse),
173);
174
175entity_messages!(
176    worktree_id,
177    AddPeer,
178    BufferSaved,
179    CloseBuffer,
180    CloseWorktree,
181    OpenBuffer,
182    JoinWorktree,
183    RemovePeer,
184    SaveBuffer,
185    UnshareWorktree,
186    UpdateBuffer,
187    UpdateWorktree,
188);
189
190entity_messages!(channel_id, ChannelMessageSent);
191
192/// A stream of protobuf messages.
193pub struct MessageStream<S> {
194    stream: S,
195    encoding_buffer: Vec<u8>,
196}
197
198impl<S> MessageStream<S> {
199    pub fn new(stream: S) -> Self {
200        Self {
201            stream,
202            encoding_buffer: Vec::new(),
203        }
204    }
205
206    pub fn inner_mut(&mut self) -> &mut S {
207        &mut self.stream
208    }
209}
210
211impl<S> MessageStream<S>
212where
213    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
214{
215    /// Write a given protobuf message to the stream.
216    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
217        self.encoding_buffer.resize(message.encoded_len(), 0);
218        self.encoding_buffer.clear();
219        message
220            .encode(&mut self.encoding_buffer)
221            .map_err(|err| io::Error::from(err))?;
222        let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
223        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
224        Ok(())
225    }
226}
227
228impl<S> MessageStream<S>
229where
230    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
231{
232    /// Read a protobuf message of the given type from the stream.
233    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
234        while let Some(bytes) = self.stream.next().await {
235            match bytes? {
236                WebSocketMessage::Binary(bytes) => {
237                    self.encoding_buffer.clear();
238                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
239                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
240                        .map_err(io::Error::from)?;
241                    return Ok(envelope);
242                }
243                WebSocketMessage::Close(_) => break,
244                _ => {}
245            }
246        }
247        Err(WebSocketError::ConnectionClosed)
248    }
249}
250
251impl Into<SystemTime> for Timestamp {
252    fn into(self) -> SystemTime {
253        UNIX_EPOCH
254            .checked_add(Duration::new(self.seconds, self.nanos))
255            .unwrap()
256    }
257}
258
259impl From<SystemTime> for Timestamp {
260    fn from(time: SystemTime) -> Self {
261        let duration = time.duration_since(UNIX_EPOCH).unwrap();
262        Self {
263            seconds: duration.as_secs(),
264            nanos: duration.subsec_nanos(),
265        }
266    }
267}
268
269impl From<u128> for Nonce {
270    fn from(nonce: u128) -> Self {
271        let upper_half = (nonce >> 64) as u64;
272        let lower_half = nonce as u64;
273        Self {
274            upper_half,
275            lower_half,
276        }
277    }
278}
279
280impl From<Nonce> for u128 {
281    fn from(nonce: Nonce) -> Self {
282        let upper_half = (nonce.upper_half as u128) << 64;
283        let lower_half = nonce.lower_half as u128;
284        upper_half | lower_half
285    }
286}