proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    const PRIORITY: MessagePriority;
 17    fn into_envelope(
 18        self,
 19        id: u32,
 20        responding_to: Option<u32>,
 21        original_sender_id: Option<u32>,
 22    ) -> Envelope;
 23    fn from_envelope(envelope: Envelope) -> Option<Self>;
 24}
 25
 26pub trait EntityMessage: EnvelopedMessage {
 27    fn remote_entity_id(&self) -> u64;
 28}
 29
 30pub trait RequestMessage: EnvelopedMessage {
 31    type Response: EnvelopedMessage;
 32}
 33
 34pub trait AnyTypedEnvelope: 'static + Send + Sync {
 35    fn payload_type_id(&self) -> TypeId;
 36    fn payload_type_name(&self) -> &'static str;
 37    fn as_any(&self) -> &dyn Any;
 38    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 39    fn is_background(&self) -> bool;
 40    fn original_sender_id(&self) -> Option<PeerId>;
 41}
 42
 43pub enum MessagePriority {
 44    Foreground,
 45    Background,
 46}
 47
 48impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 49    fn payload_type_id(&self) -> TypeId {
 50        TypeId::of::<T>()
 51    }
 52
 53    fn payload_type_name(&self) -> &'static str {
 54        T::NAME
 55    }
 56
 57    fn as_any(&self) -> &dyn Any {
 58        self
 59    }
 60
 61    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 62        self
 63    }
 64
 65    fn is_background(&self) -> bool {
 66        matches!(T::PRIORITY, MessagePriority::Background)
 67    }
 68
 69    fn original_sender_id(&self) -> Option<PeerId> {
 70        self.original_sender_id
 71    }
 72}
 73
 74macro_rules! messages {
 75    ($(($name:ident, $priority:ident)),* $(,)?) => {
 76        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 77            match envelope.payload {
 78                $(Some(envelope::Payload::$name(payload)) => {
 79                    Some(Box::new(TypedEnvelope {
 80                        sender_id,
 81                        original_sender_id: envelope.original_sender_id.map(PeerId),
 82                        message_id: envelope.id,
 83                        payload,
 84                    }))
 85                }, )*
 86                _ => None
 87            }
 88        }
 89
 90        $(
 91            impl EnvelopedMessage for $name {
 92                const NAME: &'static str = std::stringify!($name);
 93                const PRIORITY: MessagePriority = MessagePriority::$priority;
 94
 95                fn into_envelope(
 96                    self,
 97                    id: u32,
 98                    responding_to: Option<u32>,
 99                    original_sender_id: Option<u32>,
100                ) -> Envelope {
101                    Envelope {
102                        id,
103                        responding_to,
104                        original_sender_id,
105                        payload: Some(envelope::Payload::$name(self)),
106                    }
107                }
108
109                fn from_envelope(envelope: Envelope) -> Option<Self> {
110                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
111                        Some(msg)
112                    } else {
113                        None
114                    }
115                }
116            }
117        )*
118    };
119}
120
121macro_rules! request_messages {
122    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
123        $(impl RequestMessage for $request_name {
124            type Response = $response_name;
125        })*
126    };
127}
128
129macro_rules! entity_messages {
130    ($id_field:ident, $($name:ident),* $(,)?) => {
131        $(impl EntityMessage for $name {
132            fn remote_entity_id(&self) -> u64 {
133                self.$id_field
134            }
135        })*
136    };
137}
138
139messages!(
140    (Ack, Foreground),
141    (AddProjectCollaborator, Foreground),
142    (ApplyCodeAction, Background),
143    (ApplyCodeActionResponse, Background),
144    (ApplyCompletionAdditionalEdits, Background),
145    (ApplyCompletionAdditionalEditsResponse, Background),
146    (BufferReloaded, Foreground),
147    (BufferSaved, Foreground),
148    (ChannelMessageSent, Foreground),
149    (Error, Foreground),
150    (Follow, Foreground),
151    (FollowResponse, Foreground),
152    (FormatBuffers, Foreground),
153    (FormatBuffersResponse, Foreground),
154    (GetChannelMessages, Foreground),
155    (GetChannelMessagesResponse, Foreground),
156    (GetChannels, Foreground),
157    (GetChannelsResponse, Foreground),
158    (GetCodeActions, Background),
159    (GetCodeActionsResponse, Background),
160    (GetCompletions, Background),
161    (GetCompletionsResponse, Background),
162    (GetDefinition, Background),
163    (GetDefinitionResponse, Background),
164    (GetDocumentHighlights, Background),
165    (GetDocumentHighlightsResponse, Background),
166    (GetReferences, Background),
167    (GetReferencesResponse, Background),
168    (GetProjectSymbols, Background),
169    (GetProjectSymbolsResponse, Background),
170    (GetUsers, Foreground),
171    (GetUsersResponse, Foreground),
172    (JoinChannel, Foreground),
173    (JoinChannelResponse, Foreground),
174    (JoinProject, Foreground),
175    (JoinProjectResponse, Foreground),
176    (StartLanguageServer, Foreground),
177    (UpdateLanguageServer, Foreground),
178    (LeaveChannel, Foreground),
179    (LeaveProject, Foreground),
180    (OpenBufferById, Background),
181    (OpenBufferByPath, Background),
182    (OpenBufferForSymbol, Background),
183    (OpenBufferForSymbolResponse, Background),
184    (OpenBufferResponse, Background),
185    (PerformRename, Background),
186    (PerformRenameResponse, Background),
187    (PrepareRename, Background),
188    (PrepareRenameResponse, Background),
189    (RegisterProjectResponse, Foreground),
190    (Ping, Foreground),
191    (RegisterProject, Foreground),
192    (RegisterWorktree, Foreground),
193    (RemoveProjectCollaborator, Foreground),
194    (SaveBuffer, Foreground),
195    (SearchProject, Background),
196    (SearchProjectResponse, Background),
197    (SendChannelMessage, Foreground),
198    (SendChannelMessageResponse, Foreground),
199    (ShareProject, Foreground),
200    (Test, Foreground),
201    (Unfollow, Foreground),
202    (UnregisterProject, Foreground),
203    (UnregisterWorktree, Foreground),
204    (UnshareProject, Foreground),
205    (UpdateBuffer, Foreground),
206    (UpdateBufferFile, Foreground),
207    (UpdateContacts, Foreground),
208    (UpdateDiagnosticSummary, Foreground),
209    (UpdateFollowers, Foreground),
210    (UpdateWorktree, Foreground),
211);
212
213request_messages!(
214    (ApplyCodeAction, ApplyCodeActionResponse),
215    (
216        ApplyCompletionAdditionalEdits,
217        ApplyCompletionAdditionalEditsResponse
218    ),
219    (Follow, FollowResponse),
220    (FormatBuffers, FormatBuffersResponse),
221    (GetChannelMessages, GetChannelMessagesResponse),
222    (GetChannels, GetChannelsResponse),
223    (GetCodeActions, GetCodeActionsResponse),
224    (GetCompletions, GetCompletionsResponse),
225    (GetDefinition, GetDefinitionResponse),
226    (GetDocumentHighlights, GetDocumentHighlightsResponse),
227    (GetReferences, GetReferencesResponse),
228    (GetProjectSymbols, GetProjectSymbolsResponse),
229    (GetUsers, GetUsersResponse),
230    (JoinChannel, JoinChannelResponse),
231    (JoinProject, JoinProjectResponse),
232    (OpenBufferById, OpenBufferResponse),
233    (OpenBufferByPath, OpenBufferResponse),
234    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
235    (Ping, Ack),
236    (PerformRename, PerformRenameResponse),
237    (PrepareRename, PrepareRenameResponse),
238    (RegisterProject, RegisterProjectResponse),
239    (RegisterWorktree, Ack),
240    (SaveBuffer, BufferSaved),
241    (SearchProject, SearchProjectResponse),
242    (SendChannelMessage, SendChannelMessageResponse),
243    (ShareProject, Ack),
244    (Test, Test),
245    (UpdateBuffer, Ack),
246    (UpdateWorktree, Ack),
247);
248
249entity_messages!(
250    project_id,
251    AddProjectCollaborator,
252    ApplyCodeAction,
253    ApplyCompletionAdditionalEdits,
254    BufferReloaded,
255    BufferSaved,
256    Follow,
257    FormatBuffers,
258    GetCodeActions,
259    GetCompletions,
260    GetDefinition,
261    GetDocumentHighlights,
262    GetReferences,
263    GetProjectSymbols,
264    JoinProject,
265    LeaveProject,
266    OpenBufferById,
267    OpenBufferByPath,
268    OpenBufferForSymbol,
269    PerformRename,
270    PrepareRename,
271    RemoveProjectCollaborator,
272    SaveBuffer,
273    SearchProject,
274    StartLanguageServer,
275    Unfollow,
276    UnregisterWorktree,
277    UnshareProject,
278    UpdateBuffer,
279    UpdateBufferFile,
280    UpdateDiagnosticSummary,
281    UpdateFollowers,
282    UpdateLanguageServer,
283    RegisterWorktree,
284    UpdateWorktree,
285);
286
287entity_messages!(channel_id, ChannelMessageSent);
288
289/// A stream of protobuf messages.
290pub struct MessageStream<S> {
291    stream: S,
292    encoding_buffer: Vec<u8>,
293}
294
295#[derive(Debug)]
296pub enum Message {
297    Envelope(Envelope),
298    Ping,
299    Pong,
300}
301
302impl<S> MessageStream<S> {
303    pub fn new(stream: S) -> Self {
304        Self {
305            stream,
306            encoding_buffer: Vec::new(),
307        }
308    }
309
310    pub fn inner_mut(&mut self) -> &mut S {
311        &mut self.stream
312    }
313}
314
315impl<S> MessageStream<S>
316where
317    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
318{
319    pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> {
320        #[cfg(any(test, feature = "test-support"))]
321        const COMPRESSION_LEVEL: i32 = -7;
322
323        #[cfg(not(any(test, feature = "test-support")))]
324        const COMPRESSION_LEVEL: i32 = 4;
325
326        match message {
327            Message::Envelope(message) => {
328                self.encoding_buffer.resize(message.encoded_len(), 0);
329                self.encoding_buffer.clear();
330                message
331                    .encode(&mut self.encoding_buffer)
332                    .map_err(|err| io::Error::from(err))?;
333                let buffer =
334                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
335                        .unwrap();
336                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
337            }
338            Message::Ping => {
339                self.stream
340                    .send(WebSocketMessage::Ping(Default::default()))
341                    .await?;
342            }
343            Message::Pong => {
344                self.stream
345                    .send(WebSocketMessage::Pong(Default::default()))
346                    .await?;
347            }
348        }
349
350        Ok(())
351    }
352}
353
354impl<S> MessageStream<S>
355where
356    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
357{
358    pub async fn read(&mut self) -> Result<Message, WebSocketError> {
359        while let Some(bytes) = self.stream.next().await {
360            match bytes? {
361                WebSocketMessage::Binary(bytes) => {
362                    self.encoding_buffer.clear();
363                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
364                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
365                        .map_err(io::Error::from)?;
366                    return Ok(Message::Envelope(envelope));
367                }
368                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
369                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
370                WebSocketMessage::Close(_) => break,
371                _ => {}
372            }
373        }
374        Err(WebSocketError::ConnectionClosed)
375    }
376}
377
378impl Into<SystemTime> for Timestamp {
379    fn into(self) -> SystemTime {
380        UNIX_EPOCH
381            .checked_add(Duration::new(self.seconds, self.nanos))
382            .unwrap()
383    }
384}
385
386impl From<SystemTime> for Timestamp {
387    fn from(time: SystemTime) -> Self {
388        let duration = time.duration_since(UNIX_EPOCH).unwrap();
389        Self {
390            seconds: duration.as_secs(),
391            nanos: duration.subsec_nanos(),
392        }
393    }
394}
395
396impl From<u128> for Nonce {
397    fn from(nonce: u128) -> Self {
398        let upper_half = (nonce >> 64) as u64;
399        let lower_half = nonce as u64;
400        Self {
401            upper_half,
402            lower_half,
403        }
404    }
405}
406
407impl From<Nonce> for u128 {
408    fn from(nonce: Nonce) -> Self {
409        let upper_half = (nonce.upper_half as u128) << 64;
410        let lower_half = nonce.lower_half as u128;
411        upper_half | lower_half
412    }
413}