proto.rs

  1#![allow(non_snake_case)]
  2
  3use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  4use anyhow::{anyhow, Result};
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use collections::HashMap;
  7use futures::{SinkExt as _, StreamExt as _};
  8use prost::Message as _;
  9use serde::Serialize;
 10use std::any::{Any, TypeId};
 11use std::{
 12    cmp,
 13    fmt::Debug,
 14    io, iter,
 15    time::{Duration, SystemTime, UNIX_EPOCH},
 16};
 17use std::{fmt, mem};
 18
 19include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 20
 21pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 22    const NAME: &'static str;
 23    const PRIORITY: MessagePriority;
 24    fn into_envelope(
 25        self,
 26        id: u32,
 27        responding_to: Option<u32>,
 28        original_sender_id: Option<PeerId>,
 29    ) -> Envelope;
 30    fn from_envelope(envelope: Envelope) -> Option<Self>;
 31}
 32
 33pub trait EntityMessage: EnvelopedMessage {
 34    fn remote_entity_id(&self) -> u64;
 35}
 36
 37pub trait RequestMessage: EnvelopedMessage {
 38    type Response: EnvelopedMessage;
 39}
 40
 41pub trait AnyTypedEnvelope: 'static + Send + Sync {
 42    fn payload_type_id(&self) -> TypeId;
 43    fn payload_type_name(&self) -> &'static str;
 44    fn as_any(&self) -> &dyn Any;
 45    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 46    fn is_background(&self) -> bool;
 47    fn original_sender_id(&self) -> Option<PeerId>;
 48    fn sender_id(&self) -> ConnectionId;
 49    fn message_id(&self) -> u32;
 50}
 51
 52pub enum MessagePriority {
 53    Foreground,
 54    Background,
 55}
 56
 57impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 58    fn payload_type_id(&self) -> TypeId {
 59        TypeId::of::<T>()
 60    }
 61
 62    fn payload_type_name(&self) -> &'static str {
 63        T::NAME
 64    }
 65
 66    fn as_any(&self) -> &dyn Any {
 67        self
 68    }
 69
 70    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 71        self
 72    }
 73
 74    fn is_background(&self) -> bool {
 75        matches!(T::PRIORITY, MessagePriority::Background)
 76    }
 77
 78    fn original_sender_id(&self) -> Option<PeerId> {
 79        self.original_sender_id
 80    }
 81
 82    fn sender_id(&self) -> ConnectionId {
 83        self.sender_id
 84    }
 85
 86    fn message_id(&self) -> u32 {
 87        self.message_id
 88    }
 89}
 90
 91impl PeerId {
 92    pub fn from_u64(peer_id: u64) -> Self {
 93        let owner_id = (peer_id >> 32) as u32;
 94        let id = peer_id as u32;
 95        Self { owner_id, id }
 96    }
 97
 98    pub fn as_u64(self) -> u64 {
 99        ((self.owner_id as u64) << 32) | (self.id as u64)
100    }
101}
102
103impl Copy for PeerId {}
104
105impl Eq for PeerId {}
106
107impl Ord for PeerId {
108    fn cmp(&self, other: &Self) -> cmp::Ordering {
109        self.owner_id
110            .cmp(&other.owner_id)
111            .then_with(|| self.id.cmp(&other.id))
112    }
113}
114
115impl PartialOrd for PeerId {
116    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
117        Some(self.cmp(other))
118    }
119}
120
121impl std::hash::Hash for PeerId {
122    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123        self.owner_id.hash(state);
124        self.id.hash(state);
125    }
126}
127
128impl fmt::Display for PeerId {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "{}/{}", self.owner_id, self.id)
131    }
132}
133
134messages!(
135    (Ack, Foreground),
136    (AddProjectCollaborator, Foreground),
137    (ApplyCodeAction, Background),
138    (ApplyCodeActionResponse, Background),
139    (ApplyCompletionAdditionalEdits, Background),
140    (ApplyCompletionAdditionalEditsResponse, Background),
141    (BufferReloaded, Foreground),
142    (BufferSaved, Foreground),
143    (Call, Foreground),
144    (CallCanceled, Foreground),
145    (CancelCall, Foreground),
146    (CopyProjectEntry, Foreground),
147    (CreateBufferForPeer, Foreground),
148    (CreateChannel, Foreground),
149    (CreateChannelResponse, Foreground),
150    (ChannelMessageSent, Foreground),
151    (CreateProjectEntry, Foreground),
152    (CreateRoom, Foreground),
153    (CreateRoomResponse, Foreground),
154    (DeclineCall, Foreground),
155    (DeleteProjectEntry, Foreground),
156    (Error, Foreground),
157    (ExpandProjectEntry, Foreground),
158    (Follow, Foreground),
159    (FollowResponse, Foreground),
160    (FormatBuffers, Foreground),
161    (FormatBuffersResponse, Foreground),
162    (FuzzySearchUsers, Foreground),
163    (GetCodeActions, Background),
164    (GetCodeActionsResponse, Background),
165    (GetHover, Background),
166    (GetHoverResponse, Background),
167    (GetChannelMessages, Background),
168    (GetChannelMessagesResponse, Background),
169    (SendChannelMessage, Background),
170    (SendChannelMessageResponse, Background),
171    (GetCompletions, Background),
172    (GetCompletionsResponse, Background),
173    (GetDefinition, Background),
174    (GetDefinitionResponse, Background),
175    (GetTypeDefinition, Background),
176    (GetTypeDefinitionResponse, Background),
177    (GetDocumentHighlights, Background),
178    (GetDocumentHighlightsResponse, Background),
179    (GetReferences, Background),
180    (GetReferencesResponse, Background),
181    (GetProjectSymbols, Background),
182    (GetProjectSymbolsResponse, Background),
183    (GetUsers, Foreground),
184    (Hello, Foreground),
185    (IncomingCall, Foreground),
186    (InviteChannelMember, Foreground),
187    (UsersResponse, Foreground),
188    (JoinProject, Foreground),
189    (JoinProjectResponse, Foreground),
190    (JoinRoom, Foreground),
191    (JoinRoomResponse, Foreground),
192    (JoinChannelChat, Foreground),
193    (JoinChannelChatResponse, Foreground),
194    (LeaveChannelChat, Foreground),
195    (LeaveProject, Foreground),
196    (LeaveRoom, Foreground),
197    (OpenBufferById, Background),
198    (OpenBufferByPath, Background),
199    (OpenBufferForSymbol, Background),
200    (OpenBufferForSymbolResponse, Background),
201    (OpenBufferResponse, Background),
202    (PerformRename, Background),
203    (PerformRenameResponse, Background),
204    (OnTypeFormatting, Background),
205    (OnTypeFormattingResponse, Background),
206    (InlayHints, Background),
207    (InlayHintsResponse, Background),
208    (ResolveInlayHint, Background),
209    (ResolveInlayHintResponse, Background),
210    (RefreshInlayHints, Foreground),
211    (Ping, Foreground),
212    (PrepareRename, Background),
213    (PrepareRenameResponse, Background),
214    (ExpandProjectEntryResponse, Foreground),
215    (ProjectEntryResponse, Foreground),
216    (RejoinRoom, Foreground),
217    (RejoinRoomResponse, Foreground),
218    (RemoveContact, Foreground),
219    (RemoveChannelMember, Foreground),
220    (RemoveChannelMessage, Foreground),
221    (ReloadBuffers, Foreground),
222    (ReloadBuffersResponse, Foreground),
223    (RemoveProjectCollaborator, Foreground),
224    (RenameProjectEntry, Foreground),
225    (RequestContact, Foreground),
226    (RespondToContactRequest, Foreground),
227    (RespondToChannelInvite, Foreground),
228    (JoinChannel, Foreground),
229    (RoomUpdated, Foreground),
230    (SaveBuffer, Foreground),
231    (RenameChannel, Foreground),
232    (RenameChannelResponse, Foreground),
233    (SetChannelMemberAdmin, Foreground),
234    (SearchProject, Background),
235    (SearchProjectResponse, Background),
236    (ShareProject, Foreground),
237    (ShareProjectResponse, Foreground),
238    (ShowContacts, Foreground),
239    (StartLanguageServer, Foreground),
240    (SynchronizeBuffers, Foreground),
241    (SynchronizeBuffersResponse, Foreground),
242    (RejoinChannelBuffers, Foreground),
243    (RejoinChannelBuffersResponse, Foreground),
244    (Test, Foreground),
245    (Unfollow, Foreground),
246    (UnshareProject, Foreground),
247    (UpdateBuffer, Foreground),
248    (UpdateBufferFile, Foreground),
249    (UpdateContacts, Foreground),
250    (DeleteChannel, Foreground),
251    (MoveChannel, Foreground),
252    (LinkChannel, Foreground),
253    (UnlinkChannel, Foreground),
254    (UpdateChannels, Foreground),
255    (UpdateDiagnosticSummary, Foreground),
256    (UpdateFollowers, Foreground),
257    (UpdateInviteInfo, Foreground),
258    (UpdateLanguageServer, Foreground),
259    (UpdateParticipantLocation, Foreground),
260    (UpdateProject, Foreground),
261    (UpdateProjectCollaborator, Foreground),
262    (UpdateWorktree, Foreground),
263    (UpdateWorktreeSettings, Foreground),
264    (UpdateDiffBase, Foreground),
265    (GetPrivateUserInfo, Foreground),
266    (GetPrivateUserInfoResponse, Foreground),
267    (GetChannelMembers, Foreground),
268    (GetChannelMembersResponse, Foreground),
269    (JoinChannelBuffer, Foreground),
270    (JoinChannelBufferResponse, Foreground),
271    (LeaveChannelBuffer, Background),
272    (UpdateChannelBuffer, Foreground),
273    (UpdateChannelBufferCollaborators, Foreground),
274);
275
276request_messages!(
277    (ApplyCodeAction, ApplyCodeActionResponse),
278    (
279        ApplyCompletionAdditionalEdits,
280        ApplyCompletionAdditionalEditsResponse
281    ),
282    (Call, Ack),
283    (CancelCall, Ack),
284    (CopyProjectEntry, ProjectEntryResponse),
285    (CreateProjectEntry, ProjectEntryResponse),
286    (CreateRoom, CreateRoomResponse),
287    (CreateChannel, CreateChannelResponse),
288    (DeclineCall, Ack),
289    (DeleteProjectEntry, ProjectEntryResponse),
290    (ExpandProjectEntry, ExpandProjectEntryResponse),
291    (Follow, FollowResponse),
292    (FormatBuffers, FormatBuffersResponse),
293    (GetCodeActions, GetCodeActionsResponse),
294    (GetHover, GetHoverResponse),
295    (GetCompletions, GetCompletionsResponse),
296    (GetDefinition, GetDefinitionResponse),
297    (GetTypeDefinition, GetTypeDefinitionResponse),
298    (GetDocumentHighlights, GetDocumentHighlightsResponse),
299    (GetReferences, GetReferencesResponse),
300    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
301    (GetProjectSymbols, GetProjectSymbolsResponse),
302    (FuzzySearchUsers, UsersResponse),
303    (GetUsers, UsersResponse),
304    (InviteChannelMember, Ack),
305    (JoinProject, JoinProjectResponse),
306    (JoinRoom, JoinRoomResponse),
307    (JoinChannelChat, JoinChannelChatResponse),
308    (LeaveRoom, Ack),
309    (RejoinRoom, RejoinRoomResponse),
310    (IncomingCall, Ack),
311    (OpenBufferById, OpenBufferResponse),
312    (OpenBufferByPath, OpenBufferResponse),
313    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
314    (Ping, Ack),
315    (PerformRename, PerformRenameResponse),
316    (PrepareRename, PrepareRenameResponse),
317    (OnTypeFormatting, OnTypeFormattingResponse),
318    (InlayHints, InlayHintsResponse),
319    (ResolveInlayHint, ResolveInlayHintResponse),
320    (RefreshInlayHints, Ack),
321    (ReloadBuffers, ReloadBuffersResponse),
322    (RequestContact, Ack),
323    (RemoveChannelMember, Ack),
324    (RemoveContact, Ack),
325    (RespondToContactRequest, Ack),
326    (RespondToChannelInvite, Ack),
327    (SetChannelMemberAdmin, Ack),
328    (SendChannelMessage, SendChannelMessageResponse),
329    (GetChannelMessages, GetChannelMessagesResponse),
330    (GetChannelMembers, GetChannelMembersResponse),
331    (JoinChannel, JoinRoomResponse),
332    (RemoveChannelMessage, Ack),
333    (DeleteChannel, Ack),
334    (RenameProjectEntry, ProjectEntryResponse),
335    (RenameChannel, RenameChannelResponse),
336    (LinkChannel, Ack),
337    (UnlinkChannel, Ack),
338    (MoveChannel, Ack),
339    (SaveBuffer, BufferSaved),
340    (SearchProject, SearchProjectResponse),
341    (ShareProject, ShareProjectResponse),
342    (SynchronizeBuffers, SynchronizeBuffersResponse),
343    (RejoinChannelBuffers, RejoinChannelBuffersResponse),
344    (Test, Test),
345    (UpdateBuffer, Ack),
346    (UpdateParticipantLocation, Ack),
347    (UpdateProject, Ack),
348    (UpdateWorktree, Ack),
349    (JoinChannelBuffer, JoinChannelBufferResponse),
350    (LeaveChannelBuffer, Ack)
351);
352
353entity_messages!(
354    project_id,
355    AddProjectCollaborator,
356    ApplyCodeAction,
357    ApplyCompletionAdditionalEdits,
358    BufferReloaded,
359    BufferSaved,
360    CopyProjectEntry,
361    CreateBufferForPeer,
362    CreateProjectEntry,
363    DeleteProjectEntry,
364    ExpandProjectEntry,
365    FormatBuffers,
366    GetCodeActions,
367    GetCompletions,
368    GetDefinition,
369    GetTypeDefinition,
370    GetDocumentHighlights,
371    GetHover,
372    GetReferences,
373    GetProjectSymbols,
374    JoinProject,
375    LeaveProject,
376    OpenBufferById,
377    OpenBufferByPath,
378    OpenBufferForSymbol,
379    PerformRename,
380    OnTypeFormatting,
381    InlayHints,
382    ResolveInlayHint,
383    RefreshInlayHints,
384    PrepareRename,
385    ReloadBuffers,
386    RemoveProjectCollaborator,
387    RenameProjectEntry,
388    SaveBuffer,
389    SearchProject,
390    StartLanguageServer,
391    SynchronizeBuffers,
392    UnshareProject,
393    UpdateBuffer,
394    UpdateBufferFile,
395    UpdateDiagnosticSummary,
396    UpdateLanguageServer,
397    UpdateProject,
398    UpdateProjectCollaborator,
399    UpdateWorktree,
400    UpdateWorktreeSettings,
401    UpdateDiffBase
402);
403
404entity_messages!(
405    channel_id,
406    ChannelMessageSent,
407    UpdateChannelBuffer,
408    RemoveChannelMessage,
409    UpdateChannelBufferCollaborators
410);
411
412const KIB: usize = 1024;
413const MIB: usize = KIB * 1024;
414const MAX_BUFFER_LEN: usize = MIB;
415
416/// A stream of protobuf messages.
417pub struct MessageStream<S> {
418    stream: S,
419    encoding_buffer: Vec<u8>,
420}
421
422#[allow(clippy::large_enum_variant)]
423#[derive(Debug)]
424pub enum Message {
425    Envelope(Envelope),
426    Ping,
427    Pong,
428}
429
430impl<S> MessageStream<S> {
431    pub fn new(stream: S) -> Self {
432        Self {
433            stream,
434            encoding_buffer: Vec::new(),
435        }
436    }
437
438    pub fn inner_mut(&mut self) -> &mut S {
439        &mut self.stream
440    }
441}
442
443impl<S> MessageStream<S>
444where
445    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
446{
447    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
448        #[cfg(any(test, feature = "test-support"))]
449        const COMPRESSION_LEVEL: i32 = -7;
450
451        #[cfg(not(any(test, feature = "test-support")))]
452        const COMPRESSION_LEVEL: i32 = 4;
453
454        match message {
455            Message::Envelope(message) => {
456                self.encoding_buffer.reserve(message.encoded_len());
457                message
458                    .encode(&mut self.encoding_buffer)
459                    .map_err(io::Error::from)?;
460                let buffer =
461                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
462                        .unwrap();
463
464                self.encoding_buffer.clear();
465                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
466                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
467            }
468            Message::Ping => {
469                self.stream
470                    .send(WebSocketMessage::Ping(Default::default()))
471                    .await?;
472            }
473            Message::Pong => {
474                self.stream
475                    .send(WebSocketMessage::Pong(Default::default()))
476                    .await?;
477            }
478        }
479
480        Ok(())
481    }
482}
483
484impl<S> MessageStream<S>
485where
486    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
487{
488    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
489        while let Some(bytes) = self.stream.next().await {
490            match bytes? {
491                WebSocketMessage::Binary(bytes) => {
492                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
493                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
494                        .map_err(io::Error::from)?;
495
496                    self.encoding_buffer.clear();
497                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
498                    return Ok(Message::Envelope(envelope));
499                }
500                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
501                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
502                WebSocketMessage::Close(_) => break,
503                _ => {}
504            }
505        }
506        Err(anyhow!("connection closed"))
507    }
508}
509
510impl From<Timestamp> for SystemTime {
511    fn from(val: Timestamp) -> Self {
512        UNIX_EPOCH
513            .checked_add(Duration::new(val.seconds, val.nanos))
514            .unwrap()
515    }
516}
517
518impl From<SystemTime> for Timestamp {
519    fn from(time: SystemTime) -> Self {
520        let duration = time.duration_since(UNIX_EPOCH).unwrap();
521        Self {
522            seconds: duration.as_secs(),
523            nanos: duration.subsec_nanos(),
524        }
525    }
526}
527
528impl From<u128> for Nonce {
529    fn from(nonce: u128) -> Self {
530        let upper_half = (nonce >> 64) as u64;
531        let lower_half = nonce as u64;
532        Self {
533            upper_half,
534            lower_half,
535        }
536    }
537}
538
539impl From<Nonce> for u128 {
540    fn from(nonce: Nonce) -> Self {
541        let upper_half = (nonce.upper_half as u128) << 64;
542        let lower_half = nonce.lower_half as u128;
543        upper_half | lower_half
544    }
545}
546
547pub fn split_worktree_update(
548    mut message: UpdateWorktree,
549    max_chunk_size: usize,
550) -> impl Iterator<Item = UpdateWorktree> {
551    let mut done_files = false;
552
553    let mut repository_map = message
554        .updated_repositories
555        .into_iter()
556        .map(|repo| (repo.work_directory_id, repo))
557        .collect::<HashMap<_, _>>();
558
559    iter::from_fn(move || {
560        if done_files {
561            return None;
562        }
563
564        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
565        let updated_entries: Vec<_> = message
566            .updated_entries
567            .drain(..updated_entries_chunk_size)
568            .collect();
569
570        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
571        let removed_entries = message
572            .removed_entries
573            .drain(..removed_entries_chunk_size)
574            .collect();
575
576        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
577
578        let mut updated_repositories = Vec::new();
579
580        if !repository_map.is_empty() {
581            for entry in &updated_entries {
582                if let Some(repo) = repository_map.remove(&entry.id) {
583                    updated_repositories.push(repo)
584                }
585            }
586        }
587
588        let removed_repositories = if done_files {
589            mem::take(&mut message.removed_repositories)
590        } else {
591            Default::default()
592        };
593
594        if done_files {
595            updated_repositories.extend(mem::take(&mut repository_map).into_values());
596        }
597
598        Some(UpdateWorktree {
599            project_id: message.project_id,
600            worktree_id: message.worktree_id,
601            root_name: message.root_name.clone(),
602            abs_path: message.abs_path.clone(),
603            updated_entries,
604            removed_entries,
605            scan_id: message.scan_id,
606            is_last_update: done_files && message.is_last_update,
607            updated_repositories,
608            removed_repositories,
609        })
610    })
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    #[gpui::test]
618    async fn test_buffer_size() {
619        let (tx, rx) = futures::channel::mpsc::unbounded();
620        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
621        sink.write(Message::Envelope(Envelope {
622            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
623                root_name: "abcdefg".repeat(10),
624                ..Default::default()
625            })),
626            ..Default::default()
627        }))
628        .await
629        .unwrap();
630        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
631        sink.write(Message::Envelope(Envelope {
632            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
633                root_name: "abcdefg".repeat(1000000),
634                ..Default::default()
635            })),
636            ..Default::default()
637        }))
638        .await
639        .unwrap();
640        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
641
642        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
643        stream.read().await.unwrap();
644        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
645        stream.read().await.unwrap();
646        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
647    }
648
649    #[gpui::test]
650    fn test_converting_peer_id_from_and_to_u64() {
651        let peer_id = PeerId {
652            owner_id: 10,
653            id: 3,
654        };
655        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
656        let peer_id = PeerId {
657            owner_id: u32::MAX,
658            id: 3,
659        };
660        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
661        let peer_id = PeerId {
662            owner_id: 10,
663            id: u32::MAX,
664        };
665        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
666        let peer_id = PeerId {
667            owner_id: u32::MAX,
668            id: u32::MAX,
669        };
670        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
671    }
672}