proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{cmp, iter, mem};
  9use std::{
 10    fmt::Debug,
 11    io,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14
 15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 16
 17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 18    const NAME: &'static str;
 19    const PRIORITY: MessagePriority;
 20    fn into_envelope(
 21        self,
 22        id: u32,
 23        responding_to: Option<u32>,
 24        original_sender_id: Option<u32>,
 25    ) -> Envelope;
 26    fn from_envelope(envelope: Envelope) -> Option<Self>;
 27}
 28
 29pub trait EntityMessage: EnvelopedMessage {
 30    fn remote_entity_id(&self) -> u64;
 31}
 32
 33pub trait RequestMessage: EnvelopedMessage {
 34    type Response: EnvelopedMessage;
 35}
 36
 37pub trait AnyTypedEnvelope: 'static + Send + Sync {
 38    fn payload_type_id(&self) -> TypeId;
 39    fn payload_type_name(&self) -> &'static str;
 40    fn as_any(&self) -> &dyn Any;
 41    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 42    fn is_background(&self) -> bool;
 43    fn original_sender_id(&self) -> Option<PeerId>;
 44}
 45
 46pub enum MessagePriority {
 47    Foreground,
 48    Background,
 49}
 50
 51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 52    fn payload_type_id(&self) -> TypeId {
 53        TypeId::of::<T>()
 54    }
 55
 56    fn payload_type_name(&self) -> &'static str {
 57        T::NAME
 58    }
 59
 60    fn as_any(&self) -> &dyn Any {
 61        self
 62    }
 63
 64    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 65        self
 66    }
 67
 68    fn is_background(&self) -> bool {
 69        matches!(T::PRIORITY, MessagePriority::Background)
 70    }
 71
 72    fn original_sender_id(&self) -> Option<PeerId> {
 73        self.original_sender_id
 74    }
 75}
 76
 77messages!(
 78    (Ack, Foreground),
 79    (AddProjectCollaborator, Foreground),
 80    (ApplyCodeAction, Background),
 81    (ApplyCodeActionResponse, Background),
 82    (ApplyCompletionAdditionalEdits, Background),
 83    (ApplyCompletionAdditionalEditsResponse, Background),
 84    (BufferReloaded, Foreground),
 85    (BufferSaved, Foreground),
 86    (RemoveContact, Foreground),
 87    (ChannelMessageSent, Foreground),
 88    (CopyProjectEntry, Foreground),
 89    (CreateProjectEntry, Foreground),
 90    (DeleteProjectEntry, Foreground),
 91    (Error, Foreground),
 92    (Follow, Foreground),
 93    (FollowResponse, Foreground),
 94    (FormatBuffers, Foreground),
 95    (FormatBuffersResponse, Foreground),
 96    (FuzzySearchUsers, Foreground),
 97    (GetChannelMessages, Foreground),
 98    (GetChannelMessagesResponse, Foreground),
 99    (GetChannels, Foreground),
100    (GetChannelsResponse, Foreground),
101    (GetCodeActions, Background),
102    (GetCodeActionsResponse, Background),
103    (GetHover, Background),
104    (GetHoverResponse, Background),
105    (GetCompletions, Background),
106    (GetCompletionsResponse, Background),
107    (GetDefinition, Background),
108    (GetDefinitionResponse, Background),
109    (GetTypeDefinition, Background),
110    (GetTypeDefinitionResponse, Background),
111    (GetDocumentHighlights, Background),
112    (GetDocumentHighlightsResponse, Background),
113    (GetReferences, Background),
114    (GetReferencesResponse, Background),
115    (GetProjectSymbols, Background),
116    (GetProjectSymbolsResponse, Background),
117    (GetUsers, Foreground),
118    (UsersResponse, Foreground),
119    (JoinChannel, Foreground),
120    (JoinChannelResponse, Foreground),
121    (JoinProject, Foreground),
122    (JoinProjectResponse, Foreground),
123    (JoinProjectRequestCancelled, Foreground),
124    (LeaveChannel, Foreground),
125    (LeaveProject, Foreground),
126    (OpenBufferById, Background),
127    (OpenBufferByPath, Background),
128    (OpenBufferForSymbol, Background),
129    (OpenBufferForSymbolResponse, Background),
130    (OpenBufferResponse, Background),
131    (PerformRename, Background),
132    (PerformRenameResponse, Background),
133    (PrepareRename, Background),
134    (PrepareRenameResponse, Background),
135    (ProjectEntryResponse, Foreground),
136    (ProjectUnshared, Foreground),
137    (RegisterProjectResponse, Foreground),
138    (Ping, Foreground),
139    (RegisterProject, Foreground),
140    (RegisterProjectActivity, Foreground),
141    (ReloadBuffers, Foreground),
142    (ReloadBuffersResponse, Foreground),
143    (RemoveProjectCollaborator, Foreground),
144    (RenameProjectEntry, Foreground),
145    (RequestContact, Foreground),
146    (RequestJoinProject, Foreground),
147    (RespondToContactRequest, Foreground),
148    (RespondToJoinProjectRequest, Foreground),
149    (SaveBuffer, Foreground),
150    (SearchProject, Background),
151    (SearchProjectResponse, Background),
152    (SendChannelMessage, Foreground),
153    (SendChannelMessageResponse, Foreground),
154    (ShowContacts, Foreground),
155    (StartLanguageServer, Foreground),
156    (Test, Foreground),
157    (Unfollow, Foreground),
158    (UnregisterProject, Foreground),
159    (UpdateBuffer, Foreground),
160    (UpdateBufferFile, Foreground),
161    (UpdateContacts, Foreground),
162    (UpdateDiagnosticSummary, Foreground),
163    (UpdateFollowers, Foreground),
164    (UpdateInviteInfo, Foreground),
165    (UpdateLanguageServer, Foreground),
166    (UpdateProject, Foreground),
167    (UpdateWorktree, Foreground),
168    (UpdateWorktreeExtensions, Background),
169);
170
171request_messages!(
172    (ApplyCodeAction, ApplyCodeActionResponse),
173    (
174        ApplyCompletionAdditionalEdits,
175        ApplyCompletionAdditionalEditsResponse
176    ),
177    (CopyProjectEntry, ProjectEntryResponse),
178    (CreateProjectEntry, ProjectEntryResponse),
179    (DeleteProjectEntry, ProjectEntryResponse),
180    (Follow, FollowResponse),
181    (FormatBuffers, FormatBuffersResponse),
182    (GetChannelMessages, GetChannelMessagesResponse),
183    (GetChannels, GetChannelsResponse),
184    (GetCodeActions, GetCodeActionsResponse),
185    (GetHover, GetHoverResponse),
186    (GetCompletions, GetCompletionsResponse),
187    (GetDefinition, GetDefinitionResponse),
188    (GetTypeDefinition, GetTypeDefinitionResponse),
189    (GetDocumentHighlights, GetDocumentHighlightsResponse),
190    (GetReferences, GetReferencesResponse),
191    (GetProjectSymbols, GetProjectSymbolsResponse),
192    (FuzzySearchUsers, UsersResponse),
193    (GetUsers, UsersResponse),
194    (JoinChannel, JoinChannelResponse),
195    (JoinProject, JoinProjectResponse),
196    (OpenBufferById, OpenBufferResponse),
197    (OpenBufferByPath, OpenBufferResponse),
198    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
199    (Ping, Ack),
200    (PerformRename, PerformRenameResponse),
201    (PrepareRename, PrepareRenameResponse),
202    (RegisterProject, RegisterProjectResponse),
203    (ReloadBuffers, ReloadBuffersResponse),
204    (RequestContact, Ack),
205    (RemoveContact, Ack),
206    (RespondToContactRequest, Ack),
207    (RenameProjectEntry, ProjectEntryResponse),
208    (SaveBuffer, BufferSaved),
209    (SearchProject, SearchProjectResponse),
210    (SendChannelMessage, SendChannelMessageResponse),
211    (Test, Test),
212    (UnregisterProject, Ack),
213    (UpdateBuffer, Ack),
214    (UpdateWorktree, Ack),
215);
216
217entity_messages!(
218    project_id,
219    AddProjectCollaborator,
220    ApplyCodeAction,
221    ApplyCompletionAdditionalEdits,
222    BufferReloaded,
223    BufferSaved,
224    CopyProjectEntry,
225    CreateProjectEntry,
226    DeleteProjectEntry,
227    Follow,
228    FormatBuffers,
229    GetCodeActions,
230    GetCompletions,
231    GetDefinition,
232    GetTypeDefinition,
233    GetDocumentHighlights,
234    GetHover,
235    GetReferences,
236    GetProjectSymbols,
237    JoinProject,
238    JoinProjectRequestCancelled,
239    LeaveProject,
240    OpenBufferById,
241    OpenBufferByPath,
242    OpenBufferForSymbol,
243    PerformRename,
244    PrepareRename,
245    ProjectUnshared,
246    RegisterProjectActivity,
247    ReloadBuffers,
248    RemoveProjectCollaborator,
249    RenameProjectEntry,
250    RequestJoinProject,
251    SaveBuffer,
252    SearchProject,
253    StartLanguageServer,
254    Unfollow,
255    UnregisterProject,
256    UpdateBuffer,
257    UpdateBufferFile,
258    UpdateDiagnosticSummary,
259    UpdateFollowers,
260    UpdateLanguageServer,
261    UpdateProject,
262    UpdateWorktree,
263    UpdateWorktreeExtensions,
264);
265
266entity_messages!(channel_id, ChannelMessageSent);
267
268const KIB: usize = 1024;
269const MIB: usize = KIB * 1024;
270const MAX_BUFFER_LEN: usize = MIB;
271
272/// A stream of protobuf messages.
273pub struct MessageStream<S> {
274    stream: S,
275    encoding_buffer: Vec<u8>,
276}
277
278#[allow(clippy::large_enum_variant)]
279#[derive(Debug)]
280pub enum Message {
281    Envelope(Envelope),
282    Ping,
283    Pong,
284}
285
286impl<S> MessageStream<S> {
287    pub fn new(stream: S) -> Self {
288        Self {
289            stream,
290            encoding_buffer: Vec::new(),
291        }
292    }
293
294    pub fn inner_mut(&mut self) -> &mut S {
295        &mut self.stream
296    }
297}
298
299impl<S> MessageStream<S>
300where
301    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
302{
303    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
304        #[cfg(any(test, feature = "test-support"))]
305        const COMPRESSION_LEVEL: i32 = -7;
306
307        #[cfg(not(any(test, feature = "test-support")))]
308        const COMPRESSION_LEVEL: i32 = 4;
309
310        match message {
311            Message::Envelope(message) => {
312                self.encoding_buffer.reserve(message.encoded_len());
313                message
314                    .encode(&mut self.encoding_buffer)
315                    .map_err(io::Error::from)?;
316                let buffer =
317                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
318                        .unwrap();
319
320                self.encoding_buffer.clear();
321                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
322                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
323            }
324            Message::Ping => {
325                self.stream
326                    .send(WebSocketMessage::Ping(Default::default()))
327                    .await?;
328            }
329            Message::Pong => {
330                self.stream
331                    .send(WebSocketMessage::Pong(Default::default()))
332                    .await?;
333            }
334        }
335
336        Ok(())
337    }
338}
339
340impl<S> MessageStream<S>
341where
342    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
343{
344    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
345        while let Some(bytes) = self.stream.next().await {
346            match bytes? {
347                WebSocketMessage::Binary(bytes) => {
348                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
349                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
350                        .map_err(io::Error::from)?;
351
352                    self.encoding_buffer.clear();
353                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
354                    return Ok(Message::Envelope(envelope));
355                }
356                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
357                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
358                WebSocketMessage::Close(_) => break,
359                _ => {}
360            }
361        }
362        Err(anyhow!("connection closed"))
363    }
364}
365
366impl From<Timestamp> for SystemTime {
367    fn from(val: Timestamp) -> Self {
368        UNIX_EPOCH
369            .checked_add(Duration::new(val.seconds, val.nanos))
370            .unwrap()
371    }
372}
373
374impl From<SystemTime> for Timestamp {
375    fn from(time: SystemTime) -> Self {
376        let duration = time.duration_since(UNIX_EPOCH).unwrap();
377        Self {
378            seconds: duration.as_secs(),
379            nanos: duration.subsec_nanos(),
380        }
381    }
382}
383
384impl From<u128> for Nonce {
385    fn from(nonce: u128) -> Self {
386        let upper_half = (nonce >> 64) as u64;
387        let lower_half = nonce as u64;
388        Self {
389            upper_half,
390            lower_half,
391        }
392    }
393}
394
395impl From<Nonce> for u128 {
396    fn from(nonce: Nonce) -> Self {
397        let upper_half = (nonce.upper_half as u128) << 64;
398        let lower_half = nonce.lower_half as u128;
399        upper_half | lower_half
400    }
401}
402
403pub fn split_worktree_update(
404    mut message: UpdateWorktree,
405    max_chunk_size: usize,
406) -> impl Iterator<Item = UpdateWorktree> {
407    let mut done = false;
408    iter::from_fn(move || {
409        if done {
410            return None;
411        }
412
413        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
414        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
415        done = message.updated_entries.is_empty();
416        Some(UpdateWorktree {
417            project_id: message.project_id,
418            worktree_id: message.worktree_id,
419            root_name: message.root_name.clone(),
420            updated_entries,
421            removed_entries: mem::take(&mut message.removed_entries),
422            scan_id: message.scan_id,
423            is_last_update: done && message.is_last_update,
424        })
425    })
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[gpui::test]
433    async fn test_buffer_size() {
434        let (tx, rx) = futures::channel::mpsc::unbounded();
435        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
436        sink.write(Message::Envelope(Envelope {
437            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
438                root_name: "abcdefg".repeat(10),
439                ..Default::default()
440            })),
441            ..Default::default()
442        }))
443        .await
444        .unwrap();
445        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
446        sink.write(Message::Envelope(Envelope {
447            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
448                root_name: "abcdefg".repeat(1000000),
449                ..Default::default()
450            })),
451            ..Default::default()
452        }))
453        .await
454        .unwrap();
455        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
456
457        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
458        stream.read().await.unwrap();
459        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
460        stream.read().await.unwrap();
461        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
462    }
463}