proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use collections::HashMap;
  5use futures::{SinkExt as _, StreamExt as _};
  6use prost::Message as _;
  7use serde::Serialize;
  8use std::any::{Any, TypeId};
  9use std::{
 10    cmp,
 11    fmt::Debug,
 12    io, iter,
 13    time::{Duration, SystemTime, UNIX_EPOCH},
 14};
 15use std::{fmt, mem};
 16
 17include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 18
 19pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 20    const NAME: &'static str;
 21    const PRIORITY: MessagePriority;
 22    fn into_envelope(
 23        self,
 24        id: u32,
 25        responding_to: Option<u32>,
 26        original_sender_id: Option<PeerId>,
 27    ) -> Envelope;
 28    fn from_envelope(envelope: Envelope) -> Option<Self>;
 29}
 30
 31pub trait EntityMessage: EnvelopedMessage {
 32    fn remote_entity_id(&self) -> u64;
 33}
 34
 35pub trait RequestMessage: EnvelopedMessage {
 36    type Response: EnvelopedMessage;
 37}
 38
 39pub trait AnyTypedEnvelope: 'static + Send + Sync {
 40    fn payload_type_id(&self) -> TypeId;
 41    fn payload_type_name(&self) -> &'static str;
 42    fn as_any(&self) -> &dyn Any;
 43    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 44    fn is_background(&self) -> bool;
 45    fn original_sender_id(&self) -> Option<PeerId>;
 46    fn sender_id(&self) -> ConnectionId;
 47    fn message_id(&self) -> u32;
 48}
 49
 50pub enum MessagePriority {
 51    Foreground,
 52    Background,
 53}
 54
 55impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 56    fn payload_type_id(&self) -> TypeId {
 57        TypeId::of::<T>()
 58    }
 59
 60    fn payload_type_name(&self) -> &'static str {
 61        T::NAME
 62    }
 63
 64    fn as_any(&self) -> &dyn Any {
 65        self
 66    }
 67
 68    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 69        self
 70    }
 71
 72    fn is_background(&self) -> bool {
 73        matches!(T::PRIORITY, MessagePriority::Background)
 74    }
 75
 76    fn original_sender_id(&self) -> Option<PeerId> {
 77        self.original_sender_id
 78    }
 79
 80    fn sender_id(&self) -> ConnectionId {
 81        self.sender_id
 82    }
 83
 84    fn message_id(&self) -> u32 {
 85        self.message_id
 86    }
 87}
 88
 89impl PeerId {
 90    pub fn from_u64(peer_id: u64) -> Self {
 91        let owner_id = (peer_id >> 32) as u32;
 92        let id = peer_id as u32;
 93        Self { owner_id, id }
 94    }
 95
 96    pub fn as_u64(self) -> u64 {
 97        ((self.owner_id as u64) << 32) | (self.id as u64)
 98    }
 99}
