proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{
  9    cmp,
 10    fmt::Debug,
 11    io, iter,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14use std::{fmt, mem};
 15
 16include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 17
 18pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 19    const NAME: &'static str;
 20    const PRIORITY: MessagePriority;
 21    fn into_envelope(
 22        self,
 23        id: u32,
 24        responding_to: Option<u32>,
 25        original_sender_id: Option<PeerId>,
 26    ) -> Envelope;
 27    fn from_envelope(envelope: Envelope) -> Option<Self>;
 28}
 29
 30pub trait EntityMessage: EnvelopedMessage {
 31    fn remote_entity_id(&self) -> u64;
 32}
 33
 34pub trait RequestMessage: EnvelopedMessage {
 35    type Response: EnvelopedMessage;
 36}
 37
 38pub trait AnyTypedEnvelope: 'static + Send + Sync {
 39    fn payload_type_id(&self) -> TypeId;
 40    fn payload_type_name(&self) -> &'static str;
 41    fn as_any(&self) -> &dyn Any;
 42    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 43    fn is_background(&self) -> bool;
 44    fn original_sender_id(&self) -> Option<PeerId>;
 45    fn sender_id(&self) -> ConnectionId;
 46    fn message_id(&self) -> u32;
 47}
 48
 49pub enum MessagePriority {
 50    Foreground,
 51    Background,
 52}
 53
 54impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 55    fn payload_type_id(&self) -> TypeId {
 56        TypeId::of::<T>()
 57    }
 58
 59    fn payload_type_name(&self) -> &'static str {
 60        T::NAME
 61    }
 62
 63    fn as_any(&self) -> &dyn Any {
 64        self
 65    }
 66
 67    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 68        self
 69    }
 70
 71    fn is_background(&self) -> bool {
 72        matches!(T::PRIORITY, MessagePriority::Background)
 73    }
 74
 75    fn original_sender_id(&self) -> Option<PeerId> {
 76        self.original_sender_id
 77    }
 78
 79    fn sender_id(&self) -> ConnectionId {
 80        self.sender_id
 81    }
 82
 83    fn message_id(&self) -> u32 {
 84        self.message_id
 85    }
 86}
 87
 88impl PeerId {
 89    pub fn from_u64(peer_id: u64) -> Self {
 90        let owner_id = (peer_id >> 32) as u32;
 91        let id = peer_id as u32;
 92        Self { owner_id, id }
 93    }
 94
 95    pub fn as_u64(self) -> u64 {
 96        ((self.owner_id as u64) << 32) | (self.id as u64)
 97    }
 98}
 99
