proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use collections::HashMap;
  5use futures::{SinkExt as _, StreamExt as _};
  6use prost::Message as _;
  7use serde::Serialize;
  8use std::any::{Any, TypeId};
  9use std::{
 10    cmp,
 11    fmt::Debug,
 12    io, iter,
 13    time::{Duration, SystemTime, UNIX_EPOCH},
 14};
 15use std::{fmt, mem};
 16
 17include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 18
 19pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 20    const NAME: &'static str;
 21    const PRIORITY: MessagePriority;
 22    fn into_envelope(
 23        self,
 24        id: u32,
 25        responding_to: Option<u32>,
 26        original_sender_id: Option<PeerId>,
 27    ) -> Envelope;
 28    fn from_envelope(envelope: Envelope) -> Option<Self>;
 29}
 30
 31pub trait EntityMessage: EnvelopedMessage {
 32    fn remote_entity_id(&self) -> u64;
 33}
 34
 35pub trait RequestMessage: EnvelopedMessage {
 36    type Response: EnvelopedMessage;
 37}
 38
 39pub trait AnyTypedEnvelope: 'static + Send + Sync {
 40    fn payload_type_id(&self) -> TypeId;
 41    fn payload_type_name(&self) -> &'static str;
 42    fn as_any(&self) -> &dyn Any;
 43    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 44    fn is_background(&self) -> bool;
 45    fn original_sender_id(&self) -> Option<PeerId>;
 46    fn sender_id(&self) -> ConnectionId;
 47    fn message_id(&self) -> u32;
 48}
 49
 50pub enum MessagePriority {
 51    Foreground,
 52    Background,
 53}
 54
 55impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 56    fn payload_type_id(&self) -> TypeId {
 57        TypeId::of::<T>()
 58    }
 59
 60    fn payload_type_name(&self) -> &'static str {
 61        T::NAME
 62    }
 63
 64    fn as_any(&self) -> &dyn Any {
 65        self
 66    }
 67
 68    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 69        self
 70    }
 71
 72    fn is_background(&self) -> bool {
 73        matches!(T::PRIORITY, MessagePriority::Background)
 74    }
 75
 76    fn original_sender_id(&self) -> Option<PeerId> {
 77        self.original_sender_id
 78    }
 79
 80    fn sender_id(&self) -> ConnectionId {
 81        self.sender_id
 82    }
 83
 84    fn message_id(&self) -> u32 {
 85        self.message_id
 86    }
 87}
 88
 89impl PeerId {
 90    pub fn from_u64(peer_id: u64) -> Self {
 91        let owner_id = (peer_id >> 32) as u32;
 92        let id = peer_id as u32;
 93        Self { owner_id, id }
 94    }
 95
 96    pub fn as_u64(self) -> u64 {
 97        ((self.owner_id as u64) << 32) | (self.id as u64)
 98    }
 99}
