proto.rs

  1#![allow(non_snake_case)]
  2
  3use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  4use anyhow::{anyhow, Result};
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use collections::HashMap;
  7use futures::{SinkExt as _, StreamExt as _};
  8use prost::Message as _;
  9use serde::Serialize;
 10use std::any::{Any, TypeId};
 11use std::{
 12    cmp,
 13    fmt::Debug,
 14    io, iter,
 15    time::{Duration, SystemTime, UNIX_EPOCH},
 16};
 17use std::{fmt, mem};
 18
 19include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 20
 21pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 22    const NAME: &'static str;
 23    const PRIORITY: MessagePriority;
 24    fn into_envelope(
 25        self,
 26        id: u32,
 27        responding_to: Option<u32>,
 28        original_sender_id: Option<PeerId>,
 29    ) -> Envelope;
 30    fn from_envelope(envelope: Envelope) -> Option<Self>;
 31}
 32
 33pub trait EntityMessage: EnvelopedMessage {
 34    fn remote_entity_id(&self) -> u64;
 35}
 36
 37pub trait RequestMessage: EnvelopedMessage {
 38    type Response: EnvelopedMessage;
 39}
 40
 41pub trait AnyTypedEnvelope: 'static + Send + Sync {
 42    fn payload_type_id(&self) -> TypeId;
 43    fn payload_type_name(&self) -> &'static str;
 44    fn as_any(&self) -> &dyn Any;
 45    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 46    fn is_background(&self) -> bool;
 47    fn original_sender_id(&self) -> Option<PeerId>;
 48    fn sender_id(&self) -> ConnectionId;
 49    fn message_id(&self) -> u32;
 50}
 51
 52pub enum MessagePriority {
 53    Foreground,
 54    Background,
 55}
 56
 57impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 58    fn payload_type_id(&self) -> TypeId {
 59        TypeId::of::<T>()
 60    }
 61
 62    fn payload_type_name(&self) -> &'static str {
 63        T::NAME
 64    }
 65
 66    fn as_any(&self) -> &dyn Any {
 67        self
 68    }
 69
 70    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 71        self
 72    }
 73
 74    fn is_background(&self) -> bool {
 75        matches!(T::PRIORITY, MessagePriority::Background)
 76    }
 77
 78    fn original_sender_id(&self) -> Option<PeerId> {
 79        self.original_sender_id
 80    }
 81
 82    fn sender_id(&self) -> ConnectionId {
 83        self.sender_id
 84    }
 85
 86    fn message_id(&self) -> u32 {
 87        self.message_id
 88    }
 89}
 90
 91impl PeerId {
 92    pub fn from_u64(peer_id: u64) -> Self {
 93        let owner_id = (peer_id >> 32) as u32;
 94        let id = peer_id as u32;
 95        Self { owner_id, id }
 96    }
 97
 98    pub fn as_u64(self) -> u64 {
 99        ((self.owner_id as u64) << 32) | (self.id as u64)
100    }
101}
102
103impl Copy for PeerId {}
104
105impl Eq for PeerId {}
106
107impl Ord for PeerId {
108    fn cmp(&self, other: &Self) -> cmp::Ordering {
109        self.owner_id
110            .cmp(&other.owner_id)
111            .then_with(|| self.id.cmp(&other.id))
112    }
113}
114
115impl PartialOrd for PeerId {
116    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
117        Some(self.cmp(other))
118    }
119}
120
121impl std::hash::Hash for PeerId {
122    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123        self.owner_id.hash(state);
124        self.id.hash(state);
125    }
126}
127
128impl fmt::Display for PeerId {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "{}/{}", self.owner_id, self.id)
131    }
132}
133
134messages!(
135    (Ack, Foreground),
136    (AddProjectCollaborator, Foreground),
137    (ApplyCodeAction, Background),
138    (ApplyCodeActionResponse, Background),
139    (ApplyCompletionAdditionalEdits, Background),
140    (ApplyCompletionAdditionalEditsResponse, Background),
141    (BufferReloaded, Foreground),
142    (BufferSaved, Foreground),
143    (Call, Foreground),
144    (CallCanceled, Foreground),
145    (CancelCall, Foreground),
146    (CopyProjectEntry, Foreground),
147    (CreateBufferForPeer, Foreground),
148    (CreateChannel, Foreground),
149    (ChannelResponse, Foreground),
150    (CreateProjectEntry, Foreground),
151    (CreateRoom, Foreground),
152    (CreateRoomResponse, Foreground),
153    (DeclineCall, Foreground),
154    (DeleteProjectEntry, Foreground),
155    (Error, Foreground),
156    (ExpandProjectEntry, Foreground),
157    (Follow, Foreground),
158    (FollowResponse, Foreground),
159    (FormatBuffers, Foreground),
160    (FormatBuffersResponse, Foreground),
161    (FuzzySearchUsers, Foreground),
162    (GetCodeActions, Background),
163    (GetCodeActionsResponse, Background),
164    (GetHover, Background),
165    (GetHoverResponse, Background),
166    (GetCompletions, Background),
167    (GetCompletionsResponse, Background),
168    (GetDefinition, Background),
169    (GetDefinitionResponse, Background),
170    (GetTypeDefinition, Background),
171    (GetTypeDefinitionResponse, Background),
172    (GetDocumentHighlights, Background),
173    (GetDocumentHighlightsResponse, Background),
174    (GetReferences, Background),
175    (GetReferencesResponse, Background),
176    (GetProjectSymbols, Background),
177    (GetProjectSymbolsResponse, Background),
178    (GetUsers, Foreground),
179    (Hello, Foreground),
180    (IncomingCall, Foreground),
181    (InviteChannelMember, Foreground),
182    (UsersResponse, Foreground),
183    (JoinProject, Foreground),
184    (JoinProjectResponse, Foreground),
185    (JoinRoom, Foreground),
186    (JoinRoomResponse, Foreground),
187    (LeaveProject, Foreground),
188    (LeaveRoom, Foreground),
189    (OpenBufferById, Background),
190    (OpenBufferByPath, Background),
191    (OpenBufferForSymbol, Background),
192    (OpenBufferForSymbolResponse, Background),
193    (OpenBufferResponse, Background),
194    (PerformRename, Background),
195    (PerformRenameResponse, Background),
196    (OnTypeFormatting, Background),
197    (OnTypeFormattingResponse, Background),
198    (InlayHints, Background),
199    (InlayHintsResponse, Background),
200    (ResolveInlayHint, Background),
201    (ResolveInlayHintResponse, Background),
202    (RefreshInlayHints, Foreground),
203    (Ping, Foreground),
204    (PrepareRename, Background),
205    (PrepareRenameResponse, Background),
206    (ExpandProjectEntryResponse, Foreground),
207    (ProjectEntryResponse, Foreground),
208    (RejoinRoom, Foreground),
209    (RejoinRoomResponse, Foreground),
210    (RemoveContact, Foreground),
211    (RemoveChannelMember, Foreground),
212    (ReloadBuffers, Foreground),
213    (ReloadBuffersResponse, Foreground),
214    (RemoveProjectCollaborator, Foreground),
215    (RenameProjectEntry, Foreground),
216    (RequestContact, Foreground),
217    (RespondToContactRequest, Foreground),
218    (RespondToChannelInvite, Foreground),
219    (JoinChannel, Foreground),
220    (RoomUpdated, Foreground),
221    (SaveBuffer, Foreground),
222    (RenameChannel, Foreground),
223    (SetChannelMemberAdmin, Foreground),
224    (SearchProject, Background),
225    (SearchProjectResponse, Background),
226    (ShareProject, Foreground),
227    (ShareProjectResponse, Foreground),
228    (ShowContacts, Foreground),
229    (StartLanguageServer, Foreground),
230    (SynchronizeBuffers, Foreground),
231    (SynchronizeBuffersResponse, Foreground),
232    (RejoinChannelBuffers, Foreground),
233    (RejoinChannelBuffersResponse, Foreground),
234    (Test, Foreground),
235    (Unfollow, Foreground),
236    (UnshareProject, Foreground),
237    (UpdateBuffer, Foreground),
238    (UpdateBufferFile, Foreground),
239    (UpdateContacts, Foreground),
240    (RemoveChannel, Foreground),
241    (UpdateChannels, Foreground),
242    (UpdateDiagnosticSummary, Foreground),
243    (UpdateFollowers, Foreground),
244    (UpdateInviteInfo, Foreground),
245    (UpdateLanguageServer, Foreground),
246    (UpdateParticipantLocation, Foreground),
247    (UpdateProject, Foreground),
248    (UpdateProjectCollaborator, Foreground),
249    (UpdateWorktree, Foreground),
250    (UpdateWorktreeSettings, Foreground),
251    (UpdateDiffBase, Foreground),
252    (GetPrivateUserInfo, Foreground),
253    (GetPrivateUserInfoResponse, Foreground),
254    (GetChannelMembers, Foreground),
255    (GetChannelMembersResponse, Foreground),
256    (JoinChannelBuffer, Foreground),
257    (JoinChannelBufferResponse, Foreground),
258    (LeaveChannelBuffer, Background),
259    (UpdateChannelBuffer, Foreground),
260    (RemoveChannelBufferCollaborator, Foreground),
261    (AddChannelBufferCollaborator, Foreground),
262);
263
264request_messages!(
265    (ApplyCodeAction, ApplyCodeActionResponse),
266    (
267        ApplyCompletionAdditionalEdits,
268        ApplyCompletionAdditionalEditsResponse
269    ),
270    (Call, Ack),
271    (CancelCall, Ack),
272    (CopyProjectEntry, ProjectEntryResponse),
273    (CreateProjectEntry, ProjectEntryResponse),
274    (CreateRoom, CreateRoomResponse),
275    (CreateChannel, ChannelResponse),
276    (DeclineCall, Ack),
277    (DeleteProjectEntry, ProjectEntryResponse),
278    (ExpandProjectEntry, ExpandProjectEntryResponse),
279    (Follow, FollowResponse),
280    (FormatBuffers, FormatBuffersResponse),
281    (GetCodeActions, GetCodeActionsResponse),
282    (GetHover, GetHoverResponse),
283    (GetCompletions, GetCompletionsResponse),
284    (GetDefinition, GetDefinitionResponse),
285    (GetTypeDefinition, GetTypeDefinitionResponse),
286    (GetDocumentHighlights, GetDocumentHighlightsResponse),
287    (GetReferences, GetReferencesResponse),
288    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
289    (GetProjectSymbols, GetProjectSymbolsResponse),
290    (FuzzySearchUsers, UsersResponse),
291    (GetUsers, UsersResponse),
292    (InviteChannelMember, Ack),
293    (JoinProject, JoinProjectResponse),
294    (JoinRoom, JoinRoomResponse),
295    (LeaveRoom, Ack),
296    (RejoinRoom, RejoinRoomResponse),
297    (IncomingCall, Ack),
298    (OpenBufferById, OpenBufferResponse),
299    (OpenBufferByPath, OpenBufferResponse),
300    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
301    (Ping, Ack),
302    (PerformRename, PerformRenameResponse),
303    (PrepareRename, PrepareRenameResponse),
304    (OnTypeFormatting, OnTypeFormattingResponse),
305    (InlayHints, InlayHintsResponse),
306    (ResolveInlayHint, ResolveInlayHintResponse),
307    (RefreshInlayHints, Ack),
308    (ReloadBuffers, ReloadBuffersResponse),
309    (RequestContact, Ack),
310    (RemoveChannelMember, Ack),
311    (RemoveContact, Ack),
312    (RespondToContactRequest, Ack),
313    (RespondToChannelInvite, Ack),
314    (SetChannelMemberAdmin, Ack),
315    (GetChannelMembers, GetChannelMembersResponse),
316    (JoinChannel, JoinRoomResponse),
317    (RemoveChannel, Ack),
318    (RenameProjectEntry, ProjectEntryResponse),
319    (RenameChannel, ChannelResponse),
320    (SaveBuffer, BufferSaved),
321    (SearchProject, SearchProjectResponse),
322    (ShareProject, ShareProjectResponse),
323    (SynchronizeBuffers, SynchronizeBuffersResponse),
324    (RejoinChannelBuffers, RejoinChannelBuffersResponse),
325    (Test, Test),
326    (UpdateBuffer, Ack),
327    (UpdateParticipantLocation, Ack),
328    (UpdateProject, Ack),
329    (UpdateWorktree, Ack),
330    (JoinChannelBuffer, JoinChannelBufferResponse),
331    (LeaveChannelBuffer, Ack)
332);
333
334entity_messages!(
335    project_id,
336    AddProjectCollaborator,
337    ApplyCodeAction,
338    ApplyCompletionAdditionalEdits,
339    BufferReloaded,
340    BufferSaved,
341    CopyProjectEntry,
342    CreateBufferForPeer,
343    CreateProjectEntry,
344    DeleteProjectEntry,
345    ExpandProjectEntry,
346    Follow,
347    FormatBuffers,
348    GetCodeActions,
349    GetCompletions,
350    GetDefinition,
351    GetTypeDefinition,
352    GetDocumentHighlights,
353    GetHover,
354    GetReferences,
355    GetProjectSymbols,
356    JoinProject,
357    LeaveProject,
358    OpenBufferById,
359    OpenBufferByPath,
360    OpenBufferForSymbol,
361    PerformRename,
362    OnTypeFormatting,
363    InlayHints,
364    ResolveInlayHint,
365    RefreshInlayHints,
366    PrepareRename,
367    ReloadBuffers,
368    RemoveProjectCollaborator,
369    RenameProjectEntry,
370    SaveBuffer,
371    SearchProject,
372    StartLanguageServer,
373    SynchronizeBuffers,
374    Unfollow,
375    UnshareProject,
376    UpdateBuffer,
377    UpdateBufferFile,
378    UpdateDiagnosticSummary,
379    UpdateFollowers,
380    UpdateLanguageServer,
381    UpdateProject,
382    UpdateProjectCollaborator,
383    UpdateWorktree,
384    UpdateWorktreeSettings,
385    UpdateDiffBase
386);
387
388entity_messages!(
389    channel_id,
390    UpdateChannelBuffer,
391    RemoveChannelBufferCollaborator,
392    AddChannelBufferCollaborator
393);
394
395const KIB: usize = 1024;
396const MIB: usize = KIB * 1024;
397const MAX_BUFFER_LEN: usize = MIB;
398
399/// A stream of protobuf messages.
400pub struct MessageStream<S> {
401    stream: S,
402    encoding_buffer: Vec<u8>,
403}
404
405#[allow(clippy::large_enum_variant)]
406#[derive(Debug)]
407pub enum Message {
408    Envelope(Envelope),
409    Ping,
410    Pong,
411}
412
413impl<S> MessageStream<S> {
414    pub fn new(stream: S) -> Self {
415        Self {
416            stream,
417            encoding_buffer: Vec::new(),
418        }
419    }
420
421    pub fn inner_mut(&mut self) -> &mut S {
422        &mut self.stream
423    }
424}
425
426impl<S> MessageStream<S>
427where
428    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
429{
430    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
431        #[cfg(any(test, feature = "test-support"))]
432        const COMPRESSION_LEVEL: i32 = -7;
433
434        #[cfg(not(any(test, feature = "test-support")))]
435        const COMPRESSION_LEVEL: i32 = 4;
436
437        match message {
438            Message::Envelope(message) => {
439                self.encoding_buffer.reserve(message.encoded_len());
440                message
441                    .encode(&mut self.encoding_buffer)
442                    .map_err(io::Error::from)?;
443                let buffer =
444                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
445                        .unwrap();
446
447                self.encoding_buffer.clear();
448                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
449                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
450            }
451            Message::Ping => {
452                self.stream
453                    .send(WebSocketMessage::Ping(Default::default()))
454                    .await?;
455            }
456            Message::Pong => {
457                self.stream
458                    .send(WebSocketMessage::Pong(Default::default()))
459                    .await?;
460            }
461        }
462
463        Ok(())
464    }
465}
466
467impl<S> MessageStream<S>
468where
469    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
470{
471    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
472        while let Some(bytes) = self.stream.next().await {
473            match bytes? {
474                WebSocketMessage::Binary(bytes) => {
475                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
476                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
477                        .map_err(io::Error::from)?;
478
479                    self.encoding_buffer.clear();
480                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
481                    return Ok(Message::Envelope(envelope));
482                }
483                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
484                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
485                WebSocketMessage::Close(_) => break,
486                _ => {}
487            }
488        }
489        Err(anyhow!("connection closed"))
490    }
491}
492
493impl From<Timestamp> for SystemTime {
494    fn from(val: Timestamp) -> Self {
495        UNIX_EPOCH
496            .checked_add(Duration::new(val.seconds, val.nanos))
497            .unwrap()
498    }
499}
500
501impl From<SystemTime> for Timestamp {
502    fn from(time: SystemTime) -> Self {
503        let duration = time.duration_since(UNIX_EPOCH).unwrap();
504        Self {
505            seconds: duration.as_secs(),
506            nanos: duration.subsec_nanos(),
507        }
508    }
509}
510
511impl From<u128> for Nonce {
512    fn from(nonce: u128) -> Self {
513        let upper_half = (nonce >> 64) as u64;
514        let lower_half = nonce as u64;
515        Self {
516            upper_half,
517            lower_half,
518        }
519    }
520}
521
522impl From<Nonce> for u128 {
523    fn from(nonce: Nonce) -> Self {
524        let upper_half = (nonce.upper_half as u128) << 64;
525        let lower_half = nonce.lower_half as u128;
526        upper_half | lower_half
527    }
528}
529
530pub fn split_worktree_update(
531    mut message: UpdateWorktree,
532    max_chunk_size: usize,
533) -> impl Iterator<Item = UpdateWorktree> {
534    let mut done_files = false;
535
536    let mut repository_map = message
537        .updated_repositories
538        .into_iter()
539        .map(|repo| (repo.work_directory_id, repo))
540        .collect::<HashMap<_, _>>();
541
542    iter::from_fn(move || {
543        if done_files {
544            return None;
545        }
546
547        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
548        let updated_entries: Vec<_> = message
549            .updated_entries
550            .drain(..updated_entries_chunk_size)
551            .collect();
552
553        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
554        let removed_entries = message
555            .removed_entries
556            .drain(..removed_entries_chunk_size)
557            .collect();
558
559        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
560
561        let mut updated_repositories = Vec::new();
562
563        if !repository_map.is_empty() {
564            for entry in &updated_entries {
565                if let Some(repo) = repository_map.remove(&entry.id) {
566                    updated_repositories.push(repo)
567                }
568            }
569        }
570
571        let removed_repositories = if done_files {
572            mem::take(&mut message.removed_repositories)
573        } else {
574            Default::default()
575        };
576
577        if done_files {
578            updated_repositories.extend(mem::take(&mut repository_map).into_values());
579        }
580
581        Some(UpdateWorktree {
582            project_id: message.project_id,
583            worktree_id: message.worktree_id,
584            root_name: message.root_name.clone(),
585            abs_path: message.abs_path.clone(),
586            updated_entries,
587            removed_entries,
588            scan_id: message.scan_id,
589            is_last_update: done_files && message.is_last_update,
590            updated_repositories,
591            removed_repositories,
592        })
593    })
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599
600    #[gpui::test]
601    async fn test_buffer_size() {
602        let (tx, rx) = futures::channel::mpsc::unbounded();
603        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
604        sink.write(Message::Envelope(Envelope {
605            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
606                root_name: "abcdefg".repeat(10),
607                ..Default::default()
608            })),
609            ..Default::default()
610        }))
611        .await
612        .unwrap();
613        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
614        sink.write(Message::Envelope(Envelope {
615            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
616                root_name: "abcdefg".repeat(1000000),
617                ..Default::default()
618            })),
619            ..Default::default()
620        }))
621        .await
622        .unwrap();
623        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
624
625        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
626        stream.read().await.unwrap();
627        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
628        stream.read().await.unwrap();
629        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
630    }
631
632    #[gpui::test]
633    fn test_converting_peer_id_from_and_to_u64() {
634        let peer_id = PeerId {
635            owner_id: 10,
636            id: 3,
637        };
638        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
639        let peer_id = PeerId {
640            owner_id: u32::MAX,
641            id: 3,
642        };
643        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
644        let peer_id = PeerId {
645            owner_id: 10,
646            id: u32::MAX,
647        };
648        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
649        let peer_id = PeerId {
650            owner_id: u32::MAX,
651            id: u32::MAX,
652        };
653        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
654    }
655}