proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    Ack,
124    AddProjectCollaborator,
125    BufferSaved,
126    ChannelMessageSent,
127    CloseBuffer,
128    DiskBasedDiagnosticsUpdated,
129    Error,
130    GetChannelMessages,
131    GetChannelMessagesResponse,
132    GetChannels,
133    GetChannelsResponse,
134    GetUsers,
135    GetUsersResponse,
136    JoinChannel,
137    JoinChannelResponse,
138    JoinProject,
139    JoinProjectResponse,
140    LeaveChannel,
141    LeaveProject,
142    OpenBuffer,
143    OpenBufferResponse,
144    RegisterProjectResponse,
145    Ping,
146    RegisterProject,
147    RegisterWorktree,
148    RemoveProjectCollaborator,
149    SaveBuffer,
150    SendChannelMessage,
151    SendChannelMessageResponse,
152    ShareProject,
153    ShareWorktree,
154    UnregisterProject,
155    UnregisterWorktree,
156    UnshareProject,
157    UpdateBuffer,
158    UpdateContacts,
159    UpdateDiagnosticSummary,
160    UpdateWorktree,
161);
162
163request_messages!(
164    (GetChannelMessages, GetChannelMessagesResponse),
165    (GetChannels, GetChannelsResponse),
166    (GetUsers, GetUsersResponse),
167    (JoinChannel, JoinChannelResponse),
168    (JoinProject, JoinProjectResponse),
169    (OpenBuffer, OpenBufferResponse),
170    (Ping, Ack),
171    (RegisterProject, RegisterProjectResponse),
172    (RegisterWorktree, Ack),
173    (SaveBuffer, BufferSaved),
174    (SendChannelMessage, SendChannelMessageResponse),
175    (ShareProject, Ack),
176    (ShareWorktree, Ack),
177    (UpdateBuffer, Ack),
178);
179
180entity_messages!(
181    project_id,
182    AddProjectCollaborator,
183    BufferSaved,
184    CloseBuffer,
185    DiskBasedDiagnosticsUpdated,
186    JoinProject,
187    LeaveProject,
188    OpenBuffer,
189    RemoveProjectCollaborator,
190    SaveBuffer,
191    ShareWorktree,
192    UnregisterWorktree,
193    UnshareProject,
194    UpdateBuffer,
195    UpdateDiagnosticSummary,
196    UpdateWorktree,
197);
198
199entity_messages!(channel_id, ChannelMessageSent);
200
201/// A stream of protobuf messages.
202pub struct MessageStream<S> {
203    stream: S,
204    encoding_buffer: Vec<u8>,
205}
206
207impl<S> MessageStream<S> {
208    pub fn new(stream: S) -> Self {
209        Self {
210            stream,
211            encoding_buffer: Vec::new(),
212        }
213    }
214
215    pub fn inner_mut(&mut self) -> &mut S {
216        &mut self.stream
217    }
218}
219
220impl<S> MessageStream<S>
221where
222    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
223{
224    /// Write a given protobuf message to the stream.
225    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
226        self.encoding_buffer.resize(message.encoded_len(), 0);
227        self.encoding_buffer.clear();
228        message
229            .encode(&mut self.encoding_buffer)
230            .map_err(|err| io::Error::from(err))?;
231        let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
232        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
233        Ok(())
234    }
235}
236
237impl<S> MessageStream<S>
238where
239    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
240{
241    /// Read a protobuf message of the given type from the stream.
242    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
243        while let Some(bytes) = self.stream.next().await {
244            match bytes? {
245                WebSocketMessage::Binary(bytes) => {
246                    self.encoding_buffer.clear();
247                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
248                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
249                        .map_err(io::Error::from)?;
250                    return Ok(envelope);
251                }
252                WebSocketMessage::Close(_) => break,
253                _ => {}
254            }
255        }
256        Err(WebSocketError::ConnectionClosed)
257    }
258}
259
260impl Into<SystemTime> for Timestamp {
261    fn into(self) -> SystemTime {
262        UNIX_EPOCH
263            .checked_add(Duration::new(self.seconds, self.nanos))
264            .unwrap()
265    }
266}
267
268impl From<SystemTime> for Timestamp {
269    fn from(time: SystemTime) -> Self {
270        let duration = time.duration_since(UNIX_EPOCH).unwrap();
271        Self {
272            seconds: duration.as_secs(),
273            nanos: duration.subsec_nanos(),
274        }
275    }
276}
277
278impl From<u128> for Nonce {
279    fn from(nonce: u128) -> Self {
280        let upper_half = (nonce >> 64) as u64;
281        let lower_half = nonce as u64;
282        Self {
283            upper_half,
284            lower_half,
285        }
286    }
287}
288
289impl From<Nonce> for u128 {
290    fn from(nonce: Nonce) -> Self {
291        let upper_half = (nonce.upper_half as u128) << 64;
292        let lower_half = nonce.lower_half as u128;
293        upper_half | lower_half
294    }
295}