proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{
  9    io,
 10    time::{Duration, SystemTime, UNIX_EPOCH},
 11};
 12
 13include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 14
 15pub trait EnvelopedMessage: Clone + Serialize + Sized + Send + Sync + 'static {
 16    const NAME: &'static str;
 17    const PRIORITY: MessagePriority;
 18    fn into_envelope(
 19        self,
 20        id: u32,
 21        responding_to: Option<u32>,
 22        original_sender_id: Option<u32>,
 23    ) -> Envelope;
 24    fn from_envelope(envelope: Envelope) -> Option<Self>;
 25}
 26
 27pub trait EntityMessage: EnvelopedMessage {
 28    fn remote_entity_id(&self) -> u64;
 29}
 30
 31pub trait RequestMessage: EnvelopedMessage {
 32    type Response: EnvelopedMessage;
 33}
 34
 35pub trait AnyTypedEnvelope: 'static + Send + Sync {
 36    fn payload_type_id(&self) -> TypeId;
 37    fn payload_type_name(&self) -> &'static str;
 38    fn as_any(&self) -> &dyn Any;
 39    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 40    fn is_background(&self) -> bool;
 41    fn original_sender_id(&self) -> Option<PeerId>;
 42}
 43
 44pub enum MessagePriority {
 45    Foreground,
 46    Background,
 47}
 48
 49impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 50    fn payload_type_id(&self) -> TypeId {
 51        TypeId::of::<T>()
 52    }
 53
 54    fn payload_type_name(&self) -> &'static str {
 55        T::NAME
 56    }
 57
 58    fn as_any(&self) -> &dyn Any {
 59        self
 60    }
 61
 62    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 63        self
 64    }
 65
 66    fn is_background(&self) -> bool {
 67        matches!(T::PRIORITY, MessagePriority::Background)
 68    }
 69
 70    fn original_sender_id(&self) -> Option<PeerId> {
 71        self.original_sender_id
 72    }
 73}
 74
 75macro_rules! messages {
 76    ($(($name:ident, $priority:ident)),* $(,)?) => {
 77        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 78            match envelope.payload {
 79                $(Some(envelope::Payload::$name(payload)) => {
 80                    Some(Box::new(TypedEnvelope {
 81                        sender_id,
 82                        original_sender_id: envelope.original_sender_id.map(PeerId),
 83                        message_id: envelope.id,
 84                        payload,
 85                    }))
 86                }, )*
 87                _ => None
 88            }
 89        }
 90
 91        $(
 92            impl EnvelopedMessage for $name {
 93                const NAME: &'static str = std::stringify!($name);
 94                const PRIORITY: MessagePriority = MessagePriority::$priority;
 95
 96                fn into_envelope(
 97                    self,
 98                    id: u32,
 99                    responding_to: Option<u32>,
100                    original_sender_id: Option<u32>,
101                ) -> Envelope {
102                    Envelope {
103                        id,
104                        responding_to,
105                        original_sender_id,
106                        payload: Some(envelope::Payload::$name(self)),
107                    }
108                }
109
110                fn from_envelope(envelope: Envelope) -> Option<Self> {
111                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
112                        Some(msg)
113                    } else {
114                        None
115                    }
116                }
117            }
118        )*
119    };
120}
121
122macro_rules! request_messages {
123    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
124        $(impl RequestMessage for $request_name {
125            type Response = $response_name;
126        })*
127    };
128}
129
130macro_rules! entity_messages {
131    ($id_field:ident, $($name:ident),* $(,)?) => {
132        $(impl EntityMessage for $name {
133            fn remote_entity_id(&self) -> u64 {
134                self.$id_field
135            }
136        })*
137    };
138}
139
140messages!(
141    (Ack, Foreground),
142    (AddProjectCollaborator, Foreground),
143    (ApplyCodeAction, Background),
144    (ApplyCodeActionResponse, Background),
145    (ApplyCompletionAdditionalEdits, Background),
146    (ApplyCompletionAdditionalEditsResponse, Background),
147    (BufferReloaded, Foreground),
148    (BufferSaved, Foreground),
149    (ChannelMessageSent, Foreground),
150    (CreateProjectEntry, Foreground),
151    (DeleteProjectEntry, Foreground),
152    (Error, Foreground),
153    (Follow, Foreground),
154    (FollowResponse, Foreground),
155    (FormatBuffers, Foreground),
156    (FormatBuffersResponse, Foreground),
157    (GetChannelMessages, Foreground),
158    (GetChannelMessagesResponse, Foreground),
159    (GetChannels, Foreground),
160    (GetChannelsResponse, Foreground),
161    (GetCodeActions, Background),
162    (GetCodeActionsResponse, Background),
163    (GetCompletions, Background),
164    (GetCompletionsResponse, Background),
165    (GetDefinition, Background),
166    (GetDefinitionResponse, Background),
167    (GetDocumentHighlights, Background),
168    (GetDocumentHighlightsResponse, Background),
169    (GetReferences, Background),
170    (GetReferencesResponse, Background),
171    (GetProjectSymbols, Background),
172    (GetProjectSymbolsResponse, Background),
173    (GetUsers, Foreground),
174    (GetUsersResponse, Foreground),
175    (JoinChannel, Foreground),
176    (JoinChannelResponse, Foreground),
177    (JoinProject, Foreground),
178    (JoinProjectResponse, Foreground),
179    (LeaveChannel, Foreground),
180    (LeaveProject, Foreground),
181    (OpenBufferById, Background),
182    (OpenBufferByPath, Background),
183    (OpenBufferForSymbol, Background),
184    (OpenBufferForSymbolResponse, Background),
185    (OpenBufferResponse, Background),
186    (PerformRename, Background),
187    (PerformRenameResponse, Background),
188    (PrepareRename, Background),
189    (PrepareRenameResponse, Background),
190    (ProjectEntryResponse, Foreground),
191    (RegisterProjectResponse, Foreground),
192    (Ping, Foreground),
193    (RegisterProject, Foreground),
194    (RegisterWorktree, Foreground),
195    (ReloadBuffers, Foreground),
196    (ReloadBuffersResponse, Foreground),
197    (RemoveProjectCollaborator, Foreground),
198    (RenameProjectEntry, Foreground),
199    (SaveBuffer, Foreground),
200    (SearchProject, Background),
201    (SearchProjectResponse, Background),
202    (SendChannelMessage, Foreground),
203    (SendChannelMessageResponse, Foreground),
204    (ShareProject, Foreground),
205    (StartLanguageServer, Foreground),
206    (Test, Foreground),
207    (Unfollow, Foreground),
208    (UnregisterProject, Foreground),
209    (UnregisterWorktree, Foreground),
210    (UnshareProject, Foreground),
211    (UpdateBuffer, Foreground),
212    (UpdateBufferFile, Foreground),
213    (UpdateContacts, Foreground),
214    (UpdateDiagnosticSummary, Foreground),
215    (UpdateFollowers, Foreground),
216    (UpdateLanguageServer, Foreground),
217    (UpdateWorktree, Foreground),
218);
219
220request_messages!(
221    (ApplyCodeAction, ApplyCodeActionResponse),
222    (
223        ApplyCompletionAdditionalEdits,
224        ApplyCompletionAdditionalEditsResponse
225    ),
226    (CreateProjectEntry, ProjectEntryResponse),
227    (Follow, FollowResponse),
228    (FormatBuffers, FormatBuffersResponse),
229    (GetChannelMessages, GetChannelMessagesResponse),
230    (GetChannels, GetChannelsResponse),
231    (GetCodeActions, GetCodeActionsResponse),
232    (GetCompletions, GetCompletionsResponse),
233    (GetDefinition, GetDefinitionResponse),
234    (GetDocumentHighlights, GetDocumentHighlightsResponse),
235    (GetReferences, GetReferencesResponse),
236    (GetProjectSymbols, GetProjectSymbolsResponse),
237    (GetUsers, GetUsersResponse),
238    (JoinChannel, JoinChannelResponse),
239    (JoinProject, JoinProjectResponse),
240    (OpenBufferById, OpenBufferResponse),
241    (OpenBufferByPath, OpenBufferResponse),
242    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
243    (Ping, Ack),
244    (PerformRename, PerformRenameResponse),
245    (PrepareRename, PrepareRenameResponse),
246    (RegisterProject, RegisterProjectResponse),
247    (RegisterWorktree, Ack),
248    (ReloadBuffers, ReloadBuffersResponse),
249    (RenameProjectEntry, ProjectEntryResponse),
250    (SaveBuffer, BufferSaved),
251    (SearchProject, SearchProjectResponse),
252    (SendChannelMessage, SendChannelMessageResponse),
253    (ShareProject, Ack),
254    (Test, Test),
255    (UpdateBuffer, Ack),
256    (UpdateWorktree, Ack),
257);
258
259entity_messages!(
260    project_id,
261    AddProjectCollaborator,
262    ApplyCodeAction,
263    ApplyCompletionAdditionalEdits,
264    BufferReloaded,
265    BufferSaved,
266    CreateProjectEntry,
267    RenameProjectEntry,
268    DeleteProjectEntry,
269    Follow,
270    FormatBuffers,
271    GetCodeActions,
272    GetCompletions,
273    GetDefinition,
274    GetDocumentHighlights,
275    GetReferences,
276    GetProjectSymbols,
277    JoinProject,
278    LeaveProject,
279    OpenBufferById,
280    OpenBufferByPath,
281    OpenBufferForSymbol,
282    PerformRename,
283    PrepareRename,
284    ReloadBuffers,
285    RemoveProjectCollaborator,
286    SaveBuffer,
287    SearchProject,
288    StartLanguageServer,
289    Unfollow,
290    UnregisterWorktree,
291    UnshareProject,
292    UpdateBuffer,
293    UpdateBufferFile,
294    UpdateDiagnosticSummary,
295    UpdateFollowers,
296    UpdateLanguageServer,
297    RegisterWorktree,
298    UpdateWorktree,
299);
300
301entity_messages!(channel_id, ChannelMessageSent);
302
303/// A stream of protobuf messages.
304pub struct MessageStream<S> {
305    stream: S,
306    encoding_buffer: Vec<u8>,
307}
308
309#[derive(Debug)]
310pub enum Message {
311    Envelope(Envelope),
312    Ping,
313    Pong,
314}
315
316impl<S> MessageStream<S> {
317    pub fn new(stream: S) -> Self {
318        Self {
319            stream,
320            encoding_buffer: Vec::new(),
321        }
322    }
323
324    pub fn inner_mut(&mut self) -> &mut S {
325        &mut self.stream
326    }
327}
328
329impl<S> MessageStream<S>
330where
331    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
332{
333    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
334        #[cfg(any(test, feature = "test-support"))]
335        const COMPRESSION_LEVEL: i32 = -7;
336
337        #[cfg(not(any(test, feature = "test-support")))]
338        const COMPRESSION_LEVEL: i32 = 4;
339
340        match message {
341            Message::Envelope(message) => {
342                self.encoding_buffer.resize(message.encoded_len(), 0);
343                self.encoding_buffer.clear();
344                message
345                    .encode(&mut self.encoding_buffer)
346                    .map_err(|err| io::Error::from(err))?;
347                let buffer =
348                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
349                        .unwrap();
350                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
351            }
352            Message::Ping => {
353                self.stream
354                    .send(WebSocketMessage::Ping(Default::default()))
355                    .await?;
356            }
357            Message::Pong => {
358                self.stream
359                    .send(WebSocketMessage::Pong(Default::default()))
360                    .await?;
361            }
362        }
363
364        Ok(())
365    }
366}
367
368impl<S> MessageStream<S>
369where
370    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
371{
372    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
373        while let Some(bytes) = self.stream.next().await {
374            match bytes? {
375                WebSocketMessage::Binary(bytes) => {
376                    self.encoding_buffer.clear();
377                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
378                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
379                        .map_err(io::Error::from)?;
380                    return Ok(Message::Envelope(envelope));
381                }
382                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
383                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
384                WebSocketMessage::Close(_) => break,
385                _ => {}
386            }
387        }
388        Err(anyhow!("connection closed"))
389    }
390}
391
392impl Into<SystemTime> for Timestamp {
393    fn into(self) -> SystemTime {
394        UNIX_EPOCH
395            .checked_add(Duration::new(self.seconds, self.nanos))
396            .unwrap()
397    }
398}
399
400impl From<SystemTime> for Timestamp {
401    fn from(time: SystemTime) -> Self {
402        let duration = time.duration_since(UNIX_EPOCH).unwrap();
403        Self {
404            seconds: duration.as_secs(),
405            nanos: duration.subsec_nanos(),
406        }
407    }
408}
409
410impl From<u128> for Nonce {
411    fn from(nonce: u128) -> Self {
412        let upper_half = (nonce >> 64) as u64;
413        let lower_half = nonce as u64;
414        Self {
415            upper_half,
416            lower_half,
417        }
418    }
419}
420
421impl From<Nonce> for u128 {
422    fn from(nonce: Nonce) -> Self {
423        let upper_half = (nonce.upper_half as u128) << 64;
424        let lower_half = nonce.lower_half as u128;
425        upper_half | lower_half
426    }
427}