proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    Ack,
124    AddProjectCollaborator,
125    BufferSaved,
126    ChannelMessageSent,
127    CloseBuffer,
128    DiskBasedDiagnosticsUpdated,
129    DiskBasedDiagnosticsUpdating,
130    Error,
131    FormatBuffer,
132    GetChannelMessages,
133    GetChannelMessagesResponse,
134    GetChannels,
135    GetChannelsResponse,
136    GetUsers,
137    GetUsersResponse,
138    JoinChannel,
139    JoinChannelResponse,
140    JoinProject,
141    JoinProjectResponse,
142    LeaveChannel,
143    LeaveProject,
144    OpenBuffer,
145    OpenBufferResponse,
146    RegisterProjectResponse,
147    Ping,
148    RegisterProject,
149    RegisterWorktree,
150    RemoveProjectCollaborator,
151    SaveBuffer,
152    SendChannelMessage,
153    SendChannelMessageResponse,
154    ShareProject,
155    ShareWorktree,
156    UnregisterProject,
157    UnregisterWorktree,
158    UnshareProject,
159    UpdateBuffer,
160    UpdateContacts,
161    UpdateDiagnosticSummary,
162    UpdateWorktree,
163);
164
165request_messages!(
166    (FormatBuffer, Ack),
167    (GetChannelMessages, GetChannelMessagesResponse),
168    (GetChannels, GetChannelsResponse),
169    (GetUsers, GetUsersResponse),
170    (JoinChannel, JoinChannelResponse),
171    (JoinProject, JoinProjectResponse),
172    (OpenBuffer, OpenBufferResponse),
173    (Ping, Ack),
174    (RegisterProject, RegisterProjectResponse),
175    (RegisterWorktree, Ack),
176    (SaveBuffer, BufferSaved),
177    (SendChannelMessage, SendChannelMessageResponse),
178    (ShareProject, Ack),
179    (ShareWorktree, Ack),
180    (UpdateBuffer, Ack),
181);
182
183entity_messages!(
184    project_id,
185    AddProjectCollaborator,
186    BufferSaved,
187    CloseBuffer,
188    DiskBasedDiagnosticsUpdated,
189    DiskBasedDiagnosticsUpdating,
190    FormatBuffer,
191    JoinProject,
192    LeaveProject,
193    OpenBuffer,
194    RemoveProjectCollaborator,
195    SaveBuffer,
196    ShareWorktree,
197    UnregisterWorktree,
198    UnshareProject,
199    UpdateBuffer,
200    UpdateDiagnosticSummary,
201    UpdateWorktree,
202);
203
204entity_messages!(channel_id, ChannelMessageSent);
205
206/// A stream of protobuf messages.
207pub struct MessageStream<S> {
208    stream: S,
209    encoding_buffer: Vec<u8>,
210}
211
212impl<S> MessageStream<S> {
213    pub fn new(stream: S) -> Self {
214        Self {
215            stream,
216            encoding_buffer: Vec::new(),
217        }
218    }
219
220    pub fn inner_mut(&mut self) -> &mut S {
221        &mut self.stream
222    }
223}
224
225impl<S> MessageStream<S>
226where
227    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
228{
229    /// Write a given protobuf message to the stream.
230    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
231        self.encoding_buffer.resize(message.encoded_len(), 0);
232        self.encoding_buffer.clear();
233        message
234            .encode(&mut self.encoding_buffer)
235            .map_err(|err| io::Error::from(err))?;
236        let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
237        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
238        Ok(())
239    }
240}
241
242impl<S> MessageStream<S>
243where
244    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
245{
246    /// Read a protobuf message of the given type from the stream.
247    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
248        while let Some(bytes) = self.stream.next().await {
249            match bytes? {
250                WebSocketMessage::Binary(bytes) => {
251                    self.encoding_buffer.clear();
252                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
253                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
254                        .map_err(io::Error::from)?;
255                    return Ok(envelope);
256                }
257                WebSocketMessage::Close(_) => break,
258                _ => {}
259            }
260        }
261        Err(WebSocketError::ConnectionClosed)
262    }
263}
264
265impl Into<SystemTime> for Timestamp {
266    fn into(self) -> SystemTime {
267        UNIX_EPOCH
268            .checked_add(Duration::new(self.seconds, self.nanos))
269            .unwrap()
270    }
271}
272
273impl From<SystemTime> for Timestamp {
274    fn from(time: SystemTime) -> Self {
275        let duration = time.duration_since(UNIX_EPOCH).unwrap();
276        Self {
277            seconds: duration.as_secs(),
278            nanos: duration.subsec_nanos(),
279        }
280    }
281}
282
283impl From<u128> for Nonce {
284    fn from(nonce: u128) -> Self {
285        let upper_half = (nonce >> 64) as u64;
286        let lower_half = nonce as u64;
287        Self {
288            upper_half,
289            lower_half,
290        }
291    }
292}
293
294impl From<Nonce> for u128 {
295    fn from(nonce: Nonce) -> Self {
296        let upper_half = (nonce.upper_half as u128) << 64;
297        let lower_half = nonce.lower_half as u128;
298        upper_half | lower_half
299    }
300}