proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    Ack,
124    AddProjectCollaborator,
125    ApplyCodeAction,
126    ApplyCodeActionResponse,
127    ApplyCompletionAdditionalEdits,
128    ApplyCompletionAdditionalEditsResponse,
129    BufferReloaded,
130    BufferSaved,
131    ChannelMessageSent,
132    CloseBuffer,
133    DiskBasedDiagnosticsUpdated,
134    DiskBasedDiagnosticsUpdating,
135    Error,
136    FormatBuffers,
137    FormatBuffersResponse,
138    GetChannelMessages,
139    GetChannelMessagesResponse,
140    GetChannels,
141    GetChannelsResponse,
142    GetCodeActions,
143    GetCodeActionsResponse,
144    GetCompletions,
145    GetCompletionsResponse,
146    GetDefinition,
147    GetDefinitionResponse,
148    GetUsers,
149    GetUsersResponse,
150    JoinChannel,
151    JoinChannelResponse,
152    JoinProject,
153    JoinProjectResponse,
154    LeaveChannel,
155    LeaveProject,
156    OpenBuffer,
157    OpenBufferResponse,
158    RegisterProjectResponse,
159    Ping,
160    RegisterProject,
161    RegisterWorktree,
162    RemoveProjectCollaborator,
163    SaveBuffer,
164    SendChannelMessage,
165    SendChannelMessageResponse,
166    ShareProject,
167    ShareWorktree,
168    UnregisterProject,
169    UnregisterWorktree,
170    UnshareProject,
171    UpdateBuffer,
172    UpdateBufferFile,
173    UpdateContacts,
174    UpdateDiagnosticSummary,
175    UpdateWorktree,
176);
177
178request_messages!(
179    (ApplyCodeAction, ApplyCodeActionResponse),
180    (
181        ApplyCompletionAdditionalEdits,
182        ApplyCompletionAdditionalEditsResponse
183    ),
184    (FormatBuffers, FormatBuffersResponse),
185    (GetChannelMessages, GetChannelMessagesResponse),
186    (GetChannels, GetChannelsResponse),
187    (GetCodeActions, GetCodeActionsResponse),
188    (GetCompletions, GetCompletionsResponse),
189    (GetDefinition, GetDefinitionResponse),
190    (GetUsers, GetUsersResponse),
191    (JoinChannel, JoinChannelResponse),
192    (JoinProject, JoinProjectResponse),
193    (OpenBuffer, OpenBufferResponse),
194    (Ping, Ack),
195    (RegisterProject, RegisterProjectResponse),
196    (RegisterWorktree, Ack),
197    (SaveBuffer, BufferSaved),
198    (SendChannelMessage, SendChannelMessageResponse),
199    (ShareProject, Ack),
200    (ShareWorktree, Ack),
201    (UpdateBuffer, Ack),
202);
203
204entity_messages!(
205    project_id,
206    AddProjectCollaborator,
207    ApplyCodeAction,
208    ApplyCompletionAdditionalEdits,
209    BufferReloaded,
210    BufferSaved,
211    CloseBuffer,
212    DiskBasedDiagnosticsUpdated,
213    DiskBasedDiagnosticsUpdating,
214    FormatBuffers,
215    GetCodeActions,
216    GetCompletions,
217    GetDefinition,
218    JoinProject,
219    LeaveProject,
220    OpenBuffer,
221    RemoveProjectCollaborator,
222    SaveBuffer,
223    ShareWorktree,
224    UnregisterWorktree,
225    UnshareProject,
226    UpdateBuffer,
227    UpdateBufferFile,
228    UpdateDiagnosticSummary,
229    UpdateWorktree,
230);
231
232entity_messages!(channel_id, ChannelMessageSent);
233
234/// A stream of protobuf messages.
235pub struct MessageStream<S> {
236    stream: S,
237    encoding_buffer: Vec<u8>,
238}
239
240impl<S> MessageStream<S> {
241    pub fn new(stream: S) -> Self {
242        Self {
243            stream,
244            encoding_buffer: Vec::new(),
245        }
246    }
247
248    pub fn inner_mut(&mut self) -> &mut S {
249        &mut self.stream
250    }
251}
252
253impl<S> MessageStream<S>
254where
255    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
256{
257    /// Write a given protobuf message to the stream.
258    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
259        self.encoding_buffer.resize(message.encoded_len(), 0);
260        self.encoding_buffer.clear();
261        message
262            .encode(&mut self.encoding_buffer)
263            .map_err(|err| io::Error::from(err))?;
264        let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
265        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
266        Ok(())
267    }
268}
269
270impl<S> MessageStream<S>
271where
272    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
273{
274    /// Read a protobuf message of the given type from the stream.
275    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
276        while let Some(bytes) = self.stream.next().await {
277            match bytes? {
278                WebSocketMessage::Binary(bytes) => {
279                    self.encoding_buffer.clear();
280                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
281                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
282                        .map_err(io::Error::from)?;
283                    return Ok(envelope);
284                }
285                WebSocketMessage::Close(_) => break,
286                _ => {}
287            }
288        }
289        Err(WebSocketError::ConnectionClosed)
290    }
291}
292
293impl Into<SystemTime> for Timestamp {
294    fn into(self) -> SystemTime {
295        UNIX_EPOCH
296            .checked_add(Duration::new(self.seconds, self.nanos))
297            .unwrap()
298    }
299}
300
301impl From<SystemTime> for Timestamp {
302    fn from(time: SystemTime) -> Self {
303        let duration = time.duration_since(UNIX_EPOCH).unwrap();
304        Self {
305            seconds: duration.as_secs(),
306            nanos: duration.subsec_nanos(),
307        }
308    }
309}
310
311impl From<u128> for Nonce {
312    fn from(nonce: u128) -> Self {
313        let upper_half = (nonce >> 64) as u64;
314        let lower_half = nonce as u64;
315        Self {
316            upper_half,
317            lower_half,
318        }
319    }
320}
321
322impl From<Nonce> for u128 {
323    fn from(nonce: Nonce) -> Self {
324        let upper_half = (nonce.upper_half as u128) << 64;
325        let lower_half = nonce.lower_half as u128;
326        upper_half | lower_half
327    }
328}