proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use collections::HashMap;
  5use futures::{SinkExt as _, StreamExt as _};
  6use prost::Message as _;
  7use serde::Serialize;
  8use std::any::{Any, TypeId};
  9use std::{
 10    cmp,
 11    fmt::Debug,
 12    io, iter,
 13    time::{Duration, SystemTime, UNIX_EPOCH},
 14};
 15use std::{fmt, mem};
 16
 17include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 18
 19pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 20    const NAME: &'static str;
 21    const PRIORITY: MessagePriority;
 22    fn into_envelope(
 23        self,
 24        id: u32,
 25        responding_to: Option<u32>,
 26        original_sender_id: Option<PeerId>,
 27    ) -> Envelope;
 28    fn from_envelope(envelope: Envelope) -> Option<Self>;
 29}
 30
 31pub trait EntityMessage: EnvelopedMessage {
 32    fn remote_entity_id(&self) -> u64;
 33}
 34
 35pub trait RequestMessage: EnvelopedMessage {
 36    type Response: EnvelopedMessage;
 37}
 38
 39pub trait AnyTypedEnvelope: 'static + Send + Sync {
 40    fn payload_type_id(&self) -> TypeId;
 41    fn payload_type_name(&self) -> &'static str;
 42    fn as_any(&self) -> &dyn Any;
 43    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 44    fn is_background(&self) -> bool;
 45    fn original_sender_id(&self) -> Option<PeerId>;
 46    fn sender_id(&self) -> ConnectionId;
 47    fn message_id(&self) -> u32;
 48}
 49
 50pub enum MessagePriority {
 51    Foreground,
 52    Background,
 53}
 54
 55impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 56    fn payload_type_id(&self) -> TypeId {
 57        TypeId::of::<T>()
 58    }
 59
 60    fn payload_type_name(&self) -> &'static str {
 61        T::NAME
 62    }
 63
 64    fn as_any(&self) -> &dyn Any {
 65        self
 66    }
 67
 68    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 69        self
 70    }
 71
 72    fn is_background(&self) -> bool {
 73        matches!(T::PRIORITY, MessagePriority::Background)
 74    }
 75
 76    fn original_sender_id(&self) -> Option<PeerId> {
 77        self.original_sender_id
 78    }
 79
 80    fn sender_id(&self) -> ConnectionId {
 81        self.sender_id
 82    }
 83
 84    fn message_id(&self) -> u32 {
 85        self.message_id
 86    }
 87}
 88
 89impl PeerId {
 90    pub fn from_u64(peer_id: u64) -> Self {
 91        let owner_id = (peer_id >> 32) as u32;
 92        let id = peer_id as u32;
 93        Self { owner_id, id }
 94    }
 95
 96    pub fn as_u64(self) -> u64 {
 97        ((self.owner_id as u64) << 32) | (self.id as u64)
 98    }
 99}
