proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{cmp, iter, mem};
  9use std::{
 10    fmt::Debug,
 11    io,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14
 15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 16
 17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 18    const NAME: &'static str;
 19    const PRIORITY: MessagePriority;
 20    fn into_envelope(
 21        self,
 22        id: u32,
 23        responding_to: Option<u32>,
 24        original_sender_id: Option<u32>,
 25    ) -> Envelope;
 26    fn from_envelope(envelope: Envelope) -> Option<Self>;
 27}
 28
 29pub trait EntityMessage: EnvelopedMessage {
 30    fn remote_entity_id(&self) -> u64;
 31}
 32
 33pub trait RequestMessage: EnvelopedMessage {
 34    type Response: EnvelopedMessage;
 35}
 36
 37pub trait AnyTypedEnvelope: 'static + Send + Sync {
 38    fn payload_type_id(&self) -> TypeId;
 39    fn payload_type_name(&self) -> &'static str;
 40    fn as_any(&self) -> &dyn Any;
 41    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 42    fn is_background(&self) -> bool;
 43    fn original_sender_id(&self) -> Option<PeerId>;
 44}
 45
 46pub enum MessagePriority {
 47    Foreground,
 48    Background,
 49}
 50
 51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 52    fn payload_type_id(&self) -> TypeId {
 53        TypeId::of::<T>()
 54    }
 55
 56    fn payload_type_name(&self) -> &'static str {
 57        T::NAME
 58    }
 59
 60    fn as_any(&self) -> &dyn Any {
 61        self
 62    }
 63
 64    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 65        self
 66    }
 67
 68    fn is_background(&self) -> bool {
 69        matches!(T::PRIORITY, MessagePriority::Background)
 70    }
 71
 72    fn original_sender_id(&self) -> Option<PeerId> {
 73        self.original_sender_id
 74    }
 75}
 76
 77messages!(
 78    (Ack, Foreground),
 79    (AddProjectCollaborator, Foreground),
 80    (ApplyCodeAction, Background),
 81    (ApplyCodeActionResponse, Background),
 82    (ApplyCompletionAdditionalEdits, Background),
 83    (ApplyCompletionAdditionalEditsResponse, Background),
 84    (BufferReloaded, Foreground),
 85    (BufferSaved, Foreground),
 86    (RemoveContact, Foreground),
 87    (ChannelMessageSent, Foreground),
 88    (CopyProjectEntry, Foreground),
 89    (CreateBufferForPeer, Foreground),
 90    (CreateProjectEntry, Foreground),
 91    (DeleteProjectEntry, Foreground),
 92    (Error, Foreground),
 93    (Follow, Foreground),
 94    (FollowResponse, Foreground),
 95    (FormatBuffers, Foreground),
 96    (FormatBuffersResponse, Foreground),
 97    (FuzzySearchUsers, Foreground),
 98    (GetChannelMessages, Foreground),
 99    (GetChannelMessagesResponse, Foreground),
100    (GetChannels, Foreground),
101    (GetChannelsResponse, Foreground),
102    (GetCodeActions, Background),
103    (GetCodeActionsResponse, Background),
104    (GetHover, Background),
105    (GetHoverResponse, Background),
106    (GetCompletions, Background),
107    (GetCompletionsResponse, Background),
108    (GetDefinition, Background),
109    (GetDefinitionResponse, Background),
110    (GetTypeDefinition, Background),
111    (GetTypeDefinitionResponse, Background),
112    (GetDocumentHighlights, Background),
113    (GetDocumentHighlightsResponse, Background),
114    (GetReferences, Background),
115    (GetReferencesResponse, Background),
116    (GetProjectSymbols, Background),
117    (GetProjectSymbolsResponse, Background),
118    (GetUsers, Foreground),
119    (UsersResponse, Foreground),
120    (JoinChannel, Foreground),
121    (JoinChannelResponse, Foreground),
122    (JoinProject, Foreground),
123    (JoinProjectResponse, Foreground),
124    (JoinProjectRequestCancelled, Foreground),
125    (LeaveChannel, Foreground),
126    (LeaveProject, Foreground),
127    (OpenBufferById, Background),
128    (OpenBufferByPath, Background),
129    (OpenBufferForSymbol, Background),
130    (OpenBufferForSymbolResponse, Background),
131    (OpenBufferResponse, Background),
132    (PerformRename, Background),
133    (PerformRenameResponse, Background),
134    (PrepareRename, Background),
135    (PrepareRenameResponse, Background),
136    (ProjectEntryResponse, Foreground),
137    (ProjectUnshared, Foreground),
138    (RegisterProjectResponse, Foreground),
139    (Ping, Foreground),
140    (RegisterProject, Foreground),
141    (RegisterProjectActivity, Foreground),
142    (ReloadBuffers, Foreground),
143    (ReloadBuffersResponse, Foreground),
144    (RemoveProjectCollaborator, Foreground),
145    (RenameProjectEntry, Foreground),
146    (RequestContact, Foreground),
147    (RequestJoinProject, Foreground),
148    (RespondToContactRequest, Foreground),
149    (RespondToJoinProjectRequest, Foreground),
150    (SaveBuffer, Foreground),
151    (SearchProject, Background),
152    (SearchProjectResponse, Background),
153    (SendChannelMessage, Foreground),
154    (SendChannelMessageResponse, Foreground),
155    (ShowContacts, Foreground),
156    (StartLanguageServer, Foreground),
157    (Test, Foreground),
158    (Unfollow, Foreground),
159    (UnregisterProject, Foreground),
160    (UpdateBuffer, Foreground),
161    (UpdateBufferFile, Foreground),
162    (UpdateContacts, Foreground),
163    (UpdateDiagnosticSummary, Foreground),
164    (UpdateFollowers, Foreground),
165    (UpdateInviteInfo, Foreground),
166    (UpdateLanguageServer, Foreground),
167    (UpdateProject, Foreground),
168    (UpdateWorktree, Foreground),
169    (UpdateWorktreeExtensions, Background),
170    (UpdateDiffBase, Background),
171    (GetPrivateUserInfo, Foreground),
172    (GetPrivateUserInfoResponse, Foreground),
173);
174
175request_messages!(
176    (ApplyCodeAction, ApplyCodeActionResponse),
177    (
178        ApplyCompletionAdditionalEdits,
179        ApplyCompletionAdditionalEditsResponse
180    ),
181    (CopyProjectEntry, ProjectEntryResponse),
182    (CreateProjectEntry, ProjectEntryResponse),
183    (DeleteProjectEntry, ProjectEntryResponse),
184    (Follow, FollowResponse),
185    (FormatBuffers, FormatBuffersResponse),
186    (GetChannelMessages, GetChannelMessagesResponse),
187    (GetChannels, GetChannelsResponse),
188    (GetCodeActions, GetCodeActionsResponse),
189    (GetHover, GetHoverResponse),
190    (GetCompletions, GetCompletionsResponse),
191    (GetDefinition, GetDefinitionResponse),
192    (GetTypeDefinition, GetTypeDefinitionResponse),
193    (GetDocumentHighlights, GetDocumentHighlightsResponse),
194    (GetReferences, GetReferencesResponse),
195    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
196    (GetProjectSymbols, GetProjectSymbolsResponse),
197    (FuzzySearchUsers, UsersResponse),
198    (GetUsers, UsersResponse),
199    (JoinChannel, JoinChannelResponse),
200    (JoinProject, JoinProjectResponse),
201    (OpenBufferById, OpenBufferResponse),
202    (OpenBufferByPath, OpenBufferResponse),
203    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
204    (Ping, Ack),
205    (PerformRename, PerformRenameResponse),
206    (PrepareRename, PrepareRenameResponse),
207    (RegisterProject, RegisterProjectResponse),
208    (ReloadBuffers, ReloadBuffersResponse),
209    (RequestContact, Ack),
210    (RemoveContact, Ack),
211    (RespondToContactRequest, Ack),
212    (RenameProjectEntry, ProjectEntryResponse),
213    (SaveBuffer, BufferSaved),
214    (SearchProject, SearchProjectResponse),
215    (SendChannelMessage, SendChannelMessageResponse),
216    (Test, Test),
217    (UnregisterProject, Ack),
218    (UpdateBuffer, Ack),
219    (UpdateWorktree, Ack),
220);
221
222entity_messages!(
223    project_id,
224    AddProjectCollaborator,
225    ApplyCodeAction,
226    ApplyCompletionAdditionalEdits,
227    BufferReloaded,
228    BufferSaved,
229    CopyProjectEntry,
230    CreateBufferForPeer,
231    CreateProjectEntry,
232    DeleteProjectEntry,
233    Follow,
234    FormatBuffers,
235    GetCodeActions,
236    GetCompletions,
237    GetDefinition,
238    GetTypeDefinition,
239    GetDocumentHighlights,
240    GetHover,
241    GetReferences,
242    GetProjectSymbols,
243    JoinProject,
244    JoinProjectRequestCancelled,
245    LeaveProject,
246    OpenBufferById,
247    OpenBufferByPath,
248    OpenBufferForSymbol,
249    PerformRename,
250    PrepareRename,
251    ProjectUnshared,
252    RegisterProjectActivity,
253    ReloadBuffers,
254    RemoveProjectCollaborator,
255    RenameProjectEntry,
256    RequestJoinProject,
257    SaveBuffer,
258    SearchProject,
259    StartLanguageServer,
260    Unfollow,
261    UnregisterProject,
262    UpdateBuffer,
263    UpdateBufferFile,
264    UpdateDiagnosticSummary,
265    UpdateFollowers,
266    UpdateLanguageServer,
267    UpdateProject,
268    UpdateWorktree,
269    UpdateWorktreeExtensions,
270    UpdateDiffBase
271);
272
273entity_messages!(channel_id, ChannelMessageSent);
274
275const KIB: usize = 1024;
276const MIB: usize = KIB * 1024;
277const MAX_BUFFER_LEN: usize = MIB;
278
279/// A stream of protobuf messages.
280pub struct MessageStream<S> {
281    stream: S,
282    encoding_buffer: Vec<u8>,
283}
284
285#[allow(clippy::large_enum_variant)]
286#[derive(Debug)]
287pub enum Message {
288    Envelope(Envelope),
289    Ping,
290    Pong,
291}
292
293impl<S> MessageStream<S> {
294    pub fn new(stream: S) -> Self {
295        Self {
296            stream,
297            encoding_buffer: Vec::new(),
298        }
299    }
300
301    pub fn inner_mut(&mut self) -> &mut S {
302        &mut self.stream
303    }
304}
305
306impl<S> MessageStream<S>
307where
308    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
309{
310    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
311        #[cfg(any(test, feature = "test-support"))]
312        const COMPRESSION_LEVEL: i32 = -7;
313
314        #[cfg(not(any(test, feature = "test-support")))]
315        const COMPRESSION_LEVEL: i32 = 4;
316
317        match message {
318            Message::Envelope(message) => {
319                self.encoding_buffer.reserve(message.encoded_len());
320                message
321                    .encode(&mut self.encoding_buffer)
322                    .map_err(io::Error::from)?;
323                let buffer =
324                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
325                        .unwrap();
326
327                self.encoding_buffer.clear();
328                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
329                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
330            }
331            Message::Ping => {
332                self.stream
333                    .send(WebSocketMessage::Ping(Default::default()))
334                    .await?;
335            }
336            Message::Pong => {
337                self.stream
338                    .send(WebSocketMessage::Pong(Default::default()))
339                    .await?;
340            }
341        }
342
343        Ok(())
344    }
345}
346
347impl<S> MessageStream<S>
348where
349    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
350{
351    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
352        while let Some(bytes) = self.stream.next().await {
353            match bytes? {
354                WebSocketMessage::Binary(bytes) => {
355                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
356                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
357                        .map_err(io::Error::from)?;
358
359                    self.encoding_buffer.clear();
360                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
361                    return Ok(Message::Envelope(envelope));
362                }
363                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
364                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
365                WebSocketMessage::Close(_) => break,
366                _ => {}
367            }
368        }
369        Err(anyhow!("connection closed"))
370    }
371}
372
373impl From<Timestamp> for SystemTime {
374    fn from(val: Timestamp) -> Self {
375        UNIX_EPOCH
376            .checked_add(Duration::new(val.seconds, val.nanos))
377            .unwrap()
378    }
379}
380
381impl From<SystemTime> for Timestamp {
382    fn from(time: SystemTime) -> Self {
383        let duration = time.duration_since(UNIX_EPOCH).unwrap();
384        Self {
385            seconds: duration.as_secs(),
386            nanos: duration.subsec_nanos(),
387        }
388    }
389}
390
391impl From<u128> for Nonce {
392    fn from(nonce: u128) -> Self {
393        let upper_half = (nonce >> 64) as u64;
394        let lower_half = nonce as u64;
395        Self {
396            upper_half,
397            lower_half,
398        }
399    }
400}
401
402impl From<Nonce> for u128 {
403    fn from(nonce: Nonce) -> Self {
404        let upper_half = (nonce.upper_half as u128) << 64;
405        let lower_half = nonce.lower_half as u128;
406        upper_half | lower_half
407    }
408}
409
410pub fn split_worktree_update(
411    mut message: UpdateWorktree,
412    max_chunk_size: usize,
413) -> impl Iterator<Item = UpdateWorktree> {
414    let mut done = false;
415    iter::from_fn(move || {
416        if done {
417            return None;
418        }
419
420        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
421        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
422        done = message.updated_entries.is_empty();
423        Some(UpdateWorktree {
424            project_id: message.project_id,
425            worktree_id: message.worktree_id,
426            root_name: message.root_name.clone(),
427            updated_entries,
428            removed_entries: mem::take(&mut message.removed_entries),
429            scan_id: message.scan_id,
430            is_last_update: done && message.is_last_update,
431        })
432    })
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[gpui::test]
440    async fn test_buffer_size() {
441        let (tx, rx) = futures::channel::mpsc::unbounded();
442        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
443        sink.write(Message::Envelope(Envelope {
444            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
445                root_name: "abcdefg".repeat(10),
446                ..Default::default()
447            })),
448            ..Default::default()
449        }))
450        .await
451        .unwrap();
452        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
453        sink.write(Message::Envelope(Envelope {
454            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
455                root_name: "abcdefg".repeat(1000000),
456                ..Default::default()
457            })),
458            ..Default::default()
459        }))
460        .await
461        .unwrap();
462        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
463
464        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
465        stream.read().await.unwrap();
466        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
467        stream.read().await.unwrap();
468        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
469    }
470}