proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::fmt;
  9use std::{
 10    cmp,
 11    fmt::Debug,
 12    io, iter, mem,
 13    time::{Duration, SystemTime, UNIX_EPOCH},
 14};
 15
 16include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 17
 18pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 19    const NAME: &'static str;
 20    const PRIORITY: MessagePriority;
 21    fn into_envelope(
 22        self,
 23        id: u32,
 24        responding_to: Option<u32>,
 25        original_sender_id: Option<PeerId>,
 26    ) -> Envelope;
 27    fn from_envelope(envelope: Envelope) -> Option<Self>;
 28}
 29
 30pub trait EntityMessage: EnvelopedMessage {
 31    fn remote_entity_id(&self) -> u64;
 32}
 33
 34pub trait RequestMessage: EnvelopedMessage {
 35    type Response: EnvelopedMessage;
 36}
 37
 38pub trait AnyTypedEnvelope: 'static + Send + Sync {
 39    fn payload_type_id(&self) -> TypeId;
 40    fn payload_type_name(&self) -> &'static str;
 41    fn as_any(&self) -> &dyn Any;
 42    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 43    fn is_background(&self) -> bool;
 44    fn original_sender_id(&self) -> Option<PeerId>;
 45}
 46
 47pub enum MessagePriority {
 48    Foreground,
 49    Background,
 50}
 51
 52impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 53    fn payload_type_id(&self) -> TypeId {
 54        TypeId::of::<T>()
 55    }
 56
 57    fn payload_type_name(&self) -> &'static str {
 58        T::NAME
 59    }
 60
 61    fn as_any(&self) -> &dyn Any {
 62        self
 63    }
 64
 65    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 66        self
 67    }
 68
 69    fn is_background(&self) -> bool {
 70        matches!(T::PRIORITY, MessagePriority::Background)
 71    }
 72
 73    fn original_sender_id(&self) -> Option<PeerId> {
 74        self.original_sender_id
 75    }
 76}
 77
 78impl PeerId {
 79    pub fn from_u64(peer_id: u64) -> Self {
 80        let owner_id = (peer_id >> 32) as u32;
 81        let id = peer_id as u32;
 82        Self { owner_id, id }
 83    }
 84
 85    pub fn as_u64(self) -> u64 {
 86        ((self.owner_id as u64) << 32) | (self.id as u64)
 87    }
 88}
 89
 90impl Copy for PeerId {}
 91
 92impl Eq for PeerId {}
 93
 94impl Ord for PeerId {
 95    fn cmp(&self, other: &Self) -> cmp::Ordering {
 96        self.owner_id
 97            .cmp(&other.owner_id)
 98            .then_with(|| self.id.cmp(&other.id))
 99    }
100}
101
102impl PartialOrd for PeerId {
103    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
104        Some(self.cmp(other))
105    }
106}
107
108impl std::hash::Hash for PeerId {
109    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
110        self.owner_id.hash(state);
111        self.id.hash(state);
112    }
113}
114
115impl fmt::Display for PeerId {
116    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117        write!(f, "{}/{}", self.owner_id, self.id)
118    }
119}
120
121messages!(
122    (Ack, Foreground),
123    (AddProjectCollaborator, Foreground),
124    (ApplyCodeAction, Background),
125    (ApplyCodeActionResponse, Background),
126    (ApplyCompletionAdditionalEdits, Background),
127    (ApplyCompletionAdditionalEditsResponse, Background),
128    (BufferReloaded, Foreground),
129    (BufferSaved, Foreground),
130    (Call, Foreground),
131    (CallCanceled, Foreground),
132    (CancelCall, Foreground),
133    (ChannelMessageSent, Foreground),
134    (CopyProjectEntry, Foreground),
135    (CreateBufferForPeer, Foreground),
136    (CreateProjectEntry, Foreground),
137    (CreateRoom, Foreground),
138    (CreateRoomResponse, Foreground),
139    (DeclineCall, Foreground),
140    (DeleteProjectEntry, Foreground),
141    (Error, Foreground),
142    (Follow, Foreground),
143    (FollowResponse, Foreground),
144    (FormatBuffers, Foreground),
145    (FormatBuffersResponse, Foreground),
146    (FuzzySearchUsers, Foreground),
147    (GetChannelMessages, Foreground),
148    (GetChannelMessagesResponse, Foreground),
149    (GetChannels, Foreground),
150    (GetChannelsResponse, Foreground),
151    (GetCodeActions, Background),
152    (GetCodeActionsResponse, Background),
153    (GetHover, Background),
154    (GetHoverResponse, Background),
155    (GetCompletions, Background),
156    (GetCompletionsResponse, Background),
157    (GetDefinition, Background),
158    (GetDefinitionResponse, Background),
159    (GetTypeDefinition, Background),
160    (GetTypeDefinitionResponse, Background),
161    (GetDocumentHighlights, Background),
162    (GetDocumentHighlightsResponse, Background),
163    (GetReferences, Background),
164    (GetReferencesResponse, Background),
165    (GetProjectSymbols, Background),
166    (GetProjectSymbolsResponse, Background),
167    (GetUsers, Foreground),
168    (Hello, Foreground),
169    (IncomingCall, Foreground),
170    (UsersResponse, Foreground),
171    (JoinChannel, Foreground),
172    (JoinChannelResponse, Foreground),
173    (JoinProject, Foreground),
174    (JoinProjectResponse, Foreground),
175    (JoinRoom, Foreground),
176    (JoinRoomResponse, Foreground),
177    (LeaveChannel, Foreground),
178    (LeaveProject, Foreground),
179    (LeaveRoom, Foreground),
180    (OpenBufferById, Background),
181    (OpenBufferByPath, Background),
182    (OpenBufferForSymbol, Background),
183    (OpenBufferForSymbolResponse, Background),
184    (OpenBufferResponse, Background),
185    (PerformRename, Background),
186    (PerformRenameResponse, Background),
187    (Ping, Foreground),
188    (PrepareRename, Background),
189    (PrepareRenameResponse, Background),
190    (ProjectEntryResponse, Foreground),
191    (RejoinRoom, Foreground),
192    (RejoinRoomResponse, Foreground),
193    (RemoveContact, Foreground),
194    (ReloadBuffers, Foreground),
195    (ReloadBuffersResponse, Foreground),
196    (RemoveProjectCollaborator, Foreground),
197    (RenameProjectEntry, Foreground),
198    (RequestContact, Foreground),
199    (RespondToContactRequest, Foreground),
200    (RoomUpdated, Foreground),
201    (SaveBuffer, Foreground),
202    (SearchProject, Background),
203    (SearchProjectResponse, Background),
204    (SendChannelMessage, Foreground),
205    (SendChannelMessageResponse, Foreground),
206    (ShareProject, Foreground),
207    (ShareProjectResponse, Foreground),
208    (ShowContacts, Foreground),
209    (StartLanguageServer, Foreground),
210    (SynchronizeBuffers, Foreground),
211    (SynchronizeBuffersResponse, Foreground),
212    (Test, Foreground),
213    (Unfollow, Foreground),
214    (UnshareProject, Foreground),
215    (UpdateBuffer, Foreground),
216    (UpdateBufferFile, Foreground),
217    (UpdateContacts, Foreground),
218    (UpdateDiagnosticSummary, Foreground),
219    (UpdateFollowers, Foreground),
220    (UpdateInviteInfo, Foreground),
221    (UpdateLanguageServer, Foreground),
222    (UpdateParticipantLocation, Foreground),
223    (UpdateProject, Foreground),
224    (UpdateProjectCollaborator, Foreground),
225    (UpdateWorktree, Foreground),
226    (UpdateDiffBase, Background),
227    (GetPrivateUserInfo, Foreground),
228    (GetPrivateUserInfoResponse, Foreground),
229);
230
231request_messages!(
232    (ApplyCodeAction, ApplyCodeActionResponse),
233    (
234        ApplyCompletionAdditionalEdits,
235        ApplyCompletionAdditionalEditsResponse
236    ),
237    (Call, Ack),
238    (CancelCall, Ack),
239    (CopyProjectEntry, ProjectEntryResponse),
240    (CreateProjectEntry, ProjectEntryResponse),
241    (CreateRoom, CreateRoomResponse),
242    (DeclineCall, Ack),
243    (DeleteProjectEntry, ProjectEntryResponse),
244    (Follow, FollowResponse),
245    (FormatBuffers, FormatBuffersResponse),
246    (GetChannelMessages, GetChannelMessagesResponse),
247    (GetChannels, GetChannelsResponse),
248    (GetCodeActions, GetCodeActionsResponse),
249    (GetHover, GetHoverResponse),
250    (GetCompletions, GetCompletionsResponse),
251    (GetDefinition, GetDefinitionResponse),
252    (GetTypeDefinition, GetTypeDefinitionResponse),
253    (GetDocumentHighlights, GetDocumentHighlightsResponse),
254    (GetReferences, GetReferencesResponse),
255    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
256    (GetProjectSymbols, GetProjectSymbolsResponse),
257    (FuzzySearchUsers, UsersResponse),
258    (GetUsers, UsersResponse),
259    (JoinChannel, JoinChannelResponse),
260    (JoinProject, JoinProjectResponse),
261    (JoinRoom, JoinRoomResponse),
262    (RejoinRoom, RejoinRoomResponse),
263    (IncomingCall, Ack),
264    (OpenBufferById, OpenBufferResponse),
265    (OpenBufferByPath, OpenBufferResponse),
266    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
267    (Ping, Ack),
268    (PerformRename, PerformRenameResponse),
269    (PrepareRename, PrepareRenameResponse),
270    (ReloadBuffers, ReloadBuffersResponse),
271    (RequestContact, Ack),
272    (RemoveContact, Ack),
273    (RespondToContactRequest, Ack),
274    (RenameProjectEntry, ProjectEntryResponse),
275    (SaveBuffer, BufferSaved),
276    (SearchProject, SearchProjectResponse),
277    (SendChannelMessage, SendChannelMessageResponse),
278    (ShareProject, ShareProjectResponse),
279    (SynchronizeBuffers, SynchronizeBuffersResponse),
280    (Test, Test),
281    (UpdateBuffer, Ack),
282    (UpdateParticipantLocation, Ack),
283    (UpdateProject, Ack),
284    (UpdateWorktree, Ack),
285);
286
287entity_messages!(
288    project_id,
289    AddProjectCollaborator,
290    ApplyCodeAction,
291    ApplyCompletionAdditionalEdits,
292    BufferReloaded,
293    BufferSaved,
294    CopyProjectEntry,
295    CreateBufferForPeer,
296    CreateProjectEntry,
297    DeleteProjectEntry,
298    Follow,
299    FormatBuffers,
300    GetCodeActions,
301    GetCompletions,
302    GetDefinition,
303    GetTypeDefinition,
304    GetDocumentHighlights,
305    GetHover,
306    GetReferences,
307    GetProjectSymbols,
308    JoinProject,
309    LeaveProject,
310    OpenBufferById,
311    OpenBufferByPath,
312    OpenBufferForSymbol,
313    PerformRename,
314    PrepareRename,
315    ReloadBuffers,
316    RemoveProjectCollaborator,
317    RenameProjectEntry,
318    SaveBuffer,
319    SearchProject,
320    StartLanguageServer,
321    SynchronizeBuffers,
322    Unfollow,
323    UnshareProject,
324    UpdateBuffer,
325    UpdateBufferFile,
326    UpdateDiagnosticSummary,
327    UpdateFollowers,
328    UpdateLanguageServer,
329    UpdateProject,
330    UpdateProjectCollaborator,
331    UpdateWorktree,
332    UpdateDiffBase
333);
334
335entity_messages!(channel_id, ChannelMessageSent);
336
337const KIB: usize = 1024;
338const MIB: usize = KIB * 1024;
339const MAX_BUFFER_LEN: usize = MIB;
340
341/// A stream of protobuf messages.
342pub struct MessageStream<S> {
343    stream: S,
344    encoding_buffer: Vec<u8>,
345}
346
347#[allow(clippy::large_enum_variant)]
348#[derive(Debug)]
349pub enum Message {
350    Envelope(Envelope),
351    Ping,
352    Pong,
353}
354
355impl<S> MessageStream<S> {
356    pub fn new(stream: S) -> Self {
357        Self {
358            stream,
359            encoding_buffer: Vec::new(),
360        }
361    }
362
363    pub fn inner_mut(&mut self) -> &mut S {
364        &mut self.stream
365    }
366}
367
368impl<S> MessageStream<S>
369where
370    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
371{
372    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
373        #[cfg(any(test, feature = "test-support"))]
374        const COMPRESSION_LEVEL: i32 = -7;
375
376        #[cfg(not(any(test, feature = "test-support")))]
377        const COMPRESSION_LEVEL: i32 = 4;
378
379        match message {
380            Message::Envelope(message) => {
381                self.encoding_buffer.reserve(message.encoded_len());
382                message
383                    .encode(&mut self.encoding_buffer)
384                    .map_err(io::Error::from)?;
385                let buffer =
386                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
387                        .unwrap();
388
389                self.encoding_buffer.clear();
390                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
391                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
392            }
393            Message::Ping => {
394                self.stream
395                    .send(WebSocketMessage::Ping(Default::default()))
396                    .await?;
397            }
398            Message::Pong => {
399                self.stream
400                    .send(WebSocketMessage::Pong(Default::default()))
401                    .await?;
402            }
403        }
404
405        Ok(())
406    }
407}
408
409impl<S> MessageStream<S>
410where
411    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
412{
413    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
414        while let Some(bytes) = self.stream.next().await {
415            match bytes? {
416                WebSocketMessage::Binary(bytes) => {
417                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
418                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
419                        .map_err(io::Error::from)?;
420
421                    self.encoding_buffer.clear();
422                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
423                    return Ok(Message::Envelope(envelope));
424                }
425                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
426                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
427                WebSocketMessage::Close(_) => break,
428                _ => {}
429            }
430        }
431        Err(anyhow!("connection closed"))
432    }
433}
434
435impl From<Timestamp> for SystemTime {
436    fn from(val: Timestamp) -> Self {
437        UNIX_EPOCH
438            .checked_add(Duration::new(val.seconds, val.nanos))
439            .unwrap()
440    }
441}
442
443impl From<SystemTime> for Timestamp {
444    fn from(time: SystemTime) -> Self {
445        let duration = time.duration_since(UNIX_EPOCH).unwrap();
446        Self {
447            seconds: duration.as_secs(),
448            nanos: duration.subsec_nanos(),
449        }
450    }
451}
452
453impl From<u128> for Nonce {
454    fn from(nonce: u128) -> Self {
455        let upper_half = (nonce >> 64) as u64;
456        let lower_half = nonce as u64;
457        Self {
458            upper_half,
459            lower_half,
460        }
461    }
462}
463
464impl From<Nonce> for u128 {
465    fn from(nonce: Nonce) -> Self {
466        let upper_half = (nonce.upper_half as u128) << 64;
467        let lower_half = nonce.lower_half as u128;
468        upper_half | lower_half
469    }
470}
471
472pub fn split_worktree_update(
473    mut message: UpdateWorktree,
474    max_chunk_size: usize,
475) -> impl Iterator<Item = UpdateWorktree> {
476    let mut done = false;
477    iter::from_fn(move || {
478        if done {
479            return None;
480        }
481
482        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
483        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
484        done = message.updated_entries.is_empty();
485        Some(UpdateWorktree {
486            project_id: message.project_id,
487            worktree_id: message.worktree_id,
488            root_name: message.root_name.clone(),
489            abs_path: message.abs_path.clone(),
490            updated_entries,
491            removed_entries: mem::take(&mut message.removed_entries),
492            scan_id: message.scan_id,
493            is_last_update: done && message.is_last_update,
494        })
495    })
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[gpui::test]
503    async fn test_buffer_size() {
504        let (tx, rx) = futures::channel::mpsc::unbounded();
505        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
506        sink.write(Message::Envelope(Envelope {
507            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
508                root_name: "abcdefg".repeat(10),
509                ..Default::default()
510            })),
511            ..Default::default()
512        }))
513        .await
514        .unwrap();
515        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
516        sink.write(Message::Envelope(Envelope {
517            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
518                root_name: "abcdefg".repeat(1000000),
519                ..Default::default()
520            })),
521            ..Default::default()
522        }))
523        .await
524        .unwrap();
525        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
526
527        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
528        stream.read().await.unwrap();
529        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
530        stream.read().await.unwrap();
531        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
532    }
533
534    #[gpui::test]
535    fn test_converting_peer_id_from_and_to_u64() {
536        let peer_id = PeerId {
537            owner_id: 10,
538            id: 3,
539        };
540        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
541        let peer_id = PeerId {
542            owner_id: u32::MAX,
543            id: 3,
544        };
545        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
546        let peer_id = PeerId {
547            owner_id: 10,
548            id: u32::MAX,
549        };
550        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
551        let peer_id = PeerId {
552            owner_id: u32::MAX,
553            id: u32::MAX,
554        };
555        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
556    }
557}