proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{cmp, iter, mem};
  9use std::{
 10    fmt::Debug,
 11    io,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14
 15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 16
 17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 18    const NAME: &'static str;
 19    const PRIORITY: MessagePriority;
 20    fn into_envelope(
 21        self,
 22        id: u32,
 23        responding_to: Option<u32>,
 24        original_sender_id: Option<u32>,
 25    ) -> Envelope;
 26    fn from_envelope(envelope: Envelope) -> Option<Self>;
 27}
 28
 29pub trait EntityMessage: EnvelopedMessage {
 30    fn remote_entity_id(&self) -> u64;
 31}
 32
 33pub trait RequestMessage: EnvelopedMessage {
 34    type Response: EnvelopedMessage;
 35}
 36
 37pub trait AnyTypedEnvelope: 'static + Send + Sync {
 38    fn payload_type_id(&self) -> TypeId;
 39    fn payload_type_name(&self) -> &'static str;
 40    fn as_any(&self) -> &dyn Any;
 41    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 42    fn is_background(&self) -> bool;
 43    fn original_sender_id(&self) -> Option<PeerId>;
 44}
 45
 46pub enum MessagePriority {
 47    Foreground,
 48    Background,
 49}
 50
 51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 52    fn payload_type_id(&self) -> TypeId {
 53        TypeId::of::<T>()
 54    }
 55
 56    fn payload_type_name(&self) -> &'static str {
 57        T::NAME
 58    }
 59
 60    fn as_any(&self) -> &dyn Any {
 61        self
 62    }
 63
 64    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 65        self
 66    }
 67
 68    fn is_background(&self) -> bool {
 69        matches!(T::PRIORITY, MessagePriority::Background)
 70    }
 71
 72    fn original_sender_id(&self) -> Option<PeerId> {
 73        self.original_sender_id
 74    }
 75}
 76
 77messages!(
 78    (Ack, Foreground),
 79    (AddProjectCollaborator, Foreground),
 80    (ApplyCodeAction, Background),
 81    (ApplyCodeActionResponse, Background),
 82    (ApplyCompletionAdditionalEdits, Background),
 83    (ApplyCompletionAdditionalEditsResponse, Background),
 84    (BufferReloaded, Foreground),
 85    (BufferSaved, Foreground),
 86    (Call, Foreground),
 87    (CallCanceled, Foreground),
 88    (CancelCall, Foreground),
 89    (ChannelMessageSent, Foreground),
 90    (CopyProjectEntry, Foreground),
 91    (CreateBufferForPeer, Foreground),
 92    (CreateProjectEntry, Foreground),
 93    (CreateRoom, Foreground),
 94    (CreateRoomResponse, Foreground),
 95    (DeclineCall, Foreground),
 96    (DeleteProjectEntry, Foreground),
 97    (Error, Foreground),
 98    (Follow, Foreground),
 99    (FollowResponse, Foreground),
100    (FormatBuffers, Foreground),
101    (FormatBuffersResponse, Foreground),
102    (FuzzySearchUsers, Foreground),
103    (GetChannelMessages, Foreground),
104    (GetChannelMessagesResponse, Foreground),
105    (GetChannels, Foreground),
106    (GetChannelsResponse, Foreground),
107    (GetCodeActions, Background),
108    (GetCodeActionsResponse, Background),
109    (GetHover, Background),
110    (GetHoverResponse, Background),
111    (GetCompletions, Background),
112    (GetCompletionsResponse, Background),
113    (GetDefinition, Background),
114    (GetDefinitionResponse, Background),
115    (GetTypeDefinition, Background),
116    (GetTypeDefinitionResponse, Background),
117    (GetDocumentHighlights, Background),
118    (GetDocumentHighlightsResponse, Background),
119    (GetReferences, Background),
120    (GetReferencesResponse, Background),
121    (GetProjectSymbols, Background),
122    (GetProjectSymbolsResponse, Background),
123    (GetUsers, Foreground),
124    (Hello, Foreground),
125    (IncomingCall, Foreground),
126    (UsersResponse, Foreground),
127    (JoinChannel, Foreground),
128    (JoinChannelResponse, Foreground),
129    (JoinProject, Foreground),
130    (JoinProjectResponse, Foreground),
131    (JoinRoom, Foreground),
132    (JoinRoomResponse, Foreground),
133    (LeaveChannel, Foreground),
134    (LeaveProject, Foreground),
135    (LeaveRoom, Foreground),
136    (OpenBufferById, Background),
137    (OpenBufferByPath, Background),
138    (OpenBufferForSymbol, Background),
139    (OpenBufferForSymbolResponse, Background),
140    (OpenBufferResponse, Background),
141    (PerformRename, Background),
142    (PerformRenameResponse, Background),
143    (PrepareRename, Background),
144    (PrepareRenameResponse, Background),
145    (ProjectEntryResponse, Foreground),
146    (RemoveContact, Foreground),
147    (Ping, Foreground),
148    (RegisterProjectActivity, Foreground),
149    (ReloadBuffers, Foreground),
150    (ReloadBuffersResponse, Foreground),
151    (RemoveProjectCollaborator, Foreground),
152    (RenameProjectEntry, Foreground),
153    (RequestContact, Foreground),
154    (RespondToContactRequest, Foreground),
155    (RoomUpdated, Foreground),
156    (SaveBuffer, Foreground),
157    (SearchProject, Background),
158    (SearchProjectResponse, Background),
159    (SendChannelMessage, Foreground),
160    (SendChannelMessageResponse, Foreground),
161    (ShareProject, Foreground),
162    (ShareProjectResponse, Foreground),
163    (ShowContacts, Foreground),
164    (StartLanguageServer, Foreground),
165    (Test, Foreground),
166    (Unfollow, Foreground),
167    (UnshareProject, Foreground),
168    (UpdateBuffer, Foreground),
169    (UpdateBufferFile, Foreground),
170    (UpdateContacts, Foreground),
171    (UpdateDiagnosticSummary, Foreground),
172    (UpdateFollowers, Foreground),
173    (UpdateInviteInfo, Foreground),
174    (UpdateLanguageServer, Foreground),
175    (UpdateParticipantLocation, Foreground),
176    (UpdateProject, Foreground),
177    (UpdateWorktree, Foreground),
178    (UpdateWorktreeExtensions, Background),
179    (UpdateDiffBase, Background),
180    (GetPrivateUserInfo, Foreground),
181    (GetPrivateUserInfoResponse, Foreground),
182);
183
184request_messages!(
185    (ApplyCodeAction, ApplyCodeActionResponse),
186    (
187        ApplyCompletionAdditionalEdits,
188        ApplyCompletionAdditionalEditsResponse
189    ),
190    (Call, Ack),
191    (CancelCall, Ack),
192    (CopyProjectEntry, ProjectEntryResponse),
193    (CreateProjectEntry, ProjectEntryResponse),
194    (CreateRoom, CreateRoomResponse),
195    (DeclineCall, Ack),
196    (DeleteProjectEntry, ProjectEntryResponse),
197    (Follow, FollowResponse),
198    (FormatBuffers, FormatBuffersResponse),
199    (GetChannelMessages, GetChannelMessagesResponse),
200    (GetChannels, GetChannelsResponse),
201    (GetCodeActions, GetCodeActionsResponse),
202    (GetHover, GetHoverResponse),
203    (GetCompletions, GetCompletionsResponse),
204    (GetDefinition, GetDefinitionResponse),
205    (GetTypeDefinition, GetTypeDefinitionResponse),
206    (GetDocumentHighlights, GetDocumentHighlightsResponse),
207    (GetReferences, GetReferencesResponse),
208    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
209    (GetProjectSymbols, GetProjectSymbolsResponse),
210    (FuzzySearchUsers, UsersResponse),
211    (GetUsers, UsersResponse),
212    (JoinChannel, JoinChannelResponse),
213    (JoinProject, JoinProjectResponse),
214    (JoinRoom, JoinRoomResponse),
215    (IncomingCall, Ack),
216    (OpenBufferById, OpenBufferResponse),
217    (OpenBufferByPath, OpenBufferResponse),
218    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
219    (Ping, Ack),
220    (PerformRename, PerformRenameResponse),
221    (PrepareRename, PrepareRenameResponse),
222    (ReloadBuffers, ReloadBuffersResponse),
223    (RequestContact, Ack),
224    (RemoveContact, Ack),
225    (RespondToContactRequest, Ack),
226    (RenameProjectEntry, ProjectEntryResponse),
227    (SaveBuffer, BufferSaved),
228    (SearchProject, SearchProjectResponse),
229    (SendChannelMessage, SendChannelMessageResponse),
230    (ShareProject, ShareProjectResponse),
231    (Test, Test),
232    (UpdateBuffer, Ack),
233    (UpdateParticipantLocation, Ack),
234    (UpdateWorktree, Ack),
235);
236
237entity_messages!(
238    project_id,
239    AddProjectCollaborator,
240    ApplyCodeAction,
241    ApplyCompletionAdditionalEdits,
242    BufferReloaded,
243    BufferSaved,
244    CopyProjectEntry,
245    CreateBufferForPeer,
246    CreateProjectEntry,
247    DeleteProjectEntry,
248    Follow,
249    FormatBuffers,
250    GetCodeActions,
251    GetCompletions,
252    GetDefinition,
253    GetTypeDefinition,
254    GetDocumentHighlights,
255    GetHover,
256    GetReferences,
257    GetProjectSymbols,
258    JoinProject,
259    LeaveProject,
260    OpenBufferById,
261    OpenBufferByPath,
262    OpenBufferForSymbol,
263    PerformRename,
264    PrepareRename,
265    RegisterProjectActivity,
266    ReloadBuffers,
267    RemoveProjectCollaborator,
268    RenameProjectEntry,
269    SaveBuffer,
270    SearchProject,
271    StartLanguageServer,
272    Unfollow,
273    UnshareProject,
274    UpdateBuffer,
275    UpdateBufferFile,
276    UpdateDiagnosticSummary,
277    UpdateFollowers,
278    UpdateLanguageServer,
279    UpdateProject,
280    UpdateWorktree,
281    UpdateWorktreeExtensions,
282    UpdateDiffBase
283);
284
285entity_messages!(channel_id, ChannelMessageSent);
286
287const KIB: usize = 1024;
288const MIB: usize = KIB * 1024;
289const MAX_BUFFER_LEN: usize = MIB;
290
291/// A stream of protobuf messages.
292pub struct MessageStream<S> {
293    stream: S,
294    encoding_buffer: Vec<u8>,
295}
296
297#[allow(clippy::large_enum_variant)]
298#[derive(Debug)]
299pub enum Message {
300    Envelope(Envelope),
301    Ping,
302    Pong,
303}
304
305impl<S> MessageStream<S> {
306    pub fn new(stream: S) -> Self {
307        Self {
308            stream,
309            encoding_buffer: Vec::new(),
310        }
311    }
312
313    pub fn inner_mut(&mut self) -> &mut S {
314        &mut self.stream
315    }
316}
317
318impl<S> MessageStream<S>
319where
320    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
321{
322    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
323        #[cfg(any(test, feature = "test-support"))]
324        const COMPRESSION_LEVEL: i32 = -7;
325
326        #[cfg(not(any(test, feature = "test-support")))]
327        const COMPRESSION_LEVEL: i32 = 4;
328
329        match message {
330            Message::Envelope(message) => {
331                self.encoding_buffer.reserve(message.encoded_len());
332                message
333                    .encode(&mut self.encoding_buffer)
334                    .map_err(io::Error::from)?;
335                let buffer =
336                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
337                        .unwrap();
338
339                self.encoding_buffer.clear();
340                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
341                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
342            }
343            Message::Ping => {
344                self.stream
345                    .send(WebSocketMessage::Ping(Default::default()))
346                    .await?;
347            }
348            Message::Pong => {
349                self.stream
350                    .send(WebSocketMessage::Pong(Default::default()))
351                    .await?;
352            }
353        }
354
355        Ok(())
356    }
357}
358
359impl<S> MessageStream<S>
360where
361    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
362{
363    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
364        while let Some(bytes) = self.stream.next().await {
365            match bytes? {
366                WebSocketMessage::Binary(bytes) => {
367                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
368                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
369                        .map_err(io::Error::from)?;
370
371                    self.encoding_buffer.clear();
372                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
373                    return Ok(Message::Envelope(envelope));
374                }
375                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
376                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
377                WebSocketMessage::Close(_) => break,
378                _ => {}
379            }
380        }
381        Err(anyhow!("connection closed"))
382    }
383}
384
385impl From<Timestamp> for SystemTime {
386    fn from(val: Timestamp) -> Self {
387        UNIX_EPOCH
388            .checked_add(Duration::new(val.seconds, val.nanos))
389            .unwrap()
390    }
391}
392
393impl From<SystemTime> for Timestamp {
394    fn from(time: SystemTime) -> Self {
395        let duration = time.duration_since(UNIX_EPOCH).unwrap();
396        Self {
397            seconds: duration.as_secs(),
398            nanos: duration.subsec_nanos(),
399        }
400    }
401}
402
403impl From<u128> for Nonce {
404    fn from(nonce: u128) -> Self {
405        let upper_half = (nonce >> 64) as u64;
406        let lower_half = nonce as u64;
407        Self {
408            upper_half,
409            lower_half,
410        }
411    }
412}
413
414impl From<Nonce> for u128 {
415    fn from(nonce: Nonce) -> Self {
416        let upper_half = (nonce.upper_half as u128) << 64;
417        let lower_half = nonce.lower_half as u128;
418        upper_half | lower_half
419    }
420}
421
422pub fn split_worktree_update(
423    mut message: UpdateWorktree,
424    max_chunk_size: usize,
425) -> impl Iterator<Item = UpdateWorktree> {
426    let mut done = false;
427    iter::from_fn(move || {
428        if done {
429            return None;
430        }
431
432        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
433        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
434        done = message.updated_entries.is_empty();
435        Some(UpdateWorktree {
436            project_id: message.project_id,
437            worktree_id: message.worktree_id,
438            root_name: message.root_name.clone(),
439            abs_path: message.abs_path.clone(),
440            updated_entries,
441            removed_entries: mem::take(&mut message.removed_entries),
442            scan_id: message.scan_id,
443            is_last_update: done && message.is_last_update,
444        })
445    })
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[gpui::test]
453    async fn test_buffer_size() {
454        let (tx, rx) = futures::channel::mpsc::unbounded();
455        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
456        sink.write(Message::Envelope(Envelope {
457            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
458                root_name: "abcdefg".repeat(10),
459                ..Default::default()
460            })),
461            ..Default::default()
462        }))
463        .await
464        .unwrap();
465        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
466        sink.write(Message::Envelope(Envelope {
467            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
468                root_name: "abcdefg".repeat(1000000),
469                ..Default::default()
470            })),
471            ..Default::default()
472        }))
473        .await
474        .unwrap();
475        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
476
477        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
478        stream.read().await.unwrap();
479        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
480        stream.read().await.unwrap();
481        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
482    }
483}