proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{
  9    cmp,
 10    fmt::Debug,
 11    io, iter,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14use std::{fmt, mem};
 15
 16include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 17
 18pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 19    const NAME: &'static str;
 20    const PRIORITY: MessagePriority;
 21    fn into_envelope(
 22        self,
 23        id: u32,
 24        responding_to: Option<u32>,
 25        original_sender_id: Option<PeerId>,
 26    ) -> Envelope;
 27    fn from_envelope(envelope: Envelope) -> Option<Self>;
 28}
 29
 30pub trait EntityMessage: EnvelopedMessage {
 31    fn remote_entity_id(&self) -> u64;
 32}
 33
 34pub trait RequestMessage: EnvelopedMessage {
 35    type Response: EnvelopedMessage;
 36}
 37
 38pub trait AnyTypedEnvelope: 'static + Send + Sync {
 39    fn payload_type_id(&self) -> TypeId;
 40    fn payload_type_name(&self) -> &'static str;
 41    fn as_any(&self) -> &dyn Any;
 42    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 43    fn is_background(&self) -> bool;
 44    fn original_sender_id(&self) -> Option<PeerId>;
 45    fn sender_id(&self) -> ConnectionId;
 46    fn message_id(&self) -> u32;
 47}
 48
 49pub enum MessagePriority {
 50    Foreground,
 51    Background,
 52}
 53
 54impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 55    fn payload_type_id(&self) -> TypeId {
 56        TypeId::of::<T>()
 57    }
 58
 59    fn payload_type_name(&self) -> &'static str {
 60        T::NAME
 61    }
 62
 63    fn as_any(&self) -> &dyn Any {
 64        self
 65    }
 66
 67    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 68        self
 69    }
 70
 71    fn is_background(&self) -> bool {
 72        matches!(T::PRIORITY, MessagePriority::Background)
 73    }
 74
 75    fn original_sender_id(&self) -> Option<PeerId> {
 76        self.original_sender_id
 77    }
 78
 79    fn sender_id(&self) -> ConnectionId {
 80        self.sender_id
 81    }
 82
 83    fn message_id(&self) -> u32 {
 84        self.message_id
 85    }
 86}
 87
 88impl PeerId {
 89    pub fn from_u64(peer_id: u64) -> Self {
 90        let owner_id = (peer_id >> 32) as u32;
 91        let id = peer_id as u32;
 92        Self { owner_id, id }
 93    }
 94
 95    pub fn as_u64(self) -> u64 {
 96        ((self.owner_id as u64) << 32) | (self.id as u64)
 97    }
 98}
 99
