proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{cmp, iter, mem};
  9use std::{
 10    fmt::Debug,
 11    io,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14
 15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 16
 17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 18    const NAME: &'static str;
 19    const PRIORITY: MessagePriority;
 20    fn into_envelope(
 21        self,
 22        id: u32,
 23        responding_to: Option<u32>,
 24        original_sender_id: Option<u32>,
 25    ) -> Envelope;
 26    fn from_envelope(envelope: Envelope) -> Option<Self>;
 27}
 28
 29pub trait EntityMessage: EnvelopedMessage {
 30    fn remote_entity_id(&self) -> u64;
 31}
 32
 33pub trait RequestMessage: EnvelopedMessage {
 34    type Response: EnvelopedMessage;
 35}
 36
 37pub trait AnyTypedEnvelope: 'static + Send + Sync {
 38    fn payload_type_id(&self) -> TypeId;
 39    fn payload_type_name(&self) -> &'static str;
 40    fn as_any(&self) -> &dyn Any;
 41    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 42    fn is_background(&self) -> bool;
 43    fn original_sender_id(&self) -> Option<PeerId>;
 44}
 45
 46pub enum MessagePriority {
 47    Foreground,
 48    Background,
 49}
 50
 51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 52    fn payload_type_id(&self) -> TypeId {
 53        TypeId::of::<T>()
 54    }
 55
 56    fn payload_type_name(&self) -> &'static str {
 57        T::NAME
 58    }
 59
 60    fn as_any(&self) -> &dyn Any {
 61        self
 62    }
 63
 64    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 65        self
 66    }
 67
 68    fn is_background(&self) -> bool {
 69        matches!(T::PRIORITY, MessagePriority::Background)
 70    }
 71
 72    fn original_sender_id(&self) -> Option<PeerId> {
 73        self.original_sender_id
 74    }
 75}
 76
 77messages!(
 78    (Ack, Foreground),
 79    (AddProjectCollaborator, Foreground),
 80    (ApplyCodeAction, Background),
 81    (ApplyCodeActionResponse, Background),
 82    (ApplyCompletionAdditionalEdits, Background),
 83    (ApplyCompletionAdditionalEditsResponse, Background),
 84    (BufferReloaded, Foreground),
 85    (BufferSaved, Foreground),
 86    (Call, Foreground),
 87    (CallCanceled, Foreground),
 88    (CancelCall, Foreground),
 89    (ChannelMessageSent, Foreground),
 90    (CopyProjectEntry, Foreground),
 91    (CreateBufferForPeer, Foreground),
 92    (CreateProjectEntry, Foreground),
 93    (CreateRoom, Foreground),
 94    (CreateRoomResponse, Foreground),
 95    (DeclineCall, Foreground),
 96    (DeleteProjectEntry, Foreground),
 97    (Error, Foreground),
 98    (Follow, Foreground),
 99    (FollowResponse, Foreground),
100    (FormatBuffers, Foreground),
101    (FormatBuffersResponse, Foreground),
102    (FuzzySearchUsers, Foreground),
103    (GetChannelMessages, Foreground),
104    (GetChannelMessagesResponse, Foreground),
105    (GetChannels, Foreground),
106    (GetChannelsResponse, Foreground),
107    (GetCodeActions, Background),
108    (GetCodeActionsResponse, Background),
109    (GetHover, Background),
110    (GetHoverResponse, Background),
111    (GetCompletions, Background),
112    (GetCompletionsResponse, Background),
113    (GetDefinition, Background),
114    (GetDefinitionResponse, Background),
115    (GetTypeDefinition, Background),
116    (GetTypeDefinitionResponse, Background),
117    (GetDocumentHighlights, Background),
118    (GetDocumentHighlightsResponse, Background),
119    (GetReferences, Background),
120    (GetReferencesResponse, Background),
121    (GetProjectSymbols, Background),
122    (GetProjectSymbolsResponse, Background),
123    (GetUsers, Foreground),
124    (IncomingCall, Foreground),
125    (UsersResponse, Foreground),
126    (JoinChannel, Foreground),
127    (JoinChannelResponse, Foreground),
128    (JoinProject, Foreground),
129    (JoinProjectResponse, Foreground),
130    (JoinRoom, Foreground),
131    (JoinRoomResponse, Foreground),
132    (LeaveChannel, Foreground),
133    (LeaveProject, Foreground),
134    (LeaveRoom, Foreground),
135    (OpenBufferById, Background),
136    (OpenBufferByPath, Background),
137    (OpenBufferForSymbol, Background),
138    (OpenBufferForSymbolResponse, Background),
139    (OpenBufferResponse, Background),
140    (PerformRename, Background),
141    (PerformRenameResponse, Background),
142    (PrepareRename, Background),
143    (PrepareRenameResponse, Background),
144    (ProjectEntryResponse, Foreground),
145    (RemoveContact, Foreground),
146    (Ping, Foreground),
147    (RegisterProjectActivity, Foreground),
148    (ReloadBuffers, Foreground),
149    (ReloadBuffersResponse, Foreground),
150    (RemoveProjectCollaborator, Foreground),
151    (RenameProjectEntry, Foreground),
152    (RequestContact, Foreground),
153    (RespondToContactRequest, Foreground),
154    (RoomUpdated, Foreground),
155    (SaveBuffer, Foreground),
156    (SearchProject, Background),
157    (SearchProjectResponse, Background),
158    (SendChannelMessage, Foreground),
159    (SendChannelMessageResponse, Foreground),
160    (ShareProject, Foreground),
161    (ShareProjectResponse, Foreground),
162    (ShowContacts, Foreground),
163    (StartLanguageServer, Foreground),
164    (Test, Foreground),
165    (Unfollow, Foreground),
166    (UnshareProject, Foreground),
167    (UpdateBuffer, Foreground),
168    (UpdateBufferFile, Foreground),
169    (UpdateContacts, Foreground),
170    (UpdateDiagnosticSummary, Foreground),
171    (UpdateFollowers, Foreground),
172    (UpdateInviteInfo, Foreground),
173    (UpdateLanguageServer, Foreground),
174    (UpdateParticipantLocation, Foreground),
175    (UpdateProject, Foreground),
176    (UpdateWorktree, Foreground),
177    (UpdateWorktreeExtensions, Background),
178    (UpdateDiffBase, Background),
179    (GetPrivateUserInfo, Foreground),
180    (GetPrivateUserInfoResponse, Foreground),
181);
182
183request_messages!(
184    (ApplyCodeAction, ApplyCodeActionResponse),
185    (
186        ApplyCompletionAdditionalEdits,
187        ApplyCompletionAdditionalEditsResponse
188    ),
189    (Call, Ack),
190    (CancelCall, Ack),
191    (CopyProjectEntry, ProjectEntryResponse),
192    (CreateProjectEntry, ProjectEntryResponse),
193    (CreateRoom, CreateRoomResponse),
194    (DeclineCall, Ack),
195    (DeleteProjectEntry, ProjectEntryResponse),
196    (Follow, FollowResponse),
197    (FormatBuffers, FormatBuffersResponse),
198    (GetChannelMessages, GetChannelMessagesResponse),
199    (GetChannels, GetChannelsResponse),
200    (GetCodeActions, GetCodeActionsResponse),
201    (GetHover, GetHoverResponse),
202    (GetCompletions, GetCompletionsResponse),
203    (GetDefinition, GetDefinitionResponse),
204    (GetTypeDefinition, GetTypeDefinitionResponse),
205    (GetDocumentHighlights, GetDocumentHighlightsResponse),
206    (GetReferences, GetReferencesResponse),
207    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
208    (GetProjectSymbols, GetProjectSymbolsResponse),
209    (FuzzySearchUsers, UsersResponse),
210    (GetUsers, UsersResponse),
211    (JoinChannel, JoinChannelResponse),
212    (JoinProject, JoinProjectResponse),
213    (JoinRoom, JoinRoomResponse),
214    (IncomingCall, Ack),
215    (OpenBufferById, OpenBufferResponse),
216    (OpenBufferByPath, OpenBufferResponse),
217    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
218    (Ping, Ack),
219    (PerformRename, PerformRenameResponse),
220    (PrepareRename, PrepareRenameResponse),
221    (ReloadBuffers, ReloadBuffersResponse),
222    (RequestContact, Ack),
223    (RemoveContact, Ack),
224    (RespondToContactRequest, Ack),
225    (RenameProjectEntry, ProjectEntryResponse),
226    (SaveBuffer, BufferSaved),
227    (SearchProject, SearchProjectResponse),
228    (SendChannelMessage, SendChannelMessageResponse),
229    (ShareProject, ShareProjectResponse),
230    (Test, Test),
231    (UpdateBuffer, Ack),
232    (UpdateParticipantLocation, Ack),
233    (UpdateWorktree, Ack),
234);
235
236entity_messages!(
237    project_id,
238    AddProjectCollaborator,
239    ApplyCodeAction,
240    ApplyCompletionAdditionalEdits,
241    BufferReloaded,
242    BufferSaved,
243    CopyProjectEntry,
244    CreateBufferForPeer,
245    CreateProjectEntry,
246    DeleteProjectEntry,
247    Follow,
248    FormatBuffers,
249    GetCodeActions,
250    GetCompletions,
251    GetDefinition,
252    GetTypeDefinition,
253    GetDocumentHighlights,
254    GetHover,
255    GetReferences,
256    GetProjectSymbols,
257    JoinProject,
258    LeaveProject,
259    OpenBufferById,
260    OpenBufferByPath,
261    OpenBufferForSymbol,
262    PerformRename,
263    PrepareRename,
264    RegisterProjectActivity,
265    ReloadBuffers,
266    RemoveProjectCollaborator,
267    RenameProjectEntry,
268    SaveBuffer,
269    SearchProject,
270    StartLanguageServer,
271    Unfollow,
272    UnshareProject,
273    UpdateBuffer,
274    UpdateBufferFile,
275    UpdateDiagnosticSummary,
276    UpdateFollowers,
277    UpdateLanguageServer,
278    UpdateProject,
279    UpdateWorktree,
280    UpdateWorktreeExtensions,
281    UpdateDiffBase
282);
283
284entity_messages!(channel_id, ChannelMessageSent);
285
286const KIB: usize = 1024;
287const MIB: usize = KIB * 1024;
288const MAX_BUFFER_LEN: usize = MIB;
289
290/// A stream of protobuf messages.
291pub struct MessageStream<S> {
292    stream: S,
293    encoding_buffer: Vec<u8>,
294}
295
296#[allow(clippy::large_enum_variant)]
297#[derive(Debug)]
298pub enum Message {
299    Envelope(Envelope),
300    Ping,
301    Pong,
302}
303
304impl<S> MessageStream<S> {
305    pub fn new(stream: S) -> Self {
306        Self {
307            stream,
308            encoding_buffer: Vec::new(),
309        }
310    }
311
312    pub fn inner_mut(&mut self) -> &mut S {
313        &mut self.stream
314    }
315}
316
317impl<S> MessageStream<S>
318where
319    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
320{
321    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
322        #[cfg(any(test, feature = "test-support"))]
323        const COMPRESSION_LEVEL: i32 = -7;
324
325        #[cfg(not(any(test, feature = "test-support")))]
326        const COMPRESSION_LEVEL: i32 = 4;
327
328        match message {
329            Message::Envelope(message) => {
330                self.encoding_buffer.reserve(message.encoded_len());
331                message
332                    .encode(&mut self.encoding_buffer)
333                    .map_err(io::Error::from)?;
334                let buffer =
335                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
336                        .unwrap();
337
338                self.encoding_buffer.clear();
339                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
340                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
341            }
342            Message::Ping => {
343                self.stream
344                    .send(WebSocketMessage::Ping(Default::default()))
345                    .await?;
346            }
347            Message::Pong => {
348                self.stream
349                    .send(WebSocketMessage::Pong(Default::default()))
350                    .await?;
351            }
352        }
353
354        Ok(())
355    }
356}
357
358impl<S> MessageStream<S>
359where
360    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
361{
362    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
363        while let Some(bytes) = self.stream.next().await {
364            match bytes? {
365                WebSocketMessage::Binary(bytes) => {
366                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
367                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
368                        .map_err(io::Error::from)?;
369
370                    self.encoding_buffer.clear();
371                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
372                    return Ok(Message::Envelope(envelope));
373                }
374                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
375                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
376                WebSocketMessage::Close(_) => break,
377                _ => {}
378            }
379        }
380        Err(anyhow!("connection closed"))
381    }
382}
383
384impl From<Timestamp> for SystemTime {
385    fn from(val: Timestamp) -> Self {
386        UNIX_EPOCH
387            .checked_add(Duration::new(val.seconds, val.nanos))
388            .unwrap()
389    }
390}
391
392impl From<SystemTime> for Timestamp {
393    fn from(time: SystemTime) -> Self {
394        let duration = time.duration_since(UNIX_EPOCH).unwrap();
395        Self {
396            seconds: duration.as_secs(),
397            nanos: duration.subsec_nanos(),
398        }
399    }
400}
401
402impl From<u128> for Nonce {
403    fn from(nonce: u128) -> Self {
404        let upper_half = (nonce >> 64) as u64;
405        let lower_half = nonce as u64;
406        Self {
407            upper_half,
408            lower_half,
409        }
410    }
411}
412
413impl From<Nonce> for u128 {
414    fn from(nonce: Nonce) -> Self {
415        let upper_half = (nonce.upper_half as u128) << 64;
416        let lower_half = nonce.lower_half as u128;
417        upper_half | lower_half
418    }
419}
420
421pub fn split_worktree_update(
422    mut message: UpdateWorktree,
423    max_chunk_size: usize,
424) -> impl Iterator<Item = UpdateWorktree> {
425    let mut done = false;
426    iter::from_fn(move || {
427        if done {
428            return None;
429        }
430
431        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
432        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
433        done = message.updated_entries.is_empty();
434        Some(UpdateWorktree {
435            project_id: message.project_id,
436            worktree_id: message.worktree_id,
437            root_name: message.root_name.clone(),
438            updated_entries,
439            removed_entries: mem::take(&mut message.removed_entries),
440            scan_id: message.scan_id,
441            is_last_update: done && message.is_last_update,
442        })
443    })
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    #[gpui::test]
451    async fn test_buffer_size() {
452        let (tx, rx) = futures::channel::mpsc::unbounded();
453        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
454        sink.write(Message::Envelope(Envelope {
455            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
456                root_name: "abcdefg".repeat(10),
457                ..Default::default()
458            })),
459            ..Default::default()
460        }))
461        .await
462        .unwrap();
463        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
464        sink.write(Message::Envelope(Envelope {
465            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
466                root_name: "abcdefg".repeat(1000000),
467                ..Default::default()
468            })),
469            ..Default::default()
470        }))
471        .await
472        .unwrap();
473        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
474
475        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
476        stream.read().await.unwrap();
477        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
478        stream.read().await.unwrap();
479        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
480    }
481}