proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    const PRIORITY: MessagePriority;
 17    fn into_envelope(
 18        self,
 19        id: u32,
 20        responding_to: Option<u32>,
 21        original_sender_id: Option<u32>,
 22    ) -> Envelope;
 23    fn from_envelope(envelope: Envelope) -> Option<Self>;
 24}
 25
 26pub trait EntityMessage: EnvelopedMessage {
 27    fn remote_entity_id(&self) -> u64;
 28}
 29
 30pub trait RequestMessage: EnvelopedMessage {
 31    type Response: EnvelopedMessage;
 32}
 33
 34pub trait AnyTypedEnvelope: 'static + Send + Sync {
 35    fn payload_type_id(&self) -> TypeId;
 36    fn payload_type_name(&self) -> &'static str;
 37    fn as_any(&self) -> &dyn Any;
 38    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 39    fn is_background(&self) -> bool;
 40    fn original_sender_id(&self) -> Option<PeerId>;
 41}
 42
 43pub enum MessagePriority {
 44    Foreground,
 45    Background,
 46}
 47
 48impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 49    fn payload_type_id(&self) -> TypeId {
 50        TypeId::of::<T>()
 51    }
 52
 53    fn payload_type_name(&self) -> &'static str {
 54        T::NAME
 55    }
 56
 57    fn as_any(&self) -> &dyn Any {
 58        self
 59    }
 60
 61    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 62        self
 63    }
 64
 65    fn is_background(&self) -> bool {
 66        matches!(T::PRIORITY, MessagePriority::Background)
 67    }
 68
 69    fn original_sender_id(&self) -> Option<PeerId> {
 70        self.original_sender_id
 71    }
 72}
 73
 74macro_rules! messages {
 75    ($(($name:ident, $priority:ident)),* $(,)?) => {
 76        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 77            match envelope.payload {
 78                $(Some(envelope::Payload::$name(payload)) => {
 79                    Some(Box::new(TypedEnvelope {
 80                        sender_id,
 81                        original_sender_id: envelope.original_sender_id.map(PeerId),
 82                        message_id: envelope.id,
 83                        payload,
 84                    }))
 85                }, )*
 86                _ => None
 87            }
 88        }
 89
 90        $(
 91            impl EnvelopedMessage for $name {
 92                const NAME: &'static str = std::stringify!($name);
 93                const PRIORITY: MessagePriority = MessagePriority::$priority;
 94
 95                fn into_envelope(
 96                    self,
 97                    id: u32,
 98                    responding_to: Option<u32>,
 99                    original_sender_id: Option<u32>,
100                ) -> Envelope {
101                    Envelope {
102                        id,
103                        responding_to,
104                        original_sender_id,
105                        payload: Some(envelope::Payload::$name(self)),
106                    }
107                }
108
109                fn from_envelope(envelope: Envelope) -> Option<Self> {
110                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
111                        Some(msg)
112                    } else {
113                        None
114                    }
115                }
116            }
117        )*
118    };
119}
120
121macro_rules! request_messages {
122    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
123        $(impl RequestMessage for $request_name {
124            type Response = $response_name;
125        })*
126    };
127}
128
129macro_rules! entity_messages {
130    ($id_field:ident, $($name:ident),* $(,)?) => {
131        $(impl EntityMessage for $name {
132            fn remote_entity_id(&self) -> u64 {
133                self.$id_field
134            }
135        })*
136    };
137}
138
139messages!(
140    (Ack, Foreground),
141    (AddProjectCollaborator, Foreground),
142    (ApplyCodeAction, Background),
143    (ApplyCodeActionResponse, Background),
144    (ApplyCompletionAdditionalEdits, Background),
145    (ApplyCompletionAdditionalEditsResponse, Background),
146    (BufferReloaded, Foreground),
147    (BufferSaved, Foreground),
148    (ChannelMessageSent, Foreground),
149    (Error, Foreground),
150    (Follow, Foreground),
151    (FollowResponse, Foreground),
152    (FormatBuffers, Foreground),
153    (FormatBuffersResponse, Foreground),
154    (GetChannelMessages, Foreground),
155    (GetChannelMessagesResponse, Foreground),
156    (GetChannels, Foreground),
157    (GetChannelsResponse, Foreground),
158    (GetCodeActions, Background),
159    (GetCodeActionsResponse, Background),
160    (GetCompletions, Background),
161    (GetCompletionsResponse, Background),
162    (GetDefinition, Background),
163    (GetDefinitionResponse, Background),
164    (GetDocumentHighlights, Background),
165    (GetDocumentHighlightsResponse, Background),
166    (GetReferences, Background),
167    (GetReferencesResponse, Background),
168    (GetProjectSymbols, Background),
169    (GetProjectSymbolsResponse, Background),
170    (GetUsers, Foreground),
171    (GetUsersResponse, Foreground),
172    (JoinChannel, Foreground),
173    (JoinChannelResponse, Foreground),
174    (JoinProject, Foreground),
175    (JoinProjectResponse, Foreground),
176    (StartLanguageServer, Foreground),
177    (UpdateLanguageServer, Foreground),
178    (LeaveChannel, Foreground),
179    (LeaveProject, Foreground),
180    (OpenBufferById, Background),
181    (OpenBufferByPath, Background),
182    (OpenBufferForSymbol, Background),
183    (OpenBufferForSymbolResponse, Background),
184    (OpenBufferResponse, Background),
185    (PerformRename, Background),
186    (PerformRenameResponse, Background),
187    (PrepareRename, Background),
188    (PrepareRenameResponse, Background),
189    (RegisterProjectResponse, Foreground),
190    (Ping, Foreground),
191    (RegisterProject, Foreground),
192    (RegisterWorktree, Foreground),
193    (ReloadBuffers, Foreground),
194    (ReloadBuffersResponse, Foreground),
195    (RemoveProjectCollaborator, Foreground),
196    (SaveBuffer, Foreground),
197    (SearchProject, Background),
198    (SearchProjectResponse, Background),
199    (SendChannelMessage, Foreground),
200    (SendChannelMessageResponse, Foreground),
201    (ShareProject, Foreground),
202    (Test, Foreground),
203    (Unfollow, Foreground),
204    (UnregisterProject, Foreground),
205    (UnregisterWorktree, Foreground),
206    (UnshareProject, Foreground),
207    (UpdateBuffer, Foreground),
208    (UpdateBufferFile, Foreground),
209    (UpdateContacts, Foreground),
210    (UpdateDiagnosticSummary, Foreground),
211    (UpdateFollowers, Foreground),
212    (UpdateWorktree, Foreground),
213);
214
215request_messages!(
216    (ApplyCodeAction, ApplyCodeActionResponse),
217    (
218        ApplyCompletionAdditionalEdits,
219        ApplyCompletionAdditionalEditsResponse
220    ),
221    (Follow, FollowResponse),
222    (FormatBuffers, FormatBuffersResponse),
223    (GetChannelMessages, GetChannelMessagesResponse),
224    (GetChannels, GetChannelsResponse),
225    (GetCodeActions, GetCodeActionsResponse),
226    (GetCompletions, GetCompletionsResponse),
227    (GetDefinition, GetDefinitionResponse),
228    (GetDocumentHighlights, GetDocumentHighlightsResponse),
229    (GetReferences, GetReferencesResponse),
230    (GetProjectSymbols, GetProjectSymbolsResponse),
231    (GetUsers, GetUsersResponse),
232    (JoinChannel, JoinChannelResponse),
233    (JoinProject, JoinProjectResponse),
234    (OpenBufferById, OpenBufferResponse),
235    (OpenBufferByPath, OpenBufferResponse),
236    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
237    (Ping, Ack),
238    (PerformRename, PerformRenameResponse),
239    (PrepareRename, PrepareRenameResponse),
240    (RegisterProject, RegisterProjectResponse),
241    (RegisterWorktree, Ack),
242    (ReloadBuffers, ReloadBuffersResponse),
243    (SaveBuffer, BufferSaved),
244    (SearchProject, SearchProjectResponse),
245    (SendChannelMessage, SendChannelMessageResponse),
246    (ShareProject, Ack),
247    (Test, Test),
248    (UpdateBuffer, Ack),
249    (UpdateWorktree, Ack),
250);
251
252entity_messages!(
253    project_id,
254    AddProjectCollaborator,
255    ApplyCodeAction,
256    ApplyCompletionAdditionalEdits,
257    BufferReloaded,
258    BufferSaved,
259    Follow,
260    FormatBuffers,
261    GetCodeActions,
262    GetCompletions,
263    GetDefinition,
264    GetDocumentHighlights,
265    GetReferences,
266    GetProjectSymbols,
267    JoinProject,
268    LeaveProject,
269    OpenBufferById,
270    OpenBufferByPath,
271    OpenBufferForSymbol,
272    PerformRename,
273    PrepareRename,
274    ReloadBuffers,
275    RemoveProjectCollaborator,
276    SaveBuffer,
277    SearchProject,
278    StartLanguageServer,
279    Unfollow,
280    UnregisterWorktree,
281    UnshareProject,
282    UpdateBuffer,
283    UpdateBufferFile,
284    UpdateDiagnosticSummary,
285    UpdateFollowers,
286    UpdateLanguageServer,
287    RegisterWorktree,
288    UpdateWorktree,
289);
290
291entity_messages!(channel_id, ChannelMessageSent);
292
293/// A stream of protobuf messages.
294pub struct MessageStream<S> {
295    stream: S,
296    encoding_buffer: Vec<u8>,
297}
298
299#[derive(Debug)]
300pub enum Message {
301    Envelope(Envelope),
302    Ping,
303    Pong,
304}
305
306impl<S> MessageStream<S> {
307    pub fn new(stream: S) -> Self {
308        Self {
309            stream,
310            encoding_buffer: Vec::new(),
311        }
312    }
313
314    pub fn inner_mut(&mut self) -> &mut S {
315        &mut self.stream
316    }
317}
318
319impl<S> MessageStream<S>
320where
321    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
322{
323    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
324        #[cfg(any(test, feature = "test-support"))]
325        const COMPRESSION_LEVEL: i32 = -7;
326
327        #[cfg(not(any(test, feature = "test-support")))]
328        const COMPRESSION_LEVEL: i32 = 4;
329
330        match message {
331            Message::Envelope(message) => {
332                self.encoding_buffer.resize(message.encoded_len(), 0);
333                self.encoding_buffer.clear();
334                message
335                    .encode(&mut self.encoding_buffer)
336                    .map_err(|err| io::Error::from(err))?;
337                let buffer =
338                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
339                        .unwrap();
340                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
341            }
342            Message::Ping => {
343                self.stream
344                    .send(WebSocketMessage::Ping(Default::default()))
345                    .await?;
346            }
347            Message::Pong => {
348                self.stream
349                    .send(WebSocketMessage::Pong(Default::default()))
350                    .await?;
351            }
352        }
353
354        Ok(())
355    }
356}
357
358impl<S> MessageStream<S>
359where
360    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
361{
362    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
363        while let Some(bytes) = self.stream.next().await {
364            match bytes? {
365                WebSocketMessage::Binary(bytes) => {
366                    self.encoding_buffer.clear();
367                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
368                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
369                        .map_err(io::Error::from)?;
370                    return Ok(Message::Envelope(envelope));
371                }
372                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
373                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
374                WebSocketMessage::Close(_) => break,
375                _ => {}
376            }
377        }
378        Err(anyhow!("connection closed"))
379    }
380}
381
382impl Into<SystemTime> for Timestamp {
383    fn into(self) -> SystemTime {
384        UNIX_EPOCH
385            .checked_add(Duration::new(self.seconds, self.nanos))
386            .unwrap()
387    }
388}
389
390impl From<SystemTime> for Timestamp {
391    fn from(time: SystemTime) -> Self {
392        let duration = time.duration_since(UNIX_EPOCH).unwrap();
393        Self {
394            seconds: duration.as_secs(),
395            nanos: duration.subsec_nanos(),
396        }
397    }
398}
399
400impl From<u128> for Nonce {
401    fn from(nonce: u128) -> Self {
402        let upper_half = (nonce >> 64) as u64;
403        let lower_half = nonce as u64;
404        Self {
405            upper_half,
406            lower_half,
407        }
408    }
409}
410
411impl From<Nonce> for u128 {
412    fn from(nonce: Nonce) -> Self {
413        let upper_half = (nonce.upper_half as u128) << 64;
414        let lower_half = nonce.lower_half as u128;
415        upper_half | lower_half
416    }
417}