proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{cmp, iter, mem};
  9use std::{
 10    fmt::Debug,
 11    io,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14
 15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 16
 17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 18    const NAME: &'static str;
 19    const PRIORITY: MessagePriority;
 20    fn into_envelope(
 21        self,
 22        id: u32,
 23        responding_to: Option<u32>,
 24        original_sender_id: Option<u32>,
 25    ) -> Envelope;
 26    fn from_envelope(envelope: Envelope) -> Option<Self>;
 27}
 28
 29pub trait EntityMessage: EnvelopedMessage {
 30    fn remote_entity_id(&self) -> u64;
 31}
 32
 33pub trait RequestMessage: EnvelopedMessage {
 34    type Response: EnvelopedMessage;
 35}
 36
 37pub trait AnyTypedEnvelope: 'static + Send + Sync {
 38    fn payload_type_id(&self) -> TypeId;
 39    fn payload_type_name(&self) -> &'static str;
 40    fn as_any(&self) -> &dyn Any;
 41    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 42    fn is_background(&self) -> bool;
 43    fn original_sender_id(&self) -> Option<PeerId>;
 44}
 45
 46pub enum MessagePriority {
 47    Foreground,
 48    Background,
 49}
 50
 51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 52    fn payload_type_id(&self) -> TypeId {
 53        TypeId::of::<T>()
 54    }
 55
 56    fn payload_type_name(&self) -> &'static str {
 57        T::NAME
 58    }
 59
 60    fn as_any(&self) -> &dyn Any {
 61        self
 62    }
 63
 64    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 65        self
 66    }
 67
 68    fn is_background(&self) -> bool {
 69        matches!(T::PRIORITY, MessagePriority::Background)
 70    }
 71
 72    fn original_sender_id(&self) -> Option<PeerId> {
 73        self.original_sender_id
 74    }
 75}
 76
 77messages!(
 78    (Ack, Foreground),
 79    (AddProjectCollaborator, Foreground),
 80    (ApplyCodeAction, Background),
 81    (ApplyCodeActionResponse, Background),
 82    (ApplyCompletionAdditionalEdits, Background),
 83    (ApplyCompletionAdditionalEditsResponse, Background),
 84    (BufferReloaded, Foreground),
 85    (BufferSaved, Foreground),
 86    (RemoveContact, Foreground),
 87    (ChannelMessageSent, Foreground),
 88    (CopyProjectEntry, Foreground),
 89    (CreateBufferForPeer, Foreground),
 90    (CreateProjectEntry, Foreground),
 91    (DeleteProjectEntry, Foreground),
 92    (Error, Foreground),
 93    (Follow, Foreground),
 94    (FollowResponse, Foreground),
 95    (FormatBuffers, Foreground),
 96    (FormatBuffersResponse, Foreground),
 97    (FuzzySearchUsers, Foreground),
 98    (GetChannelMessages, Foreground),
 99    (GetChannelMessagesResponse, Foreground),
100    (GetChannels, Foreground),
101    (GetChannelsResponse, Foreground),
102    (GetCodeActions, Background),
103    (GetCodeActionsResponse, Background),
104    (GetHover, Background),
105    (GetHoverResponse, Background),
106    (GetCompletions, Background),
107    (GetCompletionsResponse, Background),
108    (GetDefinition, Background),
109    (GetDefinitionResponse, Background),
110    (GetTypeDefinition, Background),
111    (GetTypeDefinitionResponse, Background),
112    (GetDocumentHighlights, Background),
113    (GetDocumentHighlightsResponse, Background),
114    (GetReferences, Background),
115    (GetReferencesResponse, Background),
116    (GetProjectSymbols, Background),
117    (GetProjectSymbolsResponse, Background),
118    (GetUsers, Foreground),
119    (UsersResponse, Foreground),
120    (JoinChannel, Foreground),
121    (JoinChannelResponse, Foreground),
122    (JoinProject, Foreground),
123    (JoinProjectResponse, Foreground),
124    (JoinProjectRequestCancelled, Foreground),
125    (LeaveChannel, Foreground),
126    (LeaveProject, Foreground),
127    (OpenBufferById, Background),
128    (OpenBufferByPath, Background),
129    (OpenBufferForSymbol, Background),
130    (OpenBufferForSymbolResponse, Background),
131    (OpenBufferResponse, Background),
132    (PerformRename, Background),
133    (PerformRenameResponse, Background),
134    (PrepareRename, Background),
135    (PrepareRenameResponse, Background),
136    (ProjectEntryResponse, Foreground),
137    (ProjectUnshared, Foreground),
138    (RegisterProjectResponse, Foreground),
139    (Ping, Foreground),
140    (RegisterProject, Foreground),
141    (RegisterProjectActivity, Foreground),
142    (ReloadBuffers, Foreground),
143    (ReloadBuffersResponse, Foreground),
144    (RemoveProjectCollaborator, Foreground),
145    (RenameProjectEntry, Foreground),
146    (RequestContact, Foreground),
147    (RequestJoinProject, Foreground),
148    (RespondToContactRequest, Foreground),
149    (RespondToJoinProjectRequest, Foreground),
150    (SaveBuffer, Foreground),
151    (SearchProject, Background),
152    (SearchProjectResponse, Background),
153    (SendChannelMessage, Foreground),
154    (SendChannelMessageResponse, Foreground),
155    (ShowContacts, Foreground),
156    (StartLanguageServer, Foreground),
157    (Test, Foreground),
158    (Unfollow, Foreground),
159    (UnregisterProject, Foreground),
160    (UpdateBuffer, Foreground),
161    (UpdateBufferFile, Foreground),
162    (UpdateContacts, Foreground),
163    (UpdateDiagnosticSummary, Foreground),
164    (UpdateFollowers, Foreground),
165    (UpdateInviteInfo, Foreground),
166    (UpdateLanguageServer, Foreground),
167    (UpdateProject, Foreground),
168    (UpdateWorktree, Foreground),
169    (UpdateWorktreeExtensions, Background),
170);
171
172request_messages!(
173    (ApplyCodeAction, ApplyCodeActionResponse),
174    (
175        ApplyCompletionAdditionalEdits,
176        ApplyCompletionAdditionalEditsResponse
177    ),
178    (CopyProjectEntry, ProjectEntryResponse),
179    (CreateProjectEntry, ProjectEntryResponse),
180    (DeleteProjectEntry, ProjectEntryResponse),
181    (Follow, FollowResponse),
182    (FormatBuffers, FormatBuffersResponse),
183    (GetChannelMessages, GetChannelMessagesResponse),
184    (GetChannels, GetChannelsResponse),
185    (GetCodeActions, GetCodeActionsResponse),
186    (GetHover, GetHoverResponse),
187    (GetCompletions, GetCompletionsResponse),
188    (GetDefinition, GetDefinitionResponse),
189    (GetTypeDefinition, GetTypeDefinitionResponse),
190    (GetDocumentHighlights, GetDocumentHighlightsResponse),
191    (GetReferences, GetReferencesResponse),
192    (GetProjectSymbols, GetProjectSymbolsResponse),
193    (FuzzySearchUsers, UsersResponse),
194    (GetUsers, UsersResponse),
195    (JoinChannel, JoinChannelResponse),
196    (JoinProject, JoinProjectResponse),
197    (OpenBufferById, OpenBufferResponse),
198    (OpenBufferByPath, OpenBufferResponse),
199    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
200    (Ping, Ack),
201    (PerformRename, PerformRenameResponse),
202    (PrepareRename, PrepareRenameResponse),
203    (RegisterProject, RegisterProjectResponse),
204    (ReloadBuffers, ReloadBuffersResponse),
205    (RequestContact, Ack),
206    (RemoveContact, Ack),
207    (RespondToContactRequest, Ack),
208    (RenameProjectEntry, ProjectEntryResponse),
209    (SaveBuffer, BufferSaved),
210    (SearchProject, SearchProjectResponse),
211    (SendChannelMessage, SendChannelMessageResponse),
212    (Test, Test),
213    (UnregisterProject, Ack),
214    (UpdateBuffer, Ack),
215    (UpdateWorktree, Ack),
216);
217
218entity_messages!(
219    project_id,
220    AddProjectCollaborator,
221    ApplyCodeAction,
222    ApplyCompletionAdditionalEdits,
223    BufferReloaded,
224    BufferSaved,
225    CopyProjectEntry,
226    CreateBufferForPeer,
227    CreateProjectEntry,
228    DeleteProjectEntry,
229    Follow,
230    FormatBuffers,
231    GetCodeActions,
232    GetCompletions,
233    GetDefinition,
234    GetTypeDefinition,
235    GetDocumentHighlights,
236    GetHover,
237    GetReferences,
238    GetProjectSymbols,
239    JoinProject,
240    JoinProjectRequestCancelled,
241    LeaveProject,
242    OpenBufferById,
243    OpenBufferByPath,
244    OpenBufferForSymbol,
245    PerformRename,
246    PrepareRename,
247    ProjectUnshared,
248    RegisterProjectActivity,
249    ReloadBuffers,
250    RemoveProjectCollaborator,
251    RenameProjectEntry,
252    RequestJoinProject,
253    SaveBuffer,
254    SearchProject,
255    StartLanguageServer,
256    Unfollow,
257    UnregisterProject,
258    UpdateBuffer,
259    UpdateBufferFile,
260    UpdateDiagnosticSummary,
261    UpdateFollowers,
262    UpdateLanguageServer,
263    UpdateProject,
264    UpdateWorktree,
265    UpdateWorktreeExtensions,
266);
267
268entity_messages!(channel_id, ChannelMessageSent);
269
270const KIB: usize = 1024;
271const MIB: usize = KIB * 1024;
272const MAX_BUFFER_LEN: usize = MIB;
273
274/// A stream of protobuf messages.
275pub struct MessageStream<S> {
276    stream: S,
277    encoding_buffer: Vec<u8>,
278}
279
280#[allow(clippy::large_enum_variant)]
281#[derive(Debug)]
282pub enum Message {
283    Envelope(Envelope),
284    Ping,
285    Pong,
286}
287
288impl<S> MessageStream<S> {
289    pub fn new(stream: S) -> Self {
290        Self {
291            stream,
292            encoding_buffer: Vec::new(),
293        }
294    }
295
296    pub fn inner_mut(&mut self) -> &mut S {
297        &mut self.stream
298    }
299}
300
301impl<S> MessageStream<S>
302where
303    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
304{
305    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
306        #[cfg(any(test, feature = "test-support"))]
307        const COMPRESSION_LEVEL: i32 = -7;
308
309        #[cfg(not(any(test, feature = "test-support")))]
310        const COMPRESSION_LEVEL: i32 = 4;
311
312        match message {
313            Message::Envelope(message) => {
314                self.encoding_buffer.reserve(message.encoded_len());
315                message
316                    .encode(&mut self.encoding_buffer)
317                    .map_err(io::Error::from)?;
318                let buffer =
319                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
320                        .unwrap();
321
322                self.encoding_buffer.clear();
323                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
324                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
325            }
326            Message::Ping => {
327                self.stream
328                    .send(WebSocketMessage::Ping(Default::default()))
329                    .await?;
330            }
331            Message::Pong => {
332                self.stream
333                    .send(WebSocketMessage::Pong(Default::default()))
334                    .await?;
335            }
336        }
337
338        Ok(())
339    }
340}
341
342impl<S> MessageStream<S>
343where
344    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
345{
346    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
347        while let Some(bytes) = self.stream.next().await {
348            match bytes? {
349                WebSocketMessage::Binary(bytes) => {
350                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
351                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
352                        .map_err(io::Error::from)?;
353
354                    self.encoding_buffer.clear();
355                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
356                    return Ok(Message::Envelope(envelope));
357                }
358                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
359                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
360                WebSocketMessage::Close(_) => break,
361                _ => {}
362            }
363        }
364        Err(anyhow!("connection closed"))
365    }
366}
367
368impl From<Timestamp> for SystemTime {
369    fn from(val: Timestamp) -> Self {
370        UNIX_EPOCH
371            .checked_add(Duration::new(val.seconds, val.nanos))
372            .unwrap()
373    }
374}
375
376impl From<SystemTime> for Timestamp {
377    fn from(time: SystemTime) -> Self {
378        let duration = time.duration_since(UNIX_EPOCH).unwrap();
379        Self {
380            seconds: duration.as_secs(),
381            nanos: duration.subsec_nanos(),
382        }
383    }
384}
385
386impl From<u128> for Nonce {
387    fn from(nonce: u128) -> Self {
388        let upper_half = (nonce >> 64) as u64;
389        let lower_half = nonce as u64;
390        Self {
391            upper_half,
392            lower_half,
393        }
394    }
395}
396
397impl From<Nonce> for u128 {
398    fn from(nonce: Nonce) -> Self {
399        let upper_half = (nonce.upper_half as u128) << 64;
400        let lower_half = nonce.lower_half as u128;
401        upper_half | lower_half
402    }
403}
404
405pub fn split_worktree_update(
406    mut message: UpdateWorktree,
407    max_chunk_size: usize,
408) -> impl Iterator<Item = UpdateWorktree> {
409    let mut done = false;
410    iter::from_fn(move || {
411        if done {
412            return None;
413        }
414
415        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
416        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
417        done = message.updated_entries.is_empty();
418        Some(UpdateWorktree {
419            project_id: message.project_id,
420            worktree_id: message.worktree_id,
421            root_name: message.root_name.clone(),
422            updated_entries,
423            removed_entries: mem::take(&mut message.removed_entries),
424            scan_id: message.scan_id,
425            is_last_update: done && message.is_last_update,
426        })
427    })
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    #[gpui::test]
435    async fn test_buffer_size() {
436        let (tx, rx) = futures::channel::mpsc::unbounded();
437        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
438        sink.write(Message::Envelope(Envelope {
439            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
440                root_name: "abcdefg".repeat(10),
441                ..Default::default()
442            })),
443            ..Default::default()
444        }))
445        .await
446        .unwrap();
447        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
448        sink.write(Message::Envelope(Envelope {
449            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
450                root_name: "abcdefg".repeat(1000000),
451                ..Default::default()
452            })),
453            ..Default::default()
454        }))
455        .await
456        .unwrap();
457        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
458
459        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
460        stream.read().await.unwrap();
461        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
462        stream.read().await.unwrap();
463        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
464    }
465}