proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::fmt;
  9use std::str::FromStr;
 10use std::{
 11    cmp,
 12    fmt::Debug,
 13    io, iter, mem,
 14    time::{Duration, SystemTime, UNIX_EPOCH},
 15};
 16
 17include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 18
 19pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 20    const NAME: &'static str;
 21    const PRIORITY: MessagePriority;
 22    fn into_envelope(
 23        self,
 24        id: u32,
 25        responding_to: Option<u32>,
 26        original_sender_id: Option<PeerId>,
 27    ) -> Envelope;
 28    fn from_envelope(envelope: Envelope) -> Option<Self>;
 29}
 30
 31pub trait EntityMessage: EnvelopedMessage {
 32    fn remote_entity_id(&self) -> u64;
 33}
 34
 35pub trait RequestMessage: EnvelopedMessage {
 36    type Response: EnvelopedMessage;
 37}
 38
 39pub trait AnyTypedEnvelope: 'static + Send + Sync {
 40    fn payload_type_id(&self) -> TypeId;
 41    fn payload_type_name(&self) -> &'static str;
 42    fn as_any(&self) -> &dyn Any;
 43    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 44    fn is_background(&self) -> bool;
 45    fn original_sender_id(&self) -> Option<PeerId>;
 46}
 47
 48pub enum MessagePriority {
 49    Foreground,
 50    Background,
 51}
 52
 53impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 54    fn payload_type_id(&self) -> TypeId {
 55        TypeId::of::<T>()
 56    }
 57
 58    fn payload_type_name(&self) -> &'static str {
 59        T::NAME
 60    }
 61
 62    fn as_any(&self) -> &dyn Any {
 63        self
 64    }
 65
 66    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 67        self
 68    }
 69
 70    fn is_background(&self) -> bool {
 71        matches!(T::PRIORITY, MessagePriority::Background)
 72    }
 73
 74    fn original_sender_id(&self) -> Option<PeerId> {
 75        self.original_sender_id.clone()
 76    }
 77}
 78
 79impl PeerId {
 80    pub fn from_u64(peer_id: u64) -> Self {
 81        let owner_id = (peer_id >> 32) as u32;
 82        let id = peer_id as u32;
 83        Self { owner_id, id }
 84    }
 85
 86    pub fn as_u64(self) -> u64 {
 87        ((self.owner_id as u64) << 32) | (self.id as u64)
 88    }
 89}
 90
 91impl Copy for PeerId {}
 92
 93impl Eq for PeerId {}
 94
 95impl Ord for PeerId {
 96    fn cmp(&self, other: &Self) -> cmp::Ordering {
 97        self.owner_id
 98            .cmp(&other.owner_id)
 99            .then_with(|| self.id.cmp(&other.id))
100    }
101}
102
103impl PartialOrd for PeerId {
104    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
105        Some(self.cmp(other))
106    }
107}
108
109impl std::hash::Hash for PeerId {
110    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
111        self.owner_id.hash(state);
112        self.id.hash(state);
113    }
114}
115
116impl fmt::Display for PeerId {
117    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118        write!(f, "{}/{}", self.owner_id, self.id)
119    }
120}
121
122impl FromStr for PeerId {
123    type Err = anyhow::Error;
124
125    fn from_str(s: &str) -> Result<Self, Self::Err> {
126        let mut components = s.split('/');
127        let owner_id = components
128            .next()
129            .ok_or_else(|| anyhow!("invalid peer id {:?}", s))?
130            .parse()?;
131        let id = components
132            .next()
133            .ok_or_else(|| anyhow!("invalid peer id {:?}", s))?
134            .parse()?;
135        Ok(PeerId { owner_id, id })
136    }
137}
138
139messages!(
140    (Ack, Foreground),
141    (AddProjectCollaborator, Foreground),
142    (ApplyCodeAction, Background),
143    (ApplyCodeActionResponse, Background),
144    (ApplyCompletionAdditionalEdits, Background),
145    (ApplyCompletionAdditionalEditsResponse, Background),
146    (BufferReloaded, Foreground),
147    (BufferSaved, Foreground),
148    (Call, Foreground),
149    (CallCanceled, Foreground),
150    (CancelCall, Foreground),
151    (ChannelMessageSent, Foreground),
152    (CopyProjectEntry, Foreground),
153    (CreateBufferForPeer, Foreground),
154    (CreateProjectEntry, Foreground),
155    (CreateRoom, Foreground),
156    (CreateRoomResponse, Foreground),
157    (DeclineCall, Foreground),
158    (DeleteProjectEntry, Foreground),
159    (Error, Foreground),
160    (Follow, Foreground),
161    (FollowResponse, Foreground),
162    (FormatBuffers, Foreground),
163    (FormatBuffersResponse, Foreground),
164    (FuzzySearchUsers, Foreground),
165    (GetChannelMessages, Foreground),
166    (GetChannelMessagesResponse, Foreground),
167    (GetChannels, Foreground),
168    (GetChannelsResponse, Foreground),
169    (GetCodeActions, Background),
170    (GetCodeActionsResponse, Background),
171    (GetHover, Background),
172    (GetHoverResponse, Background),
173    (GetCompletions, Background),
174    (GetCompletionsResponse, Background),
175    (GetDefinition, Background),
176    (GetDefinitionResponse, Background),
177    (GetTypeDefinition, Background),
178    (GetTypeDefinitionResponse, Background),
179    (GetDocumentHighlights, Background),
180    (GetDocumentHighlightsResponse, Background),
181    (GetReferences, Background),
182    (GetReferencesResponse, Background),
183    (GetProjectSymbols, Background),
184    (GetProjectSymbolsResponse, Background),
185    (GetUsers, Foreground),
186    (Hello, Foreground),
187    (IncomingCall, Foreground),
188    (UsersResponse, Foreground),
189    (JoinChannel, Foreground),
190    (JoinChannelResponse, Foreground),
191    (JoinProject, Foreground),
192    (JoinProjectResponse, Foreground),
193    (JoinRoom, Foreground),
194    (JoinRoomResponse, Foreground),
195    (LeaveChannel, Foreground),
196    (LeaveProject, Foreground),
197    (LeaveRoom, Foreground),
198    (OpenBufferById, Background),
199    (OpenBufferByPath, Background),
200    (OpenBufferForSymbol, Background),
201    (OpenBufferForSymbolResponse, Background),
202    (OpenBufferResponse, Background),
203    (PerformRename, Background),
204    (PerformRenameResponse, Background),
205    (Ping, Foreground),
206    (PrepareRename, Background),
207    (PrepareRenameResponse, Background),
208    (ProjectEntryResponse, Foreground),
209    (RemoveContact, Foreground),
210    (ReloadBuffers, Foreground),
211    (ReloadBuffersResponse, Foreground),
212    (RemoveProjectCollaborator, Foreground),
213    (RenameProjectEntry, Foreground),
214    (RequestContact, Foreground),
215    (RespondToContactRequest, Foreground),
216    (RoomUpdated, Foreground),
217    (SaveBuffer, Foreground),
218    (SearchProject, Background),
219    (SearchProjectResponse, Background),
220    (SendChannelMessage, Foreground),
221    (SendChannelMessageResponse, Foreground),
222    (ShareProject, Foreground),
223    (ShareProjectResponse, Foreground),
224    (ShowContacts, Foreground),
225    (StartLanguageServer, Foreground),
226    (Test, Foreground),
227    (Unfollow, Foreground),
228    (UnshareProject, Foreground),
229    (UpdateBuffer, Foreground),
230    (UpdateBufferFile, Foreground),
231    (UpdateContacts, Foreground),
232    (UpdateDiagnosticSummary, Foreground),
233    (UpdateFollowers, Foreground),
234    (UpdateInviteInfo, Foreground),
235    (UpdateLanguageServer, Foreground),
236    (UpdateParticipantLocation, Foreground),
237    (UpdateProject, Foreground),
238    (UpdateWorktree, Foreground),
239    (UpdateDiffBase, Background),
240    (GetPrivateUserInfo, Foreground),
241    (GetPrivateUserInfoResponse, Foreground),
242);
243
244request_messages!(
245    (ApplyCodeAction, ApplyCodeActionResponse),
246    (
247        ApplyCompletionAdditionalEdits,
248        ApplyCompletionAdditionalEditsResponse
249    ),
250    (Call, Ack),
251    (CancelCall, Ack),
252    (CopyProjectEntry, ProjectEntryResponse),
253    (CreateProjectEntry, ProjectEntryResponse),
254    (CreateRoom, CreateRoomResponse),
255    (DeclineCall, Ack),
256    (DeleteProjectEntry, ProjectEntryResponse),
257    (Follow, FollowResponse),
258    (FormatBuffers, FormatBuffersResponse),
259    (GetChannelMessages, GetChannelMessagesResponse),
260    (GetChannels, GetChannelsResponse),
261    (GetCodeActions, GetCodeActionsResponse),
262    (GetHover, GetHoverResponse),
263    (GetCompletions, GetCompletionsResponse),
264    (GetDefinition, GetDefinitionResponse),
265    (GetTypeDefinition, GetTypeDefinitionResponse),
266    (GetDocumentHighlights, GetDocumentHighlightsResponse),
267    (GetReferences, GetReferencesResponse),
268    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
269    (GetProjectSymbols, GetProjectSymbolsResponse),
270    (FuzzySearchUsers, UsersResponse),
271    (GetUsers, UsersResponse),
272    (JoinChannel, JoinChannelResponse),
273    (JoinProject, JoinProjectResponse),
274    (JoinRoom, JoinRoomResponse),
275    (IncomingCall, Ack),
276    (OpenBufferById, OpenBufferResponse),
277    (OpenBufferByPath, OpenBufferResponse),
278    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
279    (Ping, Ack),
280    (PerformRename, PerformRenameResponse),
281    (PrepareRename, PrepareRenameResponse),
282    (ReloadBuffers, ReloadBuffersResponse),
283    (RequestContact, Ack),
284    (RemoveContact, Ack),
285    (RespondToContactRequest, Ack),
286    (RenameProjectEntry, ProjectEntryResponse),
287    (SaveBuffer, BufferSaved),
288    (SearchProject, SearchProjectResponse),
289    (SendChannelMessage, SendChannelMessageResponse),
290    (ShareProject, ShareProjectResponse),
291    (Test, Test),
292    (UpdateBuffer, Ack),
293    (UpdateParticipantLocation, Ack),
294    (UpdateProject, Ack),
295    (UpdateWorktree, Ack),
296);
297
298entity_messages!(
299    project_id,
300    AddProjectCollaborator,
301    ApplyCodeAction,
302    ApplyCompletionAdditionalEdits,
303    BufferReloaded,
304    BufferSaved,
305    CopyProjectEntry,
306    CreateBufferForPeer,
307    CreateProjectEntry,
308    DeleteProjectEntry,
309    Follow,
310    FormatBuffers,
311    GetCodeActions,
312    GetCompletions,
313    GetDefinition,
314    GetTypeDefinition,
315    GetDocumentHighlights,
316    GetHover,
317    GetReferences,
318    GetProjectSymbols,
319    JoinProject,
320    LeaveProject,
321    OpenBufferById,
322    OpenBufferByPath,
323    OpenBufferForSymbol,
324    PerformRename,
325    PrepareRename,
326    ReloadBuffers,
327    RemoveProjectCollaborator,
328    RenameProjectEntry,
329    SaveBuffer,
330    SearchProject,
331    StartLanguageServer,
332    Unfollow,
333    UnshareProject,
334    UpdateBuffer,
335    UpdateBufferFile,
336    UpdateDiagnosticSummary,
337    UpdateFollowers,
338    UpdateLanguageServer,
339    UpdateProject,
340    UpdateWorktree,
341    UpdateDiffBase
342);
343
344entity_messages!(channel_id, ChannelMessageSent);
345
346const KIB: usize = 1024;
347const MIB: usize = KIB * 1024;
348const MAX_BUFFER_LEN: usize = MIB;
349
350/// A stream of protobuf messages.
351pub struct MessageStream<S> {
352    stream: S,
353    encoding_buffer: Vec<u8>,
354}
355
356#[allow(clippy::large_enum_variant)]
357#[derive(Debug)]
358pub enum Message {
359    Envelope(Envelope),
360    Ping,
361    Pong,
362}
363
364impl<S> MessageStream<S> {
365    pub fn new(stream: S) -> Self {
366        Self {
367            stream,
368            encoding_buffer: Vec::new(),
369        }
370    }
371
372    pub fn inner_mut(&mut self) -> &mut S {
373        &mut self.stream
374    }
375}
376
377impl<S> MessageStream<S>
378where
379    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
380{
381    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
382        #[cfg(any(test, feature = "test-support"))]
383        const COMPRESSION_LEVEL: i32 = -7;
384
385        #[cfg(not(any(test, feature = "test-support")))]
386        const COMPRESSION_LEVEL: i32 = 4;
387
388        match message {
389            Message::Envelope(message) => {
390                self.encoding_buffer.reserve(message.encoded_len());
391                message
392                    .encode(&mut self.encoding_buffer)
393                    .map_err(io::Error::from)?;
394                let buffer =
395                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
396                        .unwrap();
397
398                self.encoding_buffer.clear();
399                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
400                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
401            }
402            Message::Ping => {
403                self.stream
404                    .send(WebSocketMessage::Ping(Default::default()))
405                    .await?;
406            }
407            Message::Pong => {
408                self.stream
409                    .send(WebSocketMessage::Pong(Default::default()))
410                    .await?;
411            }
412        }
413
414        Ok(())
415    }
416}
417
418impl<S> MessageStream<S>
419where
420    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
421{
422    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
423        while let Some(bytes) = self.stream.next().await {
424            match bytes? {
425                WebSocketMessage::Binary(bytes) => {
426                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
427                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
428                        .map_err(io::Error::from)?;
429
430                    self.encoding_buffer.clear();
431                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
432                    return Ok(Message::Envelope(envelope));
433                }
434                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
435                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
436                WebSocketMessage::Close(_) => break,
437                _ => {}
438            }
439        }
440        Err(anyhow!("connection closed"))
441    }
442}
443
444impl From<Timestamp> for SystemTime {
445    fn from(val: Timestamp) -> Self {
446        UNIX_EPOCH
447            .checked_add(Duration::new(val.seconds, val.nanos))
448            .unwrap()
449    }
450}
451
452impl From<SystemTime> for Timestamp {
453    fn from(time: SystemTime) -> Self {
454        let duration = time.duration_since(UNIX_EPOCH).unwrap();
455        Self {
456            seconds: duration.as_secs(),
457            nanos: duration.subsec_nanos(),
458        }
459    }
460}
461
462impl From<u128> for Nonce {
463    fn from(nonce: u128) -> Self {
464        let upper_half = (nonce >> 64) as u64;
465        let lower_half = nonce as u64;
466        Self {
467            upper_half,
468            lower_half,
469        }
470    }
471}
472
473impl From<Nonce> for u128 {
474    fn from(nonce: Nonce) -> Self {
475        let upper_half = (nonce.upper_half as u128) << 64;
476        let lower_half = nonce.lower_half as u128;
477        upper_half | lower_half
478    }
479}
480
481pub fn split_worktree_update(
482    mut message: UpdateWorktree,
483    max_chunk_size: usize,
484) -> impl Iterator<Item = UpdateWorktree> {
485    let mut done = false;
486    iter::from_fn(move || {
487        if done {
488            return None;
489        }
490
491        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
492        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
493        done = message.updated_entries.is_empty();
494        Some(UpdateWorktree {
495            project_id: message.project_id,
496            worktree_id: message.worktree_id,
497            root_name: message.root_name.clone(),
498            abs_path: message.abs_path.clone(),
499            updated_entries,
500            removed_entries: mem::take(&mut message.removed_entries),
501            scan_id: message.scan_id,
502            is_last_update: done && message.is_last_update,
503        })
504    })
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510
511    #[gpui::test]
512    async fn test_buffer_size() {
513        let (tx, rx) = futures::channel::mpsc::unbounded();
514        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
515        sink.write(Message::Envelope(Envelope {
516            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
517                root_name: "abcdefg".repeat(10),
518                ..Default::default()
519            })),
520            ..Default::default()
521        }))
522        .await
523        .unwrap();
524        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
525        sink.write(Message::Envelope(Envelope {
526            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
527                root_name: "abcdefg".repeat(1000000),
528                ..Default::default()
529            })),
530            ..Default::default()
531        }))
532        .await
533        .unwrap();
534        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
535
536        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
537        stream.read().await.unwrap();
538        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
539        stream.read().await.unwrap();
540        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
541    }
542
543    #[gpui::test]
544    fn test_converting_peer_id_from_and_to_u64() {
545        let peer_id = PeerId {
546            owner_id: 10,
547            id: 3,
548        };
549        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
550        let peer_id = PeerId {
551            owner_id: u32::MAX,
552            id: 3,
553        };
554        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
555        let peer_id = PeerId {
556            owner_id: 10,
557            id: u32::MAX,
558        };
559        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
560        let peer_id = PeerId {
561            owner_id: u32::MAX,
562            id: u32::MAX,
563        };
564        assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
565    }
566}