proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{cmp, iter, mem};
  9use std::{
 10    fmt::Debug,
 11    io,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14
 15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 16
 17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 18    const NAME: &'static str;
 19    const PRIORITY: MessagePriority;
 20    fn into_envelope(
 21        self,
 22        id: u32,
 23        responding_to: Option<u32>,
 24        original_sender_id: Option<u32>,
 25    ) -> Envelope;
 26    fn from_envelope(envelope: Envelope) -> Option<Self>;
 27}
 28
 29pub trait EntityMessage: EnvelopedMessage {
 30    fn remote_entity_id(&self) -> u64;
 31}
 32
 33pub trait RequestMessage: EnvelopedMessage {
 34    type Response: EnvelopedMessage;
 35}
 36
 37pub trait AnyTypedEnvelope: 'static + Send + Sync {
 38    fn payload_type_id(&self) -> TypeId;
 39    fn payload_type_name(&self) -> &'static str;
 40    fn as_any(&self) -> &dyn Any;
 41    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 42    fn is_background(&self) -> bool;
 43    fn original_sender_id(&self) -> Option<PeerId>;
 44}
 45
 46pub enum MessagePriority {
 47    Foreground,
 48    Background,
 49}
 50
 51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 52    fn payload_type_id(&self) -> TypeId {
 53        TypeId::of::<T>()
 54    }
 55
 56    fn payload_type_name(&self) -> &'static str {
 57        T::NAME
 58    }
 59
 60    fn as_any(&self) -> &dyn Any {
 61        self
 62    }
 63
 64    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 65        self
 66    }
 67
 68    fn is_background(&self) -> bool {
 69        matches!(T::PRIORITY, MessagePriority::Background)
 70    }
 71
 72    fn original_sender_id(&self) -> Option<PeerId> {
 73        self.original_sender_id
 74    }
 75}
 76
 77messages!(
 78    (Ack, Foreground),
 79    (AddProjectCollaborator, Foreground),
 80    (ApplyCodeAction, Background),
 81    (ApplyCodeActionResponse, Background),
 82    (ApplyCompletionAdditionalEdits, Background),
 83    (ApplyCompletionAdditionalEditsResponse, Background),
 84    (BufferReloaded, Foreground),
 85    (BufferSaved, Foreground),
 86    (ChannelMessageSent, Foreground),
 87    (CopyProjectEntry, Foreground),
 88    (CreateBufferForPeer, Foreground),
 89    (CreateProjectEntry, Foreground),
 90    (CreateRoom, Foreground),
 91    (CreateRoomResponse, Foreground),
 92    (DeleteProjectEntry, Foreground),
 93    (Error, Foreground),
 94    (Follow, Foreground),
 95    (FollowResponse, Foreground),
 96    (FormatBuffers, Foreground),
 97    (FormatBuffersResponse, Foreground),
 98    (FuzzySearchUsers, Foreground),
 99    (GetChannelMessages, Foreground),
100    (GetChannelMessagesResponse, Foreground),
101    (GetChannels, Foreground),
102    (GetChannelsResponse, Foreground),
103    (GetCodeActions, Background),
104    (GetCodeActionsResponse, Background),
105    (GetHover, Background),
106    (GetHoverResponse, Background),
107    (GetCompletions, Background),
108    (GetCompletionsResponse, Background),
109    (GetDefinition, Background),
110    (GetDefinitionResponse, Background),
111    (GetTypeDefinition, Background),
112    (GetTypeDefinitionResponse, Background),
113    (GetDocumentHighlights, Background),
114    (GetDocumentHighlightsResponse, Background),
115    (GetReferences, Background),
116    (GetReferencesResponse, Background),
117    (GetProjectSymbols, Background),
118    (GetProjectSymbolsResponse, Background),
119    (GetUsers, Foreground),
120    (UsersResponse, Foreground),
121    (JoinChannel, Foreground),
122    (JoinChannelResponse, Foreground),
123    (JoinProject, Foreground),
124    (JoinProjectResponse, Foreground),
125    (JoinProjectRequestCancelled, Foreground),
126    (JoinRoom, Foreground),
127    (JoinRoomResponse, Foreground),
128    (LeaveChannel, Foreground),
129    (LeaveProject, Foreground),
130    (OpenBufferById, Background),
131    (OpenBufferByPath, Background),
132    (OpenBufferForSymbol, Background),
133    (OpenBufferForSymbolResponse, Background),
134    (OpenBufferResponse, Background),
135    (PerformRename, Background),
136    (PerformRenameResponse, Background),
137    (PrepareRename, Background),
138    (PrepareRenameResponse, Background),
139    (ProjectEntryResponse, Foreground),
140    (ProjectUnshared, Foreground),
141    (RegisterProjectResponse, Foreground),
142    (RemoveContact, Foreground),
143    (Ping, Foreground),
144    (RegisterProject, Foreground),
145    (RegisterProjectActivity, Foreground),
146    (ReloadBuffers, Foreground),
147    (ReloadBuffersResponse, Foreground),
148    (RemoveProjectCollaborator, Foreground),
149    (RenameProjectEntry, Foreground),
150    (RequestContact, Foreground),
151    (RequestJoinProject, Foreground),
152    (RespondToContactRequest, Foreground),
153    (RespondToJoinProjectRequest, Foreground),
154    (SaveBuffer, Foreground),
155    (SearchProject, Background),
156    (SearchProjectResponse, Background),
157    (SendChannelMessage, Foreground),
158    (SendChannelMessageResponse, Foreground),
159    (ShowContacts, Foreground),
160    (StartLanguageServer, Foreground),
161    (Test, Foreground),
162    (Unfollow, Foreground),
163    (UnregisterProject, Foreground),
164    (UpdateBuffer, Foreground),
165    (UpdateBufferFile, Foreground),
166    (UpdateContacts, Foreground),
167    (UpdateDiagnosticSummary, Foreground),
168    (UpdateFollowers, Foreground),
169    (UpdateInviteInfo, Foreground),
170    (UpdateLanguageServer, Foreground),
171    (UpdateProject, Foreground),
172    (UpdateWorktree, Foreground),
173    (UpdateWorktreeExtensions, Background),
174);
175
176request_messages!(
177    (ApplyCodeAction, ApplyCodeActionResponse),
178    (
179        ApplyCompletionAdditionalEdits,
180        ApplyCompletionAdditionalEditsResponse
181    ),
182    (CopyProjectEntry, ProjectEntryResponse),
183    (CreateProjectEntry, ProjectEntryResponse),
184    (CreateRoom, CreateRoomResponse),
185    (DeleteProjectEntry, ProjectEntryResponse),
186    (Follow, FollowResponse),
187    (FormatBuffers, FormatBuffersResponse),
188    (GetChannelMessages, GetChannelMessagesResponse),
189    (GetChannels, GetChannelsResponse),
190    (GetCodeActions, GetCodeActionsResponse),
191    (GetHover, GetHoverResponse),
192    (GetCompletions, GetCompletionsResponse),
193    (GetDefinition, GetDefinitionResponse),
194    (GetTypeDefinition, GetTypeDefinitionResponse),
195    (GetDocumentHighlights, GetDocumentHighlightsResponse),
196    (GetReferences, GetReferencesResponse),
197    (GetProjectSymbols, GetProjectSymbolsResponse),
198    (FuzzySearchUsers, UsersResponse),
199    (GetUsers, UsersResponse),
200    (JoinChannel, JoinChannelResponse),
201    (JoinProject, JoinProjectResponse),
202    (JoinRoom, JoinRoomResponse),
203    (OpenBufferById, OpenBufferResponse),
204    (OpenBufferByPath, OpenBufferResponse),
205    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
206    (Ping, Ack),
207    (PerformRename, PerformRenameResponse),
208    (PrepareRename, PrepareRenameResponse),
209    (RegisterProject, RegisterProjectResponse),
210    (ReloadBuffers, ReloadBuffersResponse),
211    (RequestContact, Ack),
212    (RemoveContact, Ack),
213    (RespondToContactRequest, Ack),
214    (RenameProjectEntry, ProjectEntryResponse),
215    (SaveBuffer, BufferSaved),
216    (SearchProject, SearchProjectResponse),
217    (SendChannelMessage, SendChannelMessageResponse),
218    (Test, Test),
219    (UnregisterProject, Ack),
220    (UpdateBuffer, Ack),
221    (UpdateWorktree, Ack),
222);
223
224entity_messages!(
225    project_id,
226    AddProjectCollaborator,
227    ApplyCodeAction,
228    ApplyCompletionAdditionalEdits,
229    BufferReloaded,
230    BufferSaved,
231    CopyProjectEntry,
232    CreateBufferForPeer,
233    CreateProjectEntry,
234    DeleteProjectEntry,
235    Follow,
236    FormatBuffers,
237    GetCodeActions,
238    GetCompletions,
239    GetDefinition,
240    GetTypeDefinition,
241    GetDocumentHighlights,
242    GetHover,
243    GetReferences,
244    GetProjectSymbols,
245    JoinProject,
246    JoinProjectRequestCancelled,
247    LeaveProject,
248    OpenBufferById,
249    OpenBufferByPath,
250    OpenBufferForSymbol,
251    PerformRename,
252    PrepareRename,
253    ProjectUnshared,
254    RegisterProjectActivity,
255    ReloadBuffers,
256    RemoveProjectCollaborator,
257    RenameProjectEntry,
258    RequestJoinProject,
259    SaveBuffer,
260    SearchProject,
261    StartLanguageServer,
262    Unfollow,
263    UnregisterProject,
264    UpdateBuffer,
265    UpdateBufferFile,
266    UpdateDiagnosticSummary,
267    UpdateFollowers,
268    UpdateLanguageServer,
269    UpdateProject,
270    UpdateWorktree,
271    UpdateWorktreeExtensions,
272);
273
274entity_messages!(channel_id, ChannelMessageSent);
275
276const KIB: usize = 1024;
277const MIB: usize = KIB * 1024;
278const MAX_BUFFER_LEN: usize = MIB;
279
280/// A stream of protobuf messages.
281pub struct MessageStream<S> {
282    stream: S,
283    encoding_buffer: Vec<u8>,
284}
285
286#[allow(clippy::large_enum_variant)]
287#[derive(Debug)]
288pub enum Message {
289    Envelope(Envelope),
290    Ping,
291    Pong,
292}
293
294impl<S> MessageStream<S> {
295    pub fn new(stream: S) -> Self {
296        Self {
297            stream,
298            encoding_buffer: Vec::new(),
299        }
300    }
301
302    pub fn inner_mut(&mut self) -> &mut S {
303        &mut self.stream
304    }
305}
306
307impl<S> MessageStream<S>
308where
309    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
310{
311    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
312        #[cfg(any(test, feature = "test-support"))]
313        const COMPRESSION_LEVEL: i32 = -7;
314
315        #[cfg(not(any(test, feature = "test-support")))]
316        const COMPRESSION_LEVEL: i32 = 4;
317
318        match message {
319            Message::Envelope(message) => {
320                self.encoding_buffer.reserve(message.encoded_len());
321                message
322                    .encode(&mut self.encoding_buffer)
323                    .map_err(io::Error::from)?;
324                let buffer =
325                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
326                        .unwrap();
327
328                self.encoding_buffer.clear();
329                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
330                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
331            }
332            Message::Ping => {
333                self.stream
334                    .send(WebSocketMessage::Ping(Default::default()))
335                    .await?;
336            }
337            Message::Pong => {
338                self.stream
339                    .send(WebSocketMessage::Pong(Default::default()))
340                    .await?;
341            }
342        }
343
344        Ok(())
345    }
346}
347
348impl<S> MessageStream<S>
349where
350    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
351{
352    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
353        while let Some(bytes) = self.stream.next().await {
354            match bytes? {
355                WebSocketMessage::Binary(bytes) => {
356                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
357                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
358                        .map_err(io::Error::from)?;
359
360                    self.encoding_buffer.clear();
361                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
362                    return Ok(Message::Envelope(envelope));
363                }
364                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
365                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
366                WebSocketMessage::Close(_) => break,
367                _ => {}
368            }
369        }
370        Err(anyhow!("connection closed"))
371    }
372}
373
374impl From<Timestamp> for SystemTime {
375    fn from(val: Timestamp) -> Self {
376        UNIX_EPOCH
377            .checked_add(Duration::new(val.seconds, val.nanos))
378            .unwrap()
379    }
380}
381
382impl From<SystemTime> for Timestamp {
383    fn from(time: SystemTime) -> Self {
384        let duration = time.duration_since(UNIX_EPOCH).unwrap();
385        Self {
386            seconds: duration.as_secs(),
387            nanos: duration.subsec_nanos(),
388        }
389    }
390}
391
392impl From<u128> for Nonce {
393    fn from(nonce: u128) -> Self {
394        let upper_half = (nonce >> 64) as u64;
395        let lower_half = nonce as u64;
396        Self {
397            upper_half,
398            lower_half,
399        }
400    }
401}
402
403impl From<Nonce> for u128 {
404    fn from(nonce: Nonce) -> Self {
405        let upper_half = (nonce.upper_half as u128) << 64;
406        let lower_half = nonce.lower_half as u128;
407        upper_half | lower_half
408    }
409}
410
411pub fn split_worktree_update(
412    mut message: UpdateWorktree,
413    max_chunk_size: usize,
414) -> impl Iterator<Item = UpdateWorktree> {
415    let mut done = false;
416    iter::from_fn(move || {
417        if done {
418            return None;
419        }
420
421        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
422        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
423        done = message.updated_entries.is_empty();
424        Some(UpdateWorktree {
425            project_id: message.project_id,
426            worktree_id: message.worktree_id,
427            root_name: message.root_name.clone(),
428            updated_entries,
429            removed_entries: mem::take(&mut message.removed_entries),
430            scan_id: message.scan_id,
431            is_last_update: done && message.is_last_update,
432        })
433    })
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[gpui::test]
441    async fn test_buffer_size() {
442        let (tx, rx) = futures::channel::mpsc::unbounded();
443        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
444        sink.write(Message::Envelope(Envelope {
445            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
446                root_name: "abcdefg".repeat(10),
447                ..Default::default()
448            })),
449            ..Default::default()
450        }))
451        .await
452        .unwrap();
453        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
454        sink.write(Message::Envelope(Envelope {
455            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
456                root_name: "abcdefg".repeat(1000000),
457                ..Default::default()
458            })),
459            ..Default::default()
460        }))
461        .await
462        .unwrap();
463        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
464
465        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
466        stream.read().await.unwrap();
467        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
468        stream.read().await.unwrap();
469        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
470    }
471}