proto.rs

  1#![allow(non_snake_case)]
  2
  3use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  4use anyhow::{anyhow, Result};
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use collections::HashMap;
  7use futures::{SinkExt as _, StreamExt as _};
  8use prost::Message as _;
  9use serde::Serialize;
 10use std::any::{Any, TypeId};
 11use std::{
 12    cmp,
 13    fmt::Debug,
 14    io, iter,
 15    time::{Duration, SystemTime, UNIX_EPOCH},
 16};
 17use std::{fmt, mem};
 18
 19include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 20
 21pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 22    const NAME: &'static str;
 23    const PRIORITY: MessagePriority;
 24    fn into_envelope(
 25        self,
 26        id: u32,
 27        responding_to: Option<u32>,
 28        original_sender_id: Option<PeerId>,
 29    ) -> Envelope;
 30    fn from_envelope(envelope: Envelope) -> Option<Self>;
 31}
 32
 33pub trait EntityMessage: EnvelopedMessage {
 34    fn remote_entity_id(&self) -> u64;
 35}
 36
 37pub trait RequestMessage: EnvelopedMessage {
 38    type Response: EnvelopedMessage;
 39}
 40
 41pub trait AnyTypedEnvelope: 'static + Send + Sync {
 42    fn payload_type_id(&self) -> TypeId;
 43    fn payload_type_name(&self) -> &'static str;
 44    fn as_any(&self) -> &dyn Any;
 45    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 46    fn is_background(&self) -> bool;
 47    fn original_sender_id(&self) -> Option<PeerId>;
 48    fn sender_id(&self) -> ConnectionId;
 49    fn message_id(&self) -> u32;
 50}
 51
 52pub enum MessagePriority {
 53    Foreground,
 54    Background,
 55}
 56
 57impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 58    fn payload_type_id(&self) -> TypeId {
 59        TypeId::of::<T>()
 60    }
 61
 62    fn payload_type_name(&self) -> &'static str {
 63        T::NAME
 64    }
 65
 66    fn as_any(&self) -> &dyn Any {
 67        self
 68    }
 69
 70    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 71        self
 72    }
 73
 74    fn is_background(&self) -> bool {
 75        matches!(T::PRIORITY, MessagePriority::Background)
 76    }
 77
 78    fn original_sender_id(&self) -> Option<PeerId> {
 79        self.original_sender_id
 80    }
 81
 82    fn sender_id(&self) -> ConnectionId {
 83        self.sender_id
 84    }
 85
 86    fn message_id(&self) -> u32 {
 87        self.message_id
 88    }
 89}
 90
 91impl PeerId {
 92    pub fn from_u64(peer_id: u64) -> Self {
 93        let owner_id = (peer_id >> 32) as u32;
 94        let id = peer_id as u32;
 95        Self { owner_id, id }
 96    }
 97
 98    pub fn as_u64(self) -> u64 {
 99        ((self.owner_id as u64) << 32) | (self.id as u64)
100    }
101}
102
103impl Copy for PeerId {}
104
105impl Eq for PeerId {}
106
107impl Ord for PeerId {
108    fn cmp(&self, other: &Self) -> cmp::Ordering {
109        self.owner_id
110            .cmp(&other.owner_id)
111            .then_with(|| self.id.cmp(&other.id))
112    }
113}
114
115impl PartialOrd for PeerId {
116    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
117        Some(self.cmp(other))
118    }
119}
120
121impl std::hash::Hash for PeerId {
122    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123        self.owner_id.hash(state);
124        self.id.hash(state);
125    }
126}
127
128impl fmt::Display for PeerId {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "{}/{}", self.owner_id, self.id)
131    }
132}
133
134messages!(
135    (Ack, Foreground),
136    (AddProjectCollaborator, Foreground),
137    (ApplyCodeAction, Background),
138    (ApplyCodeActionResponse, Background),
139    (ApplyCompletionAdditionalEdits, Background),
140    (ApplyCompletionAdditionalEditsResponse, Background),
141    (BufferReloaded, Foreground),
142    (BufferSaved, Foreground),
143    (Call, Foreground),
144    (CallCanceled, Foreground),
145    (CancelCall, Foreground),
146    (CopyProjectEntry, Foreground),
147    (CreateBufferForPeer, Foreground),
148    (CreateChannel, Foreground),
149    (ChannelResponse, Foreground),
150    (CreateProjectEntry, Foreground),
151    (CreateRoom, Foreground),
152    (CreateRoomResponse, Foreground),
153    (DeclineCall, Foreground),
154    (DeleteProjectEntry, Foreground),
155    (Error, Foreground),
156    (ExpandProjectEntry, Foreground),
157    (Follow, Foreground),
158    (FollowResponse, Foreground),
159    (FormatBuffers, Foreground),
160    (FormatBuffersResponse, Foreground),
161    (FuzzySearchUsers, Foreground),
162    (GetCodeActions, Background),
163    (GetCodeActionsResponse, Background),
164    (GetHover, Background),
165    (GetHoverResponse, Background),
166    (GetCompletions, Background),
167    (GetCompletionsResponse, Background),
168    (GetDefinition, Background),
169    (GetDefinitionResponse, Background),
170    (GetTypeDefinition, Background),
171    (GetTypeDefinitionResponse, Background),
172    (GetDocumentHighlights, Background),
173    (GetDocumentHighlightsResponse, Background),
174    (GetReferences, Background),
175    (GetReferencesResponse, Background),
176    (GetProjectSymbols, Background),
177    (GetProjectSymbolsResponse, Background),
178    (GetUsers, Foreground),
179    (Hello, Foreground),
180    (IncomingCall, Foreground),
181    (InviteChannelMember, Foreground),
182    (UsersResponse, Foreground),
183    (JoinProject, Foreground),
184    (JoinProjectResponse, Foreground),
185    (JoinRoom, Foreground),
186    (JoinRoomResponse, Foreground),
187    (LeaveProject, Foreground),
188    (LeaveRoom, Foreground),
189    (OpenBufferById, Background),
190    (OpenBufferByPath, Background),
191    (OpenBufferForSymbol, Background),
192    (OpenBufferForSymbolResponse, Background),
193    (OpenBufferResponse, Background),
194    (PerformRename, Background),
195    (PerformRenameResponse, Background),
196    (OnTypeFormatting, Background),
197    (OnTypeFormattingResponse, Background),
198    (InlayHints, Background),
199    (InlayHintsResponse, Background),
200    (RefreshInlayHints, Foreground),
201    (Ping, Foreground),
202    (PrepareRename, Background),
203    (PrepareRenameResponse, Background),
204    (ExpandProjectEntryResponse, Foreground),
205    (ProjectEntryResponse, Foreground),
206    (RejoinRoom, Foreground),
207    (RejoinRoomResponse, Foreground),
208    (RemoveContact, Foreground),
209    (RemoveChannelMember, Foreground),
210    (ReloadBuffers, Foreground),
211    (ReloadBuffersResponse, Foreground),
212    (RemoveProjectCollaborator, Foreground),
213    (RenameProjectEntry, Foreground),
214    (RequestContact, Foreground),
215    (RespondToContactRequest, Foreground),
216    (RespondToChannelInvite, Foreground),
217    (JoinChannel, Foreground),
218    (RoomUpdated, Foreground),
219    (SaveBuffer, Foreground),
220    (RenameChannel, Foreground),
221    (SetChannelMemberAdmin, Foreground),
222    (SearchProject, Background),
223    (SearchProjectResponse, Background),
224    (ShareProject, Foreground),
225    (ShareProjectResponse, Foreground),
226    (ShowContacts, Foreground),
227    (StartLanguageServer, Foreground),
228    (SynchronizeBuffers, Foreground),
229    (SynchronizeBuffersResponse, Foreground),
230    (Test, Foreground),
231    (Unfollow, Foreground),
232    (UnshareProject, Foreground),
233    (UpdateBuffer, Foreground),
234    (UpdateBufferFile, Foreground),
235    (UpdateContacts, Foreground),
236    (RemoveChannel, Foreground),
237    (UpdateChannels, Foreground),
238    (UpdateDiagnosticSummary, Foreground),
239    (UpdateFollowers, Foreground),
240    (UpdateInviteInfo, Foreground),
241    (UpdateLanguageServer, Foreground),
242    (UpdateParticipantLocation, Foreground),
243    (UpdateProject, Foreground),
244    (UpdateProjectCollaborator, Foreground),
245    (UpdateWorktree, Foreground),
246    (UpdateWorktreeSettings, Foreground),
247    (UpdateDiffBase, Foreground),
248    (GetPrivateUserInfo, Foreground),
249    (GetPrivateUserInfoResponse, Foreground),
250    (GetChannelMembers, Foreground),
251    (GetChannelMembersResponse, Foreground),
252    (GetChannelBuffer, Foreground),
253    (GetChannelBufferResponse, Foreground)
254);
255
256request_messages!(
257    (ApplyCodeAction, ApplyCodeActionResponse),
258    (
259        ApplyCompletionAdditionalEdits,
260        ApplyCompletionAdditionalEditsResponse
261    ),
262    (Call, Ack),
263    (CancelCall, Ack),
264    (CopyProjectEntry, ProjectEntryResponse),
265    (CreateProjectEntry, ProjectEntryResponse),
266    (CreateRoom, CreateRoomResponse),
267    (CreateChannel, ChannelResponse),
268    (DeclineCall, Ack),
269    (DeleteProjectEntry, ProjectEntryResponse),
270    (ExpandProjectEntry, ExpandProjectEntryResponse),
271    (Follow, FollowResponse),
272    (FormatBuffers, FormatBuffersResponse),
273    (GetCodeActions, GetCodeActionsResponse),
274    (GetHover, GetHoverResponse),
275    (GetCompletions, GetCompletionsResponse),
276    (GetDefinition, GetDefinitionResponse),
277    (GetTypeDefinition, GetTypeDefinitionResponse),
278    (GetDocumentHighlights, GetDocumentHighlightsResponse),
279    (GetReferences, GetReferencesResponse),
280    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
281    (GetProjectSymbols, GetProjectSymbolsResponse),
282    (FuzzySearchUsers, UsersResponse),
283    (GetUsers, UsersResponse),
284    (InviteChannelMember, Ack),
285    (JoinProject, JoinProjectResponse),
286    (JoinRoom, JoinRoomResponse),
287    (LeaveRoom, Ack),
288    (RejoinRoom, RejoinRoomResponse),
289    (IncomingCall, Ack),
290    (OpenBufferById, OpenBufferResponse),
291    (OpenBufferByPath, OpenBufferResponse),
292    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
293    (Ping, Ack),
294    (PerformRename, PerformRenameResponse),
295    (PrepareRename, PrepareRenameResponse),
296    (OnTypeFormatting, OnTypeFormattingResponse),
297    (InlayHints, InlayHintsResponse),
298    (RefreshInlayHints, Ack),
299    (ReloadBuffers, ReloadBuffersResponse),
300    (RequestContact, Ack),
301    (RemoveChannelMember, Ack),
302    (RemoveContact, Ack),
303    (RespondToContactRequest, Ack),
304    (RespondToChannelInvite, Ack),
305    (SetChannelMemberAdmin, Ack),
306    (GetChannelMembers, GetChannelMembersResponse),
307    (JoinChannel, JoinRoomResponse),
308    (RemoveChannel, Ack),
309    (RenameProjectEntry, ProjectEntryResponse),
310    (RenameChannel, ChannelResponse),
311    (SaveBuffer, BufferSaved),
312    (SearchProject, SearchProjectResponse),
313    (ShareProject, ShareProjectResponse),
314    (SynchronizeBuffers, SynchronizeBuffersResponse),
315    (Test, Test),
316    (UpdateBuffer, Ack),
317    (UpdateParticipantLocation, Ack),
318    (UpdateProject, Ack),
319    (UpdateWorktree, Ack),
320    (GetChannelBuffer, GetChannelBufferResponse)
321);
322
323entity_messages!(
324    project_id,
325    AddProjectCollaborator,
326    ApplyCodeAction,
327    ApplyCompletionAdditionalEdits,
328    BufferReloaded,
329    BufferSaved,
330    CopyProjectEntry,
331    CreateBufferForPeer,
332    CreateProjectEntry,
333    DeleteProjectEntry,
334    ExpandProjectEntry,
335    Follow,
336    FormatBuffers,
337    GetCodeActions,
338    GetCompletions,
339    GetDefinition,
340    GetTypeDefinition,
341    GetDocumentHighlights,
342    GetHover,
343    GetReferences,
344    GetProjectSymbols,
345    JoinProject,
346    LeaveProject,
347    OpenBufferById,
348    OpenBufferByPath,
349    OpenBufferForSymbol,
350    PerformRename,
351    OnTypeFormatting,
352    InlayHints,
353    RefreshInlayHints,
354    PrepareRename,
355    ReloadBuffers,
356    RemoveProjectCollaborator,
357    RenameProjectEntry,
358    SaveBuffer,
359    SearchProject,
360    StartLanguageServer,
361    SynchronizeBuffers,
362    Unfollow,
363    UnshareProject,
364    UpdateBuffer,
365    UpdateBufferFile,
366    UpdateDiagnosticSummary,
367    UpdateFollowers,
368    UpdateLanguageServer,
369    UpdateProject,
370    UpdateProjectCollaborator,
371    UpdateWorktree,
372    UpdateWorktreeSettings,
373    UpdateDiffBase
374);
375
376const KIB: usize = 1024;
377const MIB: usize = KIB * 1024;
378const MAX_BUFFER_LEN: usize = MIB;
379
380/// A stream of protobuf messages.
381pub struct MessageStream<S> {
382    stream: S,
383    encoding_buffer: Vec<u8>,
384}
385
386#[allow(clippy::large_enum_variant)]
387#[derive(Debug)]
388pub enum Message {
389    Envelope(Envelope),
390    Ping,
391    Pong,
392}
393
394impl<S> MessageStream<S> {
395    pub fn new(stream: S) -> Self {
396        Self {
397            stream,
398            encoding_buffer: Vec::new(),
399        }
400    }
401
402    pub fn inner_mut(&mut self) -> &mut S {
403        &mut self.stream
404    }
405}
406
407impl<S> MessageStream<S>
408where
409    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
410{
411    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
412        #[cfg(any(test, feature = "test-support"))]
413        const COMPRESSION_LEVEL: i32 = -7;
414
415        #[cfg(not(any(test, feature = "test-support")))]
416        const COMPRESSION_LEVEL: i32 = 4;
417
418        match message {
419            Message::Envelope(message) => {
420                self.encoding_buffer.reserve(message.encoded_len());
421                message
422                    .encode(&mut self.encoding_buffer)
423                    .map_err(io::Error::from)?;
424                let buffer =
425                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
426                        .unwrap();
427
428                self.encoding_buffer.clear();
429                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
430                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
431            }
432            Message::Ping => {
433                self.stream
434                    .send(WebSocketMessage::Ping(Default::default()))
435                    .await?;
436            }
437            Message::Pong => {
438                self.stream
439                    .send(WebSocketMessage::Pong(Default::default()))
440                    .await?;
441            }
442        }
443
444        Ok(())
445    }
446}
447
448impl<S> MessageStream<S>
449where
450    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
451{
452    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
453        while let Some(bytes) = self.stream.next().await {
454            match bytes? {
455                WebSocketMessage::Binary(bytes) => {
456                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
457                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
458                        .map_err(io::Error::from)?;
459
460                    self.encoding_buffer.clear();
461                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
462                    return Ok(Message::Envelope(envelope));
463                }
464                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
465                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
466                WebSocketMessage::Close(_) => break,
467                _ => {}
468            }
469        }
470        Err(anyhow!("connection closed"))
471    }
472}
473
474impl From<Timestamp> for SystemTime {
475    fn from(val: Timestamp) -> Self {
476        UNIX_EPOCH
477            .checked_add(Duration::new(val.seconds, val.nanos))
478            .unwrap()
479    }
480}
481
482impl From<SystemTime> for Timestamp {
483    fn from(time: SystemTime) -> Self {
484        let duration = time.duration_since(UNIX_EPOCH).unwrap();
485        Self {
486            seconds: duration.as_secs(),
487            nanos: duration.subsec_nanos(),
488        }
489    }
490}
491
492impl From<u128> for Nonce {
493    fn from(nonce: u128) -> Self {
494        let upper_half = (nonce >> 64) as u64;
495        let lower_half = nonce as u64;
496        Self {
497            upper_half,
498            lower_half,
499        }
500    }
501}
502
503impl From<Nonce> for u128 {
504    fn from(nonce: Nonce) -> Self {
505        let upper_half = (nonce.upper_half as u128) << 64;
506        let lower_half = nonce.lower_half as u128;
507        upper_half | lower_half
508    }
509}
510
511pub fn split_worktree_update(
512    mut message: UpdateWorktree,
513    max_chunk_size: usize,
514) -> impl Iterator<Item = UpdateWorktree> {
515    let mut done_files = false;
516
517    let mut repository_map = message
518        .updated_repositories
519        .into_iter()
520        .map(|repo| (repo.work_directory_id, repo))
521        .collect::<HashMap<_, _>>();
522
523    iter::from_fn(move || {
524        if done_files {
525            return None;
526        }
527
528        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
529        let updated_entries: Vec<_> = message
530            .updated_entries
531            .drain(..updated_entries_chunk_size)
532            .collect();
533
534        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
535        let removed_entries = message
536            .removed_entries
537            .drain(..removed_entries_chunk_size)
538            .collect();
539
540        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
541
542        let mut updated_repositories = Vec::new();
543
544        if !repository_map.is_empty() {
545            for entry in &updated_entries {
546                if let Some(repo) = repository_map.remove(&entry.id) {
547                    updated_repositories.push(repo)
548                }
549            }
550        }
551
552        let removed_repositories = if done_files {
553            mem::take(&mut message.removed_repositories)
554        } else {
555            Default::default()
556        };
557
558        if done_files {
559            updated_repositories.extend(mem::take(&mut repository_map).into_values());
560        }
561
562        Some(UpdateWorktree {
563            project_id: message.project_id,
564            worktree_id: message.worktree_id,
565            root_name: message.root_name.clone(),
566            abs_path: message.abs_path.clone(),
567            updated_entries,
568            removed_entries,
569            scan_id: message.scan_id,
570            is_last_update: done_files && message.is_last_update,
571            updated_repositories,
572            removed_repositories,
573        })
574    })
575}
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580
581    #[gpui::test]
582    async fn test_buffer_size() {
583        let (tx, rx) = futures::channel::mpsc::unbounded();
584        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
585        sink.write(Message::Envelope(Envelope {
586            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
587                root_name: "abcdefg".repeat(10),
588                ..Default::default()
589            })),
590            ..Default::default()
591        }))
592        .await
593        .unwrap();
594        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
595        sink.write(Message::Envelope(Envelope {
596            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
597                root_name: "abcdefg".repeat(1000000),
598                ..Default::default()
599            })),
600            ..Default::default()
601        }))
602        .await
603        .unwrap();
604        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
605
606        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
607        stream.read().await.unwrap();
608        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
609        stream.read().await.unwrap();
610        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
611    }
612
613    #[gpui::test]
614    fn test_converting_peer_id_from_and_to_u64() {
615        let peer_id = PeerId {
616            owner_id: 10,
617            id: 3,
618        };
619        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
620        let peer_id = PeerId {
621            owner_id: u32::MAX,
622            id: 3,
623        };
624        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
625        let peer_id = PeerId {
626            owner_id: 10,
627            id: u32::MAX,
628        };
629        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
630        let peer_id = PeerId {
631            owner_id: u32::MAX,
632            id: u32::MAX,
633        };
634        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
635    }
636}