100
101impl Copy for PeerId {}
102
103impl Eq for PeerId {}
104
105impl Ord for PeerId {
106    fn cmp(&self, other: &Self) -> cmp::Ordering {
107        self.owner_id
108            .cmp(&other.owner_id)
109            .then_with(|| self.id.cmp(&other.id))
110    }
111}
112
113impl PartialOrd for PeerId {
114    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
115        Some(self.cmp(other))
116    }
117}
118
119impl std::hash::Hash for PeerId {
120    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
121        self.owner_id.hash(state);
122        self.id.hash(state);
123    }
124}
125
126impl fmt::Display for PeerId {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        write!(f, "{}/{}", self.owner_id, self.id)
129    }
130}
131
132messages!(
133    (Ack, Foreground),
134    (AddProjectCollaborator, Foreground),
135    (ApplyCodeAction, Background),
136    (ApplyCodeActionResponse, Background),
137    (ApplyCompletionAdditionalEdits, Background),
138    (ApplyCompletionAdditionalEditsResponse, Background),
139    (BufferReloaded, Foreground),
140    (BufferSaved, Foreground),
141    (Call, Foreground),
142    (CallCanceled, Foreground),
143    (CancelCall, Foreground),
144    (ChannelMessageSent, Foreground),
145    (CopyProjectEntry, Foreground),
146    (CreateBufferForPeer, Foreground),
147    (CreateProjectEntry, Foreground),
148    (CreateRoom, Foreground),
149    (CreateRoomResponse, Foreground),
150    (DeclineCall, Foreground),
151    (DeleteProjectEntry, Foreground),
152    (Error, Foreground),
153    (Follow, Foreground),
154    (FollowResponse, Foreground),
155    (FormatBuffers, Foreground),
156    (FormatBuffersResponse, Foreground),
157    (FuzzySearchUsers, Foreground),
158    (GetChannelMessages, Foreground),
159    (GetChannelMessagesResponse, Foreground),
160    (GetChannels, Foreground),
161    (GetChannelsResponse, Foreground),
162    (GetCodeActions, Background),
163    (GetCodeActionsResponse, Background),
164    (GetHover, Background),
165    (GetHoverResponse, Background),
166    (GetCompletions, Background),
167    (GetCompletionsResponse, Background),
168    (GetDefinition, Background),
169    (GetDefinitionResponse, Background),
170    (GetTypeDefinition, Background),
171    (GetTypeDefinitionResponse, Background),
172    (GetDocumentHighlights, Background),
173    (GetDocumentHighlightsResponse, Background),
174    (GetReferences, Background),
175    (GetReferencesResponse, Background),
176    (GetProjectSymbols, Background),
177    (GetProjectSymbolsResponse, Background),
178    (GetUsers, Foreground),
179    (Hello, Foreground),
180    (IncomingCall, Foreground),
181    (UsersResponse, Foreground),
182    (JoinChannel, Foreground),
183    (JoinChannelResponse, Foreground),
184    (JoinProject, Foreground),
185    (JoinProjectResponse, Foreground),
186    (JoinRoom, Foreground),
187    (JoinRoomResponse, Foreground),
188    (LeaveChannel, Foreground),
189    (LeaveProject, Foreground),
190    (LeaveRoom, Foreground),
191    (OpenBufferById, Background),
192    (OpenBufferByPath, Background),
193    (OpenBufferForSymbol, Background),
194    (OpenBufferForSymbolResponse, Background),
195    (OpenBufferResponse, Background),
196    (PerformRename, Background),
197    (PerformRenameResponse, Background),
198    (Ping, Foreground),
199    (PrepareRename, Background),
200    (PrepareRenameResponse, Background),
201    (ProjectEntryResponse, Foreground),
202    (RejoinRoom, Foreground),
203    (RejoinRoomResponse, Foreground),
204    (RemoveContact, Foreground),
205    (ReloadBuffers, Foreground),
206    (ReloadBuffersResponse, Foreground),
207    (RemoveProjectCollaborator, Foreground),
208    (RenameProjectEntry, Foreground),
209    (RequestContact, Foreground),
210    (RespondToContactRequest, Foreground),
211    (RoomUpdated, Foreground),
212    (SaveBuffer, Foreground),
213    (SearchProject, Background),
214    (SearchProjectResponse, Background),
215    (SendChannelMessage, Foreground),
216    (SendChannelMessageResponse, Foreground),
217    (ShareProject, Foreground),
218    (ShareProjectResponse, Foreground),
219    (ShowContacts, Foreground),
220    (StartLanguageServer, Foreground),
221    (SynchronizeBuffers, Foreground),
222    (SynchronizeBuffersResponse, Foreground),
223    (Test, Foreground),
224    (Unfollow, Foreground),
225    (UnshareProject, Foreground),
226    (UpdateBuffer, Foreground),
227    (UpdateBufferFile, Foreground),
228    (UpdateContacts, Foreground),
229    (UpdateDiagnosticSummary, Foreground),
230    (UpdateFollowers, Foreground),
231    (UpdateInviteInfo, Foreground),
232    (UpdateLanguageServer, Foreground),
233    (UpdateParticipantLocation, Foreground),
234    (UpdateProject, Foreground),
235    (UpdateProjectCollaborator, Foreground),
236    (UpdateWorktree, Foreground),
237    (UpdateDiffBase, Foreground),
238    (GetPrivateUserInfo, Foreground),
239    (GetPrivateUserInfoResponse, Foreground),
240);
241
242request_messages!(
243    (ApplyCodeAction, ApplyCodeActionResponse),
244    (
245        ApplyCompletionAdditionalEdits,
246        ApplyCompletionAdditionalEditsResponse
247    ),
248    (Call, Ack),
249    (CancelCall, Ack),
250    (CopyProjectEntry, ProjectEntryResponse),
251    (CreateProjectEntry, ProjectEntryResponse),
252    (CreateRoom, CreateRoomResponse),
253    (DeclineCall, Ack),
254    (DeleteProjectEntry, ProjectEntryResponse),
255    (Follow, FollowResponse),
256    (FormatBuffers, FormatBuffersResponse),
257    (GetChannelMessages, GetChannelMessagesResponse),
258    (GetChannels, GetChannelsResponse),
259    (GetCodeActions, GetCodeActionsResponse),
260    (GetHover, GetHoverResponse),
261    (GetCompletions, GetCompletionsResponse),
262    (GetDefinition, GetDefinitionResponse),
263    (GetTypeDefinition, GetTypeDefinitionResponse),
264    (GetDocumentHighlights, GetDocumentHighlightsResponse),
265    (GetReferences, GetReferencesResponse),
266    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
267    (GetProjectSymbols, GetProjectSymbolsResponse),
268    (FuzzySearchUsers, UsersResponse),
269    (GetUsers, UsersResponse),
270    (JoinChannel, JoinChannelResponse),
271    (JoinProject, JoinProjectResponse),
272    (JoinRoom, JoinRoomResponse),
273    (LeaveRoom, Ack),
274    (RejoinRoom, RejoinRoomResponse),
275    (IncomingCall, Ack),
276    (OpenBufferById, OpenBufferResponse),
277    (OpenBufferByPath, OpenBufferResponse),
278    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
279    (Ping, Ack),
280    (PerformRename, PerformRenameResponse),
281    (PrepareRename, PrepareRenameResponse),
282    (ReloadBuffers, ReloadBuffersResponse),
283    (RequestContact, Ack),
284    (RemoveContact, Ack),
285    (RespondToContactRequest, Ack),
286    (RenameProjectEntry, ProjectEntryResponse),
287    (SaveBuffer, BufferSaved),
288    (SearchProject, SearchProjectResponse),
289    (SendChannelMessage, SendChannelMessageResponse),
290    (ShareProject, ShareProjectResponse),
291    (SynchronizeBuffers, SynchronizeBuffersResponse),
292    (Test, Test),
293    (UpdateBuffer, Ack),
294    (UpdateParticipantLocation, Ack),
295    (UpdateProject, Ack),
296    (UpdateWorktree, Ack),
297);
298
299entity_messages!(
300    project_id,
301    AddProjectCollaborator,
302    ApplyCodeAction,
303    ApplyCompletionAdditionalEdits,
304    BufferReloaded,
305    BufferSaved,
306    CopyProjectEntry,
307    CreateBufferForPeer,
308    CreateProjectEntry,
309    DeleteProjectEntry,
310    Follow,
311    FormatBuffers,
312    GetCodeActions,
313    GetCompletions,
314    GetDefinition,
315    GetTypeDefinition,
316    GetDocumentHighlights,
317    GetHover,
318    GetReferences,
319    GetProjectSymbols,
320    JoinProject,
321    LeaveProject,
322    OpenBufferById,
323    OpenBufferByPath,
324    OpenBufferForSymbol,
325    PerformRename,
326    PrepareRename,
327    ReloadBuffers,
328    RemoveProjectCollaborator,
329    RenameProjectEntry,
330    SaveBuffer,
331    SearchProject,
332    StartLanguageServer,
333    SynchronizeBuffers,
334    Unfollow,
335    UnshareProject,
336    UpdateBuffer,
337    UpdateBufferFile,
338    UpdateDiagnosticSummary,
339    UpdateFollowers,
340    UpdateLanguageServer,
341    UpdateProject,
342    UpdateProjectCollaborator,
343    UpdateWorktree,
344    UpdateDiffBase
345);
346
347entity_messages!(channel_id, ChannelMessageSent);
348
349const KIB: usize = 1024;
350const MIB: usize = KIB * 1024;
351const MAX_BUFFER_LEN: usize = MIB;
352
353/// A stream of protobuf messages.
354pub struct MessageStream<S> {
355    stream: S,
356    encoding_buffer: Vec<u8>,
357}
358
359#[allow(clippy::large_enum_variant)]
360#[derive(Debug)]
361pub enum Message {
362    Envelope(Envelope),
363    Ping,
364    Pong,
365}
366
367impl<S> MessageStream<S> {
368    pub fn new(stream: S) -> Self {
369        Self {
370            stream,
371            encoding_buffer: Vec::new(),
372        }
373    }
374
375    pub fn inner_mut(&mut self) -> &mut S {
376        &mut self.stream
377    }
378}
379
380impl<S> MessageStream<S>
381where
382    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
383{
384    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
385        #[cfg(any(test, feature = "test-support"))]
386        const COMPRESSION_LEVEL: i32 = -7;
387
388        #[cfg(not(any(test, feature = "test-support")))]
389        const COMPRESSION_LEVEL: i32 = 4;
390
391        match message {
392            Message::Envelope(message) => {
393                self.encoding_buffer.reserve(message.encoded_len());
394                message
395                    .encode(&mut self.encoding_buffer)
396                    .map_err(io::Error::from)?;
397                let buffer =
398                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
399                        .unwrap();
400
401                self.encoding_buffer.clear();
402                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
403                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
404            }
405            Message::Ping => {
406                self.stream
407                    .send(WebSocketMessage::Ping(Default::default()))
408                    .await?;
409            }
410            Message::Pong => {
411                self.stream
412                    .send(WebSocketMessage::Pong(Default::default()))
413                    .await?;
414            }
415        }
416
417        Ok(())
418    }
419}
420
421impl<S> MessageStream<S>
422where
423    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
424{
425    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
426        while let Some(bytes) = self.stream.next().await {
427            match bytes? {
428                WebSocketMessage::Binary(bytes) => {
429                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
430                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
431                        .map_err(io::Error::from)?;
432
433                    self.encoding_buffer.clear();
434                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
435                    return Ok(Message::Envelope(envelope));
436                }
437                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
438                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
439                WebSocketMessage::Close(_) => break,
440                _ => {}
441            }
442        }
443        Err(anyhow!("connection closed"))
444    }
445}
446
447impl From<Timestamp> for SystemTime {
448    fn from(val: Timestamp) -> Self {
449        UNIX_EPOCH
450            .checked_add(Duration::new(val.seconds, val.nanos))
451            .unwrap()
452    }
453}
454
455impl From<SystemTime> for Timestamp {
456    fn from(time: SystemTime) -> Self {
457        let duration = time.duration_since(UNIX_EPOCH).unwrap();
458        Self {
459            seconds: duration.as_secs(),
460            nanos: duration.subsec_nanos(),
461        }
462    }
463}
464
465impl From<u128> for Nonce {
466    fn from(nonce: u128) -> Self {
467        let upper_half = (nonce >> 64) as u64;
468        let lower_half = nonce as u64;
469        Self {
470            upper_half,
471            lower_half,
472        }
473    }
474}
475
476impl From<Nonce> for u128 {
477    fn from(nonce: Nonce) -> Self {
478        let upper_half = (nonce.upper_half as u128) << 64;
479        let lower_half = nonce.lower_half as u128;
480        upper_half | lower_half
481    }
482}
483
484pub fn split_worktree_update(
485    mut message: UpdateWorktree,
486    max_chunk_size: usize,
487) -> impl Iterator<Item = UpdateWorktree> {
488    let mut done_files = false;
489
490    let mut repository_map = message
491        .updated_repositories
492        .into_iter()
493        .map(|repo| (repo.work_directory_id, repo))
494        .collect::<HashMap<_, _>>();
495
496    iter::from_fn(move || {
497        if done_files {
498            return None;
499        }
500
501        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
502        let updated_entries: Vec<_> = message
503            .updated_entries
504            .drain(..updated_entries_chunk_size)
505            .collect();
506
507        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
508        let removed_entries = message
509            .removed_entries
510            .drain(..removed_entries_chunk_size)
511            .collect();
512
513        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
514
515        let mut updated_repositories = Vec::new();
516
517        if !repository_map.is_empty() {
518            for entry in &updated_entries {
519                if let Some(repo) = repository_map.remove(&entry.id) {
520                    updated_repositories.push(repo)
521                }
522            }
523        }
524
525        let removed_repositories = if done_files {
526            mem::take(&mut message.removed_repositories)
527        } else {
528            Default::default()
529        };
530
531        if done_files {
532            updated_repositories.extend(mem::take(&mut repository_map).into_values());
533        }
534
535        Some(UpdateWorktree {
536            project_id: message.project_id,
537            worktree_id: message.worktree_id,
538            root_name: message.root_name.clone(),
539            abs_path: message.abs_path.clone(),
540            updated_entries,
541            removed_entries,
542            scan_id: message.scan_id,
543            is_last_update: done_files && message.is_last_update,
544            updated_repositories,
545            removed_repositories,
546        })
547    })
548}
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553
554    #[gpui::test]
555    async fn test_buffer_size() {
556        let (tx, rx) = futures::channel::mpsc::unbounded();
557        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
558        sink.write(Message::Envelope(Envelope {
559            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
560                root_name: "abcdefg".repeat(10),
561                ..Default::default()
562            })),
563            ..Default::default()
564        }))
565        .await
566        .unwrap();
567        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
568        sink.write(Message::Envelope(Envelope {
569            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
570                root_name: "abcdefg".repeat(1000000),
571                ..Default::default()
572            })),
573            ..Default::default()
574        }))
575        .await
576        .unwrap();
577        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
578
579        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
580        stream.read().await.unwrap();
581        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
582        stream.read().await.unwrap();
583        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
584    }
585
586    #[gpui::test]
587    fn test_converting_peer_id_from_and_to_u64() {
588        let peer_id = PeerId {
589            owner_id: 10,
590            id: 3,
591        };
592        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
593        let peer_id = PeerId {
594            owner_id: u32::MAX,
595            id: 3,
596        };
597        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
598        let peer_id = PeerId {
599            owner_id: 10,
600            id: u32::MAX,
601        };
602        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
603        let peer_id = PeerId {
604            owner_id: u32::MAX,
605            id: u32::MAX,
606        };
607        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
608    }
609}