proto.rs

  1use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  2use futures::{SinkExt as _, StreamExt as _};
  3use prost::Message;
  4use std::{
  5    io,
  6    time::{Duration, SystemTime, UNIX_EPOCH},
  7};
  8
  9include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 10
 11pub trait EnvelopedMessage: Clone + Sized + Send + 'static {
 12    const NAME: &'static str;
 13    fn into_envelope(
 14        self,
 15        id: u32,
 16        responding_to: Option<u32>,
 17        original_sender_id: Option<u32>,
 18    ) -> Envelope;
 19    fn matches_envelope(envelope: &Envelope) -> bool;
 20    fn from_envelope(envelope: Envelope) -> Option<Self>;
 21}
 22
 23pub trait RequestMessage: EnvelopedMessage {
 24    type Response: EnvelopedMessage;
 25}
 26
 27macro_rules! message {
 28    ($name:ident) => {
 29        impl EnvelopedMessage for $name {
 30            const NAME: &'static str = std::stringify!($name);
 31
 32            fn into_envelope(
 33                self,
 34                id: u32,
 35                responding_to: Option<u32>,
 36                original_sender_id: Option<u32>,
 37            ) -> Envelope {
 38                Envelope {
 39                    id,
 40                    responding_to,
 41                    original_sender_id,
 42                    payload: Some(envelope::Payload::$name(self)),
 43                }
 44            }
 45
 46            fn matches_envelope(envelope: &Envelope) -> bool {
 47                matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
 48            }
 49
 50            fn from_envelope(envelope: Envelope) -> Option<Self> {
 51                if let Some(envelope::Payload::$name(msg)) = envelope.payload {
 52                    Some(msg)
 53                } else {
 54                    None
 55                }
 56            }
 57        }
 58    };
 59}
 60
 61macro_rules! request_message {
 62    ($req:ident, $resp:ident) => {
 63        message!($req);
 64        message!($resp);
 65        impl RequestMessage for $req {
 66            type Response = $resp;
 67        }
 68    };
 69}
 70
 71request_message!(Auth, AuthResponse);
 72request_message!(ShareWorktree, ShareWorktreeResponse);
 73request_message!(OpenWorktree, OpenWorktreeResponse);
 74message!(UpdateWorktree);
 75message!(CloseWorktree);
 76request_message!(OpenBuffer, OpenBufferResponse);
 77message!(CloseBuffer);
 78message!(UpdateBuffer);
 79request_message!(SaveBuffer, BufferSaved);
 80message!(AddPeer);
 81message!(RemovePeer);
 82
 83/// A stream of protobuf messages.
 84pub struct MessageStream<S> {
 85    stream: S,
 86}
 87
 88impl<S> MessageStream<S> {
 89    pub fn new(stream: S) -> Self {
 90        Self { stream }
 91    }
 92
 93    pub fn inner_mut(&mut self) -> &mut S {
 94        &mut self.stream
 95    }
 96}
 97
 98impl<S> MessageStream<S>
 99where
100    S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
101{
102    /// Write a given protobuf message to the stream.
103    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
104        let mut buffer = Vec::with_capacity(message.encoded_len());
105        message
106            .encode(&mut buffer)
107            .map_err(|err| io::Error::from(err))?;
108        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
109        Ok(())
110    }
111}
112
113impl<S> MessageStream<S>
114where
115    S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
116{
117    /// Read a protobuf message of the given type from the stream.
118    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
119        while let Some(bytes) = self.stream.next().await {
120            match bytes? {
121                WebSocketMessage::Binary(bytes) => {
122                    let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
123                    return Ok(envelope);
124                }
125                WebSocketMessage::Close(_) => break,
126                _ => {}
127            }
128        }
129        Err(WebSocketError::ConnectionClosed)
130    }
131}
132
133impl Into<SystemTime> for Timestamp {
134    fn into(self) -> SystemTime {
135        UNIX_EPOCH
136            .checked_add(Duration::new(self.seconds, self.nanos))
137            .unwrap()
138    }
139}
140
141impl From<SystemTime> for Timestamp {
142    fn from(time: SystemTime) -> Self {
143        let duration = time.duration_since(UNIX_EPOCH).unwrap();
144        Self {
145            seconds: duration.as_secs(),
146            nanos: duration.subsec_nanos(),
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::test;
155
156    #[test]
157    fn test_round_trip_message() {
158        smol::block_on(async {
159            let stream = test::Channel::new();
160            let message1 = Auth {
161                user_id: 5,
162                access_token: "the-access-token".into(),
163            }
164            .into_envelope(3, None, None);
165
166            let message2 = OpenBuffer {
167                worktree_id: 0,
168                path: "some/path".to_string(),
169            }
170            .into_envelope(5, None, None);
171
172            let mut message_stream = MessageStream::new(stream);
173            message_stream.write_message(&message1).await.unwrap();
174            message_stream.write_message(&message2).await.unwrap();
175            let decoded_message1 = message_stream.read_message().await.unwrap();
176            let decoded_message2 = message_stream.read_message().await.unwrap();
177            assert_eq!(decoded_message1, message1);
178            assert_eq!(decoded_message2, message2);
179        });
180    }
181}