proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use collections::HashMap;
  5use futures::{SinkExt as _, StreamExt as _};
  6use prost::Message as _;
  7use serde::Serialize;
  8use std::any::{Any, TypeId};
  9use std::{
 10    cmp,
 11    fmt::Debug,
 12    io, iter,
 13    time::{Duration, SystemTime, UNIX_EPOCH},
 14};
 15use std::{fmt, mem};
 16
 17include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 18
 19pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 20    const NAME: &'static str;
 21    const PRIORITY: MessagePriority;
 22    fn into_envelope(
 23        self,
 24        id: u32,
 25        responding_to: Option<u32>,
 26        original_sender_id: Option<PeerId>,
 27    ) -> Envelope;
 28    fn from_envelope(envelope: Envelope) -> Option<Self>;
 29}
 30
 31pub trait EntityMessage: EnvelopedMessage {
 32    fn remote_entity_id(&self) -> u64;
 33}
 34
 35pub trait RequestMessage: EnvelopedMessage {
 36    type Response: EnvelopedMessage;
 37}
 38
 39pub trait AnyTypedEnvelope: 'static + Send + Sync {
 40    fn payload_type_id(&self) -> TypeId;
 41    fn payload_type_name(&self) -> &'static str;
 42    fn as_any(&self) -> &dyn Any;
 43    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 44    fn is_background(&self) -> bool;
 45    fn original_sender_id(&self) -> Option<PeerId>;
 46    fn sender_id(&self) -> ConnectionId;
 47    fn message_id(&self) -> u32;
 48}
 49
 50pub enum MessagePriority {
 51    Foreground,
 52    Background,
 53}
 54
 55impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 56    fn payload_type_id(&self) -> TypeId {
 57        TypeId::of::<T>()
 58    }
 59
 60    fn payload_type_name(&self) -> &'static str {
 61        T::NAME
 62    }
 63
 64    fn as_any(&self) -> &dyn Any {
 65        self
 66    }
 67
 68    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 69        self
 70    }
 71
 72    fn is_background(&self) -> bool {
 73        matches!(T::PRIORITY, MessagePriority::Background)
 74    }
 75
 76    fn original_sender_id(&self) -> Option<PeerId> {
 77        self.original_sender_id
 78    }
 79
 80    fn sender_id(&self) -> ConnectionId {
 81        self.sender_id
 82    }
 83
 84    fn message_id(&self) -> u32 {
 85        self.message_id
 86    }
 87}
 88
 89impl PeerId {
 90    pub fn from_u64(peer_id: u64) -> Self {
 91        let owner_id = (peer_id >> 32) as u32;
 92        let id = peer_id as u32;
 93        Self { owner_id, id }
 94    }
 95
 96    pub fn as_u64(self) -> u64 {
 97        ((self.owner_id as u64) << 32) | (self.id as u64)
 98    }
 99}