100impl Copy for PeerId {}
101
102impl Eq for PeerId {}
103
104impl Ord for PeerId {
105    fn cmp(&self, other: &Self) -> cmp::Ordering {
106        self.owner_id
107            .cmp(&other.owner_id)
108            .then_with(|| self.id.cmp(&other.id))
109    }
110}
111
112impl PartialOrd for PeerId {
113    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
114        Some(self.cmp(other))
115    }
116}
117
118impl std::hash::Hash for PeerId {
119    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
120        self.owner_id.hash(state);
121        self.id.hash(state);
122    }
123}
124
125impl fmt::Display for PeerId {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        write!(f, "{}/{}", self.owner_id, self.id)
128    }
129}
130
131messages!(
132    (Ack, Foreground),
133    (AddProjectCollaborator, Foreground),
134    (ApplyCodeAction, Background),
135    (ApplyCodeActionResponse, Background),
136    (ApplyCompletionAdditionalEdits, Background),
137    (ApplyCompletionAdditionalEditsResponse, Background),
138    (BufferReloaded, Foreground),
139    (BufferSaved, Foreground),
140    (Call, Foreground),
141    (CallCanceled, Foreground),
142    (CancelCall, Foreground),
143    (ChannelMessageSent, Foreground),
144    (CopyProjectEntry, Foreground),
145    (CreateBufferForPeer, Foreground),
146    (CreateProjectEntry, Foreground),
147    (CreateRoom, Foreground),
148    (CreateRoomResponse, Foreground),
149    (DeclineCall, Foreground),
150    (DeleteProjectEntry, Foreground),
151    (Error, Foreground),
152    (Follow, Foreground),
153    (FollowResponse, Foreground),
154    (FormatBuffers, Foreground),
155    (FormatBuffersResponse, Foreground),
156    (FuzzySearchUsers, Foreground),
157    (GetChannelMessages, Foreground),
158    (GetChannelMessagesResponse, Foreground),
159    (GetChannels, Foreground),
160    (GetChannelsResponse, Foreground),
161    (GetCodeActions, Background),
162    (GetCodeActionsResponse, Background),
163    (GetHover, Background),
164    (GetHoverResponse, Background),
165    (GetCompletions, Background),
166    (GetCompletionsResponse, Background),
167    (GetDefinition, Background),
168    (GetDefinitionResponse, Background),
169    (GetTypeDefinition, Background),
170    (GetTypeDefinitionResponse, Background),
171    (GetDocumentHighlights, Background),
172    (GetDocumentHighlightsResponse, Background),
173    (GetReferences, Background),
174    (GetReferencesResponse, Background),
175    (GetProjectSymbols, Background),
176    (GetProjectSymbolsResponse, Background),
177    (GetUsers, Foreground),
178    (Hello, Foreground),
179    (IncomingCall, Foreground),
180    (UsersResponse, Foreground),
181    (JoinChannel, Foreground),
182    (JoinChannelResponse, Foreground),
183    (JoinProject, Foreground),
184    (JoinProjectResponse, Foreground),
185    (JoinRoom, Foreground),
186    (JoinRoomResponse, Foreground),
187    (LeaveChannel, Foreground),
188    (LeaveProject, Foreground),
189    (LeaveRoom, Foreground),
190    (OpenBufferById, Background),
191    (OpenBufferByPath, Background),
192    (OpenBufferForSymbol, Background),
193    (OpenBufferForSymbolResponse, Background),
194    (OpenBufferResponse, Background),
195    (PerformRename, Background),
196    (PerformRenameResponse, Background),
197    (Ping, Foreground),
198    (PrepareRename, Background),
199    (PrepareRenameResponse, Background),
200    (ProjectEntryResponse, Foreground),
201    (RejoinRoom, Foreground),
202    (RejoinRoomResponse, Foreground),
203    (RemoveContact, Foreground),
204    (ReloadBuffers, Foreground),
205    (ReloadBuffersResponse, Foreground),
206    (RemoveProjectCollaborator, Foreground),
207    (RenameProjectEntry, Foreground),
208    (RequestContact, Foreground),
209    (RespondToContactRequest, Foreground),
210    (RoomUpdated, Foreground),
211    (SaveBuffer, Foreground),
212    (SearchProject, Background),
213    (SearchProjectResponse, Background),
214    (SendChannelMessage, Foreground),
215    (SendChannelMessageResponse, Foreground),
216    (ShareProject, Foreground),
217    (ShareProjectResponse, Foreground),
218    (ShowContacts, Foreground),
219    (StartLanguageServer, Foreground),
220    (SynchronizeBuffers, Foreground),
221    (SynchronizeBuffersResponse, Foreground),
222    (Test, Foreground),
223    (Unfollow, Foreground),
224    (UnshareProject, Foreground),
225    (UpdateBuffer, Foreground),
226    (UpdateBufferFile, Foreground),
227    (UpdateContacts, Foreground),
228    (UpdateDiagnosticSummary, Foreground),
229    (UpdateFollowers, Foreground),
230    (UpdateInviteInfo, Foreground),
231    (UpdateLanguageServer, Foreground),
232    (UpdateParticipantLocation, Foreground),
233    (UpdateProject, Foreground),
234    (UpdateProjectCollaborator, Foreground),
235    (UpdateWorktree, Foreground),
236    (UpdateDiffBase, Foreground),
237    (GetPrivateUserInfo, Foreground),
238    (GetPrivateUserInfoResponse, Foreground),
239);
240
241request_messages!(
242    (ApplyCodeAction, ApplyCodeActionResponse),
243    (
244        ApplyCompletionAdditionalEdits,
245        ApplyCompletionAdditionalEditsResponse
246    ),
247    (Call, Ack),
248    (CancelCall, Ack),
249    (CopyProjectEntry, ProjectEntryResponse),
250    (CreateProjectEntry, ProjectEntryResponse),
251    (CreateRoom, CreateRoomResponse),
252    (DeclineCall, Ack),
253    (DeleteProjectEntry, ProjectEntryResponse),
254    (Follow, FollowResponse),
255    (FormatBuffers, FormatBuffersResponse),
256    (GetChannelMessages, GetChannelMessagesResponse),
257    (GetChannels, GetChannelsResponse),
258    (GetCodeActions, GetCodeActionsResponse),
259    (GetHover, GetHoverResponse),
260    (GetCompletions, GetCompletionsResponse),
261    (GetDefinition, GetDefinitionResponse),
262    (GetTypeDefinition, GetTypeDefinitionResponse),
263    (GetDocumentHighlights, GetDocumentHighlightsResponse),
264    (GetReferences, GetReferencesResponse),
265    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
266    (GetProjectSymbols, GetProjectSymbolsResponse),
267    (FuzzySearchUsers, UsersResponse),
268    (GetUsers, UsersResponse),
269    (JoinChannel, JoinChannelResponse),
270    (JoinProject, JoinProjectResponse),
271    (JoinRoom, JoinRoomResponse),
272    (LeaveRoom, Ack),
273    (RejoinRoom, RejoinRoomResponse),
274    (IncomingCall, Ack),
275    (OpenBufferById, OpenBufferResponse),
276    (OpenBufferByPath, OpenBufferResponse),
277    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
278    (Ping, Ack),
279    (PerformRename, PerformRenameResponse),
280    (PrepareRename, PrepareRenameResponse),
281    (ReloadBuffers, ReloadBuffersResponse),
282    (RequestContact, Ack),
283    (RemoveContact, Ack),
284    (RespondToContactRequest, Ack),
285    (RenameProjectEntry, ProjectEntryResponse),
286    (SaveBuffer, BufferSaved),
287    (SearchProject, SearchProjectResponse),
288    (SendChannelMessage, SendChannelMessageResponse),
289    (ShareProject, ShareProjectResponse),
290    (SynchronizeBuffers, SynchronizeBuffersResponse),
291    (Test, Test),
292    (UpdateBuffer, Ack),
293    (UpdateParticipantLocation, Ack),
294    (UpdateProject, Ack),
295    (UpdateWorktree, Ack),
296);
297
298entity_messages!(
299    project_id,
300    AddProjectCollaborator,
301    ApplyCodeAction,
302    ApplyCompletionAdditionalEdits,
303    BufferReloaded,
304    BufferSaved,
305    CopyProjectEntry,
306    CreateBufferForPeer,
307    CreateProjectEntry,
308    DeleteProjectEntry,
309    Follow,
310    FormatBuffers,
311    GetCodeActions,
312    GetCompletions,
313    GetDefinition,
314    GetTypeDefinition,
315    GetDocumentHighlights,
316    GetHover,
317    GetReferences,
318    GetProjectSymbols,
319    JoinProject,
320    LeaveProject,
321    OpenBufferById,
322    OpenBufferByPath,
323    OpenBufferForSymbol,
324    PerformRename,
325    PrepareRename,
326    ReloadBuffers,
327    RemoveProjectCollaborator,
328    RenameProjectEntry,
329    SaveBuffer,
330    SearchProject,
331    StartLanguageServer,
332    SynchronizeBuffers,
333    Unfollow,
334    UnshareProject,
335    UpdateBuffer,
336    UpdateBufferFile,
337    UpdateDiagnosticSummary,
338    UpdateFollowers,
339    UpdateLanguageServer,
340    UpdateProject,
341    UpdateProjectCollaborator,
342    UpdateWorktree,
343    UpdateDiffBase
344);
345
346entity_messages!(channel_id, ChannelMessageSent);
347
348const KIB: usize = 1024;
349const MIB: usize = KIB * 1024;
350const MAX_BUFFER_LEN: usize = MIB;
351
352/// A stream of protobuf messages.
353pub struct MessageStream<S> {
354    stream: S,
355    encoding_buffer: Vec<u8>,
356}
357
358#[allow(clippy::large_enum_variant)]
359#[derive(Debug)]
360pub enum Message {
361    Envelope(Envelope),
362    Ping,
363    Pong,
364}
365
366impl<S> MessageStream<S> {
367    pub fn new(stream: S) -> Self {
368        Self {
369            stream,
370            encoding_buffer: Vec::new(),
371        }
372    }
373
374    pub fn inner_mut(&mut self) -> &mut S {
375        &mut self.stream
376    }
377}
378
379impl<S> MessageStream<S>
380where
381    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
382{
383    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
384        #[cfg(any(test, feature = "test-support"))]
385        const COMPRESSION_LEVEL: i32 = -7;
386
387        #[cfg(not(any(test, feature = "test-support")))]
388        const COMPRESSION_LEVEL: i32 = 4;
389
390        match message {
391            Message::Envelope(message) => {
392                self.encoding_buffer.reserve(message.encoded_len());
393                message
394                    .encode(&mut self.encoding_buffer)
395                    .map_err(io::Error::from)?;
396                let buffer =
397                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
398                        .unwrap();
399
400                self.encoding_buffer.clear();
401                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
402                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
403            }
404            Message::Ping => {
405                self.stream
406                    .send(WebSocketMessage::Ping(Default::default()))
407                    .await?;
408            }
409            Message::Pong => {
410                self.stream
411                    .send(WebSocketMessage::Pong(Default::default()))
412                    .await?;
413            }
414        }
415
416        Ok(())
417    }
418}
419
420impl<S> MessageStream<S>
421where
422    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
423{
424    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
425        while let Some(bytes) = self.stream.next().await {
426            match bytes? {
427                WebSocketMessage::Binary(bytes) => {
428                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
429                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
430                        .map_err(io::Error::from)?;
431
432                    self.encoding_buffer.clear();
433                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
434                    return Ok(Message::Envelope(envelope));
435                }
436                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
437                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
438                WebSocketMessage::Close(_) => break,
439                _ => {}
440            }
441        }
442        Err(anyhow!("connection closed"))
443    }
444}
445
446impl From<Timestamp> for SystemTime {
447    fn from(val: Timestamp) -> Self {
448        UNIX_EPOCH
449            .checked_add(Duration::new(val.seconds, val.nanos))
450            .unwrap()
451    }
452}
453
454impl From<SystemTime> for Timestamp {
455    fn from(time: SystemTime) -> Self {
456        let duration = time.duration_since(UNIX_EPOCH).unwrap();
457        Self {
458            seconds: duration.as_secs(),
459            nanos: duration.subsec_nanos(),
460        }
461    }
462}
463
464impl From<u128> for Nonce {
465    fn from(nonce: u128) -> Self {
466        let upper_half = (nonce >> 64) as u64;
467        let lower_half = nonce as u64;
468        Self {
469            upper_half,
470            lower_half,
471        }
472    }
473}
474
475impl From<Nonce> for u128 {
476    fn from(nonce: Nonce) -> Self {
477        let upper_half = (nonce.upper_half as u128) << 64;
478        let lower_half = nonce.lower_half as u128;
479        upper_half | lower_half
480    }
481}
482
483pub fn split_worktree_update(
484    mut message: UpdateWorktree,
485    max_chunk_size: usize,
486) -> impl Iterator<Item = UpdateWorktree> {
487    let mut done_files = false;
488    let mut done_statuses = false;
489    let mut repository_index = 0;
490    iter::from_fn(move || {
491        if done_files && done_statuses {
492            return None;
493        }
494
495        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
496        let updated_entries = message
497            .updated_entries
498            .drain(..updated_entries_chunk_size)
499            .collect();
500
501        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
502        let removed_entries = message
503            .removed_entries
504            .drain(..removed_entries_chunk_size)
505            .collect();
506
507        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
508
509        // Wait to send repositories until after we've guaranteed that their associated entries
510        // will be read
511        let updated_repositories = if done_files {
512            let mut total_statuses = 0;
513            let mut updated_repositories = Vec::new();
514            while total_statuses < max_chunk_size
515                && repository_index < message.updated_repositories.len()
516            {
517                let updated_statuses_chunk_size = cmp::min(
518                    message.updated_repositories[repository_index]
519                        .updated_worktree_statuses
520                        .len(),
521                    max_chunk_size - total_statuses,
522                );
523
524                let updated_statuses: Vec<_> = message.updated_repositories[repository_index]
525                    .updated_worktree_statuses
526                    .drain(..updated_statuses_chunk_size)
527                    .collect();
528
529                total_statuses += updated_statuses.len();
530
531                let done_this_repo = message.updated_repositories[repository_index]
532                    .updated_worktree_statuses
533                    .is_empty();
534
535                let removed_repo_paths = if done_this_repo {
536                    mem::take(
537                        &mut message.updated_repositories[repository_index]
538                            .removed_worktree_repo_paths,
539                    )
540                } else {
541                    Default::default()
542                };
543
544                updated_repositories.push(RepositoryEntry {
545                    work_directory_id: message.updated_repositories[repository_index]
546                        .work_directory_id,
547                    branch: message.updated_repositories[repository_index]
548                        .branch
549                        .clone(),
550                    updated_worktree_statuses: updated_statuses,
551                    removed_worktree_repo_paths: removed_repo_paths,
552                });
553
554                if done_this_repo {
555                    repository_index += 1;
556                }
557            }
558
559            updated_repositories
560        } else {
561            Default::default()
562        };
563
564        let removed_repositories = if done_files && done_statuses {
565            mem::take(&mut message.removed_repositories)
566        } else {
567            Default::default()
568        };
569
570        done_statuses = repository_index >= message.updated_repositories.len();
571
572        Some(UpdateWorktree {
573            project_id: message.project_id,
574            worktree_id: message.worktree_id,
575            root_name: message.root_name.clone(),
576            abs_path: message.abs_path.clone(),
577            updated_entries,
578            removed_entries,
579            scan_id: message.scan_id,
580            is_last_update: done_files && message.is_last_update,
581            updated_repositories,
582            removed_repositories,
583        })
584    })
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    #[gpui::test]
592    async fn test_buffer_size() {
593        let (tx, rx) = futures::channel::mpsc::unbounded();
594        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
595        sink.write(Message::Envelope(Envelope {
596            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
597                root_name: "abcdefg".repeat(10),
598                ..Default::default()
599            })),
600            ..Default::default()
601        }))
602        .await
603        .unwrap();
604        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
605        sink.write(Message::Envelope(Envelope {
606            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
607                root_name: "abcdefg".repeat(1000000),
608                ..Default::default()
609            })),
610            ..Default::default()
611        }))
612        .await
613        .unwrap();
614        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
615
616        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
617        stream.read().await.unwrap();
618        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
619        stream.read().await.unwrap();
620        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
621    }
622
623    #[gpui::test]
624    fn test_converting_peer_id_from_and_to_u64() {
625        let peer_id = PeerId {
626            owner_id: 10,
627            id: 3,
628        };
629        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
630        let peer_id = PeerId {
631            owner_id: u32::MAX,
632            id: 3,
633        };
634        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
635        let peer_id = PeerId {
636            owner_id: 10,
637            id: u32::MAX,
638        };
639        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
640        let peer_id = PeerId {
641            owner_id: u32::MAX,
642            id: u32::MAX,
643        };
644        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
645    }
646}