proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    Ack,
124    AddProjectCollaborator,
125    ApplyCodeAction,
126    ApplyCodeActionResponse,
127    ApplyCompletionAdditionalEdits,
128    ApplyCompletionAdditionalEditsResponse,
129    BufferReloaded,
130    BufferSaved,
131    ChannelMessageSent,
132    CloseBuffer,
133    DiskBasedDiagnosticsUpdated,
134    DiskBasedDiagnosticsUpdating,
135    Error,
136    FormatBuffer,
137    GetChannelMessages,
138    GetChannelMessagesResponse,
139    GetChannels,
140    GetChannelsResponse,
141    GetCodeActions,
142    GetCodeActionsResponse,
143    GetCompletions,
144    GetCompletionsResponse,
145    GetDefinition,
146    GetDefinitionResponse,
147    GetUsers,
148    GetUsersResponse,
149    JoinChannel,
150    JoinChannelResponse,
151    JoinProject,
152    JoinProjectResponse,
153    LeaveChannel,
154    LeaveProject,
155    OpenBuffer,
156    OpenBufferResponse,
157    RegisterProjectResponse,
158    Ping,
159    RegisterProject,
160    RegisterWorktree,
161    RemoveProjectCollaborator,
162    SaveBuffer,
163    SendChannelMessage,
164    SendChannelMessageResponse,
165    ShareProject,
166    ShareWorktree,
167    UnregisterProject,
168    UnregisterWorktree,
169    UnshareProject,
170    UpdateBuffer,
171    UpdateBufferFile,
172    UpdateContacts,
173    UpdateDiagnosticSummary,
174    UpdateWorktree,
175);
176
177request_messages!(
178    (ApplyCodeAction, ApplyCodeActionResponse),
179    (
180        ApplyCompletionAdditionalEdits,
181        ApplyCompletionAdditionalEditsResponse
182    ),
183    (FormatBuffer, Ack),
184    (GetChannelMessages, GetChannelMessagesResponse),
185    (GetChannels, GetChannelsResponse),
186    (GetCodeActions, GetCodeActionsResponse),
187    (GetCompletions, GetCompletionsResponse),
188    (GetDefinition, GetDefinitionResponse),
189    (GetUsers, GetUsersResponse),
190    (JoinChannel, JoinChannelResponse),
191    (JoinProject, JoinProjectResponse),
192    (OpenBuffer, OpenBufferResponse),
193    (Ping, Ack),
194    (RegisterProject, RegisterProjectResponse),
195    (RegisterWorktree, Ack),
196    (SaveBuffer, BufferSaved),
197    (SendChannelMessage, SendChannelMessageResponse),
198    (ShareProject, Ack),
199    (ShareWorktree, Ack),
200    (UpdateBuffer, Ack),
201);
202
203entity_messages!(
204    project_id,
205    AddProjectCollaborator,
206    ApplyCodeAction,
207    ApplyCompletionAdditionalEdits,
208    BufferReloaded,
209    BufferSaved,
210    CloseBuffer,
211    DiskBasedDiagnosticsUpdated,
212    DiskBasedDiagnosticsUpdating,
213    FormatBuffer,
214    GetCodeActions,
215    GetCompletions,
216    GetDefinition,
217    JoinProject,
218    LeaveProject,
219    OpenBuffer,
220    RemoveProjectCollaborator,
221    SaveBuffer,
222    ShareWorktree,
223    UnregisterWorktree,
224    UnshareProject,
225    UpdateBuffer,
226    UpdateBufferFile,
227    UpdateDiagnosticSummary,
228    UpdateWorktree,
229);
230
231entity_messages!(channel_id, ChannelMessageSent);
232
233/// A stream of protobuf messages.
234pub struct MessageStream<S> {
235    stream: S,
236    encoding_buffer: Vec<u8>,
237}
238
239impl<S> MessageStream<S> {
240    pub fn new(stream: S) -> Self {
241        Self {
242            stream,
243            encoding_buffer: Vec::new(),
244        }
245    }
246
247    pub fn inner_mut(&mut self) -> &mut S {
248        &mut self.stream
249    }
250}
251
252impl<S> MessageStream<S>
253where
254    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
255{
256    /// Write a given protobuf message to the stream.
257    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
258        self.encoding_buffer.resize(message.encoded_len(), 0);
259        self.encoding_buffer.clear();
260        message
261            .encode(&mut self.encoding_buffer)
262            .map_err(|err| io::Error::from(err))?;
263        let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
264        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
265        Ok(())
266    }
267}
268
269impl<S> MessageStream<S>
270where
271    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
272{
273    /// Read a protobuf message of the given type from the stream.
274    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
275        while let Some(bytes) = self.stream.next().await {
276            match bytes? {
277                WebSocketMessage::Binary(bytes) => {
278                    self.encoding_buffer.clear();
279                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
280                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
281                        .map_err(io::Error::from)?;
282                    return Ok(envelope);
283                }
284                WebSocketMessage::Close(_) => break,
285                _ => {}
286            }
287        }
288        Err(WebSocketError::ConnectionClosed)
289    }
290}
291
292impl Into<SystemTime> for Timestamp {
293    fn into(self) -> SystemTime {
294        UNIX_EPOCH
295            .checked_add(Duration::new(self.seconds, self.nanos))
296            .unwrap()
297    }
298}
299
300impl From<SystemTime> for Timestamp {
301    fn from(time: SystemTime) -> Self {
302        let duration = time.duration_since(UNIX_EPOCH).unwrap();
303        Self {
304            seconds: duration.as_secs(),
305            nanos: duration.subsec_nanos(),
306        }
307    }
308}
309
310impl From<u128> for Nonce {
311    fn from(nonce: u128) -> Self {
312        let upper_half = (nonce >> 64) as u64;
313        let lower_half = nonce as u64;
314        Self {
315            upper_half,
316            lower_half,
317        }
318    }
319}
320
321impl From<Nonce> for u128 {
322    fn from(nonce: Nonce) -> Self {
323        let upper_half = (nonce.upper_half as u128) << 64;
324        let lower_half = nonce.lower_half as u128;
325        upper_half | lower_half
326    }
327}