proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{cmp, iter, mem};
  9use std::{
 10    fmt::Debug,
 11    io,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14
 15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 16
 17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 18    const NAME: &'static str;
 19    const PRIORITY: MessagePriority;
 20    fn into_envelope(
 21        self,
 22        id: u32,
 23        responding_to: Option<u32>,
 24        original_sender_id: Option<u32>,
 25    ) -> Envelope;
 26    fn from_envelope(envelope: Envelope) -> Option<Self>;
 27}
 28
 29pub trait EntityMessage: EnvelopedMessage {
 30    fn remote_entity_id(&self) -> u64;
 31}
 32
 33pub trait RequestMessage: EnvelopedMessage {
 34    type Response: EnvelopedMessage;
 35}
 36
 37pub trait AnyTypedEnvelope: 'static + Send + Sync {
 38    fn payload_type_id(&self) -> TypeId;
 39    fn payload_type_name(&self) -> &'static str;
 40    fn as_any(&self) -> &dyn Any;
 41    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 42    fn is_background(&self) -> bool;
 43    fn original_sender_id(&self) -> Option<PeerId>;
 44}
 45
 46pub enum MessagePriority {
 47    Foreground,
 48    Background,
 49}
 50
 51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 52    fn payload_type_id(&self) -> TypeId {
 53        TypeId::of::<T>()
 54    }
 55
 56    fn payload_type_name(&self) -> &'static str {
 57        T::NAME
 58    }
 59
 60    fn as_any(&self) -> &dyn Any {
 61        self
 62    }
 63
 64    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 65        self
 66    }
 67
 68    fn is_background(&self) -> bool {
 69        matches!(T::PRIORITY, MessagePriority::Background)
 70    }
 71
 72    fn original_sender_id(&self) -> Option<PeerId> {
 73        self.original_sender_id
 74    }
 75}
 76
 77messages!(
 78    (Ack, Foreground),
 79    (AddProjectCollaborator, Foreground),
 80    (ApplyCodeAction, Background),
 81    (ApplyCodeActionResponse, Background),
 82    (ApplyCompletionAdditionalEdits, Background),
 83    (ApplyCompletionAdditionalEditsResponse, Background),
 84    (BufferReloaded, Foreground),
 85    (BufferSaved, Foreground),
 86    (Call, Foreground),
 87    (CallCanceled, Foreground),
 88    (CancelCall, Foreground),
 89    (ChannelMessageSent, Foreground),
 90    (CopyProjectEntry, Foreground),
 91    (CreateBufferForPeer, Foreground),
 92    (CreateProjectEntry, Foreground),
 93    (CreateRoom, Foreground),
 94    (CreateRoomResponse, Foreground),
 95    (DeclineCall, Foreground),
 96    (DeleteProjectEntry, Foreground),
 97    (Error, Foreground),
 98    (Follow, Foreground),
 99    (FollowResponse, Foreground),
100    (FormatBuffers, Foreground),
101    (FormatBuffersResponse, Foreground),
102    (FuzzySearchUsers, Foreground),
103    (GetChannelMessages, Foreground),
104    (GetChannelMessagesResponse, Foreground),
105    (GetChannels, Foreground),
106    (GetChannelsResponse, Foreground),
107    (GetCodeActions, Background),
108    (GetCodeActionsResponse, Background),
109    (GetHover, Background),
110    (GetHoverResponse, Background),
111    (GetCompletions, Background),
112    (GetCompletionsResponse, Background),
113    (GetDefinition, Background),
114    (GetDefinitionResponse, Background),
115    (GetTypeDefinition, Background),
116    (GetTypeDefinitionResponse, Background),
117    (GetDocumentHighlights, Background),
118    (GetDocumentHighlightsResponse, Background),
119    (GetReferences, Background),
120    (GetReferencesResponse, Background),
121    (GetProjectSymbols, Background),
122    (GetProjectSymbolsResponse, Background),
123    (GetUsers, Foreground),
124    (Hello, Foreground),
125    (IncomingCall, Foreground),
126    (UsersResponse, Foreground),
127    (JoinChannel, Foreground),
128    (JoinChannelResponse, Foreground),
129    (JoinProject, Foreground),
130    (JoinProjectResponse, Foreground),
131    (JoinRoom, Foreground),
132    (JoinRoomResponse, Foreground),
133    (LeaveChannel, Foreground),
134    (LeaveProject, Foreground),
135    (LeaveRoom, Foreground),
136    (OpenBufferById, Background),
137    (OpenBufferByPath, Background),
138    (OpenBufferForSymbol, Background),
139    (OpenBufferForSymbolResponse, Background),
140    (OpenBufferResponse, Background),
141    (PerformRename, Background),
142    (PerformRenameResponse, Background),
143    (Ping, Foreground),
144    (PrepareRename, Background),
145    (PrepareRenameResponse, Background),
146    (ProjectEntryResponse, Foreground),
147    (RemoveContact, Foreground),
148    (ReloadBuffers, Foreground),
149    (ReloadBuffersResponse, Foreground),
150    (RemoveProjectCollaborator, Foreground),
151    (RenameProjectEntry, Foreground),
152    (RequestContact, Foreground),
153    (RespondToContactRequest, Foreground),
154    (RoomUpdated, Foreground),
155    (SaveBuffer, Foreground),
156    (SearchProject, Background),
157    (SearchProjectResponse, Background),
158    (SendChannelMessage, Foreground),
159    (SendChannelMessageResponse, Foreground),
160    (ShareProject, Foreground),
161    (ShareProjectResponse, Foreground),
162    (ShowContacts, Foreground),
163    (StartLanguageServer, Foreground),
164    (Test, Foreground),
165    (Unfollow, Foreground),
166    (UnshareProject, Foreground),
167    (UpdateBuffer, Foreground),
168    (UpdateBufferFile, Foreground),
169    (UpdateContacts, Foreground),
170    (UpdateDiagnosticSummary, Foreground),
171    (UpdateFollowers, Foreground),
172    (UpdateInviteInfo, Foreground),
173    (UpdateLanguageServer, Foreground),
174    (UpdateParticipantLocation, Foreground),
175    (UpdateProject, Foreground),
176    (UpdateWorktree, Foreground),
177    (UpdateDiffBase, Background),
178    (GetPrivateUserInfo, Foreground),
179    (GetPrivateUserInfoResponse, Foreground),
180);
181
182request_messages!(
183    (ApplyCodeAction, ApplyCodeActionResponse),
184    (
185        ApplyCompletionAdditionalEdits,
186        ApplyCompletionAdditionalEditsResponse
187    ),
188    (Call, Ack),
189    (CancelCall, Ack),
190    (CopyProjectEntry, ProjectEntryResponse),
191    (CreateProjectEntry, ProjectEntryResponse),
192    (CreateRoom, CreateRoomResponse),
193    (DeclineCall, Ack),
194    (DeleteProjectEntry, ProjectEntryResponse),
195    (Follow, FollowResponse),
196    (FormatBuffers, FormatBuffersResponse),
197    (GetChannelMessages, GetChannelMessagesResponse),
198    (GetChannels, GetChannelsResponse),
199    (GetCodeActions, GetCodeActionsResponse),
200    (GetHover, GetHoverResponse),
201    (GetCompletions, GetCompletionsResponse),
202    (GetDefinition, GetDefinitionResponse),
203    (GetTypeDefinition, GetTypeDefinitionResponse),
204    (GetDocumentHighlights, GetDocumentHighlightsResponse),
205    (GetReferences, GetReferencesResponse),
206    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
207    (GetProjectSymbols, GetProjectSymbolsResponse),
208    (FuzzySearchUsers, UsersResponse),
209    (GetUsers, UsersResponse),
210    (JoinChannel, JoinChannelResponse),
211    (JoinProject, JoinProjectResponse),
212    (JoinRoom, JoinRoomResponse),
213    (IncomingCall, Ack),
214    (OpenBufferById, OpenBufferResponse),
215    (OpenBufferByPath, OpenBufferResponse),
216    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
217    (Ping, Ack),
218    (PerformRename, PerformRenameResponse),
219    (PrepareRename, PrepareRenameResponse),
220    (ReloadBuffers, ReloadBuffersResponse),
221    (RequestContact, Ack),
222    (RemoveContact, Ack),
223    (RespondToContactRequest, Ack),
224    (RenameProjectEntry, ProjectEntryResponse),
225    (SaveBuffer, BufferSaved),
226    (SearchProject, SearchProjectResponse),
227    (SendChannelMessage, SendChannelMessageResponse),
228    (ShareProject, ShareProjectResponse),
229    (Test, Test),
230    (UpdateBuffer, Ack),
231    (UpdateParticipantLocation, Ack),
232    (UpdateProject, Ack),
233    (UpdateWorktree, Ack),
234);
235
236entity_messages!(
237    project_id,
238    AddProjectCollaborator,
239    ApplyCodeAction,
240    ApplyCompletionAdditionalEdits,
241    BufferReloaded,
242    BufferSaved,
243    CopyProjectEntry,
244    CreateBufferForPeer,
245    CreateProjectEntry,
246    DeleteProjectEntry,
247    Follow,
248    FormatBuffers,
249    GetCodeActions,
250    GetCompletions,
251    GetDefinition,
252    GetTypeDefinition,
253    GetDocumentHighlights,
254    GetHover,
255    GetReferences,
256    GetProjectSymbols,
257    JoinProject,
258    LeaveProject,
259    OpenBufferById,
260    OpenBufferByPath,
261    OpenBufferForSymbol,
262    PerformRename,
263    PrepareRename,
264    ReloadBuffers,
265    RemoveProjectCollaborator,
266    RenameProjectEntry,
267    SaveBuffer,
268    SearchProject,
269    StartLanguageServer,
270    Unfollow,
271    UnshareProject,
272    UpdateBuffer,
273    UpdateBufferFile,
274    UpdateDiagnosticSummary,
275    UpdateFollowers,
276    UpdateLanguageServer,
277    UpdateProject,
278    UpdateWorktree,
279    UpdateDiffBase
280);
281
282entity_messages!(channel_id, ChannelMessageSent);
283
284const KIB: usize = 1024;
285const MIB: usize = KIB * 1024;
286const MAX_BUFFER_LEN: usize = MIB;
287
288/// A stream of protobuf messages.
289pub struct MessageStream<S> {
290    stream: S,
291    encoding_buffer: Vec<u8>,
292}
293
294#[allow(clippy::large_enum_variant)]
295#[derive(Debug)]
296pub enum Message {
297    Envelope(Envelope),
298    Ping,
299    Pong,
300}
301
302impl<S> MessageStream<S> {
303    pub fn new(stream: S) -> Self {
304        Self {
305            stream,
306            encoding_buffer: Vec::new(),
307        }
308    }
309
310    pub fn inner_mut(&mut self) -> &mut S {
311        &mut self.stream
312    }
313}
314
315impl<S> MessageStream<S>
316where
317    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
318{
319    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
320        #[cfg(any(test, feature = "test-support"))]
321        const COMPRESSION_LEVEL: i32 = -7;
322
323        #[cfg(not(any(test, feature = "test-support")))]
324        const COMPRESSION_LEVEL: i32 = 4;
325
326        match message {
327            Message::Envelope(message) => {
328                self.encoding_buffer.reserve(message.encoded_len());
329                message
330                    .encode(&mut self.encoding_buffer)
331                    .map_err(io::Error::from)?;
332                let buffer =
333                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
334                        .unwrap();
335
336                self.encoding_buffer.clear();
337                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
338                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
339            }
340            Message::Ping => {
341                self.stream
342                    .send(WebSocketMessage::Ping(Default::default()))
343                    .await?;
344            }
345            Message::Pong => {
346                self.stream
347                    .send(WebSocketMessage::Pong(Default::default()))
348                    .await?;
349            }
350        }
351
352        Ok(())
353    }
354}
355
356impl<S> MessageStream<S>
357where
358    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
359{
360    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
361        while let Some(bytes) = self.stream.next().await {
362            match bytes? {
363                WebSocketMessage::Binary(bytes) => {
364                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
365                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
366                        .map_err(io::Error::from)?;
367
368                    self.encoding_buffer.clear();
369                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
370                    return Ok(Message::Envelope(envelope));
371                }
372                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
373                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
374                WebSocketMessage::Close(_) => break,
375                _ => {}
376            }
377        }
378        Err(anyhow!("connection closed"))
379    }
380}
381
382impl From<Timestamp> for SystemTime {
383    fn from(val: Timestamp) -> Self {
384        UNIX_EPOCH
385            .checked_add(Duration::new(val.seconds, val.nanos))
386            .unwrap()
387    }
388}
389
390impl From<SystemTime> for Timestamp {
391    fn from(time: SystemTime) -> Self {
392        let duration = time.duration_since(UNIX_EPOCH).unwrap();
393        Self {
394            seconds: duration.as_secs(),
395            nanos: duration.subsec_nanos(),
396        }
397    }
398}
399
400impl From<u128> for Nonce {
401    fn from(nonce: u128) -> Self {
402        let upper_half = (nonce >> 64) as u64;
403        let lower_half = nonce as u64;
404        Self {
405            upper_half,
406            lower_half,
407        }
408    }
409}
410
411impl From<Nonce> for u128 {
412    fn from(nonce: Nonce) -> Self {
413        let upper_half = (nonce.upper_half as u128) << 64;
414        let lower_half = nonce.lower_half as u128;
415        upper_half | lower_half
416    }
417}
418
419pub fn split_worktree_update(
420    mut message: UpdateWorktree,
421    max_chunk_size: usize,
422) -> impl Iterator<Item = UpdateWorktree> {
423    let mut done = false;
424    iter::from_fn(move || {
425        if done {
426            return None;
427        }
428
429        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
430        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
431        done = message.updated_entries.is_empty();
432        Some(UpdateWorktree {
433            project_id: message.project_id,
434            worktree_id: message.worktree_id,
435            root_name: message.root_name.clone(),
436            abs_path: message.abs_path.clone(),
437            updated_entries,
438            removed_entries: mem::take(&mut message.removed_entries),
439            scan_id: message.scan_id,
440            is_last_update: done && message.is_last_update,
441        })
442    })
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[gpui::test]
450    async fn test_buffer_size() {
451        let (tx, rx) = futures::channel::mpsc::unbounded();
452        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
453        sink.write(Message::Envelope(Envelope {
454            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
455                root_name: "abcdefg".repeat(10),
456                ..Default::default()
457            })),
458            ..Default::default()
459        }))
460        .await
461        .unwrap();
462        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
463        sink.write(Message::Envelope(Envelope {
464            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
465                root_name: "abcdefg".repeat(1000000),
466                ..Default::default()
467            })),
468            ..Default::default()
469        }))
470        .await
471        .unwrap();
472        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
473
474        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
475        stream.read().await.unwrap();
476        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
477        stream.read().await.unwrap();
478        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
479    }
480}