proto.rs

  1use super::{ConnectionId, PeerId, TypedEnvelope};
  2use anyhow::{anyhow, Result};
  3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  4use futures::{SinkExt as _, StreamExt as _};
  5use prost::Message;
  6use std::any::Any;
  7use std::sync::Arc;
  8use std::{
  9    io,
 10    time::{Duration, SystemTime, UNIX_EPOCH},
 11};
 12
 13include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 14
 15pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
 16    const NAME: &'static str;
 17    fn into_envelope(
 18        self,
 19        id: u32,
 20        responding_to: Option<u32>,
 21        original_sender_id: Option<u32>,
 22    ) -> Envelope;
 23    fn matches_envelope(envelope: &Envelope) -> bool;
 24    fn from_envelope(envelope: Envelope) -> Option<Self>;
 25}
 26
 27pub trait RequestMessage: EnvelopedMessage {
 28    type Response: EnvelopedMessage;
 29}
 30
 31macro_rules! messages {
 32    ($($name:ident),*) => {
 33        fn unicast_message_into_typed_envelope(sender_id: ConnectionId, envelope: &mut Envelope) -> Option<Arc<dyn Any + Send + Sync>> {
 34            match &mut envelope.payload {
 35                $(payload @ Some(envelope::Payload::$name(_)) => Some(Arc::new(TypedEnvelope {
 36                    sender_id,
 37                    original_sender_id: envelope.original_sender_id.map(PeerId),
 38                    message_id: envelope.id,
 39                    payload: payload.take().unwrap(),
 40                })), )*
 41                _ => None
 42            }
 43        }
 44
 45        $(
 46            message!($name);
 47        )*
 48    };
 49}
 50
 51macro_rules! request_messages {
 52    ($(($request_name:ident, $response_name:ident)),*) => {
 53        fn request_message_into_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Arc<dyn Any + Send + Sync>> {
 54            match envelope.payload {
 55                $(
 56                    Some(envelope::Payload::$request_name(payload)) => Some(Arc::new(TypedEnvelope {
 57                        sender_id,
 58                        original_sender_id: envelope.original_sender_id.map(PeerId),
 59                        message_id: envelope.id,
 60                        payload,
 61                    })),
 62                    Some(envelope::Payload::$response_name(payload)) => Some(Arc::new(TypedEnvelope {
 63                        sender_id,
 64                        original_sender_id: envelope.original_sender_id.map(PeerId),
 65                        message_id: envelope.id,
 66                        payload,
 67                    })),
 68                )*
 69                _ => None
 70            }
 71        }
 72
 73        $(
 74            message!($request_name);
 75            message!($response_name);
 76        )*
 77
 78        $(impl RequestMessage for $request_name {
 79            type Response = $response_name;
 80        })*
 81    };
 82}
 83
 84macro_rules! message {
 85    ($name:ident) => {
 86        impl EnvelopedMessage for $name {
 87            const NAME: &'static str = std::stringify!($name);
 88
 89            fn into_envelope(
 90                self,
 91                id: u32,
 92                responding_to: Option<u32>,
 93                original_sender_id: Option<u32>,
 94            ) -> Envelope {
 95                Envelope {
 96                    id,
 97                    responding_to,
 98                    original_sender_id,
 99                    payload: Some(envelope::Payload::$name(self)),
100                }
101            }
102
103            fn matches_envelope(envelope: &Envelope) -> bool {
104                matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
105            }
106
107            fn from_envelope(envelope: Envelope) -> Option<Self> {
108                if let Some(envelope::Payload::$name(msg)) = envelope.payload {
109                    Some(msg)
110                } else {
111                    None
112                }
113            }
114        }
115    };
116}
117
118messages!(
119    UpdateWorktree,
120    CloseWorktree,
121    CloseBuffer,
122    UpdateBuffer,
123    AddPeer,
124    RemovePeer,
125    SendChannelMessage,
126    ChannelMessageSent
127);
128
129request_messages!(
130    (Auth, AuthResponse),
131    (ShareWorktree, ShareWorktreeResponse),
132    (OpenWorktree, OpenWorktreeResponse),
133    (OpenBuffer, OpenBufferResponse),
134    (SaveBuffer, BufferSaved),
135    (GetChannels, GetChannelsResponse),
136    (JoinChannel, JoinChannelResponse),
137    (GetUsers, GetUsersResponse)
138);
139
140pub fn build_typed_envelope(
141    sender_id: ConnectionId,
142    mut envelope: Envelope,
143) -> Result<Arc<dyn Any + Send + Sync>> {
144    unicast_message_into_typed_envelope(sender_id, &mut envelope)
145        .or_else(|| request_message_into_typed_envelope(sender_id, envelope))
146        .ok_or_else(|| anyhow!("unrecognized payload type"))
147}
148
149/// A stream of protobuf messages.
150pub struct MessageStream<S> {
151    stream: S,
152}
153
154impl<S> MessageStream<S> {
155    pub fn new(stream: S) -> Self {
156        Self { stream }
157    }
158
159    pub fn inner_mut(&mut self) -> &mut S {
160        &mut self.stream
161    }
162}
163
164impl<S> MessageStream<S>
165where
166    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
167{
168    /// Write a given protobuf message to the stream.
169    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
170        let mut buffer = Vec::with_capacity(message.encoded_len());
171        message
172            .encode(&mut buffer)
173            .map_err(|err| io::Error::from(err))?;
174        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
175        Ok(())
176    }
177}
178
179impl<S> MessageStream<S>
180where
181    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
182{
183    /// Read a protobuf message of the given type from the stream.
184    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
185        while let Some(bytes) = self.stream.next().await {
186            match bytes? {
187                WebSocketMessage::Binary(bytes) => {
188                    let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
189                    return Ok(envelope);
190                }
191                WebSocketMessage::Close(_) => break,
192                _ => {}
193            }
194        }
195        Err(WebSocketError::ConnectionClosed)
196    }
197}
198
199impl Into<SystemTime> for Timestamp {
200    fn into(self) -> SystemTime {
201        UNIX_EPOCH
202            .checked_add(Duration::new(self.seconds, self.nanos))
203            .unwrap()
204    }
205}
206
207impl From<SystemTime> for Timestamp {
208    fn from(time: SystemTime) -> Self {
209        let duration = time.duration_since(UNIX_EPOCH).unwrap();
210        Self {
211            seconds: duration.as_secs(),
212            nanos: duration.subsec_nanos(),
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::test;
221
222    #[test]
223    fn test_round_trip_message() {
224        smol::block_on(async {
225            let stream = test::Channel::new();
226            let message1 = Auth {
227                user_id: 5,
228                access_token: "the-access-token".into(),
229            }
230            .into_envelope(3, None, None);
231
232            let message2 = OpenBuffer {
233                worktree_id: 0,
234                path: "some/path".to_string(),
235            }
236            .into_envelope(5, None, None);
237
238            let mut message_stream = MessageStream::new(stream);
239            message_stream.write_message(&message1).await.unwrap();
240            message_stream.write_message(&message2).await.unwrap();
241            let decoded_message1 = message_stream.read_message().await.unwrap();
242            let decoded_message2 = message_stream.read_message().await.unwrap();
243            assert_eq!(decoded_message1, message1);
244            assert_eq!(decoded_message2, message2);
245        });
246    }
247}