proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{
  9    cmp,
 10    fmt::Debug,
 11    io, iter,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14use std::{fmt, mem};
 15
 16include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 17
 18pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 19    const NAME: &'static str;
 20    const PRIORITY: MessagePriority;
 21    fn into_envelope(
 22        self,
 23        id: u32,
 24        responding_to: Option<u32>,
 25        original_sender_id: Option<PeerId>,
 26    ) -> Envelope;
 27    fn from_envelope(envelope: Envelope) -> Option<Self>;
 28}
 29
 30pub trait EntityMessage: EnvelopedMessage {
 31    fn remote_entity_id(&self) -> u64;
 32}
 33
 34pub trait RequestMessage: EnvelopedMessage {
 35    type Response: EnvelopedMessage;
 36}
 37
 38pub trait AnyTypedEnvelope: 'static + Send + Sync {
 39    fn payload_type_id(&self) -> TypeId;
 40    fn payload_type_name(&self) -> &'static str;
 41    fn as_any(&self) -> &dyn Any;
 42    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 43    fn is_background(&self) -> bool;
 44    fn original_sender_id(&self) -> Option<PeerId>;
 45    fn sender_id(&self) -> ConnectionId;
 46    fn message_id(&self) -> u32;
 47}
 48
 49pub enum MessagePriority {
 50    Foreground,
 51    Background,
 52}
 53
 54impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 55    fn payload_type_id(&self) -> TypeId {
 56        TypeId::of::<T>()
 57    }
 58
 59    fn payload_type_name(&self) -> &'static str {
 60        T::NAME
 61    }
 62
 63    fn as_any(&self) -> &dyn Any {
 64        self
 65    }
 66
 67    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 68        self
 69    }
 70
 71    fn is_background(&self) -> bool {
 72        matches!(T::PRIORITY, MessagePriority::Background)
 73    }
 74
 75    fn original_sender_id(&self) -> Option<PeerId> {
 76        self.original_sender_id
 77    }
 78
 79    fn sender_id(&self) -> ConnectionId {
 80        self.sender_id
 81    }
 82
 83    fn message_id(&self) -> u32 {
 84        self.message_id
 85    }
 86}
 87
 88impl PeerId {
 89    pub fn from_u64(peer_id: u64) -> Self {
 90        let owner_id = (peer_id >> 32) as u32;
 91        let id = peer_id as u32;
 92        Self { owner_id, id }
 93    }
 94
 95    pub fn as_u64(self) -> u64 {
 96        ((self.owner_id as u64) << 32) | (self.id as u64)
 97    }
 98}
 99
