proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::fmt;
  9use std::{
 10    cmp,
 11    fmt::Debug,
 12    io, iter, mem,
 13    time::{Duration, SystemTime, UNIX_EPOCH},
 14};
 15
 16include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 17
 18pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 19    const NAME: &'static str;
 20    const PRIORITY: MessagePriority;
 21    fn into_envelope(
 22        self,
 23        id: u32,
 24        responding_to: Option<u32>,
 25        original_sender_id: Option<PeerId>,
 26    ) -> Envelope;
 27    fn from_envelope(envelope: Envelope) -> Option<Self>;
 28}
 29
 30pub trait EntityMessage: EnvelopedMessage {
 31    fn remote_entity_id(&self) -> u64;
 32}
 33
 34pub trait RequestMessage: EnvelopedMessage {
 35    type Response: EnvelopedMessage;
 36}
 37
 38pub trait AnyTypedEnvelope: 'static + Send + Sync {
 39    fn payload_type_id(&self) -> TypeId;
 40    fn payload_type_name(&self) -> &'static str;
 41    fn as_any(&self) -> &dyn Any;
 42    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 43    fn is_background(&self) -> bool;
 44    fn original_sender_id(&self) -> Option<PeerId>;
 45}
 46
 47pub enum MessagePriority {
 48    Foreground,
 49    Background,
 50}
 51
 52impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 53    fn payload_type_id(&self) -> TypeId {
 54        TypeId::of::<T>()
 55    }
 56
 57    fn payload_type_name(&self) -> &'static str {
 58        T::NAME
 59    }
 60
 61    fn as_any(&self) -> &dyn Any {
 62        self
 63    }
 64
 65    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 66        self
 67    }
 68
 69    fn is_background(&self) -> bool {
 70        matches!(T::PRIORITY, MessagePriority::Background)
 71    }
 72
 73    fn original_sender_id(&self) -> Option<PeerId> {
 74        self.original_sender_id
 75    }
 76}
 77
 78impl PeerId {
 79    pub fn from_u64(peer_id: u64) -> Self {
 80        let owner_id = (peer_id >> 32) as u32;
 81        let id = peer_id as u32;
 82        Self { owner_id, id }
 83    }
 84
 85    pub fn as_u64(self) -> u64 {
 86        ((self.owner_id as u64) << 32) | (self.id as u64)
 87    }
 88}
 89
 90impl Copy for PeerId {}
 91
 92impl Eq for PeerId {}
 93
 94impl Ord for PeerId {
 95    fn cmp(&self, other: &Self) -> cmp::Ordering {
 96        self.owner_id
 97            .cmp(&other.owner_id)
 98            .then_with(|| self.id.cmp(&other.id))
 99    }
100}
101
102impl PartialOrd for PeerId {
103    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
104        Some(self.cmp(other))
105    }
106}
107
108impl std::hash::Hash for PeerId {
109    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
110        self.owner_id.hash(state);
111        self.id.hash(state);
112    }
113}
114
115impl fmt::Display for PeerId {
116    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117        write!(f, "{}/{}", self.owner_id, self.id)
118    }
119}
120
121messages!(
122    (Ack, Foreground),
123    (AddProjectCollaborator, Foreground),
124    (ApplyCodeAction, Background),
125    (ApplyCodeActionResponse, Background),
126    (ApplyCompletionAdditionalEdits, Background),
127    (ApplyCompletionAdditionalEditsResponse, Background),
128    (BufferReloaded, Foreground),
129    (BufferSaved, Foreground),
130    (Call, Foreground),
131    (CallCanceled, Foreground),
132    (CancelCall, Foreground),
133    (ChannelMessageSent, Foreground),
134    (CopyProjectEntry, Foreground),
135    (CreateBufferForPeer, Foreground),
136    (CreateProjectEntry, Foreground),
137    (CreateRoom, Foreground),
138    (CreateRoomResponse, Foreground),
139    (DeclineCall, Foreground),
140    (DeleteProjectEntry, Foreground),
141    (Error, Foreground),
142    (Follow, Foreground),
143    (FollowResponse, Foreground),
144    (FormatBuffers, Foreground),
145    (FormatBuffersResponse, Foreground),
146    (FuzzySearchUsers, Foreground),
147    (GetChannelMessages, Foreground),
148    (GetChannelMessagesResponse, Foreground),
149    (GetChannels, Foreground),
150    (GetChannelsResponse, Foreground),
151    (GetCodeActions, Background),
152    (GetCodeActionsResponse, Background),
153    (GetHover, Background),
154    (GetHoverResponse, Background),
155    (GetCompletions, Background),
156    (GetCompletionsResponse, Background),
157    (GetDefinition, Background),
158    (GetDefinitionResponse, Background),
159    (GetTypeDefinition, Background),
160    (GetTypeDefinitionResponse, Background),
161    (GetDocumentHighlights, Background),
162    (GetDocumentHighlightsResponse, Background),
163    (GetReferences, Background),
164    (GetReferencesResponse, Background),
165    (GetProjectSymbols, Background),
166    (GetProjectSymbolsResponse, Background),
167    (GetUsers, Foreground),
168    (Hello, Foreground),
169    (IncomingCall, Foreground),
170    (UsersResponse, Foreground),
171    (JoinChannel, Foreground),
172    (JoinChannelResponse, Foreground),
173    (JoinProject, Foreground),
174    (JoinProjectResponse, Foreground),
175    (JoinRoom, Foreground),
176    (JoinRoomResponse, Foreground),
177    (LeaveChannel, Foreground),
178    (LeaveProject, Foreground),
179    (LeaveRoom, Foreground),
180    (OpenBufferById, Background),
181    (OpenBufferByPath, Background),
182    (OpenBufferForSymbol, Background),
183    (OpenBufferForSymbolResponse, Background),
184    (OpenBufferResponse, Background),
185    (PerformRename, Background),
186    (PerformRenameResponse, Background),
187    (Ping, Foreground),
188    (PrepareRename, Background),
189    (PrepareRenameResponse, Background),
190    (ProjectEntryResponse, Foreground),
191    (RemoveContact, Foreground),
192    (ReloadBuffers, Foreground),
193    (ReloadBuffersResponse, Foreground),
194    (RemoveProjectCollaborator, Foreground),
195    (RenameProjectEntry, Foreground),
196    (RequestContact, Foreground),
197    (RespondToContactRequest, Foreground),
198    (RoomUpdated, Foreground),
199    (SaveBuffer, Foreground),
200    (SearchProject, Background),
201    (SearchProjectResponse, Background),
202    (SendChannelMessage, Foreground),
203    (SendChannelMessageResponse, Foreground),
204    (ShareProject, Foreground),
205    (ShareProjectResponse, Foreground),
206    (ShowContacts, Foreground),
207    (StartLanguageServer, Foreground),
208    (Test, Foreground),
209    (Unfollow, Foreground),
210    (UnshareProject, Foreground),
211    (UpdateBuffer, Foreground),
212    (UpdateBufferFile, Foreground),
213    (UpdateContacts, Foreground),
214    (UpdateDiagnosticSummary, Foreground),
215    (UpdateFollowers, Foreground),
216    (UpdateInviteInfo, Foreground),
217    (UpdateLanguageServer, Foreground),
218    (UpdateParticipantLocation, Foreground),
219    (UpdateProject, Foreground),
220    (UpdateWorktree, Foreground),
221    (UpdateDiffBase, Background),
222    (GetPrivateUserInfo, Foreground),
223    (GetPrivateUserInfoResponse, Foreground),
224);
225
226request_messages!(
227    (ApplyCodeAction, ApplyCodeActionResponse),
228    (
229        ApplyCompletionAdditionalEdits,
230        ApplyCompletionAdditionalEditsResponse
231    ),
232    (Call, Ack),
233    (CancelCall, Ack),
234    (CopyProjectEntry, ProjectEntryResponse),
235    (CreateProjectEntry, ProjectEntryResponse),
236    (CreateRoom, CreateRoomResponse),
237    (DeclineCall, Ack),
238    (DeleteProjectEntry, ProjectEntryResponse),
239    (Follow, FollowResponse),
240    (FormatBuffers, FormatBuffersResponse),
241    (GetChannelMessages, GetChannelMessagesResponse),
242    (GetChannels, GetChannelsResponse),
243    (GetCodeActions, GetCodeActionsResponse),
244    (GetHover, GetHoverResponse),
245    (GetCompletions, GetCompletionsResponse),
246    (GetDefinition, GetDefinitionResponse),
247    (GetTypeDefinition, GetTypeDefinitionResponse),
248    (GetDocumentHighlights, GetDocumentHighlightsResponse),
249    (GetReferences, GetReferencesResponse),
250    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
251    (GetProjectSymbols, GetProjectSymbolsResponse),
252    (FuzzySearchUsers, UsersResponse),
253    (GetUsers, UsersResponse),
254    (JoinChannel, JoinChannelResponse),
255    (JoinProject, JoinProjectResponse),
256    (JoinRoom, JoinRoomResponse),
257    (IncomingCall, Ack),
258    (OpenBufferById, OpenBufferResponse),
259    (OpenBufferByPath, OpenBufferResponse),
260    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
261    (Ping, Ack),
262    (PerformRename, PerformRenameResponse),
263    (PrepareRename, PrepareRenameResponse),
264    (ReloadBuffers, ReloadBuffersResponse),
265    (RequestContact, Ack),
266    (RemoveContact, Ack),
267    (RespondToContactRequest, Ack),
268    (RenameProjectEntry, ProjectEntryResponse),
269    (SaveBuffer, BufferSaved),
270    (SearchProject, SearchProjectResponse),
271    (SendChannelMessage, SendChannelMessageResponse),
272    (ShareProject, ShareProjectResponse),
273    (Test, Test),
274    (UpdateBuffer, Ack),
275    (UpdateParticipantLocation, Ack),
276    (UpdateProject, Ack),
277    (UpdateWorktree, Ack),
278);
279
280entity_messages!(
281    project_id,
282    AddProjectCollaborator,
283    ApplyCodeAction,
284    ApplyCompletionAdditionalEdits,
285    BufferReloaded,
286    BufferSaved,
287    CopyProjectEntry,
288    CreateBufferForPeer,
289    CreateProjectEntry,
290    DeleteProjectEntry,
291    Follow,
292    FormatBuffers,
293    GetCodeActions,
294    GetCompletions,
295    GetDefinition,
296    GetTypeDefinition,
297    GetDocumentHighlights,
298    GetHover,
299    GetReferences,
300    GetProjectSymbols,
301    JoinProject,
302    LeaveProject,
303    OpenBufferById,
304    OpenBufferByPath,
305    OpenBufferForSymbol,
306    PerformRename,
307    PrepareRename,
308    ReloadBuffers,
309    RemoveProjectCollaborator,
310    RenameProjectEntry,
311    SaveBuffer,
312    SearchProject,
313    StartLanguageServer,
314    Unfollow,
315    UnshareProject,
316    UpdateBuffer,
317    UpdateBufferFile,
318    UpdateDiagnosticSummary,
319    UpdateFollowers,
320    UpdateLanguageServer,
321    UpdateProject,
322    UpdateWorktree,
323    UpdateDiffBase
324);
325
326entity_messages!(channel_id, ChannelMessageSent);
327
328const KIB: usize = 1024;
329const MIB: usize = KIB * 1024;
330const MAX_BUFFER_LEN: usize = MIB;
331
332/// A stream of protobuf messages.
333pub struct MessageStream<S> {
334    stream: S,
335    encoding_buffer: Vec<u8>,
336}
337
338#[allow(clippy::large_enum_variant)]
339#[derive(Debug)]
340pub enum Message {
341    Envelope(Envelope),
342    Ping,
343    Pong,
344}
345
346impl<S> MessageStream<S> {
347    pub fn new(stream: S) -> Self {
348        Self {
349            stream,
350            encoding_buffer: Vec::new(),
351        }
352    }
353
354    pub fn inner_mut(&mut self) -> &mut S {
355        &mut self.stream
356    }
357}
358
359impl<S> MessageStream<S>
360where
361    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
362{
363    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
364        #[cfg(any(test, feature = "test-support"))]
365        const COMPRESSION_LEVEL: i32 = -7;
366
367        #[cfg(not(any(test, feature = "test-support")))]
368        const COMPRESSION_LEVEL: i32 = 4;
369
370        match message {
371            Message::Envelope(message) => {
372                self.encoding_buffer.reserve(message.encoded_len());
373                message
374                    .encode(&mut self.encoding_buffer)
375                    .map_err(io::Error::from)?;
376                let buffer =
377                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
378                        .unwrap();
379
380                self.encoding_buffer.clear();
381                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
382                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
383            }
384            Message::Ping => {
385                self.stream
386                    .send(WebSocketMessage::Ping(Default::default()))
387                    .await?;
388            }
389            Message::Pong => {
390                self.stream
391                    .send(WebSocketMessage::Pong(Default::default()))
392                    .await?;
393            }
394        }
395
396        Ok(())
397    }
398}
399
400impl<S> MessageStream<S>
401where
402    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
403{
404    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
405        while let Some(bytes) = self.stream.next().await {
406            match bytes? {
407                WebSocketMessage::Binary(bytes) => {
408                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
409                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
410                        .map_err(io::Error::from)?;
411
412                    self.encoding_buffer.clear();
413                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
414                    return Ok(Message::Envelope(envelope));
415                }
416                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
417                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
418                WebSocketMessage::Close(_) => break,
419                _ => {}
420            }
421        }
422        Err(anyhow!("connection closed"))
423    }
424}
425
426impl From<Timestamp> for SystemTime {
427    fn from(val: Timestamp) -> Self {
428        UNIX_EPOCH
429            .checked_add(Duration::new(val.seconds, val.nanos))
430            .unwrap()
431    }
432}
433
434impl From<SystemTime> for Timestamp {
435    fn from(time: SystemTime) -> Self {
436        let duration = time.duration_since(UNIX_EPOCH).unwrap();
437        Self {
438            seconds: duration.as_secs(),
439            nanos: duration.subsec_nanos(),
440        }
441    }
442}
443
444impl From<u128> for Nonce {
445    fn from(nonce: u128) -> Self {
446        let upper_half = (nonce >> 64) as u64;
447        let lower_half = nonce as u64;
448        Self {
449            upper_half,
450            lower_half,
451        }
452    }
453}
454
455impl From<Nonce> for u128 {
456    fn from(nonce: Nonce) -> Self {
457        let upper_half = (nonce.upper_half as u128) << 64;
458        let lower_half = nonce.lower_half as u128;
459        upper_half | lower_half
460    }
461}
462
463pub fn split_worktree_update(
464    mut message: UpdateWorktree,
465    max_chunk_size: usize,
466) -> impl Iterator<Item = UpdateWorktree> {
467    let mut done = false;
468    iter::from_fn(move || {
469        if done {
470            return None;
471        }
472
473        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
474        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
475        done = message.updated_entries.is_empty();
476        Some(UpdateWorktree {
477            project_id: message.project_id,
478            worktree_id: message.worktree_id,
479            root_name: message.root_name.clone(),
480            abs_path: message.abs_path.clone(),
481            updated_entries,
482            removed_entries: mem::take(&mut message.removed_entries),
483            scan_id: message.scan_id,
484            is_last_update: done && message.is_last_update,
485        })
486    })
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492
493    #[gpui::test]
494    async fn test_buffer_size() {
495        let (tx, rx) = futures::channel::mpsc::unbounded();
496        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
497        sink.write(Message::Envelope(Envelope {
498            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
499                root_name: "abcdefg".repeat(10),
500                ..Default::default()
501            })),
502            ..Default::default()
503        }))
504        .await
505        .unwrap();
506        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
507        sink.write(Message::Envelope(Envelope {
508            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
509                root_name: "abcdefg".repeat(1000000),
510                ..Default::default()
511            })),
512            ..Default::default()
513        }))
514        .await
515        .unwrap();
516        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
517
518        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
519        stream.read().await.unwrap();
520        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
521        stream.read().await.unwrap();
522        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
523    }
524
525    #[gpui::test]
526    fn test_converting_peer_id_from_and_to_u64() {
527        let peer_id = PeerId {
528            owner_id: 10,
529            id: 3,
530        };
531        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
532        let peer_id = PeerId {
533            owner_id: u32::MAX,
534            id: 3,
535        };
536        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
537        let peer_id = PeerId {
538            owner_id: 10,
539            id: u32::MAX,
540        };
541        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
542        let peer_id = PeerId {
543            owner_id: u32::MAX,
544            id: u32::MAX,
545        };
546        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
547    }
548}