proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    Ack,
124    AddProjectCollaborator,
125    BufferReloaded,
126    BufferSaved,
127    ChannelMessageSent,
128    CloseBuffer,
129    DiskBasedDiagnosticsUpdated,
130    DiskBasedDiagnosticsUpdating,
131    Error,
132    FormatBuffer,
133    GetChannelMessages,
134    GetChannelMessagesResponse,
135    GetChannels,
136    GetChannelsResponse,
137    GetDefinition,
138    GetDefinitionResponse,
139    GetUsers,
140    GetUsersResponse,
141    JoinChannel,
142    JoinChannelResponse,
143    JoinProject,
144    JoinProjectResponse,
145    LeaveChannel,
146    LeaveProject,
147    OpenBuffer,
148    OpenBufferResponse,
149    RegisterProjectResponse,
150    Ping,
151    RegisterProject,
152    RegisterWorktree,
153    RemoveProjectCollaborator,
154    SaveBuffer,
155    SendChannelMessage,
156    SendChannelMessageResponse,
157    ShareProject,
158    ShareWorktree,
159    UnregisterProject,
160    UnregisterWorktree,
161    UnshareProject,
162    UpdateBuffer,
163    UpdateBufferFile,
164    UpdateContacts,
165    UpdateDiagnosticSummary,
166    UpdateWorktree,
167);
168
169request_messages!(
170    (FormatBuffer, Ack),
171    (GetChannelMessages, GetChannelMessagesResponse),
172    (GetChannels, GetChannelsResponse),
173    (GetDefinition, GetDefinitionResponse),
174    (GetUsers, GetUsersResponse),
175    (JoinChannel, JoinChannelResponse),
176    (JoinProject, JoinProjectResponse),
177    (OpenBuffer, OpenBufferResponse),
178    (Ping, Ack),
179    (RegisterProject, RegisterProjectResponse),
180    (RegisterWorktree, Ack),
181    (SaveBuffer, BufferSaved),
182    (SendChannelMessage, SendChannelMessageResponse),
183    (ShareProject, Ack),
184    (ShareWorktree, Ack),
185    (UpdateBuffer, Ack),
186);
187
188entity_messages!(
189    project_id,
190    AddProjectCollaborator,
191    BufferReloaded,
192    BufferSaved,
193    CloseBuffer,
194    DiskBasedDiagnosticsUpdated,
195    DiskBasedDiagnosticsUpdating,
196    FormatBuffer,
197    GetDefinition,
198    JoinProject,
199    LeaveProject,
200    OpenBuffer,
201    RemoveProjectCollaborator,
202    SaveBuffer,
203    ShareWorktree,
204    UnregisterWorktree,
205    UnshareProject,
206    UpdateBuffer,
207    UpdateBufferFile,
208    UpdateDiagnosticSummary,
209    UpdateWorktree,
210);
211
212entity_messages!(channel_id, ChannelMessageSent);
213
214/// A stream of protobuf messages.
215pub struct MessageStream<S> {
216    stream: S,
217    encoding_buffer: Vec<u8>,
218}
219
220impl<S> MessageStream<S> {
221    pub fn new(stream: S) -> Self {
222        Self {
223            stream,
224            encoding_buffer: Vec::new(),
225        }
226    }
227
228    pub fn inner_mut(&mut self) -> &mut S {
229        &mut self.stream
230    }
231}
232
233impl<S> MessageStream<S>
234where
235    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
236{
237    /// Write a given protobuf message to the stream.
238    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
239        self.encoding_buffer.resize(message.encoded_len(), 0);
240        self.encoding_buffer.clear();
241        message
242            .encode(&mut self.encoding_buffer)
243            .map_err(|err| io::Error::from(err))?;
244        let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
245        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
246        Ok(())
247    }
248}
249
250impl<S> MessageStream<S>
251where
252    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
253{
254    /// Read a protobuf message of the given type from the stream.
255    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
256        while let Some(bytes) = self.stream.next().await {
257            match bytes? {
258                WebSocketMessage::Binary(bytes) => {
259                    self.encoding_buffer.clear();
260                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
261                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
262                        .map_err(io::Error::from)?;
263                    return Ok(envelope);
264                }
265                WebSocketMessage::Close(_) => break,
266                _ => {}
267            }
268        }
269        Err(WebSocketError::ConnectionClosed)
270    }
271}
272
273impl Into<SystemTime> for Timestamp {
274    fn into(self) -> SystemTime {
275        UNIX_EPOCH
276            .checked_add(Duration::new(self.seconds, self.nanos))
277            .unwrap()
278    }
279}
280
281impl From<SystemTime> for Timestamp {
282    fn from(time: SystemTime) -> Self {
283        let duration = time.duration_since(UNIX_EPOCH).unwrap();
284        Self {
285            seconds: duration.as_secs(),
286            nanos: duration.subsec_nanos(),
287        }
288    }
289}
290
291impl From<u128> for Nonce {
292    fn from(nonce: u128) -> Self {
293        let upper_half = (nonce >> 64) as u64;
294        let lower_half = nonce as u64;
295        Self {
296            upper_half,
297            lower_half,
298        }
299    }
300}
301
302impl From<Nonce> for u128 {
303    fn from(nonce: Nonce) -> Self {
304        let upper_half = (nonce.upper_half as u128) << 64;
305        let lower_half = nonce.lower_half as u128;
306        upper_half | lower_half
307    }
308}