100impl Copy for PeerId {}
101
102impl Eq for PeerId {}
103
104impl Ord for PeerId {
105    fn cmp(&self, other: &Self) -> cmp::Ordering {
106        self.owner_id
107            .cmp(&other.owner_id)
108            .then_with(|| self.id.cmp(&other.id))
109    }
110}
111
112impl PartialOrd for PeerId {
113    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
114        Some(self.cmp(other))
115    }
116}
117
118impl std::hash::Hash for PeerId {
119    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
120        self.owner_id.hash(state);
121        self.id.hash(state);
122    }
123}
124
125impl fmt::Display for PeerId {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        write!(f, "{}/{}", self.owner_id, self.id)
128    }
129}
130
131messages!(
132    (Ack, Foreground),
133    (AddProjectCollaborator, Foreground),
134    (ApplyCodeAction, Background),
135    (ApplyCodeActionResponse, Background),
136    (ApplyCompletionAdditionalEdits, Background),
137    (ApplyCompletionAdditionalEditsResponse, Background),
138    (BufferReloaded, Foreground),
139    (BufferSaved, Foreground),
140    (Call, Foreground),
141    (CallCanceled, Foreground),
142    (CancelCall, Foreground),
143    (ChannelMessageSent, Foreground),
144    (CopyProjectEntry, Foreground),
145    (CreateBufferForPeer, Foreground),
146    (CreateProjectEntry, Foreground),
147    (CreateRoom, Foreground),
148    (CreateRoomResponse, Foreground),
149    (DeclineCall, Foreground),
150    (DeleteProjectEntry, Foreground),
151    (Error, Foreground),
152    (Follow, Foreground),
153    (FollowResponse, Foreground),
154    (FormatBuffers, Foreground),
155    (FormatBuffersResponse, Foreground),
156    (FuzzySearchUsers, Foreground),
157    (GetChannelMessages, Foreground),
158    (GetChannelMessagesResponse, Foreground),
159    (GetChannels, Foreground),
160    (GetChannelsResponse, Foreground),
161    (GetCodeActions, Background),
162    (GetCodeActionsResponse, Background),
163    (GetHover, Background),
164    (GetHoverResponse, Background),
165    (GetCompletions, Background),
166    (GetCompletionsResponse, Background),
167    (GetDefinition, Background),
168    (GetDefinitionResponse, Background),
169    (GetTypeDefinition, Background),
170    (GetTypeDefinitionResponse, Background),
171    (GetDocumentHighlights, Background),
172    (GetDocumentHighlightsResponse, Background),
173    (GetReferences, Background),
174    (GetReferencesResponse, Background),
175    (GetProjectSymbols, Background),
176    (GetProjectSymbolsResponse, Background),
177    (GetUsers, Foreground),
178    (Hello, Foreground),
179    (IncomingCall, Foreground),
180    (UsersResponse, Foreground),
181    (JoinChannel, Foreground),
182    (JoinChannelResponse, Foreground),
183    (JoinProject, Foreground),
184    (JoinProjectResponse, Foreground),
185    (JoinRoom, Foreground),
186    (JoinRoomResponse, Foreground),
187    (LeaveChannel, Foreground),
188    (LeaveProject, Foreground),
189    (LeaveRoom, Foreground),
190    (OpenBufferById, Background),
191    (OpenBufferByPath, Background),
192    (OpenBufferForSymbol, Background),
193    (OpenBufferForSymbolResponse, Background),
194    (OpenBufferResponse, Background),
195    (PerformRename, Background),
196    (PerformRenameResponse, Background),
197    (Ping, Foreground),
198    (PrepareRename, Background),
199    (PrepareRenameResponse, Background),
200    (ProjectEntryResponse, Foreground),
201    (RejoinRoom, Foreground),
202    (RejoinRoomResponse, Foreground),
203    (RemoveContact, Foreground),
204    (ReloadBuffers, Foreground),
205    (ReloadBuffersResponse, Foreground),
206    (RemoveProjectCollaborator, Foreground),
207    (RenameProjectEntry, Foreground),
208    (RequestContact, Foreground),
209    (RespondToContactRequest, Foreground),
210    (RoomUpdated, Foreground),
211    (SaveBuffer, Foreground),
212    (SearchProject, Background),
213    (SearchProjectResponse, Background),
214    (SendChannelMessage, Foreground),
215    (SendChannelMessageResponse, Foreground),
216    (ShareProject, Foreground),
217    (ShareProjectResponse, Foreground),
218    (ShowContacts, Foreground),
219    (StartLanguageServer, Foreground),
220    (SynchronizeBuffers, Foreground),
221    (SynchronizeBuffersResponse, Foreground),
222    (Test, Foreground),
223    (Unfollow, Foreground),
224    (UnshareProject, Foreground),
225    (UpdateBuffer, Foreground),
226    (UpdateBufferFile, Foreground),
227    (UpdateContacts, Foreground),
228    (UpdateDiagnosticSummary, Foreground),
229    (UpdateFollowers, Foreground),
230    (UpdateInviteInfo, Foreground),
231    (UpdateLanguageServer, Foreground),
232    (UpdateParticipantLocation, Foreground),
233    (UpdateProject, Foreground),
234    (UpdateProjectCollaborator, Foreground),
235    (UpdateWorktree, Foreground),
236    (UpdateDiffBase, Foreground),
237    (GetPrivateUserInfo, Foreground),
238    (GetPrivateUserInfoResponse, Foreground),
239);
240
241request_messages!(
242    (ApplyCodeAction, ApplyCodeActionResponse),
243    (
244        ApplyCompletionAdditionalEdits,
245        ApplyCompletionAdditionalEditsResponse
246    ),
247    (Call, Ack),
248    (CancelCall, Ack),
249    (CopyProjectEntry, ProjectEntryResponse),
250    (CreateProjectEntry, ProjectEntryResponse),
251    (CreateRoom, CreateRoomResponse),
252    (DeclineCall, Ack),
253    (DeleteProjectEntry, ProjectEntryResponse),
254    (Follow, FollowResponse),
255    (FormatBuffers, FormatBuffersResponse),
256    (GetChannelMessages, GetChannelMessagesResponse),
257    (GetChannels, GetChannelsResponse),
258    (GetCodeActions, GetCodeActionsResponse),
259    (GetHover, GetHoverResponse),
260    (GetCompletions, GetCompletionsResponse),
261    (GetDefinition, GetDefinitionResponse),
262    (GetTypeDefinition, GetTypeDefinitionResponse),
263    (GetDocumentHighlights, GetDocumentHighlightsResponse),
264    (GetReferences, GetReferencesResponse),
265    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
266    (GetProjectSymbols, GetProjectSymbolsResponse),
267    (FuzzySearchUsers, UsersResponse),
268    (GetUsers, UsersResponse),
269    (JoinChannel, JoinChannelResponse),
270    (JoinProject, JoinProjectResponse),
271    (JoinRoom, JoinRoomResponse),
272    (LeaveRoom, Ack),
273    (RejoinRoom, RejoinRoomResponse),
274    (IncomingCall, Ack),
275    (OpenBufferById, OpenBufferResponse),
276    (OpenBufferByPath, OpenBufferResponse),
277    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
278    (Ping, Ack),
279    (PerformRename, PerformRenameResponse),
280    (PrepareRename, PrepareRenameResponse),
281    (ReloadBuffers, ReloadBuffersResponse),
282    (RequestContact, Ack),
283    (RemoveContact, Ack),
284    (RespondToContactRequest, Ack),
285    (RenameProjectEntry, ProjectEntryResponse),
286    (SaveBuffer, BufferSaved),
287    (SearchProject, SearchProjectResponse),
288    (SendChannelMessage, SendChannelMessageResponse),
289    (ShareProject, ShareProjectResponse),
290    (SynchronizeBuffers, SynchronizeBuffersResponse),
291    (Test, Test),
292    (UpdateBuffer, Ack),
293    (UpdateParticipantLocation, Ack),
294    (UpdateProject, Ack),
295    (UpdateWorktree, Ack),
296);
297
298entity_messages!(
299    project_id,
300    AddProjectCollaborator,
301    ApplyCodeAction,
302    ApplyCompletionAdditionalEdits,
303    BufferReloaded,
304    BufferSaved,
305    CopyProjectEntry,
306    CreateBufferForPeer,
307    CreateProjectEntry,
308    DeleteProjectEntry,
309    Follow,
310    FormatBuffers,
311    GetCodeActions,
312    GetCompletions,
313    GetDefinition,
314    GetTypeDefinition,
315    GetDocumentHighlights,
316    GetHover,
317    GetReferences,
318    GetProjectSymbols,
319    JoinProject,
320    LeaveProject,
321    OpenBufferById,
322    OpenBufferByPath,
323    OpenBufferForSymbol,
324    PerformRename,
325    PrepareRename,
326    ReloadBuffers,
327    RemoveProjectCollaborator,
328    RenameProjectEntry,
329    SaveBuffer,
330    SearchProject,
331    StartLanguageServer,
332    SynchronizeBuffers,
333    Unfollow,
334    UnshareProject,
335    UpdateBuffer,
336    UpdateBufferFile,
337    UpdateDiagnosticSummary,
338    UpdateFollowers,
339    UpdateLanguageServer,
340    UpdateProject,
341    UpdateProjectCollaborator,
342    UpdateWorktree,
343    UpdateDiffBase
344);
345
346entity_messages!(channel_id, ChannelMessageSent);
347
348const KIB: usize = 1024;
349const MIB: usize = KIB * 1024;
350const MAX_BUFFER_LEN: usize = MIB;
351
352/// A stream of protobuf messages.
353pub struct MessageStream<S> {
354    stream: S,
355    encoding_buffer: Vec<u8>,
356}
357
358#[allow(clippy::large_enum_variant)]
359#[derive(Debug)]
360pub enum Message {
361    Envelope(Envelope),
362    Ping,
363    Pong,
364}
365
366impl<S> MessageStream<S> {
367    pub fn new(stream: S) -> Self {
368        Self {
369            stream,
370            encoding_buffer: Vec::new(),
371        }
372    }
373
374    pub fn inner_mut(&mut self) -> &mut S {
375        &mut self.stream
376    }
377}
378
379impl<S> MessageStream<S>
380where
381    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
382{
383    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
384        #[cfg(any(test, feature = "test-support"))]
385        const COMPRESSION_LEVEL: i32 = -7;
386
387        #[cfg(not(any(test, feature = "test-support")))]
388        const COMPRESSION_LEVEL: i32 = 4;
389
390        match message {
391            Message::Envelope(message) => {
392                self.encoding_buffer.reserve(message.encoded_len());
393                message
394                    .encode(&mut self.encoding_buffer)
395                    .map_err(io::Error::from)?;
396                let buffer =
397                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
398                        .unwrap();
399
400                self.encoding_buffer.clear();
401                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
402                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
403            }
404            Message::Ping => {
405                self.stream
406                    .send(WebSocketMessage::Ping(Default::default()))
407                    .await?;
408            }
409            Message::Pong => {
410                self.stream
411                    .send(WebSocketMessage::Pong(Default::default()))
412                    .await?;
413            }
414        }
415
416        Ok(())
417    }
418}
419
420impl<S> MessageStream<S>
421where
422    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
423{
424    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
425        while let Some(bytes) = self.stream.next().await {
426            match bytes? {
427                WebSocketMessage::Binary(bytes) => {
428                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
429                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
430                        .map_err(io::Error::from)?;
431
432                    self.encoding_buffer.clear();
433                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
434                    return Ok(Message::Envelope(envelope));
435                }
436                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
437                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
438                WebSocketMessage::Close(_) => break,
439                _ => {}
440            }
441        }
442        Err(anyhow!("connection closed"))
443    }
444}
445
446impl From<Timestamp> for SystemTime {
447    fn from(val: Timestamp) -> Self {
448        UNIX_EPOCH
449            .checked_add(Duration::new(val.seconds, val.nanos))
450            .unwrap()
451    }
452}
453
454impl From<SystemTime> for Timestamp {
455    fn from(time: SystemTime) -> Self {
456        let duration = time.duration_since(UNIX_EPOCH).unwrap();
457        Self {
458            seconds: duration.as_secs(),
459            nanos: duration.subsec_nanos(),
460        }
461    }
462}
463
464impl From<u128> for Nonce {
465    fn from(nonce: u128) -> Self {
466        let upper_half = (nonce >> 64) as u64;
467        let lower_half = nonce as u64;
468        Self {
469            upper_half,
470            lower_half,
471        }
472    }
473}
474
475impl From<Nonce> for u128 {
476    fn from(nonce: Nonce) -> Self {
477        let upper_half = (nonce.upper_half as u128) << 64;
478        let lower_half = nonce.lower_half as u128;
479        upper_half | lower_half
480    }
481}
482
483pub fn split_worktree_update(
484    mut message: UpdateWorktree,
485    max_chunk_size: usize,
486) -> impl Iterator<Item = UpdateWorktree> {
487    let mut done = false;
488    iter::from_fn(move || {
489        if done {
490            return None;
491        }
492
493        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
494        let updated_entries = message
495            .updated_entries
496            .drain(..updated_entries_chunk_size)
497            .collect();
498
499        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
500        let removed_entries = message
501            .removed_entries
502            .drain(..removed_entries_chunk_size)
503            .collect();
504
505        done = message.updated_entries.is_empty() && message.removed_entries.is_empty();
506
507        // Wait to send repositories until after we've guaranteed that their associated entries
508        // will be read
509        let updated_repositories = if done {
510            mem::take(&mut message.updated_repositories)
511        } else {
512            Default::default()
513        };
514
515        let removed_repositories = if done {
516            mem::take(&mut message.removed_repositories)
517        } else {
518            Default::default()
519        };
520
521        Some(UpdateWorktree {
522            project_id: message.project_id,
523            worktree_id: message.worktree_id,
524            root_name: message.root_name.clone(),
525            abs_path: message.abs_path.clone(),
526            updated_entries,
527            removed_entries,
528            scan_id: message.scan_id,
529            is_last_update: done && message.is_last_update,
530            updated_repositories,
531            removed_repositories,
532        })
533    })
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    #[gpui::test]
541    async fn test_buffer_size() {
542        let (tx, rx) = futures::channel::mpsc::unbounded();
543        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
544        sink.write(Message::Envelope(Envelope {
545            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
546                root_name: "abcdefg".repeat(10),
547                ..Default::default()
548            })),
549            ..Default::default()
550        }))
551        .await
552        .unwrap();
553        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
554        sink.write(Message::Envelope(Envelope {
555            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
556                root_name: "abcdefg".repeat(1000000),
557                ..Default::default()
558            })),
559            ..Default::default()
560        }))
561        .await
562        .unwrap();
563        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
564
565        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
566        stream.read().await.unwrap();
567        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
568        stream.read().await.unwrap();
569        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
570    }
571
572    #[gpui::test]
573    fn test_converting_peer_id_from_and_to_u64() {
574        let peer_id = PeerId {
575            owner_id: 10,
576            id: 3,
577        };
578        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
579        let peer_id = PeerId {
580            owner_id: u32::MAX,
581            id: 3,
582        };
583        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
584        let peer_id = PeerId {
585            owner_id: 10,
586            id: u32::MAX,
587        };
588        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
589        let peer_id = PeerId {
590            owner_id: u32::MAX,
591            id: u32::MAX,
592        };
593        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
594    }
595}