proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{
  9    fmt::Debug,
 10    io,
 11    time::{Duration, SystemTime, UNIX_EPOCH},
 12};
 13
 14include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 15
 16pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 17    const NAME: &'static str;
 18    const PRIORITY: MessagePriority;
 19    fn into_envelope(
 20        self,
 21        id: u32,
 22        responding_to: Option<u32>,
 23        original_sender_id: Option<u32>,
 24    ) -> Envelope;
 25    fn from_envelope(envelope: Envelope) -> Option<Self>;
 26}
 27
 28pub trait EntityMessage: EnvelopedMessage {
 29    fn remote_entity_id(&self) -> u64;
 30}
 31
 32pub trait RequestMessage: EnvelopedMessage {
 33    type Response: EnvelopedMessage;
 34}
 35
 36pub trait AnyTypedEnvelope: 'static + Send + Sync {
 37    fn payload_type_id(&self) -> TypeId;
 38    fn payload_type_name(&self) -> &'static str;
 39    fn as_any(&self) -> &dyn Any;
 40    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 41    fn is_background(&self) -> bool;
 42    fn original_sender_id(&self) -> Option<PeerId>;
 43}
 44
 45pub enum MessagePriority {
 46    Foreground,
 47    Background,
 48}
 49
 50impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 51    fn payload_type_id(&self) -> TypeId {
 52        TypeId::of::<T>()
 53    }
 54
 55    fn payload_type_name(&self) -> &'static str {
 56        T::NAME
 57    }
 58
 59    fn as_any(&self) -> &dyn Any {
 60        self
 61    }
 62
 63    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 64        self
 65    }
 66
 67    fn is_background(&self) -> bool {
 68        matches!(T::PRIORITY, MessagePriority::Background)
 69    }
 70
 71    fn original_sender_id(&self) -> Option<PeerId> {
 72        self.original_sender_id
 73    }
 74}
 75
 76macro_rules! messages {
 77    ($(($name:ident, $priority:ident)),* $(,)?) => {
 78        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 79            match envelope.payload {
 80                $(Some(envelope::Payload::$name(payload)) => {
 81                    Some(Box::new(TypedEnvelope {
 82                        sender_id,
 83                        original_sender_id: envelope.original_sender_id.map(PeerId),
 84                        message_id: envelope.id,
 85                        payload,
 86                    }))
 87                }, )*
 88                _ => None
 89            }
 90        }
 91
 92        $(
 93            impl EnvelopedMessage for $name {
 94                const NAME: &'static str = std::stringify!($name);
 95                const PRIORITY: MessagePriority = MessagePriority::$priority;
 96
 97                fn into_envelope(
 98                    self,
 99                    id: u32,
100                    responding_to: Option<u32>,
101                    original_sender_id: Option<u32>,
102                ) -> Envelope {
103                    Envelope {
104                        id,
105                        responding_to,
106                        original_sender_id,
107                        payload: Some(envelope::Payload::$name(self)),
108                    }
109                }
110
111                fn from_envelope(envelope: Envelope) -> Option<Self> {
112                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
113                        Some(msg)
114                    } else {
115                        None
116                    }
117                }
118            }
119        )*
120    };
121}
122
123macro_rules! request_messages {
124    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
125        $(impl RequestMessage for $request_name {
126            type Response = $response_name;
127        })*
128    };
129}
130
131macro_rules! entity_messages {
132    ($id_field:ident, $($name:ident),* $(,)?) => {
133        $(impl EntityMessage for $name {
134            fn remote_entity_id(&self) -> u64 {
135                self.$id_field
136            }
137        })*
138    };
139}
140
141messages!(
142    (Ack, Foreground),
143    (AddProjectCollaborator, Foreground),
144    (ApplyCodeAction, Background),
145    (ApplyCodeActionResponse, Background),
146    (ApplyCompletionAdditionalEdits, Background),
147    (ApplyCompletionAdditionalEditsResponse, Background),
148    (BufferReloaded, Foreground),
149    (BufferSaved, Foreground),
150    (ChannelMessageSent, Foreground),
151    (CreateProjectEntry, Foreground),
152    (DeleteProjectEntry, Foreground),
153    (Error, Foreground),
154    (Follow, Foreground),
155    (FollowResponse, Foreground),
156    (FormatBuffers, Foreground),
157    (FormatBuffersResponse, Foreground),
158    (GetChannelMessages, Foreground),
159    (GetChannelMessagesResponse, Foreground),
160    (GetChannels, Foreground),
161    (GetChannelsResponse, Foreground),
162    (GetCodeActions, Background),
163    (GetCodeActionsResponse, Background),
164    (GetCompletions, Background),
165    (GetCompletionsResponse, Background),
166    (GetDefinition, Background),
167    (GetDefinitionResponse, Background),
168    (GetDocumentHighlights, Background),
169    (GetDocumentHighlightsResponse, Background),
170    (GetReferences, Background),
171    (GetReferencesResponse, Background),
172    (GetProjectSymbols, Background),
173    (GetProjectSymbolsResponse, Background),
174    (GetUsers, Foreground),
175    (GetUsersResponse, Foreground),
176    (JoinChannel, Foreground),
177    (JoinChannelResponse, Foreground),
178    (JoinProject, Foreground),
179    (JoinProjectResponse, Foreground),
180    (LeaveChannel, Foreground),
181    (LeaveProject, Foreground),
182    (OpenBufferById, Background),
183    (OpenBufferByPath, Background),
184    (OpenBufferForSymbol, Background),
185    (OpenBufferForSymbolResponse, Background),
186    (OpenBufferResponse, Background),
187    (PerformRename, Background),
188    (PerformRenameResponse, Background),
189    (PrepareRename, Background),
190    (PrepareRenameResponse, Background),
191    (ProjectEntryResponse, Foreground),
192    (RegisterProjectResponse, Foreground),
193    (Ping, Foreground),
194    (RegisterProject, Foreground),
195    (RegisterWorktree, Foreground),
196    (ReloadBuffers, Foreground),
197    (ReloadBuffersResponse, Foreground),
198    (RemoveProjectCollaborator, Foreground),
199    (RenameProjectEntry, Foreground),
200    (SaveBuffer, Foreground),
201    (SearchProject, Background),
202    (SearchProjectResponse, Background),
203    (SendChannelMessage, Foreground),
204    (SendChannelMessageResponse, Foreground),
205    (ShareProject, Foreground),
206    (StartLanguageServer, Foreground),
207    (Test, Foreground),
208    (Unfollow, Foreground),
209    (UnregisterProject, Foreground),
210    (UnregisterWorktree, Foreground),
211    (UnshareProject, Foreground),
212    (UpdateBuffer, Foreground),
213    (UpdateBufferFile, Foreground),
214    (UpdateContacts, Foreground),
215    (UpdateDiagnosticSummary, Foreground),
216    (UpdateFollowers, Foreground),
217    (UpdateLanguageServer, Foreground),
218    (UpdateWorktree, Foreground),
219);
220
221request_messages!(
222    (ApplyCodeAction, ApplyCodeActionResponse),
223    (
224        ApplyCompletionAdditionalEdits,
225        ApplyCompletionAdditionalEditsResponse
226    ),
227    (CreateProjectEntry, ProjectEntryResponse),
228    (DeleteProjectEntry, ProjectEntryResponse),
229    (Follow, FollowResponse),
230    (FormatBuffers, FormatBuffersResponse),
231    (GetChannelMessages, GetChannelMessagesResponse),
232    (GetChannels, GetChannelsResponse),
233    (GetCodeActions, GetCodeActionsResponse),
234    (GetCompletions, GetCompletionsResponse),
235    (GetDefinition, GetDefinitionResponse),
236    (GetDocumentHighlights, GetDocumentHighlightsResponse),
237    (GetReferences, GetReferencesResponse),
238    (GetProjectSymbols, GetProjectSymbolsResponse),
239    (GetUsers, GetUsersResponse),
240    (JoinChannel, JoinChannelResponse),
241    (JoinProject, JoinProjectResponse),
242    (OpenBufferById, OpenBufferResponse),
243    (OpenBufferByPath, OpenBufferResponse),
244    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
245    (Ping, Ack),
246    (PerformRename, PerformRenameResponse),
247    (PrepareRename, PrepareRenameResponse),
248    (RegisterProject, RegisterProjectResponse),
249    (RegisterWorktree, Ack),
250    (ReloadBuffers, ReloadBuffersResponse),
251    (RenameProjectEntry, ProjectEntryResponse),
252    (SaveBuffer, BufferSaved),
253    (SearchProject, SearchProjectResponse),
254    (SendChannelMessage, SendChannelMessageResponse),
255    (ShareProject, Ack),
256    (Test, Test),
257    (UpdateBuffer, Ack),
258    (UpdateWorktree, Ack),
259);
260
261entity_messages!(
262    project_id,
263    AddProjectCollaborator,
264    ApplyCodeAction,
265    ApplyCompletionAdditionalEdits,
266    BufferReloaded,
267    BufferSaved,
268    CreateProjectEntry,
269    RenameProjectEntry,
270    DeleteProjectEntry,
271    Follow,
272    FormatBuffers,
273    GetCodeActions,
274    GetCompletions,
275    GetDefinition,
276    GetDocumentHighlights,
277    GetReferences,
278    GetProjectSymbols,
279    JoinProject,
280    LeaveProject,
281    OpenBufferById,
282    OpenBufferByPath,
283    OpenBufferForSymbol,
284    PerformRename,
285    PrepareRename,
286    ReloadBuffers,
287    RemoveProjectCollaborator,
288    SaveBuffer,
289    SearchProject,
290    StartLanguageServer,
291    Unfollow,
292    UnregisterWorktree,
293    UnshareProject,
294    UpdateBuffer,
295    UpdateBufferFile,
296    UpdateDiagnosticSummary,
297    UpdateFollowers,
298    UpdateLanguageServer,
299    RegisterWorktree,
300    UpdateWorktree,
301);
302
303entity_messages!(channel_id, ChannelMessageSent);
304
305/// A stream of protobuf messages.
306pub struct MessageStream<S> {
307    stream: S,
308    encoding_buffer: Vec<u8>,
309}
310
311#[derive(Debug)]
312pub enum Message {
313    Envelope(Envelope),
314    Ping,
315    Pong,
316}
317
318impl<S> MessageStream<S> {
319    pub fn new(stream: S) -> Self {
320        Self {
321            stream,
322            encoding_buffer: Vec::new(),
323        }
324    }
325
326    pub fn inner_mut(&mut self) -> &mut S {
327        &mut self.stream
328    }
329}
330
331impl<S> MessageStream<S>
332where
333    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
334{
335    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
336        #[cfg(any(test, feature = "test-support"))]
337        const COMPRESSION_LEVEL: i32 = -7;
338
339        #[cfg(not(any(test, feature = "test-support")))]
340        const COMPRESSION_LEVEL: i32 = 4;
341
342        match message {
343            Message::Envelope(message) => {
344                self.encoding_buffer.resize(message.encoded_len(), 0);
345                self.encoding_buffer.clear();
346                message
347                    .encode(&mut self.encoding_buffer)
348                    .map_err(|err| io::Error::from(err))?;
349                let buffer =
350                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
351                        .unwrap();
352                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
353            }
354            Message::Ping => {
355                self.stream
356                    .send(WebSocketMessage::Ping(Default::default()))
357                    .await?;
358            }
359            Message::Pong => {
360                self.stream
361                    .send(WebSocketMessage::Pong(Default::default()))
362                    .await?;
363            }
364        }
365
366        Ok(())
367    }
368}
369
370impl<S> MessageStream<S>
371where
372    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
373{
374    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
375        while let Some(bytes) = self.stream.next().await {
376            match bytes? {
377                WebSocketMessage::Binary(bytes) => {
378                    self.encoding_buffer.clear();
379                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
380                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
381                        .map_err(io::Error::from)?;
382                    return Ok(Message::Envelope(envelope));
383                }
384                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
385                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
386                WebSocketMessage::Close(_) => break,
387                _ => {}
388            }
389        }
390        Err(anyhow!("connection closed"))
391    }
392}
393
394impl Into<SystemTime> for Timestamp {
395    fn into(self) -> SystemTime {
396        UNIX_EPOCH
397            .checked_add(Duration::new(self.seconds, self.nanos))
398            .unwrap()
399    }
400}
401
402impl From<SystemTime> for Timestamp {
403    fn from(time: SystemTime) -> Self {
404        let duration = time.duration_since(UNIX_EPOCH).unwrap();
405        Self {
406            seconds: duration.as_secs(),
407            nanos: duration.subsec_nanos(),
408        }
409    }
410}
411
412impl From<u128> for Nonce {
413    fn from(nonce: u128) -> Self {
414        let upper_half = (nonce >> 64) as u64;
415        let lower_half = nonce as u64;
416        Self {
417            upper_half,
418            lower_half,
419        }
420    }
421}
422
423impl From<Nonce> for u128 {
424    fn from(nonce: Nonce) -> Self {
425        let upper_half = (nonce.upper_half as u128) << 64;
426        let lower_half = nonce.lower_half as u128;
427        upper_half | lower_half
428    }
429}