proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{
  9    io,
 10    time::{Duration, SystemTime, UNIX_EPOCH},
 11};
 12
 13include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 14
 15pub trait EnvelopedMessage: Clone + Serialize + Sized + Send + Sync + 'static {
 16    const NAME: &'static str;
 17    const PRIORITY: MessagePriority;
 18    fn into_envelope(
 19        self,
 20        id: u32,
 21        responding_to: Option<u32>,
 22        original_sender_id: Option<u32>,
 23    ) -> Envelope;
 24    fn from_envelope(envelope: Envelope) -> Option<Self>;
 25}
 26
 27pub trait EntityMessage: EnvelopedMessage {
 28    fn remote_entity_id(&self) -> u64;
 29}
 30
 31pub trait RequestMessage: EnvelopedMessage {
 32    type Response: EnvelopedMessage;
 33}
 34
 35pub trait AnyTypedEnvelope: 'static + Send + Sync {
 36    fn payload_type_id(&self) -> TypeId;
 37    fn payload_type_name(&self) -> &'static str;
 38    fn as_any(&self) -> &dyn Any;
 39    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 40    fn is_background(&self) -> bool;
 41    fn original_sender_id(&self) -> Option<PeerId>;
 42}
 43
 44pub enum MessagePriority {
 45    Foreground,
 46    Background,
 47}
 48
 49impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 50    fn payload_type_id(&self) -> TypeId {
 51        TypeId::of::<T>()
 52    }
 53
 54    fn payload_type_name(&self) -> &'static str {
 55        T::NAME
 56    }
 57
 58    fn as_any(&self) -> &dyn Any {
 59        self
 60    }
 61
 62    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 63        self
 64    }
 65
 66    fn is_background(&self) -> bool {
 67        matches!(T::PRIORITY, MessagePriority::Background)
 68    }
 69
 70    fn original_sender_id(&self) -> Option<PeerId> {
 71        self.original_sender_id
 72    }
 73}
 74
 75macro_rules! messages {
 76    ($(($name:ident, $priority:ident)),* $(,)?) => {
 77        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 78            match envelope.payload {
 79                $(Some(envelope::Payload::$name(payload)) => {
 80                    Some(Box::new(TypedEnvelope {
 81                        sender_id,
 82                        original_sender_id: envelope.original_sender_id.map(PeerId),
 83                        message_id: envelope.id,
 84                        payload,
 85                    }))
 86                }, )*
 87                _ => None
 88            }
 89        }
 90
 91        $(
 92            impl EnvelopedMessage for $name {
 93                const NAME: &'static str = std::stringify!($name);
 94                const PRIORITY: MessagePriority = MessagePriority::$priority;
 95
 96                fn into_envelope(
 97                    self,
 98                    id: u32,
 99                    responding_to: Option<u32>,
100                    original_sender_id: Option<u32>,
101                ) -> Envelope {
102                    Envelope {
103                        id,
104                        responding_to,
105                        original_sender_id,
106                        payload: Some(envelope::Payload::$name(self)),
107                    }
108                }
109
110                fn from_envelope(envelope: Envelope) -> Option<Self> {
111                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
112                        Some(msg)
113                    } else {
114                        None
115                    }
116                }
117            }
118        )*
119    };
120}
121
122macro_rules! request_messages {
123    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
124        $(impl RequestMessage for $request_name {
125            type Response = $response_name;
126        })*
127    };
128}
129
130macro_rules! entity_messages {
131    ($id_field:ident, $($name:ident),* $(,)?) => {
132        $(impl EntityMessage for $name {
133            fn remote_entity_id(&self) -> u64 {
134                self.$id_field
135            }
136        })*
137    };
138}
139
140messages!(
141    (Ack, Foreground),
142    (AddProjectCollaborator, Foreground),
143    (ApplyCodeAction, Background),
144    (ApplyCodeActionResponse, Background),
145    (ApplyCompletionAdditionalEdits, Background),
146    (ApplyCompletionAdditionalEditsResponse, Background),
147    (BufferReloaded, Foreground),
148    (BufferSaved, Foreground),
149    (ChannelMessageSent, Foreground),
150    (Error, Foreground),
151    (Follow, Foreground),
152    (FollowResponse, Foreground),
153    (FormatBuffers, Foreground),
154    (FormatBuffersResponse, Foreground),
155    (GetChannelMessages, Foreground),
156    (GetChannelMessagesResponse, Foreground),
157    (GetChannels, Foreground),
158    (GetChannelsResponse, Foreground),
159    (GetCodeActions, Background),
160    (GetCodeActionsResponse, Background),
161    (GetCompletions, Background),
162    (GetCompletionsResponse, Background),
163    (GetDefinition, Background),
164    (GetDefinitionResponse, Background),
165    (GetDocumentHighlights, Background),
166    (GetDocumentHighlightsResponse, Background),
167    (GetReferences, Background),
168    (GetReferencesResponse, Background),
169    (GetProjectSymbols, Background),
170    (GetProjectSymbolsResponse, Background),
171    (GetUsers, Foreground),
172    (GetUsersResponse, Foreground),
173    (JoinChannel, Foreground),
174    (JoinChannelResponse, Foreground),
175    (JoinProject, Foreground),
176    (JoinProjectResponse, Foreground),
177    (StartLanguageServer, Foreground),
178    (UpdateLanguageServer, Foreground),
179    (LeaveChannel, Foreground),
180    (LeaveProject, Foreground),
181    (OpenBufferById, Background),
182    (OpenBufferByPath, Background),
183    (OpenBufferForSymbol, Background),
184    (OpenBufferForSymbolResponse, Background),
185    (OpenBufferResponse, Background),
186    (PerformRename, Background),
187    (PerformRenameResponse, Background),
188    (PrepareRename, Background),
189    (PrepareRenameResponse, Background),
190    (RegisterProjectResponse, Foreground),
191    (Ping, Foreground),
192    (RegisterProject, Foreground),
193    (RegisterWorktree, Foreground),
194    (ReloadBuffers, Foreground),
195    (ReloadBuffersResponse, Foreground),
196    (RemoveProjectCollaborator, Foreground),
197    (SaveBuffer, Foreground),
198    (SearchProject, Background),
199    (SearchProjectResponse, Background),
200    (SendChannelMessage, Foreground),
201    (SendChannelMessageResponse, Foreground),
202    (ShareProject, Foreground),
203    (Test, Foreground),
204    (Unfollow, Foreground),
205    (UnregisterProject, Foreground),
206    (UnregisterWorktree, Foreground),
207    (UnshareProject, Foreground),
208    (UpdateBuffer, Foreground),
209    (UpdateBufferFile, Foreground),
210    (UpdateContacts, Foreground),
211    (UpdateDiagnosticSummary, Foreground),
212    (UpdateFollowers, Foreground),
213    (UpdateWorktree, Foreground),
214);
215
216request_messages!(
217    (ApplyCodeAction, ApplyCodeActionResponse),
218    (
219        ApplyCompletionAdditionalEdits,
220        ApplyCompletionAdditionalEditsResponse
221    ),
222    (Follow, FollowResponse),
223    (FormatBuffers, FormatBuffersResponse),
224    (GetChannelMessages, GetChannelMessagesResponse),
225    (GetChannels, GetChannelsResponse),
226    (GetCodeActions, GetCodeActionsResponse),
227    (GetCompletions, GetCompletionsResponse),
228    (GetDefinition, GetDefinitionResponse),
229    (GetDocumentHighlights, GetDocumentHighlightsResponse),
230    (GetReferences, GetReferencesResponse),
231    (GetProjectSymbols, GetProjectSymbolsResponse),
232    (GetUsers, GetUsersResponse),
233    (JoinChannel, JoinChannelResponse),
234    (JoinProject, JoinProjectResponse),
235    (OpenBufferById, OpenBufferResponse),
236    (OpenBufferByPath, OpenBufferResponse),
237    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
238    (Ping, Ack),
239    (PerformRename, PerformRenameResponse),
240    (PrepareRename, PrepareRenameResponse),
241    (RegisterProject, RegisterProjectResponse),
242    (RegisterWorktree, Ack),
243    (ReloadBuffers, ReloadBuffersResponse),
244    (SaveBuffer, BufferSaved),
245    (SearchProject, SearchProjectResponse),
246    (SendChannelMessage, SendChannelMessageResponse),
247    (ShareProject, Ack),
248    (Test, Test),
249    (UpdateBuffer, Ack),
250    (UpdateWorktree, Ack),
251);
252
253entity_messages!(
254    project_id,
255    AddProjectCollaborator,
256    ApplyCodeAction,
257    ApplyCompletionAdditionalEdits,
258    BufferReloaded,
259    BufferSaved,
260    Follow,
261    FormatBuffers,
262    GetCodeActions,
263    GetCompletions,
264    GetDefinition,
265    GetDocumentHighlights,
266    GetReferences,
267    GetProjectSymbols,
268    JoinProject,
269    LeaveProject,
270    OpenBufferById,
271    OpenBufferByPath,
272    OpenBufferForSymbol,
273    PerformRename,
274    PrepareRename,
275    ReloadBuffers,
276    RemoveProjectCollaborator,
277    SaveBuffer,
278    SearchProject,
279    StartLanguageServer,
280    Unfollow,
281    UnregisterWorktree,
282    UnshareProject,
283    UpdateBuffer,
284    UpdateBufferFile,
285    UpdateDiagnosticSummary,
286    UpdateFollowers,
287    UpdateLanguageServer,
288    RegisterWorktree,
289    UpdateWorktree,
290);
291
292entity_messages!(channel_id, ChannelMessageSent);
293
294/// A stream of protobuf messages.
295pub struct MessageStream<S> {
296    stream: S,
297    encoding_buffer: Vec<u8>,
298}
299
300#[derive(Debug)]
301pub enum Message {
302    Envelope(Envelope),
303    Ping,
304    Pong,
305}
306
307impl<S> MessageStream<S> {
308    pub fn new(stream: S) -> Self {
309        Self {
310            stream,
311            encoding_buffer: Vec::new(),
312        }
313    }
314
315    pub fn inner_mut(&mut self) -> &mut S {
316        &mut self.stream
317    }
318}
319
320impl<S> MessageStream<S>
321where
322    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
323{
324    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
325        #[cfg(any(test, feature = "test-support"))]
326        const COMPRESSION_LEVEL: i32 = -7;
327
328        #[cfg(not(any(test, feature = "test-support")))]
329        const COMPRESSION_LEVEL: i32 = 4;
330
331        match message {
332            Message::Envelope(message) => {
333                self.encoding_buffer.resize(message.encoded_len(), 0);
334                self.encoding_buffer.clear();
335                message
336                    .encode(&mut self.encoding_buffer)
337                    .map_err(|err| io::Error::from(err))?;
338                let buffer =
339                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
340                        .unwrap();
341                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
342            }
343            Message::Ping => {
344                self.stream
345                    .send(WebSocketMessage::Ping(Default::default()))
346                    .await?;
347            }
348            Message::Pong => {
349                self.stream
350                    .send(WebSocketMessage::Pong(Default::default()))
351                    .await?;
352            }
353        }
354
355        Ok(())
356    }
357}
358
359impl<S> MessageStream<S>
360where
361    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
362{
363    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
364        while let Some(bytes) = self.stream.next().await {
365            match bytes? {
366                WebSocketMessage::Binary(bytes) => {
367                    self.encoding_buffer.clear();
368                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
369                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
370                        .map_err(io::Error::from)?;
371                    return Ok(Message::Envelope(envelope));
372                }
373                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
374                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
375                WebSocketMessage::Close(_) => break,
376                _ => {}
377            }
378        }
379        Err(anyhow!("connection closed"))
380    }
381}
382
383impl Into<SystemTime> for Timestamp {
384    fn into(self) -> SystemTime {
385        UNIX_EPOCH
386            .checked_add(Duration::new(self.seconds, self.nanos))
387            .unwrap()
388    }
389}
390
391impl From<SystemTime> for Timestamp {
392    fn from(time: SystemTime) -> Self {
393        let duration = time.duration_since(UNIX_EPOCH).unwrap();
394        Self {
395            seconds: duration.as_secs(),
396            nanos: duration.subsec_nanos(),
397        }
398    }
399}
400
401impl From<u128> for Nonce {
402    fn from(nonce: u128) -> Self {
403        let upper_half = (nonce >> 64) as u64;
404        let lower_half = nonce as u64;
405        Self {
406            upper_half,
407            lower_half,
408        }
409    }
410}
411
412impl From<Nonce> for u128 {
413    fn from(nonce: Nonce) -> Self {
414        let upper_half = (nonce.upper_half as u128) << 64;
415        let lower_half = nonce.lower_half as u128;
416        upper_half | lower_half
417    }
418}