100impl Copy for PeerId {}
101
102impl Eq for PeerId {}
103
104impl Ord for PeerId {
105    fn cmp(&self, other: &Self) -> cmp::Ordering {
106        self.owner_id
107            .cmp(&other.owner_id)
108            .then_with(|| self.id.cmp(&other.id))
109    }
110}
111
112impl PartialOrd for PeerId {
113    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
114        Some(self.cmp(other))
115    }
116}
117
118impl std::hash::Hash for PeerId {
119    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
120        self.owner_id.hash(state);
121        self.id.hash(state);
122    }
123}
124
125impl fmt::Display for PeerId {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        write!(f, "{}/{}", self.owner_id, self.id)
128    }
129}
130
131messages!(
132    (Ack, Foreground),
133    (AddProjectCollaborator, Foreground),
134    (ApplyCodeAction, Background),
135    (ApplyCodeActionResponse, Background),
136    (ApplyCompletionAdditionalEdits, Background),
137    (ApplyCompletionAdditionalEditsResponse, Background),
138    (BufferReloaded, Foreground),
139    (BufferSaved, Foreground),
140    (Call, Foreground),
141    (CallCanceled, Foreground),
142    (CancelCall, Foreground),
143    (ChannelMessageSent, Foreground),
144    (CopyProjectEntry, Foreground),
145    (CreateBufferForPeer, Foreground),
146    (CreateProjectEntry, Foreground),
147    (CreateRoom, Foreground),
148    (CreateRoomResponse, Foreground),
149    (DeclineCall, Foreground),
150    (DeleteProjectEntry, Foreground),
151    (Error, Foreground),
152    (Follow, Foreground),
153    (FollowResponse, Foreground),
154    (FormatBuffers, Foreground),
155    (FormatBuffersResponse, Foreground),
156    (FuzzySearchUsers, Foreground),
157    (GetChannelMessages, Foreground),
158    (GetChannelMessagesResponse, Foreground),
159    (GetChannels, Foreground),
160    (GetChannelsResponse, Foreground),
161    (GetCodeActions, Background),
162    (GetCodeActionsResponse, Background),
163    (GetHover, Background),
164    (GetHoverResponse, Background),
165    (GetCompletions, Background),
166    (GetCompletionsResponse, Background),
167    (GetDefinition, Background),
168    (GetDefinitionResponse, Background),
169    (GetTypeDefinition, Background),
170    (GetTypeDefinitionResponse, Background),
171    (GetDocumentHighlights, Background),
172    (GetDocumentHighlightsResponse, Background),
173    (GetReferences, Background),
174    (GetReferencesResponse, Background),
175    (GetProjectSymbols, Background),
176    (GetProjectSymbolsResponse, Background),
177    (GetUsers, Foreground),
178    (Hello, Foreground),
179    (IncomingCall, Foreground),
180    (UsersResponse, Foreground),
181    (JoinChannel, Foreground),
182    (JoinChannelResponse, Foreground),
183    (JoinProject, Foreground),
184    (JoinProjectResponse, Foreground),
185    (JoinRoom, Foreground),
186    (JoinRoomResponse, Foreground),
187    (LeaveChannel, Foreground),
188    (LeaveProject, Foreground),
189    (LeaveRoom, Foreground),
190    (OpenBufferById, Background),
191    (OpenBufferByPath, Background),
192    (OpenBufferForSymbol, Background),
193    (OpenBufferForSymbolResponse, Background),
194    (OpenBufferResponse, Background),
195    (PerformRename, Background),
196    (PerformRenameResponse, Background),
197    (Ping, Foreground),
198    (PrepareRename, Background),
199    (PrepareRenameResponse, Background),
200    (ProjectEntryResponse, Foreground),
201    (RejoinRoom, Foreground),
202    (RejoinRoomResponse, Foreground),
203    (RemoveContact, Foreground),
204    (ReloadBuffers, Foreground),
205    (ReloadBuffersResponse, Foreground),
206    (RemoveProjectCollaborator, Foreground),
207    (RenameProjectEntry, Foreground),
208    (RequestContact, Foreground),
209    (RespondToContactRequest, Foreground),
210    (RoomUpdated, Foreground),
211    (SaveBuffer, Foreground),
212    (SearchProject, Background),
213    (SearchProjectResponse, Background),
214    (SendChannelMessage, Foreground),
215    (SendChannelMessageResponse, Foreground),
216    (ShareProject, Foreground),
217    (ShareProjectResponse, Foreground),
218    (ShowContacts, Foreground),
219    (StartLanguageServer, Foreground),
220    (SynchronizeBuffers, Foreground),
221    (SynchronizeBuffersResponse, Foreground),
222    (Test, Foreground),
223    (Unfollow, Foreground),
224    (UnshareProject, Foreground),
225    (UpdateBuffer, Foreground),
226    (UpdateBufferFile, Foreground),
227    (UpdateContacts, Foreground),
228    (UpdateDiagnosticSummary, Foreground),
229    (UpdateFollowers, Foreground),
230    (UpdateInviteInfo, Foreground),
231    (UpdateLanguageServer, Foreground),
232    (UpdateParticipantLocation, Foreground),
233    (UpdateProject, Foreground),
234    (UpdateProjectCollaborator, Foreground),
235    (UpdateWorktree, Foreground),
236    (UpdateDiffBase, Foreground),
237    (GetPrivateUserInfo, Foreground),
238    (GetPrivateUserInfoResponse, Foreground),
239);
240
241request_messages!(
242    (ApplyCodeAction, ApplyCodeActionResponse),
243    (
244        ApplyCompletionAdditionalEdits,
245        ApplyCompletionAdditionalEditsResponse
246    ),
247    (Call, Ack),
248    (CancelCall, Ack),
249    (CopyProjectEntry, ProjectEntryResponse),
250    (CreateProjectEntry, ProjectEntryResponse),
251    (CreateRoom, CreateRoomResponse),
252    (DeclineCall, Ack),
253    (DeleteProjectEntry, ProjectEntryResponse),
254    (Follow, FollowResponse),
255    (FormatBuffers, FormatBuffersResponse),
256    (GetChannelMessages, GetChannelMessagesResponse),
257    (GetChannels, GetChannelsResponse),
258    (GetCodeActions, GetCodeActionsResponse),
259    (GetHover, GetHoverResponse),
260    (GetCompletions, GetCompletionsResponse),
261    (GetDefinition, GetDefinitionResponse),
262    (GetTypeDefinition, GetTypeDefinitionResponse),
263    (GetDocumentHighlights, GetDocumentHighlightsResponse),
264    (GetReferences, GetReferencesResponse),
265    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
266    (GetProjectSymbols, GetProjectSymbolsResponse),
267    (FuzzySearchUsers, UsersResponse),
268    (GetUsers, UsersResponse),
269    (JoinChannel, JoinChannelResponse),
270    (JoinProject, JoinProjectResponse),
271    (JoinRoom, JoinRoomResponse),
272    (LeaveRoom, Ack),
273    (RejoinRoom, RejoinRoomResponse),
274    (IncomingCall, Ack),
275    (OpenBufferById, OpenBufferResponse),
276    (OpenBufferByPath, OpenBufferResponse),
277    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
278    (Ping, Ack),
279    (PerformRename, PerformRenameResponse),
280    (PrepareRename, PrepareRenameResponse),
281    (ReloadBuffers, ReloadBuffersResponse),
282    (RequestContact, Ack),
283    (RemoveContact, Ack),
284    (RespondToContactRequest, Ack),
285    (RenameProjectEntry, ProjectEntryResponse),
286    (SaveBuffer, BufferSaved),
287    (SearchProject, SearchProjectResponse),
288    (SendChannelMessage, SendChannelMessageResponse),
289    (ShareProject, ShareProjectResponse),
290    (SynchronizeBuffers, SynchronizeBuffersResponse),
291    (Test, Test),
292    (UpdateBuffer, Ack),
293    (UpdateParticipantLocation, Ack),
294    (UpdateProject, Ack),
295    (UpdateWorktree, Ack),
296);
297
298entity_messages!(
299    project_id,
300    AddProjectCollaborator,
301    ApplyCodeAction,
302    ApplyCompletionAdditionalEdits,
303    BufferReloaded,
304    BufferSaved,
305    CopyProjectEntry,
306    CreateBufferForPeer,
307    CreateProjectEntry,
308    DeleteProjectEntry,
309    Follow,
310    FormatBuffers,
311    GetCodeActions,
312    GetCompletions,
313    GetDefinition,
314    GetTypeDefinition,
315    GetDocumentHighlights,
316    GetHover,
317    GetReferences,
318    GetProjectSymbols,
319    JoinProject,
320    LeaveProject,
321    OpenBufferById,
322    OpenBufferByPath,
323    OpenBufferForSymbol,
324    PerformRename,
325    PrepareRename,
326    ReloadBuffers,
327    RemoveProjectCollaborator,
328    RenameProjectEntry,
329    SaveBuffer,
330    SearchProject,
331    StartLanguageServer,
332    SynchronizeBuffers,
333    Unfollow,
334    UnshareProject,
335    UpdateBuffer,
336    UpdateBufferFile,
337    UpdateDiagnosticSummary,
338    UpdateFollowers,
339    UpdateLanguageServer,
340    UpdateProject,
341    UpdateProjectCollaborator,
342    UpdateWorktree,
343    UpdateDiffBase
344);
345
346entity_messages!(channel_id, ChannelMessageSent);
347
348const KIB: usize = 1024;
349const MIB: usize = KIB * 1024;
350const MAX_BUFFER_LEN: usize = MIB;
351
352/// A stream of protobuf messages.
353pub struct MessageStream<S> {
354    stream: S,
355    encoding_buffer: Vec<u8>,
356}
357
358#[allow(clippy::large_enum_variant)]
359#[derive(Debug)]
360pub enum Message {
361    Envelope(Envelope),
362    Ping,
363    Pong,
364}
365
366impl<S> MessageStream<S> {
367    pub fn new(stream: S) -> Self {
368        Self {
369            stream,
370            encoding_buffer: Vec::new(),
371        }
372    }
373
374    pub fn inner_mut(&mut self) -> &mut S {
375        &mut self.stream
376    }
377}
378
379impl<S> MessageStream<S>
380where
381    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
382{
383    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
384        #[cfg(any(test, feature = "test-support"))]
385        const COMPRESSION_LEVEL: i32 = -7;
386
387        #[cfg(not(any(test, feature = "test-support")))]
388        const COMPRESSION_LEVEL: i32 = 4;
389
390        match message {
391            Message::Envelope(message) => {
392                self.encoding_buffer.reserve(message.encoded_len());
393                message
394                    .encode(&mut self.encoding_buffer)
395                    .map_err(io::Error::from)?;
396                let buffer =
397                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
398                        .unwrap();
399
400                self.encoding_buffer.clear();
401                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
402                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
403            }
404            Message::Ping => {
405                self.stream
406                    .send(WebSocketMessage::Ping(Default::default()))
407                    .await?;
408            }
409            Message::Pong => {
410                self.stream
411                    .send(WebSocketMessage::Pong(Default::default()))
412                    .await?;
413            }
414        }
415
416        Ok(())
417    }
418}
419
420impl<S> MessageStream<S>
421where
422    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
423{
424    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
425        while let Some(bytes) = self.stream.next().await {
426            match bytes? {
427                WebSocketMessage::Binary(bytes) => {
428                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
429                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
430                        .map_err(io::Error::from)?;
431
432                    self.encoding_buffer.clear();
433                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
434                    return Ok(Message::Envelope(envelope));
435                }
436                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
437                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
438                WebSocketMessage::Close(_) => break,
439                _ => {}
440            }
441        }
442        Err(anyhow!("connection closed"))
443    }
444}
445
446impl From<Timestamp> for SystemTime {
447    fn from(val: Timestamp) -> Self {
448        UNIX_EPOCH
449            .checked_add(Duration::new(val.seconds, val.nanos))
450            .unwrap()
451    }
452}
453
454impl From<SystemTime> for Timestamp {
455    fn from(time: SystemTime) -> Self {
456        let duration = time.duration_since(UNIX_EPOCH).unwrap();
457        Self {
458            seconds: duration.as_secs(),
459            nanos: duration.subsec_nanos(),
460        }
461    }
462}
463
464impl From<u128> for Nonce {
465    fn from(nonce: u128) -> Self {
466        let upper_half = (nonce >> 64) as u64;
467        let lower_half = nonce as u64;
468        Self {
469            upper_half,
470            lower_half,
471        }
472    }
473}
474
475impl From<Nonce> for u128 {
476    fn from(nonce: Nonce) -> Self {
477        let upper_half = (nonce.upper_half as u128) << 64;
478        let lower_half = nonce.lower_half as u128;
479        upper_half | lower_half
480    }
481}
482
483pub fn split_worktree_update(
484    mut message: UpdateWorktree,
485    max_chunk_size: usize,
486) -> impl Iterator<Item = UpdateWorktree> {
487    let mut done_files = false;
488    let mut done_statuses = false;
489    let mut repository_index = 0;
490    let mut root_repo_found = false;
491    iter::from_fn(move || {
492        if done_files && done_statuses {
493            return None;
494        }
495
496        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
497        let updated_entries: Vec<_> = message
498            .updated_entries
499            .drain(..updated_entries_chunk_size)
500            .collect();
501
502        let mut updated_repositories: Vec<_> = Default::default();
503
504        if !root_repo_found {
505            for entry in updated_entries.iter() {
506                if let Some(repo) = message.updated_repositories.get(0) {
507                    if repo.work_directory_id == entry.id {
508                        root_repo_found = true;
509                        updated_repositories.push(RepositoryEntry {
510                            work_directory_id: repo.work_directory_id,
511                            branch: repo.branch.clone(),
512                            removed_worktree_repo_paths: Default::default(),
513                            updated_worktree_statuses: Default::default(),
514                        });
515                        break;
516                    }
517                }
518            }
519        }
520
521        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
522        let removed_entries = message
523            .removed_entries
524            .drain(..removed_entries_chunk_size)
525            .collect();
526
527        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
528
529        // Wait to send repositories until after we've guaranteed that their associated entries
530        // will be read
531        if done_files {
532            let mut total_statuses = 0;
533            while total_statuses < max_chunk_size
534                && repository_index < message.updated_repositories.len()
535            {
536                let updated_statuses_chunk_size = cmp::min(
537                    message.updated_repositories[repository_index]
538                        .updated_worktree_statuses
539                        .len(),
540                    max_chunk_size - total_statuses,
541                );
542
543                let updated_statuses: Vec<_> = message.updated_repositories[repository_index]
544                    .updated_worktree_statuses
545                    .drain(..updated_statuses_chunk_size)
546                    .collect();
547
548                total_statuses += updated_statuses.len();
549
550                let done_this_repo = message.updated_repositories[repository_index]
551                    .updated_worktree_statuses
552                    .is_empty();
553
554                let removed_repo_paths = if done_this_repo {
555                    mem::take(
556                        &mut message.updated_repositories[repository_index]
557                            .removed_worktree_repo_paths,
558                    )
559                } else {
560                    Default::default()
561                };
562
563                updated_repositories.push(RepositoryEntry {
564                    work_directory_id: message.updated_repositories[repository_index]
565                        .work_directory_id,
566                    branch: message.updated_repositories[repository_index]
567                        .branch
568                        .clone(),
569                    updated_worktree_statuses: updated_statuses,
570                    removed_worktree_repo_paths: removed_repo_paths,
571                });
572
573                if done_this_repo {
574                    repository_index += 1;
575                }
576            }
577        } else {
578            Default::default()
579        };
580
581        let removed_repositories = if done_files && done_statuses {
582            mem::take(&mut message.removed_repositories)
583        } else {
584            Default::default()
585        };
586
587        done_statuses = repository_index >= message.updated_repositories.len();
588
589        Some(UpdateWorktree {
590            project_id: message.project_id,
591            worktree_id: message.worktree_id,
592            root_name: message.root_name.clone(),
593            abs_path: message.abs_path.clone(),
594            updated_entries,
595            removed_entries,
596            scan_id: message.scan_id,
597            is_last_update: done_files && message.is_last_update,
598            updated_repositories,
599            removed_repositories,
600        })
601    })
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607
608    #[gpui::test]
609    async fn test_buffer_size() {
610        let (tx, rx) = futures::channel::mpsc::unbounded();
611        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
612        sink.write(Message::Envelope(Envelope {
613            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
614                root_name: "abcdefg".repeat(10),
615                ..Default::default()
616            })),
617            ..Default::default()
618        }))
619        .await
620        .unwrap();
621        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
622        sink.write(Message::Envelope(Envelope {
623            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
624                root_name: "abcdefg".repeat(1000000),
625                ..Default::default()
626            })),
627            ..Default::default()
628        }))
629        .await
630        .unwrap();
631        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
632
633        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
634        stream.read().await.unwrap();
635        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
636        stream.read().await.unwrap();
637        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
638    }
639
640    #[gpui::test]
641    fn test_converting_peer_id_from_and_to_u64() {
642        let peer_id = PeerId {
643            owner_id: 10,
644            id: 3,
645        };
646        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
647        let peer_id = PeerId {
648            owner_id: u32::MAX,
649            id: 3,
650        };
651        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
652        let peer_id = PeerId {
653            owner_id: 10,
654            id: u32::MAX,
655        };
656        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
657        let peer_id = PeerId {
658            owner_id: u32::MAX,
659            id: u32::MAX,
660        };
661        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
662    }
663}