proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::Result;
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::Any;
  7use std::{
  8    io,
  9    time::{Duration, SystemTime, UNIX_EPOCH},
 10};
 11
 12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 13
 14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 15    const NAME: &'static str;
 16    fn into_envelope(
 17        self,
 18        id: u32,
 19        responding_to: Option<u32>,
 20        original_sender_id: Option<u32>,
 21    ) -> Envelope;
 22    fn matches_envelope(envelope: &Envelope) -> bool;
 23    fn from_envelope(envelope: Envelope) -> Option<Self>;
 24}
 25
 26pub trait EntityMessage: EnvelopedMessage {
 27    fn remote_entity_id(&self) -> u64;
 28}
 29
 30pub trait RequestMessage: EnvelopedMessage {
 31    type Response: EnvelopedMessage;
 32}
 33
 34macro_rules! messages {
 35    ($($name:ident),* $(,)?) => {
 36        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn Any + Send + Sync>> {
 37            match envelope.payload {
 38                $(Some(envelope::Payload::$name(payload)) => {
 39                    Some(Box::new(TypedEnvelope {
 40                        sender_id,
 41                        original_sender_id: envelope.original_sender_id.map(PeerId),
 42                        message_id: envelope.id,
 43                        payload,
 44                    }))
 45                }, )*
 46                _ => None
 47            }
 48        }
 49
 50        $(
 51            impl EnvelopedMessage for $name {
 52                const NAME: &'static str = std::stringify!($name);
 53
 54                fn into_envelope(
 55                    self,
 56                    id: u32,
 57                    responding_to: Option<u32>,
 58                    original_sender_id: Option<u32>,
 59                ) -> Envelope {
 60                    Envelope {
 61                        id,
 62                        responding_to,
 63                        original_sender_id,
 64                        payload: Some(envelope::Payload::$name(self)),
 65                    }
 66                }
 67
 68                fn matches_envelope(envelope: &Envelope) -> bool {
 69                    matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
 70                }
 71
 72                fn from_envelope(envelope: Envelope) -> Option<Self> {
 73                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 74                        Some(msg)
 75                    } else {
 76                        None
 77                    }
 78                }
 79            }
 80        )*
 81    };
 82}
 83
 84macro_rules! request_messages {
 85    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
 86        $(impl RequestMessage for $request_name {
 87            type Response = $response_name;
 88        })*
 89    };
 90}
 91
 92macro_rules! entity_messages {
 93    ($id_field:ident, $($name:ident),* $(,)?) => {
 94        $(impl EntityMessage for $name {
 95            fn remote_entity_id(&self) -> u64 {
 96                self.$id_field
 97            }
 98        })*
 99    };
100}
101
102messages!(
103    AddPeer,
104    Auth,
105    AuthResponse,
106    BufferSaved,
107    ChannelMessageSent,
108    CloseBuffer,
109    CloseWorktree,
110    GetChannels,
111    GetChannelsResponse,
112    GetUsers,
113    GetUsersResponse,
114    JoinChannel,
115    JoinChannelResponse,
116    OpenBuffer,
117    OpenBufferResponse,
118    OpenWorktree,
119    OpenWorktreeResponse,
120    Ping,
121    Pong,
122    RemovePeer,
123    SaveBuffer,
124    SendChannelMessage,
125    ShareWorktree,
126    ShareWorktreeResponse,
127    UpdateBuffer,
128    UpdateWorktree,
129);
130
131request_messages!(
132    (Auth, AuthResponse),
133    (GetChannels, GetChannelsResponse),
134    (GetUsers, GetUsersResponse),
135    (JoinChannel, JoinChannelResponse),
136    (OpenBuffer, OpenBufferResponse),
137    (OpenWorktree, OpenWorktreeResponse),
138    (Ping, Pong),
139    (SaveBuffer, BufferSaved),
140    (ShareWorktree, ShareWorktreeResponse),
141);
142
143entity_messages!(
144    worktree_id,
145    AddPeer,
146    BufferSaved,
147    CloseBuffer,
148    CloseWorktree,
149    OpenBuffer,
150    OpenWorktree,
151    RemovePeer,
152    SaveBuffer,
153    UpdateBuffer,
154    UpdateWorktree,
155);
156
157entity_messages!(channel_id, ChannelMessageSent);
158
159/// A stream of protobuf messages.
160pub struct MessageStream<S> {
161    stream: S,
162}
163
164impl<S> MessageStream<S> {
165    pub fn new(stream: S) -> Self {
166        Self { stream }
167    }
168
169    pub fn inner_mut(&mut self) -> &mut S {
170        &mut self.stream
171    }
172}
173
174impl<S> MessageStream<S>
175where
176    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
177{
178    /// Write a given protobuf message to the stream.
179    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
180        let mut buffer = Vec::with_capacity(message.encoded_len());
181        message
182            .encode(&mut buffer)
183            .map_err(|err| io::Error::from(err))?;
184        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
185        Ok(())
186    }
187}
188
189impl<S> MessageStream<S>
190where
191    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
192{
193    /// Read a protobuf message of the given type from the stream.
194    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
195        while let Some(bytes) = self.stream.next().await {
196            match bytes? {
197                WebSocketMessage::Binary(bytes) => {
198                    let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
199                    return Ok(envelope);
200                }
201                WebSocketMessage::Close(_) => break,
202                _ => {}
203            }
204        }
205        Err(WebSocketError::ConnectionClosed)
206    }
207}
208
209impl Into<SystemTime> for Timestamp {
210    fn into(self) -> SystemTime {
211        UNIX_EPOCH
212            .checked_add(Duration::new(self.seconds, self.nanos))
213            .unwrap()
214    }
215}
216
217impl From<SystemTime> for Timestamp {
218    fn from(time: SystemTime) -> Self {
219        let duration = time.duration_since(UNIX_EPOCH).unwrap();
220        Self {
221            seconds: duration.as_secs(),
222            nanos: duration.subsec_nanos(),
223        }
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::test;
231
232    #[test]
233    fn test_round_trip_message() {
234        smol::block_on(async {
235            let stream = test::Channel::new();
236            let message1 = Auth {
237                user_id: 5,
238                access_token: "the-access-token".into(),
239            }
240            .into_envelope(3, None, None);
241
242            let message2 = OpenBuffer {
243                worktree_id: 0,
244                path: "some/path".to_string(),
245            }
246            .into_envelope(5, None, None);
247
248            let mut message_stream = MessageStream::new(stream);
249            message_stream.write_message(&message1).await.unwrap();
250            message_stream.write_message(&message2).await.unwrap();
251            let decoded_message1 = message_stream.read_message().await.unwrap();
252            let decoded_message2 = message_stream.read_message().await.unwrap();
253            assert_eq!(decoded_message1, message1);
254            assert_eq!(decoded_message2, message2);
255        });
256    }
257}