proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{
  9    fmt::Debug,
 10    io,
 11    time::{Duration, SystemTime, UNIX_EPOCH},
 12};
 13
 14include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 15
 16pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 17    const NAME: &'static str;
 18    const PRIORITY: MessagePriority;
 19    fn into_envelope(
 20        self,
 21        id: u32,
 22        responding_to: Option<u32>,
 23        original_sender_id: Option<u32>,
 24    ) -> Envelope;
 25    fn from_envelope(envelope: Envelope) -> Option<Self>;
 26}
 27
 28pub trait EntityMessage: EnvelopedMessage {
 29    fn remote_entity_id(&self) -> u64;
 30}
 31
 32pub trait RequestMessage: EnvelopedMessage {
 33    type Response: EnvelopedMessage;
 34}
 35
 36pub trait AnyTypedEnvelope: 'static + Send + Sync {
 37    fn payload_type_id(&self) -> TypeId;
 38    fn payload_type_name(&self) -> &'static str;
 39    fn as_any(&self) -> &dyn Any;
 40    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 41    fn is_background(&self) -> bool;
 42    fn original_sender_id(&self) -> Option<PeerId>;
 43}
 44
 45pub enum MessagePriority {
 46    Foreground,
 47    Background,
 48}
 49
 50impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 51    fn payload_type_id(&self) -> TypeId {
 52        TypeId::of::<T>()
 53    }
 54
 55    fn payload_type_name(&self) -> &'static str {
 56        T::NAME
 57    }
 58
 59    fn as_any(&self) -> &dyn Any {
 60        self
 61    }
 62
 63    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 64        self
 65    }
 66
 67    fn is_background(&self) -> bool {
 68        matches!(T::PRIORITY, MessagePriority::Background)
 69    }
 70
 71    fn original_sender_id(&self) -> Option<PeerId> {
 72        self.original_sender_id
 73    }
 74}
 75
 76messages!(
 77    (Ack, Foreground),
 78    (AddProjectCollaborator, Foreground),
 79    (ApplyCodeAction, Background),
 80    (ApplyCodeActionResponse, Background),
 81    (ApplyCompletionAdditionalEdits, Background),
 82    (ApplyCompletionAdditionalEditsResponse, Background),
 83    (BufferReloaded, Foreground),
 84    (BufferSaved, Foreground),
 85    (RemoveContact, Foreground),
 86    (ChannelMessageSent, Foreground),
 87    (CopyProjectEntry, Foreground),
 88    (CreateProjectEntry, Foreground),
 89    (DeleteProjectEntry, Foreground),
 90    (Error, Foreground),
 91    (Follow, Foreground),
 92    (FollowResponse, Foreground),
 93    (FormatBuffers, Foreground),
 94    (FormatBuffersResponse, Foreground),
 95    (FuzzySearchUsers, Foreground),
 96    (GetChannelMessages, Foreground),
 97    (GetChannelMessagesResponse, Foreground),
 98    (GetChannels, Foreground),
 99    (GetChannelsResponse, Foreground),
100    (GetCodeActions, Background),
101    (GetCodeActionsResponse, Background),
102    (GetCompletions, Background),
103    (GetCompletionsResponse, Background),
104    (GetDefinition, Background),
105    (GetDefinitionResponse, Background),
106    (GetDocumentHighlights, Background),
107    (GetDocumentHighlightsResponse, Background),
108    (GetReferences, Background),
109    (GetReferencesResponse, Background),
110    (GetProjectSymbols, Background),
111    (GetProjectSymbolsResponse, Background),
112    (GetUsers, Foreground),
113    (UsersResponse, Foreground),
114    (JoinChannel, Foreground),
115    (JoinChannelResponse, Foreground),
116    (JoinProject, Foreground),
117    (JoinProjectResponse, Foreground),
118    (JoinProjectRequestCancelled, Foreground),
119    (LeaveChannel, Foreground),
120    (LeaveProject, Foreground),
121    (OpenBufferById, Background),
122    (OpenBufferByPath, Background),
123    (OpenBufferForSymbol, Background),
124    (OpenBufferForSymbolResponse, Background),
125    (OpenBufferResponse, Background),
126    (PerformRename, Background),
127    (PerformRenameResponse, Background),
128    (PrepareRename, Background),
129    (PrepareRenameResponse, Background),
130    (ProjectEntryResponse, Foreground),
131    (RegisterProjectResponse, Foreground),
132    (Ping, Foreground),
133    (ProjectUnshared, Foreground),
134    (RegisterProject, Foreground),
135    (RegisterWorktree, Foreground),
136    (ReloadBuffers, Foreground),
137    (ReloadBuffersResponse, Foreground),
138    (RemoveProjectCollaborator, Foreground),
139    (RenameProjectEntry, Foreground),
140    (RequestContact, Foreground),
141    (RequestJoinProject, Foreground),
142    (RespondToContactRequest, Foreground),
143    (RespondToJoinProjectRequest, Foreground),
144    (SaveBuffer, Foreground),
145    (SearchProject, Background),
146    (SearchProjectResponse, Background),
147    (SendChannelMessage, Foreground),
148    (SendChannelMessageResponse, Foreground),
149    (ShowContacts, Foreground),
150    (StartLanguageServer, Foreground),
151    (Test, Foreground),
152    (Unfollow, Foreground),
153    (UnregisterProject, Foreground),
154    (UnregisterWorktree, Foreground),
155    (UpdateBuffer, Foreground),
156    (UpdateBufferFile, Foreground),
157    (UpdateContacts, Foreground),
158    (UpdateDiagnosticSummary, Foreground),
159    (UpdateFollowers, Foreground),
160    (UpdateInviteInfo, Foreground),
161    (UpdateLanguageServer, Foreground),
162    (UpdateWorktree, Foreground),
163);
164
165request_messages!(
166    (ApplyCodeAction, ApplyCodeActionResponse),
167    (
168        ApplyCompletionAdditionalEdits,
169        ApplyCompletionAdditionalEditsResponse
170    ),
171    (CopyProjectEntry, ProjectEntryResponse),
172    (CreateProjectEntry, ProjectEntryResponse),
173    (DeleteProjectEntry, ProjectEntryResponse),
174    (Follow, FollowResponse),
175    (FormatBuffers, FormatBuffersResponse),
176    (GetChannelMessages, GetChannelMessagesResponse),
177    (GetChannels, GetChannelsResponse),
178    (GetCodeActions, GetCodeActionsResponse),
179    (GetCompletions, GetCompletionsResponse),
180    (GetDefinition, GetDefinitionResponse),
181    (GetDocumentHighlights, GetDocumentHighlightsResponse),
182    (GetReferences, GetReferencesResponse),
183    (GetProjectSymbols, GetProjectSymbolsResponse),
184    (FuzzySearchUsers, UsersResponse),
185    (GetUsers, UsersResponse),
186    (JoinChannel, JoinChannelResponse),
187    (JoinProject, JoinProjectResponse),
188    (OpenBufferById, OpenBufferResponse),
189    (OpenBufferByPath, OpenBufferResponse),
190    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
191    (Ping, Ack),
192    (PerformRename, PerformRenameResponse),
193    (PrepareRename, PrepareRenameResponse),
194    (RegisterProject, RegisterProjectResponse),
195    (RegisterWorktree, Ack),
196    (ReloadBuffers, ReloadBuffersResponse),
197    (RequestContact, Ack),
198    (RemoveContact, Ack),
199    (RespondToContactRequest, Ack),
200    (RenameProjectEntry, ProjectEntryResponse),
201    (SaveBuffer, BufferSaved),
202    (SearchProject, SearchProjectResponse),
203    (SendChannelMessage, SendChannelMessageResponse),
204    (Test, Test),
205    (UpdateBuffer, Ack),
206    (UpdateWorktree, Ack),
207);
208
209entity_messages!(
210    project_id,
211    AddProjectCollaborator,
212    ApplyCodeAction,
213    ApplyCompletionAdditionalEdits,
214    BufferReloaded,
215    BufferSaved,
216    CopyProjectEntry,
217    CreateProjectEntry,
218    DeleteProjectEntry,
219    Follow,
220    FormatBuffers,
221    GetCodeActions,
222    GetCompletions,
223    GetDefinition,
224    GetDocumentHighlights,
225    GetReferences,
226    GetProjectSymbols,
227    JoinProject,
228    JoinProjectRequestCancelled,
229    LeaveProject,
230    OpenBufferById,
231    OpenBufferByPath,
232    OpenBufferForSymbol,
233    PerformRename,
234    PrepareRename,
235    ProjectUnshared,
236    ReloadBuffers,
237    RemoveProjectCollaborator,
238    RenameProjectEntry,
239    RequestJoinProject,
240    SaveBuffer,
241    SearchProject,
242    StartLanguageServer,
243    Unfollow,
244    UnregisterProject,
245    UnregisterWorktree,
246    UpdateBuffer,
247    UpdateBufferFile,
248    UpdateDiagnosticSummary,
249    UpdateFollowers,
250    UpdateLanguageServer,
251    RegisterWorktree,
252    UpdateWorktree,
253);
254
255entity_messages!(channel_id, ChannelMessageSent);
256
257const MAX_BUFFER_LEN: usize = 1 * 1024 * 1024;
258
259/// A stream of protobuf messages.
260pub struct MessageStream<S> {
261    stream: S,
262    encoding_buffer: Vec<u8>,
263}
264
265#[derive(Debug)]
266pub enum Message {
267    Envelope(Envelope),
268    Ping,
269    Pong,
270}
271
272impl<S> MessageStream<S> {
273    pub fn new(stream: S) -> Self {
274        Self {
275            stream,
276            encoding_buffer: Vec::new(),
277        }
278    }
279
280    pub fn inner_mut(&mut self) -> &mut S {
281        &mut self.stream
282    }
283}
284
285impl<S> MessageStream<S>
286where
287    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
288{
289    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
290        #[cfg(any(test, feature = "test-support"))]
291        const COMPRESSION_LEVEL: i32 = -7;
292
293        #[cfg(not(any(test, feature = "test-support")))]
294        const COMPRESSION_LEVEL: i32 = 4;
295
296        match message {
297            Message::Envelope(message) => {
298                self.encoding_buffer.reserve(message.encoded_len());
299                message
300                    .encode(&mut self.encoding_buffer)
301                    .map_err(|err| io::Error::from(err))?;
302                let buffer =
303                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
304                        .unwrap();
305
306                self.encoding_buffer.clear();
307                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
308                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
309            }
310            Message::Ping => {
311                self.stream
312                    .send(WebSocketMessage::Ping(Default::default()))
313                    .await?;
314            }
315            Message::Pong => {
316                self.stream
317                    .send(WebSocketMessage::Pong(Default::default()))
318                    .await?;
319            }
320        }
321
322        Ok(())
323    }
324}
325
326impl<S> MessageStream<S>
327where
328    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
329{
330    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
331        while let Some(bytes) = self.stream.next().await {
332            match bytes? {
333                WebSocketMessage::Binary(bytes) => {
334                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
335                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
336                        .map_err(io::Error::from)?;
337
338                    self.encoding_buffer.clear();
339                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
340                    return Ok(Message::Envelope(envelope));
341                }
342                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
343                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
344                WebSocketMessage::Close(_) => break,
345                _ => {}
346            }
347        }
348        Err(anyhow!("connection closed"))
349    }
350}
351
352impl Into<SystemTime> for Timestamp {
353    fn into(self) -> SystemTime {
354        UNIX_EPOCH
355            .checked_add(Duration::new(self.seconds, self.nanos))
356            .unwrap()
357    }
358}
359
360impl From<SystemTime> for Timestamp {
361    fn from(time: SystemTime) -> Self {
362        let duration = time.duration_since(UNIX_EPOCH).unwrap();
363        Self {
364            seconds: duration.as_secs(),
365            nanos: duration.subsec_nanos(),
366        }
367    }
368}
369
370impl From<u128> for Nonce {
371    fn from(nonce: u128) -> Self {
372        let upper_half = (nonce >> 64) as u64;
373        let lower_half = nonce as u64;
374        Self {
375            upper_half,
376            lower_half,
377        }
378    }
379}
380
381impl From<Nonce> for u128 {
382    fn from(nonce: Nonce) -> Self {
383        let upper_half = (nonce.upper_half as u128) << 64;
384        let lower_half = nonce.lower_half as u128;
385        upper_half | lower_half
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[gpui::test]
394    async fn test_buffer_size() {
395        let (tx, rx) = futures::channel::mpsc::unbounded();
396        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
397        sink.write(Message::Envelope(Envelope {
398            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
399                root_name: "abcdefg".repeat(10),
400                ..Default::default()
401            })),
402            ..Default::default()
403        }))
404        .await
405        .unwrap();
406        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
407        sink.write(Message::Envelope(Envelope {
408            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
409                root_name: "abcdefg".repeat(1000000),
410                ..Default::default()
411            })),
412            ..Default::default()
413        }))
414        .await
415        .unwrap();
416        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
417
418        let mut stream = MessageStream::new(rx.map(|msg| anyhow::Ok(msg)));
419        stream.read().await.unwrap();
420        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
421        stream.read().await.unwrap();
422        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
423    }
424}