proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use collections::HashMap;
  5use futures::{SinkExt as _, StreamExt as _};
  6use prost::Message as _;
  7use serde::Serialize;
  8use std::any::{Any, TypeId};
  9use std::{
 10    cmp,
 11    fmt::Debug,
 12    io, iter,
 13    time::{Duration, SystemTime, UNIX_EPOCH},
 14};
 15use std::{fmt, mem};
 16
 17include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 18
 19pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 20    const NAME: &'static str;
 21    const PRIORITY: MessagePriority;
 22    fn into_envelope(
 23        self,
 24        id: u32,
 25        responding_to: Option<u32>,
 26        original_sender_id: Option<PeerId>,
 27    ) -> Envelope;
 28    fn from_envelope(envelope: Envelope) -> Option<Self>;
 29}
 30
 31pub trait EntityMessage: EnvelopedMessage {
 32    fn remote_entity_id(&self) -> u64;
 33}
 34
 35pub trait RequestMessage: EnvelopedMessage {
 36    type Response: EnvelopedMessage;
 37}
 38
 39pub trait AnyTypedEnvelope: 'static + Send + Sync {
 40    fn payload_type_id(&self) -> TypeId;
 41    fn payload_type_name(&self) -> &'static str;
 42    fn as_any(&self) -> &dyn Any;
 43    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 44    fn is_background(&self) -> bool;
 45    fn original_sender_id(&self) -> Option<PeerId>;
 46    fn sender_id(&self) -> ConnectionId;
 47    fn message_id(&self) -> u32;
 48}
 49
 50pub enum MessagePriority {
 51    Foreground,
 52    Background,
 53}
 54
 55impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 56    fn payload_type_id(&self) -> TypeId {
 57        TypeId::of::<T>()
 58    }
 59
 60    fn payload_type_name(&self) -> &'static str {
 61        T::NAME
 62    }
 63
 64    fn as_any(&self) -> &dyn Any {
 65        self
 66    }
 67
 68    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 69        self
 70    }
 71
 72    fn is_background(&self) -> bool {
 73        matches!(T::PRIORITY, MessagePriority::Background)
 74    }
 75
 76    fn original_sender_id(&self) -> Option<PeerId> {
 77        self.original_sender_id
 78    }
 79
 80    fn sender_id(&self) -> ConnectionId {
 81        self.sender_id
 82    }
 83
 84    fn message_id(&self) -> u32 {
 85        self.message_id
 86    }
 87}
 88
 89impl PeerId {
 90    pub fn from_u64(peer_id: u64) -> Self {
 91        let owner_id = (peer_id >> 32) as u32;
 92        let id = peer_id as u32;
 93        Self { owner_id, id }
 94    }
 95
 96    pub fn as_u64(self) -> u64 {
 97        ((self.owner_id as u64) << 32) | (self.id as u64)
 98    }
 99}
