proto.rs

  1#![allow(non_snake_case)]
  2
  3use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  4use anyhow::{anyhow, Result};
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use collections::HashMap;
  7use futures::{SinkExt as _, StreamExt as _};
  8use prost::Message as _;
  9use serde::Serialize;
 10use std::any::{Any, TypeId};
 11use std::{
 12    cmp,
 13    fmt::Debug,
 14    io, iter,
 15    time::{Duration, SystemTime, UNIX_EPOCH},
 16};
 17use std::{fmt, mem};
 18
 19include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 20
 21pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 22    const NAME: &'static str;
 23    const PRIORITY: MessagePriority;
 24    fn into_envelope(
 25        self,
 26        id: u32,
 27        responding_to: Option<u32>,
 28        original_sender_id: Option<PeerId>,
 29    ) -> Envelope;
 30    fn from_envelope(envelope: Envelope) -> Option<Self>;
 31}
 32
 33pub trait EntityMessage: EnvelopedMessage {
 34    fn remote_entity_id(&self) -> u64;
 35}
 36
 37pub trait RequestMessage: EnvelopedMessage {
 38    type Response: EnvelopedMessage;
 39}
 40
 41pub trait AnyTypedEnvelope: 'static + Send + Sync {
 42    fn payload_type_id(&self) -> TypeId;
 43    fn payload_type_name(&self) -> &'static str;
 44    fn as_any(&self) -> &dyn Any;
 45    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 46    fn is_background(&self) -> bool;
 47    fn original_sender_id(&self) -> Option<PeerId>;
 48    fn sender_id(&self) -> ConnectionId;
 49    fn message_id(&self) -> u32;
 50}
 51
 52pub enum MessagePriority {
 53    Foreground,
 54    Background,
 55}
 56
 57impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 58    fn payload_type_id(&self) -> TypeId {
 59        TypeId::of::<T>()
 60    }
 61
 62    fn payload_type_name(&self) -> &'static str {
 63        T::NAME
 64    }
 65
 66    fn as_any(&self) -> &dyn Any {
 67        self
 68    }
 69
 70    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 71        self
 72    }
 73
 74    fn is_background(&self) -> bool {
 75        matches!(T::PRIORITY, MessagePriority::Background)
 76    }
 77
 78    fn original_sender_id(&self) -> Option<PeerId> {
 79        self.original_sender_id
 80    }
 81
 82    fn sender_id(&self) -> ConnectionId {
 83        self.sender_id
 84    }
 85
 86    fn message_id(&self) -> u32 {
 87        self.message_id
 88    }
 89}
 90
 91impl PeerId {
 92    pub fn from_u64(peer_id: u64) -> Self {
 93        let owner_id = (peer_id >> 32) as u32;
 94        let id = peer_id as u32;
 95        Self { owner_id, id }
 96    }
 97
 98    pub fn as_u64(self) -> u64 {
 99        ((self.owner_id as u64) << 32) | (self.id as u64)
100    }
101}
102
103impl Copy for PeerId {}
104
105impl Eq for PeerId {}
106
107impl Ord for PeerId {
108    fn cmp(&self, other: &Self) -> cmp::Ordering {
109        self.owner_id
110            .cmp(&other.owner_id)
111            .then_with(|| self.id.cmp(&other.id))
112    }
113}
114
115impl PartialOrd for PeerId {
116    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
117        Some(self.cmp(other))
118    }
119}
120
121impl std::hash::Hash for PeerId {
122    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123        self.owner_id.hash(state);
124        self.id.hash(state);
125    }
126}
127
128impl fmt::Display for PeerId {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "{}/{}", self.owner_id, self.id)
131    }
132}
133
134messages!(
135    (Ack, Foreground),
136    (AddProjectCollaborator, Foreground),
137    (ApplyCodeAction, Background),
138    (ApplyCodeActionResponse, Background),
139    (ApplyCompletionAdditionalEdits, Background),
140    (ApplyCompletionAdditionalEditsResponse, Background),
141    (BufferReloaded, Foreground),
142    (BufferSaved, Foreground),
143    (Call, Foreground),
144    (CallCanceled, Foreground),
145    (CancelCall, Foreground),
146    (ChannelMessageSent, Foreground),
147    (CopyProjectEntry, Foreground),
148    (CreateBufferForPeer, Foreground),
149    (CreateProjectEntry, Foreground),
150    (CreateRoom, Foreground),
151    (CreateRoomResponse, Foreground),
152    (DeclineCall, Foreground),
153    (DeleteProjectEntry, Foreground),
154    (Error, Foreground),
155    (ExpandProjectEntry, Foreground),
156    (Follow, Foreground),
157    (FollowResponse, Foreground),
158    (FormatBuffers, Foreground),
159    (FormatBuffersResponse, Foreground),
160    (FuzzySearchUsers, Foreground),
161    (GetChannelMessages, Foreground),
162    (GetChannelMessagesResponse, Foreground),
163    (GetChannels, Foreground),
164    (GetChannelsResponse, Foreground),
165    (GetCodeActions, Background),
166    (GetCodeActionsResponse, Background),
167    (GetHover, Background),
168    (GetHoverResponse, Background),
169    (GetCompletions, Background),
170    (GetCompletionsResponse, Background),
171    (GetDefinition, Background),
172    (GetDefinitionResponse, Background),
173    (GetTypeDefinition, Background),
174    (GetTypeDefinitionResponse, Background),
175    (GetDocumentHighlights, Background),
176    (GetDocumentHighlightsResponse, Background),
177    (GetReferences, Background),
178    (GetReferencesResponse, Background),
179    (GetProjectSymbols, Background),
180    (GetProjectSymbolsResponse, Background),
181    (GetUsers, Foreground),
182    (Hello, Foreground),
183    (IncomingCall, Foreground),
184    (UsersResponse, Foreground),
185    (JoinChannel, Foreground),
186    (JoinChannelResponse, Foreground),
187    (JoinProject, Foreground),
188    (JoinProjectResponse, Foreground),
189    (JoinRoom, Foreground),
190    (JoinRoomResponse, Foreground),
191    (LeaveChannel, Foreground),
192    (LeaveProject, Foreground),
193    (LeaveRoom, Foreground),
194    (OpenBufferById, Background),
195    (OpenBufferByPath, Background),
196    (OpenBufferForSymbol, Background),
197    (OpenBufferForSymbolResponse, Background),
198    (OpenBufferResponse, Background),
199    (PerformRename, Background),
200    (PerformRenameResponse, Background),
201    (OnTypeFormatting, Background),
202    (OnTypeFormattingResponse, Background),
203    (InlayHints, Background),
204    (InlayHintsResponse, Background),
205    (RefreshInlayHints, Foreground),
206    (Ping, Foreground),
207    (PrepareRename, Background),
208    (PrepareRenameResponse, Background),
209    (ExpandProjectEntryResponse, Foreground),
210    (ProjectEntryResponse, Foreground),
211    (RejoinRoom, Foreground),
212    (RejoinRoomResponse, Foreground),
213    (RemoveContact, Foreground),
214    (ReloadBuffers, Foreground),
215    (ReloadBuffersResponse, Foreground),
216    (RemoveProjectCollaborator, Foreground),
217    (RenameProjectEntry, Foreground),
218    (RequestContact, Foreground),
219    (RespondToContactRequest, Foreground),
220    (RoomUpdated, Foreground),
221    (SaveBuffer, Foreground),
222    (SearchProject, Background),
223    (SearchProjectResponse, Background),
224    (SendChannelMessage, Foreground),
225    (SendChannelMessageResponse, Foreground),
226    (ShareProject, Foreground),
227    (ShareProjectResponse, Foreground),
228    (ShowContacts, Foreground),
229    (StartLanguageServer, Foreground),
230    (SynchronizeBuffers, Foreground),
231    (SynchronizeBuffersResponse, Foreground),
232    (Test, Foreground),
233    (Unfollow, Foreground),
234    (UnshareProject, Foreground),
235    (UpdateBuffer, Foreground),
236    (UpdateBufferFile, Foreground),
237    (UpdateContacts, Foreground),
238    (UpdateDiagnosticSummary, Foreground),
239    (UpdateFollowers, Foreground),
240    (UpdateInviteInfo, Foreground),
241    (UpdateLanguageServer, Foreground),
242    (UpdateParticipantLocation, Foreground),
243    (UpdateProject, Foreground),
244    (UpdateProjectCollaborator, Foreground),
245    (UpdateWorktree, Foreground),
246    (UpdateWorktreeSettings, Foreground),
247    (UpdateDiffBase, Foreground),
248    (GetPrivateUserInfo, Foreground),
249    (GetPrivateUserInfoResponse, Foreground),
250);
251
252request_messages!(
253    (ApplyCodeAction, ApplyCodeActionResponse),
254    (
255        ApplyCompletionAdditionalEdits,
256        ApplyCompletionAdditionalEditsResponse
257    ),
258    (Call, Ack),
259    (CancelCall, Ack),
260    (CopyProjectEntry, ProjectEntryResponse),
261    (CreateProjectEntry, ProjectEntryResponse),
262    (CreateRoom, CreateRoomResponse),
263    (DeclineCall, Ack),
264    (DeleteProjectEntry, ProjectEntryResponse),
265    (ExpandProjectEntry, ExpandProjectEntryResponse),
266    (Follow, FollowResponse),
267    (FormatBuffers, FormatBuffersResponse),
268    (GetChannelMessages, GetChannelMessagesResponse),
269    (GetChannels, GetChannelsResponse),
270    (GetCodeActions, GetCodeActionsResponse),
271    (GetHover, GetHoverResponse),
272    (GetCompletions, GetCompletionsResponse),
273    (GetDefinition, GetDefinitionResponse),
274    (GetTypeDefinition, GetTypeDefinitionResponse),
275    (GetDocumentHighlights, GetDocumentHighlightsResponse),
276    (GetReferences, GetReferencesResponse),
277    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
278    (GetProjectSymbols, GetProjectSymbolsResponse),
279    (FuzzySearchUsers, UsersResponse),
280    (GetUsers, UsersResponse),
281    (JoinChannel, JoinChannelResponse),
282    (JoinProject, JoinProjectResponse),
283    (JoinRoom, JoinRoomResponse),
284    (LeaveRoom, Ack),
285    (RejoinRoom, RejoinRoomResponse),
286    (IncomingCall, Ack),
287    (OpenBufferById, OpenBufferResponse),
288    (OpenBufferByPath, OpenBufferResponse),
289    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
290    (Ping, Ack),
291    (PerformRename, PerformRenameResponse),
292    (PrepareRename, PrepareRenameResponse),
293    (OnTypeFormatting, OnTypeFormattingResponse),
294    (InlayHints, InlayHintsResponse),
295    (RefreshInlayHints, Ack),
296    (ReloadBuffers, ReloadBuffersResponse),
297    (RequestContact, Ack),
298    (RemoveContact, Ack),
299    (RespondToContactRequest, Ack),
300    (RenameProjectEntry, ProjectEntryResponse),
301    (SaveBuffer, BufferSaved),
302    (SearchProject, SearchProjectResponse),
303    (SendChannelMessage, SendChannelMessageResponse),
304    (ShareProject, ShareProjectResponse),
305    (SynchronizeBuffers, SynchronizeBuffersResponse),
306    (Test, Test),
307    (UpdateBuffer, Ack),
308    (UpdateParticipantLocation, Ack),
309    (UpdateProject, Ack),
310    (UpdateWorktree, Ack),
311);
312
313entity_messages!(
314    project_id,
315    AddProjectCollaborator,
316    ApplyCodeAction,
317    ApplyCompletionAdditionalEdits,
318    BufferReloaded,
319    BufferSaved,
320    CopyProjectEntry,
321    CreateBufferForPeer,
322    CreateProjectEntry,
323    DeleteProjectEntry,
324    ExpandProjectEntry,
325    Follow,
326    FormatBuffers,
327    GetCodeActions,
328    GetCompletions,
329    GetDefinition,
330    GetTypeDefinition,
331    GetDocumentHighlights,
332    GetHover,
333    GetReferences,
334    GetProjectSymbols,
335    JoinProject,
336    LeaveProject,
337    OpenBufferById,
338    OpenBufferByPath,
339    OpenBufferForSymbol,
340    PerformRename,
341    OnTypeFormatting,
342    InlayHints,
343    RefreshInlayHints,
344    PrepareRename,
345    ReloadBuffers,
346    RemoveProjectCollaborator,
347    RenameProjectEntry,
348    SaveBuffer,
349    SearchProject,
350    StartLanguageServer,
351    SynchronizeBuffers,
352    Unfollow,
353    UnshareProject,
354    UpdateBuffer,
355    UpdateBufferFile,
356    UpdateDiagnosticSummary,
357    UpdateFollowers,
358    UpdateLanguageServer,
359    UpdateProject,
360    UpdateProjectCollaborator,
361    UpdateWorktree,
362    UpdateWorktreeSettings,
363    UpdateDiffBase
364);
365
366entity_messages!(channel_id, ChannelMessageSent);
367
368const KIB: usize = 1024;
369const MIB: usize = KIB * 1024;
370const MAX_BUFFER_LEN: usize = MIB;
371
372/// A stream of protobuf messages.
373pub struct MessageStream<S> {
374    stream: S,
375    encoding_buffer: Vec<u8>,
376}
377
378#[allow(clippy::large_enum_variant)]
379#[derive(Debug)]
380pub enum Message {
381    Envelope(Envelope),
382    Ping,
383    Pong,
384}
385
386impl<S> MessageStream<S> {
387    pub fn new(stream: S) -> Self {
388        Self {
389            stream,
390            encoding_buffer: Vec::new(),
391        }
392    }
393
394    pub fn inner_mut(&mut self) -> &mut S {
395        &mut self.stream
396    }
397}
398
399impl<S> MessageStream<S>
400where
401    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
402{
403    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
404        #[cfg(any(test, feature = "test-support"))]
405        const COMPRESSION_LEVEL: i32 = -7;
406
407        #[cfg(not(any(test, feature = "test-support")))]
408        const COMPRESSION_LEVEL: i32 = 4;
409
410        match message {
411            Message::Envelope(message) => {
412                self.encoding_buffer.reserve(message.encoded_len());
413                message
414                    .encode(&mut self.encoding_buffer)
415                    .map_err(io::Error::from)?;
416                let buffer =
417                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
418                        .unwrap();
419
420                self.encoding_buffer.clear();
421                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
422                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
423            }
424            Message::Ping => {
425                self.stream
426                    .send(WebSocketMessage::Ping(Default::default()))
427                    .await?;
428            }
429            Message::Pong => {
430                self.stream
431                    .send(WebSocketMessage::Pong(Default::default()))
432                    .await?;
433            }
434        }
435
436        Ok(())
437    }
438}
439
440impl<S> MessageStream<S>
441where
442    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
443{
444    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
445        while let Some(bytes) = self.stream.next().await {
446            match bytes? {
447                WebSocketMessage::Binary(bytes) => {
448                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
449                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
450                        .map_err(io::Error::from)?;
451
452                    self.encoding_buffer.clear();
453                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
454                    return Ok(Message::Envelope(envelope));
455                }
456                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
457                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
458                WebSocketMessage::Close(_) => break,
459                _ => {}
460            }
461        }
462        Err(anyhow!("connection closed"))
463    }
464}
465
466impl From<Timestamp> for SystemTime {
467    fn from(val: Timestamp) -> Self {
468        UNIX_EPOCH
469            .checked_add(Duration::new(val.seconds, val.nanos))
470            .unwrap()
471    }
472}
473
474impl From<SystemTime> for Timestamp {
475    fn from(time: SystemTime) -> Self {
476        let duration = time.duration_since(UNIX_EPOCH).unwrap();
477        Self {
478            seconds: duration.as_secs(),
479            nanos: duration.subsec_nanos(),
480        }
481    }
482}
483
484impl From<u128> for Nonce {
485    fn from(nonce: u128) -> Self {
486        let upper_half = (nonce >> 64) as u64;
487        let lower_half = nonce as u64;
488        Self {
489            upper_half,
490            lower_half,
491        }
492    }
493}
494
495impl From<Nonce> for u128 {
496    fn from(nonce: Nonce) -> Self {
497        let upper_half = (nonce.upper_half as u128) << 64;
498        let lower_half = nonce.lower_half as u128;
499        upper_half | lower_half
500    }
501}
502
503pub fn split_worktree_update(
504    mut message: UpdateWorktree,
505    max_chunk_size: usize,
506) -> impl Iterator<Item = UpdateWorktree> {
507    let mut done_files = false;
508
509    let mut repository_map = message
510        .updated_repositories
511        .into_iter()
512        .map(|repo| (repo.work_directory_id, repo))
513        .collect::<HashMap<_, _>>();
514
515    iter::from_fn(move || {
516        if done_files {
517            return None;
518        }
519
520        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
521        let updated_entries: Vec<_> = message
522            .updated_entries
523            .drain(..updated_entries_chunk_size)
524            .collect();
525
526        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
527        let removed_entries = message
528            .removed_entries
529            .drain(..removed_entries_chunk_size)
530            .collect();
531
532        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
533
534        let mut updated_repositories = Vec::new();
535
536        if !repository_map.is_empty() {
537            for entry in &updated_entries {
538                if let Some(repo) = repository_map.remove(&entry.id) {
539                    updated_repositories.push(repo)
540                }
541            }
542        }
543
544        let removed_repositories = if done_files {
545            mem::take(&mut message.removed_repositories)
546        } else {
547            Default::default()
548        };
549
550        if done_files {
551            updated_repositories.extend(mem::take(&mut repository_map).into_values());
552        }
553
554        Some(UpdateWorktree {
555            project_id: message.project_id,
556            worktree_id: message.worktree_id,
557            root_name: message.root_name.clone(),
558            abs_path: message.abs_path.clone(),
559            updated_entries,
560            removed_entries,
561            scan_id: message.scan_id,
562            is_last_update: done_files && message.is_last_update,
563            updated_repositories,
564            removed_repositories,
565        })
566    })
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572
573    #[gpui::test]
574    async fn test_buffer_size() {
575        let (tx, rx) = futures::channel::mpsc::unbounded();
576        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
577        sink.write(Message::Envelope(Envelope {
578            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
579                root_name: "abcdefg".repeat(10),
580                ..Default::default()
581            })),
582            ..Default::default()
583        }))
584        .await
585        .unwrap();
586        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
587        sink.write(Message::Envelope(Envelope {
588            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
589                root_name: "abcdefg".repeat(1000000),
590                ..Default::default()
591            })),
592            ..Default::default()
593        }))
594        .await
595        .unwrap();
596        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
597
598        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
599        stream.read().await.unwrap();
600        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
601        stream.read().await.unwrap();
602        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
603    }
604
605    #[gpui::test]
606    fn test_converting_peer_id_from_and_to_u64() {
607        let peer_id = PeerId {
608            owner_id: 10,
609            id: 3,
610        };
611        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
612        let peer_id = PeerId {
613            owner_id: u32::MAX,
614            id: 3,
615        };
616        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
617        let peer_id = PeerId {
618            owner_id: 10,
619            id: u32::MAX,
620        };
621        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
622        let peer_id = PeerId {
623            owner_id: u32::MAX,
624            id: u32::MAX,
625        };
626        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
627    }
628}