proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use collections::HashMap;
  5use futures::{SinkExt as _, StreamExt as _};
  6use prost::Message as _;
  7use serde::Serialize;
  8use std::any::{Any, TypeId};
  9use std::{
 10    cmp,
 11    fmt::Debug,
 12    io, iter,
 13    time::{Duration, SystemTime, UNIX_EPOCH},
 14};
 15use std::{fmt, mem};
 16
 17include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 18
 19pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 20    const NAME: &'static str;
 21    const PRIORITY: MessagePriority;
 22    fn into_envelope(
 23        self,
 24        id: u32,
 25        responding_to: Option<u32>,
 26        original_sender_id: Option<PeerId>,
 27    ) -> Envelope;
 28    fn from_envelope(envelope: Envelope) -> Option<Self>;
 29}
 30
 31pub trait EntityMessage: EnvelopedMessage {
 32    fn remote_entity_id(&self) -> u64;
 33}
 34
 35pub trait RequestMessage: EnvelopedMessage {
 36    type Response: EnvelopedMessage;
 37}
 38
 39pub trait AnyTypedEnvelope: 'static + Send + Sync {
 40    fn payload_type_id(&self) -> TypeId;
 41    fn payload_type_name(&self) -> &'static str;
 42    fn as_any(&self) -> &dyn Any;
 43    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 44    fn is_background(&self) -> bool;
 45    fn original_sender_id(&self) -> Option<PeerId>;
 46    fn sender_id(&self) -> ConnectionId;
 47    fn message_id(&self) -> u32;
 48}
 49
 50pub enum MessagePriority {
 51    Foreground,
 52    Background,
 53}
 54
 55impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 56    fn payload_type_id(&self) -> TypeId {
 57        TypeId::of::<T>()
 58    }
 59
 60    fn payload_type_name(&self) -> &'static str {
 61        T::NAME
 62    }
 63
 64    fn as_any(&self) -> &dyn Any {
 65        self
 66    }
 67
 68    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 69        self
 70    }
 71
 72    fn is_background(&self) -> bool {
 73        matches!(T::PRIORITY, MessagePriority::Background)
 74    }
 75
 76    fn original_sender_id(&self) -> Option<PeerId> {
 77        self.original_sender_id
 78    }
 79
 80    fn sender_id(&self) -> ConnectionId {
 81        self.sender_id
 82    }
 83
 84    fn message_id(&self) -> u32 {
 85        self.message_id
 86    }
 87}
 88
 89impl PeerId {
 90    pub fn from_u64(peer_id: u64) -> Self {
 91        let owner_id = (peer_id >> 32) as u32;
 92        let id = peer_id as u32;
 93        Self { owner_id, id }
 94    }
 95
 96    pub fn as_u64(self) -> u64 {
 97        ((self.owner_id as u64) << 32) | (self.id as u64)
 98    }
 99}
