proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{cmp, iter, mem};
  9use std::{
 10    fmt::Debug,
 11    io,
 12    time::{Duration, SystemTime, UNIX_EPOCH},
 13};
 14
 15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 16
 17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 18    const NAME: &'static str;
 19    const PRIORITY: MessagePriority;
 20    fn into_envelope(
 21        self,
 22        id: u32,
 23        responding_to: Option<u32>,
 24        original_sender_id: Option<u32>,
 25    ) -> Envelope;
 26    fn from_envelope(envelope: Envelope) -> Option<Self>;
 27}
 28
 29pub trait EntityMessage: EnvelopedMessage {
 30    fn remote_entity_id(&self) -> u64;
 31}
 32
 33pub trait RequestMessage: EnvelopedMessage {
 34    type Response: EnvelopedMessage;
 35}
 36
 37pub trait AnyTypedEnvelope: 'static + Send + Sync {
 38    fn payload_type_id(&self) -> TypeId;
 39    fn payload_type_name(&self) -> &'static str;
 40    fn as_any(&self) -> &dyn Any;
 41    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 42    fn is_background(&self) -> bool;
 43    fn original_sender_id(&self) -> Option<PeerId>;
 44}
 45
 46pub enum MessagePriority {
 47    Foreground,
 48    Background,
 49}
 50
 51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 52    fn payload_type_id(&self) -> TypeId {
 53        TypeId::of::<T>()
 54    }
 55
 56    fn payload_type_name(&self) -> &'static str {
 57        T::NAME
 58    }
 59
 60    fn as_any(&self) -> &dyn Any {
 61        self
 62    }
 63
 64    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 65        self
 66    }
 67
 68    fn is_background(&self) -> bool {
 69        matches!(T::PRIORITY, MessagePriority::Background)
 70    }
 71
 72    fn original_sender_id(&self) -> Option<PeerId> {
 73        self.original_sender_id
 74    }
 75}
 76
 77messages!(
 78    (Ack, Foreground),
 79    (AddProjectCollaborator, Foreground),
 80    (ApplyCodeAction, Background),
 81    (ApplyCodeActionResponse, Background),
 82    (ApplyCompletionAdditionalEdits, Background),
 83    (ApplyCompletionAdditionalEditsResponse, Background),
 84    (BufferReloaded, Foreground),
 85    (BufferSaved, Foreground),
 86    (RemoveContact, Foreground),
 87    (ChannelMessageSent, Foreground),
 88    (CopyProjectEntry, Foreground),
 89    (CreateProjectEntry, Foreground),
 90    (DeleteProjectEntry, Foreground),
 91    (Error, Foreground),
 92    (Follow, Foreground),
 93    (FollowResponse, Foreground),
 94    (FormatBuffers, Foreground),
 95    (FormatBuffersResponse, Foreground),
 96    (FuzzySearchUsers, Foreground),
 97    (GetChannelMessages, Foreground),
 98    (GetChannelMessagesResponse, Foreground),
 99    (GetChannels, Foreground),
100    (GetChannelsResponse, Foreground),
101    (GetCodeActions, Background),
102    (GetCodeActionsResponse, Background),
103    (GetHover, Background),
104    (GetHoverResponse, Background),
105    (GetCompletions, Background),
106    (GetCompletionsResponse, Background),
107    (GetDefinition, Background),
108    (GetDefinitionResponse, Background),
109    (GetTypeDefinition, Background),
110    (GetTypeDefinitionResponse, Background),
111    (GetDocumentHighlights, Background),
112    (GetDocumentHighlightsResponse, Background),
113    (GetReferences, Background),
114    (GetReferencesResponse, Background),
115    (GetProjectSymbols, Background),
116    (GetProjectSymbolsResponse, Background),
117    (GetUsers, Foreground),
118    (UsersResponse, Foreground),
119    (JoinChannel, Foreground),
120    (JoinChannelResponse, Foreground),
121    (JoinProject, Foreground),
122    (JoinProjectResponse, Foreground),
123    (JoinProjectRequestCancelled, Foreground),
124    (LeaveChannel, Foreground),
125    (LeaveProject, Foreground),
126    (OpenBufferById, Background),
127    (OpenBufferByPath, Background),
128    (OpenBufferForSymbol, Background),
129    (OpenBufferForSymbolResponse, Background),
130    (OpenBufferResponse, Background),
131    (PerformRename, Background),
132    (PerformRenameResponse, Background),
133    (PrepareRename, Background),
134    (PrepareRenameResponse, Background),
135    (ProjectEntryResponse, Foreground),
136    (ProjectUnshared, Foreground),
137    (RegisterProjectResponse, Foreground),
138    (Ping, Foreground),
139    (RegisterProject, Foreground),
140    (RegisterProjectActivity, Foreground),
141    (ReloadBuffers, Foreground),
142    (ReloadBuffersResponse, Foreground),
143    (RemoveProjectCollaborator, Foreground),
144    (RenameProjectEntry, Foreground),
145    (RequestContact, Foreground),
146    (RequestJoinProject, Foreground),
147    (RespondToContactRequest, Foreground),
148    (RespondToJoinProjectRequest, Foreground),
149    (SaveBuffer, Foreground),
150    (SearchProject, Background),
151    (SearchProjectResponse, Background),
152    (SendChannelMessage, Foreground),
153    (SendChannelMessageResponse, Foreground),
154    (ShowContacts, Foreground),
155    (StartLanguageServer, Foreground),
156    (Test, Foreground),
157    (Unfollow, Foreground),
158    (UnregisterProject, Foreground),
159    (UpdateBuffer, Foreground),
160    (UpdateBufferFile, Foreground),
161    (UpdateContacts, Foreground),
162    (UpdateDiagnosticSummary, Foreground),
163    (UpdateFollowers, Foreground),
164    (UpdateInviteInfo, Foreground),
165    (UpdateLanguageServer, Foreground),
166    (UpdateProject, Foreground),
167    (UpdateWorktree, Foreground),
168    (UpdateWorktreeExtensions, Background),
169);
170
171request_messages!(
172    (ApplyCodeAction, ApplyCodeActionResponse),
173    (
174        ApplyCompletionAdditionalEdits,
175        ApplyCompletionAdditionalEditsResponse
176    ),
177    (CopyProjectEntry, ProjectEntryResponse),
178    (CreateProjectEntry, ProjectEntryResponse),
179    (DeleteProjectEntry, ProjectEntryResponse),
180    (Follow, FollowResponse),
181    (FormatBuffers, FormatBuffersResponse),
182    (GetChannelMessages, GetChannelMessagesResponse),
183    (GetChannels, GetChannelsResponse),
184    (GetCodeActions, GetCodeActionsResponse),
185    (GetHover, GetHoverResponse),
186    (GetCompletions, GetCompletionsResponse),
187    (GetDefinition, GetDefinitionResponse),
188    (GetTypeDefinition, GetTypeDefinitionResponse),
189    (GetDocumentHighlights, GetDocumentHighlightsResponse),
190    (GetReferences, GetReferencesResponse),
191    (GetProjectSymbols, GetProjectSymbolsResponse),
192    (FuzzySearchUsers, UsersResponse),
193    (GetUsers, UsersResponse),
194    (JoinChannel, JoinChannelResponse),
195    (JoinProject, JoinProjectResponse),
196    (OpenBufferById, OpenBufferResponse),
197    (OpenBufferByPath, OpenBufferResponse),
198    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
199    (Ping, Ack),
200    (PerformRename, PerformRenameResponse),
201    (PrepareRename, PrepareRenameResponse),
202    (RegisterProject, RegisterProjectResponse),
203    (ReloadBuffers, ReloadBuffersResponse),
204    (RequestContact, Ack),
205    (RemoveContact, Ack),
206    (RespondToContactRequest, Ack),
207    (RenameProjectEntry, ProjectEntryResponse),
208    (SaveBuffer, BufferSaved),
209    (SearchProject, SearchProjectResponse),
210    (SendChannelMessage, SendChannelMessageResponse),
211    (Test, Test),
212    (UnregisterProject, Ack),
213    (UpdateBuffer, Ack),
214    (UpdateWorktree, Ack),
215);
216
217entity_messages!(
218    project_id,
219    AddProjectCollaborator,
220    ApplyCodeAction,
221    ApplyCompletionAdditionalEdits,
222    BufferReloaded,
223    BufferSaved,
224    CopyProjectEntry,
225    CreateProjectEntry,
226    DeleteProjectEntry,
227    Follow,
228    FormatBuffers,
229    GetCodeActions,
230    GetCompletions,
231    GetDefinition,
232    GetTypeDefinition,
233    GetDocumentHighlights,
234    GetHover,
235    GetReferences,
236    GetProjectSymbols,
237    JoinProject,
238    JoinProjectRequestCancelled,
239    LeaveProject,
240    OpenBufferById,
241    OpenBufferByPath,
242    OpenBufferForSymbol,
243    PerformRename,
244    PrepareRename,
245    ProjectUnshared,
246    RegisterProjectActivity,
247    ReloadBuffers,
248    RemoveProjectCollaborator,
249    RenameProjectEntry,
250    RequestJoinProject,
251    SaveBuffer,
252    SearchProject,
253    StartLanguageServer,
254    Unfollow,
255    UnregisterProject,
256    UpdateBuffer,
257    UpdateBufferFile,
258    UpdateDiagnosticSummary,
259    UpdateFollowers,
260    UpdateLanguageServer,
261    UpdateProject,
262    UpdateWorktree,
263    UpdateWorktreeExtensions,
264);
265
266entity_messages!(channel_id, ChannelMessageSent);
267
268const MAX_BUFFER_LEN: usize = 1 * 1024 * 1024;
269
270/// A stream of protobuf messages.
271pub struct MessageStream<S> {
272    stream: S,
273    encoding_buffer: Vec<u8>,
274}
275
276#[derive(Debug)]
277pub enum Message {
278    Envelope(Envelope),
279    Ping,
280    Pong,
281}
282
283impl<S> MessageStream<S> {
284    pub fn new(stream: S) -> Self {
285        Self {
286            stream,
287            encoding_buffer: Vec::new(),
288        }
289    }
290
291    pub fn inner_mut(&mut self) -> &mut S {
292        &mut self.stream
293    }
294}
295
296impl<S> MessageStream<S>
297where
298    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
299{
300    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
301        #[cfg(any(test, feature = "test-support"))]
302        const COMPRESSION_LEVEL: i32 = -7;
303
304        #[cfg(not(any(test, feature = "test-support")))]
305        const COMPRESSION_LEVEL: i32 = 4;
306
307        match message {
308            Message::Envelope(message) => {
309                self.encoding_buffer.reserve(message.encoded_len());
310                message
311                    .encode(&mut self.encoding_buffer)
312                    .map_err(|err| io::Error::from(err))?;
313                let buffer =
314                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
315                        .unwrap();
316
317                self.encoding_buffer.clear();
318                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
319                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
320            }
321            Message::Ping => {
322                self.stream
323                    .send(WebSocketMessage::Ping(Default::default()))
324                    .await?;
325            }
326            Message::Pong => {
327                self.stream
328                    .send(WebSocketMessage::Pong(Default::default()))
329                    .await?;
330            }
331        }
332
333        Ok(())
334    }
335}
336
337impl<S> MessageStream<S>
338where
339    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
340{
341    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
342        while let Some(bytes) = self.stream.next().await {
343            match bytes? {
344                WebSocketMessage::Binary(bytes) => {
345                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
346                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
347                        .map_err(io::Error::from)?;
348
349                    self.encoding_buffer.clear();
350                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
351                    return Ok(Message::Envelope(envelope));
352                }
353                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
354                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
355                WebSocketMessage::Close(_) => break,
356                _ => {}
357            }
358        }
359        Err(anyhow!("connection closed"))
360    }
361}
362
363impl Into<SystemTime> for Timestamp {
364    fn into(self) -> SystemTime {
365        UNIX_EPOCH
366            .checked_add(Duration::new(self.seconds, self.nanos))
367            .unwrap()
368    }
369}
370
371impl From<SystemTime> for Timestamp {
372    fn from(time: SystemTime) -> Self {
373        let duration = time.duration_since(UNIX_EPOCH).unwrap();
374        Self {
375            seconds: duration.as_secs(),
376            nanos: duration.subsec_nanos(),
377        }
378    }
379}
380
381impl From<u128> for Nonce {
382    fn from(nonce: u128) -> Self {
383        let upper_half = (nonce >> 64) as u64;
384        let lower_half = nonce as u64;
385        Self {
386            upper_half,
387            lower_half,
388        }
389    }
390}
391
392impl From<Nonce> for u128 {
393    fn from(nonce: Nonce) -> Self {
394        let upper_half = (nonce.upper_half as u128) << 64;
395        let lower_half = nonce.lower_half as u128;
396        upper_half | lower_half
397    }
398}
399
400pub fn split_worktree_update(
401    mut message: UpdateWorktree,
402    max_chunk_size: usize,
403) -> impl Iterator<Item = UpdateWorktree> {
404    let mut done = false;
405    iter::from_fn(move || {
406        if done {
407            return None;
408        }
409
410        let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
411        let updated_entries = message.updated_entries.drain(..chunk_size).collect();
412        done = message.updated_entries.is_empty();
413        Some(UpdateWorktree {
414            project_id: message.project_id,
415            worktree_id: message.worktree_id,
416            root_name: message.root_name.clone(),
417            updated_entries,
418            removed_entries: mem::take(&mut message.removed_entries),
419            scan_id: message.scan_id,
420            is_last_update: done && message.is_last_update,
421        })
422    })
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    #[gpui::test]
430    async fn test_buffer_size() {
431        let (tx, rx) = futures::channel::mpsc::unbounded();
432        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
433        sink.write(Message::Envelope(Envelope {
434            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
435                root_name: "abcdefg".repeat(10),
436                ..Default::default()
437            })),
438            ..Default::default()
439        }))
440        .await
441        .unwrap();
442        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
443        sink.write(Message::Envelope(Envelope {
444            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
445                root_name: "abcdefg".repeat(1000000),
446                ..Default::default()
447            })),
448            ..Default::default()
449        }))
450        .await
451        .unwrap();
452        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
453
454        let mut stream = MessageStream::new(rx.map(|msg| anyhow::Ok(msg)));
455        stream.read().await.unwrap();
456        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
457        stream.read().await.unwrap();
458        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
459    }
460}