proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    const PRIORITY: MessagePriority;
 17    fn into_envelope(
 18        self,
 19        id: u32,
 20        responding_to: Option<u32>,
 21        original_sender_id: Option<u32>,
 22    ) -> Envelope;
 23    fn from_envelope(envelope: Envelope) -> Option<Self>;
 24}
 25
 26pub trait EntityMessage: EnvelopedMessage {
 27    fn remote_entity_id(&self) -> u64;
 28}
 29
 30pub trait RequestMessage: EnvelopedMessage {
 31    type Response: EnvelopedMessage;
 32}
 33
 34pub trait AnyTypedEnvelope: 'static + Send + Sync {
 35    fn payload_type_id(&self) -> TypeId;
 36    fn payload_type_name(&self) -> &'static str;
 37    fn as_any(&self) -> &dyn Any;
 38    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 39    fn is_background(&self) -> bool;
 40    fn original_sender_id(&self) -> Option<PeerId>;
 41}
 42
 43pub enum MessagePriority {
 44    Foreground,
 45    Background,
 46}
 47
 48impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 49    fn payload_type_id(&self) -> TypeId {
 50        TypeId::of::<T>()
 51    }
 52
 53    fn payload_type_name(&self) -> &'static str {
 54        T::NAME
 55    }
 56
 57    fn as_any(&self) -> &dyn Any {
 58        self
 59    }
 60
 61    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 62        self
 63    }
 64
 65    fn is_background(&self) -> bool {
 66        matches!(T::PRIORITY, MessagePriority::Background)
 67    }
 68
 69    fn original_sender_id(&self) -> Option<PeerId> {
 70        self.original_sender_id
 71    }
 72}
 73
 74macro_rules! messages {
 75    ($(($name:ident, $priority:ident)),* $(,)?) => {
 76        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 77            match envelope.payload {
 78                $(Some(envelope::Payload::$name(payload)) => {
 79                    Some(Box::new(TypedEnvelope {
 80                        sender_id,
 81                        original_sender_id: envelope.original_sender_id.map(PeerId),
 82                        message_id: envelope.id,
 83                        payload,
 84                    }))
 85                }, )*
 86                _ => None
 87            }
 88        }
 89
 90        $(
 91            impl EnvelopedMessage for $name {
 92                const NAME: &'static str = std::stringify!($name);
 93                const PRIORITY: MessagePriority = MessagePriority::$priority;
 94
 95                fn into_envelope(
 96                    self,
 97                    id: u32,
 98                    responding_to: Option<u32>,
 99                    original_sender_id: Option<u32>,
100                ) -> Envelope {
101                    Envelope {
102                        id,
103                        responding_to,
104                        original_sender_id,
105                        payload: Some(envelope::Payload::$name(self)),
106                    }
107                }
108
109                fn from_envelope(envelope: Envelope) -> Option<Self> {
110                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
111                        Some(msg)
112                    } else {
113                        None
114                    }
115                }
116            }
117        )*
118    };
119}
120
121macro_rules! request_messages {
122    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
123        $(impl RequestMessage for $request_name {
124            type Response = $response_name;
125        })*
126    };
127}
128
129macro_rules! entity_messages {
130    ($id_field:ident, $($name:ident),* $(,)?) => {
131        $(impl EntityMessage for $name {
132            fn remote_entity_id(&self) -> u64 {
133                self.$id_field
134            }
135        })*
136    };
137}
138
139messages!(
140    (Ack, Foreground),
141    (AddProjectCollaborator, Foreground),
142    (ApplyCodeAction, Background),
143    (ApplyCodeActionResponse, Background),
144    (ApplyCompletionAdditionalEdits, Background),
145    (ApplyCompletionAdditionalEditsResponse, Background),
146    (BufferReloaded, Foreground),
147    (BufferSaved, Foreground),
148    (ChannelMessageSent, Foreground),
149    (Error, Foreground),
150    (FormatBuffers, Foreground),
151    (FormatBuffersResponse, Foreground),
152    (GetChannelMessages, Foreground),
153    (GetChannelMessagesResponse, Foreground),
154    (GetChannels, Foreground),
155    (GetChannelsResponse, Foreground),
156    (GetCodeActions, Background),
157    (GetCodeActionsResponse, Background),
158    (GetCompletions, Background),
159    (GetCompletionsResponse, Background),
160    (GetDefinition, Background),
161    (GetDefinitionResponse, Background),
162    (GetDocumentHighlights, Background),
163    (GetDocumentHighlightsResponse, Background),
164    (GetReferences, Background),
165    (GetReferencesResponse, Background),
166    (GetProjectSymbols, Background),
167    (GetProjectSymbolsResponse, Background),
168    (GetUsers, Foreground),
169    (GetUsersResponse, Foreground),
170    (JoinChannel, Foreground),
171    (JoinChannelResponse, Foreground),
172    (JoinProject, Foreground),
173    (JoinProjectResponse, Foreground),
174    (StartLanguageServer, Foreground),
175    (UpdateLanguageServer, Foreground),
176    (LeaveChannel, Foreground),
177    (LeaveProject, Foreground),
178    (OpenBuffer, Background),
179    (OpenBufferForSymbol, Background),
180    (OpenBufferForSymbolResponse, Background),
181    (OpenBufferResponse, Background),
182    (PerformRename, Background),
183    (PerformRenameResponse, Background),
184    (PrepareRename, Background),
185    (PrepareRenameResponse, Background),
186    (RegisterProjectResponse, Foreground),
187    (Ping, Foreground),
188    (RegisterProject, Foreground),
189    (RegisterWorktree, Foreground),
190    (RemoveProjectCollaborator, Foreground),
191    (SaveBuffer, Foreground),
192    (SearchProject, Background),
193    (SearchProjectResponse, Background),
194    (SendChannelMessage, Foreground),
195    (SendChannelMessageResponse, Foreground),
196    (ShareProject, Foreground),
197    (Test, Foreground),
198    (UnregisterProject, Foreground),
199    (UnregisterWorktree, Foreground),
200    (UnshareProject, Foreground),
201    (UpdateBuffer, Background),
202    (UpdateBufferFile, Foreground),
203    (UpdateContacts, Foreground),
204    (UpdateDiagnosticSummary, Foreground),
205    (UpdateWorktree, Foreground),
206);
207
208request_messages!(
209    (ApplyCodeAction, ApplyCodeActionResponse),
210    (
211        ApplyCompletionAdditionalEdits,
212        ApplyCompletionAdditionalEditsResponse
213    ),
214    (FormatBuffers, FormatBuffersResponse),
215    (GetChannelMessages, GetChannelMessagesResponse),
216    (GetChannels, GetChannelsResponse),
217    (GetCodeActions, GetCodeActionsResponse),
218    (GetCompletions, GetCompletionsResponse),
219    (GetDefinition, GetDefinitionResponse),
220    (GetDocumentHighlights, GetDocumentHighlightsResponse),
221    (GetReferences, GetReferencesResponse),
222    (GetProjectSymbols, GetProjectSymbolsResponse),
223    (GetUsers, GetUsersResponse),
224    (JoinChannel, JoinChannelResponse),
225    (JoinProject, JoinProjectResponse),
226    (OpenBuffer, OpenBufferResponse),
227    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
228    (Ping, Ack),
229    (PerformRename, PerformRenameResponse),
230    (PrepareRename, PrepareRenameResponse),
231    (RegisterProject, RegisterProjectResponse),
232    (RegisterWorktree, Ack),
233    (SaveBuffer, BufferSaved),
234    (SearchProject, SearchProjectResponse),
235    (SendChannelMessage, SendChannelMessageResponse),
236    (ShareProject, Ack),
237    (Test, Test),
238    (UpdateBuffer, Ack),
239    (UpdateWorktree, Ack),
240);
241
242entity_messages!(
243    project_id,
244    AddProjectCollaborator,
245    ApplyCodeAction,
246    ApplyCompletionAdditionalEdits,
247    BufferReloaded,
248    BufferSaved,
249    FormatBuffers,
250    GetCodeActions,
251    GetCompletions,
252    GetDefinition,
253    GetDocumentHighlights,
254    GetReferences,
255    GetProjectSymbols,
256    JoinProject,
257    LeaveProject,
258    OpenBuffer,
259    OpenBufferForSymbol,
260    PerformRename,
261    PrepareRename,
262    RemoveProjectCollaborator,
263    SaveBuffer,
264    SearchProject,
265    StartLanguageServer,
266    UnregisterWorktree,
267    UnshareProject,
268    UpdateBuffer,
269    UpdateBufferFile,
270    UpdateDiagnosticSummary,
271    UpdateLanguageServer,
272    RegisterWorktree,
273    UpdateWorktree,
274);
275
276entity_messages!(channel_id, ChannelMessageSent);
277
278/// A stream of protobuf messages.
279pub struct MessageStream<S> {
280    stream: S,
281    encoding_buffer: Vec<u8>,
282}
283
284#[derive(Debug)]
285pub enum Message {
286    Envelope(Envelope),
287    Ping,
288    Pong,
289}
290
291impl<S> MessageStream<S> {
292    pub fn new(stream: S) -> Self {
293        Self {
294            stream,
295            encoding_buffer: Vec::new(),
296        }
297    }
298
299    pub fn inner_mut(&mut self) -> &mut S {
300        &mut self.stream
301    }
302}
303
304impl<S> MessageStream<S>
305where
306    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
307{
308    pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> {
309        #[cfg(any(test, feature = "test-support"))]
310        const COMPRESSION_LEVEL: i32 = -7;
311
312        #[cfg(not(any(test, feature = "test-support")))]
313        const COMPRESSION_LEVEL: i32 = 4;
314
315        match message {
316            Message::Envelope(message) => {
317                self.encoding_buffer.resize(message.encoded_len(), 0);
318                self.encoding_buffer.clear();
319                message
320                    .encode(&mut self.encoding_buffer)
321                    .map_err(|err| io::Error::from(err))?;
322                let buffer =
323                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
324                        .unwrap();
325                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
326            }
327            Message::Ping => {
328                self.stream
329                    .send(WebSocketMessage::Ping(Default::default()))
330                    .await?;
331            }
332            Message::Pong => {
333                self.stream
334                    .send(WebSocketMessage::Pong(Default::default()))
335                    .await?;
336            }
337        }
338
339        Ok(())
340    }
341}
342
343impl<S> MessageStream<S>
344where
345    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
346{
347    pub async fn read(&mut self) -> Result<Message, WebSocketError> {
348        while let Some(bytes) = self.stream.next().await {
349            match bytes? {
350                WebSocketMessage::Binary(bytes) => {
351                    self.encoding_buffer.clear();
352                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
353                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
354                        .map_err(io::Error::from)?;
355                    return Ok(Message::Envelope(envelope));
356                }
357                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
358                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
359                WebSocketMessage::Close(_) => break,
360                _ => {}
361            }
362        }
363        Err(WebSocketError::ConnectionClosed)
364    }
365}
366
367impl Into<SystemTime> for Timestamp {
368    fn into(self) -> SystemTime {
369        UNIX_EPOCH
370            .checked_add(Duration::new(self.seconds, self.nanos))
371            .unwrap()
372    }
373}
374
375impl From<SystemTime> for Timestamp {
376    fn from(time: SystemTime) -> Self {
377        let duration = time.duration_since(UNIX_EPOCH).unwrap();
378        Self {
379            seconds: duration.as_secs(),
380            nanos: duration.subsec_nanos(),
381        }
382    }
383}
384
385impl From<u128> for Nonce {
386    fn from(nonce: u128) -> Self {
387        let upper_half = (nonce >> 64) as u64;
388        let lower_half = nonce as u64;
389        Self {
390            upper_half,
391            lower_half,
392        }
393    }
394}
395
396impl From<Nonce> for u128 {
397    fn from(nonce: Nonce) -> Self {
398        let upper_half = (nonce.upper_half as u128) << 64;
399        let lower_half = nonce.lower_half as u128;
400        upper_half | lower_half
401    }
402}