100
101impl Copy for PeerId {}
102
103impl Eq for PeerId {}
104
105impl Ord for PeerId {
106    fn cmp(&self, other: &Self) -> cmp::Ordering {
107        self.owner_id
108            .cmp(&other.owner_id)
109            .then_with(|| self.id.cmp(&other.id))
110    }
111}
112
113impl PartialOrd for PeerId {
114    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
115        Some(self.cmp(other))
116    }
117}
118
119impl std::hash::Hash for PeerId {
120    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
121        self.owner_id.hash(state);
122        self.id.hash(state);
123    }
124}
125
126impl fmt::Display for PeerId {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        write!(f, "{}/{}", self.owner_id, self.id)
129    }
130}
131
132messages!(
133    (Ack, Foreground),
134    (AddProjectCollaborator, Foreground),
135    (ApplyCodeAction, Background),
136    (ApplyCodeActionResponse, Background),
137    (ApplyCompletionAdditionalEdits, Background),
138    (ApplyCompletionAdditionalEditsResponse, Background),
139    (BufferReloaded, Foreground),
140    (BufferSaved, Foreground),
141    (Call, Foreground),
142    (CallCanceled, Foreground),
143    (CancelCall, Foreground),
144    (ChannelMessageSent, Foreground),
145    (CopyProjectEntry, Foreground),
146    (CreateBufferForPeer, Foreground),
147    (CreateProjectEntry, Foreground),
148    (CreateRoom, Foreground),
149    (CreateRoomResponse, Foreground),
150    (DeclineCall, Foreground),
151    (DeleteProjectEntry, Foreground),
152    (Error, Foreground),
153    (ExpandProjectEntry, Foreground),
154    (Follow, Foreground),
155    (FollowResponse, Foreground),
156    (FormatBuffers, Foreground),
157    (FormatBuffersResponse, Foreground),
158    (FuzzySearchUsers, Foreground),
159    (GetChannelMessages, Foreground),
160    (GetChannelMessagesResponse, Foreground),
161    (GetChannels, Foreground),
162    (GetChannelsResponse, Foreground),
163    (GetCodeActions, Background),
164    (GetCodeActionsResponse, Background),
165    (GetHover, Background),
166    (GetHoverResponse, Background),
167    (GetCompletions, Background),
168    (GetCompletionsResponse, Background),
169    (GetDefinition, Background),
170    (GetDefinitionResponse, Background),
171    (GetTypeDefinition, Background),
172    (GetTypeDefinitionResponse, Background),
173    (GetDocumentHighlights, Background),
174    (GetDocumentHighlightsResponse, Background),
175    (GetReferences, Background),
176    (GetReferencesResponse, Background),
177    (GetProjectSymbols, Background),
178    (GetProjectSymbolsResponse, Background),
179    (GetUsers, Foreground),
180    (Hello, Foreground),
181    (IncomingCall, Foreground),
182    (UsersResponse, Foreground),
183    (JoinChannel, Foreground),
184    (JoinChannelResponse, Foreground),
185    (JoinProject, Foreground),
186    (JoinProjectResponse, Foreground),
187    (JoinRoom, Foreground),
188    (JoinRoomResponse, Foreground),
189    (LeaveChannel, Foreground),
190    (LeaveProject, Foreground),
191    (LeaveRoom, Foreground),
192    (OpenBufferById, Background),
193    (OpenBufferByPath, Background),
194    (OpenBufferForSymbol, Background),
195    (OpenBufferForSymbolResponse, Background),
196    (OpenBufferResponse, Background),
197    (PerformRename, Background),
198    (PerformRenameResponse, Background),
199    (OnTypeFormatting, Background),
200    (OnTypeFormattingResponse, Background),
201    (InlayHints, Background),
202    (InlayHintsResponse, Background),
203    (RefreshInlayHints, Foreground),
204    (Ping, Foreground),
205    (PrepareRename, Background),
206    (PrepareRenameResponse, Background),
207    (ExpandProjectEntryResponse, Foreground),
208    (ProjectEntryResponse, Foreground),
209    (RejoinRoom, Foreground),
210    (RejoinRoomResponse, Foreground),
211    (RemoveContact, Foreground),
212    (ReloadBuffers, Foreground),
213    (ReloadBuffersResponse, Foreground),
214    (RemoveProjectCollaborator, Foreground),
215    (RenameProjectEntry, Foreground),
216    (RequestContact, Foreground),
217    (RespondToContactRequest, Foreground),
218    (RoomUpdated, Foreground),
219    (SaveBuffer, Foreground),
220    (SearchProject, Background),
221    (SearchProjectResponse, Background),
222    (SendChannelMessage, Foreground),
223    (SendChannelMessageResponse, Foreground),
224    (ShareProject, Foreground),
225    (ShareProjectResponse, Foreground),
226    (ShowContacts, Foreground),
227    (StartLanguageServer, Foreground),
228    (SynchronizeBuffers, Foreground),
229    (SynchronizeBuffersResponse, Foreground),
230    (Test, Foreground),
231    (Unfollow, Foreground),
232    (UnshareProject, Foreground),
233    (UpdateBuffer, Foreground),
234    (UpdateBufferFile, Foreground),
235    (UpdateContacts, Foreground),
236    (UpdateDiagnosticSummary, Foreground),
237    (UpdateFollowers, Foreground),
238    (UpdateInviteInfo, Foreground),
239    (UpdateLanguageServer, Foreground),
240    (UpdateParticipantLocation, Foreground),
241    (UpdateProject, Foreground),
242    (UpdateProjectCollaborator, Foreground),
243    (UpdateWorktree, Foreground),
244    (UpdateWorktreeSettings, Foreground),
245    (UpdateDiffBase, Foreground),
246    (GetPrivateUserInfo, Foreground),
247    (GetPrivateUserInfoResponse, Foreground),
248);
249
250request_messages!(
251    (ApplyCodeAction, ApplyCodeActionResponse),
252    (
253        ApplyCompletionAdditionalEdits,
254        ApplyCompletionAdditionalEditsResponse
255    ),
256    (Call, Ack),
257    (CancelCall, Ack),
258    (CopyProjectEntry, ProjectEntryResponse),
259    (CreateProjectEntry, ProjectEntryResponse),
260    (CreateRoom, CreateRoomResponse),
261    (DeclineCall, Ack),
262    (DeleteProjectEntry, ProjectEntryResponse),
263    (ExpandProjectEntry, ExpandProjectEntryResponse),
264    (Follow, FollowResponse),
265    (FormatBuffers, FormatBuffersResponse),
266    (GetChannelMessages, GetChannelMessagesResponse),
267    (GetChannels, GetChannelsResponse),
268    (GetCodeActions, GetCodeActionsResponse),
269    (GetHover, GetHoverResponse),
270    (GetCompletions, GetCompletionsResponse),
271    (GetDefinition, GetDefinitionResponse),
272    (GetTypeDefinition, GetTypeDefinitionResponse),
273    (GetDocumentHighlights, GetDocumentHighlightsResponse),
274    (GetReferences, GetReferencesResponse),
275    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
276    (GetProjectSymbols, GetProjectSymbolsResponse),
277    (FuzzySearchUsers, UsersResponse),
278    (GetUsers, UsersResponse),
279    (JoinChannel, JoinChannelResponse),
280    (JoinProject, JoinProjectResponse),
281    (JoinRoom, JoinRoomResponse),
282    (LeaveRoom, Ack),
283    (RejoinRoom, RejoinRoomResponse),
284    (IncomingCall, Ack),
285    (OpenBufferById, OpenBufferResponse),
286    (OpenBufferByPath, OpenBufferResponse),
287    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
288    (Ping, Ack),
289    (PerformRename, PerformRenameResponse),
290    (PrepareRename, PrepareRenameResponse),
291    (OnTypeFormatting, OnTypeFormattingResponse),
292    (InlayHints, InlayHintsResponse),
293    (RefreshInlayHints, Ack),
294    (ReloadBuffers, ReloadBuffersResponse),
295    (RequestContact, Ack),
296    (RemoveContact, Ack),
297    (RespondToContactRequest, Ack),
298    (RenameProjectEntry, ProjectEntryResponse),
299    (SaveBuffer, BufferSaved),
300    (SearchProject, SearchProjectResponse),
301    (SendChannelMessage, SendChannelMessageResponse),
302    (ShareProject, ShareProjectResponse),
303    (SynchronizeBuffers, SynchronizeBuffersResponse),
304    (Test, Test),
305    (UpdateBuffer, Ack),
306    (UpdateParticipantLocation, Ack),
307    (UpdateProject, Ack),
308    (UpdateWorktree, Ack),
309);
310
311entity_messages!(
312    project_id,
313    AddProjectCollaborator,
314    ApplyCodeAction,
315    ApplyCompletionAdditionalEdits,
316    BufferReloaded,
317    BufferSaved,
318    CopyProjectEntry,
319    CreateBufferForPeer,
320    CreateProjectEntry,
321    DeleteProjectEntry,
322    ExpandProjectEntry,
323    Follow,
324    FormatBuffers,
325    GetCodeActions,
326    GetCompletions,
327    GetDefinition,
328    GetTypeDefinition,
329    GetDocumentHighlights,
330    GetHover,
331    GetReferences,
332    GetProjectSymbols,
333    JoinProject,
334    LeaveProject,
335    OpenBufferById,
336    OpenBufferByPath,
337    OpenBufferForSymbol,
338    PerformRename,
339    OnTypeFormatting,
340    InlayHints,
341    RefreshInlayHints,
342    PrepareRename,
343    ReloadBuffers,
344    RemoveProjectCollaborator,
345    RenameProjectEntry,
346    SaveBuffer,
347    SearchProject,
348    StartLanguageServer,
349    SynchronizeBuffers,
350    Unfollow,
351    UnshareProject,
352    UpdateBuffer,
353    UpdateBufferFile,
354    UpdateDiagnosticSummary,
355    UpdateFollowers,
356    UpdateLanguageServer,
357    UpdateProject,
358    UpdateProjectCollaborator,
359    UpdateWorktree,
360    UpdateWorktreeSettings,
361    UpdateDiffBase
362);
363
364entity_messages!(channel_id, ChannelMessageSent);
365
366const KIB: usize = 1024;
367const MIB: usize = KIB * 1024;
368const MAX_BUFFER_LEN: usize = MIB;
369
370/// A stream of protobuf messages.
371pub struct MessageStream<S> {
372    stream: S,
373    encoding_buffer: Vec<u8>,
374}
375
376#[allow(clippy::large_enum_variant)]
377#[derive(Debug)]
378pub enum Message {
379    Envelope(Envelope),
380    Ping,
381    Pong,
382}
383
384impl<S> MessageStream<S> {
385    pub fn new(stream: S) -> Self {
386        Self {
387            stream,
388            encoding_buffer: Vec::new(),
389        }
390    }
391
392    pub fn inner_mut(&mut self) -> &mut S {
393        &mut self.stream
394    }
395}
396
397impl<S> MessageStream<S>
398where
399    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
400{
401    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
402        #[cfg(any(test, feature = "test-support"))]
403        const COMPRESSION_LEVEL: i32 = -7;
404
405        #[cfg(not(any(test, feature = "test-support")))]
406        const COMPRESSION_LEVEL: i32 = 4;
407
408        match message {
409            Message::Envelope(message) => {
410                self.encoding_buffer.reserve(message.encoded_len());
411                message
412                    .encode(&mut self.encoding_buffer)
413                    .map_err(io::Error::from)?;
414                let buffer =
415                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
416                        .unwrap();
417
418                self.encoding_buffer.clear();
419                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
420                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
421            }
422            Message::Ping => {
423                self.stream
424                    .send(WebSocketMessage::Ping(Default::default()))
425                    .await?;
426            }
427            Message::Pong => {
428                self.stream
429                    .send(WebSocketMessage::Pong(Default::default()))
430                    .await?;
431            }
432        }
433
434        Ok(())
435    }
436}
437
438impl<S> MessageStream<S>
439where
440    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
441{
442    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
443        while let Some(bytes) = self.stream.next().await {
444            match bytes? {
445                WebSocketMessage::Binary(bytes) => {
446                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
447                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
448                        .map_err(io::Error::from)?;
449
450                    self.encoding_buffer.clear();
451                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
452                    return Ok(Message::Envelope(envelope));
453                }
454                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
455                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
456                WebSocketMessage::Close(_) => break,
457                _ => {}
458            }
459        }
460        Err(anyhow!("connection closed"))
461    }
462}
463
464impl From<Timestamp> for SystemTime {
465    fn from(val: Timestamp) -> Self {
466        UNIX_EPOCH
467            .checked_add(Duration::new(val.seconds, val.nanos))
468            .unwrap()
469    }
470}
471
472impl From<SystemTime> for Timestamp {
473    fn from(time: SystemTime) -> Self {
474        let duration = time.duration_since(UNIX_EPOCH).unwrap();
475        Self {
476            seconds: duration.as_secs(),
477            nanos: duration.subsec_nanos(),
478        }
479    }
480}
481
482impl From<u128> for Nonce {
483    fn from(nonce: u128) -> Self {
484        let upper_half = (nonce >> 64) as u64;
485        let lower_half = nonce as u64;
486        Self {
487            upper_half,
488            lower_half,
489        }
490    }
491}
492
493impl From<Nonce> for u128 {
494    fn from(nonce: Nonce) -> Self {
495        let upper_half = (nonce.upper_half as u128) << 64;
496        let lower_half = nonce.lower_half as u128;
497        upper_half | lower_half
498    }
499}
500
501pub fn split_worktree_update(
502    mut message: UpdateWorktree,
503    max_chunk_size: usize,
504) -> impl Iterator<Item = UpdateWorktree> {
505    let mut done_files = false;
506
507    let mut repository_map = message
508        .updated_repositories
509        .into_iter()
510        .map(|repo| (repo.work_directory_id, repo))
511        .collect::<HashMap<_, _>>();
512
513    iter::from_fn(move || {
514        if done_files {
515            return None;
516        }
517
518        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
519        let updated_entries: Vec<_> = message
520            .updated_entries
521            .drain(..updated_entries_chunk_size)
522            .collect();
523
524        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
525        let removed_entries = message
526            .removed_entries
527            .drain(..removed_entries_chunk_size)
528            .collect();
529
530        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
531
532        let mut updated_repositories = Vec::new();
533
534        if !repository_map.is_empty() {
535            for entry in &updated_entries {
536                if let Some(repo) = repository_map.remove(&entry.id) {
537                    updated_repositories.push(repo)
538                }
539            }
540        }
541
542        let removed_repositories = if done_files {
543            mem::take(&mut message.removed_repositories)
544        } else {
545            Default::default()
546        };
547
548        if done_files {
549            updated_repositories.extend(mem::take(&mut repository_map).into_values());
550        }
551
552        Some(UpdateWorktree {
553            project_id: message.project_id,
554            worktree_id: message.worktree_id,
555            root_name: message.root_name.clone(),
556            abs_path: message.abs_path.clone(),
557            updated_entries,
558            removed_entries,
559            scan_id: message.scan_id,
560            is_last_update: done_files && message.is_last_update,
561            updated_repositories,
562            removed_repositories,
563        })
564    })
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570
571    #[gpui::test]
572    async fn test_buffer_size() {
573        let (tx, rx) = futures::channel::mpsc::unbounded();
574        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
575        sink.write(Message::Envelope(Envelope {
576            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
577                root_name: "abcdefg".repeat(10),
578                ..Default::default()
579            })),
580            ..Default::default()
581        }))
582        .await
583        .unwrap();
584        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
585        sink.write(Message::Envelope(Envelope {
586            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
587                root_name: "abcdefg".repeat(1000000),
588                ..Default::default()
589            })),
590            ..Default::default()
591        }))
592        .await
593        .unwrap();
594        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
595
596        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
597        stream.read().await.unwrap();
598        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
599        stream.read().await.unwrap();
600        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
601    }
602
603    #[gpui::test]
604    fn test_converting_peer_id_from_and_to_u64() {
605        let peer_id = PeerId {
606            owner_id: 10,
607            id: 3,
608        };
609        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
610        let peer_id = PeerId {
611            owner_id: u32::MAX,
612            id: 3,
613        };
614        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
615        let peer_id = PeerId {
616            owner_id: 10,
617            id: u32::MAX,
618        };
619        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
620        let peer_id = PeerId {
621            owner_id: u32::MAX,
622            id: u32::MAX,
623        };
624        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
625    }
626}