proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{cmp, iter, mem};
  9use std::{
 10    fmt::Debug,
 11    io,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14
 15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 16
 17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 18    const NAME: &'static str;
 19    const PRIORITY: MessagePriority;
 20    fn into_envelope(
 21        self,
 22        id: u32,
 23        responding_to: Option<u32>,
 24        original_sender_id: Option<u32>,
 25    ) -> Envelope;
 26    fn from_envelope(envelope: Envelope) -> Option<Self>;
 27}
 28
 29pub trait EntityMessage: EnvelopedMessage {
 30    fn remote_entity_id(&self) -> u64;
 31}
 32
 33pub trait RequestMessage: EnvelopedMessage {
 34    type Response: EnvelopedMessage;
 35}
 36
 37pub trait AnyTypedEnvelope: 'static + Send + Sync {
 38    fn payload_type_id(&self) -> TypeId;
 39    fn payload_type_name(&self) -> &'static str;
 40    fn as_any(&self) -> &dyn Any;
 41    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 42    fn is_background(&self) -> bool;
 43    fn original_sender_id(&self) -> Option<PeerId>;
 44}
 45
 46pub enum MessagePriority {
 47    Foreground,
 48    Background,
 49}
 50
 51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 52    fn payload_type_id(&self) -> TypeId {
 53        TypeId::of::<T>()
 54    }
 55
 56    fn payload_type_name(&self) -> &'static str {
 57        T::NAME
 58    }
 59
 60    fn as_any(&self) -> &dyn Any {
 61        self
 62    }
 63
 64    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 65        self
 66    }
 67
 68    fn is_background(&self) -> bool {
 69        matches!(T::PRIORITY, MessagePriority::Background)
 70    }
 71
 72    fn original_sender_id(&self) -> Option<PeerId> {
 73        self.original_sender_id
 74    }
 75}
 76
 77messages!(
 78    (Ack, Foreground),
 79    (AddProjectCollaborator, Foreground),
 80    (ApplyCodeAction, Background),
 81    (ApplyCodeActionResponse, Background),
 82    (ApplyCompletionAdditionalEdits, Background),
 83    (ApplyCompletionAdditionalEditsResponse, Background),
 84    (BufferReloaded, Foreground),
 85    (BufferSaved, Foreground),
 86    (Call, Foreground),
 87    (CallCanceled, Foreground),
 88    (CancelCall, Foreground),
 89    (ChannelMessageSent, Foreground),
 90    (CopyProjectEntry, Foreground),
 91    (CreateBufferForPeer, Foreground),
 92    (CreateProjectEntry, Foreground),
 93    (CreateRoom, Foreground),
 94    (CreateRoomResponse, Foreground),
 95    (DeclineCall, Foreground),
 96    (DeleteProjectEntry, Foreground),
 97    (Error, Foreground),
 98    (Follow, Foreground),
 99    (FollowResponse, Foreground),
100    (FormatBuffers, Foreground),
101    (FormatBuffersResponse, Foreground),
102    (FuzzySearchUsers, Foreground),
103    (GetChannelMessages, Foreground),
104    (GetChannelMessagesResponse, Foreground),
105    (GetChannels, Foreground),
106    (GetChannelsResponse, Foreground),
107    (GetCodeActions, Background),
108    (GetCodeActionsResponse, Background),
109    (GetHover, Background),
110    (GetHoverResponse, Background),
111    (GetCompletions, Background),
112    (GetCompletionsResponse, Background),
113    (GetDefinition, Background),
114    (GetDefinitionResponse, Background),
115    (GetTypeDefinition, Background),
116    (GetTypeDefinitionResponse, Background),
117    (GetDocumentHighlights, Background),
118    (GetDocumentHighlightsResponse, Background),
119    (GetReferences, Background),
120    (GetReferencesResponse, Background),
121    (GetProjectSymbols, Background),
122    (GetProjectSymbolsResponse, Background),
123    (GetUsers, Foreground),
124    (IncomingCall, Foreground),
125    (UsersResponse, Foreground),
126    (JoinChannel, Foreground),
127    (JoinChannelResponse, Foreground),
128    (JoinProject, Foreground),
129    (JoinProjectResponse, Foreground),
130    (JoinRoom, Foreground),
131    (JoinRoomResponse, Foreground),
132    (LeaveChannel, Foreground),
133    (LeaveProject, Foreground),
134    (LeaveRoom, Foreground),
135    (OpenBufferById, Background),
136    (OpenBufferByPath, Background),
137    (OpenBufferForSymbol, Background),
138    (OpenBufferForSymbolResponse, Background),
139    (OpenBufferResponse, Background),
140    (PerformRename, Background),
141    (PerformRenameResponse, Background),
142    (PrepareRename, Background),
143    (PrepareRenameResponse, Background),
144    (ProjectEntryResponse, Foreground),
145    (RemoveContact, Foreground),
146    (Ping, Foreground),
147    (RegisterProjectActivity, Foreground),
148    (ReloadBuffers, Foreground),
149    (ReloadBuffersResponse, Foreground),
150    (RemoveProjectCollaborator, Foreground),
151    (RenameProjectEntry, Foreground),
152    (RequestContact, Foreground),
153    (RespondToContactRequest, Foreground),
154    (RoomUpdated, Foreground),
155    (SaveBuffer, Foreground),
156    (SearchProject, Background),
157    (SearchProjectResponse, Background),
158    (SendChannelMessage, Foreground),
159    (SendChannelMessageResponse, Foreground),
160    (ShareProject, Foreground),
161    (ShareProjectResponse, Foreground),
162    (ShowContacts, Foreground),
163    (StartLanguageServer, Foreground),
164    (Test, Foreground),
165    (Unfollow, Foreground),
166    (UnshareProject, Foreground),
167    (UpdateBuffer, Foreground),
168    (UpdateBufferFile, Foreground),
169    (UpdateContacts, Foreground),
170    (UpdateDiagnosticSummary, Foreground),
171    (UpdateFollowers, Foreground),
172    (UpdateInviteInfo, Foreground),
173    (UpdateLanguageServer, Foreground),
174    (UpdateParticipantLocation, Foreground),
175    (UpdateProject, Foreground),
176    (UpdateWorktree, Foreground),
177    (UpdateWorktreeExtensions, Background),
178);
179
180request_messages!(
181    (ApplyCodeAction, ApplyCodeActionResponse),
182    (
183        ApplyCompletionAdditionalEdits,
184        ApplyCompletionAdditionalEditsResponse
185    ),
186    (Call, Ack),
187    (CancelCall, Ack),
188    (CopyProjectEntry, ProjectEntryResponse),
189    (CreateProjectEntry, ProjectEntryResponse),
190    (CreateRoom, CreateRoomResponse),
191    (DeclineCall, Ack),
192    (DeleteProjectEntry, ProjectEntryResponse),
193    (Follow, FollowResponse),
194    (FormatBuffers, FormatBuffersResponse),
195    (GetChannelMessages, GetChannelMessagesResponse),
196    (GetChannels, GetChannelsResponse),
197    (GetCodeActions, GetCodeActionsResponse),
198    (GetHover, GetHoverResponse),
199    (GetCompletions, GetCompletionsResponse),
200    (GetDefinition, GetDefinitionResponse),
201    (GetTypeDefinition, GetTypeDefinitionResponse),
202    (GetDocumentHighlights, GetDocumentHighlightsResponse),
203    (GetReferences, GetReferencesResponse),
204    (GetProjectSymbols, GetProjectSymbolsResponse),
205    (FuzzySearchUsers, UsersResponse),
206    (GetUsers, UsersResponse),
207    (JoinChannel, JoinChannelResponse),
208    (JoinProject, JoinProjectResponse),
209    (JoinRoom, JoinRoomResponse),
210    (IncomingCall, Ack),
211    (OpenBufferById, OpenBufferResponse),
212    (OpenBufferByPath, OpenBufferResponse),
213    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
214    (Ping, Ack),
215    (PerformRename, PerformRenameResponse),
216    (PrepareRename, PrepareRenameResponse),
217    (ReloadBuffers, ReloadBuffersResponse),
218    (RequestContact, Ack),
219    (RemoveContact, Ack),
220    (RespondToContactRequest, Ack),
221    (RenameProjectEntry, ProjectEntryResponse),
222    (SaveBuffer, BufferSaved),
223    (SearchProject, SearchProjectResponse),
224    (SendChannelMessage, SendChannelMessageResponse),
225    (ShareProject, ShareProjectResponse),
226    (Test, Test),
227    (UpdateBuffer, Ack),
228    (UpdateParticipantLocation, Ack),
229    (UpdateWorktree, Ack),
230);
231
232entity_messages!(
233    project_id,
234    AddProjectCollaborator,
235    ApplyCodeAction,
236    ApplyCompletionAdditionalEdits,
237    BufferReloaded,
238    BufferSaved,
239    CopyProjectEntry,
240    CreateBufferForPeer,
241    CreateProjectEntry,
242    DeleteProjectEntry,
243    Follow,
244    FormatBuffers,
245    GetCodeActions,
246    GetCompletions,
247    GetDefinition,
248    GetTypeDefinition,
249    GetDocumentHighlights,
250    GetHover,
251    GetReferences,
252    GetProjectSymbols,
253    JoinProject,
254    LeaveProject,
255    OpenBufferById,
256    OpenBufferByPath,
257    OpenBufferForSymbol,
258    PerformRename,
259    PrepareRename,
260    RegisterProjectActivity,
261    ReloadBuffers,
262    RemoveProjectCollaborator,
263    RenameProjectEntry,
264    SaveBuffer,
265    SearchProject,
266    StartLanguageServer,
267    Unfollow,
268    UnshareProject,
269    UpdateBuffer,
270    UpdateBufferFile,
271    UpdateDiagnosticSummary,
272    UpdateFollowers,
273    UpdateLanguageServer,
274    UpdateProject,
275    UpdateWorktree,
276    UpdateWorktreeExtensions,
277);
278
279entity_messages!(channel_id, ChannelMessageSent);
280
281const KIB: usize = 1024;
282const MIB: usize = KIB * 1024;
283const MAX_BUFFER_LEN: usize = MIB;
284
285/// A stream of protobuf messages.
286pub struct MessageStream<S> {
287    stream: S,
288    encoding_buffer: Vec<u8>,
289}
290
291#[allow(clippy::large_enum_variant)]
292#[derive(Debug)]
293pub enum Message {
294    Envelope(Envelope),
295    Ping,
296    Pong,
297}
298
299impl<S> MessageStream<S> {
300    pub fn new(stream: S) -> Self {
301        Self {
302            stream,
303            encoding_buffer: Vec::new(),
304        }
305    }
306
307    pub fn inner_mut(&mut self) -> &mut S {
308        &mut self.stream
309    }
310}
311
312impl<S> MessageStream<S>
313where
314    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
315{
316    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
317        #[cfg(any(test, feature = "test-support"))]
318        const COMPRESSION_LEVEL: i32 = -7;
319
320        #[cfg(not(any(test, feature = "test-support")))]
321        const COMPRESSION_LEVEL: i32 = 4;
322
323        match message {
324            Message::Envelope(message) => {
325                self.encoding_buffer.reserve(message.encoded_len());
326                message
327                    .encode(&mut self.encoding_buffer)
328                    .map_err(io::Error::from)?;
329                let buffer =
330                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
331                        .unwrap();
332
333                self.encoding_buffer.clear();
334                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
335                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
336            }
337            Message::Ping => {
338                self.stream
339                    .send(WebSocketMessage::Ping(Default::default()))
340                    .await?;
341            }
342            Message::Pong => {
343                self.stream
344                    .send(WebSocketMessage::Pong(Default::default()))
345                    .await?;
346            }
347        }
348
349        Ok(())
350    }
351}
352
353impl<S> MessageStream<S>
354where
355    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
356{
357    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
358        while let Some(bytes) = self.stream.next().await {
359            match bytes? {
360                WebSocketMessage::Binary(bytes) => {
361                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
362                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
363                        .map_err(io::Error::from)?;
364
365                    self.encoding_buffer.clear();
366                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
367                    return Ok(Message::Envelope(envelope));
368                }
369                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
370                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
371                WebSocketMessage::Close(_) => break,
372                _ => {}
373            }
374        }
375        Err(anyhow!("connection closed"))
376    }
377}
378
379impl From<Timestamp> for SystemTime {
380    fn from(val: Timestamp) -> Self {
381        UNIX_EPOCH
382            .checked_add(Duration::new(val.seconds, val.nanos))
383            .unwrap()
384    }
385}
386
387impl From<SystemTime> for Timestamp {
388    fn from(time: SystemTime) -> Self {
389        let duration = time.duration_since(UNIX_EPOCH).unwrap();
390        Self {
391            seconds: duration.as_secs(),
392            nanos: duration.subsec_nanos(),
393        }
394    }
395}
396
397impl From<u128> for Nonce {
398    fn from(nonce: u128) -> Self {
399        let upper_half = (nonce >> 64) as u64;
400        let lower_half = nonce as u64;
401        Self {
402            upper_half,
403            lower_half,
404        }
405    }
406}
407
408impl From<Nonce> for u128 {
409    fn from(nonce: Nonce) -> Self {
410        let upper_half = (nonce.upper_half as u128) << 64;
411        let lower_half = nonce.lower_half as u128;
412        upper_half | lower_half
413    }
414}
415
416pub fn split_worktree_update(
417    mut message: UpdateWorktree,
418    max_chunk_size: usize,
419) -> impl Iterator<Item = UpdateWorktree> {
420    let mut done = false;
421    iter::from_fn(move || {
422        if done {
423            return None;
424        }
425
426        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
427        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
428        done = message.updated_entries.is_empty();
429        Some(UpdateWorktree {
430            project_id: message.project_id,
431            worktree_id: message.worktree_id,
432            root_name: message.root_name.clone(),
433            updated_entries,
434            removed_entries: mem::take(&mut message.removed_entries),
435            scan_id: message.scan_id,
436            is_last_update: done && message.is_last_update,
437        })
438    })
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    #[gpui::test]
446    async fn test_buffer_size() {
447        let (tx, rx) = futures::channel::mpsc::unbounded();
448        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
449        sink.write(Message::Envelope(Envelope {
450            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
451                root_name: "abcdefg".repeat(10),
452                ..Default::default()
453            })),
454            ..Default::default()
455        }))
456        .await
457        .unwrap();
458        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
459        sink.write(Message::Envelope(Envelope {
460            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
461                root_name: "abcdefg".repeat(1000000),
462                ..Default::default()
463            })),
464            ..Default::default()
465        }))
466        .await
467        .unwrap();
468        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
469
470        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
471        stream.read().await.unwrap();
472        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
473        stream.read().await.unwrap();
474        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
475    }
476}