proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::fmt;
  9use std::{
 10    cmp,
 11    fmt::Debug,
 12    io, iter, mem,
 13    time::{Duration, SystemTime, UNIX_EPOCH},
 14};
 15
 16include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 17
 18pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 19    const NAME: &'static str;
 20    const PRIORITY: MessagePriority;
 21    fn into_envelope(
 22        self,
 23        id: u32,
 24        responding_to: Option<u32>,
 25        original_sender_id: Option<PeerId>,
 26    ) -> Envelope;
 27    fn from_envelope(envelope: Envelope) -> Option<Self>;
 28}
 29
 30pub trait EntityMessage: EnvelopedMessage {
 31    fn remote_entity_id(&self) -> u64;
 32}
 33
 34pub trait RequestMessage: EnvelopedMessage {
 35    type Response: EnvelopedMessage;
 36}
 37
 38pub trait AnyTypedEnvelope: 'static + Send + Sync {
 39    fn payload_type_id(&self) -> TypeId;
 40    fn payload_type_name(&self) -> &'static str;
 41    fn as_any(&self) -> &dyn Any;
 42    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 43    fn is_background(&self) -> bool;
 44    fn original_sender_id(&self) -> Option<PeerId>;
 45    fn sender_id(&self) -> ConnectionId;
 46    fn message_id(&self) -> u32;
 47}
 48
 49pub enum MessagePriority {
 50    Foreground,
 51    Background,
 52}
 53
 54impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 55    fn payload_type_id(&self) -> TypeId {
 56        TypeId::of::<T>()
 57    }
 58
 59    fn payload_type_name(&self) -> &'static str {
 60        T::NAME
 61    }
 62
 63    fn as_any(&self) -> &dyn Any {
 64        self
 65    }
 66
 67    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 68        self
 69    }
 70
 71    fn is_background(&self) -> bool {
 72        matches!(T::PRIORITY, MessagePriority::Background)
 73    }
 74
 75    fn original_sender_id(&self) -> Option<PeerId> {
 76        self.original_sender_id
 77    }
 78
 79    fn sender_id(&self) -> ConnectionId {
 80        self.sender_id
 81    }
 82
 83    fn message_id(&self) -> u32 {
 84        self.message_id
 85    }
 86}
 87
 88impl PeerId {
 89    pub fn from_u64(peer_id: u64) -> Self {
 90        let owner_id = (peer_id >> 32) as u32;
 91        let id = peer_id as u32;
 92        Self { owner_id, id }
 93    }
 94
 95    pub fn as_u64(self) -> u64 {
 96        ((self.owner_id as u64) << 32) | (self.id as u64)
 97    }
 98}
 99
