proto.rs

  1#![allow(non_snake_case)]
  2
  3use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  4use anyhow::{anyhow, Result};
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use collections::HashMap;
  7use futures::{SinkExt as _, StreamExt as _};
  8use prost::Message as _;
  9use serde::Serialize;
 10use std::any::{Any, TypeId};
 11use std::{
 12    cmp,
 13    fmt::Debug,
 14    io, iter,
 15    time::{Duration, SystemTime, UNIX_EPOCH},
 16};
 17use std::{fmt, mem};
 18
 19include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 20
 21pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 22    const NAME: &'static str;
 23    const PRIORITY: MessagePriority;
 24    fn into_envelope(
 25        self,
 26        id: u32,
 27        responding_to: Option<u32>,
 28        original_sender_id: Option<PeerId>,
 29    ) -> Envelope;
 30    fn from_envelope(envelope: Envelope) -> Option<Self>;
 31}
 32
 33pub trait EntityMessage: EnvelopedMessage {
 34    fn remote_entity_id(&self) -> u64;
 35}
 36
 37pub trait RequestMessage: EnvelopedMessage {
 38    type Response: EnvelopedMessage;
 39}
 40
 41pub trait AnyTypedEnvelope: 'static + Send + Sync {
 42    fn payload_type_id(&self) -> TypeId;
 43    fn payload_type_name(&self) -> &'static str;
 44    fn as_any(&self) -> &dyn Any;
 45    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 46    fn is_background(&self) -> bool;
 47    fn original_sender_id(&self) -> Option<PeerId>;
 48    fn sender_id(&self) -> ConnectionId;
 49    fn message_id(&self) -> u32;
 50}
 51
 52pub enum MessagePriority {
 53    Foreground,
 54    Background,
 55}
 56
 57impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 58    fn payload_type_id(&self) -> TypeId {
 59        TypeId::of::<T>()
 60    }
 61
 62    fn payload_type_name(&self) -> &'static str {
 63        T::NAME
 64    }
 65
 66    fn as_any(&self) -> &dyn Any {
 67        self
 68    }
 69
 70    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 71        self
 72    }
 73
 74    fn is_background(&self) -> bool {
 75        matches!(T::PRIORITY, MessagePriority::Background)
 76    }
 77
 78    fn original_sender_id(&self) -> Option<PeerId> {
 79        self.original_sender_id
 80    }
 81
 82    fn sender_id(&self) -> ConnectionId {
 83        self.sender_id
 84    }
 85
 86    fn message_id(&self) -> u32 {
 87        self.message_id
 88    }
 89}
 90
 91impl PeerId {
 92    pub fn from_u64(peer_id: u64) -> Self {
 93        let owner_id = (peer_id >> 32) as u32;
 94        let id = peer_id as u32;
 95        Self { owner_id, id }
 96    }
 97
 98    pub fn as_u64(self) -> u64 {
 99        ((self.owner_id as u64) << 32) | (self.id as u64)
100    }
101}
102
103impl Copy for PeerId {}
104
105impl Eq for PeerId {}
106
107impl Ord for PeerId {
108    fn cmp(&self, other: &Self) -> cmp::Ordering {
109        self.owner_id
110            .cmp(&other.owner_id)
111            .then_with(|| self.id.cmp(&other.id))
112    }
113}
114
115impl PartialOrd for PeerId {
116    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
117        Some(self.cmp(other))
118    }
119}
120
121impl std::hash::Hash for PeerId {
122    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123        self.owner_id.hash(state);
124        self.id.hash(state);
125    }
126}
127
128impl fmt::Display for PeerId {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "{}/{}", self.owner_id, self.id)
131    }
132}
133
134messages!(
135    (Ack, Foreground),
136    (AddProjectCollaborator, Foreground),
137    (ApplyCodeAction, Background),
138    (ApplyCodeActionResponse, Background),
139    (ApplyCompletionAdditionalEdits, Background),
140    (ApplyCompletionAdditionalEditsResponse, Background),
141    (BufferReloaded, Foreground),
142    (BufferSaved, Foreground),
143    (Call, Foreground),
144    (CallCanceled, Foreground),
145    (CancelCall, Foreground),
146    (CopyProjectEntry, Foreground),
147    (CreateBufferForPeer, Foreground),
148    (CreateChannel, Foreground),
149    (ChannelResponse, Foreground),
150    (CreateProjectEntry, Foreground),
151    (CreateRoom, Foreground),
152    (CreateRoomResponse, Foreground),
153    (DeclineCall, Foreground),
154    (DeleteProjectEntry, Foreground),
155    (Error, Foreground),
156    (ExpandProjectEntry, Foreground),
157    (Follow, Foreground),
158    (FollowResponse, Foreground),
159    (FormatBuffers, Foreground),
160    (FormatBuffersResponse, Foreground),
161    (FuzzySearchUsers, Foreground),
162    (GetCodeActions, Background),
163    (GetCodeActionsResponse, Background),
164    (GetHover, Background),
165    (GetHoverResponse, Background),
166    (GetCompletions, Background),
167    (GetCompletionsResponse, Background),
168    (GetDefinition, Background),
169    (GetDefinitionResponse, Background),
170    (GetTypeDefinition, Background),
171    (GetTypeDefinitionResponse, Background),
172    (GetDocumentHighlights, Background),
173    (GetDocumentHighlightsResponse, Background),
174    (GetReferences, Background),
175    (GetReferencesResponse, Background),
176    (GetProjectSymbols, Background),
177    (GetProjectSymbolsResponse, Background),
178    (GetUsers, Foreground),
179    (Hello, Foreground),
180    (IncomingCall, Foreground),
181    (InviteChannelMember, Foreground),
182    (UsersResponse, Foreground),
183    (JoinProject, Foreground),
184    (JoinProjectResponse, Foreground),
185    (JoinRoom, Foreground),
186    (JoinRoomResponse, Foreground),
187    (LeaveProject, Foreground),
188    (LeaveRoom, Foreground),
189    (OpenBufferById, Background),
190    (OpenBufferByPath, Background),
191    (OpenBufferForSymbol, Background),
192    (OpenBufferForSymbolResponse, Background),
193    (OpenBufferResponse, Background),
194    (PerformRename, Background),
195    (PerformRenameResponse, Background),
196    (OnTypeFormatting, Background),
197    (OnTypeFormattingResponse, Background),
198    (InlayHints, Background),
199    (InlayHintsResponse, Background),
200    (RefreshInlayHints, Foreground),
201    (Ping, Foreground),
202    (PrepareRename, Background),
203    (PrepareRenameResponse, Background),
204    (ExpandProjectEntryResponse, Foreground),
205    (ProjectEntryResponse, Foreground),
206    (RejoinRoom, Foreground),
207    (RejoinRoomResponse, Foreground),
208    (RemoveContact, Foreground),
209    (RemoveChannelMember, Foreground),
210    (ReloadBuffers, Foreground),
211    (ReloadBuffersResponse, Foreground),
212    (RemoveProjectCollaborator, Foreground),
213    (RenameProjectEntry, Foreground),
214    (RequestContact, Foreground),
215    (RespondToContactRequest, Foreground),
216    (RespondToChannelInvite, Foreground),
217    (JoinChannel, Foreground),
218    (RoomUpdated, Foreground),
219    (SaveBuffer, Foreground),
220    (RenameChannel, Foreground),
221    (SetChannelMemberAdmin, Foreground),
222    (SearchProject, Background),
223    (SearchProjectResponse, Background),
224    (ShareProject, Foreground),
225    (ShareProjectResponse, Foreground),
226    (ShowContacts, Foreground),
227    (StartLanguageServer, Foreground),
228    (SynchronizeBuffers, Foreground),
229    (SynchronizeBuffersResponse, Foreground),
230    (Test, Foreground),
231    (Unfollow, Foreground),
232    (UnshareProject, Foreground),
233    (UpdateBuffer, Foreground),
234    (UpdateBufferFile, Foreground),
235    (UpdateContacts, Foreground),
236    (RemoveChannel, Foreground),
237    (UpdateChannels, Foreground),
238    (UpdateDiagnosticSummary, Foreground),
239    (UpdateFollowers, Foreground),
240    (UpdateInviteInfo, Foreground),
241    (UpdateLanguageServer, Foreground),
242    (UpdateParticipantLocation, Foreground),
243    (UpdateProject, Foreground),
244    (UpdateProjectCollaborator, Foreground),
245    (UpdateWorktree, Foreground),
246    (UpdateWorktreeSettings, Foreground),
247    (UpdateDiffBase, Foreground),
248    (GetPrivateUserInfo, Foreground),
249    (GetPrivateUserInfoResponse, Foreground),
250    (GetChannelMembers, Foreground),
251    (GetChannelMembersResponse, Foreground)
252);
253
254request_messages!(
255    (ApplyCodeAction, ApplyCodeActionResponse),
256    (
257        ApplyCompletionAdditionalEdits,
258        ApplyCompletionAdditionalEditsResponse
259    ),
260    (Call, Ack),
261    (CancelCall, Ack),
262    (CopyProjectEntry, ProjectEntryResponse),
263    (CreateProjectEntry, ProjectEntryResponse),
264    (CreateRoom, CreateRoomResponse),
265    (CreateChannel, ChannelResponse),
266    (DeclineCall, Ack),
267    (DeleteProjectEntry, ProjectEntryResponse),
268    (ExpandProjectEntry, ExpandProjectEntryResponse),
269    (Follow, FollowResponse),
270    (FormatBuffers, FormatBuffersResponse),
271    (GetCodeActions, GetCodeActionsResponse),
272    (GetHover, GetHoverResponse),
273    (GetCompletions, GetCompletionsResponse),
274    (GetDefinition, GetDefinitionResponse),
275    (GetTypeDefinition, GetTypeDefinitionResponse),
276    (GetDocumentHighlights, GetDocumentHighlightsResponse),
277    (GetReferences, GetReferencesResponse),
278    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
279    (GetProjectSymbols, GetProjectSymbolsResponse),
280    (FuzzySearchUsers, UsersResponse),
281    (GetUsers, UsersResponse),
282    (InviteChannelMember, Ack),
283    (JoinProject, JoinProjectResponse),
284    (JoinRoom, JoinRoomResponse),
285    (LeaveRoom, Ack),
286    (RejoinRoom, RejoinRoomResponse),
287    (IncomingCall, Ack),
288    (OpenBufferById, OpenBufferResponse),
289    (OpenBufferByPath, OpenBufferResponse),
290    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
291    (Ping, Ack),
292    (PerformRename, PerformRenameResponse),
293    (PrepareRename, PrepareRenameResponse),
294    (OnTypeFormatting, OnTypeFormattingResponse),
295    (InlayHints, InlayHintsResponse),
296    (RefreshInlayHints, Ack),
297    (ReloadBuffers, ReloadBuffersResponse),
298    (RequestContact, Ack),
299    (RemoveChannelMember, Ack),
300    (RemoveContact, Ack),
301    (RespondToContactRequest, Ack),
302    (RespondToChannelInvite, Ack),
303    (SetChannelMemberAdmin, Ack),
304    (GetChannelMembers, GetChannelMembersResponse),
305    (JoinChannel, JoinRoomResponse),
306    (RemoveChannel, Ack),
307    (RenameProjectEntry, ProjectEntryResponse),
308    (RenameChannel, ChannelResponse),
309    (SaveBuffer, BufferSaved),
310    (SearchProject, SearchProjectResponse),
311    (ShareProject, ShareProjectResponse),
312    (SynchronizeBuffers, SynchronizeBuffersResponse),
313    (Test, Test),
314    (UpdateBuffer, Ack),
315    (UpdateParticipantLocation, Ack),
316    (UpdateProject, Ack),
317    (UpdateWorktree, Ack),
318);
319
320entity_messages!(
321    project_id,
322    AddProjectCollaborator,
323    ApplyCodeAction,
324    ApplyCompletionAdditionalEdits,
325    BufferReloaded,
326    BufferSaved,
327    CopyProjectEntry,
328    CreateBufferForPeer,
329    CreateProjectEntry,
330    DeleteProjectEntry,
331    ExpandProjectEntry,
332    Follow,
333    FormatBuffers,
334    GetCodeActions,
335    GetCompletions,
336    GetDefinition,
337    GetTypeDefinition,
338    GetDocumentHighlights,
339    GetHover,
340    GetReferences,
341    GetProjectSymbols,
342    JoinProject,
343    LeaveProject,
344    OpenBufferById,
345    OpenBufferByPath,
346    OpenBufferForSymbol,
347    PerformRename,
348    OnTypeFormatting,
349    InlayHints,
350    RefreshInlayHints,
351    PrepareRename,
352    ReloadBuffers,
353    RemoveProjectCollaborator,
354    RenameProjectEntry,
355    SaveBuffer,
356    SearchProject,
357    StartLanguageServer,
358    SynchronizeBuffers,
359    Unfollow,
360    UnshareProject,
361    UpdateBuffer,
362    UpdateBufferFile,
363    UpdateDiagnosticSummary,
364    UpdateFollowers,
365    UpdateLanguageServer,
366    UpdateProject,
367    UpdateProjectCollaborator,
368    UpdateWorktree,
369    UpdateWorktreeSettings,
370    UpdateDiffBase
371);
372
373const KIB: usize = 1024;
374const MIB: usize = KIB * 1024;
375const MAX_BUFFER_LEN: usize = MIB;
376
377/// A stream of protobuf messages.
378pub struct MessageStream<S> {
379    stream: S,
380    encoding_buffer: Vec<u8>,
381}
382
383#[allow(clippy::large_enum_variant)]
384#[derive(Debug)]
385pub enum Message {
386    Envelope(Envelope),
387    Ping,
388    Pong,
389}
390
391impl<S> MessageStream<S> {
392    pub fn new(stream: S) -> Self {
393        Self {
394            stream,
395            encoding_buffer: Vec::new(),
396        }
397    }
398
399    pub fn inner_mut(&mut self) -> &mut S {
400        &mut self.stream
401    }
402}
403
404impl<S> MessageStream<S>
405where
406    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
407{
408    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
409        #[cfg(any(test, feature = "test-support"))]
410        const COMPRESSION_LEVEL: i32 = -7;
411
412        #[cfg(not(any(test, feature = "test-support")))]
413        const COMPRESSION_LEVEL: i32 = 4;
414
415        match message {
416            Message::Envelope(message) => {
417                self.encoding_buffer.reserve(message.encoded_len());
418                message
419                    .encode(&mut self.encoding_buffer)
420                    .map_err(io::Error::from)?;
421                let buffer =
422                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
423                        .unwrap();
424
425                self.encoding_buffer.clear();
426                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
427                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
428            }
429            Message::Ping => {
430                self.stream
431                    .send(WebSocketMessage::Ping(Default::default()))
432                    .await?;
433            }
434            Message::Pong => {
435                self.stream
436                    .send(WebSocketMessage::Pong(Default::default()))
437                    .await?;
438            }
439        }
440
441        Ok(())
442    }
443}
444
445impl<S> MessageStream<S>
446where
447    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
448{
449    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
450        while let Some(bytes) = self.stream.next().await {
451            match bytes? {
452                WebSocketMessage::Binary(bytes) => {
453                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
454                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
455                        .map_err(io::Error::from)?;
456
457                    self.encoding_buffer.clear();
458                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
459                    return Ok(Message::Envelope(envelope));
460                }
461                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
462                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
463                WebSocketMessage::Close(_) => break,
464                _ => {}
465            }
466        }
467        Err(anyhow!("connection closed"))
468    }
469}
470
471impl From<Timestamp> for SystemTime {
472    fn from(val: Timestamp) -> Self {
473        UNIX_EPOCH
474            .checked_add(Duration::new(val.seconds, val.nanos))
475            .unwrap()
476    }
477}
478
479impl From<SystemTime> for Timestamp {
480    fn from(time: SystemTime) -> Self {
481        let duration = time.duration_since(UNIX_EPOCH).unwrap();
482        Self {
483            seconds: duration.as_secs(),
484            nanos: duration.subsec_nanos(),
485        }
486    }
487}
488
489impl From<u128> for Nonce {
490    fn from(nonce: u128) -> Self {
491        let upper_half = (nonce >> 64) as u64;
492        let lower_half = nonce as u64;
493        Self {
494            upper_half,
495            lower_half,
496        }
497    }
498}
499
500impl From<Nonce> for u128 {
501    fn from(nonce: Nonce) -> Self {
502        let upper_half = (nonce.upper_half as u128) << 64;
503        let lower_half = nonce.lower_half as u128;
504        upper_half | lower_half
505    }
506}
507
508pub fn split_worktree_update(
509    mut message: UpdateWorktree,
510    max_chunk_size: usize,
511) -> impl Iterator<Item = UpdateWorktree> {
512    let mut done_files = false;
513
514    let mut repository_map = message
515        .updated_repositories
516        .into_iter()
517        .map(|repo| (repo.work_directory_id, repo))
518        .collect::<HashMap<_, _>>();
519
520    iter::from_fn(move || {
521        if done_files {
522            return None;
523        }
524
525        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
526        let updated_entries: Vec<_> = message
527            .updated_entries
528            .drain(..updated_entries_chunk_size)
529            .collect();
530
531        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
532        let removed_entries = message
533            .removed_entries
534            .drain(..removed_entries_chunk_size)
535            .collect();
536
537        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
538
539        let mut updated_repositories = Vec::new();
540
541        if !repository_map.is_empty() {
542            for entry in &updated_entries {
543                if let Some(repo) = repository_map.remove(&entry.id) {
544                    updated_repositories.push(repo)
545                }
546            }
547        }
548
549        let removed_repositories = if done_files {
550            mem::take(&mut message.removed_repositories)
551        } else {
552            Default::default()
553        };
554
555        if done_files {
556            updated_repositories.extend(mem::take(&mut repository_map).into_values());
557        }
558
559        Some(UpdateWorktree {
560            project_id: message.project_id,
561            worktree_id: message.worktree_id,
562            root_name: message.root_name.clone(),
563            abs_path: message.abs_path.clone(),
564            updated_entries,
565            removed_entries,
566            scan_id: message.scan_id,
567            is_last_update: done_files && message.is_last_update,
568            updated_repositories,
569            removed_repositories,
570        })
571    })
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577
578    #[gpui::test]
579    async fn test_buffer_size() {
580        let (tx, rx) = futures::channel::mpsc::unbounded();
581        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
582        sink.write(Message::Envelope(Envelope {
583            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
584                root_name: "abcdefg".repeat(10),
585                ..Default::default()
586            })),
587            ..Default::default()
588        }))
589        .await
590        .unwrap();
591        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
592        sink.write(Message::Envelope(Envelope {
593            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
594                root_name: "abcdefg".repeat(1000000),
595                ..Default::default()
596            })),
597            ..Default::default()
598        }))
599        .await
600        .unwrap();
601        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
602
603        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
604        stream.read().await.unwrap();
605        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
606        stream.read().await.unwrap();
607        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
608    }
609
610    #[gpui::test]
611    fn test_converting_peer_id_from_and_to_u64() {
612        let peer_id = PeerId {
613            owner_id: 10,
614            id: 3,
615        };
616        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
617        let peer_id = PeerId {
618            owner_id: u32::MAX,
619            id: 3,
620        };
621        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
622        let peer_id = PeerId {
623            owner_id: 10,
624            id: u32::MAX,
625        };
626        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
627        let peer_id = PeerId {
628            owner_id: u32::MAX,
629            id: u32::MAX,
630        };
631        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
632    }
633}