proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    const PRIORITY: MessagePriority;
 17    fn into_envelope(
 18        self,
 19        id: u32,
 20        responding_to: Option<u32>,
 21        original_sender_id: Option<u32>,
 22    ) -> Envelope;
 23    fn from_envelope(envelope: Envelope) -> Option<Self>;
 24}
 25
 26pub trait EntityMessage: EnvelopedMessage {
 27    fn remote_entity_id(&self) -> u64;
 28}
 29
 30pub trait RequestMessage: EnvelopedMessage {
 31    type Response: EnvelopedMessage;
 32}
 33
 34pub trait AnyTypedEnvelope: 'static + Send + Sync {
 35    fn payload_type_id(&self) -> TypeId;
 36    fn payload_type_name(&self) -> &'static str;
 37    fn as_any(&self) -> &dyn Any;
 38    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 39    fn is_background(&self) -> bool;
 40    fn original_sender_id(&self) -> Option<PeerId>;
 41}
 42
 43pub enum MessagePriority {
 44    Foreground,
 45    Background,
 46}
 47
 48impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 49    fn payload_type_id(&self) -> TypeId {
 50        TypeId::of::<T>()
 51    }
 52
 53    fn payload_type_name(&self) -> &'static str {
 54        T::NAME
 55    }
 56
 57    fn as_any(&self) -> &dyn Any {
 58        self
 59    }
 60
 61    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 62        self
 63    }
 64
 65    fn is_background(&self) -> bool {
 66        matches!(T::PRIORITY, MessagePriority::Background)
 67    }
 68
 69    fn original_sender_id(&self) -> Option<PeerId> {
 70        self.original_sender_id
 71    }
 72}
 73
 74macro_rules! messages {
 75    ($(($name:ident, $priority:ident)),* $(,)?) => {
 76        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 77            match envelope.payload {
 78                $(Some(envelope::Payload::$name(payload)) => {
 79                    Some(Box::new(TypedEnvelope {
 80                        sender_id,
 81                        original_sender_id: envelope.original_sender_id.map(PeerId),
 82                        message_id: envelope.id,
 83                        payload,
 84                    }))
 85                }, )*
 86                _ => None
 87            }
 88        }
 89
 90        $(
 91            impl EnvelopedMessage for $name {
 92                const NAME: &'static str = std::stringify!($name);
 93                const PRIORITY: MessagePriority = MessagePriority::$priority;
 94
 95                fn into_envelope(
 96                    self,
 97                    id: u32,
 98                    responding_to: Option<u32>,
 99                    original_sender_id: Option<u32>,
100                ) -> Envelope {
101                    Envelope {
102                        id,
103                        responding_to,
104                        original_sender_id,
105                        payload: Some(envelope::Payload::$name(self)),
106                    }
107                }
108
109                fn from_envelope(envelope: Envelope) -> Option<Self> {
110                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
111                        Some(msg)
112                    } else {
113                        None
114                    }
115                }
116            }
117        )*
118    };
119}
120
121macro_rules! request_messages {
122    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
123        $(impl RequestMessage for $request_name {
124            type Response = $response_name;
125        })*
126    };
127}
128
129macro_rules! entity_messages {
130    ($id_field:ident, $($name:ident),* $(,)?) => {
131        $(impl EntityMessage for $name {
132            fn remote_entity_id(&self) -> u64 {
133                self.$id_field
134            }
135        })*
136    };
137}
138
139messages!(
140    (Ack, Foreground),
141    (AddProjectCollaborator, Foreground),
142    (ApplyCodeAction, Background),
143    (ApplyCodeActionResponse, Background),
144    (ApplyCompletionAdditionalEdits, Background),
145    (ApplyCompletionAdditionalEditsResponse, Background),
146    (BufferReloaded, Foreground),
147    (BufferSaved, Foreground),
148    (ChannelMessageSent, Foreground),
149    (CloseBuffer, Foreground),
150    (DiskBasedDiagnosticsUpdated, Background),
151    (DiskBasedDiagnosticsUpdating, Background),
152    (Error, Foreground),
153    (FormatBuffers, Foreground),
154    (FormatBuffersResponse, Foreground),
155    (GetChannelMessages, Foreground),
156    (GetChannelMessagesResponse, Foreground),
157    (GetChannels, Foreground),
158    (GetChannelsResponse, Foreground),
159    (GetCodeActions, Background),
160    (GetCodeActionsResponse, Background),
161    (GetCompletions, Background),
162    (GetCompletionsResponse, Background),
163    (GetDefinition, Background),
164    (GetDefinitionResponse, Background),
165    (GetDocumentHighlights, Background),
166    (GetDocumentHighlightsResponse, Background),
167    (GetReferences, Background),
168    (GetReferencesResponse, Background),
169    (GetProjectSymbols, Background),
170    (GetProjectSymbolsResponse, Background),
171    (GetUsers, Foreground),
172    (GetUsersResponse, Foreground),
173    (JoinChannel, Foreground),
174    (JoinChannelResponse, Foreground),
175    (JoinProject, Foreground),
176    (JoinProjectResponse, Foreground),
177    (LeaveChannel, Foreground),
178    (LeaveProject, Foreground),
179    (OpenBuffer, Background),
180    (OpenBufferForSymbol, Background),
181    (OpenBufferForSymbolResponse, Background),
182    (OpenBufferResponse, Background),
183    (PerformRename, Background),
184    (PerformRenameResponse, Background),
185    (PrepareRename, Background),
186    (PrepareRenameResponse, Background),
187    (RegisterProjectResponse, Foreground),
188    (Ping, Foreground),
189    (RegisterProject, Foreground),
190    (RegisterWorktree, Foreground),
191    (RemoveProjectCollaborator, Foreground),
192    (SaveBuffer, Foreground),
193    (SearchProject, Background),
194    (SearchProjectResponse, Background),
195    (SendChannelMessage, Foreground),
196    (SendChannelMessageResponse, Foreground),
197    (ShareProject, Foreground),
198    (Test, Foreground),
199    (UnregisterProject, Foreground),
200    (UnregisterWorktree, Foreground),
201    (UnshareProject, Foreground),
202    (UpdateBuffer, Background),
203    (UpdateBufferFile, Foreground),
204    (UpdateContacts, Foreground),
205    (UpdateDiagnosticSummary, Foreground),
206    (UpdateWorktree, Foreground),
207);
208
209request_messages!(
210    (ApplyCodeAction, ApplyCodeActionResponse),
211    (
212        ApplyCompletionAdditionalEdits,
213        ApplyCompletionAdditionalEditsResponse
214    ),
215    (FormatBuffers, FormatBuffersResponse),
216    (GetChannelMessages, GetChannelMessagesResponse),
217    (GetChannels, GetChannelsResponse),
218    (GetCodeActions, GetCodeActionsResponse),
219    (GetCompletions, GetCompletionsResponse),
220    (GetDefinition, GetDefinitionResponse),
221    (GetDocumentHighlights, GetDocumentHighlightsResponse),
222    (GetReferences, GetReferencesResponse),
223    (GetProjectSymbols, GetProjectSymbolsResponse),
224    (GetUsers, GetUsersResponse),
225    (JoinChannel, JoinChannelResponse),
226    (JoinProject, JoinProjectResponse),
227    (OpenBuffer, OpenBufferResponse),
228    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
229    (Ping, Ack),
230    (PerformRename, PerformRenameResponse),
231    (PrepareRename, PrepareRenameResponse),
232    (RegisterProject, RegisterProjectResponse),
233    (RegisterWorktree, Ack),
234    (SaveBuffer, BufferSaved),
235    (SearchProject, SearchProjectResponse),
236    (SendChannelMessage, SendChannelMessageResponse),
237    (ShareProject, Ack),
238    (Test, Test),
239    (UpdateBuffer, Ack),
240    (UpdateWorktree, Ack),
241);
242
243entity_messages!(
244    project_id,
245    AddProjectCollaborator,
246    ApplyCodeAction,
247    ApplyCompletionAdditionalEdits,
248    BufferReloaded,
249    BufferSaved,
250    CloseBuffer,
251    DiskBasedDiagnosticsUpdated,
252    DiskBasedDiagnosticsUpdating,
253    FormatBuffers,
254    GetCodeActions,
255    GetCompletions,
256    GetDefinition,
257    GetDocumentHighlights,
258    GetReferences,
259    GetProjectSymbols,
260    JoinProject,
261    LeaveProject,
262    OpenBuffer,
263    OpenBufferForSymbol,
264    PerformRename,
265    PrepareRename,
266    RemoveProjectCollaborator,
267    SaveBuffer,
268    SearchProject,
269    UnregisterWorktree,
270    UnshareProject,
271    UpdateBuffer,
272    UpdateBufferFile,
273    UpdateDiagnosticSummary,
274    RegisterWorktree,
275    UpdateWorktree,
276);
277
278entity_messages!(channel_id, ChannelMessageSent);
279
280/// A stream of protobuf messages.
281pub struct MessageStream<S> {
282    stream: S,
283    encoding_buffer: Vec<u8>,
284}
285
286#[derive(Debug)]
287pub enum Message {
288    Envelope(Envelope),
289    Ping,
290    Pong,
291}
292
293impl<S> MessageStream<S> {
294    pub fn new(stream: S) -> Self {
295        Self {
296            stream,
297            encoding_buffer: Vec::new(),
298        }
299    }
300
301    pub fn inner_mut(&mut self) -> &mut S {
302        &mut self.stream
303    }
304}
305
306impl<S> MessageStream<S>
307where
308    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
309{
310    pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> {
311        #[cfg(any(test, feature = "test-support"))]
312        const COMPRESSION_LEVEL: i32 = -7;
313
314        #[cfg(not(any(test, feature = "test-support")))]
315        const COMPRESSION_LEVEL: i32 = 4;
316
317        match message {
318            Message::Envelope(message) => {
319                self.encoding_buffer.resize(message.encoded_len(), 0);
320                self.encoding_buffer.clear();
321                message
322                    .encode(&mut self.encoding_buffer)
323                    .map_err(|err| io::Error::from(err))?;
324                let buffer =
325                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
326                        .unwrap();
327                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
328            }
329            Message::Ping => {
330                self.stream
331                    .send(WebSocketMessage::Ping(Default::default()))
332                    .await?;
333            }
334            Message::Pong => {
335                self.stream
336                    .send(WebSocketMessage::Pong(Default::default()))
337                    .await?;
338            }
339        }
340
341        Ok(())
342    }
343}
344
345impl<S> MessageStream<S>
346where
347    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
348{
349    pub async fn read(&mut self) -> Result<Message, WebSocketError> {
350        while let Some(bytes) = self.stream.next().await {
351            match bytes? {
352                WebSocketMessage::Binary(bytes) => {
353                    self.encoding_buffer.clear();
354                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
355                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
356                        .map_err(io::Error::from)?;
357                    return Ok(Message::Envelope(envelope));
358                }
359                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
360                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
361                WebSocketMessage::Close(_) => break,
362                _ => {}
363            }
364        }
365        Err(WebSocketError::ConnectionClosed)
366    }
367}
368
369impl Into<SystemTime> for Timestamp {
370    fn into(self) -> SystemTime {
371        UNIX_EPOCH
372            .checked_add(Duration::new(self.seconds, self.nanos))
373            .unwrap()
374    }
375}
376
377impl From<SystemTime> for Timestamp {
378    fn from(time: SystemTime) -> Self {
379        let duration = time.duration_since(UNIX_EPOCH).unwrap();
380        Self {
381            seconds: duration.as_secs(),
382            nanos: duration.subsec_nanos(),
383        }
384    }
385}
386
387impl From<u128> for Nonce {
388    fn from(nonce: u128) -> Self {
389        let upper_half = (nonce >> 64) as u64;
390        let lower_half = nonce as u64;
391        Self {
392            upper_half,
393            lower_half,
394        }
395    }
396}
397
398impl From<Nonce> for u128 {
399    fn from(nonce: Nonce) -> Self {
400        let upper_half = (nonce.upper_half as u128) << 64;
401        let lower_half = nonce.lower_half as u128;
402        upper_half | lower_half
403    }
404}