proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    Ack,
124    AddProjectCollaborator,
125    BufferSaved,
126    ChannelMessageSent,
127    CloseBuffer,
128    Error,
129    GetChannelMessages,
130    GetChannelMessagesResponse,
131    GetChannels,
132    GetChannelsResponse,
133    GetUsers,
134    GetUsersResponse,
135    JoinChannel,
136    JoinChannelResponse,
137    JoinProject,
138    JoinProjectResponse,
139    LeaveChannel,
140    LeaveProject,
141    OpenBuffer,
142    OpenBufferResponse,
143    RegisterProjectResponse,
144    Ping,
145    RegisterProject,
146    RegisterWorktree,
147    RemoveProjectCollaborator,
148    SaveBuffer,
149    SendChannelMessage,
150    SendChannelMessageResponse,
151    ShareProject,
152    ShareWorktree,
153    UnregisterProject,
154    UnregisterWorktree,
155    UnshareProject,
156    UpdateBuffer,
157    UpdateContacts,
158    UpdateWorktree,
159);
160
161request_messages!(
162    (GetChannelMessages, GetChannelMessagesResponse),
163    (GetChannels, GetChannelsResponse),
164    (GetUsers, GetUsersResponse),
165    (JoinChannel, JoinChannelResponse),
166    (JoinProject, JoinProjectResponse),
167    (OpenBuffer, OpenBufferResponse),
168    (Ping, Ack),
169    (RegisterProject, RegisterProjectResponse),
170    (RegisterWorktree, Ack),
171    (SaveBuffer, BufferSaved),
172    (SendChannelMessage, SendChannelMessageResponse),
173    (ShareProject, Ack),
174    (ShareWorktree, Ack),
175    (UpdateBuffer, Ack),
176);
177
178entity_messages!(
179    project_id,
180    AddProjectCollaborator,
181    RemoveProjectCollaborator,
182    JoinProject,
183    LeaveProject,
184    BufferSaved,
185    OpenBuffer,
186    CloseBuffer,
187    SaveBuffer,
188    ShareWorktree,
189    UnregisterWorktree,
190    UnshareProject,
191    UpdateBuffer,
192    UpdateWorktree,
193);
194
195entity_messages!(channel_id, ChannelMessageSent);
196
197/// A stream of protobuf messages.
198pub struct MessageStream<S> {
199    stream: S,
200    encoding_buffer: Vec<u8>,
201}
202
203impl<S> MessageStream<S> {
204    pub fn new(stream: S) -> Self {
205        Self {
206            stream,
207            encoding_buffer: Vec::new(),
208        }
209    }
210
211    pub fn inner_mut(&mut self) -> &mut S {
212        &mut self.stream
213    }
214}
215
216impl<S> MessageStream<S>
217where
218    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
219{
220    /// Write a given protobuf message to the stream.
221    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
222        self.encoding_buffer.resize(message.encoded_len(), 0);
223        self.encoding_buffer.clear();
224        message
225            .encode(&mut self.encoding_buffer)
226            .map_err(|err| io::Error::from(err))?;
227        let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
228        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
229        Ok(())
230    }
231}
232
233impl<S> MessageStream<S>
234where
235    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
236{
237    /// Read a protobuf message of the given type from the stream.
238    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
239        while let Some(bytes) = self.stream.next().await {
240            match bytes? {
241                WebSocketMessage::Binary(bytes) => {
242                    self.encoding_buffer.clear();
243                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
244                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
245                        .map_err(io::Error::from)?;
246                    return Ok(envelope);
247                }
248                WebSocketMessage::Close(_) => break,
249                _ => {}
250            }
251        }
252        Err(WebSocketError::ConnectionClosed)
253    }
254}
255
256impl Into<SystemTime> for Timestamp {
257    fn into(self) -> SystemTime {
258        UNIX_EPOCH
259            .checked_add(Duration::new(self.seconds, self.nanos))
260            .unwrap()
261    }
262}
263
264impl From<SystemTime> for Timestamp {
265    fn from(time: SystemTime) -> Self {
266        let duration = time.duration_since(UNIX_EPOCH).unwrap();
267        Self {
268            seconds: duration.as_secs(),
269            nanos: duration.subsec_nanos(),
270        }
271    }
272}
273
274impl From<u128> for Nonce {
275    fn from(nonce: u128) -> Self {
276        let upper_half = (nonce >> 64) as u64;
277        let lower_half = nonce as u64;
278        Self {
279            upper_half,
280            lower_half,
281        }
282    }
283}
284
285impl From<Nonce> for u128 {
286    fn from(nonce: Nonce) -> Self {
287        let upper_half = (nonce.upper_half as u128) << 64;
288        let lower_half = nonce.lower_half as u128;
289        upper_half | lower_half
290    }
291}