proto.rs

  1#![allow(non_snake_case)]
  2
  3use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  4use anyhow::{anyhow, Result};
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use collections::HashMap;
  7use futures::{SinkExt as _, StreamExt as _};
  8use prost::Message as _;
  9use serde::Serialize;
 10use std::any::{Any, TypeId};
 11use std::time::Instant;
 12use std::{
 13    cmp,
 14    fmt::Debug,
 15    io, iter,
 16    time::{Duration, SystemTime, UNIX_EPOCH},
 17};
 18use std::{fmt, mem};
 19
 20include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 21
 22pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 23    const NAME: &'static str;
 24    const PRIORITY: MessagePriority;
 25    fn into_envelope(
 26        self,
 27        id: u32,
 28        responding_to: Option<u32>,
 29        original_sender_id: Option<PeerId>,
 30    ) -> Envelope;
 31    fn from_envelope(envelope: Envelope) -> Option<Self>;
 32}
 33
 34pub trait EntityMessage: EnvelopedMessage {
 35    type Entity;
 36    fn remote_entity_id(&self) -> u64;
 37}
 38
 39pub trait RequestMessage: EnvelopedMessage {
 40    type Response: EnvelopedMessage;
 41}
 42
 43pub trait AnyTypedEnvelope: 'static + Send + Sync {
 44    fn payload_type_id(&self) -> TypeId;
 45    fn payload_type_name(&self) -> &'static str;
 46    fn as_any(&self) -> &dyn Any;
 47    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 48    fn is_background(&self) -> bool;
 49    fn original_sender_id(&self) -> Option<PeerId>;
 50    fn sender_id(&self) -> ConnectionId;
 51    fn message_id(&self) -> u32;
 52}
 53
 54pub enum MessagePriority {
 55    Foreground,
 56    Background,
 57}
 58
 59impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 60    fn payload_type_id(&self) -> TypeId {
 61        TypeId::of::<T>()
 62    }
 63
 64    fn payload_type_name(&self) -> &'static str {
 65        T::NAME
 66    }
 67
 68    fn as_any(&self) -> &dyn Any {
 69        self
 70    }
 71
 72    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 73        self
 74    }
 75
 76    fn is_background(&self) -> bool {
 77        matches!(T::PRIORITY, MessagePriority::Background)
 78    }
 79
 80    fn original_sender_id(&self) -> Option<PeerId> {
 81        self.original_sender_id
 82    }
 83
 84    fn sender_id(&self) -> ConnectionId {
 85        self.sender_id
 86    }
 87
 88    fn message_id(&self) -> u32 {
 89        self.message_id
 90    }
 91}
 92
 93impl PeerId {
 94    pub fn from_u64(peer_id: u64) -> Self {
 95        let owner_id = (peer_id >> 32) as u32;
 96        let id = peer_id as u32;
 97        Self { owner_id, id }
 98    }
 99
100    pub fn as_u64(self) -> u64 {
101        ((self.owner_id as u64) << 32) | (self.id as u64)
102    }
103}
104
105impl Copy for PeerId {}
106
107impl Eq for PeerId {}
108
109impl Ord for PeerId {
110    fn cmp(&self, other: &Self) -> cmp::Ordering {
111        self.owner_id
112            .cmp(&other.owner_id)
113            .then_with(|| self.id.cmp(&other.id))
114    }
115}
116
117impl PartialOrd for PeerId {
118    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
119        Some(self.cmp(other))
120    }
121}
122
123impl std::hash::Hash for PeerId {
124    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
125        self.owner_id.hash(state);
126        self.id.hash(state);
127    }
128}
129
130impl fmt::Display for PeerId {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        write!(f, "{}/{}", self.owner_id, self.id)
133    }
134}
135
136messages!(
137    (Ack, Foreground),
138    (AckBufferOperation, Background),
139    (AckChannelMessage, Background),
140    (AddNotification, Foreground),
141    (AddProjectCollaborator, Foreground),
142    (ApplyCodeAction, Background),
143    (ApplyCodeActionResponse, Background),
144    (ApplyCompletionAdditionalEdits, Background),
145    (ApplyCompletionAdditionalEditsResponse, Background),
146    (BufferReloaded, Foreground),
147    (BufferSaved, Foreground),
148    (Call, Foreground),
149    (CallCanceled, Foreground),
150    (CancelCall, Foreground),
151    (ChannelMessageSent, Foreground),
152    (ChannelMessageUpdate, Foreground),
153    (CompleteWithLanguageModel, Background),
154    (ComputeEmbeddings, Background),
155    (ComputeEmbeddingsResponse, Background),
156    (CopyProjectEntry, Foreground),
157    (CountTokensWithLanguageModel, Background),
158    (CountTokensResponse, Background),
159    (CreateBufferForPeer, Foreground),
160    (CreateChannel, Foreground),
161    (CreateChannelResponse, Foreground),
162    (CreateProjectEntry, Foreground),
163    (CreateRoom, Foreground),
164    (CreateRoomResponse, Foreground),
165    (DeclineCall, Foreground),
166    (DeleteChannel, Foreground),
167    (DeleteNotification, Foreground),
168    (UpdateNotification, Foreground),
169    (DeleteProjectEntry, Foreground),
170    (EndStream, Foreground),
171    (Error, Foreground),
172    (ExpandProjectEntry, Foreground),
173    (ExpandProjectEntryResponse, Foreground),
174    (Follow, Foreground),
175    (FollowResponse, Foreground),
176    (FormatBuffers, Foreground),
177    (FormatBuffersResponse, Foreground),
178    (FuzzySearchUsers, Foreground),
179    (GetCachedEmbeddings, Background),
180    (GetCachedEmbeddingsResponse, Background),
181    (GetChannelMembers, Foreground),
182    (GetChannelMembersResponse, Foreground),
183    (GetChannelMessages, Background),
184    (GetChannelMessagesById, Background),
185    (GetChannelMessagesResponse, Background),
186    (GetCodeActions, Background),
187    (GetCodeActionsResponse, Background),
188    (GetCompletions, Background),
189    (GetCompletionsResponse, Background),
190    (GetDefinition, Background),
191    (GetDefinitionResponse, Background),
192    (GetDocumentHighlights, Background),
193    (GetDocumentHighlightsResponse, Background),
194    (GetHover, Background),
195    (GetHoverResponse, Background),
196    (GetNotifications, Foreground),
197    (GetNotificationsResponse, Foreground),
198    (GetPrivateUserInfo, Foreground),
199    (GetPrivateUserInfoResponse, Foreground),
200    (GetProjectSymbols, Background),
201    (GetProjectSymbolsResponse, Background),
202    (GetReferences, Background),
203    (GetReferencesResponse, Background),
204    (GetSupermavenApiKey, Background),
205    (GetSupermavenApiKeyResponse, Background),
206    (GetTypeDefinition, Background),
207    (GetTypeDefinitionResponse, Background),
208    (GetImplementation, Background),
209    (GetImplementationResponse, Background),
210    (GetUsers, Foreground),
211    (Hello, Foreground),
212    (IncomingCall, Foreground),
213    (InlayHints, Background),
214    (InlayHintsResponse, Background),
215    (InviteChannelMember, Foreground),
216    (JoinChannel, Foreground),
217    (JoinChannelBuffer, Foreground),
218    (JoinChannelBufferResponse, Foreground),
219    (JoinChannelChat, Foreground),
220    (JoinChannelChatResponse, Foreground),
221    (JoinProject, Foreground),
222    (JoinHostedProject, Foreground),
223    (JoinProjectResponse, Foreground),
224    (JoinRoom, Foreground),
225    (JoinRoomResponse, Foreground),
226    (LanguageModelResponse, Background),
227    (LeaveChannelBuffer, Background),
228    (LeaveChannelChat, Foreground),
229    (LeaveProject, Foreground),
230    (LeaveRoom, Foreground),
231    (MarkNotificationRead, Foreground),
232    (MoveChannel, Foreground),
233    (OnTypeFormatting, Background),
234    (OnTypeFormattingResponse, Background),
235    (OpenBufferById, Background),
236    (OpenBufferByPath, Background),
237    (OpenBufferForSymbol, Background),
238    (OpenBufferForSymbolResponse, Background),
239    (OpenBufferResponse, Background),
240    (PerformRename, Background),
241    (PerformRenameResponse, Background),
242    (Ping, Foreground),
243    (PrepareRename, Background),
244    (PrepareRenameResponse, Background),
245    (ProjectEntryResponse, Foreground),
246    (RefreshInlayHints, Foreground),
247    (RejoinChannelBuffers, Foreground),
248    (RejoinChannelBuffersResponse, Foreground),
249    (RejoinRoom, Foreground),
250    (RejoinRoomResponse, Foreground),
251    (ReloadBuffers, Foreground),
252    (ReloadBuffersResponse, Foreground),
253    (RemoveChannelMember, Foreground),
254    (RemoveChannelMessage, Foreground),
255    (UpdateChannelMessage, Foreground),
256    (RemoveContact, Foreground),
257    (RemoveProjectCollaborator, Foreground),
258    (RenameChannel, Foreground),
259    (RenameChannelResponse, Foreground),
260    (RenameProjectEntry, Foreground),
261    (RequestContact, Foreground),
262    (ResolveCompletionDocumentation, Background),
263    (ResolveCompletionDocumentationResponse, Background),
264    (ResolveInlayHint, Background),
265    (ResolveInlayHintResponse, Background),
266    (RespondToChannelInvite, Foreground),
267    (RespondToContactRequest, Foreground),
268    (RoomUpdated, Foreground),
269    (SaveBuffer, Foreground),
270    (SetChannelMemberRole, Foreground),
271    (SetChannelVisibility, Foreground),
272    (SearchProject, Background),
273    (SearchProjectResponse, Background),
274    (SendChannelMessage, Background),
275    (SendChannelMessageResponse, Background),
276    (ShareProject, Foreground),
277    (ShareProjectResponse, Foreground),
278    (ShowContacts, Foreground),
279    (StartLanguageServer, Foreground),
280    (SynchronizeBuffers, Foreground),
281    (SynchronizeBuffersResponse, Foreground),
282    (Test, Foreground),
283    (Unfollow, Foreground),
284    (UnshareProject, Foreground),
285    (UpdateBuffer, Foreground),
286    (UpdateBufferFile, Foreground),
287    (UpdateChannelBuffer, Foreground),
288    (UpdateChannelBufferCollaborators, Foreground),
289    (UpdateChannels, Foreground),
290    (UpdateUserChannels, Foreground),
291    (UpdateContacts, Foreground),
292    (UpdateDiagnosticSummary, Foreground),
293    (UpdateDiffBase, Foreground),
294    (UpdateFollowers, Foreground),
295    (UpdateInviteInfo, Foreground),
296    (UpdateLanguageServer, Foreground),
297    (UpdateParticipantLocation, Foreground),
298    (UpdateProject, Foreground),
299    (UpdateProjectCollaborator, Foreground),
300    (UpdateWorktree, Foreground),
301    (UpdateWorktreeSettings, Foreground),
302    (UsersResponse, Foreground),
303    (LspExtExpandMacro, Background),
304    (LspExtExpandMacroResponse, Background),
305    (SetRoomParticipantRole, Foreground),
306    (BlameBuffer, Foreground),
307    (BlameBufferResponse, Foreground),
308    (CreateDevServerProject, Background),
309    (CreateDevServerProjectResponse, Foreground),
310    (CreateDevServer, Foreground),
311    (CreateDevServerResponse, Foreground),
312    (DevServerInstructions, Foreground),
313    (ShutdownDevServer, Foreground),
314    (ReconnectDevServer, Foreground),
315    (ReconnectDevServerResponse, Foreground),
316    (ShareDevServerProject, Foreground),
317    (JoinDevServerProject, Foreground),
318    (RejoinRemoteProjects, Foreground),
319    (RejoinRemoteProjectsResponse, Foreground),
320    (MultiLspQuery, Background),
321    (MultiLspQueryResponse, Background),
322    (DevServerProjectsUpdate, Foreground),
323    (ValidateDevServerProjectRequest, Background),
324    (DeleteDevServer, Foreground),
325    (DeleteDevServerProject, Foreground),
326    (RegenerateDevServerToken, Foreground),
327    (RegenerateDevServerTokenResponse, Foreground),
328    (RenameDevServer, Foreground),
329    (OpenNewBuffer, Foreground)
330);
331
332request_messages!(
333    (ApplyCodeAction, ApplyCodeActionResponse),
334    (
335        ApplyCompletionAdditionalEdits,
336        ApplyCompletionAdditionalEditsResponse
337    ),
338    (Call, Ack),
339    (CancelCall, Ack),
340    (CopyProjectEntry, ProjectEntryResponse),
341    (CompleteWithLanguageModel, LanguageModelResponse),
342    (ComputeEmbeddings, ComputeEmbeddingsResponse),
343    (CountTokensWithLanguageModel, CountTokensResponse),
344    (CreateChannel, CreateChannelResponse),
345    (CreateProjectEntry, ProjectEntryResponse),
346    (CreateRoom, CreateRoomResponse),
347    (DeclineCall, Ack),
348    (DeleteChannel, Ack),
349    (DeleteProjectEntry, ProjectEntryResponse),
350    (ExpandProjectEntry, ExpandProjectEntryResponse),
351    (Follow, FollowResponse),
352    (FormatBuffers, FormatBuffersResponse),
353    (FuzzySearchUsers, UsersResponse),
354    (GetCachedEmbeddings, GetCachedEmbeddingsResponse),
355    (GetChannelMembers, GetChannelMembersResponse),
356    (GetChannelMessages, GetChannelMessagesResponse),
357    (GetChannelMessagesById, GetChannelMessagesResponse),
358    (GetCodeActions, GetCodeActionsResponse),
359    (GetCompletions, GetCompletionsResponse),
360    (GetDefinition, GetDefinitionResponse),
361    (GetImplementation, GetImplementationResponse),
362    (GetDocumentHighlights, GetDocumentHighlightsResponse),
363    (GetHover, GetHoverResponse),
364    (GetNotifications, GetNotificationsResponse),
365    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
366    (GetProjectSymbols, GetProjectSymbolsResponse),
367    (GetReferences, GetReferencesResponse),
368    (GetSupermavenApiKey, GetSupermavenApiKeyResponse),
369    (GetTypeDefinition, GetTypeDefinitionResponse),
370    (GetUsers, UsersResponse),
371    (IncomingCall, Ack),
372    (InlayHints, InlayHintsResponse),
373    (InviteChannelMember, Ack),
374    (JoinChannel, JoinRoomResponse),
375    (JoinChannelBuffer, JoinChannelBufferResponse),
376    (JoinChannelChat, JoinChannelChatResponse),
377    (JoinHostedProject, JoinProjectResponse),
378    (JoinProject, JoinProjectResponse),
379    (JoinRoom, JoinRoomResponse),
380    (LeaveChannelBuffer, Ack),
381    (LeaveRoom, Ack),
382    (MarkNotificationRead, Ack),
383    (MoveChannel, Ack),
384    (OnTypeFormatting, OnTypeFormattingResponse),
385    (OpenBufferById, OpenBufferResponse),
386    (OpenBufferByPath, OpenBufferResponse),
387    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
388    (OpenNewBuffer, OpenBufferResponse),
389    (PerformRename, PerformRenameResponse),
390    (Ping, Ack),
391    (PrepareRename, PrepareRenameResponse),
392    (RefreshInlayHints, Ack),
393    (RejoinChannelBuffers, RejoinChannelBuffersResponse),
394    (RejoinRoom, RejoinRoomResponse),
395    (ReloadBuffers, ReloadBuffersResponse),
396    (RemoveChannelMember, Ack),
397    (RemoveChannelMessage, Ack),
398    (UpdateChannelMessage, Ack),
399    (RemoveContact, Ack),
400    (RenameChannel, RenameChannelResponse),
401    (RenameProjectEntry, ProjectEntryResponse),
402    (RequestContact, Ack),
403    (
404        ResolveCompletionDocumentation,
405        ResolveCompletionDocumentationResponse
406    ),
407    (ResolveInlayHint, ResolveInlayHintResponse),
408    (RespondToChannelInvite, Ack),
409    (RespondToContactRequest, Ack),
410    (SaveBuffer, BufferSaved),
411    (SearchProject, SearchProjectResponse),
412    (SendChannelMessage, SendChannelMessageResponse),
413    (SetChannelMemberRole, Ack),
414    (SetChannelVisibility, Ack),
415    (ShareProject, ShareProjectResponse),
416    (SynchronizeBuffers, SynchronizeBuffersResponse),
417    (Test, Test),
418    (UpdateBuffer, Ack),
419    (UpdateParticipantLocation, Ack),
420    (UpdateProject, Ack),
421    (UpdateWorktree, Ack),
422    (LspExtExpandMacro, LspExtExpandMacroResponse),
423    (SetRoomParticipantRole, Ack),
424    (BlameBuffer, BlameBufferResponse),
425    (CreateDevServerProject, CreateDevServerProjectResponse),
426    (CreateDevServer, CreateDevServerResponse),
427    (ShutdownDevServer, Ack),
428    (ShareDevServerProject, ShareProjectResponse),
429    (JoinDevServerProject, JoinProjectResponse),
430    (RejoinRemoteProjects, RejoinRemoteProjectsResponse),
431    (ReconnectDevServer, ReconnectDevServerResponse),
432    (ValidateDevServerProjectRequest, Ack),
433    (MultiLspQuery, MultiLspQueryResponse),
434    (DeleteDevServer, Ack),
435    (DeleteDevServerProject, Ack),
436    (RegenerateDevServerToken, RegenerateDevServerTokenResponse),
437    (RenameDevServer, Ack)
438);
439
440entity_messages!(
441    {project_id, ShareProject},
442    AddProjectCollaborator,
443    ApplyCodeAction,
444    ApplyCompletionAdditionalEdits,
445    BlameBuffer,
446    BufferReloaded,
447    BufferSaved,
448    CopyProjectEntry,
449    CreateBufferForPeer,
450    CreateProjectEntry,
451    DeleteProjectEntry,
452    ExpandProjectEntry,
453    FormatBuffers,
454    GetCodeActions,
455    GetCompletions,
456    GetDefinition,
457    GetImplementation,
458    GetDocumentHighlights,
459    GetHover,
460    GetProjectSymbols,
461    GetReferences,
462    GetTypeDefinition,
463    InlayHints,
464    JoinProject,
465    LeaveProject,
466    MultiLspQuery,
467    OnTypeFormatting,
468    OpenNewBuffer,
469    OpenBufferById,
470    OpenBufferByPath,
471    OpenBufferForSymbol,
472    PerformRename,
473    PrepareRename,
474    RefreshInlayHints,
475    ReloadBuffers,
476    RemoveProjectCollaborator,
477    RenameProjectEntry,
478    ResolveCompletionDocumentation,
479    ResolveInlayHint,
480    SaveBuffer,
481    SearchProject,
482    StartLanguageServer,
483    SynchronizeBuffers,
484    UnshareProject,
485    UpdateBuffer,
486    UpdateBufferFile,
487    UpdateDiagnosticSummary,
488    UpdateDiffBase,
489    UpdateLanguageServer,
490    UpdateProject,
491    UpdateProjectCollaborator,
492    UpdateWorktree,
493    UpdateWorktreeSettings,
494    LspExtExpandMacro,
495);
496
497entity_messages!(
498    {channel_id, Channel},
499    ChannelMessageSent,
500    ChannelMessageUpdate,
501    RemoveChannelMessage,
502    UpdateChannelMessage,
503    UpdateChannelBuffer,
504    UpdateChannelBufferCollaborators,
505);
506
507const KIB: usize = 1024;
508const MIB: usize = KIB * 1024;
509const MAX_BUFFER_LEN: usize = MIB;
510
511/// A stream of protobuf messages.
512pub struct MessageStream<S> {
513    stream: S,
514    encoding_buffer: Vec<u8>,
515}
516
517#[allow(clippy::large_enum_variant)]
518#[derive(Debug)]
519pub enum Message {
520    Envelope(Envelope),
521    Ping,
522    Pong,
523}
524
525impl<S> MessageStream<S> {
526    pub fn new(stream: S) -> Self {
527        Self {
528            stream,
529            encoding_buffer: Vec::new(),
530        }
531    }
532
533    pub fn inner_mut(&mut self) -> &mut S {
534        &mut self.stream
535    }
536}
537
538impl<S> MessageStream<S>
539where
540    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
541{
542    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
543        #[cfg(any(test, feature = "test-support"))]
544        const COMPRESSION_LEVEL: i32 = -7;
545
546        #[cfg(not(any(test, feature = "test-support")))]
547        const COMPRESSION_LEVEL: i32 = 4;
548
549        match message {
550            Message::Envelope(message) => {
551                self.encoding_buffer.reserve(message.encoded_len());
552                message
553                    .encode(&mut self.encoding_buffer)
554                    .map_err(io::Error::from)?;
555                let buffer =
556                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
557                        .unwrap();
558
559                self.encoding_buffer.clear();
560                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
561                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
562            }
563            Message::Ping => {
564                self.stream
565                    .send(WebSocketMessage::Ping(Default::default()))
566                    .await?;
567            }
568            Message::Pong => {
569                self.stream
570                    .send(WebSocketMessage::Pong(Default::default()))
571                    .await?;
572            }
573        }
574
575        Ok(())
576    }
577}
578
579impl<S> MessageStream<S>
580where
581    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
582{
583    pub async fn read(&mut self) -> Result<(Message, Instant), anyhow::Error> {
584        while let Some(bytes) = self.stream.next().await {
585            let received_at = Instant::now();
586            match bytes? {
587                WebSocketMessage::Binary(bytes) => {
588                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
589                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
590                        .map_err(io::Error::from)?;
591
592                    self.encoding_buffer.clear();
593                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
594                    return Ok((Message::Envelope(envelope), received_at));
595                }
596                WebSocketMessage::Ping(_) => return Ok((Message::Ping, received_at)),
597                WebSocketMessage::Pong(_) => return Ok((Message::Pong, received_at)),
598                WebSocketMessage::Close(_) => break,
599                _ => {}
600            }
601        }
602        Err(anyhow!("connection closed"))
603    }
604}
605
606impl From<Timestamp> for SystemTime {
607    fn from(val: Timestamp) -> Self {
608        UNIX_EPOCH
609            .checked_add(Duration::new(val.seconds, val.nanos))
610            .unwrap()
611    }
612}
613
614impl From<SystemTime> for Timestamp {
615    fn from(time: SystemTime) -> Self {
616        let duration = time.duration_since(UNIX_EPOCH).unwrap();
617        Self {
618            seconds: duration.as_secs(),
619            nanos: duration.subsec_nanos(),
620        }
621    }
622}
623
624impl From<u128> for Nonce {
625    fn from(nonce: u128) -> Self {
626        let upper_half = (nonce >> 64) as u64;
627        let lower_half = nonce as u64;
628        Self {
629            upper_half,
630            lower_half,
631        }
632    }
633}
634
635impl From<Nonce> for u128 {
636    fn from(nonce: Nonce) -> Self {
637        let upper_half = (nonce.upper_half as u128) << 64;
638        let lower_half = nonce.lower_half as u128;
639        upper_half | lower_half
640    }
641}
642
643pub fn split_worktree_update(
644    mut message: UpdateWorktree,
645    max_chunk_size: usize,
646) -> impl Iterator<Item = UpdateWorktree> {
647    let mut done_files = false;
648
649    let mut repository_map = message
650        .updated_repositories
651        .into_iter()
652        .map(|repo| (repo.work_directory_id, repo))
653        .collect::<HashMap<_, _>>();
654
655    iter::from_fn(move || {
656        if done_files {
657            return None;
658        }
659
660        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
661        let updated_entries: Vec<_> = message
662            .updated_entries
663            .drain(..updated_entries_chunk_size)
664            .collect();
665
666        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
667        let removed_entries = message
668            .removed_entries
669            .drain(..removed_entries_chunk_size)
670            .collect();
671
672        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
673
674        let mut updated_repositories = Vec::new();
675
676        if !repository_map.is_empty() {
677            for entry in &updated_entries {
678                if let Some(repo) = repository_map.remove(&entry.id) {
679                    updated_repositories.push(repo)
680                }
681            }
682        }
683
684        let removed_repositories = if done_files {
685            mem::take(&mut message.removed_repositories)
686        } else {
687            Default::default()
688        };
689
690        if done_files {
691            updated_repositories.extend(mem::take(&mut repository_map).into_values());
692        }
693
694        Some(UpdateWorktree {
695            project_id: message.project_id,
696            worktree_id: message.worktree_id,
697            root_name: message.root_name.clone(),
698            abs_path: message.abs_path.clone(),
699            updated_entries,
700            removed_entries,
701            scan_id: message.scan_id,
702            is_last_update: done_files && message.is_last_update,
703            updated_repositories,
704            removed_repositories,
705        })
706    })
707}
708
709#[cfg(test)]
710mod tests {
711    use super::*;
712
713    #[gpui::test]
714    async fn test_buffer_size() {
715        let (tx, rx) = futures::channel::mpsc::unbounded();
716        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
717        sink.write(Message::Envelope(Envelope {
718            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
719                root_name: "abcdefg".repeat(10),
720                ..Default::default()
721            })),
722            ..Default::default()
723        }))
724        .await
725        .unwrap();
726        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
727        sink.write(Message::Envelope(Envelope {
728            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
729                root_name: "abcdefg".repeat(1000000),
730                ..Default::default()
731            })),
732            ..Default::default()
733        }))
734        .await
735        .unwrap();
736        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
737
738        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
739        stream.read().await.unwrap();
740        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
741        stream.read().await.unwrap();
742        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
743    }
744
745    #[gpui::test]
746    fn test_converting_peer_id_from_and_to_u64() {
747        let peer_id = PeerId {
748            owner_id: 10,
749            id: 3,
750        };
751        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
752        let peer_id = PeerId {
753            owner_id: u32::MAX,
754            id: 3,
755        };
756        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
757        let peer_id = PeerId {
758            owner_id: 10,
759            id: u32::MAX,
760        };
761        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
762        let peer_id = PeerId {
763            owner_id: u32::MAX,
764            id: u32::MAX,
765        };
766        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
767    }
768}