proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{
  9    fmt::Debug,
 10    io,
 11    time::{Duration, SystemTime, UNIX_EPOCH},
 12};
 13
 14include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 15
 16pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 17    const NAME: &'static str;
 18    const PRIORITY: MessagePriority;
 19    fn into_envelope(
 20        self,
 21        id: u32,
 22        responding_to: Option<u32>,
 23        original_sender_id: Option<u32>,
 24    ) -> Envelope;
 25    fn from_envelope(envelope: Envelope) -> Option<Self>;
 26}
 27
 28pub trait EntityMessage: EnvelopedMessage {
 29    fn remote_entity_id(&self) -> u64;
 30}
 31
 32pub trait RequestMessage: EnvelopedMessage {
 33    type Response: EnvelopedMessage;
 34}
 35
 36pub trait AnyTypedEnvelope: 'static + Send + Sync {
 37    fn payload_type_id(&self) -> TypeId;
 38    fn payload_type_name(&self) -> &'static str;
 39    fn as_any(&self) -> &dyn Any;
 40    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 41    fn is_background(&self) -> bool;
 42    fn original_sender_id(&self) -> Option<PeerId>;
 43}
 44
 45pub enum MessagePriority {
 46    Foreground,
 47    Background,
 48}
 49
 50impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 51    fn payload_type_id(&self) -> TypeId {
 52        TypeId::of::<T>()
 53    }
 54
 55    fn payload_type_name(&self) -> &'static str {
 56        T::NAME
 57    }
 58
 59    fn as_any(&self) -> &dyn Any {
 60        self
 61    }
 62
 63    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 64        self
 65    }
 66
 67    fn is_background(&self) -> bool {
 68        matches!(T::PRIORITY, MessagePriority::Background)
 69    }
 70
 71    fn original_sender_id(&self) -> Option<PeerId> {
 72        self.original_sender_id
 73    }
 74}
 75
 76messages!(
 77    (Ack, Foreground),
 78    (AddProjectCollaborator, Foreground),
 79    (ApplyCodeAction, Background),
 80    (ApplyCodeActionResponse, Background),
 81    (ApplyCompletionAdditionalEdits, Background),
 82    (ApplyCompletionAdditionalEditsResponse, Background),
 83    (BufferReloaded, Foreground),
 84    (BufferSaved, Foreground),
 85    (RemoveContact, Foreground),
 86    (ChannelMessageSent, Foreground),
 87    (CopyProjectEntry, Foreground),
 88    (CreateProjectEntry, Foreground),
 89    (DeleteProjectEntry, Foreground),
 90    (Error, Foreground),
 91    (Follow, Foreground),
 92    (FollowResponse, Foreground),
 93    (FormatBuffers, Foreground),
 94    (FormatBuffersResponse, Foreground),
 95    (FuzzySearchUsers, Foreground),
 96    (GetChannelMessages, Foreground),
 97    (GetChannelMessagesResponse, Foreground),
 98    (GetChannels, Foreground),
 99    (GetChannelsResponse, Foreground),
100    (GetCodeActions, Background),
101    (GetCodeActionsResponse, Background),
102    (GetHover, Background),
103    (GetHoverResponse, Background),
104    (GetCompletions, Background),
105    (GetCompletionsResponse, Background),
106    (GetDefinition, Background),
107    (GetDefinitionResponse, Background),
108    (GetDocumentHighlights, Background),
109    (GetDocumentHighlightsResponse, Background),
110    (GetReferences, Background),
111    (GetReferencesResponse, Background),
112    (GetProjectSymbols, Background),
113    (GetProjectSymbolsResponse, Background),
114    (GetUsers, Foreground),
115    (UsersResponse, Foreground),
116    (JoinChannel, Foreground),
117    (JoinChannelResponse, Foreground),
118    (JoinProject, Foreground),
119    (JoinProjectResponse, Foreground),
120    (JoinProjectRequestCancelled, Foreground),
121    (LeaveChannel, Foreground),
122    (LeaveProject, Foreground),
123    (OpenBufferById, Background),
124    (OpenBufferByPath, Background),
125    (OpenBufferForSymbol, Background),
126    (OpenBufferForSymbolResponse, Background),
127    (OpenBufferResponse, Background),
128    (PerformRename, Background),
129    (PerformRenameResponse, Background),
130    (PrepareRename, Background),
131    (PrepareRenameResponse, Background),
132    (ProjectEntryResponse, Foreground),
133    (RegisterProjectResponse, Foreground),
134    (Ping, Foreground),
135    (ProjectUnshared, Foreground),
136    (RegisterProject, Foreground),
137    (RegisterProjectActivity, Foreground),
138    (ReloadBuffers, Foreground),
139    (ReloadBuffersResponse, Foreground),
140    (RemoveProjectCollaborator, Foreground),
141    (RenameProjectEntry, Foreground),
142    (RequestContact, Foreground),
143    (RequestJoinProject, Foreground),
144    (RespondToContactRequest, Foreground),
145    (RespondToJoinProjectRequest, Foreground),
146    (SaveBuffer, Foreground),
147    (SearchProject, Background),
148    (SearchProjectResponse, Background),
149    (SendChannelMessage, Foreground),
150    (SendChannelMessageResponse, Foreground),
151    (ShowContacts, Foreground),
152    (StartLanguageServer, Foreground),
153    (Test, Foreground),
154    (Unfollow, Foreground),
155    (UnregisterProject, Foreground),
156    (UpdateBuffer, Foreground),
157    (UpdateBufferFile, Foreground),
158    (UpdateContacts, Foreground),
159    (UpdateDiagnosticSummary, Foreground),
160    (UpdateFollowers, Foreground),
161    (UpdateInviteInfo, Foreground),
162    (UpdateLanguageServer, Foreground),
163    (UpdateProject, Foreground),
164    (UpdateWorktree, Foreground),
165);
166
167request_messages!(
168    (ApplyCodeAction, ApplyCodeActionResponse),
169    (
170        ApplyCompletionAdditionalEdits,
171        ApplyCompletionAdditionalEditsResponse
172    ),
173    (CopyProjectEntry, ProjectEntryResponse),
174    (CreateProjectEntry, ProjectEntryResponse),
175    (DeleteProjectEntry, ProjectEntryResponse),
176    (Follow, FollowResponse),
177    (FormatBuffers, FormatBuffersResponse),
178    (GetChannelMessages, GetChannelMessagesResponse),
179    (GetChannels, GetChannelsResponse),
180    (GetCodeActions, GetCodeActionsResponse),
181    (GetHover, GetHoverResponse),
182    (GetCompletions, GetCompletionsResponse),
183    (GetDefinition, GetDefinitionResponse),
184    (GetDocumentHighlights, GetDocumentHighlightsResponse),
185    (GetReferences, GetReferencesResponse),
186    (GetProjectSymbols, GetProjectSymbolsResponse),
187    (FuzzySearchUsers, UsersResponse),
188    (GetUsers, UsersResponse),
189    (JoinChannel, JoinChannelResponse),
190    (JoinProject, JoinProjectResponse),
191    (OpenBufferById, OpenBufferResponse),
192    (OpenBufferByPath, OpenBufferResponse),
193    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
194    (Ping, Ack),
195    (PerformRename, PerformRenameResponse),
196    (PrepareRename, PrepareRenameResponse),
197    (RegisterProject, RegisterProjectResponse),
198    (ReloadBuffers, ReloadBuffersResponse),
199    (RequestContact, Ack),
200    (RemoveContact, Ack),
201    (RespondToContactRequest, Ack),
202    (RenameProjectEntry, ProjectEntryResponse),
203    (SaveBuffer, BufferSaved),
204    (SearchProject, SearchProjectResponse),
205    (SendChannelMessage, SendChannelMessageResponse),
206    (Test, Test),
207    (UnregisterProject, Ack),
208    (UpdateBuffer, Ack),
209    (UpdateWorktree, Ack),
210);
211
212entity_messages!(
213    project_id,
214    AddProjectCollaborator,
215    ApplyCodeAction,
216    ApplyCompletionAdditionalEdits,
217    BufferReloaded,
218    BufferSaved,
219    CopyProjectEntry,
220    CreateProjectEntry,
221    DeleteProjectEntry,
222    Follow,
223    FormatBuffers,
224    GetCodeActions,
225    GetCompletions,
226    GetDefinition,
227    GetDocumentHighlights,
228    GetHover,
229    GetReferences,
230    GetProjectSymbols,
231    JoinProject,
232    JoinProjectRequestCancelled,
233    LeaveProject,
234    OpenBufferById,
235    OpenBufferByPath,
236    OpenBufferForSymbol,
237    PerformRename,
238    PrepareRename,
239    ProjectUnshared,
240    RegisterProjectActivity,
241    ReloadBuffers,
242    RemoveProjectCollaborator,
243    RenameProjectEntry,
244    RequestJoinProject,
245    SaveBuffer,
246    SearchProject,
247    StartLanguageServer,
248    Unfollow,
249    UnregisterProject,
250    UpdateBuffer,
251    UpdateBufferFile,
252    UpdateDiagnosticSummary,
253    UpdateFollowers,
254    UpdateLanguageServer,
255    UpdateProject,
256    UpdateWorktree,
257);
258
259entity_messages!(channel_id, ChannelMessageSent);
260
261const MAX_BUFFER_LEN: usize = 1 * 1024 * 1024;
262
263/// A stream of protobuf messages.
264pub struct MessageStream<S> {
265    stream: S,
266    encoding_buffer: Vec<u8>,
267}
268
269#[derive(Debug)]
270pub enum Message {
271    Envelope(Envelope),
272    Ping,
273    Pong,
274}
275
276impl<S> MessageStream<S> {
277    pub fn new(stream: S) -> Self {
278        Self {
279            stream,
280            encoding_buffer: Vec::new(),
281        }
282    }
283
284    pub fn inner_mut(&mut self) -> &mut S {
285        &mut self.stream
286    }
287}
288
289impl<S> MessageStream<S>
290where
291    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
292{
293    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
294        #[cfg(any(test, feature = "test-support"))]
295        const COMPRESSION_LEVEL: i32 = -7;
296
297        #[cfg(not(any(test, feature = "test-support")))]
298        const COMPRESSION_LEVEL: i32 = 4;
299
300        match message {
301            Message::Envelope(message) => {
302                self.encoding_buffer.reserve(message.encoded_len());
303                message
304                    .encode(&mut self.encoding_buffer)
305                    .map_err(|err| io::Error::from(err))?;
306                let buffer =
307                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
308                        .unwrap();
309
310                self.encoding_buffer.clear();
311                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
312                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
313            }
314            Message::Ping => {
315                self.stream
316                    .send(WebSocketMessage::Ping(Default::default()))
317                    .await?;
318            }
319            Message::Pong => {
320                self.stream
321                    .send(WebSocketMessage::Pong(Default::default()))
322                    .await?;
323            }
324        }
325
326        Ok(())
327    }
328}
329
330impl<S> MessageStream<S>
331where
332    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
333{
334    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
335        while let Some(bytes) = self.stream.next().await {
336            match bytes? {
337                WebSocketMessage::Binary(bytes) => {
338                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
339                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
340                        .map_err(io::Error::from)?;
341
342                    self.encoding_buffer.clear();
343                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
344                    return Ok(Message::Envelope(envelope));
345                }
346                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
347                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
348                WebSocketMessage::Close(_) => break,
349                _ => {}
350            }
351        }
352        Err(anyhow!("connection closed"))
353    }
354}
355
356impl Into<SystemTime> for Timestamp {
357    fn into(self) -> SystemTime {
358        UNIX_EPOCH
359            .checked_add(Duration::new(self.seconds, self.nanos))
360            .unwrap()
361    }
362}
363
364impl From<SystemTime> for Timestamp {
365    fn from(time: SystemTime) -> Self {
366        let duration = time.duration_since(UNIX_EPOCH).unwrap();
367        Self {
368            seconds: duration.as_secs(),
369            nanos: duration.subsec_nanos(),
370        }
371    }
372}
373
374impl From<u128> for Nonce {
375    fn from(nonce: u128) -> Self {
376        let upper_half = (nonce >> 64) as u64;
377        let lower_half = nonce as u64;
378        Self {
379            upper_half,
380            lower_half,
381        }
382    }
383}
384
385impl From<Nonce> for u128 {
386    fn from(nonce: Nonce) -> Self {
387        let upper_half = (nonce.upper_half as u128) << 64;
388        let lower_half = nonce.lower_half as u128;
389        upper_half | lower_half
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[gpui::test]
398    async fn test_buffer_size() {
399        let (tx, rx) = futures::channel::mpsc::unbounded();
400        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
401        sink.write(Message::Envelope(Envelope {
402            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
403                root_name: "abcdefg".repeat(10),
404                ..Default::default()
405            })),
406            ..Default::default()
407        }))
408        .await
409        .unwrap();
410        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
411        sink.write(Message::Envelope(Envelope {
412            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
413                root_name: "abcdefg".repeat(1000000),
414                ..Default::default()
415            })),
416            ..Default::default()
417        }))
418        .await
419        .unwrap();
420        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
421
422        let mut stream = MessageStream::new(rx.map(|msg| anyhow::Ok(msg)));
423        stream.read().await.unwrap();
424        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
425        stream.read().await.unwrap();
426        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
427    }
428}