proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::{Any, TypeId};
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn from_envelope(envelope: Envelope) -> Option<Self>;
 23}
 24
 25pub trait EntityMessage: EnvelopedMessage {
 26    fn remote_entity_id(&self) -> u64;
 27}
 28
 29pub trait RequestMessage: EnvelopedMessage {
 30    type Response: EnvelopedMessage;
 31}
 32
 33pub trait AnyTypedEnvelope: 'static + Send + Sync {
 34    fn payload_type_id(&self) -> TypeId;
 35    fn payload_type_name(&self) -> &'static str;
 36    fn as_any(&self) -> &dyn Any;
 37    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
 38}
 39
 40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
 41    fn payload_type_id(&self) -> TypeId {
 42        TypeId::of::<T>()
 43    }
 44
 45    fn payload_type_name(&self) -> &'static str {
 46        T::NAME
 47    }
 48
 49    fn as_any(&self) -> &dyn Any {
 50        self
 51    }
 52
 53    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
 54        self
 55    }
 56}
 57
 58macro_rules! messages {
 59    ($($name:ident),* $(,)?) => {
 60        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
 61            match envelope.payload {
 62                $(Some(envelope::Payload::$name(payload)) => {
 63                    Some(Box::new(TypedEnvelope {
 64                        sender_id,
 65                        original_sender_id: envelope.original_sender_id.map(PeerId),
 66                        message_id: envelope.id,
 67                        payload,
 68                    }))
 69                }, )*
 70                _ => None
 71            }
 72        }
 73
 74        $(
 75            impl EnvelopedMessage for $name {
 76                const NAME: &'static str = std::stringify!($name);
 77
 78                fn into_envelope(
 79                    self,
 80                    id: u32,
 81                    responding_to: Option<u32>,
 82                    original_sender_id: Option<u32>,
 83                ) -> Envelope {
 84                    Envelope {
 85                        id,
 86                        responding_to,
 87                        original_sender_id,
 88                        payload: Some(envelope::Payload::$name(self)),
 89                    }
 90                }
 91
 92                fn from_envelope(envelope: Envelope) -> Option<Self> {
 93                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 94                        Some(msg)
 95                    } else {
 96                        None
 97                    }
 98                }
 99            }
100        )*
101    };
102}
103
104macro_rules! request_messages {
105    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106        $(impl RequestMessage for $request_name {
107            type Response = $response_name;
108        })*
109    };
110}
111
112macro_rules! entity_messages {
113    ($id_field:ident, $($name:ident),* $(,)?) => {
114        $(impl EntityMessage for $name {
115            fn remote_entity_id(&self) -> u64 {
116                self.$id_field
117            }
118        })*
119    };
120}
121
122messages!(
123    Ack,
124    AddProjectCollaborator,
125    BufferReloaded,
126    BufferSaved,
127    ChannelMessageSent,
128    CloseBuffer,
129    DiskBasedDiagnosticsUpdated,
130    DiskBasedDiagnosticsUpdating,
131    Error,
132    FormatBuffer,
133    GetChannelMessages,
134    GetChannelMessagesResponse,
135    GetChannels,
136    GetChannelsResponse,
137    GetCompletions,
138    GetCompletionsResponse,
139    GetDefinition,
140    GetDefinitionResponse,
141    GetUsers,
142    GetUsersResponse,
143    JoinChannel,
144    JoinChannelResponse,
145    JoinProject,
146    JoinProjectResponse,
147    LeaveChannel,
148    LeaveProject,
149    OpenBuffer,
150    OpenBufferResponse,
151    RegisterProjectResponse,
152    Ping,
153    RegisterProject,
154    RegisterWorktree,
155    RemoveProjectCollaborator,
156    SaveBuffer,
157    SendChannelMessage,
158    SendChannelMessageResponse,
159    ShareProject,
160    ShareWorktree,
161    UnregisterProject,
162    UnregisterWorktree,
163    UnshareProject,
164    UpdateBuffer,
165    UpdateBufferFile,
166    UpdateContacts,
167    UpdateDiagnosticSummary,
168    UpdateWorktree,
169);
170
171request_messages!(
172    (FormatBuffer, Ack),
173    (GetChannelMessages, GetChannelMessagesResponse),
174    (GetChannels, GetChannelsResponse),
175    (GetCompletions, GetCompletionsResponse),
176    (GetDefinition, GetDefinitionResponse),
177    (GetUsers, GetUsersResponse),
178    (JoinChannel, JoinChannelResponse),
179    (JoinProject, JoinProjectResponse),
180    (OpenBuffer, OpenBufferResponse),
181    (Ping, Ack),
182    (RegisterProject, RegisterProjectResponse),
183    (RegisterWorktree, Ack),
184    (SaveBuffer, BufferSaved),
185    (SendChannelMessage, SendChannelMessageResponse),
186    (ShareProject, Ack),
187    (ShareWorktree, Ack),
188    (UpdateBuffer, Ack),
189);
190
191entity_messages!(
192    project_id,
193    AddProjectCollaborator,
194    BufferReloaded,
195    BufferSaved,
196    CloseBuffer,
197    DiskBasedDiagnosticsUpdated,
198    DiskBasedDiagnosticsUpdating,
199    FormatBuffer,
200    GetCompletions,
201    GetDefinition,
202    JoinProject,
203    LeaveProject,
204    OpenBuffer,
205    RemoveProjectCollaborator,
206    SaveBuffer,
207    ShareWorktree,
208    UnregisterWorktree,
209    UnshareProject,
210    UpdateBuffer,
211    UpdateBufferFile,
212    UpdateDiagnosticSummary,
213    UpdateWorktree,
214);
215
216entity_messages!(channel_id, ChannelMessageSent);
217
218/// A stream of protobuf messages.
219pub struct MessageStream<S> {
220    stream: S,
221    encoding_buffer: Vec<u8>,
222}
223
224impl<S> MessageStream<S> {
225    pub fn new(stream: S) -> Self {
226        Self {
227            stream,
228            encoding_buffer: Vec::new(),
229        }
230    }
231
232    pub fn inner_mut(&mut self) -> &mut S {
233        &mut self.stream
234    }
235}
236
237impl<S> MessageStream<S>
238where
239    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
240{
241    /// Write a given protobuf message to the stream.
242    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
243        self.encoding_buffer.resize(message.encoded_len(), 0);
244        self.encoding_buffer.clear();
245        message
246            .encode(&mut self.encoding_buffer)
247            .map_err(|err| io::Error::from(err))?;
248        let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
249        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
250        Ok(())
251    }
252}
253
254impl<S> MessageStream<S>
255where
256    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
257{
258    /// Read a protobuf message of the given type from the stream.
259    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
260        while let Some(bytes) = self.stream.next().await {
261            match bytes? {
262                WebSocketMessage::Binary(bytes) => {
263                    self.encoding_buffer.clear();
264                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
265                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
266                        .map_err(io::Error::from)?;
267                    return Ok(envelope);
268                }
269                WebSocketMessage::Close(_) => break,
270                _ => {}
271            }
272        }
273        Err(WebSocketError::ConnectionClosed)
274    }
275}
276
277impl Into<SystemTime> for Timestamp {
278    fn into(self) -> SystemTime {
279        UNIX_EPOCH
280            .checked_add(Duration::new(self.seconds, self.nanos))
281            .unwrap()
282    }
283}
284
285impl From<SystemTime> for Timestamp {
286    fn from(time: SystemTime) -> Self {
287        let duration = time.duration_since(UNIX_EPOCH).unwrap();
288        Self {
289            seconds: duration.as_secs(),
290            nanos: duration.subsec_nanos(),
291        }
292    }
293}
294
295impl From<u128> for Nonce {
296    fn from(nonce: u128) -> Self {
297        let upper_half = (nonce >> 64) as u64;
298        let lower_half = nonce as u64;
299        Self {
300            upper_half,
301            lower_half,
302        }
303    }
304}
305
306impl From<Nonce> for u128 {
307    fn from(nonce: Nonce) -> Self {
308        let upper_half = (nonce.upper_half as u128) << 64;
309        let lower_half = nonce.lower_half as u128;
310        upper_half | lower_half
311    }
312}