proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    Ack,
124    AddProjectCollaborator,
125    ApplyCompletionAdditionalEdits,
126    ApplyCompletionAdditionalEditsResponse,
127    BufferReloaded,
128    BufferSaved,
129    ChannelMessageSent,
130    CloseBuffer,
131    DiskBasedDiagnosticsUpdated,
132    DiskBasedDiagnosticsUpdating,
133    Error,
134    FormatBuffer,
135    GetChannelMessages,
136    GetChannelMessagesResponse,
137    GetChannels,
138    GetChannelsResponse,
139    GetCompletions,
140    GetCompletionsResponse,
141    GetDefinition,
142    GetDefinitionResponse,
143    GetUsers,
144    GetUsersResponse,
145    JoinChannel,
146    JoinChannelResponse,
147    JoinProject,
148    JoinProjectResponse,
149    LeaveChannel,
150    LeaveProject,
151    OpenBuffer,
152    OpenBufferResponse,
153    RegisterProjectResponse,
154    Ping,
155    RegisterProject,
156    RegisterWorktree,
157    RemoveProjectCollaborator,
158    SaveBuffer,
159    SendChannelMessage,
160    SendChannelMessageResponse,
161    ShareProject,
162    ShareWorktree,
163    UnregisterProject,
164    UnregisterWorktree,
165    UnshareProject,
166    UpdateBuffer,
167    UpdateBufferFile,
168    UpdateContacts,
169    UpdateDiagnosticSummary,
170    UpdateWorktree,
171);
172
173request_messages!(
174    (
175        ApplyCompletionAdditionalEdits,
176        ApplyCompletionAdditionalEditsResponse
177    ),
178    (FormatBuffer, Ack),
179    (GetChannelMessages, GetChannelMessagesResponse),
180    (GetChannels, GetChannelsResponse),
181    (GetCompletions, GetCompletionsResponse),
182    (GetDefinition, GetDefinitionResponse),
183    (GetUsers, GetUsersResponse),
184    (JoinChannel, JoinChannelResponse),
185    (JoinProject, JoinProjectResponse),
186    (OpenBuffer, OpenBufferResponse),
187    (Ping, Ack),
188    (RegisterProject, RegisterProjectResponse),
189    (RegisterWorktree, Ack),
190    (SaveBuffer, BufferSaved),
191    (SendChannelMessage, SendChannelMessageResponse),
192    (ShareProject, Ack),
193    (ShareWorktree, Ack),
194    (UpdateBuffer, Ack),
195);
196
197entity_messages!(
198    project_id,
199    AddProjectCollaborator,
200    ApplyCompletionAdditionalEdits,
201    BufferReloaded,
202    BufferSaved,
203    CloseBuffer,
204    DiskBasedDiagnosticsUpdated,
205    DiskBasedDiagnosticsUpdating,
206    FormatBuffer,
207    GetCompletions,
208    GetDefinition,
209    JoinProject,
210    LeaveProject,
211    OpenBuffer,
212    RemoveProjectCollaborator,
213    SaveBuffer,
214    ShareWorktree,
215    UnregisterWorktree,
216    UnshareProject,
217    UpdateBuffer,
218    UpdateBufferFile,
219    UpdateDiagnosticSummary,
220    UpdateWorktree,
221);
222
223entity_messages!(channel_id, ChannelMessageSent);
224
225/// A stream of protobuf messages.
226pub struct MessageStream<S> {
227    stream: S,
228    encoding_buffer: Vec<u8>,
229}
230
231impl<S> MessageStream<S> {
232    pub fn new(stream: S) -> Self {
233        Self {
234            stream,
235            encoding_buffer: Vec::new(),
236        }
237    }
238
239    pub fn inner_mut(&mut self) -> &mut S {
240        &mut self.stream
241    }
242}
243
244impl<S> MessageStream<S>
245where
246    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
247{
248    /// Write a given protobuf message to the stream.
249    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
250        self.encoding_buffer.resize(message.encoded_len(), 0);
251        self.encoding_buffer.clear();
252        message
253            .encode(&mut self.encoding_buffer)
254            .map_err(|err| io::Error::from(err))?;
255        let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
256        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
257        Ok(())
258    }
259}
260
261impl<S> MessageStream<S>
262where
263    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
264{
265    /// Read a protobuf message of the given type from the stream.
266    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
267        while let Some(bytes) = self.stream.next().await {
268            match bytes? {
269                WebSocketMessage::Binary(bytes) => {
270                    self.encoding_buffer.clear();
271                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
272                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
273                        .map_err(io::Error::from)?;
274                    return Ok(envelope);
275                }
276                WebSocketMessage::Close(_) => break,
277                _ => {}
278            }
279        }
280        Err(WebSocketError::ConnectionClosed)
281    }
282}
283
284impl Into<SystemTime> for Timestamp {
285    fn into(self) -> SystemTime {
286        UNIX_EPOCH
287            .checked_add(Duration::new(self.seconds, self.nanos))
288            .unwrap()
289    }
290}
291
292impl From<SystemTime> for Timestamp {
293    fn from(time: SystemTime) -> Self {
294        let duration = time.duration_since(UNIX_EPOCH).unwrap();
295        Self {
296            seconds: duration.as_secs(),
297            nanos: duration.subsec_nanos(),
298        }
299    }
300}
301
302impl From<u128> for Nonce {
303    fn from(nonce: u128) -> Self {
304        let upper_half = (nonce >> 64) as u64;
305        let lower_half = nonce as u64;
306        Self {
307            upper_half,
308            lower_half,
309        }
310    }
311}
312
313impl From<Nonce> for u128 {
314    fn from(nonce: Nonce) -> Self {
315        let upper_half = (nonce.upper_half as u128) << 64;
316        let lower_half = nonce.lower_half as u128;
317        upper_half | lower_half
318    }
319}