100impl Copy for PeerId {}
101
102impl Eq for PeerId {}
103
104impl Ord for PeerId {
105    fn cmp(&self, other: &Self) -> cmp::Ordering {
106        self.owner_id
107            .cmp(&other.owner_id)
108            .then_with(|| self.id.cmp(&other.id))
109    }
110}
111
112impl PartialOrd for PeerId {
113    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
114        Some(self.cmp(other))
115    }
116}
117
118impl std::hash::Hash for PeerId {
119    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
120        self.owner_id.hash(state);
121        self.id.hash(state);
122    }
123}
124
125impl fmt::Display for PeerId {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        write!(f, "{}/{}", self.owner_id, self.id)
128    }
129}
130
131messages!(
132    (Ack, Foreground),
133    (AddProjectCollaborator, Foreground),
134    (ApplyCodeAction, Background),
135    (ApplyCodeActionResponse, Background),
136    (ApplyCompletionAdditionalEdits, Background),
137    (ApplyCompletionAdditionalEditsResponse, Background),
138    (BufferReloaded, Foreground),
139    (BufferSaved, Foreground),
140    (Call, Foreground),
141    (CallCanceled, Foreground),
142    (CancelCall, Foreground),
143    (ChannelMessageSent, Foreground),
144    (CopyProjectEntry, Foreground),
145    (CreateBufferForPeer, Foreground),
146    (CreateProjectEntry, Foreground),
147    (CreateRoom, Foreground),
148    (CreateRoomResponse, Foreground),
149    (DeclineCall, Foreground),
150    (DeleteProjectEntry, Foreground),
151    (Error, Foreground),
152    (Follow, Foreground),
153    (FollowResponse, Foreground),
154    (FormatBuffers, Foreground),
155    (FormatBuffersResponse, Foreground),
156    (FuzzySearchUsers, Foreground),
157    (GetChannelMessages, Foreground),
158    (GetChannelMessagesResponse, Foreground),
159    (GetChannels, Foreground),
160    (GetChannelsResponse, Foreground),
161    (GetCodeActions, Background),
162    (GetCodeActionsResponse, Background),
163    (GetHover, Background),
164    (GetHoverResponse, Background),
165    (GetCompletions, Background),
166    (GetCompletionsResponse, Background),
167    (GetDefinition, Background),
168    (GetDefinitionResponse, Background),
169    (GetTypeDefinition, Background),
170    (GetTypeDefinitionResponse, Background),
171    (GetDocumentHighlights, Background),
172    (GetDocumentHighlightsResponse, Background),
173    (GetReferences, Background),
174    (GetReferencesResponse, Background),
175    (GetProjectSymbols, Background),
176    (GetProjectSymbolsResponse, Background),
177    (GetUsers, Foreground),
178    (Hello, Foreground),
179    (IncomingCall, Foreground),
180    (UsersResponse, Foreground),
181    (JoinChannel, Foreground),
182    (JoinChannelResponse, Foreground),
183    (JoinProject, Foreground),
184    (JoinProjectResponse, Foreground),
185    (JoinRoom, Foreground),
186    (JoinRoomResponse, Foreground),
187    (LeaveChannel, Foreground),
188    (LeaveProject, Foreground),
189    (LeaveRoom, Foreground),
190    (OpenBufferById, Background),
191    (OpenBufferByPath, Background),
192    (OpenBufferForSymbol, Background),
193    (OpenBufferForSymbolResponse, Background),
194    (OpenBufferResponse, Background),
195    (PerformRename, Background),
196    (PerformRenameResponse, Background),
197    (Ping, Foreground),
198    (PrepareRename, Background),
199    (PrepareRenameResponse, Background),
200    (ProjectEntryResponse, Foreground),
201    (RejoinRoom, Foreground),
202    (RejoinRoomResponse, Foreground),
203    (RemoveContact, Foreground),
204    (ReloadBuffers, Foreground),
205    (ReloadBuffersResponse, Foreground),
206    (RemoveProjectCollaborator, Foreground),
207    (RenameProjectEntry, Foreground),
208    (RequestContact, Foreground),
209    (RespondToContactRequest, Foreground),
210    (RoomUpdated, Foreground),
211    (SaveBuffer, Foreground),
212    (SearchProject, Background),
213    (SearchProjectResponse, Background),
214    (SendChannelMessage, Foreground),
215    (SendChannelMessageResponse, Foreground),
216    (ShareProject, Foreground),
217    (ShareProjectResponse, Foreground),
218    (ShowContacts, Foreground),
219    (StartLanguageServer, Foreground),
220    (SynchronizeBuffers, Foreground),
221    (SynchronizeBuffersResponse, Foreground),
222    (Test, Foreground),
223    (Unfollow, Foreground),
224    (UnshareProject, Foreground),
225    (UpdateBuffer, Foreground),
226    (UpdateBufferFile, Foreground),
227    (UpdateContacts, Foreground),
228    (UpdateDiagnosticSummary, Foreground),
229    (UpdateFollowers, Foreground),
230    (UpdateInviteInfo, Foreground),
231    (UpdateLanguageServer, Foreground),
232    (UpdateParticipantLocation, Foreground),
233    (UpdateProject, Foreground),
234    (UpdateProjectCollaborator, Foreground),
235    (UpdateWorktree, Foreground),
236    (UpdateDiffBase, Background),
237    (GetPrivateUserInfo, Foreground),
238    (GetPrivateUserInfoResponse, Foreground),
239);
240
241request_messages!(
242    (ApplyCodeAction, ApplyCodeActionResponse),
243    (
244        ApplyCompletionAdditionalEdits,
245        ApplyCompletionAdditionalEditsResponse
246    ),
247    (Call, Ack),
248    (CancelCall, Ack),
249    (CopyProjectEntry, ProjectEntryResponse),
250    (CreateProjectEntry, ProjectEntryResponse),
251    (CreateRoom, CreateRoomResponse),
252    (DeclineCall, Ack),
253    (DeleteProjectEntry, ProjectEntryResponse),
254    (Follow, FollowResponse),
255    (FormatBuffers, FormatBuffersResponse),
256    (GetChannelMessages, GetChannelMessagesResponse),
257    (GetChannels, GetChannelsResponse),
258    (GetCodeActions, GetCodeActionsResponse),
259    (GetHover, GetHoverResponse),
260    (GetCompletions, GetCompletionsResponse),
261    (GetDefinition, GetDefinitionResponse),
262    (GetTypeDefinition, GetTypeDefinitionResponse),
263    (GetDocumentHighlights, GetDocumentHighlightsResponse),
264    (GetReferences, GetReferencesResponse),
265    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
266    (GetProjectSymbols, GetProjectSymbolsResponse),
267    (FuzzySearchUsers, UsersResponse),
268    (GetUsers, UsersResponse),
269    (JoinChannel, JoinChannelResponse),
270    (JoinProject, JoinProjectResponse),
271    (JoinRoom, JoinRoomResponse),
272    (RejoinRoom, RejoinRoomResponse),
273    (IncomingCall, Ack),
274    (OpenBufferById, OpenBufferResponse),
275    (OpenBufferByPath, OpenBufferResponse),
276    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
277    (Ping, Ack),
278    (PerformRename, PerformRenameResponse),
279    (PrepareRename, PrepareRenameResponse),
280    (ReloadBuffers, ReloadBuffersResponse),
281    (RequestContact, Ack),
282    (RemoveContact, Ack),
283    (RespondToContactRequest, Ack),
284    (RenameProjectEntry, ProjectEntryResponse),
285    (SaveBuffer, BufferSaved),
286    (SearchProject, SearchProjectResponse),
287    (SendChannelMessage, SendChannelMessageResponse),
288    (ShareProject, ShareProjectResponse),
289    (SynchronizeBuffers, SynchronizeBuffersResponse),
290    (Test, Test),
291    (UpdateBuffer, Ack),
292    (UpdateParticipantLocation, Ack),
293    (UpdateProject, Ack),
294    (UpdateWorktree, Ack),
295);
296
297entity_messages!(
298    project_id,
299    AddProjectCollaborator,
300    ApplyCodeAction,
301    ApplyCompletionAdditionalEdits,
302    BufferReloaded,
303    BufferSaved,
304    CopyProjectEntry,
305    CreateBufferForPeer,
306    CreateProjectEntry,
307    DeleteProjectEntry,
308    Follow,
309    FormatBuffers,
310    GetCodeActions,
311    GetCompletions,
312    GetDefinition,
313    GetTypeDefinition,
314    GetDocumentHighlights,
315    GetHover,
316    GetReferences,
317    GetProjectSymbols,
318    JoinProject,
319    LeaveProject,
320    OpenBufferById,
321    OpenBufferByPath,
322    OpenBufferForSymbol,
323    PerformRename,
324    PrepareRename,
325    ReloadBuffers,
326    RemoveProjectCollaborator,
327    RenameProjectEntry,
328    SaveBuffer,
329    SearchProject,
330    StartLanguageServer,
331    SynchronizeBuffers,
332    Unfollow,
333    UnshareProject,
334    UpdateBuffer,
335    UpdateBufferFile,
336    UpdateDiagnosticSummary,
337    UpdateFollowers,
338    UpdateLanguageServer,
339    UpdateProject,
340    UpdateProjectCollaborator,
341    UpdateWorktree,
342    UpdateDiffBase
343);
344
345entity_messages!(channel_id, ChannelMessageSent);
346
347const KIB: usize = 1024;
348const MIB: usize = KIB * 1024;
349const MAX_BUFFER_LEN: usize = MIB;
350
351/// A stream of protobuf messages.
352pub struct MessageStream<S> {
353    stream: S,
354    encoding_buffer: Vec<u8>,
355}
356
357#[allow(clippy::large_enum_variant)]
358#[derive(Debug)]
359pub enum Message {
360    Envelope(Envelope),
361    Ping,
362    Pong,
363}
364
365impl<S> MessageStream<S> {
366    pub fn new(stream: S) -> Self {
367        Self {
368            stream,
369            encoding_buffer: Vec::new(),
370        }
371    }
372
373    pub fn inner_mut(&mut self) -> &mut S {
374        &mut self.stream
375    }
376}
377
378impl<S> MessageStream<S>
379where
380    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
381{
382    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
383        #[cfg(any(test, feature = "test-support"))]
384        const COMPRESSION_LEVEL: i32 = -7;
385
386        #[cfg(not(any(test, feature = "test-support")))]
387        const COMPRESSION_LEVEL: i32 = 4;
388
389        match message {
390            Message::Envelope(message) => {
391                self.encoding_buffer.reserve(message.encoded_len());
392                message
393                    .encode(&mut self.encoding_buffer)
394                    .map_err(io::Error::from)?;
395                let buffer =
396                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
397                        .unwrap();
398
399                self.encoding_buffer.clear();
400                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
401                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
402            }
403            Message::Ping => {
404                self.stream
405                    .send(WebSocketMessage::Ping(Default::default()))
406                    .await?;
407            }
408            Message::Pong => {
409                self.stream
410                    .send(WebSocketMessage::Pong(Default::default()))
411                    .await?;
412            }
413        }
414
415        Ok(())
416    }
417}
418
419impl<S> MessageStream<S>
420where
421    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
422{
423    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
424        while let Some(bytes) = self.stream.next().await {
425            match bytes? {
426                WebSocketMessage::Binary(bytes) => {
427                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
428                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
429                        .map_err(io::Error::from)?;
430
431                    self.encoding_buffer.clear();
432                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
433                    return Ok(Message::Envelope(envelope));
434                }
435                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
436                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
437                WebSocketMessage::Close(_) => break,
438                _ => {}
439            }
440        }
441        Err(anyhow!("connection closed"))
442    }
443}
444
445impl From<Timestamp> for SystemTime {
446    fn from(val: Timestamp) -> Self {
447        UNIX_EPOCH
448            .checked_add(Duration::new(val.seconds, val.nanos))
449            .unwrap()
450    }
451}
452
453impl From<SystemTime> for Timestamp {
454    fn from(time: SystemTime) -> Self {
455        let duration = time.duration_since(UNIX_EPOCH).unwrap();
456        Self {
457            seconds: duration.as_secs(),
458            nanos: duration.subsec_nanos(),
459        }
460    }
461}
462
463impl From<u128> for Nonce {
464    fn from(nonce: u128) -> Self {
465        let upper_half = (nonce >> 64) as u64;
466        let lower_half = nonce as u64;
467        Self {
468            upper_half,
469            lower_half,
470        }
471    }
472}
473
474impl From<Nonce> for u128 {
475    fn from(nonce: Nonce) -> Self {
476        let upper_half = (nonce.upper_half as u128) << 64;
477        let lower_half = nonce.lower_half as u128;
478        upper_half | lower_half
479    }
480}
481
482pub fn split_worktree_update(
483    mut message: UpdateWorktree,
484    max_chunk_size: usize,
485) -> impl Iterator<Item = UpdateWorktree> {
486    let mut done = false;
487    iter::from_fn(move || {
488        if done {
489            return None;
490        }
491
492        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
493        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
494        done = message.updated_entries.is_empty();
495        Some(UpdateWorktree {
496            project_id: message.project_id,
497            worktree_id: message.worktree_id,
498            root_name: message.root_name.clone(),
499            abs_path: message.abs_path.clone(),
500            updated_entries,
501            removed_entries: mem::take(&mut message.removed_entries),
502            scan_id: message.scan_id,
503            is_last_update: done && message.is_last_update,
504        })
505    })
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511
512    #[gpui::test]
513    async fn test_buffer_size() {
514        let (tx, rx) = futures::channel::mpsc::unbounded();
515        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
516        sink.write(Message::Envelope(Envelope {
517            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
518                root_name: "abcdefg".repeat(10),
519                ..Default::default()
520            })),
521            ..Default::default()
522        }))
523        .await
524        .unwrap();
525        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
526        sink.write(Message::Envelope(Envelope {
527            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
528                root_name: "abcdefg".repeat(1000000),
529                ..Default::default()
530            })),
531            ..Default::default()
532        }))
533        .await
534        .unwrap();
535        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
536
537        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
538        stream.read().await.unwrap();
539        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
540        stream.read().await.unwrap();
541        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
542    }
543
544    #[gpui::test]
545    fn test_converting_peer_id_from_and_to_u64() {
546        let peer_id = PeerId {
547            owner_id: 10,
548            id: 3,
549        };
550        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
551        let peer_id = PeerId {
552            owner_id: u32::MAX,
553            id: 3,
554        };
555        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
556        let peer_id = PeerId {
557            owner_id: 10,
558            id: u32::MAX,
559        };
560        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
561        let peer_id = PeerId {
562            owner_id: u32::MAX,
563            id: u32::MAX,
564        };
565        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
566    }
567}