100
101impl Copy for PeerId {}
102
103impl Eq for PeerId {}
104
105impl Ord for PeerId {
106    fn cmp(&self, other: &Self) -> cmp::Ordering {
107        self.owner_id
108            .cmp(&other.owner_id)
109            .then_with(|| self.id.cmp(&other.id))
110    }
111}
112
113impl PartialOrd for PeerId {
114    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
115        Some(self.cmp(other))
116    }
117}
118
119impl std::hash::Hash for PeerId {
120    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
121        self.owner_id.hash(state);
122        self.id.hash(state);
123    }
124}
125
126impl fmt::Display for PeerId {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        write!(f, "{}/{}", self.owner_id, self.id)
129    }
130}
131
132messages!(
133    (Ack, Foreground),
134    (AddProjectCollaborator, Foreground),
135    (ApplyCodeAction, Background),
136    (ApplyCodeActionResponse, Background),
137    (ApplyCompletionAdditionalEdits, Background),
138    (ApplyCompletionAdditionalEditsResponse, Background),
139    (BufferReloaded, Foreground),
140    (BufferSaved, Foreground),
141    (Call, Foreground),
142    (CallCanceled, Foreground),
143    (CancelCall, Foreground),
144    (ChannelMessageSent, Foreground),
145    (CopyProjectEntry, Foreground),
146    (CreateBufferForPeer, Foreground),
147    (CreateProjectEntry, Foreground),
148    (CreateRoom, Foreground),
149    (CreateRoomResponse, Foreground),
150    (DeclineCall, Foreground),
151    (DeleteProjectEntry, Foreground),
152    (Error, Foreground),
153    (Follow, Foreground),
154    (FollowResponse, Foreground),
155    (FormatBuffers, Foreground),
156    (FormatBuffersResponse, Foreground),
157    (FuzzySearchUsers, Foreground),
158    (GetChannelMessages, Foreground),
159    (GetChannelMessagesResponse, Foreground),
160    (GetChannels, Foreground),
161    (GetChannelsResponse, Foreground),
162    (GetCodeActions, Background),
163    (GetCodeActionsResponse, Background),
164    (GetHover, Background),
165    (GetHoverResponse, Background),
166    (GetCompletions, Background),
167    (GetCompletionsResponse, Background),
168    (GetDefinition, Background),
169    (GetDefinitionResponse, Background),
170    (GetTypeDefinition, Background),
171    (GetTypeDefinitionResponse, Background),
172    (GetDocumentHighlights, Background),
173    (GetDocumentHighlightsResponse, Background),
174    (GetReferences, Background),
175    (GetReferencesResponse, Background),
176    (GetProjectSymbols, Background),
177    (GetProjectSymbolsResponse, Background),
178    (GetUsers, Foreground),
179    (Hello, Foreground),
180    (IncomingCall, Foreground),
181    (UsersResponse, Foreground),
182    (JoinChannel, Foreground),
183    (JoinChannelResponse, Foreground),
184    (JoinProject, Foreground),
185    (JoinProjectResponse, Foreground),
186    (JoinRoom, Foreground),
187    (JoinRoomResponse, Foreground),
188    (LeaveChannel, Foreground),
189    (LeaveProject, Foreground),
190    (LeaveRoom, Foreground),
191    (OpenBufferById, Background),
192    (OpenBufferByPath, Background),
193    (OpenBufferForSymbol, Background),
194    (OpenBufferForSymbolResponse, Background),
195    (OpenBufferResponse, Background),
196    (PerformRename, Background),
197    (PerformRenameResponse, Background),
198    (OnTypeFormatting, Background),
199    (OnTypeFormattingResponse, Background),
200    (Ping, Foreground),
201    (PrepareRename, Background),
202    (PrepareRenameResponse, Background),
203    (ProjectEntryResponse, Foreground),
204    (RejoinRoom, Foreground),
205    (RejoinRoomResponse, Foreground),
206    (RemoveContact, Foreground),
207    (ReloadBuffers, Foreground),
208    (ReloadBuffersResponse, Foreground),
209    (RemoveProjectCollaborator, Foreground),
210    (RenameProjectEntry, Foreground),
211    (RequestContact, Foreground),
212    (RespondToContactRequest, Foreground),
213    (RoomUpdated, Foreground),
214    (SaveBuffer, Foreground),
215    (SearchProject, Background),
216    (SearchProjectResponse, Background),
217    (SendChannelMessage, Foreground),
218    (SendChannelMessageResponse, Foreground),
219    (ShareProject, Foreground),
220    (ShareProjectResponse, Foreground),
221    (ShowContacts, Foreground),
222    (StartLanguageServer, Foreground),
223    (SynchronizeBuffers, Foreground),
224    (SynchronizeBuffersResponse, Foreground),
225    (Test, Foreground),
226    (Unfollow, Foreground),
227    (UnshareProject, Foreground),
228    (UpdateBuffer, Foreground),
229    (UpdateBufferFile, Foreground),
230    (UpdateContacts, Foreground),
231    (UpdateDiagnosticSummary, Foreground),
232    (UpdateFollowers, Foreground),
233    (UpdateInviteInfo, Foreground),
234    (UpdateLanguageServer, Foreground),
235    (UpdateParticipantLocation, Foreground),
236    (UpdateProject, Foreground),
237    (UpdateProjectCollaborator, Foreground),
238    (UpdateWorktree, Foreground),
239    (UpdateWorktreeSettings, Foreground),
240    (UpdateDiffBase, Foreground),
241    (GetPrivateUserInfo, Foreground),
242    (GetPrivateUserInfoResponse, Foreground),
243);
244
245request_messages!(
246    (ApplyCodeAction, ApplyCodeActionResponse),
247    (
248        ApplyCompletionAdditionalEdits,
249        ApplyCompletionAdditionalEditsResponse
250    ),
251    (Call, Ack),
252    (CancelCall, Ack),
253    (CopyProjectEntry, ProjectEntryResponse),
254    (CreateProjectEntry, ProjectEntryResponse),
255    (CreateRoom, CreateRoomResponse),
256    (DeclineCall, Ack),
257    (DeleteProjectEntry, ProjectEntryResponse),
258    (Follow, FollowResponse),
259    (FormatBuffers, FormatBuffersResponse),
260    (GetChannelMessages, GetChannelMessagesResponse),
261    (GetChannels, GetChannelsResponse),
262    (GetCodeActions, GetCodeActionsResponse),
263    (GetHover, GetHoverResponse),
264    (GetCompletions, GetCompletionsResponse),
265    (GetDefinition, GetDefinitionResponse),
266    (GetTypeDefinition, GetTypeDefinitionResponse),
267    (GetDocumentHighlights, GetDocumentHighlightsResponse),
268    (GetReferences, GetReferencesResponse),
269    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
270    (GetProjectSymbols, GetProjectSymbolsResponse),
271    (FuzzySearchUsers, UsersResponse),
272    (GetUsers, UsersResponse),
273    (JoinChannel, JoinChannelResponse),
274    (JoinProject, JoinProjectResponse),
275    (JoinRoom, JoinRoomResponse),
276    (LeaveRoom, Ack),
277    (RejoinRoom, RejoinRoomResponse),
278    (IncomingCall, Ack),
279    (OpenBufferById, OpenBufferResponse),
280    (OpenBufferByPath, OpenBufferResponse),
281    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
282    (Ping, Ack),
283    (PerformRename, PerformRenameResponse),
284    (PrepareRename, PrepareRenameResponse),
285    (OnTypeFormatting, OnTypeFormattingResponse),
286    (ReloadBuffers, ReloadBuffersResponse),
287    (RequestContact, Ack),
288    (RemoveContact, Ack),
289    (RespondToContactRequest, Ack),
290    (RenameProjectEntry, ProjectEntryResponse),
291    (SaveBuffer, BufferSaved),
292    (SearchProject, SearchProjectResponse),
293    (SendChannelMessage, SendChannelMessageResponse),
294    (ShareProject, ShareProjectResponse),
295    (SynchronizeBuffers, SynchronizeBuffersResponse),
296    (Test, Test),
297    (UpdateBuffer, Ack),
298    (UpdateParticipantLocation, Ack),
299    (UpdateProject, Ack),
300    (UpdateWorktree, Ack),
301);
302
303entity_messages!(
304    project_id,
305    AddProjectCollaborator,
306    ApplyCodeAction,
307    ApplyCompletionAdditionalEdits,
308    BufferReloaded,
309    BufferSaved,
310    CopyProjectEntry,
311    CreateBufferForPeer,
312    CreateProjectEntry,
313    DeleteProjectEntry,
314    Follow,
315    FormatBuffers,
316    GetCodeActions,
317    GetCompletions,
318    GetDefinition,
319    GetTypeDefinition,
320    GetDocumentHighlights,
321    GetHover,
322    GetReferences,
323    GetProjectSymbols,
324    JoinProject,
325    LeaveProject,
326    OpenBufferById,
327    OpenBufferByPath,
328    OpenBufferForSymbol,
329    PerformRename,
330    OnTypeFormatting,
331    PrepareRename,
332    ReloadBuffers,
333    RemoveProjectCollaborator,
334    RenameProjectEntry,
335    SaveBuffer,
336    SearchProject,
337    StartLanguageServer,
338    SynchronizeBuffers,
339    Unfollow,
340    UnshareProject,
341    UpdateBuffer,
342    UpdateBufferFile,
343    UpdateDiagnosticSummary,
344    UpdateFollowers,
345    UpdateLanguageServer,
346    UpdateProject,
347    UpdateProjectCollaborator,
348    UpdateWorktree,
349    UpdateWorktreeSettings,
350    UpdateDiffBase
351);
352
353entity_messages!(channel_id, ChannelMessageSent);
354
355const KIB: usize = 1024;
356const MIB: usize = KIB * 1024;
357const MAX_BUFFER_LEN: usize = MIB;
358
359/// A stream of protobuf messages.
360pub struct MessageStream<S> {
361    stream: S,
362    encoding_buffer: Vec<u8>,
363}
364
365#[allow(clippy::large_enum_variant)]
366#[derive(Debug)]
367pub enum Message {
368    Envelope(Envelope),
369    Ping,
370    Pong,
371}
372
373impl<S> MessageStream<S> {
374    pub fn new(stream: S) -> Self {
375        Self {
376            stream,
377            encoding_buffer: Vec::new(),
378        }
379    }
380
381    pub fn inner_mut(&mut self) -> &mut S {
382        &mut self.stream
383    }
384}
385
386impl<S> MessageStream<S>
387where
388    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
389{
390    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
391        #[cfg(any(test, feature = "test-support"))]
392        const COMPRESSION_LEVEL: i32 = -7;
393
394        #[cfg(not(any(test, feature = "test-support")))]
395        const COMPRESSION_LEVEL: i32 = 4;
396
397        match message {
398            Message::Envelope(message) => {
399                self.encoding_buffer.reserve(message.encoded_len());
400                message
401                    .encode(&mut self.encoding_buffer)
402                    .map_err(io::Error::from)?;
403                let buffer =
404                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
405                        .unwrap();
406
407                self.encoding_buffer.clear();
408                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
409                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
410            }
411            Message::Ping => {
412                self.stream
413                    .send(WebSocketMessage::Ping(Default::default()))
414                    .await?;
415            }
416            Message::Pong => {
417                self.stream
418                    .send(WebSocketMessage::Pong(Default::default()))
419                    .await?;
420            }
421        }
422
423        Ok(())
424    }
425}
426
427impl<S> MessageStream<S>
428where
429    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
430{
431    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
432        while let Some(bytes) = self.stream.next().await {
433            match bytes? {
434                WebSocketMessage::Binary(bytes) => {
435                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
436                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
437                        .map_err(io::Error::from)?;
438
439                    self.encoding_buffer.clear();
440                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
441                    return Ok(Message::Envelope(envelope));
442                }
443                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
444                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
445                WebSocketMessage::Close(_) => break,
446                _ => {}
447            }
448        }
449        Err(anyhow!("connection closed"))
450    }
451}
452
453impl From<Timestamp> for SystemTime {
454    fn from(val: Timestamp) -> Self {
455        UNIX_EPOCH
456            .checked_add(Duration::new(val.seconds, val.nanos))
457            .unwrap()
458    }
459}
460
461impl From<SystemTime> for Timestamp {
462    fn from(time: SystemTime) -> Self {
463        let duration = time.duration_since(UNIX_EPOCH).unwrap();
464        Self {
465            seconds: duration.as_secs(),
466            nanos: duration.subsec_nanos(),
467        }
468    }
469}
470
471impl From<u128> for Nonce {
472    fn from(nonce: u128) -> Self {
473        let upper_half = (nonce >> 64) as u64;
474        let lower_half = nonce as u64;
475        Self {
476            upper_half,
477            lower_half,
478        }
479    }
480}
481
482impl From<Nonce> for u128 {
483    fn from(nonce: Nonce) -> Self {
484        let upper_half = (nonce.upper_half as u128) << 64;
485        let lower_half = nonce.lower_half as u128;
486        upper_half | lower_half
487    }
488}
489
490pub fn split_worktree_update(
491    mut message: UpdateWorktree,
492    max_chunk_size: usize,
493) -> impl Iterator<Item = UpdateWorktree> {
494    let mut done_files = false;
495
496    let mut repository_map = message
497        .updated_repositories
498        .into_iter()
499        .map(|repo| (repo.work_directory_id, repo))
500        .collect::<HashMap<_, _>>();
501
502    iter::from_fn(move || {
503        if done_files {
504            return None;
505        }
506
507        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
508        let updated_entries: Vec<_> = message
509            .updated_entries
510            .drain(..updated_entries_chunk_size)
511            .collect();
512
513        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
514        let removed_entries = message
515            .removed_entries
516            .drain(..removed_entries_chunk_size)
517            .collect();
518
519        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
520
521        let mut updated_repositories = Vec::new();
522
523        if !repository_map.is_empty() {
524            for entry in &updated_entries {
525                if let Some(repo) = repository_map.remove(&entry.id) {
526                    updated_repositories.push(repo)
527                }
528            }
529        }
530
531        let removed_repositories = if done_files {
532            mem::take(&mut message.removed_repositories)
533        } else {
534            Default::default()
535        };
536
537        if done_files {
538            updated_repositories.extend(mem::take(&mut repository_map).into_values());
539        }
540
541        Some(UpdateWorktree {
542            project_id: message.project_id,
543            worktree_id: message.worktree_id,
544            root_name: message.root_name.clone(),
545            abs_path: message.abs_path.clone(),
546            updated_entries,
547            removed_entries,
548            scan_id: message.scan_id,
549            is_last_update: done_files && message.is_last_update,
550            updated_repositories,
551            removed_repositories,
552        })
553    })
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    #[gpui::test]
561    async fn test_buffer_size() {
562        let (tx, rx) = futures::channel::mpsc::unbounded();
563        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
564        sink.write(Message::Envelope(Envelope {
565            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
566                root_name: "abcdefg".repeat(10),
567                ..Default::default()
568            })),
569            ..Default::default()
570        }))
571        .await
572        .unwrap();
573        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
574        sink.write(Message::Envelope(Envelope {
575            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
576                root_name: "abcdefg".repeat(1000000),
577                ..Default::default()
578            })),
579            ..Default::default()
580        }))
581        .await
582        .unwrap();
583        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
584
585        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
586        stream.read().await.unwrap();
587        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
588        stream.read().await.unwrap();
589        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
590    }
591
592    #[gpui::test]
593    fn test_converting_peer_id_from_and_to_u64() {
594        let peer_id = PeerId {
595            owner_id: 10,
596            id: 3,
597        };
598        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
599        let peer_id = PeerId {
600            owner_id: u32::MAX,
601            id: 3,
602        };
603        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
604        let peer_id = PeerId {
605            owner_id: 10,
606            id: u32::MAX,
607        };
608        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
609        let peer_id = PeerId {
610            owner_id: u32::MAX,
611            id: u32::MAX,
612        };
613        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
614    }
615}