proto.rs

  1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::Message as WebSocketMessage;
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message as _;
  6use serde::Serialize;
  7use std::any::{Any, TypeId};
  8use std::{
  9    fmt::Debug,
 10    io,
 11    time::{Duration, SystemTime, UNIX_EPOCH},
 12};
 13
 14include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 15
 16pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
 17    const NAME: &'static str;
 18    const PRIORITY: MessagePriority;
 19    fn into_envelope(
 20        self,
 21        id: u32,
 22        responding_to: Option<u32>,
 23        original_sender_id: Option<u32>,
 24    ) -> Envelope;
 25    fn from_envelope(envelope: Envelope) -> Option<Self>;
 26}
 27
 28pub trait EntityMessage: EnvelopedMessage {
 29    fn remote_entity_id(&self) -> u64;
 30}
 31
 32pub trait RequestMessage: EnvelopedMessage {
 33    type Response: EnvelopedMessage;
 34}
 35
 36pub trait AnyTypedEnvelope: 'static + Send + Sync {
 37    fn payload_type_id(&self) -> TypeId;
 38    fn payload_type_name(&self) -> &'static str;
 39    fn as_any(&self) -> &dyn Any;
 40    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 41    fn is_background(&self) -> bool;
 42    fn original_sender_id(&self) -> Option<PeerId>;
 43}
 44
 45pub enum MessagePriority {
 46    Foreground,
 47    Background,
 48}
 49
 50impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 51    fn payload_type_id(&self) -> TypeId {
 52        TypeId::of::<T>()
 53    }
 54
 55    fn payload_type_name(&self) -> &'static str {
 56        T::NAME
 57    }
 58
 59    fn as_any(&self) -> &dyn Any {
 60        self
 61    }
 62
 63    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 64        self
 65    }
 66
 67    fn is_background(&self) -> bool {
 68        matches!(T::PRIORITY, MessagePriority::Background)
 69    }
 70
 71    fn original_sender_id(&self) -> Option<PeerId> {
 72        self.original_sender_id
 73    }
 74}
 75
 76messages!(
 77    (Ack, Foreground),
 78    (AddProjectCollaborator, Foreground),
 79    (ApplyCodeAction, Background),
 80    (ApplyCodeActionResponse, Background),
 81    (ApplyCompletionAdditionalEdits, Background),
 82    (ApplyCompletionAdditionalEditsResponse, Background),
 83    (BufferReloaded, Foreground),
 84    (BufferSaved, Foreground),
 85    (RemoveContact, Foreground),
 86    (ChannelMessageSent, Foreground),
 87    (CopyProjectEntry, Foreground),
 88    (CreateProjectEntry, Foreground),
 89    (DeleteProjectEntry, Foreground),
 90    (Error, Foreground),
 91    (Follow, Foreground),
 92    (FollowResponse, Foreground),
 93    (FormatBuffers, Foreground),
 94    (FormatBuffersResponse, Foreground),
 95    (FuzzySearchUsers, Foreground),
 96    (GetChannelMessages, Foreground),
 97    (GetChannelMessagesResponse, Foreground),
 98    (GetChannels, Foreground),
 99    (GetChannelsResponse, Foreground),
100    (GetCodeActions, Background),
101    (GetCodeActionsResponse, Background),
102    (GetCompletions, Background),
103    (GetCompletionsResponse, Background),
104    (GetDefinition, Background),
105    (GetDefinitionResponse, Background),
106    (GetDocumentHighlights, Background),
107    (GetDocumentHighlightsResponse, Background),
108    (GetReferences, Background),
109    (GetReferencesResponse, Background),
110    (GetProjectSymbols, Background),
111    (GetProjectSymbolsResponse, Background),
112    (GetUsers, Foreground),
113    (UsersResponse, Foreground),
114    (JoinChannel, Foreground),
115    (JoinChannelResponse, Foreground),
116    (JoinProject, Foreground),
117    (JoinProjectResponse, Foreground),
118    (JoinProjectRequestCancelled, Foreground),
119    (LeaveChannel, Foreground),
120    (LeaveProject, Foreground),
121    (OpenBufferById, Background),
122    (OpenBufferByPath, Background),
123    (OpenBufferForSymbol, Background),
124    (OpenBufferForSymbolResponse, Background),
125    (OpenBufferResponse, Background),
126    (PerformRename, Background),
127    (PerformRenameResponse, Background),
128    (PrepareRename, Background),
129    (PrepareRenameResponse, Background),
130    (ProjectEntryResponse, Foreground),
131    (RegisterProjectResponse, Foreground),
132    (Ping, Foreground),
133    (ProjectUnshared, Foreground),
134    (RegisterProject, Foreground),
135    (ReloadBuffers, Foreground),
136    (ReloadBuffersResponse, Foreground),
137    (RemoveProjectCollaborator, Foreground),
138    (RenameProjectEntry, Foreground),
139    (RequestContact, Foreground),
140    (RequestJoinProject, Foreground),
141    (RespondToContactRequest, Foreground),
142    (RespondToJoinProjectRequest, Foreground),
143    (SaveBuffer, Foreground),
144    (SearchProject, Background),
145    (SearchProjectResponse, Background),
146    (SendChannelMessage, Foreground),
147    (SendChannelMessageResponse, Foreground),
148    (ShowContacts, Foreground),
149    (StartLanguageServer, Foreground),
150    (Test, Foreground),
151    (Unfollow, Foreground),
152    (UnregisterProject, Foreground),
153    (UpdateBuffer, Foreground),
154    (UpdateBufferFile, Foreground),
155    (UpdateContacts, Foreground),
156    (UpdateDiagnosticSummary, Foreground),
157    (UpdateFollowers, Foreground),
158    (UpdateInviteInfo, Foreground),
159    (UpdateLanguageServer, Foreground),
160    (UpdateProject, Foreground),
161    (UpdateWorktree, Foreground),
162);
163
164request_messages!(
165    (ApplyCodeAction, ApplyCodeActionResponse),
166    (
167        ApplyCompletionAdditionalEdits,
168        ApplyCompletionAdditionalEditsResponse
169    ),
170    (CopyProjectEntry, ProjectEntryResponse),
171    (CreateProjectEntry, ProjectEntryResponse),
172    (DeleteProjectEntry, ProjectEntryResponse),
173    (Follow, FollowResponse),
174    (FormatBuffers, FormatBuffersResponse),
175    (GetChannelMessages, GetChannelMessagesResponse),
176    (GetChannels, GetChannelsResponse),
177    (GetCodeActions, GetCodeActionsResponse),
178    (GetCompletions, GetCompletionsResponse),
179    (GetDefinition, GetDefinitionResponse),
180    (GetDocumentHighlights, GetDocumentHighlightsResponse),
181    (GetReferences, GetReferencesResponse),
182    (GetProjectSymbols, GetProjectSymbolsResponse),
183    (FuzzySearchUsers, UsersResponse),
184    (GetUsers, UsersResponse),
185    (JoinChannel, JoinChannelResponse),
186    (JoinProject, JoinProjectResponse),
187    (OpenBufferById, OpenBufferResponse),
188    (OpenBufferByPath, OpenBufferResponse),
189    (OpenBufferForSymbol, OpenBufferForSymbolResponse),
190    (Ping, Ack),
191    (PerformRename, PerformRenameResponse),
192    (PrepareRename, PrepareRenameResponse),
193    (RegisterProject, RegisterProjectResponse),
194    (ReloadBuffers, ReloadBuffersResponse),
195    (RequestContact, Ack),
196    (RemoveContact, Ack),
197    (RespondToContactRequest, Ack),
198    (RenameProjectEntry, ProjectEntryResponse),
199    (SaveBuffer, BufferSaved),
200    (SearchProject, SearchProjectResponse),
201    (SendChannelMessage, SendChannelMessageResponse),
202    (Test, Test),
203    (UnregisterProject, Ack),
204    (UpdateBuffer, Ack),
205    (UpdateWorktree, Ack),
206);
207
208entity_messages!(
209    project_id,
210    AddProjectCollaborator,
211    ApplyCodeAction,
212    ApplyCompletionAdditionalEdits,
213    BufferReloaded,
214    BufferSaved,
215    CopyProjectEntry,
216    CreateProjectEntry,
217    DeleteProjectEntry,
218    Follow,
219    FormatBuffers,
220    GetCodeActions,
221    GetCompletions,
222    GetDefinition,
223    GetDocumentHighlights,
224    GetReferences,
225    GetProjectSymbols,
226    JoinProject,
227    JoinProjectRequestCancelled,
228    LeaveProject,
229    OpenBufferById,
230    OpenBufferByPath,
231    OpenBufferForSymbol,
232    PerformRename,
233    PrepareRename,
234    ProjectUnshared,
235    ReloadBuffers,
236    RemoveProjectCollaborator,
237    RenameProjectEntry,
238    RequestJoinProject,
239    SaveBuffer,
240    SearchProject,
241    StartLanguageServer,
242    Unfollow,
243    UnregisterProject,
244    UpdateBuffer,
245    UpdateBufferFile,
246    UpdateDiagnosticSummary,
247    UpdateFollowers,
248    UpdateLanguageServer,
249    UpdateProject,
250    UpdateWorktree,
251);
252
253entity_messages!(channel_id, ChannelMessageSent);
254
255const MAX_BUFFER_LEN: usize = 1 * 1024 * 1024;
256
257/// A stream of protobuf messages.
258pub struct MessageStream<S> {
259    stream: S,
260    encoding_buffer: Vec<u8>,
261}
262
263#[derive(Debug)]
264pub enum Message {
265    Envelope(Envelope),
266    Ping,
267    Pong,
268}
269
270impl<S> MessageStream<S> {
271    pub fn new(stream: S) -> Self {
272        Self {
273            stream,
274            encoding_buffer: Vec::new(),
275        }
276    }
277
278    pub fn inner_mut(&mut self) -> &mut S {
279        &mut self.stream
280    }
281}
282
283impl<S> MessageStream<S>
284where
285    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
286{
287    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
288        #[cfg(any(test, feature = "test-support"))]
289        const COMPRESSION_LEVEL: i32 = -7;
290
291        #[cfg(not(any(test, feature = "test-support")))]
292        const COMPRESSION_LEVEL: i32 = 4;
293
294        match message {
295            Message::Envelope(message) => {
296                self.encoding_buffer.reserve(message.encoded_len());
297                message
298                    .encode(&mut self.encoding_buffer)
299                    .map_err(|err| io::Error::from(err))?;
300                let buffer =
301                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
302                        .unwrap();
303
304                self.encoding_buffer.clear();
305                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
306                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
307            }
308            Message::Ping => {
309                self.stream
310                    .send(WebSocketMessage::Ping(Default::default()))
311                    .await?;
312            }
313            Message::Pong => {
314                self.stream
315                    .send(WebSocketMessage::Pong(Default::default()))
316                    .await?;
317            }
318        }
319
320        Ok(())
321    }
322}
323
324impl<S> MessageStream<S>
325where
326    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
327{
328    pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
329        while let Some(bytes) = self.stream.next().await {
330            match bytes? {
331                WebSocketMessage::Binary(bytes) => {
332                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
333                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
334                        .map_err(io::Error::from)?;
335
336                    self.encoding_buffer.clear();
337                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
338                    return Ok(Message::Envelope(envelope));
339                }
340                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
341                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
342                WebSocketMessage::Close(_) => break,
343                _ => {}
344            }
345        }
346        Err(anyhow!("connection closed"))
347    }
348}
349
350impl Into<SystemTime> for Timestamp {
351    fn into(self) -> SystemTime {
352        UNIX_EPOCH
353            .checked_add(Duration::new(self.seconds, self.nanos))
354            .unwrap()
355    }
356}
357
358impl From<SystemTime> for Timestamp {
359    fn from(time: SystemTime) -> Self {
360        let duration = time.duration_since(UNIX_EPOCH).unwrap();
361        Self {
362            seconds: duration.as_secs(),
363            nanos: duration.subsec_nanos(),
364        }
365    }
366}
367
368impl From<u128> for Nonce {
369    fn from(nonce: u128) -> Self {
370        let upper_half = (nonce >> 64) as u64;
371        let lower_half = nonce as u64;
372        Self {
373            upper_half,
374            lower_half,
375        }
376    }
377}
378
379impl From<Nonce> for u128 {
380    fn from(nonce: Nonce) -> Self {
381        let upper_half = (nonce.upper_half as u128) << 64;
382        let lower_half = nonce.lower_half as u128;
383        upper_half | lower_half
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    #[gpui::test]
392    async fn test_buffer_size() {
393        let (tx, rx) = futures::channel::mpsc::unbounded();
394        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
395        sink.write(Message::Envelope(Envelope {
396            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
397                root_name: "abcdefg".repeat(10),
398                ..Default::default()
399            })),
400            ..Default::default()
401        }))
402        .await
403        .unwrap();
404        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
405        sink.write(Message::Envelope(Envelope {
406            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
407                root_name: "abcdefg".repeat(1000000),
408                ..Default::default()
409            })),
410            ..Default::default()
411        }))
412        .await
413        .unwrap();
414        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
415
416        let mut stream = MessageStream::new(rx.map(|msg| anyhow::Ok(msg)));
417        stream.read().await.unwrap();
418        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
419        stream.read().await.unwrap();
420        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
421    }
422}