proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{
  9    fmt::Debug,
 10    io,
 11    time::{Duration, SystemTime, UNIX_EPOCH},
 12};
 13
 14include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 15
 16pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 17    const NAME: &'static str;
 18    const PRIORITY: MessagePriority;
 19    fn into_envelope(
 20        self,
 21        id: u32,
 22        responding_to: Option<u32>,
 23        original_sender_id: Option<u32>,
 24    ) -> Envelope;
 25    fn from_envelope(envelope: Envelope) -> Option<Self>;
 26}
 27
 28pub trait EntityMessage: EnvelopedMessage {
 29    fn remote_entity_id(&self) -> u64;
 30}
 31
 32pub trait RequestMessage: EnvelopedMessage {
 33    type Response: EnvelopedMessage;
 34}
 35
 36pub trait AnyTypedEnvelope: 'static + Send + Sync {
 37    fn payload_type_id(&self) -> TypeId;
 38    fn payload_type_name(&self) -> &'static str;
 39    fn as_any(&self) -> &dyn Any;
 40    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 41    fn is_background(&self) -> bool;
 42    fn original_sender_id(&self) -> Option<PeerId>;
 43}
 44
 45pub enum MessagePriority {
 46    Foreground,
 47    Background,
 48}
 49
 50impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 51    fn payload_type_id(&self) -> TypeId {
 52        TypeId::of::<T>()
 53    }
 54
 55    fn payload_type_name(&self) -> &'static str {
 56        T::NAME
 57    }
 58
 59    fn as_any(&self) -> &dyn Any {
 60        self
 61    }
 62
 63    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 64        self
 65    }
 66
 67    fn is_background(&self) -> bool {
 68        matches!(T::PRIORITY, MessagePriority::Background)
 69    }
 70
 71    fn original_sender_id(&self) -> Option<PeerId> {
 72        self.original_sender_id
 73    }
 74}
 75
 76messages!(
 77    (Ack, Foreground),
 78    (AddProjectCollaborator, Foreground),
 79    (ApplyCodeAction, Background),
 80    (ApplyCodeActionResponse, Background),
 81    (ApplyCompletionAdditionalEdits, Background),
 82    (ApplyCompletionAdditionalEditsResponse, Background),
 83    (BufferReloaded, Foreground),
 84    (BufferSaved, Foreground),
 85    (RemoveContact, Foreground),
 86    (ChannelMessageSent, Foreground),
 87    (CopyProjectEntry, Foreground),
 88    (CreateProjectEntry, Foreground),
 89    (DeleteProjectEntry, Foreground),
 90    (Error, Foreground),
 91    (Follow, Foreground),
 92    (FollowResponse, Foreground),
 93    (FormatBuffers, Foreground),
 94    (FormatBuffersResponse, Foreground),
 95    (FuzzySearchUsers, Foreground),
 96    (GetChannelMessages, Foreground),
 97    (GetChannelMessagesResponse, Foreground),
 98    (GetChannels, Foreground),
 99    (GetChannelsResponse, Foreground),
100    (GetCodeActions, Background),
101    (GetCodeActionsResponse, Background),
102    (GetHover, Background),
103    (GetHoverResponse, Background),
104    (GetCompletions, Background),
105    (GetCompletionsResponse, Background),
106    (GetDefinition, Background),
107    (GetDefinitionResponse, Background),
108    (GetDocumentHighlights, Background),
109    (GetDocumentHighlightsResponse, Background),
110    (GetReferences, Background),
111    (GetReferencesResponse, Background),
112    (GetProjectSymbols, Background),
113    (GetProjectSymbolsResponse, Background),
114    (GetUsers, Foreground),
115    (UsersResponse, Foreground),
116    (JoinChannel, Foreground),
117    (JoinChannelResponse, Foreground),
118    (JoinProject, Foreground),
119    (JoinProjectResponse, Foreground),
120    (JoinProjectRequestCancelled, Foreground),
121    (LeaveChannel, Foreground),
122    (LeaveProject, Foreground),
123    (OpenBufferById, Background),
124    (OpenBufferByPath, Background),
125    (OpenBufferForSymbol, Background),
126    (OpenBufferForSymbolResponse, Background),
127    (OpenBufferResponse, Background),
128    (PerformRename, Background),
129    (PerformRenameResponse, Background),
130    (PrepareRename, Background),
131    (PrepareRenameResponse, Background),
132    (ProjectEntryResponse, Foreground),
133    (RegisterProjectResponse, Foreground),
134    (Ping, Foreground),
135    (ProjectUnshared, Foreground),
136    (RegisterProject, Foreground),
137    (ReloadBuffers, Foreground),
138    (ReloadBuffersResponse, Foreground),
139    (RemoveProjectCollaborator, Foreground),
140    (RenameProjectEntry, Foreground),
141    (RequestContact, Foreground),
142    (RequestJoinProject, Foreground),
143    (RespondToContactRequest, Foreground),
144    (RespondToJoinProjectRequest, Foreground),
145    (SaveBuffer, Foreground),
146    (SearchProject, Background),
147    (SearchProjectResponse, Background),
148    (SendChannelMessage, Foreground),
149    (SendChannelMessageResponse, Foreground),
150    (ShowContacts, Foreground),
151    (StartLanguageServer, Foreground),
152    (Test, Foreground),
153    (Unfollow, Foreground),
154    (UnregisterProject, Foreground),
155    (UpdateBuffer, Foreground),
156    (UpdateBufferFile, Foreground),
157    (UpdateContacts, Foreground),
158    (UpdateDiagnosticSummary, Foreground),
159    (UpdateFollowers, Foreground),
160    (UpdateInviteInfo, Foreground),
161    (UpdateLanguageServer, Foreground),
162    (UpdateProject, Foreground),
163    (UpdateWorktree, Foreground),
164);
165
166request_messages!(
167    (ApplyCodeAction, ApplyCodeActionResponse),
168    (
169        ApplyCompletionAdditionalEdits,
170        ApplyCompletionAdditionalEditsResponse
171    ),
172    (CopyProjectEntry, ProjectEntryResponse),
173    (CreateProjectEntry, ProjectEntryResponse),
174    (DeleteProjectEntry, ProjectEntryResponse),
175    (Follow, FollowResponse),
176    (FormatBuffers, FormatBuffersResponse),
177    (GetChannelMessages, GetChannelMessagesResponse),
178    (GetChannels, GetChannelsResponse),
179    (GetCodeActions, GetCodeActionsResponse),
180    (GetHover, GetHoverResponse),
181    (GetCompletions, GetCompletionsResponse),
182    (GetDefinition, GetDefinitionResponse),
183    (GetDocumentHighlights, GetDocumentHighlightsResponse),
184    (GetReferences, GetReferencesResponse),
185    (GetProjectSymbols, GetProjectSymbolsResponse),
186    (FuzzySearchUsers, UsersResponse),
187    (GetUsers, UsersResponse),
188    (JoinChannel, JoinChannelResponse),
189    (JoinProject, JoinProjectResponse),
190    (OpenBufferById, OpenBufferResponse),
191    (OpenBufferByPath, OpenBufferResponse),
192    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
193    (Ping, Ack),
194    (PerformRename, PerformRenameResponse),
195    (PrepareRename, PrepareRenameResponse),
196    (RegisterProject, RegisterProjectResponse),
197    (ReloadBuffers, ReloadBuffersResponse),
198    (RequestContact, Ack),
199    (RemoveContact, Ack),
200    (RespondToContactRequest, Ack),
201    (RenameProjectEntry, ProjectEntryResponse),
202    (SaveBuffer, BufferSaved),
203    (SearchProject, SearchProjectResponse),
204    (SendChannelMessage, SendChannelMessageResponse),
205    (Test, Test),
206    (UnregisterProject, Ack),
207    (UpdateBuffer, Ack),
208    (UpdateWorktree, Ack),
209);
210
211entity_messages!(
212    project_id,
213    AddProjectCollaborator,
214    ApplyCodeAction,
215    ApplyCompletionAdditionalEdits,
216    BufferReloaded,
217    BufferSaved,
218    CopyProjectEntry,
219    CreateProjectEntry,
220    DeleteProjectEntry,
221    Follow,
222    FormatBuffers,
223    GetCodeActions,
224    GetCompletions,
225    GetDefinition,
226    GetDocumentHighlights,
227    GetHover,
228    GetReferences,
229    GetProjectSymbols,
230    JoinProject,
231    JoinProjectRequestCancelled,
232    LeaveProject,
233    OpenBufferById,
234    OpenBufferByPath,
235    OpenBufferForSymbol,
236    PerformRename,
237    PrepareRename,
238    ProjectUnshared,
239    ReloadBuffers,
240    RemoveProjectCollaborator,
241    RenameProjectEntry,
242    RequestJoinProject,
243    SaveBuffer,
244    SearchProject,
245    StartLanguageServer,
246    Unfollow,
247    UnregisterProject,
248    UpdateBuffer,
249    UpdateBufferFile,
250    UpdateDiagnosticSummary,
251    UpdateFollowers,
252    UpdateLanguageServer,
253    UpdateProject,
254    UpdateWorktree,
255);
256
257entity_messages!(channel_id, ChannelMessageSent);
258
259const MAX_BUFFER_LEN: usize = 1 * 1024 * 1024;
260
261/// A stream of protobuf messages.
262pub struct MessageStream<S> {
263    stream: S,
264    encoding_buffer: Vec<u8>,
265}
266
267#[derive(Debug)]
268pub enum Message {
269    Envelope(Envelope),
270    Ping,
271    Pong,
272}
273
274impl<S> MessageStream<S> {
275    pub fn new(stream: S) -> Self {
276        Self {
277            stream,
278            encoding_buffer: Vec::new(),
279        }
280    }
281
282    pub fn inner_mut(&mut self) -> &mut S {
283        &mut self.stream
284    }
285}
286
287impl<S> MessageStream<S>
288where
289    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
290{
291    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
292        #[cfg(any(test, feature = "test-support"))]
293        const COMPRESSION_LEVEL: i32 = -7;
294
295        #[cfg(not(any(test, feature = "test-support")))]
296        const COMPRESSION_LEVEL: i32 = 4;
297
298        match message {
299            Message::Envelope(message) => {
300                self.encoding_buffer.reserve(message.encoded_len());
301                message
302                    .encode(&mut self.encoding_buffer)
303                    .map_err(|err| io::Error::from(err))?;
304                let buffer =
305                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
306                        .unwrap();
307
308                self.encoding_buffer.clear();
309                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
310                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
311            }
312            Message::Ping => {
313                self.stream
314                    .send(WebSocketMessage::Ping(Default::default()))
315                    .await?;
316            }
317            Message::Pong => {
318                self.stream
319                    .send(WebSocketMessage::Pong(Default::default()))
320                    .await?;
321            }
322        }
323
324        Ok(())
325    }
326}
327
328impl<S> MessageStream<S>
329where
330    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
331{
332    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
333        while let Some(bytes) = self.stream.next().await {
334            match bytes? {
335                WebSocketMessage::Binary(bytes) => {
336                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
337                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
338                        .map_err(io::Error::from)?;
339
340                    self.encoding_buffer.clear();
341                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
342                    return Ok(Message::Envelope(envelope));
343                }
344                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
345                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
346                WebSocketMessage::Close(_) => break,
347                _ => {}
348            }
349        }
350        Err(anyhow!("connection closed"))
351    }
352}
353
354impl Into<SystemTime> for Timestamp {
355    fn into(self) -> SystemTime {
356        UNIX_EPOCH
357            .checked_add(Duration::new(self.seconds, self.nanos))
358            .unwrap()
359    }
360}
361
362impl From<SystemTime> for Timestamp {
363    fn from(time: SystemTime) -> Self {
364        let duration = time.duration_since(UNIX_EPOCH).unwrap();
365        Self {
366            seconds: duration.as_secs(),
367            nanos: duration.subsec_nanos(),
368        }
369    }
370}
371
372impl From<u128> for Nonce {
373    fn from(nonce: u128) -> Self {
374        let upper_half = (nonce >> 64) as u64;
375        let lower_half = nonce as u64;
376        Self {
377            upper_half,
378            lower_half,
379        }
380    }
381}
382
383impl From<Nonce> for u128 {
384    fn from(nonce: Nonce) -> Self {
385        let upper_half = (nonce.upper_half as u128) << 64;
386        let lower_half = nonce.lower_half as u128;
387        upper_half | lower_half
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[gpui::test]
396    async fn test_buffer_size() {
397        let (tx, rx) = futures::channel::mpsc::unbounded();
398        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
399        sink.write(Message::Envelope(Envelope {
400            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
401                root_name: "abcdefg".repeat(10),
402                ..Default::default()
403            })),
404            ..Default::default()
405        }))
406        .await
407        .unwrap();
408        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
409        sink.write(Message::Envelope(Envelope {
410            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
411                root_name: "abcdefg".repeat(1000000),
412                ..Default::default()
413            })),
414            ..Default::default()
415        }))
416        .await
417        .unwrap();
418        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
419
420        let mut stream = MessageStream::new(rx.map(|msg| anyhow::Ok(msg)));
421        stream.read().await.unwrap();
422        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
423        stream.read().await.unwrap();
424        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
425    }
426}