proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use collections::HashMap;
  5use futures::{SinkExt as _, StreamExt as _};
  6use prost::Message as _;
  7use serde::Serialize;
  8use std::any::{Any, TypeId};
  9use std::{
 10    cmp,
 11    fmt::Debug,
 12    io, iter,
 13    time::{Duration, SystemTime, UNIX_EPOCH},
 14};
 15use std::{fmt, mem};
 16
 17include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 18
 19pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 20    const NAME: &'static str;
 21    const PRIORITY: MessagePriority;
 22    fn into_envelope(
 23        self,
 24        id: u32,
 25        responding_to: Option<u32>,
 26        original_sender_id: Option<PeerId>,
 27    ) -> Envelope;
 28    fn from_envelope(envelope: Envelope) -> Option<Self>;
 29}
 30
 31pub trait EntityMessage: EnvelopedMessage {
 32    fn remote_entity_id(&self) -> u64;
 33}
 34
 35pub trait RequestMessage: EnvelopedMessage {
 36    type Response: EnvelopedMessage;
 37}
 38
 39pub trait AnyTypedEnvelope: 'static + Send + Sync {
 40    fn payload_type_id(&self) -> TypeId;
 41    fn payload_type_name(&self) -> &'static str;
 42    fn as_any(&self) -> &dyn Any;
 43    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 44    fn is_background(&self) -> bool;
 45    fn original_sender_id(&self) -> Option<PeerId>;
 46    fn sender_id(&self) -> ConnectionId;
 47    fn message_id(&self) -> u32;
 48}
 49
 50pub enum MessagePriority {
 51    Foreground,
 52    Background,
 53}
 54
 55impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 56    fn payload_type_id(&self) -> TypeId {
 57        TypeId::of::<T>()
 58    }
 59
 60    fn payload_type_name(&self) -> &'static str {
 61        T::NAME
 62    }
 63
 64    fn as_any(&self) -> &dyn Any {
 65        self
 66    }
 67
 68    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 69        self
 70    }
 71
 72    fn is_background(&self) -> bool {
 73        matches!(T::PRIORITY, MessagePriority::Background)
 74    }
 75
 76    fn original_sender_id(&self) -> Option<PeerId> {
 77        self.original_sender_id
 78    }
 79
 80    fn sender_id(&self) -> ConnectionId {
 81        self.sender_id
 82    }
 83
 84    fn message_id(&self) -> u32 {
 85        self.message_id
 86    }
 87}
 88
 89impl PeerId {
 90    pub fn from_u64(peer_id: u64) -> Self {
 91        let owner_id = (peer_id >> 32) as u32;
 92        let id = peer_id as u32;
 93        Self { owner_id, id }
 94    }
 95
 96    pub fn as_u64(self) -> u64 {
 97        ((self.owner_id as u64) << 32) | (self.id as u64)
 98    }
 99}
