proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{cmp, iter, mem};
  9use std::{
 10    fmt::Debug,
 11    io,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14
 15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 16
 17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 18    const NAME: &'static str;
 19    const PRIORITY: MessagePriority;
 20    fn into_envelope(
 21        self,
 22        id: u32,
 23        responding_to: Option<u32>,
 24        original_sender_id: Option<u32>,
 25    ) -> Envelope;
 26    fn from_envelope(envelope: Envelope) -> Option<Self>;
 27}
 28
 29pub trait EntityMessage: EnvelopedMessage {
 30    fn remote_entity_id(&self) -> u64;
 31}
 32
 33pub trait RequestMessage: EnvelopedMessage {
 34    type Response: EnvelopedMessage;
 35}
 36
 37pub trait AnyTypedEnvelope: 'static + Send + Sync {
 38    fn payload_type_id(&self) -> TypeId;
 39    fn payload_type_name(&self) -> &'static str;
 40    fn as_any(&self) -> &dyn Any;
 41    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 42    fn is_background(&self) -> bool;
 43    fn original_sender_id(&self) -> Option<PeerId>;
 44}
 45
 46pub enum MessagePriority {
 47    Foreground,
 48    Background,
 49}
 50
 51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 52    fn payload_type_id(&self) -> TypeId {
 53        TypeId::of::<T>()
 54    }
 55
 56    fn payload_type_name(&self) -> &'static str {
 57        T::NAME
 58    }
 59
 60    fn as_any(&self) -> &dyn Any {
 61        self
 62    }
 63
 64    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 65        self
 66    }
 67
 68    fn is_background(&self) -> bool {
 69        matches!(T::PRIORITY, MessagePriority::Background)
 70    }
 71
 72    fn original_sender_id(&self) -> Option<PeerId> {
 73        self.original_sender_id
 74    }
 75}
 76
 77messages!(
 78    (Ack, Foreground),
 79    (AddProjectCollaborator, Foreground),
 80    (ApplyCodeAction, Background),
 81    (ApplyCodeActionResponse, Background),
 82    (ApplyCompletionAdditionalEdits, Background),
 83    (ApplyCompletionAdditionalEditsResponse, Background),
 84    (BufferReloaded, Foreground),
 85    (BufferSaved, Foreground),
 86    (RemoveContact, Foreground),
 87    (ChannelMessageSent, Foreground),
 88    (CopyProjectEntry, Foreground),
 89    (CreateProjectEntry, Foreground),
 90    (DeleteProjectEntry, Foreground),
 91    (Error, Foreground),
 92    (Follow, Foreground),
 93    (FollowResponse, Foreground),
 94    (FormatBuffers, Foreground),
 95    (FormatBuffersResponse, Foreground),
 96    (FuzzySearchUsers, Foreground),
 97    (GetChannelMessages, Foreground),
 98    (GetChannelMessagesResponse, Foreground),
 99    (GetChannels, Foreground),
100    (GetChannelsResponse, Foreground),
101    (GetCodeActions, Background),
102    (GetCodeActionsResponse, Background),
103    (GetHover, Background),
104    (GetHoverResponse, Background),
105    (GetCompletions, Background),
106    (GetCompletionsResponse, Background),
107    (GetDefinition, Background),
108    (GetDefinitionResponse, Background),
109    (GetDocumentHighlights, Background),
110    (GetDocumentHighlightsResponse, Background),
111    (GetReferences, Background),
112    (GetReferencesResponse, Background),
113    (GetProjectSymbols, Background),
114    (GetProjectSymbolsResponse, Background),
115    (GetUsers, Foreground),
116    (UsersResponse, Foreground),
117    (JoinChannel, Foreground),
118    (JoinChannelResponse, Foreground),
119    (JoinProject, Foreground),
120    (JoinProjectResponse, Foreground),
121    (JoinProjectRequestCancelled, Foreground),
122    (LeaveChannel, Foreground),
123    (LeaveProject, Foreground),
124    (OpenBufferById, Background),
125    (OpenBufferByPath, Background),
126    (OpenBufferForSymbol, Background),
127    (OpenBufferForSymbolResponse, Background),
128    (OpenBufferResponse, Background),
129    (PerformRename, Background),
130    (PerformRenameResponse, Background),
131    (PrepareRename, Background),
132    (PrepareRenameResponse, Background),
133    (ProjectEntryResponse, Foreground),
134    (RegisterProjectResponse, Foreground),
135    (Ping, Foreground),
136    (ProjectUnshared, Foreground),
137    (RegisterProject, Foreground),
138    (RegisterProjectActivity, Foreground),
139    (ReloadBuffers, Foreground),
140    (ReloadBuffersResponse, Foreground),
141    (RemoveProjectCollaborator, Foreground),
142    (RenameProjectEntry, Foreground),
143    (RequestContact, Foreground),
144    (RequestJoinProject, Foreground),
145    (RespondToContactRequest, Foreground),
146    (RespondToJoinProjectRequest, Foreground),
147    (SaveBuffer, Foreground),
148    (SearchProject, Background),
149    (SearchProjectResponse, Background),
150    (SendChannelMessage, Foreground),
151    (SendChannelMessageResponse, Foreground),
152    (ShowContacts, Foreground),
153    (StartLanguageServer, Foreground),
154    (Test, Foreground),
155    (Unfollow, Foreground),
156    (UnregisterProject, Foreground),
157    (UpdateBuffer, Foreground),
158    (UpdateBufferFile, Foreground),
159    (UpdateContacts, Foreground),
160    (UpdateDiagnosticSummary, Foreground),
161    (UpdateFollowers, Foreground),
162    (UpdateInviteInfo, Foreground),
163    (UpdateLanguageServer, Foreground),
164    (UpdateProject, Foreground),
165    (UpdateWorktree, Foreground),
166);
167
168request_messages!(
169    (ApplyCodeAction, ApplyCodeActionResponse),
170    (
171        ApplyCompletionAdditionalEdits,
172        ApplyCompletionAdditionalEditsResponse
173    ),
174    (CopyProjectEntry, ProjectEntryResponse),
175    (CreateProjectEntry, ProjectEntryResponse),
176    (DeleteProjectEntry, ProjectEntryResponse),
177    (Follow, FollowResponse),
178    (FormatBuffers, FormatBuffersResponse),
179    (GetChannelMessages, GetChannelMessagesResponse),
180    (GetChannels, GetChannelsResponse),
181    (GetCodeActions, GetCodeActionsResponse),
182    (GetHover, GetHoverResponse),
183    (GetCompletions, GetCompletionsResponse),
184    (GetDefinition, GetDefinitionResponse),
185    (GetDocumentHighlights, GetDocumentHighlightsResponse),
186    (GetReferences, GetReferencesResponse),
187    (GetProjectSymbols, GetProjectSymbolsResponse),
188    (FuzzySearchUsers, UsersResponse),
189    (GetUsers, UsersResponse),
190    (JoinChannel, JoinChannelResponse),
191    (JoinProject, JoinProjectResponse),
192    (OpenBufferById, OpenBufferResponse),
193    (OpenBufferByPath, OpenBufferResponse),
194    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
195    (Ping, Ack),
196    (PerformRename, PerformRenameResponse),
197    (PrepareRename, PrepareRenameResponse),
198    (RegisterProject, RegisterProjectResponse),
199    (ReloadBuffers, ReloadBuffersResponse),
200    (RequestContact, Ack),
201    (RemoveContact, Ack),
202    (RespondToContactRequest, Ack),
203    (RenameProjectEntry, ProjectEntryResponse),
204    (SaveBuffer, BufferSaved),
205    (SearchProject, SearchProjectResponse),
206    (SendChannelMessage, SendChannelMessageResponse),
207    (Test, Test),
208    (UnregisterProject, Ack),
209    (UpdateBuffer, Ack),
210    (UpdateWorktree, Ack),
211);
212
213entity_messages!(
214    project_id,
215    AddProjectCollaborator,
216    ApplyCodeAction,
217    ApplyCompletionAdditionalEdits,
218    BufferReloaded,
219    BufferSaved,
220    CopyProjectEntry,
221    CreateProjectEntry,
222    DeleteProjectEntry,
223    Follow,
224    FormatBuffers,
225    GetCodeActions,
226    GetCompletions,
227    GetDefinition,
228    GetDocumentHighlights,
229    GetHover,
230    GetReferences,
231    GetProjectSymbols,
232    JoinProject,
233    JoinProjectRequestCancelled,
234    LeaveProject,
235    OpenBufferById,
236    OpenBufferByPath,
237    OpenBufferForSymbol,
238    PerformRename,
239    PrepareRename,
240    ProjectUnshared,
241    RegisterProjectActivity,
242    ReloadBuffers,
243    RemoveProjectCollaborator,
244    RenameProjectEntry,
245    RequestJoinProject,
246    SaveBuffer,
247    SearchProject,
248    StartLanguageServer,
249    Unfollow,
250    UnregisterProject,
251    UpdateBuffer,
252    UpdateBufferFile,
253    UpdateDiagnosticSummary,
254    UpdateFollowers,
255    UpdateLanguageServer,
256    UpdateProject,
257    UpdateWorktree,
258);
259
260entity_messages!(channel_id, ChannelMessageSent);
261
262const MAX_BUFFER_LEN: usize = 1 * 1024 * 1024;
263
264/// A stream of protobuf messages.
265pub struct MessageStream<S> {
266    stream: S,
267    encoding_buffer: Vec<u8>,
268}
269
270#[derive(Debug)]
271pub enum Message {
272    Envelope(Envelope),
273    Ping,
274    Pong,
275}
276
277impl<S> MessageStream<S> {
278    pub fn new(stream: S) -> Self {
279        Self {
280            stream,
281            encoding_buffer: Vec::new(),
282        }
283    }
284
285    pub fn inner_mut(&mut self) -> &mut S {
286        &mut self.stream
287    }
288}
289
290impl<S> MessageStream<S>
291where
292    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
293{
294    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
295        #[cfg(any(test, feature = "test-support"))]
296        const COMPRESSION_LEVEL: i32 = -7;
297
298        #[cfg(not(any(test, feature = "test-support")))]
299        const COMPRESSION_LEVEL: i32 = 4;
300
301        match message {
302            Message::Envelope(message) => {
303                self.encoding_buffer.reserve(message.encoded_len());
304                message
305                    .encode(&mut self.encoding_buffer)
306                    .map_err(|err| io::Error::from(err))?;
307                let buffer =
308                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
309                        .unwrap();
310
311                self.encoding_buffer.clear();
312                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
313                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
314            }
315            Message::Ping => {
316                self.stream
317                    .send(WebSocketMessage::Ping(Default::default()))
318                    .await?;
319            }
320            Message::Pong => {
321                self.stream
322                    .send(WebSocketMessage::Pong(Default::default()))
323                    .await?;
324            }
325        }
326
327        Ok(())
328    }
329}
330
331impl<S> MessageStream<S>
332where
333    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
334{
335    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
336        while let Some(bytes) = self.stream.next().await {
337            match bytes? {
338                WebSocketMessage::Binary(bytes) => {
339                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
340                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
341                        .map_err(io::Error::from)?;
342
343                    self.encoding_buffer.clear();
344                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
345                    return Ok(Message::Envelope(envelope));
346                }
347                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
348                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
349                WebSocketMessage::Close(_) => break,
350                _ => {}
351            }
352        }
353        Err(anyhow!("connection closed"))
354    }
355}
356
357impl Into<SystemTime> for Timestamp {
358    fn into(self) -> SystemTime {
359        UNIX_EPOCH
360            .checked_add(Duration::new(self.seconds, self.nanos))
361            .unwrap()
362    }
363}
364
365impl From<SystemTime> for Timestamp {
366    fn from(time: SystemTime) -> Self {
367        let duration = time.duration_since(UNIX_EPOCH).unwrap();
368        Self {
369            seconds: duration.as_secs(),
370            nanos: duration.subsec_nanos(),
371        }
372    }
373}
374
375impl From<u128> for Nonce {
376    fn from(nonce: u128) -> Self {
377        let upper_half = (nonce >> 64) as u64;
378        let lower_half = nonce as u64;
379        Self {
380            upper_half,
381            lower_half,
382        }
383    }
384}
385
386impl From<Nonce> for u128 {
387    fn from(nonce: Nonce) -> Self {
388        let upper_half = (nonce.upper_half as u128) << 64;
389        let lower_half = nonce.lower_half as u128;
390        upper_half | lower_half
391    }
392}
393
394pub fn split_worktree_update(
395    mut message: UpdateWorktree,
396    max_chunk_size: usize,
397) -> impl Iterator<Item = UpdateWorktree> {
398    let mut done = false;
399    iter::from_fn(move || {
400        if done {
401            return None;
402        }
403
404        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
405        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
406        done = message.updated_entries.is_empty();
407        Some(UpdateWorktree {
408            project_id: message.project_id,
409            worktree_id: message.worktree_id,
410            root_name: message.root_name.clone(),
411            updated_entries,
412            removed_entries: mem::take(&mut message.removed_entries),
413            scan_id: message.scan_id,
414            is_last_update: done && message.is_last_update,
415        })
416    })
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[gpui::test]
424    async fn test_buffer_size() {
425        let (tx, rx) = futures::channel::mpsc::unbounded();
426        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
427        sink.write(Message::Envelope(Envelope {
428            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
429                root_name: "abcdefg".repeat(10),
430                ..Default::default()
431            })),
432            ..Default::default()
433        }))
434        .await
435        .unwrap();
436        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
437        sink.write(Message::Envelope(Envelope {
438            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
439                root_name: "abcdefg".repeat(1000000),
440                ..Default::default()
441            })),
442            ..Default::default()
443        }))
444        .await
445        .unwrap();
446        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
447
448        let mut stream = MessageStream::new(rx.map(|msg| anyhow::Ok(msg)));
449        stream.read().await.unwrap();
450        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
451        stream.read().await.unwrap();
452        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
453    }
454}