proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{cmp, iter, mem};
  9use std::{
 10    fmt::Debug,
 11    io,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14
 15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 16
 17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 18    const NAME: &'static str;
 19    const PRIORITY: MessagePriority;
 20    fn into_envelope(
 21        self,
 22        id: u32,
 23        responding_to: Option<u32>,
 24        original_sender_id: Option<u32>,
 25    ) -> Envelope;
 26    fn from_envelope(envelope: Envelope) -> Option<Self>;
 27}
 28
 29pub trait EntityMessage: EnvelopedMessage {
 30    fn remote_entity_id(&self) -> u64;
 31}
 32
 33pub trait RequestMessage: EnvelopedMessage {
 34    type Response: EnvelopedMessage;
 35}
 36
 37pub trait AnyTypedEnvelope: 'static + Send + Sync {
 38    fn payload_type_id(&self) -> TypeId;
 39    fn payload_type_name(&self) -> &'static str;
 40    fn as_any(&self) -> &dyn Any;
 41    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 42    fn is_background(&self) -> bool;
 43    fn original_sender_id(&self) -> Option<PeerId>;
 44}
 45
 46pub enum MessagePriority {
 47    Foreground,
 48    Background,
 49}
 50
 51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 52    fn payload_type_id(&self) -> TypeId {
 53        TypeId::of::<T>()
 54    }
 55
 56    fn payload_type_name(&self) -> &'static str {
 57        T::NAME
 58    }
 59
 60    fn as_any(&self) -> &dyn Any {
 61        self
 62    }
 63
 64    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 65        self
 66    }
 67
 68    fn is_background(&self) -> bool {
 69        matches!(T::PRIORITY, MessagePriority::Background)
 70    }
 71
 72    fn original_sender_id(&self) -> Option<PeerId> {
 73        self.original_sender_id
 74    }
 75}
 76
 77messages!(
 78    (Ack, Foreground),
 79    (AddProjectCollaborator, Foreground),
 80    (ApplyCodeAction, Background),
 81    (ApplyCodeActionResponse, Background),
 82    (ApplyCompletionAdditionalEdits, Background),
 83    (ApplyCompletionAdditionalEditsResponse, Background),
 84    (BufferReloaded, Foreground),
 85    (BufferSaved, Foreground),
 86    (Call, Foreground),
 87    (CallCanceled, Foreground),
 88    (CancelCall, Foreground),
 89    (ChannelMessageSent, Foreground),
 90    (CopyProjectEntry, Foreground),
 91    (CreateBufferForPeer, Foreground),
 92    (CreateProjectEntry, Foreground),
 93    (CreateRoom, Foreground),
 94    (CreateRoomResponse, Foreground),
 95    (DeclineCall, Foreground),
 96    (DeleteProjectEntry, Foreground),
 97    (Error, Foreground),
 98    (Follow, Foreground),
 99    (FollowResponse, Foreground),
100    (FormatBuffers, Foreground),
101    (FormatBuffersResponse, Foreground),
102    (FuzzySearchUsers, Foreground),
103    (GetChannelMessages, Foreground),
104    (GetChannelMessagesResponse, Foreground),
105    (GetChannels, Foreground),
106    (GetChannelsResponse, Foreground),
107    (GetCodeActions, Background),
108    (GetCodeActionsResponse, Background),
109    (GetHover, Background),
110    (GetHoverResponse, Background),
111    (GetCompletions, Background),
112    (GetCompletionsResponse, Background),
113    (GetDefinition, Background),
114    (GetDefinitionResponse, Background),
115    (GetTypeDefinition, Background),
116    (GetTypeDefinitionResponse, Background),
117    (GetDocumentHighlights, Background),
118    (GetDocumentHighlightsResponse, Background),
119    (GetReferences, Background),
120    (GetReferencesResponse, Background),
121    (GetProjectSymbols, Background),
122    (GetProjectSymbolsResponse, Background),
123    (GetUsers, Foreground),
124    (Hello, Foreground),
125    (IncomingCall, Foreground),
126    (UsersResponse, Foreground),
127    (JoinChannel, Foreground),
128    (JoinChannelResponse, Foreground),
129    (JoinProject, Foreground),
130    (JoinProjectResponse, Foreground),
131    (JoinRoom, Foreground),
132    (JoinRoomResponse, Foreground),
133    (LeaveChannel, Foreground),
134    (LeaveProject, Foreground),
135    (LeaveRoom, Foreground),
136    (OpenBufferById, Background),
137    (OpenBufferByPath, Background),
138    (OpenBufferForSymbol, Background),
139    (OpenBufferForSymbolResponse, Background),
140    (OpenBufferResponse, Background),
141    (PerformRename, Background),
142    (PerformRenameResponse, Background),
143    (Ping, Foreground),
144    (PrepareRename, Background),
145    (PrepareRenameResponse, Background),
146    (ProjectEntryResponse, Foreground),
147    (RemoveContact, Foreground),
148    (ReloadBuffers, Foreground),
149    (ReloadBuffersResponse, Foreground),
150    (RemoveProjectCollaborator, Foreground),
151    (RenameProjectEntry, Foreground),
152    (RequestContact, Foreground),
153    (RespondToContactRequest, Foreground),
154    (RoomUpdated, Foreground),
155    (SaveBuffer, Foreground),
156    (SearchProject, Background),
157    (SearchProjectResponse, Background),
158    (SendChannelMessage, Foreground),
159    (SendChannelMessageResponse, Foreground),
160    (ShareProject, Foreground),
161    (ShareProjectResponse, Foreground),
162    (ShowContacts, Foreground),
163    (StartLanguageServer, Foreground),
164    (Test, Foreground),
165    (Unfollow, Foreground),
166    (UnshareProject, Foreground),
167    (UpdateBuffer, Foreground),
168    (UpdateBufferFile, Foreground),
169    (UpdateContacts, Foreground),
170    (UpdateDiagnosticSummary, Foreground),
171    (UpdateFollowers, Foreground),
172    (UpdateInviteInfo, Foreground),
173    (UpdateLanguageServer, Foreground),
174    (UpdateParticipantLocation, Foreground),
175    (UpdateProject, Foreground),
176    (UpdateWorktree, Foreground),
177    (UpdateDiffBase, Background),
178    (GetPrivateUserInfo, Foreground),
179    (GetPrivateUserInfoResponse, Foreground),
180);
181
182request_messages!(
183    (ApplyCodeAction, ApplyCodeActionResponse),
184    (
185        ApplyCompletionAdditionalEdits,
186        ApplyCompletionAdditionalEditsResponse
187    ),
188    (Call, Ack),
189    (CancelCall, Ack),
190    (CopyProjectEntry, ProjectEntryResponse),
191    (CreateProjectEntry, ProjectEntryResponse),
192    (CreateRoom, CreateRoomResponse),
193    (DeclineCall, Ack),
194    (DeleteProjectEntry, ProjectEntryResponse),
195    (Follow, FollowResponse),
196    (FormatBuffers, FormatBuffersResponse),
197    (GetChannelMessages, GetChannelMessagesResponse),
198    (GetChannels, GetChannelsResponse),
199    (GetCodeActions, GetCodeActionsResponse),
200    (GetHover, GetHoverResponse),
201    (GetCompletions, GetCompletionsResponse),
202    (GetDefinition, GetDefinitionResponse),
203    (GetTypeDefinition, GetTypeDefinitionResponse),
204    (GetDocumentHighlights, GetDocumentHighlightsResponse),
205    (GetReferences, GetReferencesResponse),
206    (GetPrivateUserInfo, GetPrivateUserInfoResponse),
207    (GetProjectSymbols, GetProjectSymbolsResponse),
208    (FuzzySearchUsers, UsersResponse),
209    (GetUsers, UsersResponse),
210    (JoinChannel, JoinChannelResponse),
211    (JoinProject, JoinProjectResponse),
212    (JoinRoom, JoinRoomResponse),
213    (IncomingCall, Ack),
214    (OpenBufferById, OpenBufferResponse),
215    (OpenBufferByPath, OpenBufferResponse),
216    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
217    (Ping, Ack),
218    (PerformRename, PerformRenameResponse),
219    (PrepareRename, PrepareRenameResponse),
220    (ReloadBuffers, ReloadBuffersResponse),
221    (RequestContact, Ack),
222    (RemoveContact, Ack),
223    (RespondToContactRequest, Ack),
224    (RenameProjectEntry, ProjectEntryResponse),
225    (SaveBuffer, BufferSaved),
226    (SearchProject, SearchProjectResponse),
227    (SendChannelMessage, SendChannelMessageResponse),
228    (ShareProject, ShareProjectResponse),
229    (Test, Test),
230    (UpdateBuffer, Ack),
231    (UpdateDiagnosticSummary, Ack),
232    (UpdateParticipantLocation, Ack),
233    (UpdateProject, Ack),
234    (UpdateWorktree, Ack),
235);
236
237entity_messages!(
238    project_id,
239    AddProjectCollaborator,
240    ApplyCodeAction,
241    ApplyCompletionAdditionalEdits,
242    BufferReloaded,
243    BufferSaved,
244    CopyProjectEntry,
245    CreateBufferForPeer,
246    CreateProjectEntry,
247    DeleteProjectEntry,
248    Follow,
249    FormatBuffers,
250    GetCodeActions,
251    GetCompletions,
252    GetDefinition,
253    GetTypeDefinition,
254    GetDocumentHighlights,
255    GetHover,
256    GetReferences,
257    GetProjectSymbols,
258    JoinProject,
259    LeaveProject,
260    OpenBufferById,
261    OpenBufferByPath,
262    OpenBufferForSymbol,
263    PerformRename,
264    PrepareRename,
265    ReloadBuffers,
266    RemoveProjectCollaborator,
267    RenameProjectEntry,
268    SaveBuffer,
269    SearchProject,
270    StartLanguageServer,
271    Unfollow,
272    UnshareProject,
273    UpdateBuffer,
274    UpdateBufferFile,
275    UpdateDiagnosticSummary,
276    UpdateFollowers,
277    UpdateLanguageServer,
278    UpdateProject,
279    UpdateWorktree,
280    UpdateDiffBase
281);
282
283entity_messages!(channel_id, ChannelMessageSent);
284
285const KIB: usize = 1024;
286const MIB: usize = KIB * 1024;
287const MAX_BUFFER_LEN: usize = MIB;
288
289/// A stream of protobuf messages.
290pub struct MessageStream<S> {
291    stream: S,
292    encoding_buffer: Vec<u8>,
293}
294
295#[allow(clippy::large_enum_variant)]
296#[derive(Debug)]
297pub enum Message {
298    Envelope(Envelope),
299    Ping,
300    Pong,
301}
302
303impl<S> MessageStream<S> {
304    pub fn new(stream: S) -> Self {
305        Self {
306            stream,
307            encoding_buffer: Vec::new(),
308        }
309    }
310
311    pub fn inner_mut(&mut self) -> &mut S {
312        &mut self.stream
313    }
314}
315
316impl<S> MessageStream<S>
317where
318    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
319{
320    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
321        #[cfg(any(test, feature = "test-support"))]
322        const COMPRESSION_LEVEL: i32 = -7;
323
324        #[cfg(not(any(test, feature = "test-support")))]
325        const COMPRESSION_LEVEL: i32 = 4;
326
327        match message {
328            Message::Envelope(message) => {
329                self.encoding_buffer.reserve(message.encoded_len());
330                message
331                    .encode(&mut self.encoding_buffer)
332                    .map_err(io::Error::from)?;
333                let buffer =
334                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
335                        .unwrap();
336
337                self.encoding_buffer.clear();
338                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
339                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
340            }
341            Message::Ping => {
342                self.stream
343                    .send(WebSocketMessage::Ping(Default::default()))
344                    .await?;
345            }
346            Message::Pong => {
347                self.stream
348                    .send(WebSocketMessage::Pong(Default::default()))
349                    .await?;
350            }
351        }
352
353        Ok(())
354    }
355}
356
357impl<S> MessageStream<S>
358where
359    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
360{
361    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
362        while let Some(bytes) = self.stream.next().await {
363            match bytes? {
364                WebSocketMessage::Binary(bytes) => {
365                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
366                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
367                        .map_err(io::Error::from)?;
368
369                    self.encoding_buffer.clear();
370                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
371                    return Ok(Message::Envelope(envelope));
372                }
373                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
374                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
375                WebSocketMessage::Close(_) => break,
376                _ => {}
377            }
378        }
379        Err(anyhow!("connection closed"))
380    }
381}
382
383impl From<Timestamp> for SystemTime {
384    fn from(val: Timestamp) -> Self {
385        UNIX_EPOCH
386            .checked_add(Duration::new(val.seconds, val.nanos))
387            .unwrap()
388    }
389}
390
391impl From<SystemTime> for Timestamp {
392    fn from(time: SystemTime) -> Self {
393        let duration = time.duration_since(UNIX_EPOCH).unwrap();
394        Self {
395            seconds: duration.as_secs(),
396            nanos: duration.subsec_nanos(),
397        }
398    }
399}
400
401impl From<u128> for Nonce {
402    fn from(nonce: u128) -> Self {
403        let upper_half = (nonce >> 64) as u64;
404        let lower_half = nonce as u64;
405        Self {
406            upper_half,
407            lower_half,
408        }
409    }
410}
411
412impl From<Nonce> for u128 {
413    fn from(nonce: Nonce) -> Self {
414        let upper_half = (nonce.upper_half as u128) << 64;
415        let lower_half = nonce.lower_half as u128;
416        upper_half | lower_half
417    }
418}
419
420pub fn split_worktree_update(
421    mut message: UpdateWorktree,
422    max_chunk_size: usize,
423) -> impl Iterator<Item = UpdateWorktree> {
424    let mut done = false;
425    iter::from_fn(move || {
426        if done {
427            return None;
428        }
429
430        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
431        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
432        done = message.updated_entries.is_empty();
433        Some(UpdateWorktree {
434            project_id: message.project_id,
435            worktree_id: message.worktree_id,
436            root_name: message.root_name.clone(),
437            abs_path: message.abs_path.clone(),
438            updated_entries,
439            removed_entries: mem::take(&mut message.removed_entries),
440            scan_id: message.scan_id,
441            is_last_update: done && message.is_last_update,
442        })
443    })
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    #[gpui::test]
451    async fn test_buffer_size() {
452        let (tx, rx) = futures::channel::mpsc::unbounded();
453        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
454        sink.write(Message::Envelope(Envelope {
455            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
456                root_name: "abcdefg".repeat(10),
457                ..Default::default()
458            })),
459            ..Default::default()
460        }))
461        .await
462        .unwrap();
463        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
464        sink.write(Message::Envelope(Envelope {
465            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
466                root_name: "abcdefg".repeat(1000000),
467                ..Default::default()
468            })),
469            ..Default::default()
470        }))
471        .await
472        .unwrap();
473        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
474
475        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
476        stream.read().await.unwrap();
477        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
478        stream.read().await.unwrap();
479        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
480    }
481}