100
101impl Copy for PeerId {}
102
103impl Eq for PeerId {}
104
105impl Ord for PeerId {
106    fn cmp(&self, other: &Self) -> cmp::Ordering {
107        self.owner_id
108            .cmp(&other.owner_id)
109            .then_with(|| self.id.cmp(&other.id))
110    }
111}
112
113impl PartialOrd for PeerId {
114    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
115        Some(self.cmp(other))
116    }
117}
118
119impl std::hash::Hash for PeerId {
120    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
121        self.owner_id.hash(state);
122        self.id.hash(state);
123    }
124}
125
126impl fmt::Display for PeerId {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        write!(f, "{}/{}", self.owner_id, self.id)
129    }
130}
131
132messages!(
133    (Ack, Foreground),
134    (AddProjectCollaborator, Foreground),
135    (ApplyCodeAction, Background),
136    (ApplyCodeActionResponse, Background),
137    (ApplyCompletionAdditionalEdits, Background),
138    (ApplyCompletionAdditionalEditsResponse, Background),
139    (BufferReloaded, Foreground),
140    (BufferSaved, Foreground),
141    (Call, Foreground),
142    (CallCanceled, Foreground),
143    (CancelCall, Foreground),
144    (ChannelMessageSent, Foreground),
145    (CopyProjectEntry, Foreground),
146    (CreateBufferForPeer, Foreground),
147    (CreateProjectEntry, Foreground),
148    (CreateRoom, Foreground),
149    (CreateRoomResponse, Foreground),
150    (DeclineCall, Foreground),
151    (DeleteProjectEntry, Foreground),
152    (Error, Foreground),
153    (ExpandProjectEntry, Foreground),
154    (Follow, Foreground),
155    (FollowResponse, Foreground),
156    (FormatBuffers, Foreground),
157    (FormatBuffersResponse, Foreground),
158    (FuzzySearchUsers, Foreground),
159    (GetChannelMessages, Foreground),
160    (GetChannelMessagesResponse, Foreground),
161    (GetChannels, Foreground),
162    (GetChannelsResponse, Foreground),
163    (GetCodeActions, Background),
164    (GetCodeActionsResponse, Background),
165    (GetHover, Background),
166    (GetHoverResponse, Background),
167    (GetCompletions, Background),
168    (GetCompletionsResponse, Background),
169    (GetDefinition, Background),
170    (GetDefinitionResponse, Background),
171    (GetTypeDefinition, Background),
172    (GetTypeDefinitionResponse, Background),
173    (GetDocumentHighlights, Background),
174    (GetDocumentHighlightsResponse, Background),
175    (GetReferences, Background),
176    (GetReferencesResponse, Background),
177    (GetProjectSymbols, Background),
178    (GetProjectSymbolsResponse, Background),
179    (GetUsers, Foreground),
180    (Hello, Foreground),
181    (IncomingCall, Foreground),
182    (UsersResponse, Foreground),
183    (JoinChannel, Foreground),
184    (JoinChannelResponse, Foreground),
185    (JoinProject, Foreground),
186    (JoinProjectResponse, Foreground),
187    (JoinRoom, Foreground),
188    (JoinRoomResponse, Foreground),
189    (LeaveChannel, Foreground),
190    (LeaveProject, Foreground),
191    (LeaveRoom, Foreground),
192    (OpenBufferById, Background),
193    (OpenBufferByPath, Background),
194    (OpenBufferForSymbol, Background),
195    (OpenBufferForSymbolResponse, Background),
196    (OpenBufferResponse, Background),
197    (PerformRename, Background),
198    (PerformRenameResponse, Background),
199    (OnTypeFormatting, Background),
200    (OnTypeFormattingResponse, Background),
201    (InlayHints, Background),
202    (InlayHintsResponse, Background),
203    (Ping, Foreground),
204    (PrepareRename, Background),
205    (PrepareRenameResponse, Background),
206    (ExpandProjectEntryResponse, Foreground),
207    (ProjectEntryResponse, Foreground),
208    (RejoinRoom, Foreground),
209    (RejoinRoomResponse, Foreground),
210    (RemoveContact, Foreground),
211    (ReloadBuffers, Foreground),
212    (ReloadBuffersResponse, Foreground),
213    (RemoveProjectCollaborator, Foreground),
214    (RenameProjectEntry, Foreground),
215    (RequestContact, Foreground),
216    (RespondToContactRequest, Foreground),
217    (RoomUpdated, Foreground),
218    (SaveBuffer, Foreground),
219    (SearchProject, Background),
220    (SearchProjectResponse, Background),
221    (SendChannelMessage, Foreground),
222    (SendChannelMessageResponse, Foreground),
223    (ShareProject, Foreground),
224    (ShareProjectResponse, Foreground),
225    (ShowContacts, Foreground),
226    (StartLanguageServer, Foreground),
227    (SynchronizeBuffers, Foreground),
228    (SynchronizeBuffersResponse, Foreground),
229    (Test, Foreground),
230    (Unfollow, Foreground),
231    (UnshareProject, Foreground),
232    (UpdateBuffer, Foreground),
233    (UpdateBufferFile, Foreground),
234    (UpdateContacts, Foreground),
235    (UpdateDiagnosticSummary, Foreground),
236    (UpdateFollowers, Foreground),
237    (UpdateInviteInfo, Foreground),
238    (UpdateLanguageServer, Foreground),
239    (UpdateParticipantLocation, Foreground),
240    (UpdateProject, Foreground),
241    (UpdateProjectCollaborator, Foreground),
242    (UpdateWorktree, Foreground),
243    (UpdateWorktreeSettings, Foreground),
244    (UpdateDiffBase, Foreground),
245    (GetPrivateUserInfo, Foreground),
246    (GetPrivateUserInfoResponse, Foreground),
247);
248
249request_messages!(
250    (ApplyCodeAction, ApplyCodeActionResponse),
251    (
252        ApplyCompletionAdditionalEdits,
253        ApplyCompletionAdditionalEditsResponse
254    ),
255    (Call, Ack),
256    (CancelCall, Ack),
257    (CopyProjectEntry, ProjectEntryResponse),
258    (CreateProjectEntry, ProjectEntryResponse),
259    (CreateRoom, CreateRoomResponse),
260    (DeclineCall, Ack),
261    (DeleteProjectEntry, ProjectEntryResponse),
262    (ExpandProjectEntry, ExpandProjectEntryResponse),
263    (Follow, FollowResponse),
264    (FormatBuffers, FormatBuffersResponse),
265    (GetChannelMessages, GetChannelMessagesResponse),
266    (GetChannels, GetChannelsResponse),
267    (GetCodeActions, GetCodeActionsResponse),
268    (GetHover, GetHoverResponse),
269    (GetCompletions, GetCompletionsResponse),
270    (GetDefinition, GetDefinitionResponse),
271    (GetTypeDefinition, GetTypeDefinitionResponse),
272    (GetDocumentHighlights, GetDocumentHighlightsResponse),
273    (GetReferences, GetReferencesResponse),
274    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
275    (GetProjectSymbols, GetProjectSymbolsResponse),
276    (FuzzySearchUsers, UsersResponse),
277    (GetUsers, UsersResponse),
278    (JoinChannel, JoinChannelResponse),
279    (JoinProject, JoinProjectResponse),
280    (JoinRoom, JoinRoomResponse),
281    (LeaveRoom, Ack),
282    (RejoinRoom, RejoinRoomResponse),
283    (IncomingCall, Ack),
284    (OpenBufferById, OpenBufferResponse),
285    (OpenBufferByPath, OpenBufferResponse),
286    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
287    (Ping, Ack),
288    (PerformRename, PerformRenameResponse),
289    (PrepareRename, PrepareRenameResponse),
290    (OnTypeFormatting, OnTypeFormattingResponse),
291    (InlayHints, InlayHintsResponse),
292    (ReloadBuffers, ReloadBuffersResponse),
293    (RequestContact, Ack),
294    (RemoveContact, Ack),
295    (RespondToContactRequest, Ack),
296    (RenameProjectEntry, ProjectEntryResponse),
297    (SaveBuffer, BufferSaved),
298    (SearchProject, SearchProjectResponse),
299    (SendChannelMessage, SendChannelMessageResponse),
300    (ShareProject, ShareProjectResponse),
301    (SynchronizeBuffers, SynchronizeBuffersResponse),
302    (Test, Test),
303    (UpdateBuffer, Ack),
304    (UpdateParticipantLocation, Ack),
305    (UpdateProject, Ack),
306    (UpdateWorktree, Ack),
307);
308
309entity_messages!(
310    project_id,
311    AddProjectCollaborator,
312    ApplyCodeAction,
313    ApplyCompletionAdditionalEdits,
314    BufferReloaded,
315    BufferSaved,
316    CopyProjectEntry,
317    CreateBufferForPeer,
318    CreateProjectEntry,
319    DeleteProjectEntry,
320    ExpandProjectEntry,
321    Follow,
322    FormatBuffers,
323    GetCodeActions,
324    GetCompletions,
325    GetDefinition,
326    GetTypeDefinition,
327    GetDocumentHighlights,
328    GetHover,
329    GetReferences,
330    GetProjectSymbols,
331    JoinProject,
332    LeaveProject,
333    OpenBufferById,
334    OpenBufferByPath,
335    OpenBufferForSymbol,
336    PerformRename,
337    OnTypeFormatting,
338    InlayHints,
339    PrepareRename,
340    ReloadBuffers,
341    RemoveProjectCollaborator,
342    RenameProjectEntry,
343    SaveBuffer,
344    SearchProject,
345    StartLanguageServer,
346    SynchronizeBuffers,
347    Unfollow,
348    UnshareProject,
349    UpdateBuffer,
350    UpdateBufferFile,
351    UpdateDiagnosticSummary,
352    UpdateFollowers,
353    UpdateLanguageServer,
354    UpdateProject,
355    UpdateProjectCollaborator,
356    UpdateWorktree,
357    UpdateWorktreeSettings,
358    UpdateDiffBase
359);
360
361entity_messages!(channel_id, ChannelMessageSent);
362
363const KIB: usize = 1024;
364const MIB: usize = KIB * 1024;
365const MAX_BUFFER_LEN: usize = MIB;
366
367/// A stream of protobuf messages.
368pub struct MessageStream<S> {
369    stream: S,
370    encoding_buffer: Vec<u8>,
371}
372
373#[allow(clippy::large_enum_variant)]
374#[derive(Debug)]
375pub enum Message {
376    Envelope(Envelope),
377    Ping,
378    Pong,
379}
380
381impl<S> MessageStream<S> {
382    pub fn new(stream: S) -> Self {
383        Self {
384            stream,
385            encoding_buffer: Vec::new(),
386        }
387    }
388
389    pub fn inner_mut(&mut self) -> &mut S {
390        &mut self.stream
391    }
392}
393
394impl<S> MessageStream<S>
395where
396    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
397{
398    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
399        #[cfg(any(test, feature = "test-support"))]
400        const COMPRESSION_LEVEL: i32 = -7;
401
402        #[cfg(not(any(test, feature = "test-support")))]
403        const COMPRESSION_LEVEL: i32 = 4;
404
405        match message {
406            Message::Envelope(message) => {
407                self.encoding_buffer.reserve(message.encoded_len());
408                message
409                    .encode(&mut self.encoding_buffer)
410                    .map_err(io::Error::from)?;
411                let buffer =
412                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
413                        .unwrap();
414
415                self.encoding_buffer.clear();
416                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
417                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
418            }
419            Message::Ping => {
420                self.stream
421                    .send(WebSocketMessage::Ping(Default::default()))
422                    .await?;
423            }
424            Message::Pong => {
425                self.stream
426                    .send(WebSocketMessage::Pong(Default::default()))
427                    .await?;
428            }
429        }
430
431        Ok(())
432    }
433}
434
435impl<S> MessageStream<S>
436where
437    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
438{
439    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
440        while let Some(bytes) = self.stream.next().await {
441            match bytes? {
442                WebSocketMessage::Binary(bytes) => {
443                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
444                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
445                        .map_err(io::Error::from)?;
446
447                    self.encoding_buffer.clear();
448                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
449                    return Ok(Message::Envelope(envelope));
450                }
451                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
452                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
453                WebSocketMessage::Close(_) => break,
454                _ => {}
455            }
456        }
457        Err(anyhow!("connection closed"))
458    }
459}
460
461impl From<Timestamp> for SystemTime {
462    fn from(val: Timestamp) -> Self {
463        UNIX_EPOCH
464            .checked_add(Duration::new(val.seconds, val.nanos))
465            .unwrap()
466    }
467}
468
469impl From<SystemTime> for Timestamp {
470    fn from(time: SystemTime) -> Self {
471        let duration = time.duration_since(UNIX_EPOCH).unwrap();
472        Self {
473            seconds: duration.as_secs(),
474            nanos: duration.subsec_nanos(),
475        }
476    }
477}
478
479impl From<u128> for Nonce {
480    fn from(nonce: u128) -> Self {
481        let upper_half = (nonce >> 64) as u64;
482        let lower_half = nonce as u64;
483        Self {
484            upper_half,
485            lower_half,
486        }
487    }
488}
489
490impl From<Nonce> for u128 {
491    fn from(nonce: Nonce) -> Self {
492        let upper_half = (nonce.upper_half as u128) << 64;
493        let lower_half = nonce.lower_half as u128;
494        upper_half | lower_half
495    }
496}
497
498pub fn split_worktree_update(
499    mut message: UpdateWorktree,
500    max_chunk_size: usize,
501) -> impl Iterator<Item = UpdateWorktree> {
502    let mut done_files = false;
503
504    let mut repository_map = message
505        .updated_repositories
506        .into_iter()
507        .map(|repo| (repo.work_directory_id, repo))
508        .collect::<HashMap<_, _>>();
509
510    iter::from_fn(move || {
511        if done_files {
512            return None;
513        }
514
515        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
516        let updated_entries: Vec<_> = message
517            .updated_entries
518            .drain(..updated_entries_chunk_size)
519            .collect();
520
521        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
522        let removed_entries = message
523            .removed_entries
524            .drain(..removed_entries_chunk_size)
525            .collect();
526
527        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
528
529        let mut updated_repositories = Vec::new();
530
531        if !repository_map.is_empty() {
532            for entry in &updated_entries {
533                if let Some(repo) = repository_map.remove(&entry.id) {
534                    updated_repositories.push(repo)
535                }
536            }
537        }
538
539        let removed_repositories = if done_files {
540            mem::take(&mut message.removed_repositories)
541        } else {
542            Default::default()
543        };
544
545        if done_files {
546            updated_repositories.extend(mem::take(&mut repository_map).into_values());
547        }
548
549        Some(UpdateWorktree {
550            project_id: message.project_id,
551            worktree_id: message.worktree_id,
552            root_name: message.root_name.clone(),
553            abs_path: message.abs_path.clone(),
554            updated_entries,
555            removed_entries,
556            scan_id: message.scan_id,
557            is_last_update: done_files && message.is_last_update,
558            updated_repositories,
559            removed_repositories,
560        })
561    })
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567
568    #[gpui::test]
569    async fn test_buffer_size() {
570        let (tx, rx) = futures::channel::mpsc::unbounded();
571        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
572        sink.write(Message::Envelope(Envelope {
573            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
574                root_name: "abcdefg".repeat(10),
575                ..Default::default()
576            })),
577            ..Default::default()
578        }))
579        .await
580        .unwrap();
581        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
582        sink.write(Message::Envelope(Envelope {
583            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
584                root_name: "abcdefg".repeat(1000000),
585                ..Default::default()
586            })),
587            ..Default::default()
588        }))
589        .await
590        .unwrap();
591        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
592
593        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
594        stream.read().await.unwrap();
595        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
596        stream.read().await.unwrap();
597        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
598    }
599
600    #[gpui::test]
601    fn test_converting_peer_id_from_and_to_u64() {
602        let peer_id = PeerId {
603            owner_id: 10,
604            id: 3,
605        };
606        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
607        let peer_id = PeerId {
608            owner_id: u32::MAX,
609            id: 3,
610        };
611        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
612        let peer_id = PeerId {
613            owner_id: 10,
614            id: u32::MAX,
615        };
616        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
617        let peer_id = PeerId {
618            owner_id: u32::MAX,
619            id: u32::MAX,
620        };
621        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
622    }
623}