100
101impl Copy for PeerId {}
102
103impl Eq for PeerId {}
104
105impl Ord for PeerId {
106    fn cmp(&self, other: &Self) -> cmp::Ordering {
107        self.owner_id
108            .cmp(&other.owner_id)
109            .then_with(|| self.id.cmp(&other.id))
110    }
111}
112
113impl PartialOrd for PeerId {
114    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
115        Some(self.cmp(other))
116    }
117}
118
119impl std::hash::Hash for PeerId {
120    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
121        self.owner_id.hash(state);
122        self.id.hash(state);
123    }
124}
125
126impl fmt::Display for PeerId {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        write!(f, "{}/{}", self.owner_id, self.id)
129    }
130}
131
132messages!(
133    (Ack, Foreground),
134    (AddProjectCollaborator, Foreground),
135    (ApplyCodeAction, Background),
136    (ApplyCodeActionResponse, Background),
137    (ApplyCompletionAdditionalEdits, Background),
138    (ApplyCompletionAdditionalEditsResponse, Background),
139    (BufferReloaded, Foreground),
140    (BufferSaved, Foreground),
141    (Call, Foreground),
142    (CallCanceled, Foreground),
143    (CancelCall, Foreground),
144    (ChannelMessageSent, Foreground),
145    (CopyProjectEntry, Foreground),
146    (CreateBufferForPeer, Foreground),
147    (CreateProjectEntry, Foreground),
148    (CreateRoom, Foreground),
149    (CreateRoomResponse, Foreground),
150    (DeclineCall, Foreground),
151    (DeleteProjectEntry, Foreground),
152    (Error, Foreground),
153    (Follow, Foreground),
154    (FollowResponse, Foreground),
155    (FormatBuffers, Foreground),
156    (FormatBuffersResponse, Foreground),
157    (FuzzySearchUsers, Foreground),
158    (GetChannelMessages, Foreground),
159    (GetChannelMessagesResponse, Foreground),
160    (GetChannels, Foreground),
161    (GetChannelsResponse, Foreground),
162    (GetCodeActions, Background),
163    (GetCodeActionsResponse, Background),
164    (GetHover, Background),
165    (GetHoverResponse, Background),
166    (GetCompletions, Background),
167    (GetCompletionsResponse, Background),
168    (GetDefinition, Background),
169    (GetDefinitionResponse, Background),
170    (GetTypeDefinition, Background),
171    (GetTypeDefinitionResponse, Background),
172    (GetDocumentHighlights, Background),
173    (GetDocumentHighlightsResponse, Background),
174    (GetReferences, Background),
175    (GetReferencesResponse, Background),
176    (GetProjectSymbols, Background),
177    (GetProjectSymbolsResponse, Background),
178    (GetUsers, Foreground),
179    (Hello, Foreground),
180    (IncomingCall, Foreground),
181    (UsersResponse, Foreground),
182    (JoinChannel, Foreground),
183    (JoinChannelResponse, Foreground),
184    (JoinProject, Foreground),
185    (JoinProjectResponse, Foreground),
186    (JoinRoom, Foreground),
187    (JoinRoomResponse, Foreground),
188    (LeaveChannel, Foreground),
189    (LeaveProject, Foreground),
190    (LeaveRoom, Foreground),
191    (OpenBufferById, Background),
192    (OpenBufferByPath, Background),
193    (OpenBufferForSymbol, Background),
194    (OpenBufferForSymbolResponse, Background),
195    (OpenBufferResponse, Background),
196    (PerformRename, Background),
197    (PerformRenameResponse, Background),
198    (OnTypeFormatting, Background),
199    (OnTypeFormattingResponse, Background),
200    (Ping, Foreground),
201    (PrepareRename, Background),
202    (PrepareRenameResponse, Background),
203    (ProjectEntryResponse, Foreground),
204    (RejoinRoom, Foreground),
205    (RejoinRoomResponse, Foreground),
206    (RemoveContact, Foreground),
207    (ReloadBuffers, Foreground),
208    (ReloadBuffersResponse, Foreground),
209    (RemoveProjectCollaborator, Foreground),
210    (RenameProjectEntry, Foreground),
211    (RequestContact, Foreground),
212    (RespondToContactRequest, Foreground),
213    (RoomUpdated, Foreground),
214    (SaveBuffer, Foreground),
215    (SearchProject, Background),
216    (SearchProjectResponse, Background),
217    (SendChannelMessage, Foreground),
218    (SendChannelMessageResponse, Foreground),
219    (ShareProject, Foreground),
220    (ShareProjectResponse, Foreground),
221    (ShowContacts, Foreground),
222    (StartLanguageServer, Foreground),
223    (SynchronizeBuffers, Foreground),
224    (SynchronizeBuffersResponse, Foreground),
225    (Test, Foreground),
226    (Unfollow, Foreground),
227    (UnshareProject, Foreground),
228    (UpdateBuffer, Foreground),
229    (UpdateBufferFile, Foreground),
230    (UpdateContacts, Foreground),
231    (UpdateDiagnosticSummary, Foreground),
232    (UpdateFollowers, Foreground),
233    (UpdateInviteInfo, Foreground),
234    (UpdateLanguageServer, Foreground),
235    (UpdateParticipantLocation, Foreground),
236    (UpdateProject, Foreground),
237    (UpdateProjectCollaborator, Foreground),
238    (UpdateWorktree, Foreground),
239    (UpdateDiffBase, Foreground),
240    (GetPrivateUserInfo, Foreground),
241    (GetPrivateUserInfoResponse, Foreground),
242);
243
244request_messages!(
245    (ApplyCodeAction, ApplyCodeActionResponse),
246    (
247        ApplyCompletionAdditionalEdits,
248        ApplyCompletionAdditionalEditsResponse
249    ),
250    (Call, Ack),
251    (CancelCall, Ack),
252    (CopyProjectEntry, ProjectEntryResponse),
253    (CreateProjectEntry, ProjectEntryResponse),
254    (CreateRoom, CreateRoomResponse),
255    (DeclineCall, Ack),
256    (DeleteProjectEntry, ProjectEntryResponse),
257    (Follow, FollowResponse),
258    (FormatBuffers, FormatBuffersResponse),
259    (GetChannelMessages, GetChannelMessagesResponse),
260    (GetChannels, GetChannelsResponse),
261    (GetCodeActions, GetCodeActionsResponse),
262    (GetHover, GetHoverResponse),
263    (GetCompletions, GetCompletionsResponse),
264    (GetDefinition, GetDefinitionResponse),
265    (GetTypeDefinition, GetTypeDefinitionResponse),
266    (GetDocumentHighlights, GetDocumentHighlightsResponse),
267    (GetReferences, GetReferencesResponse),
268    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
269    (GetProjectSymbols, GetProjectSymbolsResponse),
270    (FuzzySearchUsers, UsersResponse),
271    (GetUsers, UsersResponse),
272    (JoinChannel, JoinChannelResponse),
273    (JoinProject, JoinProjectResponse),
274    (JoinRoom, JoinRoomResponse),
275    (LeaveRoom, Ack),
276    (RejoinRoom, RejoinRoomResponse),
277    (IncomingCall, Ack),
278    (OpenBufferById, OpenBufferResponse),
279    (OpenBufferByPath, OpenBufferResponse),
280    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
281    (Ping, Ack),
282    (PerformRename, PerformRenameResponse),
283    (PrepareRename, PrepareRenameResponse),
284    (OnTypeFormatting, OnTypeFormattingResponse),
285    (ReloadBuffers, ReloadBuffersResponse),
286    (RequestContact, Ack),
287    (RemoveContact, Ack),
288    (RespondToContactRequest, Ack),
289    (RenameProjectEntry, ProjectEntryResponse),
290    (SaveBuffer, BufferSaved),
291    (SearchProject, SearchProjectResponse),
292    (SendChannelMessage, SendChannelMessageResponse),
293    (ShareProject, ShareProjectResponse),
294    (SynchronizeBuffers, SynchronizeBuffersResponse),
295    (Test, Test),
296    (UpdateBuffer, Ack),
297    (UpdateParticipantLocation, Ack),
298    (UpdateProject, Ack),
299    (UpdateWorktree, Ack),
300);
301
302entity_messages!(
303    project_id,
304    AddProjectCollaborator,
305    ApplyCodeAction,
306    ApplyCompletionAdditionalEdits,
307    BufferReloaded,
308    BufferSaved,
309    CopyProjectEntry,
310    CreateBufferForPeer,
311    CreateProjectEntry,
312    DeleteProjectEntry,
313    Follow,
314    FormatBuffers,
315    GetCodeActions,
316    GetCompletions,
317    GetDefinition,
318    GetTypeDefinition,
319    GetDocumentHighlights,
320    GetHover,
321    GetReferences,
322    GetProjectSymbols,
323    JoinProject,
324    LeaveProject,
325    OpenBufferById,
326    OpenBufferByPath,
327    OpenBufferForSymbol,
328    PerformRename,
329    OnTypeFormatting,
330    PrepareRename,
331    ReloadBuffers,
332    RemoveProjectCollaborator,
333    RenameProjectEntry,
334    SaveBuffer,
335    SearchProject,
336    StartLanguageServer,
337    SynchronizeBuffers,
338    Unfollow,
339    UnshareProject,
340    UpdateBuffer,
341    UpdateBufferFile,
342    UpdateDiagnosticSummary,
343    UpdateFollowers,
344    UpdateLanguageServer,
345    UpdateProject,
346    UpdateProjectCollaborator,
347    UpdateWorktree,
348    UpdateDiffBase
349);
350
351entity_messages!(channel_id, ChannelMessageSent);
352
353const KIB: usize = 1024;
354const MIB: usize = KIB * 1024;
355const MAX_BUFFER_LEN: usize = MIB;
356
357/// A stream of protobuf messages.
358pub struct MessageStream<S> {
359    stream: S,
360    encoding_buffer: Vec<u8>,
361}
362
363#[allow(clippy::large_enum_variant)]
364#[derive(Debug)]
365pub enum Message {
366    Envelope(Envelope),
367    Ping,
368    Pong,
369}
370
371impl<S> MessageStream<S> {
372    pub fn new(stream: S) -> Self {
373        Self {
374            stream,
375            encoding_buffer: Vec::new(),
376        }
377    }
378
379    pub fn inner_mut(&mut self) -> &mut S {
380        &mut self.stream
381    }
382}
383
384impl<S> MessageStream<S>
385where
386    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
387{
388    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
389        #[cfg(any(test, feature = "test-support"))]
390        const COMPRESSION_LEVEL: i32 = -7;
391
392        #[cfg(not(any(test, feature = "test-support")))]
393        const COMPRESSION_LEVEL: i32 = 4;
394
395        match message {
396            Message::Envelope(message) => {
397                self.encoding_buffer.reserve(message.encoded_len());
398                message
399                    .encode(&mut self.encoding_buffer)
400                    .map_err(io::Error::from)?;
401                let buffer =
402                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
403                        .unwrap();
404
405                self.encoding_buffer.clear();
406                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
407                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
408            }
409            Message::Ping => {
410                self.stream
411                    .send(WebSocketMessage::Ping(Default::default()))
412                    .await?;
413            }
414            Message::Pong => {
415                self.stream
416                    .send(WebSocketMessage::Pong(Default::default()))
417                    .await?;
418            }
419        }
420
421        Ok(())
422    }
423}
424
425impl<S> MessageStream<S>
426where
427    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
428{
429    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
430        while let Some(bytes) = self.stream.next().await {
431            match bytes? {
432                WebSocketMessage::Binary(bytes) => {
433                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
434                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
435                        .map_err(io::Error::from)?;
436
437                    self.encoding_buffer.clear();
438                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
439                    return Ok(Message::Envelope(envelope));
440                }
441                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
442                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
443                WebSocketMessage::Close(_) => break,
444                _ => {}
445            }
446        }
447        Err(anyhow!("connection closed"))
448    }
449}
450
451impl From<Timestamp> for SystemTime {
452    fn from(val: Timestamp) -> Self {
453        UNIX_EPOCH
454            .checked_add(Duration::new(val.seconds, val.nanos))
455            .unwrap()
456    }
457}
458
459impl From<SystemTime> for Timestamp {
460    fn from(time: SystemTime) -> Self {
461        let duration = time.duration_since(UNIX_EPOCH).unwrap();
462        Self {
463            seconds: duration.as_secs(),
464            nanos: duration.subsec_nanos(),
465        }
466    }
467}
468
469impl From<u128> for Nonce {
470    fn from(nonce: u128) -> Self {
471        let upper_half = (nonce >> 64) as u64;
472        let lower_half = nonce as u64;
473        Self {
474            upper_half,
475            lower_half,
476        }
477    }
478}
479
480impl From<Nonce> for u128 {
481    fn from(nonce: Nonce) -> Self {
482        let upper_half = (nonce.upper_half as u128) << 64;
483        let lower_half = nonce.lower_half as u128;
484        upper_half | lower_half
485    }
486}
487
488pub fn split_worktree_update(
489    mut message: UpdateWorktree,
490    max_chunk_size: usize,
491) -> impl Iterator<Item = UpdateWorktree> {
492    let mut done_files = false;
493
494    let mut repository_map = message
495        .updated_repositories
496        .into_iter()
497        .map(|repo| (repo.work_directory_id, repo))
498        .collect::<HashMap<_, _>>();
499
500    iter::from_fn(move || {
501        if done_files {
502            return None;
503        }
504
505        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
506        let updated_entries: Vec<_> = message
507            .updated_entries
508            .drain(..updated_entries_chunk_size)
509            .collect();
510
511        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
512        let removed_entries = message
513            .removed_entries
514            .drain(..removed_entries_chunk_size)
515            .collect();
516
517        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
518
519        let mut updated_repositories = Vec::new();
520
521        if !repository_map.is_empty() {
522            for entry in &updated_entries {
523                if let Some(repo) = repository_map.remove(&entry.id) {
524                    updated_repositories.push(repo)
525                }
526            }
527        }
528
529        let removed_repositories = if done_files {
530            mem::take(&mut message.removed_repositories)
531        } else {
532            Default::default()
533        };
534
535        if done_files {
536            updated_repositories.extend(mem::take(&mut repository_map).into_values());
537        }
538
539        Some(UpdateWorktree {
540            project_id: message.project_id,
541            worktree_id: message.worktree_id,
542            root_name: message.root_name.clone(),
543            abs_path: message.abs_path.clone(),
544            updated_entries,
545            removed_entries,
546            scan_id: message.scan_id,
547            is_last_update: done_files && message.is_last_update,
548            updated_repositories,
549            removed_repositories,
550        })
551    })
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557
558    #[gpui::test]
559    async fn test_buffer_size() {
560        let (tx, rx) = futures::channel::mpsc::unbounded();
561        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
562        sink.write(Message::Envelope(Envelope {
563            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
564                root_name: "abcdefg".repeat(10),
565                ..Default::default()
566            })),
567            ..Default::default()
568        }))
569        .await
570        .unwrap();
571        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
572        sink.write(Message::Envelope(Envelope {
573            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
574                root_name: "abcdefg".repeat(1000000),
575                ..Default::default()
576            })),
577            ..Default::default()
578        }))
579        .await
580        .unwrap();
581        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
582
583        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
584        stream.read().await.unwrap();
585        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
586        stream.read().await.unwrap();
587        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
588    }
589
590    #[gpui::test]
591    fn test_converting_peer_id_from_and_to_u64() {
592        let peer_id = PeerId {
593            owner_id: 10,
594            id: 3,
595        };
596        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
597        let peer_id = PeerId {
598            owner_id: u32::MAX,
599            id: 3,
600        };
601        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
602        let peer_id = PeerId {
603            owner_id: 10,
604            id: u32::MAX,
605        };
606        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
607        let peer_id = PeerId {
608            owner_id: u32::MAX,
609            id: u32::MAX,
610        };
611        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
612    }
613}