proto.rs

  1#![allow(non_snake_case)]
  2
  3use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  4use anyhow::{anyhow, Result};
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use collections::HashMap;
  7use futures::{SinkExt as _, StreamExt as _};
  8use prost::Message as _;
  9use serde::Serialize;
 10use std::any::{Any, TypeId};
 11use std::{
 12    cmp,
 13    fmt::Debug,
 14    io, iter,
 15    time::{Duration, SystemTime, UNIX_EPOCH},
 16};
 17use std::{fmt, mem};
 18
 19include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 20
 21pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 22    const NAME: &'static str;
 23    const PRIORITY: MessagePriority;
 24    fn into_envelope(
 25        self,
 26        id: u32,
 27        responding_to: Option<u32>,
 28        original_sender_id: Option<PeerId>,
 29    ) -> Envelope;
 30    fn from_envelope(envelope: Envelope) -> Option<Self>;
 31}
 32
 33pub trait EntityMessage: EnvelopedMessage {
 34    fn remote_entity_id(&self) -> u64;
 35}
 36
 37pub trait RequestMessage: EnvelopedMessage {
 38    type Response: EnvelopedMessage;
 39}
 40
 41pub trait AnyTypedEnvelope: 'static + Send + Sync {
 42    fn payload_type_id(&self) -> TypeId;
 43    fn payload_type_name(&self) -> &'static str;
 44    fn as_any(&self) -> &dyn Any;
 45    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 46    fn is_background(&self) -> bool;
 47    fn original_sender_id(&self) -> Option<PeerId>;
 48    fn sender_id(&self) -> ConnectionId;
 49    fn message_id(&self) -> u32;
 50}
 51
 52pub enum MessagePriority {
 53    Foreground,
 54    Background,
 55}
 56
 57impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 58    fn payload_type_id(&self) -> TypeId {
 59        TypeId::of::<T>()
 60    }
 61
 62    fn payload_type_name(&self) -> &'static str {
 63        T::NAME
 64    }
 65
 66    fn as_any(&self) -> &dyn Any {
 67        self
 68    }
 69
 70    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 71        self
 72    }
 73
 74    fn is_background(&self) -> bool {
 75        matches!(T::PRIORITY, MessagePriority::Background)
 76    }
 77
 78    fn original_sender_id(&self) -> Option<PeerId> {
 79        self.original_sender_id
 80    }
 81
 82    fn sender_id(&self) -> ConnectionId {
 83        self.sender_id
 84    }
 85
 86    fn message_id(&self) -> u32 {
 87        self.message_id
 88    }
 89}
 90
 91impl PeerId {
 92    pub fn from_u64(peer_id: u64) -> Self {
 93        let owner_id = (peer_id >> 32) as u32;
 94        let id = peer_id as u32;
 95        Self { owner_id, id }
 96    }
 97
 98    pub fn as_u64(self) -> u64 {
 99        ((self.owner_id as u64) << 32) | (self.id as u64)
100    }
101}
102
103impl Copy for PeerId {}
104
105impl Eq for PeerId {}
106
107impl Ord for PeerId {
108    fn cmp(&self, other: &Self) -> cmp::Ordering {
109        self.owner_id
110            .cmp(&other.owner_id)
111            .then_with(|| self.id.cmp(&other.id))
112    }
113}
114
115impl PartialOrd for PeerId {
116    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
117        Some(self.cmp(other))
118    }
119}
120
121impl std::hash::Hash for PeerId {
122    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123        self.owner_id.hash(state);
124        self.id.hash(state);
125    }
126}
127
128impl fmt::Display for PeerId {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "{}/{}", self.owner_id, self.id)
131    }
132}
133
134messages!(
135    (Ack, Foreground),
136    (AddProjectCollaborator, Foreground),
137    (ApplyCodeAction, Background),
138    (ApplyCodeActionResponse, Background),
139    (ApplyCompletionAdditionalEdits, Background),
140    (ApplyCompletionAdditionalEditsResponse, Background),
141    (BufferReloaded, Foreground),
142    (BufferSaved, Foreground),
143    (Call, Foreground),
144    (CallCanceled, Foreground),
145    (CancelCall, Foreground),
146    (CopyProjectEntry, Foreground),
147    (CreateBufferForPeer, Foreground),
148    (CreateChannel, Foreground),
149    (CreateChannelResponse, Foreground),
150    (ChannelMessageSent, Foreground),
151    (CreateProjectEntry, Foreground),
152    (CreateRoom, Foreground),
153    (CreateRoomResponse, Foreground),
154    (DeclineCall, Foreground),
155    (DeleteProjectEntry, Foreground),
156    (Error, Foreground),
157    (ExpandProjectEntry, Foreground),
158    (Follow, Foreground),
159    (FollowResponse, Foreground),
160    (FormatBuffers, Foreground),
161    (FormatBuffersResponse, Foreground),
162    (FuzzySearchUsers, Foreground),
163    (GetCodeActions, Background),
164    (GetCodeActionsResponse, Background),
165    (GetHover, Background),
166    (GetHoverResponse, Background),
167    (GetChannelMessages, Background),
168    (GetChannelMessagesResponse, Background),
169    (SendChannelMessage, Background),
170    (SendChannelMessageResponse, Background),
171    (GetCompletions, Background),
172    (GetCompletionsResponse, Background),
173    (GetDefinition, Background),
174    (GetDefinitionResponse, Background),
175    (GetTypeDefinition, Background),
176    (GetTypeDefinitionResponse, Background),
177    (GetDocumentHighlights, Background),
178    (GetDocumentHighlightsResponse, Background),
179    (GetReferences, Background),
180    (GetReferencesResponse, Background),
181    (GetProjectSymbols, Background),
182    (GetProjectSymbolsResponse, Background),
183    (GetUsers, Foreground),
184    (Hello, Foreground),
185    (IncomingCall, Foreground),
186    (InviteChannelMember, Foreground),
187    (UsersResponse, Foreground),
188    (JoinProject, Foreground),
189    (JoinProjectResponse, Foreground),
190    (JoinRoom, Foreground),
191    (JoinRoomResponse, Foreground),
192    (JoinChannelChat, Foreground),
193    (JoinChannelChatResponse, Foreground),
194    (LeaveChannelChat, Foreground),
195    (LeaveProject, Foreground),
196    (LeaveRoom, Foreground),
197    (OpenBufferById, Background),
198    (OpenBufferByPath, Background),
199    (OpenBufferForSymbol, Background),
200    (OpenBufferForSymbolResponse, Background),
201    (OpenBufferResponse, Background),
202    (PerformRename, Background),
203    (PerformRenameResponse, Background),
204    (OnTypeFormatting, Background),
205    (OnTypeFormattingResponse, Background),
206    (InlayHints, Background),
207    (InlayHintsResponse, Background),
208    (ResolveCompletionDocumentation, Background),
209    (ResolveCompletionDocumentationResponse, Background),
210    (ResolveInlayHint, Background),
211    (ResolveInlayHintResponse, Background),
212    (RefreshInlayHints, Foreground),
213    (Ping, Foreground),
214    (PrepareRename, Background),
215    (PrepareRenameResponse, Background),
216    (ExpandProjectEntryResponse, Foreground),
217    (ProjectEntryResponse, Foreground),
218    (RejoinRoom, Foreground),
219    (RejoinRoomResponse, Foreground),
220    (RemoveContact, Foreground),
221    (RemoveChannelMember, Foreground),
222    (RemoveChannelMessage, Foreground),
223    (ReloadBuffers, Foreground),
224    (ReloadBuffersResponse, Foreground),
225    (RemoveProjectCollaborator, Foreground),
226    (RenameProjectEntry, Foreground),
227    (RequestContact, Foreground),
228    (RespondToContactRequest, Foreground),
229    (RespondToChannelInvite, Foreground),
230    (JoinChannel, Foreground),
231    (RoomUpdated, Foreground),
232    (SaveBuffer, Foreground),
233    (RenameChannel, Foreground),
234    (RenameChannelResponse, Foreground),
235    (SetChannelMemberRole, Foreground),
236    (SetChannelVisibility, Foreground),
237    (SearchProject, Background),
238    (SearchProjectResponse, Background),
239    (ShareProject, Foreground),
240    (ShareProjectResponse, Foreground),
241    (ShowContacts, Foreground),
242    (StartLanguageServer, Foreground),
243    (SynchronizeBuffers, Foreground),
244    (SynchronizeBuffersResponse, Foreground),
245    (RejoinChannelBuffers, Foreground),
246    (RejoinChannelBuffersResponse, Foreground),
247    (Test, Foreground),
248    (Unfollow, Foreground),
249    (UnshareProject, Foreground),
250    (UpdateBuffer, Foreground),
251    (UpdateBufferFile, Foreground),
252    (UpdateContacts, Foreground),
253    (DeleteChannel, Foreground),
254    (MoveChannel, Foreground),
255    (LinkChannel, Foreground),
256    (UnlinkChannel, Foreground),
257    (UpdateChannels, Foreground),
258    (UpdateDiagnosticSummary, Foreground),
259    (UpdateFollowers, Foreground),
260    (UpdateInviteInfo, Foreground),
261    (UpdateLanguageServer, Foreground),
262    (UpdateParticipantLocation, Foreground),
263    (UpdateProject, Foreground),
264    (UpdateProjectCollaborator, Foreground),
265    (UpdateWorktree, Foreground),
266    (UpdateWorktreeSettings, Foreground),
267    (UpdateDiffBase, Foreground),
268    (GetPrivateUserInfo, Foreground),
269    (GetPrivateUserInfoResponse, Foreground),
270    (GetChannelMembers, Foreground),
271    (GetChannelMembersResponse, Foreground),
272    (JoinChannelBuffer, Foreground),
273    (JoinChannelBufferResponse, Foreground),
274    (LeaveChannelBuffer, Background),
275    (UpdateChannelBuffer, Foreground),
276    (UpdateChannelBufferCollaborators, Foreground),
277    (AckBufferOperation, Background),
278    (AckChannelMessage, Background),
279);
280
281request_messages!(
282    (ApplyCodeAction, ApplyCodeActionResponse),
283    (
284        ApplyCompletionAdditionalEdits,
285        ApplyCompletionAdditionalEditsResponse
286    ),
287    (Call, Ack),
288    (CancelCall, Ack),
289    (CopyProjectEntry, ProjectEntryResponse),
290    (CreateProjectEntry, ProjectEntryResponse),
291    (CreateRoom, CreateRoomResponse),
292    (CreateChannel, CreateChannelResponse),
293    (DeclineCall, Ack),
294    (DeleteProjectEntry, ProjectEntryResponse),
295    (ExpandProjectEntry, ExpandProjectEntryResponse),
296    (Follow, FollowResponse),
297    (FormatBuffers, FormatBuffersResponse),
298    (GetCodeActions, GetCodeActionsResponse),
299    (GetHover, GetHoverResponse),
300    (GetCompletions, GetCompletionsResponse),
301    (GetDefinition, GetDefinitionResponse),
302    (GetTypeDefinition, GetTypeDefinitionResponse),
303    (GetDocumentHighlights, GetDocumentHighlightsResponse),
304    (GetReferences, GetReferencesResponse),
305    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
306    (GetProjectSymbols, GetProjectSymbolsResponse),
307    (FuzzySearchUsers, UsersResponse),
308    (GetUsers, UsersResponse),
309    (InviteChannelMember, Ack),
310    (JoinProject, JoinProjectResponse),
311    (JoinRoom, JoinRoomResponse),
312    (JoinChannelChat, JoinChannelChatResponse),
313    (LeaveRoom, Ack),
314    (RejoinRoom, RejoinRoomResponse),
315    (IncomingCall, Ack),
316    (OpenBufferById, OpenBufferResponse),
317    (OpenBufferByPath, OpenBufferResponse),
318    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
319    (Ping, Ack),
320    (PerformRename, PerformRenameResponse),
321    (PrepareRename, PrepareRenameResponse),
322    (OnTypeFormatting, OnTypeFormattingResponse),
323    (InlayHints, InlayHintsResponse),
324    (
325        ResolveCompletionDocumentation,
326        ResolveCompletionDocumentationResponse
327    ),
328    (ResolveInlayHint, ResolveInlayHintResponse),
329    (RefreshInlayHints, Ack),
330    (ReloadBuffers, ReloadBuffersResponse),
331    (RequestContact, Ack),
332    (RemoveChannelMember, Ack),
333    (RemoveContact, Ack),
334    (RespondToContactRequest, Ack),
335    (RespondToChannelInvite, Ack),
336    (SetChannelMemberRole, Ack),
337    (SetChannelVisibility, Ack),
338    (SendChannelMessage, SendChannelMessageResponse),
339    (GetChannelMessages, GetChannelMessagesResponse),
340    (GetChannelMembers, GetChannelMembersResponse),
341    (JoinChannel, JoinRoomResponse),
342    (RemoveChannelMessage, Ack),
343    (DeleteChannel, Ack),
344    (RenameProjectEntry, ProjectEntryResponse),
345    (RenameChannel, RenameChannelResponse),
346    (LinkChannel, Ack),
347    (UnlinkChannel, Ack),
348    (MoveChannel, Ack),
349    (SaveBuffer, BufferSaved),
350    (SearchProject, SearchProjectResponse),
351    (ShareProject, ShareProjectResponse),
352    (SynchronizeBuffers, SynchronizeBuffersResponse),
353    (RejoinChannelBuffers, RejoinChannelBuffersResponse),
354    (Test, Test),
355    (UpdateBuffer, Ack),
356    (UpdateParticipantLocation, Ack),
357    (UpdateProject, Ack),
358    (UpdateWorktree, Ack),
359    (JoinChannelBuffer, JoinChannelBufferResponse),
360    (LeaveChannelBuffer, Ack)
361);
362
363entity_messages!(
364    project_id,
365    AddProjectCollaborator,
366    ApplyCodeAction,
367    ApplyCompletionAdditionalEdits,
368    BufferReloaded,
369    BufferSaved,
370    CopyProjectEntry,
371    CreateBufferForPeer,
372    CreateProjectEntry,
373    DeleteProjectEntry,
374    ExpandProjectEntry,
375    FormatBuffers,
376    GetCodeActions,
377    GetCompletions,
378    GetDefinition,
379    GetTypeDefinition,
380    GetDocumentHighlights,
381    GetHover,
382    GetReferences,
383    GetProjectSymbols,
384    JoinProject,
385    LeaveProject,
386    OpenBufferById,
387    OpenBufferByPath,
388    OpenBufferForSymbol,
389    PerformRename,
390    OnTypeFormatting,
391    InlayHints,
392    ResolveCompletionDocumentation,
393    ResolveInlayHint,
394    RefreshInlayHints,
395    PrepareRename,
396    ReloadBuffers,
397    RemoveProjectCollaborator,
398    RenameProjectEntry,
399    SaveBuffer,
400    SearchProject,
401    StartLanguageServer,
402    SynchronizeBuffers,
403    UnshareProject,
404    UpdateBuffer,
405    UpdateBufferFile,
406    UpdateDiagnosticSummary,
407    UpdateLanguageServer,
408    UpdateProject,
409    UpdateProjectCollaborator,
410    UpdateWorktree,
411    UpdateWorktreeSettings,
412    UpdateDiffBase
413);
414
415entity_messages!(
416    channel_id,
417    ChannelMessageSent,
418    UpdateChannelBuffer,
419    RemoveChannelMessage,
420    UpdateChannelBufferCollaborators,
421);
422
423const KIB: usize = 1024;
424const MIB: usize = KIB * 1024;
425const MAX_BUFFER_LEN: usize = MIB;
426
427/// A stream of protobuf messages.
428pub struct MessageStream<S> {
429    stream: S,
430    encoding_buffer: Vec<u8>,
431}
432
433#[allow(clippy::large_enum_variant)]
434#[derive(Debug)]
435pub enum Message {
436    Envelope(Envelope),
437    Ping,
438    Pong,
439}
440
441impl<S> MessageStream<S> {
442    pub fn new(stream: S) -> Self {
443        Self {
444            stream,
445            encoding_buffer: Vec::new(),
446        }
447    }
448
449    pub fn inner_mut(&mut self) -> &mut S {
450        &mut self.stream
451    }
452}
453
454impl<S> MessageStream<S>
455where
456    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
457{
458    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
459        #[cfg(any(test, feature = "test-support"))]
460        const COMPRESSION_LEVEL: i32 = -7;
461
462        #[cfg(not(any(test, feature = "test-support")))]
463        const COMPRESSION_LEVEL: i32 = 4;
464
465        match message {
466            Message::Envelope(message) => {
467                self.encoding_buffer.reserve(message.encoded_len());
468                message
469                    .encode(&mut self.encoding_buffer)
470                    .map_err(io::Error::from)?;
471                let buffer =
472                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
473                        .unwrap();
474
475                self.encoding_buffer.clear();
476                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
477                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
478            }
479            Message::Ping => {
480                self.stream
481                    .send(WebSocketMessage::Ping(Default::default()))
482                    .await?;
483            }
484            Message::Pong => {
485                self.stream
486                    .send(WebSocketMessage::Pong(Default::default()))
487                    .await?;
488            }
489        }
490
491        Ok(())
492    }
493}
494
495impl<S> MessageStream<S>
496where
497    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
498{
499    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
500        while let Some(bytes) = self.stream.next().await {
501            match bytes? {
502                WebSocketMessage::Binary(bytes) => {
503                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
504                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
505                        .map_err(io::Error::from)?;
506
507                    self.encoding_buffer.clear();
508                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
509                    return Ok(Message::Envelope(envelope));
510                }
511                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
512                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
513                WebSocketMessage::Close(_) => break,
514                _ => {}
515            }
516        }
517        Err(anyhow!("connection closed"))
518    }
519}
520
521impl From<Timestamp> for SystemTime {
522    fn from(val: Timestamp) -> Self {
523        UNIX_EPOCH
524            .checked_add(Duration::new(val.seconds, val.nanos))
525            .unwrap()
526    }
527}
528
529impl From<SystemTime> for Timestamp {
530    fn from(time: SystemTime) -> Self {
531        let duration = time.duration_since(UNIX_EPOCH).unwrap();
532        Self {
533            seconds: duration.as_secs(),
534            nanos: duration.subsec_nanos(),
535        }
536    }
537}
538
539impl From<u128> for Nonce {
540    fn from(nonce: u128) -> Self {
541        let upper_half = (nonce >> 64) as u64;
542        let lower_half = nonce as u64;
543        Self {
544            upper_half,
545            lower_half,
546        }
547    }
548}
549
550impl From<Nonce> for u128 {
551    fn from(nonce: Nonce) -> Self {
552        let upper_half = (nonce.upper_half as u128) << 64;
553        let lower_half = nonce.lower_half as u128;
554        upper_half | lower_half
555    }
556}
557
558pub fn split_worktree_update(
559    mut message: UpdateWorktree,
560    max_chunk_size: usize,
561) -> impl Iterator<Item = UpdateWorktree> {
562    let mut done_files = false;
563
564    let mut repository_map = message
565        .updated_repositories
566        .into_iter()
567        .map(|repo| (repo.work_directory_id, repo))
568        .collect::<HashMap<_, _>>();
569
570    iter::from_fn(move || {
571        if done_files {
572            return None;
573        }
574
575        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
576        let updated_entries: Vec<_> = message
577            .updated_entries
578            .drain(..updated_entries_chunk_size)
579            .collect();
580
581        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
582        let removed_entries = message
583            .removed_entries
584            .drain(..removed_entries_chunk_size)
585            .collect();
586
587        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
588
589        let mut updated_repositories = Vec::new();
590
591        if !repository_map.is_empty() {
592            for entry in &updated_entries {
593                if let Some(repo) = repository_map.remove(&entry.id) {
594                    updated_repositories.push(repo)
595                }
596            }
597        }
598
599        let removed_repositories = if done_files {
600            mem::take(&mut message.removed_repositories)
601        } else {
602            Default::default()
603        };
604
605        if done_files {
606            updated_repositories.extend(mem::take(&mut repository_map).into_values());
607        }
608
609        Some(UpdateWorktree {
610            project_id: message.project_id,
611            worktree_id: message.worktree_id,
612            root_name: message.root_name.clone(),
613            abs_path: message.abs_path.clone(),
614            updated_entries,
615            removed_entries,
616            scan_id: message.scan_id,
617            is_last_update: done_files && message.is_last_update,
618            updated_repositories,
619            removed_repositories,
620        })
621    })
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627
628    #[gpui::test]
629    async fn test_buffer_size() {
630        let (tx, rx) = futures::channel::mpsc::unbounded();
631        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
632        sink.write(Message::Envelope(Envelope {
633            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
634                root_name: "abcdefg".repeat(10),
635                ..Default::default()
636            })),
637            ..Default::default()
638        }))
639        .await
640        .unwrap();
641        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
642        sink.write(Message::Envelope(Envelope {
643            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
644                root_name: "abcdefg".repeat(1000000),
645                ..Default::default()
646            })),
647            ..Default::default()
648        }))
649        .await
650        .unwrap();
651        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
652
653        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
654        stream.read().await.unwrap();
655        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
656        stream.read().await.unwrap();
657        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
658    }
659
660    #[gpui::test]
661    fn test_converting_peer_id_from_and_to_u64() {
662        let peer_id = PeerId {
663            owner_id: 10,
664            id: 3,
665        };
666        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
667        let peer_id = PeerId {
668            owner_id: u32::MAX,
669            id: 3,
670        };
671        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
672        let peer_id = PeerId {
673            owner_id: 10,
674            id: u32::MAX,
675        };
676        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
677        let peer_id = PeerId {
678            owner_id: u32::MAX,
679            id: u32::MAX,
680        };
681        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
682    }
683}