proto.rs

  1#![allow(non_snake_case)]
  2
  3use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  4use anyhow::{anyhow, Result};
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use collections::HashMap;
  7use futures::{SinkExt as _, StreamExt as _};
  8use prost::Message as _;
  9use serde::Serialize;
 10use std::any::{Any, TypeId};
 11use std::{
 12    cmp,
 13    fmt::Debug,
 14    io, iter,
 15    time::{Duration, SystemTime, UNIX_EPOCH},
 16};
 17use std::{fmt, mem};
 18
 19include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 20
 21pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 22    const NAME: &'static str;
 23    const PRIORITY: MessagePriority;
 24    fn into_envelope(
 25        self,
 26        id: u32,
 27        responding_to: Option<u32>,
 28        original_sender_id: Option<PeerId>,
 29    ) -> Envelope;
 30    fn from_envelope(envelope: Envelope) -> Option<Self>;
 31}
 32
 33pub trait EntityMessage: EnvelopedMessage {
 34    fn remote_entity_id(&self) -> u64;
 35}
 36
 37pub trait RequestMessage: EnvelopedMessage {
 38    type Response: EnvelopedMessage;
 39}
 40
 41pub trait AnyTypedEnvelope: 'static + Send + Sync {
 42    fn payload_type_id(&self) -> TypeId;
 43    fn payload_type_name(&self) -> &'static str;
 44    fn as_any(&self) -> &dyn Any;
 45    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 46    fn is_background(&self) -> bool;
 47    fn original_sender_id(&self) -> Option<PeerId>;
 48    fn sender_id(&self) -> ConnectionId;
 49    fn message_id(&self) -> u32;
 50}
 51
 52pub enum MessagePriority {
 53    Foreground,
 54    Background,
 55}
 56
 57impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 58    fn payload_type_id(&self) -> TypeId {
 59        TypeId::of::<T>()
 60    }
 61
 62    fn payload_type_name(&self) -> &'static str {
 63        T::NAME
 64    }
 65
 66    fn as_any(&self) -> &dyn Any {
 67        self
 68    }
 69
 70    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 71        self
 72    }
 73
 74    fn is_background(&self) -> bool {
 75        matches!(T::PRIORITY, MessagePriority::Background)
 76    }
 77
 78    fn original_sender_id(&self) -> Option<PeerId> {
 79        self.original_sender_id
 80    }
 81
 82    fn sender_id(&self) -> ConnectionId {
 83        self.sender_id
 84    }
 85
 86    fn message_id(&self) -> u32 {
 87        self.message_id
 88    }
 89}
 90
 91impl PeerId {
 92    pub fn from_u64(peer_id: u64) -> Self {
 93        let owner_id = (peer_id >> 32) as u32;
 94        let id = peer_id as u32;
 95        Self { owner_id, id }
 96    }
 97
 98    pub fn as_u64(self) -> u64 {
 99        ((self.owner_id as u64) << 32) | (self.id as u64)
100    }
101}
102
103impl Copy for PeerId {}
104
105impl Eq for PeerId {}
106
107impl Ord for PeerId {
108    fn cmp(&self, other: &Self) -> cmp::Ordering {
109        self.owner_id
110            .cmp(&other.owner_id)
111            .then_with(|| self.id.cmp(&other.id))
112    }
113}
114
115impl PartialOrd for PeerId {
116    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
117        Some(self.cmp(other))
118    }
119}
120
121impl std::hash::Hash for PeerId {
122    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123        self.owner_id.hash(state);
124        self.id.hash(state);
125    }
126}
127
128impl fmt::Display for PeerId {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "{}/{}", self.owner_id, self.id)
131    }
132}
133
134messages!(
135    (Ack, Foreground),
136    (AddProjectCollaborator, Foreground),
137    (ApplyCodeAction, Background),
138    (ApplyCodeActionResponse, Background),
139    (ApplyCompletionAdditionalEdits, Background),
140    (ApplyCompletionAdditionalEditsResponse, Background),
141    (BufferReloaded, Foreground),
142    (BufferSaved, Foreground),
143    (Call, Foreground),
144    (CallCanceled, Foreground),
145    (CancelCall, Foreground),
146    (CopyProjectEntry, Foreground),
147    (CreateBufferForPeer, Foreground),
148    (CreateChannel, Foreground),
149    (CreateChannelResponse, Foreground),
150    (ChannelMessageSent, Foreground),
151    (CreateProjectEntry, Foreground),
152    (CreateRoom, Foreground),
153    (CreateRoomResponse, Foreground),
154    (DeclineCall, Foreground),
155    (DeleteProjectEntry, Foreground),
156    (Error, Foreground),
157    (ExpandProjectEntry, Foreground),
158    (Follow, Foreground),
159    (FollowResponse, Foreground),
160    (FormatBuffers, Foreground),
161    (FormatBuffersResponse, Foreground),
162    (FuzzySearchUsers, Foreground),
163    (GetCodeActions, Background),
164    (GetCodeActionsResponse, Background),
165    (GetHover, Background),
166    (GetHoverResponse, Background),
167    (GetChannelMessages, Background),
168    (GetChannelMessagesResponse, Background),
169    (SendChannelMessage, Background),
170    (SendChannelMessageResponse, Background),
171    (GetCompletions, Background),
172    (GetCompletionsResponse, Background),
173    (GetDefinition, Background),
174    (GetDefinitionResponse, Background),
175    (GetTypeDefinition, Background),
176    (GetTypeDefinitionResponse, Background),
177    (GetDocumentHighlights, Background),
178    (GetDocumentHighlightsResponse, Background),
179    (GetReferences, Background),
180    (GetReferencesResponse, Background),
181    (GetProjectSymbols, Background),
182    (GetProjectSymbolsResponse, Background),
183    (GetUsers, Foreground),
184    (Hello, Foreground),
185    (IncomingCall, Foreground),
186    (InviteChannelMember, Foreground),
187    (UsersResponse, Foreground),
188    (JoinProject, Foreground),
189    (JoinProjectResponse, Foreground),
190    (JoinRoom, Foreground),
191    (JoinRoomResponse, Foreground),
192    (JoinChannelChat, Foreground),
193    (JoinChannelChatResponse, Foreground),
194    (LeaveChannelChat, Foreground),
195    (LeaveProject, Foreground),
196    (LeaveRoom, Foreground),
197    (OpenBufferById, Background),
198    (OpenBufferByPath, Background),
199    (OpenBufferForSymbol, Background),
200    (OpenBufferForSymbolResponse, Background),
201    (OpenBufferResponse, Background),
202    (PerformRename, Background),
203    (PerformRenameResponse, Background),
204    (OnTypeFormatting, Background),
205    (OnTypeFormattingResponse, Background),
206    (InlayHints, Background),
207    (InlayHintsResponse, Background),
208    (ResolveCompletionDocumentation, Background),
209    (ResolveCompletionDocumentationResponse, Background),
210    (ResolveInlayHint, Background),
211    (ResolveInlayHintResponse, Background),
212    (RefreshInlayHints, Foreground),
213    (Ping, Foreground),
214    (PrepareRename, Background),
215    (PrepareRenameResponse, Background),
216    (ExpandProjectEntryResponse, Foreground),
217    (ProjectEntryResponse, Foreground),
218    (RejoinRoom, Foreground),
219    (RejoinRoomResponse, Foreground),
220    (RemoveContact, Foreground),
221    (RemoveChannelMember, Foreground),
222    (RemoveChannelMessage, Foreground),
223    (ReloadBuffers, Foreground),
224    (ReloadBuffersResponse, Foreground),
225    (RemoveProjectCollaborator, Foreground),
226    (RenameProjectEntry, Foreground),
227    (RequestContact, Foreground),
228    (RespondToContactRequest, Foreground),
229    (RespondToChannelInvite, Foreground),
230    (JoinChannel, Foreground),
231    (RoomUpdated, Foreground),
232    (SaveBuffer, Foreground),
233    (RenameChannel, Foreground),
234    (RenameChannelResponse, Foreground),
235    (SetChannelMemberAdmin, Foreground),
236    (SearchProject, Background),
237    (SearchProjectResponse, Background),
238    (ShareProject, Foreground),
239    (ShareProjectResponse, Foreground),
240    (ShowContacts, Foreground),
241    (StartLanguageServer, Foreground),
242    (SynchronizeBuffers, Foreground),
243    (SynchronizeBuffersResponse, Foreground),
244    (RejoinChannelBuffers, Foreground),
245    (RejoinChannelBuffersResponse, Foreground),
246    (Test, Foreground),
247    (Unfollow, Foreground),
248    (UnshareProject, Foreground),
249    (UpdateBuffer, Foreground),
250    (UpdateBufferFile, Foreground),
251    (UpdateContacts, Foreground),
252    (DeleteChannel, Foreground),
253    (MoveChannel, Foreground),
254    (LinkChannel, Foreground),
255    (UnlinkChannel, Foreground),
256    (UpdateChannels, Foreground),
257    (UpdateDiagnosticSummary, Foreground),
258    (UpdateFollowers, Foreground),
259    (UpdateInviteInfo, Foreground),
260    (UpdateLanguageServer, Foreground),
261    (UpdateParticipantLocation, Foreground),
262    (UpdateProject, Foreground),
263    (UpdateProjectCollaborator, Foreground),
264    (UpdateWorktree, Foreground),
265    (UpdateWorktreeSettings, Foreground),
266    (UpdateDiffBase, Foreground),
267    (GetPrivateUserInfo, Foreground),
268    (GetPrivateUserInfoResponse, Foreground),
269    (GetChannelMembers, Foreground),
270    (GetChannelMembersResponse, Foreground),
271    (JoinChannelBuffer, Foreground),
272    (JoinChannelBufferResponse, Foreground),
273    (LeaveChannelBuffer, Background),
274    (UpdateChannelBuffer, Foreground),
275    (UpdateChannelBufferCollaborators, Foreground),
276    (AckBufferOperation, Background),
277    (AckChannelMessage, Background),
278);
279
280request_messages!(
281    (ApplyCodeAction, ApplyCodeActionResponse),
282    (
283        ApplyCompletionAdditionalEdits,
284        ApplyCompletionAdditionalEditsResponse
285    ),
286    (Call, Ack),
287    (CancelCall, Ack),
288    (CopyProjectEntry, ProjectEntryResponse),
289    (CreateProjectEntry, ProjectEntryResponse),
290    (CreateRoom, CreateRoomResponse),
291    (CreateChannel, CreateChannelResponse),
292    (DeclineCall, Ack),
293    (DeleteProjectEntry, ProjectEntryResponse),
294    (ExpandProjectEntry, ExpandProjectEntryResponse),
295    (Follow, FollowResponse),
296    (FormatBuffers, FormatBuffersResponse),
297    (GetCodeActions, GetCodeActionsResponse),
298    (GetHover, GetHoverResponse),
299    (GetCompletions, GetCompletionsResponse),
300    (GetDefinition, GetDefinitionResponse),
301    (GetTypeDefinition, GetTypeDefinitionResponse),
302    (GetDocumentHighlights, GetDocumentHighlightsResponse),
303    (GetReferences, GetReferencesResponse),
304    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
305    (GetProjectSymbols, GetProjectSymbolsResponse),
306    (FuzzySearchUsers, UsersResponse),
307    (GetUsers, UsersResponse),
308    (InviteChannelMember, Ack),
309    (JoinProject, JoinProjectResponse),
310    (JoinRoom, JoinRoomResponse),
311    (JoinChannelChat, JoinChannelChatResponse),
312    (LeaveRoom, Ack),
313    (RejoinRoom, RejoinRoomResponse),
314    (IncomingCall, Ack),
315    (OpenBufferById, OpenBufferResponse),
316    (OpenBufferByPath, OpenBufferResponse),
317    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
318    (Ping, Ack),
319    (PerformRename, PerformRenameResponse),
320    (PrepareRename, PrepareRenameResponse),
321    (OnTypeFormatting, OnTypeFormattingResponse),
322    (InlayHints, InlayHintsResponse),
323    (
324        ResolveCompletionDocumentation,
325        ResolveCompletionDocumentationResponse
326    ),
327    (ResolveInlayHint, ResolveInlayHintResponse),
328    (RefreshInlayHints, Ack),
329    (ReloadBuffers, ReloadBuffersResponse),
330    (RequestContact, Ack),
331    (RemoveChannelMember, Ack),
332    (RemoveContact, Ack),
333    (RespondToContactRequest, Ack),
334    (RespondToChannelInvite, Ack),
335    (SetChannelMemberAdmin, Ack),
336    (SendChannelMessage, SendChannelMessageResponse),
337    (GetChannelMessages, GetChannelMessagesResponse),
338    (GetChannelMembers, GetChannelMembersResponse),
339    (JoinChannel, JoinRoomResponse),
340    (RemoveChannelMessage, Ack),
341    (DeleteChannel, Ack),
342    (RenameProjectEntry, ProjectEntryResponse),
343    (RenameChannel, RenameChannelResponse),
344    (LinkChannel, Ack),
345    (UnlinkChannel, Ack),
346    (MoveChannel, Ack),
347    (SaveBuffer, BufferSaved),
348    (SearchProject, SearchProjectResponse),
349    (ShareProject, ShareProjectResponse),
350    (SynchronizeBuffers, SynchronizeBuffersResponse),
351    (RejoinChannelBuffers, RejoinChannelBuffersResponse),
352    (Test, Test),
353    (UpdateBuffer, Ack),
354    (UpdateParticipantLocation, Ack),
355    (UpdateProject, Ack),
356    (UpdateWorktree, Ack),
357    (JoinChannelBuffer, JoinChannelBufferResponse),
358    (LeaveChannelBuffer, Ack)
359);
360
361entity_messages!(
362    project_id,
363    AddProjectCollaborator,
364    ApplyCodeAction,
365    ApplyCompletionAdditionalEdits,
366    BufferReloaded,
367    BufferSaved,
368    CopyProjectEntry,
369    CreateBufferForPeer,
370    CreateProjectEntry,
371    DeleteProjectEntry,
372    ExpandProjectEntry,
373    FormatBuffers,
374    GetCodeActions,
375    GetCompletions,
376    GetDefinition,
377    GetTypeDefinition,
378    GetDocumentHighlights,
379    GetHover,
380    GetReferences,
381    GetProjectSymbols,
382    JoinProject,
383    LeaveProject,
384    OpenBufferById,
385    OpenBufferByPath,
386    OpenBufferForSymbol,
387    PerformRename,
388    OnTypeFormatting,
389    InlayHints,
390    ResolveCompletionDocumentation,
391    ResolveInlayHint,
392    RefreshInlayHints,
393    PrepareRename,
394    ReloadBuffers,
395    RemoveProjectCollaborator,
396    RenameProjectEntry,
397    SaveBuffer,
398    SearchProject,
399    StartLanguageServer,
400    SynchronizeBuffers,
401    UnshareProject,
402    UpdateBuffer,
403    UpdateBufferFile,
404    UpdateDiagnosticSummary,
405    UpdateLanguageServer,
406    UpdateProject,
407    UpdateProjectCollaborator,
408    UpdateWorktree,
409    UpdateWorktreeSettings,
410    UpdateDiffBase
411);
412
413entity_messages!(
414    channel_id,
415    ChannelMessageSent,
416    UpdateChannelBuffer,
417    RemoveChannelMessage,
418    UpdateChannelBufferCollaborators,
419);
420
421const KIB: usize = 1024;
422const MIB: usize = KIB * 1024;
423const MAX_BUFFER_LEN: usize = MIB;
424
425/// A stream of protobuf messages.
426pub struct MessageStream<S> {
427    stream: S,
428    encoding_buffer: Vec<u8>,
429}
430
431#[allow(clippy::large_enum_variant)]
432#[derive(Debug)]
433pub enum Message {
434    Envelope(Envelope),
435    Ping,
436    Pong,
437}
438
439impl<S> MessageStream<S> {
440    pub fn new(stream: S) -> Self {
441        Self {
442            stream,
443            encoding_buffer: Vec::new(),
444        }
445    }
446
447    pub fn inner_mut(&mut self) -> &mut S {
448        &mut self.stream
449    }
450}
451
452impl<S> MessageStream<S>
453where
454    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
455{
456    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
457        #[cfg(any(test, feature = "test-support"))]
458        const COMPRESSION_LEVEL: i32 = -7;
459
460        #[cfg(not(any(test, feature = "test-support")))]
461        const COMPRESSION_LEVEL: i32 = 4;
462
463        match message {
464            Message::Envelope(message) => {
465                self.encoding_buffer.reserve(message.encoded_len());
466                message
467                    .encode(&mut self.encoding_buffer)
468                    .map_err(io::Error::from)?;
469                let buffer =
470                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
471                        .unwrap();
472
473                self.encoding_buffer.clear();
474                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
475                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
476            }
477            Message::Ping => {
478                self.stream
479                    .send(WebSocketMessage::Ping(Default::default()))
480                    .await?;
481            }
482            Message::Pong => {
483                self.stream
484                    .send(WebSocketMessage::Pong(Default::default()))
485                    .await?;
486            }
487        }
488
489        Ok(())
490    }
491}
492
493impl<S> MessageStream<S>
494where
495    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
496{
497    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
498        while let Some(bytes) = self.stream.next().await {
499            match bytes? {
500                WebSocketMessage::Binary(bytes) => {
501                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
502                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
503                        .map_err(io::Error::from)?;
504
505                    self.encoding_buffer.clear();
506                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
507                    return Ok(Message::Envelope(envelope));
508                }
509                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
510                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
511                WebSocketMessage::Close(_) => break,
512                _ => {}
513            }
514        }
515        Err(anyhow!("connection closed"))
516    }
517}
518
519impl From<Timestamp> for SystemTime {
520    fn from(val: Timestamp) -> Self {
521        UNIX_EPOCH
522            .checked_add(Duration::new(val.seconds, val.nanos))
523            .unwrap()
524    }
525}
526
527impl From<SystemTime> for Timestamp {
528    fn from(time: SystemTime) -> Self {
529        let duration = time.duration_since(UNIX_EPOCH).unwrap();
530        Self {
531            seconds: duration.as_secs(),
532            nanos: duration.subsec_nanos(),
533        }
534    }
535}
536
537impl From<u128> for Nonce {
538    fn from(nonce: u128) -> Self {
539        let upper_half = (nonce >> 64) as u64;
540        let lower_half = nonce as u64;
541        Self {
542            upper_half,
543            lower_half,
544        }
545    }
546}
547
548impl From<Nonce> for u128 {
549    fn from(nonce: Nonce) -> Self {
550        let upper_half = (nonce.upper_half as u128) << 64;
551        let lower_half = nonce.lower_half as u128;
552        upper_half | lower_half
553    }
554}
555
556pub fn split_worktree_update(
557    mut message: UpdateWorktree,
558    max_chunk_size: usize,
559) -> impl Iterator<Item = UpdateWorktree> {
560    let mut done_files = false;
561
562    let mut repository_map = message
563        .updated_repositories
564        .into_iter()
565        .map(|repo| (repo.work_directory_id, repo))
566        .collect::<HashMap<_, _>>();
567
568    iter::from_fn(move || {
569        if done_files {
570            return None;
571        }
572
573        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
574        let updated_entries: Vec<_> = message
575            .updated_entries
576            .drain(..updated_entries_chunk_size)
577            .collect();
578
579        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
580        let removed_entries = message
581            .removed_entries
582            .drain(..removed_entries_chunk_size)
583            .collect();
584
585        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
586
587        let mut updated_repositories = Vec::new();
588
589        if !repository_map.is_empty() {
590            for entry in &updated_entries {
591                if let Some(repo) = repository_map.remove(&entry.id) {
592                    updated_repositories.push(repo)
593                }
594            }
595        }
596
597        let removed_repositories = if done_files {
598            mem::take(&mut message.removed_repositories)
599        } else {
600            Default::default()
601        };
602
603        if done_files {
604            updated_repositories.extend(mem::take(&mut repository_map).into_values());
605        }
606
607        Some(UpdateWorktree {
608            project_id: message.project_id,
609            worktree_id: message.worktree_id,
610            root_name: message.root_name.clone(),
611            abs_path: message.abs_path.clone(),
612            updated_entries,
613            removed_entries,
614            scan_id: message.scan_id,
615            is_last_update: done_files && message.is_last_update,
616            updated_repositories,
617            removed_repositories,
618        })
619    })
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625
626    #[gpui::test]
627    async fn test_buffer_size() {
628        let (tx, rx) = futures::channel::mpsc::unbounded();
629        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
630        sink.write(Message::Envelope(Envelope {
631            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
632                root_name: "abcdefg".repeat(10),
633                ..Default::default()
634            })),
635            ..Default::default()
636        }))
637        .await
638        .unwrap();
639        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
640        sink.write(Message::Envelope(Envelope {
641            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
642                root_name: "abcdefg".repeat(1000000),
643                ..Default::default()
644            })),
645            ..Default::default()
646        }))
647        .await
648        .unwrap();
649        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
650
651        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
652        stream.read().await.unwrap();
653        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
654        stream.read().await.unwrap();
655        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
656    }
657
658    #[gpui::test]
659    fn test_converting_peer_id_from_and_to_u64() {
660        let peer_id = PeerId {
661            owner_id: 10,
662            id: 3,
663        };
664        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
665        let peer_id = PeerId {
666            owner_id: u32::MAX,
667            id: 3,
668        };
669        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
670        let peer_id = PeerId {
671            owner_id: 10,
672            id: u32::MAX,
673        };
674        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
675        let peer_id = PeerId {
676            owner_id: u32::MAX,
677            id: u32::MAX,
678        };
679        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
680    }
681}