proto.rs

  1#![allow(non_snake_case)]
  2
  3use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  4use anyhow::{anyhow, Result};
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use collections::HashMap;
  7use futures::{SinkExt as _, StreamExt as _};
  8use prost::Message as _;
  9use serde::Serialize;
 10use std::any::{Any, TypeId};
 11use std::{
 12    cmp,
 13    fmt::Debug,
 14    io, iter,
 15    time::{Duration, SystemTime, UNIX_EPOCH},
 16};
 17use std::{fmt, mem};
 18
 19include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 20
 21pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 22    const NAME: &'static str;
 23    const PRIORITY: MessagePriority;
 24    fn into_envelope(
 25        self,
 26        id: u32,
 27        responding_to: Option<u32>,
 28        original_sender_id: Option<PeerId>,
 29    ) -> Envelope;
 30    fn from_envelope(envelope: Envelope) -> Option<Self>;
 31}
 32
 33pub trait EntityMessage: EnvelopedMessage {
 34    fn remote_entity_id(&self) -> u64;
 35}
 36
 37pub trait RequestMessage: EnvelopedMessage {
 38    type Response: EnvelopedMessage;
 39}
 40
 41pub trait AnyTypedEnvelope: 'static + Send + Sync {
 42    fn payload_type_id(&self) -> TypeId;
 43    fn payload_type_name(&self) -> &'static str;
 44    fn as_any(&self) -> &dyn Any;
 45    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 46    fn is_background(&self) -> bool;
 47    fn original_sender_id(&self) -> Option<PeerId>;
 48    fn sender_id(&self) -> ConnectionId;
 49    fn message_id(&self) -> u32;
 50}
 51
 52pub enum MessagePriority {
 53    Foreground,
 54    Background,
 55}
 56
 57impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 58    fn payload_type_id(&self) -> TypeId {
 59        TypeId::of::<T>()
 60    }
 61
 62    fn payload_type_name(&self) -> &'static str {
 63        T::NAME
 64    }
 65
 66    fn as_any(&self) -> &dyn Any {
 67        self
 68    }
 69
 70    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 71        self
 72    }
 73
 74    fn is_background(&self) -> bool {
 75        matches!(T::PRIORITY, MessagePriority::Background)
 76    }
 77
 78    fn original_sender_id(&self) -> Option<PeerId> {
 79        self.original_sender_id
 80    }
 81
 82    fn sender_id(&self) -> ConnectionId {
 83        self.sender_id
 84    }
 85
 86    fn message_id(&self) -> u32 {
 87        self.message_id
 88    }
 89}
 90
 91impl PeerId {
 92    pub fn from_u64(peer_id: u64) -> Self {
 93        let owner_id = (peer_id >> 32) as u32;
 94        let id = peer_id as u32;
 95        Self { owner_id, id }
 96    }
 97
 98    pub fn as_u64(self) -> u64 {
 99        ((self.owner_id as u64) << 32) | (self.id as u64)
100    }
101}
102
103impl Copy for PeerId {}
104
105impl Eq for PeerId {}
106
107impl Ord for PeerId {
108    fn cmp(&self, other: &Self) -> cmp::Ordering {
109        self.owner_id
110            .cmp(&other.owner_id)
111            .then_with(|| self.id.cmp(&other.id))
112    }
113}
114
115impl PartialOrd for PeerId {
116    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
117        Some(self.cmp(other))
118    }
119}
120
121impl std::hash::Hash for PeerId {
122    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123        self.owner_id.hash(state);
124        self.id.hash(state);
125    }
126}
127
128impl fmt::Display for PeerId {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "{}/{}", self.owner_id, self.id)
131    }
132}
133
134messages!(
135    (Ack, Foreground),
136    (AddProjectCollaborator, Foreground),
137    (ApplyCodeAction, Background),
138    (ApplyCodeActionResponse, Background),
139    (ApplyCompletionAdditionalEdits, Background),
140    (ApplyCompletionAdditionalEditsResponse, Background),
141    (BufferReloaded, Foreground),
142    (BufferSaved, Foreground),
143    (Call, Foreground),
144    (CallCanceled, Foreground),
145    (CancelCall, Foreground),
146    (CopyProjectEntry, Foreground),
147    (CreateBufferForPeer, Foreground),
148    (CreateChannel, Foreground),
149    (ChannelResponse, Foreground),
150    (CreateProjectEntry, Foreground),
151    (CreateRoom, Foreground),
152    (CreateRoomResponse, Foreground),
153    (DeclineCall, Foreground),
154    (DeleteProjectEntry, Foreground),
155    (Error, Foreground),
156    (ExpandProjectEntry, Foreground),
157    (Follow, Foreground),
158    (FollowResponse, Foreground),
159    (FormatBuffers, Foreground),
160    (FormatBuffersResponse, Foreground),
161    (FuzzySearchUsers, Foreground),
162    (GetCodeActions, Background),
163    (GetCodeActionsResponse, Background),
164    (GetHover, Background),
165    (GetHoverResponse, Background),
166    (GetCompletions, Background),
167    (GetCompletionsResponse, Background),
168    (GetDefinition, Background),
169    (GetDefinitionResponse, Background),
170    (GetTypeDefinition, Background),
171    (GetTypeDefinitionResponse, Background),
172    (GetDocumentHighlights, Background),
173    (GetDocumentHighlightsResponse, Background),
174    (GetReferences, Background),
175    (GetReferencesResponse, Background),
176    (GetProjectSymbols, Background),
177    (GetProjectSymbolsResponse, Background),
178    (GetUsers, Foreground),
179    (Hello, Foreground),
180    (IncomingCall, Foreground),
181    (InviteChannelMember, Foreground),
182    (UsersResponse, Foreground),
183    (JoinProject, Foreground),
184    (JoinProjectResponse, Foreground),
185    (JoinRoom, Foreground),
186    (JoinRoomResponse, Foreground),
187    (LeaveProject, Foreground),
188    (LeaveRoom, Foreground),
189    (OpenBufferById, Background),
190    (OpenBufferByPath, Background),
191    (OpenBufferForSymbol, Background),
192    (OpenBufferForSymbolResponse, Background),
193    (OpenBufferResponse, Background),
194    (PerformRename, Background),
195    (PerformRenameResponse, Background),
196    (OnTypeFormatting, Background),
197    (OnTypeFormattingResponse, Background),
198    (InlayHints, Background),
199    (InlayHintsResponse, Background),
200    (ResolveInlayHint, Background),
201    (ResolveInlayHintResponse, Background),
202    (RefreshInlayHints, Foreground),
203    (Ping, Foreground),
204    (PrepareRename, Background),
205    (PrepareRenameResponse, Background),
206    (ExpandProjectEntryResponse, Foreground),
207    (ProjectEntryResponse, Foreground),
208    (RejoinRoom, Foreground),
209    (RejoinRoomResponse, Foreground),
210    (RemoveContact, Foreground),
211    (RemoveChannelMember, Foreground),
212    (ReloadBuffers, Foreground),
213    (ReloadBuffersResponse, Foreground),
214    (RemoveProjectCollaborator, Foreground),
215    (RenameProjectEntry, Foreground),
216    (RequestContact, Foreground),
217    (RespondToContactRequest, Foreground),
218    (RespondToChannelInvite, Foreground),
219    (JoinChannel, Foreground),
220    (RoomUpdated, Foreground),
221    (SaveBuffer, Foreground),
222    (RenameChannel, Foreground),
223    (SetChannelMemberAdmin, Foreground),
224    (SearchProject, Background),
225    (SearchProjectResponse, Background),
226    (ShareProject, Foreground),
227    (ShareProjectResponse, Foreground),
228    (ShowContacts, Foreground),
229    (StartLanguageServer, Foreground),
230    (SynchronizeBuffers, Foreground),
231    (SynchronizeBuffersResponse, Foreground),
232    (Test, Foreground),
233    (Unfollow, Foreground),
234    (UnshareProject, Foreground),
235    (UpdateBuffer, Foreground),
236    (UpdateBufferFile, Foreground),
237    (UpdateContacts, Foreground),
238    (RemoveChannel, Foreground),
239    (UpdateChannels, Foreground),
240    (UpdateDiagnosticSummary, Foreground),
241    (UpdateFollowers, Foreground),
242    (UpdateInviteInfo, Foreground),
243    (UpdateLanguageServer, Foreground),
244    (UpdateParticipantLocation, Foreground),
245    (UpdateProject, Foreground),
246    (UpdateProjectCollaborator, Foreground),
247    (UpdateWorktree, Foreground),
248    (UpdateWorktreeSettings, Foreground),
249    (UpdateDiffBase, Foreground),
250    (GetPrivateUserInfo, Foreground),
251    (GetPrivateUserInfoResponse, Foreground),
252    (GetChannelMembers, Foreground),
253    (GetChannelMembersResponse, Foreground),
254    (JoinChannelBuffer, Foreground),
255    (JoinChannelBufferResponse, Foreground),
256    (LeaveChannelBuffer, Background),
257    (UpdateChannelBuffer, Foreground),
258    (RemoveChannelBufferCollaborator, Foreground),
259    (AddChannelBufferCollaborator, Foreground),
260);
261
262request_messages!(
263    (ApplyCodeAction, ApplyCodeActionResponse),
264    (
265        ApplyCompletionAdditionalEdits,
266        ApplyCompletionAdditionalEditsResponse
267    ),
268    (Call, Ack),
269    (CancelCall, Ack),
270    (CopyProjectEntry, ProjectEntryResponse),
271    (CreateProjectEntry, ProjectEntryResponse),
272    (CreateRoom, CreateRoomResponse),
273    (CreateChannel, ChannelResponse),
274    (DeclineCall, Ack),
275    (DeleteProjectEntry, ProjectEntryResponse),
276    (ExpandProjectEntry, ExpandProjectEntryResponse),
277    (Follow, FollowResponse),
278    (FormatBuffers, FormatBuffersResponse),
279    (GetCodeActions, GetCodeActionsResponse),
280    (GetHover, GetHoverResponse),
281    (GetCompletions, GetCompletionsResponse),
282    (GetDefinition, GetDefinitionResponse),
283    (GetTypeDefinition, GetTypeDefinitionResponse),
284    (GetDocumentHighlights, GetDocumentHighlightsResponse),
285    (GetReferences, GetReferencesResponse),
286    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
287    (GetProjectSymbols, GetProjectSymbolsResponse),
288    (FuzzySearchUsers, UsersResponse),
289    (GetUsers, UsersResponse),
290    (InviteChannelMember, Ack),
291    (JoinProject, JoinProjectResponse),
292    (JoinRoom, JoinRoomResponse),
293    (LeaveRoom, Ack),
294    (RejoinRoom, RejoinRoomResponse),
295    (IncomingCall, Ack),
296    (OpenBufferById, OpenBufferResponse),
297    (OpenBufferByPath, OpenBufferResponse),
298    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
299    (Ping, Ack),
300    (PerformRename, PerformRenameResponse),
301    (PrepareRename, PrepareRenameResponse),
302    (OnTypeFormatting, OnTypeFormattingResponse),
303    (InlayHints, InlayHintsResponse),
304    (ResolveInlayHint, ResolveInlayHintResponse),
305    (RefreshInlayHints, Ack),
306    (ReloadBuffers, ReloadBuffersResponse),
307    (RequestContact, Ack),
308    (RemoveChannelMember, Ack),
309    (RemoveContact, Ack),
310    (RespondToContactRequest, Ack),
311    (RespondToChannelInvite, Ack),
312    (SetChannelMemberAdmin, Ack),
313    (GetChannelMembers, GetChannelMembersResponse),
314    (JoinChannel, JoinRoomResponse),
315    (RemoveChannel, Ack),
316    (RenameProjectEntry, ProjectEntryResponse),
317    (RenameChannel, ChannelResponse),
318    (SaveBuffer, BufferSaved),
319    (SearchProject, SearchProjectResponse),
320    (ShareProject, ShareProjectResponse),
321    (SynchronizeBuffers, SynchronizeBuffersResponse),
322    (Test, Test),
323    (UpdateBuffer, Ack),
324    (UpdateParticipantLocation, Ack),
325    (UpdateProject, Ack),
326    (UpdateWorktree, Ack),
327    (JoinChannelBuffer, JoinChannelBufferResponse),
328    (LeaveChannelBuffer, Ack)
329);
330
331entity_messages!(
332    project_id,
333    AddProjectCollaborator,
334    ApplyCodeAction,
335    ApplyCompletionAdditionalEdits,
336    BufferReloaded,
337    BufferSaved,
338    CopyProjectEntry,
339    CreateBufferForPeer,
340    CreateProjectEntry,
341    DeleteProjectEntry,
342    ExpandProjectEntry,
343    Follow,
344    FormatBuffers,
345    GetCodeActions,
346    GetCompletions,
347    GetDefinition,
348    GetTypeDefinition,
349    GetDocumentHighlights,
350    GetHover,
351    GetReferences,
352    GetProjectSymbols,
353    JoinProject,
354    LeaveProject,
355    OpenBufferById,
356    OpenBufferByPath,
357    OpenBufferForSymbol,
358    PerformRename,
359    OnTypeFormatting,
360    InlayHints,
361    ResolveInlayHint,
362    RefreshInlayHints,
363    PrepareRename,
364    ReloadBuffers,
365    RemoveProjectCollaborator,
366    RenameProjectEntry,
367    SaveBuffer,
368    SearchProject,
369    StartLanguageServer,
370    SynchronizeBuffers,
371    Unfollow,
372    UnshareProject,
373    UpdateBuffer,
374    UpdateBufferFile,
375    UpdateDiagnosticSummary,
376    UpdateFollowers,
377    UpdateLanguageServer,
378    UpdateProject,
379    UpdateProjectCollaborator,
380    UpdateWorktree,
381    UpdateWorktreeSettings,
382    UpdateDiffBase
383);
384
385entity_messages!(
386    channel_id,
387    UpdateChannelBuffer,
388    RemoveChannelBufferCollaborator,
389    AddChannelBufferCollaborator
390);
391
392const KIB: usize = 1024;
393const MIB: usize = KIB * 1024;
394const MAX_BUFFER_LEN: usize = MIB;
395
396/// A stream of protobuf messages.
397pub struct MessageStream<S> {
398    stream: S,
399    encoding_buffer: Vec<u8>,
400}
401
402#[allow(clippy::large_enum_variant)]
403#[derive(Debug)]
404pub enum Message {
405    Envelope(Envelope),
406    Ping,
407    Pong,
408}
409
410impl<S> MessageStream<S> {
411    pub fn new(stream: S) -> Self {
412        Self {
413            stream,
414            encoding_buffer: Vec::new(),
415        }
416    }
417
418    pub fn inner_mut(&mut self) -> &mut S {
419        &mut self.stream
420    }
421}
422
423impl<S> MessageStream<S>
424where
425    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
426{
427    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
428        #[cfg(any(test, feature = "test-support"))]
429        const COMPRESSION_LEVEL: i32 = -7;
430
431        #[cfg(not(any(test, feature = "test-support")))]
432        const COMPRESSION_LEVEL: i32 = 4;
433
434        match message {
435            Message::Envelope(message) => {
436                self.encoding_buffer.reserve(message.encoded_len());
437                message
438                    .encode(&mut self.encoding_buffer)
439                    .map_err(io::Error::from)?;
440                let buffer =
441                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
442                        .unwrap();
443
444                self.encoding_buffer.clear();
445                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
446                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
447            }
448            Message::Ping => {
449                self.stream
450                    .send(WebSocketMessage::Ping(Default::default()))
451                    .await?;
452            }
453            Message::Pong => {
454                self.stream
455                    .send(WebSocketMessage::Pong(Default::default()))
456                    .await?;
457            }
458        }
459
460        Ok(())
461    }
462}
463
464impl<S> MessageStream<S>
465where
466    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
467{
468    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
469        while let Some(bytes) = self.stream.next().await {
470            match bytes? {
471                WebSocketMessage::Binary(bytes) => {
472                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
473                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
474                        .map_err(io::Error::from)?;
475
476                    self.encoding_buffer.clear();
477                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
478                    return Ok(Message::Envelope(envelope));
479                }
480                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
481                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
482                WebSocketMessage::Close(_) => break,
483                _ => {}
484            }
485        }
486        Err(anyhow!("connection closed"))
487    }
488}
489
490impl From<Timestamp> for SystemTime {
491    fn from(val: Timestamp) -> Self {
492        UNIX_EPOCH
493            .checked_add(Duration::new(val.seconds, val.nanos))
494            .unwrap()
495    }
496}
497
498impl From<SystemTime> for Timestamp {
499    fn from(time: SystemTime) -> Self {
500        let duration = time.duration_since(UNIX_EPOCH).unwrap();
501        Self {
502            seconds: duration.as_secs(),
503            nanos: duration.subsec_nanos(),
504        }
505    }
506}
507
508impl From<u128> for Nonce {
509    fn from(nonce: u128) -> Self {
510        let upper_half = (nonce >> 64) as u64;
511        let lower_half = nonce as u64;
512        Self {
513            upper_half,
514            lower_half,
515        }
516    }
517}
518
519impl From<Nonce> for u128 {
520    fn from(nonce: Nonce) -> Self {
521        let upper_half = (nonce.upper_half as u128) << 64;
522        let lower_half = nonce.lower_half as u128;
523        upper_half | lower_half
524    }
525}
526
527pub fn split_worktree_update(
528    mut message: UpdateWorktree,
529    max_chunk_size: usize,
530) -> impl Iterator<Item = UpdateWorktree> {
531    let mut done_files = false;
532
533    let mut repository_map = message
534        .updated_repositories
535        .into_iter()
536        .map(|repo| (repo.work_directory_id, repo))
537        .collect::<HashMap<_, _>>();
538
539    iter::from_fn(move || {
540        if done_files {
541            return None;
542        }
543
544        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
545        let updated_entries: Vec<_> = message
546            .updated_entries
547            .drain(..updated_entries_chunk_size)
548            .collect();
549
550        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
551        let removed_entries = message
552            .removed_entries
553            .drain(..removed_entries_chunk_size)
554            .collect();
555
556        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
557
558        let mut updated_repositories = Vec::new();
559
560        if !repository_map.is_empty() {
561            for entry in &updated_entries {
562                if let Some(repo) = repository_map.remove(&entry.id) {
563                    updated_repositories.push(repo)
564                }
565            }
566        }
567
568        let removed_repositories = if done_files {
569            mem::take(&mut message.removed_repositories)
570        } else {
571            Default::default()
572        };
573
574        if done_files {
575            updated_repositories.extend(mem::take(&mut repository_map).into_values());
576        }
577
578        Some(UpdateWorktree {
579            project_id: message.project_id,
580            worktree_id: message.worktree_id,
581            root_name: message.root_name.clone(),
582            abs_path: message.abs_path.clone(),
583            updated_entries,
584            removed_entries,
585            scan_id: message.scan_id,
586            is_last_update: done_files && message.is_last_update,
587            updated_repositories,
588            removed_repositories,
589        })
590    })
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596
597    #[gpui::test]
598    async fn test_buffer_size() {
599        let (tx, rx) = futures::channel::mpsc::unbounded();
600        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
601        sink.write(Message::Envelope(Envelope {
602            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
603                root_name: "abcdefg".repeat(10),
604                ..Default::default()
605            })),
606            ..Default::default()
607        }))
608        .await
609        .unwrap();
610        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
611        sink.write(Message::Envelope(Envelope {
612            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
613                root_name: "abcdefg".repeat(1000000),
614                ..Default::default()
615            })),
616            ..Default::default()
617        }))
618        .await
619        .unwrap();
620        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
621
622        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
623        stream.read().await.unwrap();
624        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
625        stream.read().await.unwrap();
626        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
627    }
628
629    #[gpui::test]
630    fn test_converting_peer_id_from_and_to_u64() {
631        let peer_id = PeerId {
632            owner_id: 10,
633            id: 3,
634        };
635        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
636        let peer_id = PeerId {
637            owner_id: u32::MAX,
638            id: 3,
639        };
640        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
641        let peer_id = PeerId {
642            owner_id: 10,
643            id: u32::MAX,
644        };
645        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
646        let peer_id = PeerId {
647            owner_id: u32::MAX,
648            id: u32::MAX,
649        };
650        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
651    }
652}