proto.rs

  1#![allow(non_snake_case)]
  2
  3use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  4use anyhow::{anyhow, Result};
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use collections::HashMap;
  7use futures::{SinkExt as _, StreamExt as _};
  8use prost::Message as _;
  9use serde::Serialize;
 10use std::any::{Any, TypeId};
 11use std::{
 12    cmp,
 13    fmt::Debug,
 14    io, iter,
 15    time::{Duration, SystemTime, UNIX_EPOCH},
 16};
 17use std::{fmt, mem};
 18
 19include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 20
 21pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 22    const NAME: &'static str;
 23    const PRIORITY: MessagePriority;
 24    fn into_envelope(
 25        self,
 26        id: u32,
 27        responding_to: Option<u32>,
 28        original_sender_id: Option<PeerId>,
 29    ) -> Envelope;
 30    fn from_envelope(envelope: Envelope) -> Option<Self>;
 31}
 32
 33pub trait EntityMessage: EnvelopedMessage {
 34    fn remote_entity_id(&self) -> u64;
 35}
 36
 37pub trait RequestMessage: EnvelopedMessage {
 38    type Response: EnvelopedMessage;
 39}
 40
 41pub trait AnyTypedEnvelope: 'static + Send + Sync {
 42    fn payload_type_id(&self) -> TypeId;
 43    fn payload_type_name(&self) -> &'static str;
 44    fn as_any(&self) -> &dyn Any;
 45    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 46    fn is_background(&self) -> bool;
 47    fn original_sender_id(&self) -> Option<PeerId>;
 48    fn sender_id(&self) -> ConnectionId;
 49    fn message_id(&self) -> u32;
 50}
 51
 52pub enum MessagePriority {
 53    Foreground,
 54    Background,
 55}
 56
 57impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 58    fn payload_type_id(&self) -> TypeId {
 59        TypeId::of::<T>()
 60    }
 61
 62    fn payload_type_name(&self) -> &'static str {
 63        T::NAME
 64    }
 65
 66    fn as_any(&self) -> &dyn Any {
 67        self
 68    }
 69
 70    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 71        self
 72    }
 73
 74    fn is_background(&self) -> bool {
 75        matches!(T::PRIORITY, MessagePriority::Background)
 76    }
 77
 78    fn original_sender_id(&self) -> Option<PeerId> {
 79        self.original_sender_id
 80    }
 81
 82    fn sender_id(&self) -> ConnectionId {
 83        self.sender_id
 84    }
 85
 86    fn message_id(&self) -> u32 {
 87        self.message_id
 88    }
 89}
 90
 91impl PeerId {
 92    pub fn from_u64(peer_id: u64) -> Self {
 93        let owner_id = (peer_id >> 32) as u32;
 94        let id = peer_id as u32;
 95        Self { owner_id, id }
 96    }
 97
 98    pub fn as_u64(self) -> u64 {
 99        ((self.owner_id as u64) << 32) | (self.id as u64)
100    }
101}
102
103impl Copy for PeerId {}
104
105impl Eq for PeerId {}
106
107impl Ord for PeerId {
108    fn cmp(&self, other: &Self) -> cmp::Ordering {
109        self.owner_id
110            .cmp(&other.owner_id)
111            .then_with(|| self.id.cmp(&other.id))
112    }
113}
114
115impl PartialOrd for PeerId {
116    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
117        Some(self.cmp(other))
118    }
119}
120
121impl std::hash::Hash for PeerId {
122    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123        self.owner_id.hash(state);
124        self.id.hash(state);
125    }
126}
127
128impl fmt::Display for PeerId {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "{}/{}", self.owner_id, self.id)
131    }
132}
133
134messages!(
135    (Ack, Foreground),
136    (AddNotifications, Foreground),
137    (AddProjectCollaborator, Foreground),
138    (ApplyCodeAction, Background),
139    (ApplyCodeActionResponse, Background),
140    (ApplyCompletionAdditionalEdits, Background),
141    (ApplyCompletionAdditionalEditsResponse, Background),
142    (BufferReloaded, Foreground),
143    (BufferSaved, Foreground),
144    (Call, Foreground),
145    (CallCanceled, Foreground),
146    (CancelCall, Foreground),
147    (CopyProjectEntry, Foreground),
148    (CreateBufferForPeer, Foreground),
149    (CreateChannel, Foreground),
150    (CreateChannelResponse, Foreground),
151    (ChannelMessageSent, Foreground),
152    (CreateProjectEntry, Foreground),
153    (CreateRoom, Foreground),
154    (CreateRoomResponse, Foreground),
155    (DeclineCall, Foreground),
156    (DeleteProjectEntry, Foreground),
157    (Error, Foreground),
158    (ExpandProjectEntry, Foreground),
159    (Follow, Foreground),
160    (FollowResponse, Foreground),
161    (FormatBuffers, Foreground),
162    (FormatBuffersResponse, Foreground),
163    (FuzzySearchUsers, Foreground),
164    (GetCodeActions, Background),
165    (GetCodeActionsResponse, Background),
166    (GetHover, Background),
167    (GetHoverResponse, Background),
168    (GetChannelMessages, Background),
169    (GetChannelMessagesResponse, Background),
170    (GetChannelMessagesById, Background),
171    (SendChannelMessage, Background),
172    (SendChannelMessageResponse, Background),
173    (GetCompletions, Background),
174    (GetCompletionsResponse, Background),
175    (GetDefinition, Background),
176    (GetDefinitionResponse, Background),
177    (GetTypeDefinition, Background),
178    (GetTypeDefinitionResponse, Background),
179    (GetDocumentHighlights, Background),
180    (GetDocumentHighlightsResponse, Background),
181    (GetReferences, Background),
182    (GetReferencesResponse, Background),
183    (GetProjectSymbols, Background),
184    (GetProjectSymbolsResponse, Background),
185    (GetUsers, Foreground),
186    (Hello, Foreground),
187    (IncomingCall, Foreground),
188    (InviteChannelMember, Foreground),
189    (UsersResponse, Foreground),
190    (JoinProject, Foreground),
191    (JoinProjectResponse, Foreground),
192    (JoinRoom, Foreground),
193    (JoinRoomResponse, Foreground),
194    (JoinChannelChat, Foreground),
195    (JoinChannelChatResponse, Foreground),
196    (LeaveChannelChat, Foreground),
197    (LeaveProject, Foreground),
198    (LeaveRoom, Foreground),
199    (OpenBufferById, Background),
200    (OpenBufferByPath, Background),
201    (OpenBufferForSymbol, Background),
202    (OpenBufferForSymbolResponse, Background),
203    (OpenBufferResponse, Background),
204    (PerformRename, Background),
205    (PerformRenameResponse, Background),
206    (OnTypeFormatting, Background),
207    (OnTypeFormattingResponse, Background),
208    (InlayHints, Background),
209    (InlayHintsResponse, Background),
210    (ResolveInlayHint, Background),
211    (ResolveInlayHintResponse, Background),
212    (RefreshInlayHints, Foreground),
213    (Ping, Foreground),
214    (PrepareRename, Background),
215    (PrepareRenameResponse, Background),
216    (ExpandProjectEntryResponse, Foreground),
217    (ProjectEntryResponse, Foreground),
218    (RejoinRoom, Foreground),
219    (RejoinRoomResponse, Foreground),
220    (RemoveContact, Foreground),
221    (RemoveChannelMember, Foreground),
222    (RemoveChannelMessage, Foreground),
223    (ReloadBuffers, Foreground),
224    (ReloadBuffersResponse, Foreground),
225    (RemoveProjectCollaborator, Foreground),
226    (RenameProjectEntry, Foreground),
227    (RequestContact, Foreground),
228    (RespondToContactRequest, Foreground),
229    (RespondToChannelInvite, Foreground),
230    (JoinChannel, Foreground),
231    (RoomUpdated, Foreground),
232    (SaveBuffer, Foreground),
233    (RenameChannel, Foreground),
234    (RenameChannelResponse, Foreground),
235    (SetChannelMemberAdmin, Foreground),
236    (SearchProject, Background),
237    (SearchProjectResponse, Background),
238    (ShareProject, Foreground),
239    (ShareProjectResponse, Foreground),
240    (ShowContacts, Foreground),
241    (StartLanguageServer, Foreground),
242    (SynchronizeBuffers, Foreground),
243    (SynchronizeBuffersResponse, Foreground),
244    (RejoinChannelBuffers, Foreground),
245    (RejoinChannelBuffersResponse, Foreground),
246    (Test, Foreground),
247    (Unfollow, Foreground),
248    (UnshareProject, Foreground),
249    (UpdateBuffer, Foreground),
250    (UpdateBufferFile, Foreground),
251    (UpdateContacts, Foreground),
252    (DeleteChannel, Foreground),
253    (MoveChannel, Foreground),
254    (LinkChannel, Foreground),
255    (UnlinkChannel, Foreground),
256    (UpdateChannels, Foreground),
257    (UpdateDiagnosticSummary, Foreground),
258    (UpdateFollowers, Foreground),
259    (UpdateInviteInfo, Foreground),
260    (UpdateLanguageServer, Foreground),
261    (UpdateParticipantLocation, Foreground),
262    (UpdateProject, Foreground),
263    (UpdateProjectCollaborator, Foreground),
264    (UpdateWorktree, Foreground),
265    (UpdateWorktreeSettings, Foreground),
266    (UpdateDiffBase, Foreground),
267    (GetPrivateUserInfo, Foreground),
268    (GetPrivateUserInfoResponse, Foreground),
269    (GetChannelMembers, Foreground),
270    (GetChannelMembersResponse, Foreground),
271    (JoinChannelBuffer, Foreground),
272    (JoinChannelBufferResponse, Foreground),
273    (LeaveChannelBuffer, Background),
274    (UpdateChannelBuffer, Foreground),
275    (UpdateChannelBufferCollaborators, Foreground),
276    (AckBufferOperation, Background),
277    (AckChannelMessage, Background),
278);
279
280request_messages!(
281    (ApplyCodeAction, ApplyCodeActionResponse),
282    (
283        ApplyCompletionAdditionalEdits,
284        ApplyCompletionAdditionalEditsResponse
285    ),
286    (Call, Ack),
287    (CancelCall, Ack),
288    (CopyProjectEntry, ProjectEntryResponse),
289    (CreateProjectEntry, ProjectEntryResponse),
290    (CreateRoom, CreateRoomResponse),
291    (CreateChannel, CreateChannelResponse),
292    (DeclineCall, Ack),
293    (DeleteProjectEntry, ProjectEntryResponse),
294    (ExpandProjectEntry, ExpandProjectEntryResponse),
295    (Follow, FollowResponse),
296    (FormatBuffers, FormatBuffersResponse),
297    (GetCodeActions, GetCodeActionsResponse),
298    (GetHover, GetHoverResponse),
299    (GetCompletions, GetCompletionsResponse),
300    (GetDefinition, GetDefinitionResponse),
301    (GetTypeDefinition, GetTypeDefinitionResponse),
302    (GetDocumentHighlights, GetDocumentHighlightsResponse),
303    (GetReferences, GetReferencesResponse),
304    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
305    (GetProjectSymbols, GetProjectSymbolsResponse),
306    (FuzzySearchUsers, UsersResponse),
307    (GetUsers, UsersResponse),
308    (InviteChannelMember, Ack),
309    (JoinProject, JoinProjectResponse),
310    (JoinRoom, JoinRoomResponse),
311    (JoinChannelChat, JoinChannelChatResponse),
312    (LeaveRoom, Ack),
313    (RejoinRoom, RejoinRoomResponse),
314    (IncomingCall, Ack),
315    (OpenBufferById, OpenBufferResponse),
316    (OpenBufferByPath, OpenBufferResponse),
317    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
318    (Ping, Ack),
319    (PerformRename, PerformRenameResponse),
320    (PrepareRename, PrepareRenameResponse),
321    (OnTypeFormatting, OnTypeFormattingResponse),
322    (InlayHints, InlayHintsResponse),
323    (ResolveInlayHint, ResolveInlayHintResponse),
324    (RefreshInlayHints, Ack),
325    (ReloadBuffers, ReloadBuffersResponse),
326    (RequestContact, Ack),
327    (RemoveChannelMember, Ack),
328    (RemoveContact, Ack),
329    (RespondToContactRequest, Ack),
330    (RespondToChannelInvite, Ack),
331    (SetChannelMemberAdmin, Ack),
332    (SendChannelMessage, SendChannelMessageResponse),
333    (GetChannelMessages, GetChannelMessagesResponse),
334    (GetChannelMessagesById, GetChannelMessagesResponse),
335    (GetChannelMembers, GetChannelMembersResponse),
336    (JoinChannel, JoinRoomResponse),
337    (RemoveChannelMessage, Ack),
338    (DeleteChannel, Ack),
339    (RenameProjectEntry, ProjectEntryResponse),
340    (RenameChannel, RenameChannelResponse),
341    (LinkChannel, Ack),
342    (UnlinkChannel, Ack),
343    (MoveChannel, Ack),
344    (SaveBuffer, BufferSaved),
345    (SearchProject, SearchProjectResponse),
346    (ShareProject, ShareProjectResponse),
347    (SynchronizeBuffers, SynchronizeBuffersResponse),
348    (RejoinChannelBuffers, RejoinChannelBuffersResponse),
349    (Test, Test),
350    (UpdateBuffer, Ack),
351    (UpdateParticipantLocation, Ack),
352    (UpdateProject, Ack),
353    (UpdateWorktree, Ack),
354    (JoinChannelBuffer, JoinChannelBufferResponse),
355    (LeaveChannelBuffer, Ack)
356);
357
358entity_messages!(
359    project_id,
360    AddProjectCollaborator,
361    ApplyCodeAction,
362    ApplyCompletionAdditionalEdits,
363    BufferReloaded,
364    BufferSaved,
365    CopyProjectEntry,
366    CreateBufferForPeer,
367    CreateProjectEntry,
368    DeleteProjectEntry,
369    ExpandProjectEntry,
370    FormatBuffers,
371    GetCodeActions,
372    GetCompletions,
373    GetDefinition,
374    GetTypeDefinition,
375    GetDocumentHighlights,
376    GetHover,
377    GetReferences,
378    GetProjectSymbols,
379    JoinProject,
380    LeaveProject,
381    OpenBufferById,
382    OpenBufferByPath,
383    OpenBufferForSymbol,
384    PerformRename,
385    OnTypeFormatting,
386    InlayHints,
387    ResolveInlayHint,
388    RefreshInlayHints,
389    PrepareRename,
390    ReloadBuffers,
391    RemoveProjectCollaborator,
392    RenameProjectEntry,
393    SaveBuffer,
394    SearchProject,
395    StartLanguageServer,
396    SynchronizeBuffers,
397    UnshareProject,
398    UpdateBuffer,
399    UpdateBufferFile,
400    UpdateDiagnosticSummary,
401    UpdateLanguageServer,
402    UpdateProject,
403    UpdateProjectCollaborator,
404    UpdateWorktree,
405    UpdateWorktreeSettings,
406    UpdateDiffBase
407);
408
409entity_messages!(
410    channel_id,
411    ChannelMessageSent,
412    UpdateChannelBuffer,
413    RemoveChannelMessage,
414    UpdateChannelBufferCollaborators,
415);
416
417const KIB: usize = 1024;
418const MIB: usize = KIB * 1024;
419const MAX_BUFFER_LEN: usize = MIB;
420
421/// A stream of protobuf messages.
422pub struct MessageStream<S> {
423    stream: S,
424    encoding_buffer: Vec<u8>,
425}
426
427#[allow(clippy::large_enum_variant)]
428#[derive(Debug)]
429pub enum Message {
430    Envelope(Envelope),
431    Ping,
432    Pong,
433}
434
435impl<S> MessageStream<S> {
436    pub fn new(stream: S) -> Self {
437        Self {
438            stream,
439            encoding_buffer: Vec::new(),
440        }
441    }
442
443    pub fn inner_mut(&mut self) -> &mut S {
444        &mut self.stream
445    }
446}
447
448impl<S> MessageStream<S>
449where
450    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
451{
452    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
453        #[cfg(any(test, feature = "test-support"))]
454        const COMPRESSION_LEVEL: i32 = -7;
455
456        #[cfg(not(any(test, feature = "test-support")))]
457        const COMPRESSION_LEVEL: i32 = 4;
458
459        match message {
460            Message::Envelope(message) => {
461                self.encoding_buffer.reserve(message.encoded_len());
462                message
463                    .encode(&mut self.encoding_buffer)
464                    .map_err(io::Error::from)?;
465                let buffer =
466                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
467                        .unwrap();
468
469                self.encoding_buffer.clear();
470                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
471                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
472            }
473            Message::Ping => {
474                self.stream
475                    .send(WebSocketMessage::Ping(Default::default()))
476                    .await?;
477            }
478            Message::Pong => {
479                self.stream
480                    .send(WebSocketMessage::Pong(Default::default()))
481                    .await?;
482            }
483        }
484
485        Ok(())
486    }
487}
488
489impl<S> MessageStream<S>
490where
491    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
492{
493    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
494        while let Some(bytes) = self.stream.next().await {
495            match bytes? {
496                WebSocketMessage::Binary(bytes) => {
497                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
498                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
499                        .map_err(io::Error::from)?;
500
501                    self.encoding_buffer.clear();
502                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
503                    return Ok(Message::Envelope(envelope));
504                }
505                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
506                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
507                WebSocketMessage::Close(_) => break,
508                _ => {}
509            }
510        }
511        Err(anyhow!("connection closed"))
512    }
513}
514
515impl From<Timestamp> for SystemTime {
516    fn from(val: Timestamp) -> Self {
517        UNIX_EPOCH
518            .checked_add(Duration::new(val.seconds, val.nanos))
519            .unwrap()
520    }
521}
522
523impl From<SystemTime> for Timestamp {
524    fn from(time: SystemTime) -> Self {
525        let duration = time.duration_since(UNIX_EPOCH).unwrap();
526        Self {
527            seconds: duration.as_secs(),
528            nanos: duration.subsec_nanos(),
529        }
530    }
531}
532
533impl From<u128> for Nonce {
534    fn from(nonce: u128) -> Self {
535        let upper_half = (nonce >> 64) as u64;
536        let lower_half = nonce as u64;
537        Self {
538            upper_half,
539            lower_half,
540        }
541    }
542}
543
544impl From<Nonce> for u128 {
545    fn from(nonce: Nonce) -> Self {
546        let upper_half = (nonce.upper_half as u128) << 64;
547        let lower_half = nonce.lower_half as u128;
548        upper_half | lower_half
549    }
550}
551
552pub fn split_worktree_update(
553    mut message: UpdateWorktree,
554    max_chunk_size: usize,
555) -> impl Iterator<Item = UpdateWorktree> {
556    let mut done_files = false;
557
558    let mut repository_map = message
559        .updated_repositories
560        .into_iter()
561        .map(|repo| (repo.work_directory_id, repo))
562        .collect::<HashMap<_, _>>();
563
564    iter::from_fn(move || {
565        if done_files {
566            return None;
567        }
568
569        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
570        let updated_entries: Vec<_> = message
571            .updated_entries
572            .drain(..updated_entries_chunk_size)
573            .collect();
574
575        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
576        let removed_entries = message
577            .removed_entries
578            .drain(..removed_entries_chunk_size)
579            .collect();
580
581        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
582
583        let mut updated_repositories = Vec::new();
584
585        if !repository_map.is_empty() {
586            for entry in &updated_entries {
587                if let Some(repo) = repository_map.remove(&entry.id) {
588                    updated_repositories.push(repo)
589                }
590            }
591        }
592
593        let removed_repositories = if done_files {
594            mem::take(&mut message.removed_repositories)
595        } else {
596            Default::default()
597        };
598
599        if done_files {
600            updated_repositories.extend(mem::take(&mut repository_map).into_values());
601        }
602
603        Some(UpdateWorktree {
604            project_id: message.project_id,
605            worktree_id: message.worktree_id,
606            root_name: message.root_name.clone(),
607            abs_path: message.abs_path.clone(),
608            updated_entries,
609            removed_entries,
610            scan_id: message.scan_id,
611            is_last_update: done_files && message.is_last_update,
612            updated_repositories,
613            removed_repositories,
614        })
615    })
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621
622    #[gpui::test]
623    async fn test_buffer_size() {
624        let (tx, rx) = futures::channel::mpsc::unbounded();
625        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
626        sink.write(Message::Envelope(Envelope {
627            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
628                root_name: "abcdefg".repeat(10),
629                ..Default::default()
630            })),
631            ..Default::default()
632        }))
633        .await
634        .unwrap();
635        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
636        sink.write(Message::Envelope(Envelope {
637            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
638                root_name: "abcdefg".repeat(1000000),
639                ..Default::default()
640            })),
641            ..Default::default()
642        }))
643        .await
644        .unwrap();
645        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
646
647        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
648        stream.read().await.unwrap();
649        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
650        stream.read().await.unwrap();
651        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
652    }
653
654    #[gpui::test]
655    fn test_converting_peer_id_from_and_to_u64() {
656        let peer_id = PeerId {
657            owner_id: 10,
658            id: 3,
659        };
660        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
661        let peer_id = PeerId {
662            owner_id: u32::MAX,
663            id: 3,
664        };
665        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
666        let peer_id = PeerId {
667            owner_id: 10,
668            id: u32::MAX,
669        };
670        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
671        let peer_id = PeerId {
672            owner_id: u32::MAX,
673            id: u32::MAX,
674        };
675        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
676    }
677}