proto.rs

  1#![allow(non_snake_case)]
  2
  3use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  4use anyhow::{anyhow, Result};
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use collections::HashMap;
  7use futures::{SinkExt as _, StreamExt as _};
  8use prost::Message as _;
  9use serde::Serialize;
 10use std::any::{Any, TypeId};
 11use std::{
 12    cmp,
 13    fmt::Debug,
 14    io, iter,
 15    time::{Duration, SystemTime, UNIX_EPOCH},
 16};
 17use std::{fmt, mem};
 18
 19include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 20
 21pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 22    const NAME: &'static str;
 23    const PRIORITY: MessagePriority;
 24    fn into_envelope(
 25        self,
 26        id: u32,
 27        responding_to: Option<u32>,
 28        original_sender_id: Option<PeerId>,
 29    ) -> Envelope;
 30    fn from_envelope(envelope: Envelope) -> Option<Self>;
 31}
 32
 33pub trait EntityMessage: EnvelopedMessage {
 34    fn remote_entity_id(&self) -> u64;
 35}
 36
 37pub trait RequestMessage: EnvelopedMessage {
 38    type Response: EnvelopedMessage;
 39}
 40
 41pub trait AnyTypedEnvelope: 'static + Send + Sync {
 42    fn payload_type_id(&self) -> TypeId;
 43    fn payload_type_name(&self) -> &'static str;
 44    fn as_any(&self) -> &dyn Any;
 45    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 46    fn is_background(&self) -> bool;
 47    fn original_sender_id(&self) -> Option<PeerId>;
 48    fn sender_id(&self) -> ConnectionId;
 49    fn message_id(&self) -> u32;
 50}
 51
 52pub enum MessagePriority {
 53    Foreground,
 54    Background,
 55}
 56
 57impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 58    fn payload_type_id(&self) -> TypeId {
 59        TypeId::of::<T>()
 60    }
 61
 62    fn payload_type_name(&self) -> &'static str {
 63        T::NAME
 64    }
 65
 66    fn as_any(&self) -> &dyn Any {
 67        self
 68    }
 69
 70    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 71        self
 72    }
 73
 74    fn is_background(&self) -> bool {
 75        matches!(T::PRIORITY, MessagePriority::Background)
 76    }
 77
 78    fn original_sender_id(&self) -> Option<PeerId> {
 79        self.original_sender_id
 80    }
 81
 82    fn sender_id(&self) -> ConnectionId {
 83        self.sender_id
 84    }
 85
 86    fn message_id(&self) -> u32 {
 87        self.message_id
 88    }
 89}
 90
 91impl PeerId {
 92    pub fn from_u64(peer_id: u64) -> Self {
 93        let owner_id = (peer_id >> 32) as u32;
 94        let id = peer_id as u32;
 95        Self { owner_id, id }
 96    }
 97
 98    pub fn as_u64(self) -> u64 {
 99        ((self.owner_id as u64) << 32) | (self.id as u64)
100    }
101}
102
103impl Copy for PeerId {}
104
105impl Eq for PeerId {}
106
107impl Ord for PeerId {
108    fn cmp(&self, other: &Self) -> cmp::Ordering {
109        self.owner_id
110            .cmp(&other.owner_id)
111            .then_with(|| self.id.cmp(&other.id))
112    }
113}
114
115impl PartialOrd for PeerId {
116    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
117        Some(self.cmp(other))
118    }
119}
120
121impl std::hash::Hash for PeerId {
122    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123        self.owner_id.hash(state);
124        self.id.hash(state);
125    }
126}
127
128impl fmt::Display for PeerId {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "{}/{}", self.owner_id, self.id)
131    }
132}
133
134messages!(
135    (Ack, Foreground),
136    (AddProjectCollaborator, Foreground),
137    (ApplyCodeAction, Background),
138    (ApplyCodeActionResponse, Background),
139    (ApplyCompletionAdditionalEdits, Background),
140    (ApplyCompletionAdditionalEditsResponse, Background),
141    (BufferReloaded, Foreground),
142    (BufferSaved, Foreground),
143    (Call, Foreground),
144    (CallCanceled, Foreground),
145    (CancelCall, Foreground),
146    (CopyProjectEntry, Foreground),
147    (CreateBufferForPeer, Foreground),
148    (CreateChannel, Foreground),
149    (ChannelResponse, Foreground),
150    (ChannelMessageSent, Foreground),
151    (CreateProjectEntry, Foreground),
152    (CreateRoom, Foreground),
153    (CreateRoomResponse, Foreground),
154    (DeclineCall, Foreground),
155    (DeleteProjectEntry, Foreground),
156    (Error, Foreground),
157    (ExpandProjectEntry, Foreground),
158    (Follow, Foreground),
159    (FollowResponse, Foreground),
160    (FormatBuffers, Foreground),
161    (FormatBuffersResponse, Foreground),
162    (FuzzySearchUsers, Foreground),
163    (GetCodeActions, Background),
164    (GetCodeActionsResponse, Background),
165    (GetHover, Background),
166    (GetHoverResponse, Background),
167    (GetChannelMessages, Background),
168    (GetChannelMessagesResponse, Background),
169    (SendChannelMessage, Background),
170    (SendChannelMessageResponse, Background),
171    (GetCompletions, Background),
172    (GetCompletionsResponse, Background),
173    (GetDefinition, Background),
174    (GetDefinitionResponse, Background),
175    (GetTypeDefinition, Background),
176    (GetTypeDefinitionResponse, Background),
177    (GetDocumentHighlights, Background),
178    (GetDocumentHighlightsResponse, Background),
179    (GetReferences, Background),
180    (GetReferencesResponse, Background),
181    (GetProjectSymbols, Background),
182    (GetProjectSymbolsResponse, Background),
183    (GetUsers, Foreground),
184    (Hello, Foreground),
185    (IncomingCall, Foreground),
186    (InviteChannelMember, Foreground),
187    (UsersResponse, Foreground),
188    (JoinProject, Foreground),
189    (JoinProjectResponse, Foreground),
190    (JoinRoom, Foreground),
191    (JoinRoomResponse, Foreground),
192    (JoinChannelChat, Foreground),
193    (JoinChannelChatResponse, Foreground),
194    (LeaveChannelChat, Foreground),
195    (LeaveProject, Foreground),
196    (LeaveRoom, Foreground),
197    (OpenBufferById, Background),
198    (OpenBufferByPath, Background),
199    (OpenBufferForSymbol, Background),
200    (OpenBufferForSymbolResponse, Background),
201    (OpenBufferResponse, Background),
202    (PerformRename, Background),
203    (PerformRenameResponse, Background),
204    (OnTypeFormatting, Background),
205    (OnTypeFormattingResponse, Background),
206    (InlayHints, Background),
207    (InlayHintsResponse, Background),
208    (ResolveInlayHint, Background),
209    (ResolveInlayHintResponse, Background),
210    (RefreshInlayHints, Foreground),
211    (Ping, Foreground),
212    (PrepareRename, Background),
213    (PrepareRenameResponse, Background),
214    (ExpandProjectEntryResponse, Foreground),
215    (ProjectEntryResponse, Foreground),
216    (RejoinRoom, Foreground),
217    (RejoinRoomResponse, Foreground),
218    (RemoveContact, Foreground),
219    (RemoveChannelMember, Foreground),
220    (RemoveChannelMessage, Foreground),
221    (ReloadBuffers, Foreground),
222    (ReloadBuffersResponse, Foreground),
223    (RemoveProjectCollaborator, Foreground),
224    (RenameProjectEntry, Foreground),
225    (RequestContact, Foreground),
226    (RespondToContactRequest, Foreground),
227    (RespondToChannelInvite, Foreground),
228    (JoinChannel, Foreground),
229    (RoomUpdated, Foreground),
230    (SaveBuffer, Foreground),
231    (RenameChannel, Foreground),
232    (SetChannelMemberAdmin, Foreground),
233    (SearchProject, Background),
234    (SearchProjectResponse, Background),
235    (ShareProject, Foreground),
236    (ShareProjectResponse, Foreground),
237    (ShowContacts, Foreground),
238    (StartLanguageServer, Foreground),
239    (SynchronizeBuffers, Foreground),
240    (SynchronizeBuffersResponse, Foreground),
241    (RejoinChannelBuffers, Foreground),
242    (RejoinChannelBuffersResponse, Foreground),
243    (Test, Foreground),
244    (Unfollow, Foreground),
245    (UnshareProject, Foreground),
246    (UpdateBuffer, Foreground),
247    (UpdateBufferFile, Foreground),
248    (UpdateContacts, Foreground),
249    (RemoveChannel, Foreground),
250    (UpdateChannels, Foreground),
251    (UpdateDiagnosticSummary, Foreground),
252    (UpdateFollowers, Foreground),
253    (UpdateInviteInfo, Foreground),
254    (UpdateLanguageServer, Foreground),
255    (UpdateParticipantLocation, Foreground),
256    (UpdateProject, Foreground),
257    (UpdateProjectCollaborator, Foreground),
258    (UpdateWorktree, Foreground),
259    (UpdateWorktreeSettings, Foreground),
260    (UpdateDiffBase, Foreground),
261    (GetPrivateUserInfo, Foreground),
262    (GetPrivateUserInfoResponse, Foreground),
263    (GetChannelMembers, Foreground),
264    (GetChannelMembersResponse, Foreground),
265    (JoinChannelBuffer, Foreground),
266    (JoinChannelBufferResponse, Foreground),
267    (LeaveChannelBuffer, Background),
268    (UpdateChannelBuffer, Foreground),
269    (RemoveChannelBufferCollaborator, Foreground),
270    (AddChannelBufferCollaborator, Foreground),
271    (UpdateChannelBufferCollaborator, Foreground),
272);
273
274request_messages!(
275    (ApplyCodeAction, ApplyCodeActionResponse),
276    (
277        ApplyCompletionAdditionalEdits,
278        ApplyCompletionAdditionalEditsResponse
279    ),
280    (Call, Ack),
281    (CancelCall, Ack),
282    (CopyProjectEntry, ProjectEntryResponse),
283    (CreateProjectEntry, ProjectEntryResponse),
284    (CreateRoom, CreateRoomResponse),
285    (CreateChannel, ChannelResponse),
286    (DeclineCall, Ack),
287    (DeleteProjectEntry, ProjectEntryResponse),
288    (ExpandProjectEntry, ExpandProjectEntryResponse),
289    (Follow, FollowResponse),
290    (FormatBuffers, FormatBuffersResponse),
291    (GetCodeActions, GetCodeActionsResponse),
292    (GetHover, GetHoverResponse),
293    (GetCompletions, GetCompletionsResponse),
294    (GetDefinition, GetDefinitionResponse),
295    (GetTypeDefinition, GetTypeDefinitionResponse),
296    (GetDocumentHighlights, GetDocumentHighlightsResponse),
297    (GetReferences, GetReferencesResponse),
298    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
299    (GetProjectSymbols, GetProjectSymbolsResponse),
300    (FuzzySearchUsers, UsersResponse),
301    (GetUsers, UsersResponse),
302    (InviteChannelMember, Ack),
303    (JoinProject, JoinProjectResponse),
304    (JoinRoom, JoinRoomResponse),
305    (JoinChannelChat, JoinChannelChatResponse),
306    (LeaveRoom, Ack),
307    (RejoinRoom, RejoinRoomResponse),
308    (IncomingCall, Ack),
309    (OpenBufferById, OpenBufferResponse),
310    (OpenBufferByPath, OpenBufferResponse),
311    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
312    (Ping, Ack),
313    (PerformRename, PerformRenameResponse),
314    (PrepareRename, PrepareRenameResponse),
315    (OnTypeFormatting, OnTypeFormattingResponse),
316    (InlayHints, InlayHintsResponse),
317    (ResolveInlayHint, ResolveInlayHintResponse),
318    (RefreshInlayHints, Ack),
319    (ReloadBuffers, ReloadBuffersResponse),
320    (RequestContact, Ack),
321    (RemoveChannelMember, Ack),
322    (RemoveContact, Ack),
323    (RespondToContactRequest, Ack),
324    (RespondToChannelInvite, Ack),
325    (SetChannelMemberAdmin, Ack),
326    (SendChannelMessage, SendChannelMessageResponse),
327    (GetChannelMessages, GetChannelMessagesResponse),
328    (GetChannelMembers, GetChannelMembersResponse),
329    (JoinChannel, JoinRoomResponse),
330    (RemoveChannel, Ack),
331    (RemoveChannelMessage, Ack),
332    (RenameProjectEntry, ProjectEntryResponse),
333    (RenameChannel, ChannelResponse),
334    (SaveBuffer, BufferSaved),
335    (SearchProject, SearchProjectResponse),
336    (ShareProject, ShareProjectResponse),
337    (SynchronizeBuffers, SynchronizeBuffersResponse),
338    (RejoinChannelBuffers, RejoinChannelBuffersResponse),
339    (Test, Test),
340    (UpdateBuffer, Ack),
341    (UpdateParticipantLocation, Ack),
342    (UpdateProject, Ack),
343    (UpdateWorktree, Ack),
344    (JoinChannelBuffer, JoinChannelBufferResponse),
345    (LeaveChannelBuffer, Ack)
346);
347
348entity_messages!(
349    project_id,
350    AddProjectCollaborator,
351    ApplyCodeAction,
352    ApplyCompletionAdditionalEdits,
353    BufferReloaded,
354    BufferSaved,
355    CopyProjectEntry,
356    CreateBufferForPeer,
357    CreateProjectEntry,
358    DeleteProjectEntry,
359    ExpandProjectEntry,
360    Follow,
361    FormatBuffers,
362    GetCodeActions,
363    GetCompletions,
364    GetDefinition,
365    GetTypeDefinition,
366    GetDocumentHighlights,
367    GetHover,
368    GetReferences,
369    GetProjectSymbols,
370    JoinProject,
371    LeaveProject,
372    OpenBufferById,
373    OpenBufferByPath,
374    OpenBufferForSymbol,
375    PerformRename,
376    OnTypeFormatting,
377    InlayHints,
378    ResolveInlayHint,
379    RefreshInlayHints,
380    PrepareRename,
381    ReloadBuffers,
382    RemoveProjectCollaborator,
383    RenameProjectEntry,
384    SaveBuffer,
385    SearchProject,
386    StartLanguageServer,
387    SynchronizeBuffers,
388    Unfollow,
389    UnshareProject,
390    UpdateBuffer,
391    UpdateBufferFile,
392    UpdateDiagnosticSummary,
393    UpdateFollowers,
394    UpdateLanguageServer,
395    UpdateProject,
396    UpdateProjectCollaborator,
397    UpdateWorktree,
398    UpdateWorktreeSettings,
399    UpdateDiffBase
400);
401
402entity_messages!(
403    channel_id,
404    ChannelMessageSent,
405    UpdateChannelBuffer,
406    RemoveChannelBufferCollaborator,
407    RemoveChannelMessage,
408    AddChannelBufferCollaborator,
409    UpdateChannelBufferCollaborator
410);
411
412const KIB: usize = 1024;
413const MIB: usize = KIB * 1024;
414const MAX_BUFFER_LEN: usize = MIB;
415
416/// A stream of protobuf messages.
417pub struct MessageStream<S> {
418    stream: S,
419    encoding_buffer: Vec<u8>,
420}
421
422#[allow(clippy::large_enum_variant)]
423#[derive(Debug)]
424pub enum Message {
425    Envelope(Envelope),
426    Ping,
427    Pong,
428}
429
430impl<S> MessageStream<S> {
431    pub fn new(stream: S) -> Self {
432        Self {
433            stream,
434            encoding_buffer: Vec::new(),
435        }
436    }
437
438    pub fn inner_mut(&mut self) -> &mut S {
439        &mut self.stream
440    }
441}
442
443impl<S> MessageStream<S>
444where
445    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
446{
447    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
448        #[cfg(any(test, feature = "test-support"))]
449        const COMPRESSION_LEVEL: i32 = -7;
450
451        #[cfg(not(any(test, feature = "test-support")))]
452        const COMPRESSION_LEVEL: i32 = 4;
453
454        match message {
455            Message::Envelope(message) => {
456                self.encoding_buffer.reserve(message.encoded_len());
457                message
458                    .encode(&mut self.encoding_buffer)
459                    .map_err(io::Error::from)?;
460                let buffer =
461                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
462                        .unwrap();
463
464                self.encoding_buffer.clear();
465                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
466                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
467            }
468            Message::Ping => {
469                self.stream
470                    .send(WebSocketMessage::Ping(Default::default()))
471                    .await?;
472            }
473            Message::Pong => {
474                self.stream
475                    .send(WebSocketMessage::Pong(Default::default()))
476                    .await?;
477            }
478        }
479
480        Ok(())
481    }
482}
483
484impl<S> MessageStream<S>
485where
486    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
487{
488    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
489        while let Some(bytes) = self.stream.next().await {
490            match bytes? {
491                WebSocketMessage::Binary(bytes) => {
492                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
493                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
494                        .map_err(io::Error::from)?;
495
496                    self.encoding_buffer.clear();
497                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
498                    return Ok(Message::Envelope(envelope));
499                }
500                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
501                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
502                WebSocketMessage::Close(_) => break,
503                _ => {}
504            }
505        }
506        Err(anyhow!("connection closed"))
507    }
508}
509
510impl From<Timestamp> for SystemTime {
511    fn from(val: Timestamp) -> Self {
512        UNIX_EPOCH
513            .checked_add(Duration::new(val.seconds, val.nanos))
514            .unwrap()
515    }
516}
517
518impl From<SystemTime> for Timestamp {
519    fn from(time: SystemTime) -> Self {
520        let duration = time.duration_since(UNIX_EPOCH).unwrap();
521        Self {
522            seconds: duration.as_secs(),
523            nanos: duration.subsec_nanos(),
524        }
525    }
526}
527
528impl From<u128> for Nonce {
529    fn from(nonce: u128) -> Self {
530        let upper_half = (nonce >> 64) as u64;
531        let lower_half = nonce as u64;
532        Self {
533            upper_half,
534            lower_half,
535        }
536    }
537}
538
539impl From<Nonce> for u128 {
540    fn from(nonce: Nonce) -> Self {
541        let upper_half = (nonce.upper_half as u128) << 64;
542        let lower_half = nonce.lower_half as u128;
543        upper_half | lower_half
544    }
545}
546
547pub fn split_worktree_update(
548    mut message: UpdateWorktree,
549    max_chunk_size: usize,
550) -> impl Iterator<Item = UpdateWorktree> {
551    let mut done_files = false;
552
553    let mut repository_map = message
554        .updated_repositories
555        .into_iter()
556        .map(|repo| (repo.work_directory_id, repo))
557        .collect::<HashMap<_, _>>();
558
559    iter::from_fn(move || {
560        if done_files {
561            return None;
562        }
563
564        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
565        let updated_entries: Vec<_> = message
566            .updated_entries
567            .drain(..updated_entries_chunk_size)
568            .collect();
569
570        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
571        let removed_entries = message
572            .removed_entries
573            .drain(..removed_entries_chunk_size)
574            .collect();
575
576        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
577
578        let mut updated_repositories = Vec::new();
579
580        if !repository_map.is_empty() {
581            for entry in &updated_entries {
582                if let Some(repo) = repository_map.remove(&entry.id) {
583                    updated_repositories.push(repo)
584                }
585            }
586        }
587
588        let removed_repositories = if done_files {
589            mem::take(&mut message.removed_repositories)
590        } else {
591            Default::default()
592        };
593
594        if done_files {
595            updated_repositories.extend(mem::take(&mut repository_map).into_values());
596        }
597
598        Some(UpdateWorktree {
599            project_id: message.project_id,
600            worktree_id: message.worktree_id,
601            root_name: message.root_name.clone(),
602            abs_path: message.abs_path.clone(),
603            updated_entries,
604            removed_entries,
605            scan_id: message.scan_id,
606            is_last_update: done_files && message.is_last_update,
607            updated_repositories,
608            removed_repositories,
609        })
610    })
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    #[gpui::test]
618    async fn test_buffer_size() {
619        let (tx, rx) = futures::channel::mpsc::unbounded();
620        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
621        sink.write(Message::Envelope(Envelope {
622            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
623                root_name: "abcdefg".repeat(10),
624                ..Default::default()
625            })),
626            ..Default::default()
627        }))
628        .await
629        .unwrap();
630        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
631        sink.write(Message::Envelope(Envelope {
632            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
633                root_name: "abcdefg".repeat(1000000),
634                ..Default::default()
635            })),
636            ..Default::default()
637        }))
638        .await
639        .unwrap();
640        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
641
642        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
643        stream.read().await.unwrap();
644        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
645        stream.read().await.unwrap();
646        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
647    }
648
649    #[gpui::test]
650    fn test_converting_peer_id_from_and_to_u64() {
651        let peer_id = PeerId {
652            owner_id: 10,
653            id: 3,
654        };
655        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
656        let peer_id = PeerId {
657            owner_id: u32::MAX,
658            id: 3,
659        };
660        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
661        let peer_id = PeerId {
662            owner_id: 10,
663            id: u32::MAX,
664        };
665        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
666        let peer_id = PeerId {
667            owner_id: u32::MAX,
668            id: u32::MAX,
669        };
670        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
671    }
672}