100
101impl Copy for PeerId {}
102
103impl Eq for PeerId {}
104
105impl Ord for PeerId {
106    fn cmp(&self, other: &Self) -> cmp::Ordering {
107        self.owner_id
108            .cmp(&other.owner_id)
109            .then_with(|| self.id.cmp(&other.id))
110    }
111}
112
113impl PartialOrd for PeerId {
114    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
115        Some(self.cmp(other))
116    }
117}
118
119impl std::hash::Hash for PeerId {
120    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
121        self.owner_id.hash(state);
122        self.id.hash(state);
123    }
124}
125
126impl fmt::Display for PeerId {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        write!(f, "{}/{}", self.owner_id, self.id)
129    }
130}
131
132messages!(
133    (Ack, Foreground),
134    (AddProjectCollaborator, Foreground),
135    (ApplyCodeAction, Background),
136    (ApplyCodeActionResponse, Background),
137    (ApplyCompletionAdditionalEdits, Background),
138    (ApplyCompletionAdditionalEditsResponse, Background),
139    (BufferReloaded, Foreground),
140    (BufferSaved, Foreground),
141    (Call, Foreground),
142    (CallCanceled, Foreground),
143    (CancelCall, Foreground),
144    (ChannelMessageSent, Foreground),
145    (CopyProjectEntry, Foreground),
146    (CreateBufferForPeer, Foreground),
147    (CreateProjectEntry, Foreground),
148    (CreateRoom, Foreground),
149    (CreateRoomResponse, Foreground),
150    (DeclineCall, Foreground),
151    (DeleteProjectEntry, Foreground),
152    (Error, Foreground),
153    (Follow, Foreground),
154    (FollowResponse, Foreground),
155    (FormatBuffers, Foreground),
156    (FormatBuffersResponse, Foreground),
157    (FuzzySearchUsers, Foreground),
158    (GetChannelMessages, Foreground),
159    (GetChannelMessagesResponse, Foreground),
160    (GetChannels, Foreground),
161    (GetChannelsResponse, Foreground),
162    (GetCodeActions, Background),
163    (GetCodeActionsResponse, Background),
164    (GetHover, Background),
165    (GetHoverResponse, Background),
166    (GetCompletions, Background),
167    (GetCompletionsResponse, Background),
168    (GetDefinition, Background),
169    (GetDefinitionResponse, Background),
170    (GetTypeDefinition, Background),
171    (GetTypeDefinitionResponse, Background),
172    (GetDocumentHighlights, Background),
173    (GetDocumentHighlightsResponse, Background),
174    (GetReferences, Background),
175    (GetReferencesResponse, Background),
176    (GetProjectSymbols, Background),
177    (GetProjectSymbolsResponse, Background),
178    (GetUsers, Foreground),
179    (Hello, Foreground),
180    (IncomingCall, Foreground),
181    (UsersResponse, Foreground),
182    (JoinChannel, Foreground),
183    (JoinChannelResponse, Foreground),
184    (JoinProject, Foreground),
185    (JoinProjectResponse, Foreground),
186    (JoinRoom, Foreground),
187    (JoinRoomResponse, Foreground),
188    (LeaveChannel, Foreground),
189    (LeaveProject, Foreground),
190    (LeaveRoom, Foreground),
191    (OpenBufferById, Background),
192    (OpenBufferByPath, Background),
193    (OpenBufferForSymbol, Background),
194    (OpenBufferForSymbolResponse, Background),
195    (OpenBufferResponse, Background),
196    (PerformRename, Background),
197    (PerformRenameResponse, Background),
198    (OnTypeFormatting, Background),
199    (OnTypeFormattingResponse, Background),
200    (Ping, Foreground),
201    (PrepareRename, Background),
202    (PrepareRenameResponse, Background),
203    (ProjectEntryResponse, Foreground),
204    (RejoinRoom, Foreground),
205    (RejoinRoomResponse, Foreground),
206    (RemoveContact, Foreground),
207    (ReloadBuffers, Foreground),
208    (ReloadBuffersResponse, Foreground),
209    (RemoveProjectCollaborator, Foreground),
210    (RenameProjectEntry, Foreground),
211    (RequestContact, Foreground),
212    (RespondToContactRequest, Foreground),
213    (RoomUpdated, Foreground),
214    (SaveBuffer, Foreground),
215    (SearchProject, Background),
216    (SearchProjectResponse, Background),
217    (SendChannelMessage, Foreground),
218    (SendChannelMessageResponse, Foreground),
219    (ShareProject, Foreground),
220    (ShareProjectResponse, Foreground),
221    (ShowContacts, Foreground),
222    (StartLanguageServer, Foreground),
223    (SynchronizeBuffers, Foreground),
224    (SynchronizeBuffersResponse, Foreground),
225    (Test, Foreground),
226    (Unfollow, Foreground),
227    (UnshareProject, Foreground),
228    (UpdateBuffer, Foreground),
229    (UpdateBufferFile, Foreground),
230    (UpdateContacts, Foreground),
231    (UpdateDiagnosticSummary, Foreground),
232    (UpdateFollowers, Foreground),
233    (UpdateInviteInfo, Foreground),
234    (UpdateLanguageServer, Foreground),
235    (UpdateParticipantLocation, Foreground),
236    (UpdateProject, Foreground),
237    (UpdateProjectCollaborator, Foreground),
238    (UpdateWorktree, Foreground),
239    (UpdateDiffBase, Foreground),
240    (GetPrivateUserInfo, Foreground),
241    (GetPrivateUserInfoResponse, Foreground),
242);
243
244request_messages!(
245    (ApplyCodeAction, ApplyCodeActionResponse),
246    (
247        ApplyCompletionAdditionalEdits,
248        ApplyCompletionAdditionalEditsResponse
249    ),
250    (Call, Ack),
251    (CancelCall, Ack),
252    (CopyProjectEntry, ProjectEntryResponse),
253    (CreateProjectEntry, ProjectEntryResponse),
254    (CreateRoom, CreateRoomResponse),
255    (DeclineCall, Ack),
256    (DeleteProjectEntry, ProjectEntryResponse),
257    (Follow, FollowResponse),
258    (FormatBuffers, FormatBuffersResponse),
259    (GetChannelMessages, GetChannelMessagesResponse),
260    (GetChannels, GetChannelsResponse),
261    (GetCodeActions, GetCodeActionsResponse),
262    (GetHover, GetHoverResponse),
263    (GetCompletions, GetCompletionsResponse),
264    (GetDefinition, GetDefinitionResponse),
265    (GetTypeDefinition, GetTypeDefinitionResponse),
266    (GetDocumentHighlights, GetDocumentHighlightsResponse),
267    (GetReferences, GetReferencesResponse),
268    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
269    (GetProjectSymbols, GetProjectSymbolsResponse),
270    (FuzzySearchUsers, UsersResponse),
271    (GetUsers, UsersResponse),
272    (JoinChannel, JoinChannelResponse),
273    (JoinProject, JoinProjectResponse),
274    (JoinRoom, JoinRoomResponse),
275    (LeaveRoom, Ack),
276    (RejoinRoom, RejoinRoomResponse),
277    (IncomingCall, Ack),
278    (OpenBufferById, OpenBufferResponse),
279    (OpenBufferByPath, OpenBufferResponse),
280    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
281    (Ping, Ack),
282    (PerformRename, PerformRenameResponse),
283    (PrepareRename, PrepareRenameResponse),
284    (OnTypeFormatting, OnTypeFormattingResponse),
285    (ReloadBuffers, ReloadBuffersResponse),
286    (RequestContact, Ack),
287    (RemoveContact, Ack),
288    (RespondToContactRequest, Ack),
289    (RenameProjectEntry, ProjectEntryResponse),
290    (SaveBuffer, BufferSaved),
291    (SearchProject, SearchProjectResponse),
292    (SendChannelMessage, SendChannelMessageResponse),
293    (ShareProject, ShareProjectResponse),
294    (SynchronizeBuffers, SynchronizeBuffersResponse),
295    (Test, Test),
296    (UpdateBuffer, Ack),
297    (UpdateParticipantLocation, Ack),
298    (UpdateProject, Ack),
299    (UpdateWorktree, Ack),
300);
301
302entity_messages!(
303    project_id,
304    AddProjectCollaborator,
305    ApplyCodeAction,
306    ApplyCompletionAdditionalEdits,
307    BufferReloaded,
308    BufferSaved,
309    CopyProjectEntry,
310    CreateBufferForPeer,
311    CreateProjectEntry,
312    DeleteProjectEntry,
313    Follow,
314    FormatBuffers,
315    GetCodeActions,
316    GetCompletions,
317    GetDefinition,
318    GetTypeDefinition,
319    GetDocumentHighlights,
320    GetHover,
321    GetReferences,
322    GetProjectSymbols,
323    JoinProject,
324    LeaveProject,
325    OpenBufferById,
326    OpenBufferByPath,
327    OpenBufferForSymbol,
328    PerformRename,
329    OnTypeFormatting,
330    PrepareRename,
331    ReloadBuffers,
332    RemoveProjectCollaborator,
333    RenameProjectEntry,
334    SaveBuffer,
335    SearchProject,
336    StartLanguageServer,
337    SynchronizeBuffers,
338    Unfollow,
339    UnshareProject,
340    UpdateBuffer,
341    UpdateBufferFile,
342    UpdateDiagnosticSummary,
343    UpdateFollowers,
344    UpdateLanguageServer,
345    UpdateProject,
346    UpdateProjectCollaborator,
347    UpdateWorktree,
348    UpdateDiffBase
349);
350
351entity_messages!(channel_id, ChannelMessageSent);
352
353const KIB: usize = 1024;
354const MIB: usize = KIB * 1024;
355const MAX_BUFFER_LEN: usize = MIB;
356
357/// A stream of protobuf messages.
358pub struct MessageStream<S> {
359    stream: S,
360    encoding_buffer: Vec<u8>,
361}
362
363#[allow(clippy::large_enum_variant)]
364#[derive(Debug)]
365pub enum Message {
366    Envelope(Envelope),
367    Ping,
368    Pong,
369}
370
371impl<S> MessageStream<S> {
372    pub fn new(stream: S) -> Self {
373        Self {
374            stream,
375            encoding_buffer: Vec::new(),
376        }
377    }
378
379    pub fn inner_mut(&mut self) -> &mut S {
380        &mut self.stream
381    }
382}
383
384impl<S> MessageStream<S>
385where
386    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
387{
388    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
389        #[cfg(any(test, feature = "test-support"))]
390        const COMPRESSION_LEVEL: i32 = -7;
391
392        #[cfg(not(any(test, feature = "test-support")))]
393        const COMPRESSION_LEVEL: i32 = 4;
394
395        match message {
396            Message::Envelope(message) => {
397                self.encoding_buffer.reserve(message.encoded_len());
398                message
399                    .encode(&mut self.encoding_buffer)
400                    .map_err(io::Error::from)?;
401                let buffer =
402                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
403                        .unwrap();
404
405                self.encoding_buffer.clear();
406                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
407                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
408            }
409            Message::Ping => {
410                self.stream
411                    .send(WebSocketMessage::Ping(Default::default()))
412                    .await?;
413            }
414            Message::Pong => {
415                self.stream
416                    .send(WebSocketMessage::Pong(Default::default()))
417                    .await?;
418            }
419        }
420
421        Ok(())
422    }
423}
424
425impl<S> MessageStream<S>
426where
427    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
428{
429    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
430        while let Some(bytes) = self.stream.next().await {
431            match bytes? {
432                WebSocketMessage::Binary(bytes) => {
433                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
434                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
435                        .map_err(io::Error::from)?;
436
437                    self.encoding_buffer.clear();
438                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
439                    return Ok(Message::Envelope(envelope));
440                }
441                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
442                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
443                WebSocketMessage::Close(_) => break,
444                _ => {}
445            }
446        }
447        Err(anyhow!("connection closed"))
448    }
449}
450
451impl From<Timestamp> for SystemTime {
452    fn from(val: Timestamp) -> Self {
453        UNIX_EPOCH
454            .checked_add(Duration::new(val.seconds, val.nanos))
455            .unwrap()
456    }
457}
458
459impl From<SystemTime> for Timestamp {
460    fn from(time: SystemTime) -> Self {
461        let duration = time.duration_since(UNIX_EPOCH).unwrap();
462        Self {
463            seconds: duration.as_secs(),
464            nanos: duration.subsec_nanos(),
465        }
466    }
467}
468
469impl From<u128> for Nonce {
470    fn from(nonce: u128) -> Self {
471        let upper_half = (nonce >> 64) as u64;
472        let lower_half = nonce as u64;
473        Self {
474            upper_half,
475            lower_half,
476        }
477    }
478}
479
480impl From<Nonce> for u128 {
481    fn from(nonce: Nonce) -> Self {
482        let upper_half = (nonce.upper_half as u128) << 64;
483        let lower_half = nonce.lower_half as u128;
484        upper_half | lower_half
485    }
486}
487
488pub fn split_worktree_update(
489    mut message: UpdateWorktree,
490    max_chunk_size: usize,
491) -> impl Iterator<Item = UpdateWorktree> {
492    let mut done_files = false;
493
494    let mut repository_map = message
495        .updated_repositories
496        .into_iter()
497        .map(|repo| (repo.work_directory_id, repo))
498        .collect::<HashMap<_, _>>();
499
500    // Maintain a list of inflight repositories
501    // Every time we send a repository's work directory, stick it in the list of in-flight repositories
502    // Every go of the loop, drain each in-flight repository's statuses
503    // Until we have no more data
504
505    iter::from_fn(move || {
506        if done_files {
507            return None;
508        }
509
510        let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
511        let updated_entries: Vec<_> = message
512            .updated_entries
513            .drain(..updated_entries_chunk_size)
514            .collect();
515
516        let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
517        let removed_entries = message
518            .removed_entries
519            .drain(..removed_entries_chunk_size)
520            .collect();
521
522        done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
523
524        let mut updated_repositories = Vec::new();
525
526        if !repository_map.is_empty() {
527            for entry in &updated_entries {
528                if let Some(repo) = repository_map.remove(&entry.id) {
529                    updated_repositories.push(repo)
530                }
531            }
532        }
533
534        let removed_repositories = if done_files {
535            mem::take(&mut message.removed_repositories)
536        } else {
537            Default::default()
538        };
539
540        if done_files {
541            updated_repositories.extend(mem::take(&mut repository_map).into_values());
542        }
543
544        Some(UpdateWorktree {
545            project_id: message.project_id,
546            worktree_id: message.worktree_id,
547            root_name: message.root_name.clone(),
548            abs_path: message.abs_path.clone(),
549            updated_entries,
550            removed_entries,
551            scan_id: message.scan_id,
552            is_last_update: done_files && message.is_last_update,
553            updated_repositories,
554            removed_repositories,
555        })
556    })
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    #[gpui::test]
564    async fn test_buffer_size() {
565        let (tx, rx) = futures::channel::mpsc::unbounded();
566        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
567        sink.write(Message::Envelope(Envelope {
568            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
569                root_name: "abcdefg".repeat(10),
570                ..Default::default()
571            })),
572            ..Default::default()
573        }))
574        .await
575        .unwrap();
576        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
577        sink.write(Message::Envelope(Envelope {
578            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
579                root_name: "abcdefg".repeat(1000000),
580                ..Default::default()
581            })),
582            ..Default::default()
583        }))
584        .await
585        .unwrap();
586        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
587
588        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
589        stream.read().await.unwrap();
590        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
591        stream.read().await.unwrap();
592        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
593    }
594
595    #[gpui::test]
596    fn test_converting_peer_id_from_and_to_u64() {
597        let peer_id = PeerId {
598            owner_id: 10,
599            id: 3,
600        };
601        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
602        let peer_id = PeerId {
603            owner_id: u32::MAX,
604            id: 3,
605        };
606        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
607        let peer_id = PeerId {
608            owner_id: 10,
609            id: u32::MAX,
610        };
611        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
612        let peer_id = PeerId {
613            owner_id: u32::MAX,
614            id: u32::MAX,
615        };
616        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
617    }
618}