proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{cmp, iter, mem};
  9use std::{
 10    fmt::Debug,
 11    io,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14
 15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 16
 17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 18    const NAME: &'static str;
 19    const PRIORITY: MessagePriority;
 20    fn into_envelope(
 21        self,
 22        id: u32,
 23        responding_to: Option<u32>,
 24        original_sender_id: Option<u32>,
 25    ) -> Envelope;
 26    fn from_envelope(envelope: Envelope) -> Option<Self>;
 27}
 28
 29pub trait EntityMessage: EnvelopedMessage {
 30    fn remote_entity_id(&self) -> u64;
 31}
 32
 33pub trait RequestMessage: EnvelopedMessage {
 34    type Response: EnvelopedMessage;
 35}
 36
 37pub trait AnyTypedEnvelope: 'static + Send + Sync {
 38    fn payload_type_id(&self) -> TypeId;
 39    fn payload_type_name(&self) -> &'static str;
 40    fn as_any(&self) -> &dyn Any;
 41    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 42    fn is_background(&self) -> bool;
 43    fn original_sender_id(&self) -> Option<PeerId>;
 44}
 45
 46pub enum MessagePriority {
 47    Foreground,
 48    Background,
 49}
 50
 51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 52    fn payload_type_id(&self) -> TypeId {
 53        TypeId::of::<T>()
 54    }
 55
 56    fn payload_type_name(&self) -> &'static str {
 57        T::NAME
 58    }
 59
 60    fn as_any(&self) -> &dyn Any {
 61        self
 62    }
 63
 64    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 65        self
 66    }
 67
 68    fn is_background(&self) -> bool {
 69        matches!(T::PRIORITY, MessagePriority::Background)
 70    }
 71
 72    fn original_sender_id(&self) -> Option<PeerId> {
 73        self.original_sender_id
 74    }
 75}
 76
 77messages!(
 78    (Ack, Foreground),
 79    (AddProjectCollaborator, Foreground),
 80    (ApplyCodeAction, Background),
 81    (ApplyCodeActionResponse, Background),
 82    (ApplyCompletionAdditionalEdits, Background),
 83    (ApplyCompletionAdditionalEditsResponse, Background),
 84    (BufferReloaded, Foreground),
 85    (BufferSaved, Foreground),
 86    (RemoveContact, Foreground),
 87    (ChannelMessageSent, Foreground),
 88    (CopyProjectEntry, Foreground),
 89    (CreateBufferForPeer, Foreground),
 90    (CreateProjectEntry, Foreground),
 91    (DeleteProjectEntry, Foreground),
 92    (Error, Foreground),
 93    (Follow, Foreground),
 94    (FollowResponse, Foreground),
 95    (FormatBuffers, Foreground),
 96    (FormatBuffersResponse, Foreground),
 97    (FuzzySearchUsers, Foreground),
 98    (GetChannelMessages, Foreground),
 99    (GetChannelMessagesResponse, Foreground),
100    (GetChannels, Foreground),
101    (GetChannelsResponse, Foreground),
102    (GetCodeActions, Background),
103    (GetCodeActionsResponse, Background),
104    (GetHover, Background),
105    (GetHoverResponse, Background),
106    (GetCompletions, Background),
107    (GetCompletionsResponse, Background),
108    (GetDefinition, Background),
109    (GetDefinitionResponse, Background),
110    (GetTypeDefinition, Background),
111    (GetTypeDefinitionResponse, Background),
112    (GetDocumentHighlights, Background),
113    (GetDocumentHighlightsResponse, Background),
114    (GetReferences, Background),
115    (GetReferencesResponse, Background),
116    (GetProjectSymbols, Background),
117    (GetProjectSymbolsResponse, Background),
118    (GetUsers, Foreground),
119    (UsersResponse, Foreground),
120    (JoinChannel, Foreground),
121    (JoinChannelResponse, Foreground),
122    (JoinProject, Foreground),
123    (JoinProjectResponse, Foreground),
124    (JoinProjectRequestCancelled, Foreground),
125    (LeaveChannel, Foreground),
126    (LeaveProject, Foreground),
127    (OpenBufferById, Background),
128    (OpenBufferByPath, Background),
129    (OpenBufferForSymbol, Background),
130    (OpenBufferForSymbolResponse, Background),
131    (OpenBufferResponse, Background),
132    (PerformRename, Background),
133    (PerformRenameResponse, Background),
134    (PrepareRename, Background),
135    (PrepareRenameResponse, Background),
136    (ProjectEntryResponse, Foreground),
137    (ProjectUnshared, Foreground),
138    (RegisterProjectResponse, Foreground),
139    (Ping, Foreground),
140    (RegisterProject, Foreground),
141    (RegisterProjectActivity, Foreground),
142    (ReloadBuffers, Foreground),
143    (ReloadBuffersResponse, Foreground),
144    (RemoveProjectCollaborator, Foreground),
145    (RenameProjectEntry, Foreground),
146    (RequestContact, Foreground),
147    (RequestJoinProject, Foreground),
148    (RespondToContactRequest, Foreground),
149    (RespondToJoinProjectRequest, Foreground),
150    (SaveBuffer, Foreground),
151    (SearchProject, Background),
152    (SearchProjectResponse, Background),
153    (SendChannelMessage, Foreground),
154    (SendChannelMessageResponse, Foreground),
155    (ShowContacts, Foreground),
156    (StartLanguageServer, Foreground),
157    (Test, Foreground),
158    (Unfollow, Foreground),
159    (UnregisterProject, Foreground),
160    (UpdateBuffer, Foreground),
161    (UpdateBufferFile, Foreground),
162    (UpdateContacts, Foreground),
163    (UpdateDiagnosticSummary, Foreground),
164    (UpdateFollowers, Foreground),
165    (UpdateInviteInfo, Foreground),
166    (UpdateLanguageServer, Foreground),
167    (UpdateProject, Foreground),
168    (UpdateWorktree, Foreground),
169    (UpdateWorktreeExtensions, Background),
170    (GetPrivateUserInfo, Foreground),
171    (GetPrivateUserInfoResponse, Foreground),
172);
173
174request_messages!(
175    (ApplyCodeAction, ApplyCodeActionResponse),
176    (
177        ApplyCompletionAdditionalEdits,
178        ApplyCompletionAdditionalEditsResponse
179    ),
180    (CopyProjectEntry, ProjectEntryResponse),
181    (CreateProjectEntry, ProjectEntryResponse),
182    (DeleteProjectEntry, ProjectEntryResponse),
183    (Follow, FollowResponse),
184    (FormatBuffers, FormatBuffersResponse),
185    (GetChannelMessages, GetChannelMessagesResponse),
186    (GetChannels, GetChannelsResponse),
187    (GetCodeActions, GetCodeActionsResponse),
188    (GetHover, GetHoverResponse),
189    (GetCompletions, GetCompletionsResponse),
190    (GetDefinition, GetDefinitionResponse),
191    (GetTypeDefinition, GetTypeDefinitionResponse),
192    (GetDocumentHighlights, GetDocumentHighlightsResponse),
193    (GetReferences, GetReferencesResponse),
194    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
195    (GetProjectSymbols, GetProjectSymbolsResponse),
196    (FuzzySearchUsers, UsersResponse),
197    (GetUsers, UsersResponse),
198    (JoinChannel, JoinChannelResponse),
199    (JoinProject, JoinProjectResponse),
200    (OpenBufferById, OpenBufferResponse),
201    (OpenBufferByPath, OpenBufferResponse),
202    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
203    (Ping, Ack),
204    (PerformRename, PerformRenameResponse),
205    (PrepareRename, PrepareRenameResponse),
206    (RegisterProject, RegisterProjectResponse),
207    (ReloadBuffers, ReloadBuffersResponse),
208    (RequestContact, Ack),
209    (RemoveContact, Ack),
210    (RespondToContactRequest, Ack),
211    (RenameProjectEntry, ProjectEntryResponse),
212    (SaveBuffer, BufferSaved),
213    (SearchProject, SearchProjectResponse),
214    (SendChannelMessage, SendChannelMessageResponse),
215    (Test, Test),
216    (UnregisterProject, Ack),
217    (UpdateBuffer, Ack),
218    (UpdateWorktree, Ack),
219);
220
221entity_messages!(
222    project_id,
223    AddProjectCollaborator,
224    ApplyCodeAction,
225    ApplyCompletionAdditionalEdits,
226    BufferReloaded,
227    BufferSaved,
228    CopyProjectEntry,
229    CreateBufferForPeer,
230    CreateProjectEntry,
231    DeleteProjectEntry,
232    Follow,
233    FormatBuffers,
234    GetCodeActions,
235    GetCompletions,
236    GetDefinition,
237    GetTypeDefinition,
238    GetDocumentHighlights,
239    GetHover,
240    GetReferences,
241    GetProjectSymbols,
242    JoinProject,
243    JoinProjectRequestCancelled,
244    LeaveProject,
245    OpenBufferById,
246    OpenBufferByPath,
247    OpenBufferForSymbol,
248    PerformRename,
249    PrepareRename,
250    ProjectUnshared,
251    RegisterProjectActivity,
252    ReloadBuffers,
253    RemoveProjectCollaborator,
254    RenameProjectEntry,
255    RequestJoinProject,
256    SaveBuffer,
257    SearchProject,
258    StartLanguageServer,
259    Unfollow,
260    UnregisterProject,
261    UpdateBuffer,
262    UpdateBufferFile,
263    UpdateDiagnosticSummary,
264    UpdateFollowers,
265    UpdateLanguageServer,
266    UpdateProject,
267    UpdateWorktree,
268    UpdateWorktreeExtensions,
269);
270
271entity_messages!(channel_id, ChannelMessageSent);
272
273const KIB: usize = 1024;
274const MIB: usize = KIB * 1024;
275const MAX_BUFFER_LEN: usize = MIB;
276
277/// A stream of protobuf messages.
278pub struct MessageStream<S> {
279    stream: S,
280    encoding_buffer: Vec<u8>,
281}
282
283#[allow(clippy::large_enum_variant)]
284#[derive(Debug)]
285pub enum Message {
286    Envelope(Envelope),
287    Ping,
288    Pong,
289}
290
291impl<S> MessageStream<S> {
292    pub fn new(stream: S) -> Self {
293        Self {
294            stream,
295            encoding_buffer: Vec::new(),
296        }
297    }
298
299    pub fn inner_mut(&mut self) -> &mut S {
300        &mut self.stream
301    }
302}
303
304impl<S> MessageStream<S>
305where
306    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
307{
308    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
309        #[cfg(any(test, feature = "test-support"))]
310        const COMPRESSION_LEVEL: i32 = -7;
311
312        #[cfg(not(any(test, feature = "test-support")))]
313        const COMPRESSION_LEVEL: i32 = 4;
314
315        match message {
316            Message::Envelope(message) => {
317                self.encoding_buffer.reserve(message.encoded_len());
318                message
319                    .encode(&mut self.encoding_buffer)
320                    .map_err(io::Error::from)?;
321                let buffer =
322                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
323                        .unwrap();
324
325                self.encoding_buffer.clear();
326                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
327                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
328            }
329            Message::Ping => {
330                self.stream
331                    .send(WebSocketMessage::Ping(Default::default()))
332                    .await?;
333            }
334            Message::Pong => {
335                self.stream
336                    .send(WebSocketMessage::Pong(Default::default()))
337                    .await?;
338            }
339        }
340
341        Ok(())
342    }
343}
344
345impl<S> MessageStream<S>
346where
347    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
348{
349    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
350        while let Some(bytes) = self.stream.next().await {
351            match bytes? {
352                WebSocketMessage::Binary(bytes) => {
353                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
354                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
355                        .map_err(io::Error::from)?;
356
357                    self.encoding_buffer.clear();
358                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
359                    return Ok(Message::Envelope(envelope));
360                }
361                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
362                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
363                WebSocketMessage::Close(_) => break,
364                _ => {}
365            }
366        }
367        Err(anyhow!("connection closed"))
368    }
369}
370
371impl From<Timestamp> for SystemTime {
372    fn from(val: Timestamp) -> Self {
373        UNIX_EPOCH
374            .checked_add(Duration::new(val.seconds, val.nanos))
375            .unwrap()
376    }
377}
378
379impl From<SystemTime> for Timestamp {
380    fn from(time: SystemTime) -> Self {
381        let duration = time.duration_since(UNIX_EPOCH).unwrap();
382        Self {
383            seconds: duration.as_secs(),
384            nanos: duration.subsec_nanos(),
385        }
386    }
387}
388
389impl From<u128> for Nonce {
390    fn from(nonce: u128) -> Self {
391        let upper_half = (nonce >> 64) as u64;
392        let lower_half = nonce as u64;
393        Self {
394            upper_half,
395            lower_half,
396        }
397    }
398}
399
400impl From<Nonce> for u128 {
401    fn from(nonce: Nonce) -> Self {
402        let upper_half = (nonce.upper_half as u128) << 64;
403        let lower_half = nonce.lower_half as u128;
404        upper_half | lower_half
405    }
406}
407
408pub fn split_worktree_update(
409    mut message: UpdateWorktree,
410    max_chunk_size: usize,
411) -> impl Iterator<Item = UpdateWorktree> {
412    let mut done = false;
413    iter::from_fn(move || {
414        if done {
415            return None;
416        }
417
418        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
419        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
420        done = message.updated_entries.is_empty();
421        Some(UpdateWorktree {
422            project_id: message.project_id,
423            worktree_id: message.worktree_id,
424            root_name: message.root_name.clone(),
425            updated_entries,
426            removed_entries: mem::take(&mut message.removed_entries),
427            scan_id: message.scan_id,
428            is_last_update: done && message.is_last_update,
429        })
430    })
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[gpui::test]
438    async fn test_buffer_size() {
439        let (tx, rx) = futures::channel::mpsc::unbounded();
440        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
441        sink.write(Message::Envelope(Envelope {
442            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
443                root_name: "abcdefg".repeat(10),
444                ..Default::default()
445            })),
446            ..Default::default()
447        }))
448        .await
449        .unwrap();
450        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
451        sink.write(Message::Envelope(Envelope {
452            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
453                root_name: "abcdefg".repeat(1000000),
454                ..Default::default()
455            })),
456            ..Default::default()
457        }))
458        .await
459        .unwrap();
460        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
461
462        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
463        stream.read().await.unwrap();
464        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
465        stream.read().await.unwrap();
466        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
467    }
468}