proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    Ack,
124    AddProjectCollaborator,
125    BufferSaved,
126    ChannelMessageSent,
127    CloseBuffer,
128    DiskBasedDiagnosticsUpdated,
129    DiskBasedDiagnosticsUpdating,
130    Error,
131    GetChannelMessages,
132    GetChannelMessagesResponse,
133    GetChannels,
134    GetChannelsResponse,
135    GetUsers,
136    GetUsersResponse,
137    JoinChannel,
138    JoinChannelResponse,
139    JoinProject,
140    JoinProjectResponse,
141    LeaveChannel,
142    LeaveProject,
143    OpenBuffer,
144    OpenBufferResponse,
145    RegisterProjectResponse,
146    Ping,
147    RegisterProject,
148    RegisterWorktree,
149    RemoveProjectCollaborator,
150    SaveBuffer,
151    SendChannelMessage,
152    SendChannelMessageResponse,
153    ShareProject,
154    ShareWorktree,
155    UnregisterProject,
156    UnregisterWorktree,
157    UnshareProject,
158    UpdateBuffer,
159    UpdateContacts,
160    UpdateDiagnosticSummary,
161    UpdateWorktree,
162);
163
164request_messages!(
165    (GetChannelMessages, GetChannelMessagesResponse),
166    (GetChannels, GetChannelsResponse),
167    (GetUsers, GetUsersResponse),
168    (JoinChannel, JoinChannelResponse),
169    (JoinProject, JoinProjectResponse),
170    (OpenBuffer, OpenBufferResponse),
171    (Ping, Ack),
172    (RegisterProject, RegisterProjectResponse),
173    (RegisterWorktree, Ack),
174    (SaveBuffer, BufferSaved),
175    (SendChannelMessage, SendChannelMessageResponse),
176    (ShareProject, Ack),
177    (ShareWorktree, Ack),
178    (UpdateBuffer, Ack),
179);
180
181entity_messages!(
182    project_id,
183    AddProjectCollaborator,
184    BufferSaved,
185    CloseBuffer,
186    DiskBasedDiagnosticsUpdated,
187    DiskBasedDiagnosticsUpdating,
188    JoinProject,
189    LeaveProject,
190    OpenBuffer,
191    RemoveProjectCollaborator,
192    SaveBuffer,
193    ShareWorktree,
194    UnregisterWorktree,
195    UnshareProject,
196    UpdateBuffer,
197    UpdateDiagnosticSummary,
198    UpdateWorktree,
199);
200
201entity_messages!(channel_id, ChannelMessageSent);
202
203/// A stream of protobuf messages.
204pub struct MessageStream<S> {
205    stream: S,
206    encoding_buffer: Vec<u8>,
207}
208
209impl<S> MessageStream<S> {
210    pub fn new(stream: S) -> Self {
211        Self {
212            stream,
213            encoding_buffer: Vec::new(),
214        }
215    }
216
217    pub fn inner_mut(&mut self) -> &mut S {
218        &mut self.stream
219    }
220}
221
222impl<S> MessageStream<S>
223where
224    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
225{
226    /// Write a given protobuf message to the stream.
227    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
228        self.encoding_buffer.resize(message.encoded_len(), 0);
229        self.encoding_buffer.clear();
230        message
231            .encode(&mut self.encoding_buffer)
232            .map_err(|err| io::Error::from(err))?;
233        let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
234        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
235        Ok(())
236    }
237}
238
239impl<S> MessageStream<S>
240where
241    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
242{
243    /// Read a protobuf message of the given type from the stream.
244    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
245        while let Some(bytes) = self.stream.next().await {
246            match bytes? {
247                WebSocketMessage::Binary(bytes) => {
248                    self.encoding_buffer.clear();
249                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
250                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
251                        .map_err(io::Error::from)?;
252                    return Ok(envelope);
253                }
254                WebSocketMessage::Close(_) => break,
255                _ => {}
256            }
257        }
258        Err(WebSocketError::ConnectionClosed)
259    }
260}
261
262impl Into<SystemTime> for Timestamp {
263    fn into(self) -> SystemTime {
264        UNIX_EPOCH
265            .checked_add(Duration::new(self.seconds, self.nanos))
266            .unwrap()
267    }
268}
269
270impl From<SystemTime> for Timestamp {
271    fn from(time: SystemTime) -> Self {
272        let duration = time.duration_since(UNIX_EPOCH).unwrap();
273        Self {
274            seconds: duration.as_secs(),
275            nanos: duration.subsec_nanos(),
276        }
277    }
278}
279
280impl From<u128> for Nonce {
281    fn from(nonce: u128) -> Self {
282        let upper_half = (nonce >> 64) as u64;
283        let lower_half = nonce as u64;
284        Self {
285            upper_half,
286            lower_half,
287        }
288    }
289}
290
291impl From<Nonce> for u128 {
292    fn from(nonce: Nonce) -> Self {
293        let upper_half = (nonce.upper_half as u128) << 64;
294        let lower_half = nonce.lower_half as u128;
295        upper_half | lower_half
296    }
297}