proto.rs

  1#![allow(non_snake_case)]
  2
  3use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  4use anyhow::{anyhow, Result};
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use collections::HashMap;
  7use futures::{SinkExt as _, StreamExt as _};
  8use prost::Message as _;
  9use serde::Serialize;
 10use std::any::{Any, TypeId};
 11use std::{
 12    cmp,
 13    fmt::Debug,
 14    io, iter,
 15    time::{Duration, SystemTime, UNIX_EPOCH},
 16};
 17use std::{fmt, mem};
 18
 19include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 20
 21pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 22    const NAME: &'static str;
 23    const PRIORITY: MessagePriority;
 24    fn into_envelope(
 25        self,
 26        id: u32,
 27        responding_to: Option<u32>,
 28        original_sender_id: Option<PeerId>,
 29    ) -> Envelope;
 30    fn from_envelope(envelope: Envelope) -> Option<Self>;
 31}
 32
 33pub trait EntityMessage: EnvelopedMessage {
 34    fn remote_entity_id(&self) -> u64;
 35}
 36
 37pub trait RequestMessage: EnvelopedMessage {
 38    type Response: EnvelopedMessage;
 39}
 40
 41pub trait AnyTypedEnvelope: 'static + Send + Sync {
 42    fn payload_type_id(&self) -> TypeId;
 43    fn payload_type_name(&self) -> &'static str;
 44    fn as_any(&self) -> &dyn Any;
 45    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 46    fn is_background(&self) -> bool;
 47    fn original_sender_id(&self) -> Option<PeerId>;
 48    fn sender_id(&self) -> ConnectionId;
 49    fn message_id(&self) -> u32;
 50}
 51
 52pub enum MessagePriority {
 53    Foreground,
 54    Background,
 55}
 56
 57impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 58    fn payload_type_id(&self) -> TypeId {
 59        TypeId::of::<T>()
 60    }
 61
 62    fn payload_type_name(&self) -> &'static str {
 63        T::NAME
 64    }
 65
 66    fn as_any(&self) -> &dyn Any {
 67        self
 68    }
 69
 70    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 71        self
 72    }
 73
 74    fn is_background(&self) -> bool {
 75        matches!(T::PRIORITY, MessagePriority::Background)
 76    }
 77
 78    fn original_sender_id(&self) -> Option<PeerId> {
 79        self.original_sender_id
 80    }
 81
 82    fn sender_id(&self) -> ConnectionId {
 83        self.sender_id
 84    }
 85
 86    fn message_id(&self) -> u32 {
 87        self.message_id
 88    }
 89}
 90
 91impl PeerId {
 92    pub fn from_u64(peer_id: u64) -> Self {
 93        let owner_id = (peer_id >> 32) as u32;
 94        let id = peer_id as u32;
 95        Self { owner_id, id }
 96    }
 97
 98    pub fn as_u64(self) -> u64 {
 99        ((self.owner_id as u64) << 32) | (self.id as u64)
100    }
101}
102
103impl Copy for PeerId {}
104
105impl Eq for PeerId {}
106
107impl Ord for PeerId {
108    fn cmp(&self, other: &Self) -> cmp::Ordering {
109        self.owner_id
110            .cmp(&other.owner_id)
111            .then_with(|| self.id.cmp(&other.id))
112    }
113}
114
115impl PartialOrd for PeerId {
116    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
117        Some(self.cmp(other))
118    }
119}
120
121impl std::hash::Hash for PeerId {
122    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123        self.owner_id.hash(state);
124        self.id.hash(state);
125    }
126}
127
128impl fmt::Display for PeerId {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "{}/{}", self.owner_id, self.id)
131    }
132}
133
134messages!(
135    (Ack, Foreground),
136    (AddProjectCollaborator, Foreground),
137    (ApplyCodeAction, Background),
138    (ApplyCodeActionResponse, Background),
139    (ApplyCompletionAdditionalEdits, Background),
140    (ApplyCompletionAdditionalEditsResponse, Background),
141    (BufferReloaded, Foreground),
142    (BufferSaved, Foreground),
143    (Call, Foreground),
144    (CallCanceled, Foreground),
145    (CancelCall, Foreground),
146    (CopyProjectEntry, Foreground),
147    (CreateBufferForPeer, Foreground),
148    (CreateChannel, Foreground),
149    (CreateChannelResponse, Foreground),
150    (ChannelMessageSent, Foreground),
151    (CreateProjectEntry, Foreground),
152    (CreateRoom, Foreground),
153    (CreateRoomResponse, Foreground),
154    (DeclineCall, Foreground),
155    (DeleteProjectEntry, Foreground),
156    (Error, Foreground),
157    (ExpandProjectEntry, Foreground),
158    (Follow, Foreground),
159    (FollowResponse, Foreground),
160    (FormatBuffers, Foreground),
161    (FormatBuffersResponse, Foreground),
162    (FuzzySearchUsers, Foreground),
163    (GetCodeActions, Background),
164    (GetCodeActionsResponse, Background),
165    (GetHover, Background),
166    (GetHoverResponse, Background),
167    (GetChannelMessages, Background),
168    (GetChannelMessagesResponse, Background),
169    (SendChannelMessage, Background),
170    (SendChannelMessageResponse, Background),
171    (GetCompletions, Background),
172    (GetCompletionsResponse, Background),
173    (GetDefinition, Background),
174    (GetDefinitionResponse, Background),
175    (GetTypeDefinition, Background),
176    (GetTypeDefinitionResponse, Background),
177    (GetDocumentHighlights, Background),
178    (GetDocumentHighlightsResponse, Background),
179    (GetReferences, Background),
180    (GetReferencesResponse, Background),
181    (GetProjectSymbols, Background),
182    (GetProjectSymbolsResponse, Background),
183    (GetUsers, Foreground),
184    (Hello, Foreground),
185    (IncomingCall, Foreground),
186    (InviteChannelMember, Foreground),
187    (UsersResponse, Foreground),
188    (JoinProject, Foreground),
189    (JoinProjectResponse, Foreground),
190    (JoinRoom, Foreground),
191    (JoinRoomResponse, Foreground),
192    (JoinChannelChat, Foreground),
193    (JoinChannelChatResponse, Foreground),
194    (LeaveChannelChat, Foreground),
195    (LeaveProject, Foreground),
196    (LeaveRoom, Foreground),
197    (OpenBufferById, Background),
198    (OpenBufferByPath, Background),
199    (OpenBufferForSymbol, Background),
200    (OpenBufferForSymbolResponse, Background),
201    (OpenBufferResponse, Background),
202    (PerformRename, Background),
203    (PerformRenameResponse, Background),
204    (OnTypeFormatting, Background),
205    (OnTypeFormattingResponse, Background),
206    (InlayHints, Background),
207    (InlayHintsResponse, Background),
208    (ResolveInlayHint, Background),
209    (ResolveInlayHintResponse, Background),
210    (RefreshInlayHints, Foreground),
211    (Ping, Foreground),
212    (PrepareRename, Background),
213    (PrepareRenameResponse, Background),
214    (ExpandProjectEntryResponse, Foreground),
215    (ProjectEntryResponse, Foreground),
216    (RejoinRoom, Foreground),
217    (RejoinRoomResponse, Foreground),
218    (RemoveContact, Foreground),
219    (RemoveChannelMember, Foreground),
220    (RemoveChannelMessage, Foreground),
221    (ReloadBuffers, Foreground),
222    (ReloadBuffersResponse, Foreground),
223    (RemoveProjectCollaborator, Foreground),
224    (RenameProjectEntry, Foreground),
225    (RequestContact, Foreground),
226    (RespondToContactRequest, Foreground),
227    (RespondToChannelInvite, Foreground),
228    (JoinChannel, Foreground),
229    (RoomUpdated, Foreground),
230    (SaveBuffer, Foreground),
231    (RenameChannel, Foreground),
232    (RenameChannelResponse, Foreground),
233    (SetChannelMemberAdmin, Foreground),
234    (SearchProject, Background),
235    (SearchProjectResponse, Background),
236    (ShareProject, Foreground),
237    (ShareProjectResponse, Foreground),
238    (ShowContacts, Foreground),
239    (StartLanguageServer, Foreground),
240    (SynchronizeBuffers, Foreground),
241    (SynchronizeBuffersResponse, Foreground),
242    (RejoinChannelBuffers, Foreground),
243    (RejoinChannelBuffersResponse, Foreground),
244    (Test, Foreground),
245    (Unfollow, Foreground),
246    (UnshareProject, Foreground),
247    (UpdateBuffer, Foreground),
248    (UpdateBufferFile, Foreground),
249    (UpdateContacts, Foreground),
250    (DeleteChannel, Foreground),
251    (MoveChannel, Foreground),
252    (LinkChannel, Foreground),
253    (UnlinkChannel, Foreground),
254    (UpdateChannels, Foreground),
255    (UpdateDiagnosticSummary, Foreground),
256    (UpdateFollowers, Foreground),
257    (UpdateInviteInfo, Foreground),
258    (UpdateLanguageServer, Foreground),
259    (UpdateParticipantLocation, Foreground),
260    (UpdateProject, Foreground),
261    (UpdateProjectCollaborator, Foreground),
262    (UpdateWorktree, Foreground),
263    (UpdateWorktreeSettings, Foreground),
264    (UpdateDiffBase, Foreground),
265    (GetPrivateUserInfo, Foreground),
266    (GetPrivateUserInfoResponse, Foreground),
267    (GetChannelMembers, Foreground),
268    (GetChannelMembersResponse, Foreground),
269    (JoinChannelBuffer, Foreground),
270    (JoinChannelBufferResponse, Foreground),
271    (LeaveChannelBuffer, Background),
272    (UpdateChannelBuffer, Foreground),
273    (RemoveChannelBufferCollaborator, Foreground),
274    (AddChannelBufferCollaborator, Foreground),
275    (UpdateChannelBufferCollaborator, Foreground),
276);
277
278request_messages!(
279    (ApplyCodeAction, ApplyCodeActionResponse),
280    (
281        ApplyCompletionAdditionalEdits,
282        ApplyCompletionAdditionalEditsResponse
283    ),
284    (Call, Ack),
285    (CancelCall, Ack),
286    (CopyProjectEntry, ProjectEntryResponse),
287    (CreateProjectEntry, ProjectEntryResponse),
288    (CreateRoom, CreateRoomResponse),
289    (CreateChannel, CreateChannelResponse),
290    (DeclineCall, Ack),
291    (DeleteProjectEntry, ProjectEntryResponse),
292    (ExpandProjectEntry, ExpandProjectEntryResponse),
293    (Follow, FollowResponse),
294    (FormatBuffers, FormatBuffersResponse),
295    (GetCodeActions, GetCodeActionsResponse),
296    (GetHover, GetHoverResponse),
297    (GetCompletions, GetCompletionsResponse),
298    (GetDefinition, GetDefinitionResponse),
299    (GetTypeDefinition, GetTypeDefinitionResponse),
300    (GetDocumentHighlights, GetDocumentHighlightsResponse),
301    (GetReferences, GetReferencesResponse),
302    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
303    (GetProjectSymbols, GetProjectSymbolsResponse),
304    (FuzzySearchUsers, UsersResponse),
305    (GetUsers, UsersResponse),
306    (InviteChannelMember, Ack),
307    (JoinProject, JoinProjectResponse),
308    (JoinRoom, JoinRoomResponse),
309    (JoinChannelChat, JoinChannelChatResponse),
310    (LeaveRoom, Ack),
311    (RejoinRoom, RejoinRoomResponse),
312    (IncomingCall, Ack),
313    (OpenBufferById, OpenBufferResponse),
314    (OpenBufferByPath, OpenBufferResponse),
315    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
316    (Ping, Ack),
317    (PerformRename, PerformRenameResponse),
318    (PrepareRename, PrepareRenameResponse),
319    (OnTypeFormatting, OnTypeFormattingResponse),
320    (InlayHints, InlayHintsResponse),
321    (ResolveInlayHint, ResolveInlayHintResponse),
322    (RefreshInlayHints, Ack),
323    (ReloadBuffers, ReloadBuffersResponse),
324    (RequestContact, Ack),
325    (RemoveChannelMember, Ack),
326    (RemoveContact, Ack),
327    (RespondToContactRequest, Ack),
328    (RespondToChannelInvite, Ack),
329    (SetChannelMemberAdmin, Ack),
330    (SendChannelMessage, SendChannelMessageResponse),
331    (GetChannelMessages, GetChannelMessagesResponse),
332    (GetChannelMembers, GetChannelMembersResponse),
333    (JoinChannel, JoinRoomResponse),
334    (RemoveChannelMessage, Ack),
335    (DeleteChannel, Ack),
336    (RenameProjectEntry, ProjectEntryResponse),
337    (RenameChannel, RenameChannelResponse),
338    (LinkChannel, Ack),
339    (UnlinkChannel, Ack),
340    (MoveChannel, Ack),
341    (SaveBuffer, BufferSaved),
342    (SearchProject, SearchProjectResponse),
343    (ShareProject, ShareProjectResponse),
344    (SynchronizeBuffers, SynchronizeBuffersResponse),
345    (RejoinChannelBuffers, RejoinChannelBuffersResponse),
346    (Test, Test),
347    (UpdateBuffer, Ack),
348    (UpdateParticipantLocation, Ack),
349    (UpdateProject, Ack),
350    (UpdateWorktree, Ack),
351    (JoinChannelBuffer, JoinChannelBufferResponse),
352    (LeaveChannelBuffer, Ack)
353);
354
355entity_messages!(
356    project_id,
357    AddProjectCollaborator,
358    ApplyCodeAction,
359    ApplyCompletionAdditionalEdits,
360    BufferReloaded,
361    BufferSaved,
362    CopyProjectEntry,
363    CreateBufferForPeer,
364    CreateProjectEntry,
365    DeleteProjectEntry,
366    ExpandProjectEntry,
367    Follow,
368    FormatBuffers,
369    GetCodeActions,
370    GetCompletions,
371    GetDefinition,
372    GetTypeDefinition,
373    GetDocumentHighlights,
374    GetHover,
375    GetReferences,
376    GetProjectSymbols,
377    JoinProject,
378    LeaveProject,
379    OpenBufferById,
380    OpenBufferByPath,
381    OpenBufferForSymbol,
382    PerformRename,
383    OnTypeFormatting,
384    InlayHints,
385    ResolveInlayHint,
386    RefreshInlayHints,
387    PrepareRename,
388    ReloadBuffers,
389    RemoveProjectCollaborator,
390    RenameProjectEntry,
391    SaveBuffer,
392    SearchProject,
393    StartLanguageServer,
394    SynchronizeBuffers,
395    Unfollow,
396    UnshareProject,
397    UpdateBuffer,
398    UpdateBufferFile,
399    UpdateDiagnosticSummary,
400    UpdateFollowers,
401    UpdateLanguageServer,
402    UpdateProject,
403    UpdateProjectCollaborator,
404    UpdateWorktree,
405    UpdateWorktreeSettings,
406    UpdateDiffBase
407);
408
409entity_messages!(
410    channel_id,
411    ChannelMessageSent,
412    UpdateChannelBuffer,
413    RemoveChannelBufferCollaborator,
414    RemoveChannelMessage,
415    AddChannelBufferCollaborator,
416    UpdateChannelBufferCollaborator
417);
418
419const KIB: usize = 1024;
420const MIB: usize = KIB * 1024;
421const MAX_BUFFER_LEN: usize = MIB;
422
423/// A stream of protobuf messages.
424pub struct MessageStream<S> {
425    stream: S,
426    encoding_buffer: Vec<u8>,
427}
428
429#[allow(clippy::large_enum_variant)]
430#[derive(Debug)]
431pub enum Message {
432    Envelope(Envelope),
433    Ping,
434    Pong,
435}
436
437impl<S> MessageStream<S> {
438    pub fn new(stream: S) -> Self {
439        Self {
440            stream,
441            encoding_buffer: Vec::new(),
442        }
443    }
444
445    pub fn inner_mut(&mut self) -> &mut S {
446        &mut self.stream
447    }
448}
449
450impl<S> MessageStream<S>
451where
452    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
453{
454    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
455        #[cfg(any(test, feature = "test-support"))]
456        const COMPRESSION_LEVEL: i32 = -7;
457
458        #[cfg(not(any(test, feature = "test-support")))]
459        const COMPRESSION_LEVEL: i32 = 4;
460
461        match message {
462            Message::Envelope(message) => {
463                self.encoding_buffer.reserve(message.encoded_len());
464                message
465                    .encode(&mut self.encoding_buffer)
466                    .map_err(io::Error::from)?;
467                let buffer =
468                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
469                        .unwrap();
470
471                self.encoding_buffer.clear();
472                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
473                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
474            }
475            Message::Ping => {
476                self.stream
477                    .send(WebSocketMessage::Ping(Default::default()))
478                    .await?;
479            }
480            Message::Pong => {
481                self.stream
482                    .send(WebSocketMessage::Pong(Default::default()))
483                    .await?;
484            }
485        }
486
487        Ok(())
488    }
489}
490
491impl<S> MessageStream<S>
492where
493    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
494{
495    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
496        while let Some(bytes) = self.stream.next().await {
497            match bytes? {
498                WebSocketMessage::Binary(bytes) => {
499                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
500                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
501                        .map_err(io::Error::from)?;
502
503                    self.encoding_buffer.clear();
504                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
505                    return Ok(Message::Envelope(envelope));
506                }
507                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
508                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
509                WebSocketMessage::Close(_) => break,
510                _ => {}
511            }
512        }
513        Err(anyhow!("connection closed"))
514    }
515}
516
517impl From<Timestamp> for SystemTime {
518    fn from(val: Timestamp) -> Self {
519        UNIX_EPOCH
520            .checked_add(Duration::new(val.seconds, val.nanos))
521            .unwrap()
522    }
523}
524
525impl From<SystemTime> for Timestamp {
526    fn from(time: SystemTime) -> Self {
527        let duration = time.duration_since(UNIX_EPOCH).unwrap();
528        Self {
529            seconds: duration.as_secs(),
530            nanos: duration.subsec_nanos(),
531        }
532    }
533}
534
535impl From<u128> for Nonce {
536    fn from(nonce: u128) -> Self {
537        let upper_half = (nonce >> 64) as u64;
538        let lower_half = nonce as u64;
539        Self {
540            upper_half,
541            lower_half,
542        }
543    }
544}
545
546impl From<Nonce> for u128 {
547    fn from(nonce: Nonce) -> Self {
548        let upper_half = (nonce.upper_half as u128) << 64;
549        let lower_half = nonce.lower_half as u128;
550        upper_half | lower_half
551    }
552}
553
554pub fn split_worktree_update(
555    mut message: UpdateWorktree,
556    max_chunk_size: usize,
557) -> impl Iterator<Item = UpdateWorktree> {
558    let mut done_files = false;
559
560    let mut repository_map = message
561        .updated_repositories
562        .into_iter()
563        .map(|repo| (repo.work_directory_id, repo))
564        .collect::<HashMap<_, _>>();
565
566    iter::from_fn(move || {
567        if done_files {
568            return None;
569        }
570
571        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
572        let updated_entries: Vec<_> = message
573            .updated_entries
574            .drain(..updated_entries_chunk_size)
575            .collect();
576
577        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
578        let removed_entries = message
579            .removed_entries
580            .drain(..removed_entries_chunk_size)
581            .collect();
582
583        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
584
585        let mut updated_repositories = Vec::new();
586
587        if !repository_map.is_empty() {
588            for entry in &updated_entries {
589                if let Some(repo) = repository_map.remove(&entry.id) {
590                    updated_repositories.push(repo)
591                }
592            }
593        }
594
595        let removed_repositories = if done_files {
596            mem::take(&mut message.removed_repositories)
597        } else {
598            Default::default()
599        };
600
601        if done_files {
602            updated_repositories.extend(mem::take(&mut repository_map).into_values());
603        }
604
605        Some(UpdateWorktree {
606            project_id: message.project_id,
607            worktree_id: message.worktree_id,
608            root_name: message.root_name.clone(),
609            abs_path: message.abs_path.clone(),
610            updated_entries,
611            removed_entries,
612            scan_id: message.scan_id,
613            is_last_update: done_files && message.is_last_update,
614            updated_repositories,
615            removed_repositories,
616        })
617    })
618}
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623
624    #[gpui::test]
625    async fn test_buffer_size() {
626        let (tx, rx) = futures::channel::mpsc::unbounded();
627        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
628        sink.write(Message::Envelope(Envelope {
629            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
630                root_name: "abcdefg".repeat(10),
631                ..Default::default()
632            })),
633            ..Default::default()
634        }))
635        .await
636        .unwrap();
637        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
638        sink.write(Message::Envelope(Envelope {
639            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
640                root_name: "abcdefg".repeat(1000000),
641                ..Default::default()
642            })),
643            ..Default::default()
644        }))
645        .await
646        .unwrap();
647        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
648
649        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
650        stream.read().await.unwrap();
651        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
652        stream.read().await.unwrap();
653        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
654    }
655
656    #[gpui::test]
657    fn test_converting_peer_id_from_and_to_u64() {
658        let peer_id = PeerId {
659            owner_id: 10,
660            id: 3,
661        };
662        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
663        let peer_id = PeerId {
664            owner_id: u32::MAX,
665            id: 3,
666        };
667        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
668        let peer_id = PeerId {
669            owner_id: 10,
670            id: u32::MAX,
671        };
672        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
673        let peer_id = PeerId {
674            owner_id: u32::MAX,
675            id: u32::MAX,
676        };
677        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
678    }
679}