proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::fmt;
  9use std::{
 10    cmp,
 11    fmt::Debug,
 12    io, iter,
 13    time::{Duration, SystemTime, UNIX_EPOCH},
 14};
 15
 16include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 17
 18pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 19    const NAME: &'static str;
 20    const PRIORITY: MessagePriority;
 21    fn into_envelope(
 22        self,
 23        id: u32,
 24        responding_to: Option<u32>,
 25        original_sender_id: Option<PeerId>,
 26    ) -> Envelope;
 27    fn from_envelope(envelope: Envelope) -> Option<Self>;
 28}
 29
 30pub trait EntityMessage: EnvelopedMessage {
 31    fn remote_entity_id(&self) -> u64;
 32}
 33
 34pub trait RequestMessage: EnvelopedMessage {
 35    type Response: EnvelopedMessage;
 36}
 37
 38pub trait AnyTypedEnvelope: 'static + Send + Sync {
 39    fn payload_type_id(&self) -> TypeId;
 40    fn payload_type_name(&self) -> &'static str;
 41    fn as_any(&self) -> &dyn Any;
 42    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 43    fn is_background(&self) -> bool;
 44    fn original_sender_id(&self) -> Option<PeerId>;
 45    fn sender_id(&self) -> ConnectionId;
 46    fn message_id(&self) -> u32;
 47}
 48
 49pub enum MessagePriority {
 50    Foreground,
 51    Background,
 52}
 53
 54impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 55    fn payload_type_id(&self) -> TypeId {
 56        TypeId::of::<T>()
 57    }
 58
 59    fn payload_type_name(&self) -> &'static str {
 60        T::NAME
 61    }
 62
 63    fn as_any(&self) -> &dyn Any {
 64        self
 65    }
 66
 67    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 68        self
 69    }
 70
 71    fn is_background(&self) -> bool {
 72        matches!(T::PRIORITY, MessagePriority::Background)
 73    }
 74
 75    fn original_sender_id(&self) -> Option<PeerId> {
 76        self.original_sender_id
 77    }
 78
 79    fn sender_id(&self) -> ConnectionId {
 80        self.sender_id
 81    }
 82
 83    fn message_id(&self) -> u32 {
 84        self.message_id
 85    }
 86}
 87
 88impl PeerId {
 89    pub fn from_u64(peer_id: u64) -> Self {
 90        let owner_id = (peer_id >> 32) as u32;
 91        let id = peer_id as u32;
 92        Self { owner_id, id }
 93    }
 94
 95    pub fn as_u64(self) -> u64 {
 96        ((self.owner_id as u64) << 32) | (self.id as u64)
 97    }
 98}
 99
