proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    const PRIORITY: MessagePriority;
 17    fn into_envelope(
 18        self,
 19        id: u32,
 20        responding_to: Option<u32>,
 21        original_sender_id: Option<u32>,
 22    ) -> Envelope;
 23    fn from_envelope(envelope: Envelope) -> Option<Self>;
 24}
 25
 26pub trait EntityMessage: EnvelopedMessage {
 27    fn remote_entity_id(&self) -> u64;
 28}
 29
 30pub trait RequestMessage: EnvelopedMessage {
 31    type Response: EnvelopedMessage;
 32}
 33
 34pub trait AnyTypedEnvelope: 'static + Send + Sync {
 35    fn payload_type_id(&self) -> TypeId;
 36    fn payload_type_name(&self) -> &'static str;
 37    fn as_any(&self) -> &dyn Any;
 38    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 39    fn is_background(&self) -> bool;
 40}
 41
 42pub enum MessagePriority {
 43    Foreground,
 44    Background,
 45}
 46
 47impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 48    fn payload_type_id(&self) -> TypeId {
 49        TypeId::of::<T>()
 50    }
 51
 52    fn payload_type_name(&self) -> &'static str {
 53        T::NAME
 54    }
 55
 56    fn as_any(&self) -> &dyn Any {
 57        self
 58    }
 59
 60    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 61        self
 62    }
 63
 64    fn is_background(&self) -> bool {
 65        matches!(T::PRIORITY, MessagePriority::Background)
 66    }
 67}
 68
 69macro_rules! messages {
 70    ($(($name:ident, $priority:ident)),* $(,)?) => {
 71        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 72            match envelope.payload {
 73                $(Some(envelope::Payload::$name(payload)) => {
 74                    Some(Box::new(TypedEnvelope {
 75                        sender_id,
 76                        original_sender_id: envelope.original_sender_id.map(PeerId),
 77                        message_id: envelope.id,
 78                        payload,
 79                    }))
 80                }, )*
 81                _ => None
 82            }
 83        }
 84
 85        $(
 86            impl EnvelopedMessage for $name {
 87                const NAME: &'static str = std::stringify!($name);
 88                const PRIORITY: MessagePriority = MessagePriority::$priority;
 89
 90                fn into_envelope(
 91                    self,
 92                    id: u32,
 93                    responding_to: Option<u32>,
 94                    original_sender_id: Option<u32>,
 95                ) -> Envelope {
 96                    Envelope {
 97                        id,
 98                        responding_to,
 99                        original_sender_id,
100                        payload: Some(envelope::Payload::$name(self)),
101                    }
102                }
103
104                fn from_envelope(envelope: Envelope) -> Option<Self> {
105                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
106                        Some(msg)
107                    } else {
108                        None
109                    }
110                }
111            }
112        )*
113    };
114}
115
116macro_rules! request_messages {
117    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
118        $(impl RequestMessage for $request_name {
119            type Response = $response_name;
120        })*
121    };
122}
123
124macro_rules! entity_messages {
125    ($id_field:ident, $($name:ident),* $(,)?) => {
126        $(impl EntityMessage for $name {
127            fn remote_entity_id(&self) -> u64 {
128                self.$id_field
129            }
130        })*
131    };
132}
133
134messages!(
135    (Ack, Foreground),
136    (AddProjectCollaborator, Foreground),
137    (ApplyCodeAction, Foreground),
138    (ApplyCodeActionResponse, Foreground),
139    (ApplyCompletionAdditionalEdits, Foreground),
140    (ApplyCompletionAdditionalEditsResponse, Foreground),
141    (BufferReloaded, Foreground),
142    (BufferSaved, Foreground),
143    (ChannelMessageSent, Foreground),
144    (CloseBuffer, Foreground),
145    (DiskBasedDiagnosticsUpdated, Background),
146    (DiskBasedDiagnosticsUpdating, Background),
147    (Error, Foreground),
148    (FormatBuffers, Foreground),
149    (FormatBuffersResponse, Foreground),
150    (GetChannelMessages, Foreground),
151    (GetChannelMessagesResponse, Foreground),
152    (GetChannels, Foreground),
153    (GetChannelsResponse, Foreground),
154    (GetCodeActions, Background),
155    (GetCodeActionsResponse, Foreground),
156    (GetCompletions, Background),
157    (GetCompletionsResponse, Foreground),
158    (GetDefinition, Foreground),
159    (GetDefinitionResponse, Foreground),
160    (GetUsers, Foreground),
161    (GetUsersResponse, Foreground),
162    (JoinChannel, Foreground),
163    (JoinChannelResponse, Foreground),
164    (JoinProject, Foreground),
165    (JoinProjectResponse, Foreground),
166    (LeaveChannel, Foreground),
167    (LeaveProject, Foreground),
168    (OpenBuffer, Foreground),
169    (OpenBufferResponse, Foreground),
170    (PerformRename, Background),
171    (PerformRenameResponse, Background),
172    (PrepareRename, Background),
173    (PrepareRenameResponse, Background),
174    (RegisterProjectResponse, Foreground),
175    (Ping, Foreground),
176    (RegisterProject, Foreground),
177    (RegisterWorktree, Foreground),
178    (RemoveProjectCollaborator, Foreground),
179    (SaveBuffer, Foreground),
180    (SendChannelMessage, Foreground),
181    (SendChannelMessageResponse, Foreground),
182    (ShareProject, Foreground),
183    (ShareWorktree, Foreground),
184    (Test, Foreground),
185    (UnregisterProject, Foreground),
186    (UnregisterWorktree, Foreground),
187    (UnshareProject, Foreground),
188    (UpdateBuffer, Foreground),
189    (UpdateBufferFile, Foreground),
190    (UpdateContacts, Foreground),
191    (UpdateDiagnosticSummary, Foreground),
192    (UpdateWorktree, Foreground),
193);
194
195request_messages!(
196    (ApplyCodeAction, ApplyCodeActionResponse),
197    (
198        ApplyCompletionAdditionalEdits,
199        ApplyCompletionAdditionalEditsResponse
200    ),
201    (FormatBuffers, FormatBuffersResponse),
202    (GetChannelMessages, GetChannelMessagesResponse),
203    (GetChannels, GetChannelsResponse),
204    (GetCodeActions, GetCodeActionsResponse),
205    (GetCompletions, GetCompletionsResponse),
206    (GetDefinition, GetDefinitionResponse),
207    (GetUsers, GetUsersResponse),
208    (JoinChannel, JoinChannelResponse),
209    (JoinProject, JoinProjectResponse),
210    (OpenBuffer, OpenBufferResponse),
211    (Ping, Ack),
212    (PerformRename, PerformRenameResponse),
213    (PrepareRename, PrepareRenameResponse),
214    (RegisterProject, RegisterProjectResponse),
215    (RegisterWorktree, Ack),
216    (SaveBuffer, BufferSaved),
217    (SendChannelMessage, SendChannelMessageResponse),
218    (ShareProject, Ack),
219    (ShareWorktree, Ack),
220    (Test, Test),
221    (UpdateBuffer, Ack),
222    (UpdateWorktree, Ack),
223);
224
225entity_messages!(
226    project_id,
227    AddProjectCollaborator,
228    ApplyCodeAction,
229    ApplyCompletionAdditionalEdits,
230    BufferReloaded,
231    BufferSaved,
232    CloseBuffer,
233    DiskBasedDiagnosticsUpdated,
234    DiskBasedDiagnosticsUpdating,
235    FormatBuffers,
236    GetCodeActions,
237    GetCompletions,
238    GetDefinition,
239    JoinProject,
240    LeaveProject,
241    OpenBuffer,
242    PerformRename,
243    PrepareRename,
244    RemoveProjectCollaborator,
245    SaveBuffer,
246    ShareWorktree,
247    UnregisterWorktree,
248    UnshareProject,
249    UpdateBuffer,
250    UpdateBufferFile,
251    UpdateDiagnosticSummary,
252    UpdateWorktree,
253);
254
255entity_messages!(channel_id, ChannelMessageSent);
256
257/// A stream of protobuf messages.
258pub struct MessageStream<S> {
259    stream: S,
260    encoding_buffer: Vec<u8>,
261}
262
263impl<S> MessageStream<S> {
264    pub fn new(stream: S) -> Self {
265        Self {
266            stream,
267            encoding_buffer: Vec::new(),
268        }
269    }
270
271    pub fn inner_mut(&mut self) -> &mut S {
272        &mut self.stream
273    }
274}
275
276impl<S> MessageStream<S>
277where
278    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
279{
280    /// Write a given protobuf message to the stream.
281    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
282        #[cfg(any(test, feature = "test-support"))]
283        const COMPRESSION_LEVEL: i32 = -7;
284
285        #[cfg(not(any(test, feature = "test-support")))]
286        const COMPRESSION_LEVEL: i32 = 4;
287
288        self.encoding_buffer.resize(message.encoded_len(), 0);
289        self.encoding_buffer.clear();
290        message
291            .encode(&mut self.encoding_buffer)
292            .map_err(|err| io::Error::from(err))?;
293        let buffer =
294            zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL).unwrap();
295        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
296        Ok(())
297    }
298}
299
300impl<S> MessageStream<S>
301where
302    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
303{
304    /// Read a protobuf message of the given type from the stream.
305    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
306        while let Some(bytes) = self.stream.next().await {
307            match bytes? {
308                WebSocketMessage::Binary(bytes) => {
309                    self.encoding_buffer.clear();
310                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
311                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
312                        .map_err(io::Error::from)?;
313                    return Ok(envelope);
314                }
315                WebSocketMessage::Close(_) => break,
316                _ => {}
317            }
318        }
319        Err(WebSocketError::ConnectionClosed)
320    }
321}
322
323impl Into<SystemTime> for Timestamp {
324    fn into(self) -> SystemTime {
325        UNIX_EPOCH
326            .checked_add(Duration::new(self.seconds, self.nanos))
327            .unwrap()
328    }
329}
330
331impl From<SystemTime> for Timestamp {
332    fn from(time: SystemTime) -> Self {
333        let duration = time.duration_since(UNIX_EPOCH).unwrap();
334        Self {
335            seconds: duration.as_secs(),
336            nanos: duration.subsec_nanos(),
337        }
338    }
339}
340
341impl From<u128> for Nonce {
342    fn from(nonce: u128) -> Self {
343        let upper_half = (nonce >> 64) as u64;
344        let lower_half = nonce as u64;
345        Self {
346            upper_half,
347            lower_half,
348        }
349    }
350}
351
352impl From<Nonce> for u128 {
353    fn from(nonce: Nonce) -> Self {
354        let upper_half = (nonce.upper_half as u128) << 64;
355        let lower_half = nonce.lower_half as u128;
356        upper_half | lower_half
357    }
358}