100
101impl Copy for PeerId {}
102
103impl Eq for PeerId {}
104
105impl Ord for PeerId {
106    fn cmp(&self, other: &Self) -> cmp::Ordering {
107        self.owner_id
108            .cmp(&other.owner_id)
109            .then_with(|| self.id.cmp(&other.id))
110    }
111}
112
113impl PartialOrd for PeerId {
114    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
115        Some(self.cmp(other))
116    }
117}
118
119impl std::hash::Hash for PeerId {
120    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
121        self.owner_id.hash(state);
122        self.id.hash(state);
123    }
124}
125
126impl fmt::Display for PeerId {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        write!(f, "{}/{}", self.owner_id, self.id)
129    }
130}
131
132messages!(
133    (Ack, Foreground),
134    (AddProjectCollaborator, Foreground),
135    (ApplyCodeAction, Background),
136    (ApplyCodeActionResponse, Background),
137    (ApplyCompletionAdditionalEdits, Background),
138    (ApplyCompletionAdditionalEditsResponse, Background),
139    (BufferReloaded, Foreground),
140    (BufferSaved, Foreground),
141    (Call, Foreground),
142    (CallCanceled, Foreground),
143    (CancelCall, Foreground),
144    (ChannelMessageSent, Foreground),
145    (CopyProjectEntry, Foreground),
146    (CreateBufferForPeer, Foreground),
147    (CreateProjectEntry, Foreground),
148    (CreateRoom, Foreground),
149    (CreateRoomResponse, Foreground),
150    (DeclineCall, Foreground),
151    (DeleteProjectEntry, Foreground),
152    (Error, Foreground),
153    (ExpandProjectEntry, Foreground),
154    (Follow, Foreground),
155    (FollowResponse, Foreground),
156    (FormatBuffers, Foreground),
157    (FormatBuffersResponse, Foreground),
158    (FuzzySearchUsers, Foreground),
159    (GetChannelMessages, Foreground),
160    (GetChannelMessagesResponse, Foreground),
161    (GetChannels, Foreground),
162    (GetChannelsResponse, Foreground),
163    (GetCodeActions, Background),
164    (GetCodeActionsResponse, Background),
165    (GetHover, Background),
166    (GetHoverResponse, Background),
167    (GetCompletions, Background),
168    (GetCompletionsResponse, Background),
169    (GetDefinition, Background),
170    (GetDefinitionResponse, Background),
171    (GetTypeDefinition, Background),
172    (GetTypeDefinitionResponse, Background),
173    (GetDocumentHighlights, Background),
174    (GetDocumentHighlightsResponse, Background),
175    (GetReferences, Background),
176    (GetReferencesResponse, Background),
177    (GetProjectSymbols, Background),
178    (GetProjectSymbolsResponse, Background),
179    (GetUsers, Foreground),
180    (Hello, Foreground),
181    (IncomingCall, Foreground),
182    (UsersResponse, Foreground),
183    (JoinChannel, Foreground),
184    (JoinChannelResponse, Foreground),
185    (JoinProject, Foreground),
186    (JoinProjectResponse, Foreground),
187    (JoinRoom, Foreground),
188    (JoinRoomResponse, Foreground),
189    (LeaveChannel, Foreground),
190    (LeaveProject, Foreground),
191    (LeaveRoom, Foreground),
192    (OpenBufferById, Background),
193    (OpenBufferByPath, Background),
194    (OpenBufferForSymbol, Background),
195    (OpenBufferForSymbolResponse, Background),
196    (OpenBufferResponse, Background),
197    (PerformRename, Background),
198    (PerformRenameResponse, Background),
199    (OnTypeFormatting, Background),
200    (OnTypeFormattingResponse, Background),
201    (Ping, Foreground),
202    (PrepareRename, Background),
203    (PrepareRenameResponse, Background),
204    (ExpandProjectEntryResponse, Foreground),
205    (ProjectEntryResponse, Foreground),
206    (RejoinRoom, Foreground),
207    (RejoinRoomResponse, Foreground),
208    (RemoveContact, Foreground),
209    (ReloadBuffers, Foreground),
210    (ReloadBuffersResponse, Foreground),
211    (RemoveProjectCollaborator, Foreground),
212    (RenameProjectEntry, Foreground),
213    (RequestContact, Foreground),
214    (RespondToContactRequest, Foreground),
215    (RoomUpdated, Foreground),
216    (SaveBuffer, Foreground),
217    (SearchProject, Background),
218    (SearchProjectResponse, Background),
219    (SendChannelMessage, Foreground),
220    (SendChannelMessageResponse, Foreground),
221    (ShareProject, Foreground),
222    (ShareProjectResponse, Foreground),
223    (ShowContacts, Foreground),
224    (StartLanguageServer, Foreground),
225    (SynchronizeBuffers, Foreground),
226    (SynchronizeBuffersResponse, Foreground),
227    (Test, Foreground),
228    (Unfollow, Foreground),
229    (UnshareProject, Foreground),
230    (UpdateBuffer, Foreground),
231    (UpdateBufferFile, Foreground),
232    (UpdateContacts, Foreground),
233    (UpdateDiagnosticSummary, Foreground),
234    (UpdateFollowers, Foreground),
235    (UpdateInviteInfo, Foreground),
236    (UpdateLanguageServer, Foreground),
237    (UpdateParticipantLocation, Foreground),
238    (UpdateProject, Foreground),
239    (UpdateProjectCollaborator, Foreground),
240    (UpdateWorktree, Foreground),
241    (UpdateWorktreeSettings, Foreground),
242    (UpdateDiffBase, Foreground),
243    (GetPrivateUserInfo, Foreground),
244    (GetPrivateUserInfoResponse, Foreground),
245);
246
247request_messages!(
248    (ApplyCodeAction, ApplyCodeActionResponse),
249    (
250        ApplyCompletionAdditionalEdits,
251        ApplyCompletionAdditionalEditsResponse
252    ),
253    (Call, Ack),
254    (CancelCall, Ack),
255    (CopyProjectEntry, ProjectEntryResponse),
256    (CreateProjectEntry, ProjectEntryResponse),
257    (CreateRoom, CreateRoomResponse),
258    (DeclineCall, Ack),
259    (DeleteProjectEntry, ProjectEntryResponse),
260    (ExpandProjectEntry, ExpandProjectEntryResponse),
261    (Follow, FollowResponse),
262    (FormatBuffers, FormatBuffersResponse),
263    (GetChannelMessages, GetChannelMessagesResponse),
264    (GetChannels, GetChannelsResponse),
265    (GetCodeActions, GetCodeActionsResponse),
266    (GetHover, GetHoverResponse),
267    (GetCompletions, GetCompletionsResponse),
268    (GetDefinition, GetDefinitionResponse),
269    (GetTypeDefinition, GetTypeDefinitionResponse),
270    (GetDocumentHighlights, GetDocumentHighlightsResponse),
271    (GetReferences, GetReferencesResponse),
272    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
273    (GetProjectSymbols, GetProjectSymbolsResponse),
274    (FuzzySearchUsers, UsersResponse),
275    (GetUsers, UsersResponse),
276    (JoinChannel, JoinChannelResponse),
277    (JoinProject, JoinProjectResponse),
278    (JoinRoom, JoinRoomResponse),
279    (LeaveRoom, Ack),
280    (RejoinRoom, RejoinRoomResponse),
281    (IncomingCall, Ack),
282    (OpenBufferById, OpenBufferResponse),
283    (OpenBufferByPath, OpenBufferResponse),
284    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
285    (Ping, Ack),
286    (PerformRename, PerformRenameResponse),
287    (PrepareRename, PrepareRenameResponse),
288    (OnTypeFormatting, OnTypeFormattingResponse),
289    (ReloadBuffers, ReloadBuffersResponse),
290    (RequestContact, Ack),
291    (RemoveContact, Ack),
292    (RespondToContactRequest, Ack),
293    (RenameProjectEntry, ProjectEntryResponse),
294    (SaveBuffer, BufferSaved),
295    (SearchProject, SearchProjectResponse),
296    (SendChannelMessage, SendChannelMessageResponse),
297    (ShareProject, ShareProjectResponse),
298    (SynchronizeBuffers, SynchronizeBuffersResponse),
299    (Test, Test),
300    (UpdateBuffer, Ack),
301    (UpdateParticipantLocation, Ack),
302    (UpdateProject, Ack),
303    (UpdateWorktree, Ack),
304);
305
306entity_messages!(
307    project_id,
308    AddProjectCollaborator,
309    ApplyCodeAction,
310    ApplyCompletionAdditionalEdits,
311    BufferReloaded,
312    BufferSaved,
313    CopyProjectEntry,
314    CreateBufferForPeer,
315    CreateProjectEntry,
316    DeleteProjectEntry,
317    ExpandProjectEntry,
318    Follow,
319    FormatBuffers,
320    GetCodeActions,
321    GetCompletions,
322    GetDefinition,
323    GetTypeDefinition,
324    GetDocumentHighlights,
325    GetHover,
326    GetReferences,
327    GetProjectSymbols,
328    JoinProject,
329    LeaveProject,
330    OpenBufferById,
331    OpenBufferByPath,
332    OpenBufferForSymbol,
333    PerformRename,
334    OnTypeFormatting,
335    PrepareRename,
336    ReloadBuffers,
337    RemoveProjectCollaborator,
338    RenameProjectEntry,
339    SaveBuffer,
340    SearchProject,
341    StartLanguageServer,
342    SynchronizeBuffers,
343    Unfollow,
344    UnshareProject,
345    UpdateBuffer,
346    UpdateBufferFile,
347    UpdateDiagnosticSummary,
348    UpdateFollowers,
349    UpdateLanguageServer,
350    UpdateProject,
351    UpdateProjectCollaborator,
352    UpdateWorktree,
353    UpdateWorktreeSettings,
354    UpdateDiffBase
355);
356
357entity_messages!(channel_id, ChannelMessageSent);
358
359const KIB: usize = 1024;
360const MIB: usize = KIB * 1024;
361const MAX_BUFFER_LEN: usize = MIB;
362
363/// A stream of protobuf messages.
364pub struct MessageStream<S> {
365    stream: S,
366    encoding_buffer: Vec<u8>,
367}
368
369#[allow(clippy::large_enum_variant)]
370#[derive(Debug)]
371pub enum Message {
372    Envelope(Envelope),
373    Ping,
374    Pong,
375}
376
377impl<S> MessageStream<S> {
378    pub fn new(stream: S) -> Self {
379        Self {
380            stream,
381            encoding_buffer: Vec::new(),
382        }
383    }
384
385    pub fn inner_mut(&mut self) -> &mut S {
386        &mut self.stream
387    }
388}
389
390impl<S> MessageStream<S>
391where
392    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
393{
394    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
395        #[cfg(any(test, feature = "test-support"))]
396        const COMPRESSION_LEVEL: i32 = -7;
397
398        #[cfg(not(any(test, feature = "test-support")))]
399        const COMPRESSION_LEVEL: i32 = 4;
400
401        match message {
402            Message::Envelope(message) => {
403                self.encoding_buffer.reserve(message.encoded_len());
404                message
405                    .encode(&mut self.encoding_buffer)
406                    .map_err(io::Error::from)?;
407                let buffer =
408                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
409                        .unwrap();
410
411                self.encoding_buffer.clear();
412                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
413                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
414            }
415            Message::Ping => {
416                self.stream
417                    .send(WebSocketMessage::Ping(Default::default()))
418                    .await?;
419            }
420            Message::Pong => {
421                self.stream
422                    .send(WebSocketMessage::Pong(Default::default()))
423                    .await?;
424            }
425        }
426
427        Ok(())
428    }
429}
430
431impl<S> MessageStream<S>
432where
433    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
434{
435    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
436        while let Some(bytes) = self.stream.next().await {
437            match bytes? {
438                WebSocketMessage::Binary(bytes) => {
439                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
440                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
441                        .map_err(io::Error::from)?;
442
443                    self.encoding_buffer.clear();
444                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
445                    return Ok(Message::Envelope(envelope));
446                }
447                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
448                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
449                WebSocketMessage::Close(_) => break,
450                _ => {}
451            }
452        }
453        Err(anyhow!("connection closed"))
454    }
455}
456
457impl From<Timestamp> for SystemTime {
458    fn from(val: Timestamp) -> Self {
459        UNIX_EPOCH
460            .checked_add(Duration::new(val.seconds, val.nanos))
461            .unwrap()
462    }
463}
464
465impl From<SystemTime> for Timestamp {
466    fn from(time: SystemTime) -> Self {
467        let duration = time.duration_since(UNIX_EPOCH).unwrap();
468        Self {
469            seconds: duration.as_secs(),
470            nanos: duration.subsec_nanos(),
471        }
472    }
473}
474
475impl From<u128> for Nonce {
476    fn from(nonce: u128) -> Self {
477        let upper_half = (nonce >> 64) as u64;
478        let lower_half = nonce as u64;
479        Self {
480            upper_half,
481            lower_half,
482        }
483    }
484}
485
486impl From<Nonce> for u128 {
487    fn from(nonce: Nonce) -> Self {
488        let upper_half = (nonce.upper_half as u128) << 64;
489        let lower_half = nonce.lower_half as u128;
490        upper_half | lower_half
491    }
492}
493
494pub fn split_worktree_update(
495    mut message: UpdateWorktree,
496    max_chunk_size: usize,
497) -> impl Iterator<Item = UpdateWorktree> {
498    let mut done_files = false;
499
500    let mut repository_map = message
501        .updated_repositories
502        .into_iter()
503        .map(|repo| (repo.work_directory_id, repo))
504        .collect::<HashMap<_, _>>();
505
506    iter::from_fn(move || {
507        if done_files {
508            return None;
509        }
510
511        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
512        let updated_entries: Vec<_> = message
513            .updated_entries
514            .drain(..updated_entries_chunk_size)
515            .collect();
516
517        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
518        let removed_entries = message
519            .removed_entries
520            .drain(..removed_entries_chunk_size)
521            .collect();
522
523        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
524
525        let mut updated_repositories = Vec::new();
526
527        if !repository_map.is_empty() {
528            for entry in &updated_entries {
529                if let Some(repo) = repository_map.remove(&entry.id) {
530                    updated_repositories.push(repo)
531                }
532            }
533        }
534
535        let removed_repositories = if done_files {
536            mem::take(&mut message.removed_repositories)
537        } else {
538            Default::default()
539        };
540
541        if done_files {
542            updated_repositories.extend(mem::take(&mut repository_map).into_values());
543        }
544
545        Some(UpdateWorktree {
546            project_id: message.project_id,
547            worktree_id: message.worktree_id,
548            root_name: message.root_name.clone(),
549            abs_path: message.abs_path.clone(),
550            updated_entries,
551            removed_entries,
552            scan_id: message.scan_id,
553            is_last_update: done_files && message.is_last_update,
554            updated_repositories,
555            removed_repositories,
556        })
557    })
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    #[gpui::test]
565    async fn test_buffer_size() {
566        let (tx, rx) = futures::channel::mpsc::unbounded();
567        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
568        sink.write(Message::Envelope(Envelope {
569            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
570                root_name: "abcdefg".repeat(10),
571                ..Default::default()
572            })),
573            ..Default::default()
574        }))
575        .await
576        .unwrap();
577        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
578        sink.write(Message::Envelope(Envelope {
579            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
580                root_name: "abcdefg".repeat(1000000),
581                ..Default::default()
582            })),
583            ..Default::default()
584        }))
585        .await
586        .unwrap();
587        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
588
589        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
590        stream.read().await.unwrap();
591        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
592        stream.read().await.unwrap();
593        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
594    }
595
596    #[gpui::test]
597    fn test_converting_peer_id_from_and_to_u64() {
598        let peer_id = PeerId {
599            owner_id: 10,
600            id: 3,
601        };
602        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
603        let peer_id = PeerId {
604            owner_id: u32::MAX,
605            id: 3,
606        };
607        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
608        let peer_id = PeerId {
609            owner_id: 10,
610            id: u32::MAX,
611        };
612        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
613        let peer_id = PeerId {
614            owner_id: u32::MAX,
615            id: u32::MAX,
616        };
617        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
618    }
619}