100impl Copy for PeerId {}
101
102impl Eq for PeerId {}
103
104impl Ord for PeerId {
105    fn cmp(&self, other: &Self) -> cmp::Ordering {
106        self.owner_id
107            .cmp(&other.owner_id)
108            .then_with(|| self.id.cmp(&other.id))
109    }
110}
111
112impl PartialOrd for PeerId {
113    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
114        Some(self.cmp(other))
115    }
116}
117
118impl std::hash::Hash for PeerId {
119    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
120        self.owner_id.hash(state);
121        self.id.hash(state);
122    }
123}
124
125impl fmt::Display for PeerId {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        write!(f, "{}/{}", self.owner_id, self.id)
128    }
129}
130
131messages!(
132    (Ack, Foreground),
133    (AddProjectCollaborator, Foreground),
134    (ApplyCodeAction, Background),
135    (ApplyCodeActionResponse, Background),
136    (ApplyCompletionAdditionalEdits, Background),
137    (ApplyCompletionAdditionalEditsResponse, Background),
138    (BufferReloaded, Foreground),
139    (BufferSaved, Foreground),
140    (Call, Foreground),
141    (CallCanceled, Foreground),
142    (CancelCall, Foreground),
143    (ChannelMessageSent, Foreground),
144    (CopyProjectEntry, Foreground),
145    (CreateBufferForPeer, Foreground),
146    (CreateProjectEntry, Foreground),
147    (CreateRoom, Foreground),
148    (CreateRoomResponse, Foreground),
149    (DeclineCall, Foreground),
150    (DeleteProjectEntry, Foreground),
151    (Error, Foreground),
152    (Follow, Foreground),
153    (FollowResponse, Foreground),
154    (FormatBuffers, Foreground),
155    (FormatBuffersResponse, Foreground),
156    (FuzzySearchUsers, Foreground),
157    (GetChannelMessages, Foreground),
158    (GetChannelMessagesResponse, Foreground),
159    (GetChannels, Foreground),
160    (GetChannelsResponse, Foreground),
161    (GetCodeActions, Background),
162    (GetCodeActionsResponse, Background),
163    (GetHover, Background),
164    (GetHoverResponse, Background),
165    (GetCompletions, Background),
166    (GetCompletionsResponse, Background),
167    (GetDefinition, Background),
168    (GetDefinitionResponse, Background),
169    (GetTypeDefinition, Background),
170    (GetTypeDefinitionResponse, Background),
171    (GetDocumentHighlights, Background),
172    (GetDocumentHighlightsResponse, Background),
173    (GetReferences, Background),
174    (GetReferencesResponse, Background),
175    (GetProjectSymbols, Background),
176    (GetProjectSymbolsResponse, Background),
177    (GetUsers, Foreground),
178    (Hello, Foreground),
179    (IncomingCall, Foreground),
180    (UsersResponse, Foreground),
181    (JoinChannel, Foreground),
182    (JoinChannelResponse, Foreground),
183    (JoinProject, Foreground),
184    (JoinProjectResponse, Foreground),
185    (JoinRoom, Foreground),
186    (JoinRoomResponse, Foreground),
187    (LeaveChannel, Foreground),
188    (LeaveProject, Foreground),
189    (LeaveRoom, Foreground),
190    (OpenBufferById, Background),
191    (OpenBufferByPath, Background),
192    (OpenBufferForSymbol, Background),
193    (OpenBufferForSymbolResponse, Background),
194    (OpenBufferResponse, Background),
195    (PerformRename, Background),
196    (PerformRenameResponse, Background),
197    (Ping, Foreground),
198    (PrepareRename, Background),
199    (PrepareRenameResponse, Background),
200    (ProjectEntryResponse, Foreground),
201    (RejoinRoom, Foreground),
202    (RejoinRoomResponse, Foreground),
203    (RemoveContact, Foreground),
204    (ReloadBuffers, Foreground),
205    (ReloadBuffersResponse, Foreground),
206    (RemoveProjectCollaborator, Foreground),
207    (RenameProjectEntry, Foreground),
208    (RequestContact, Foreground),
209    (RespondToContactRequest, Foreground),
210    (RoomUpdated, Foreground),
211    (SaveBuffer, Foreground),
212    (SearchProject, Background),
213    (SearchProjectResponse, Background),
214    (SendChannelMessage, Foreground),
215    (SendChannelMessageResponse, Foreground),
216    (ShareProject, Foreground),
217    (ShareProjectResponse, Foreground),
218    (ShowContacts, Foreground),
219    (StartLanguageServer, Foreground),
220    (SynchronizeBuffers, Foreground),
221    (SynchronizeBuffersResponse, Foreground),
222    (Test, Foreground),
223    (Unfollow, Foreground),
224    (UnshareProject, Foreground),
225    (UpdateBuffer, Foreground),
226    (UpdateBufferFile, Foreground),
227    (UpdateContacts, Foreground),
228    (UpdateDiagnosticSummary, Foreground),
229    (UpdateFollowers, Foreground),
230    (UpdateInviteInfo, Foreground),
231    (UpdateLanguageServer, Foreground),
232    (UpdateParticipantLocation, Foreground),
233    (UpdateProject, Foreground),
234    (UpdateProjectCollaborator, Foreground),
235    (UpdateWorktree, Foreground),
236    (UpdateDiffBase, Foreground),
237    (GetPrivateUserInfo, Foreground),
238    (GetPrivateUserInfoResponse, Foreground),
239);
240
241request_messages!(
242    (ApplyCodeAction, ApplyCodeActionResponse),
243    (
244        ApplyCompletionAdditionalEdits,
245        ApplyCompletionAdditionalEditsResponse
246    ),
247    (Call, Ack),
248    (CancelCall, Ack),
249    (CopyProjectEntry, ProjectEntryResponse),
250    (CreateProjectEntry, ProjectEntryResponse),
251    (CreateRoom, CreateRoomResponse),
252    (DeclineCall, Ack),
253    (DeleteProjectEntry, ProjectEntryResponse),
254    (Follow, FollowResponse),
255    (FormatBuffers, FormatBuffersResponse),
256    (GetChannelMessages, GetChannelMessagesResponse),
257    (GetChannels, GetChannelsResponse),
258    (GetCodeActions, GetCodeActionsResponse),
259    (GetHover, GetHoverResponse),
260    (GetCompletions, GetCompletionsResponse),
261    (GetDefinition, GetDefinitionResponse),
262    (GetTypeDefinition, GetTypeDefinitionResponse),
263    (GetDocumentHighlights, GetDocumentHighlightsResponse),
264    (GetReferences, GetReferencesResponse),
265    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
266    (GetProjectSymbols, GetProjectSymbolsResponse),
267    (FuzzySearchUsers, UsersResponse),
268    (GetUsers, UsersResponse),
269    (JoinChannel, JoinChannelResponse),
270    (JoinProject, JoinProjectResponse),
271    (JoinRoom, JoinRoomResponse),
272    (LeaveRoom, Ack),
273    (RejoinRoom, RejoinRoomResponse),
274    (IncomingCall, Ack),
275    (OpenBufferById, OpenBufferResponse),
276    (OpenBufferByPath, OpenBufferResponse),
277    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
278    (Ping, Ack),
279    (PerformRename, PerformRenameResponse),
280    (PrepareRename, PrepareRenameResponse),
281    (ReloadBuffers, ReloadBuffersResponse),
282    (RequestContact, Ack),
283    (RemoveContact, Ack),
284    (RespondToContactRequest, Ack),
285    (RenameProjectEntry, ProjectEntryResponse),
286    (SaveBuffer, BufferSaved),
287    (SearchProject, SearchProjectResponse),
288    (SendChannelMessage, SendChannelMessageResponse),
289    (ShareProject, ShareProjectResponse),
290    (SynchronizeBuffers, SynchronizeBuffersResponse),
291    (Test, Test),
292    (UpdateBuffer, Ack),
293    (UpdateParticipantLocation, Ack),
294    (UpdateProject, Ack),
295    (UpdateWorktree, Ack),
296);
297
298entity_messages!(
299    project_id,
300    AddProjectCollaborator,
301    ApplyCodeAction,
302    ApplyCompletionAdditionalEdits,
303    BufferReloaded,
304    BufferSaved,
305    CopyProjectEntry,
306    CreateBufferForPeer,
307    CreateProjectEntry,
308    DeleteProjectEntry,
309    Follow,
310    FormatBuffers,
311    GetCodeActions,
312    GetCompletions,
313    GetDefinition,
314    GetTypeDefinition,
315    GetDocumentHighlights,
316    GetHover,
317    GetReferences,
318    GetProjectSymbols,
319    JoinProject,
320    LeaveProject,
321    OpenBufferById,
322    OpenBufferByPath,
323    OpenBufferForSymbol,
324    PerformRename,
325    PrepareRename,
326    ReloadBuffers,
327    RemoveProjectCollaborator,
328    RenameProjectEntry,
329    SaveBuffer,
330    SearchProject,
331    StartLanguageServer,
332    SynchronizeBuffers,
333    Unfollow,
334    UnshareProject,
335    UpdateBuffer,
336    UpdateBufferFile,
337    UpdateDiagnosticSummary,
338    UpdateFollowers,
339    UpdateLanguageServer,
340    UpdateProject,
341    UpdateProjectCollaborator,
342    UpdateWorktree,
343    UpdateDiffBase
344);
345
346entity_messages!(channel_id, ChannelMessageSent);
347
348const KIB: usize = 1024;
349const MIB: usize = KIB * 1024;
350const MAX_BUFFER_LEN: usize = MIB;
351
352/// A stream of protobuf messages.
353pub struct MessageStream<S> {
354    stream: S,
355    encoding_buffer: Vec<u8>,
356}
357
358#[allow(clippy::large_enum_variant)]
359#[derive(Debug)]
360pub enum Message {
361    Envelope(Envelope),
362    Ping,
363    Pong,
364}
365
366impl<S> MessageStream<S> {
367    pub fn new(stream: S) -> Self {
368        Self {
369            stream,
370            encoding_buffer: Vec::new(),
371        }
372    }
373
374    pub fn inner_mut(&mut self) -> &mut S {
375        &mut self.stream
376    }
377}
378
379impl<S> MessageStream<S>
380where
381    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
382{
383    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
384        #[cfg(any(test, feature = "test-support"))]
385        const COMPRESSION_LEVEL: i32 = -7;
386
387        #[cfg(not(any(test, feature = "test-support")))]
388        const COMPRESSION_LEVEL: i32 = 4;
389
390        match message {
391            Message::Envelope(message) => {
392                self.encoding_buffer.reserve(message.encoded_len());
393                message
394                    .encode(&mut self.encoding_buffer)
395                    .map_err(io::Error::from)?;
396                let buffer =
397                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
398                        .unwrap();
399
400                self.encoding_buffer.clear();
401                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
402                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
403            }
404            Message::Ping => {
405                self.stream
406                    .send(WebSocketMessage::Ping(Default::default()))
407                    .await?;
408            }
409            Message::Pong => {
410                self.stream
411                    .send(WebSocketMessage::Pong(Default::default()))
412                    .await?;
413            }
414        }
415
416        Ok(())
417    }
418}
419
420impl<S> MessageStream<S>
421where
422    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
423{
424    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
425        while let Some(bytes) = self.stream.next().await {
426            match bytes? {
427                WebSocketMessage::Binary(bytes) => {
428                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
429                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
430                        .map_err(io::Error::from)?;
431
432                    self.encoding_buffer.clear();
433                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
434                    return Ok(Message::Envelope(envelope));
435                }
436                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
437                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
438                WebSocketMessage::Close(_) => break,
439                _ => {}
440            }
441        }
442        Err(anyhow!("connection closed"))
443    }
444}
445
446impl From<Timestamp> for SystemTime {
447    fn from(val: Timestamp) -> Self {
448        UNIX_EPOCH
449            .checked_add(Duration::new(val.seconds, val.nanos))
450            .unwrap()
451    }
452}
453
454impl From<SystemTime> for Timestamp {
455    fn from(time: SystemTime) -> Self {
456        let duration = time.duration_since(UNIX_EPOCH).unwrap();
457        Self {
458            seconds: duration.as_secs(),
459            nanos: duration.subsec_nanos(),
460        }
461    }
462}
463
464impl From<u128> for Nonce {
465    fn from(nonce: u128) -> Self {
466        let upper_half = (nonce >> 64) as u64;
467        let lower_half = nonce as u64;
468        Self {
469            upper_half,
470            lower_half,
471        }
472    }
473}
474
475impl From<Nonce> for u128 {
476    fn from(nonce: Nonce) -> Self {
477        let upper_half = (nonce.upper_half as u128) << 64;
478        let lower_half = nonce.lower_half as u128;
479        upper_half | lower_half
480    }
481}
482
483pub fn split_worktree_update(
484    mut message: UpdateWorktree,
485    max_chunk_size: usize,
486) -> impl Iterator<Item = UpdateWorktree> {
487    let mut done = false;
488    iter::from_fn(move || {
489        if done {
490            return None;
491        }
492
493        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
494        let updated_entries = message
495            .updated_entries
496            .drain(..updated_entries_chunk_size)
497            .collect();
498
499        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
500        let removed_entries = message
501            .removed_entries
502            .drain(..removed_entries_chunk_size)
503            .collect();
504
505        done = message.updated_entries.is_empty() && message.removed_entries.is_empty();
506        Some(UpdateWorktree {
507            project_id: message.project_id,
508            worktree_id: message.worktree_id,
509            root_name: message.root_name.clone(),
510            abs_path: message.abs_path.clone(),
511            updated_entries,
512            removed_entries,
513            scan_id: message.scan_id,
514            is_last_update: done && message.is_last_update,
515        })
516    })
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522
523    #[gpui::test]
524    async fn test_buffer_size() {
525        let (tx, rx) = futures::channel::mpsc::unbounded();
526        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
527        sink.write(Message::Envelope(Envelope {
528            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
529                root_name: "abcdefg".repeat(10),
530                ..Default::default()
531            })),
532            ..Default::default()
533        }))
534        .await
535        .unwrap();
536        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
537        sink.write(Message::Envelope(Envelope {
538            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
539                root_name: "abcdefg".repeat(1000000),
540                ..Default::default()
541            })),
542            ..Default::default()
543        }))
544        .await
545        .unwrap();
546        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
547
548        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
549        stream.read().await.unwrap();
550        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
551        stream.read().await.unwrap();
552        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
553    }
554
555    #[gpui::test]
556    fn test_converting_peer_id_from_and_to_u64() {
557        let peer_id = PeerId {
558            owner_id: 10,
559            id: 3,
560        };
561        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
562        let peer_id = PeerId {
563            owner_id: u32::MAX,
564            id: 3,
565        };
566        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
567        let peer_id = PeerId {
568            owner_id: 10,
569            id: u32::MAX,
570        };
571        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
572        let peer_id = PeerId {
573            owner_id: u32::MAX,
574            id: u32::MAX,
575        };
576        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
577    }
578}