1use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
2use futures::{SinkExt as _, StreamExt as _};
3use prost::Message;
4use std::{
5 io,
6 time::{Duration, SystemTime, UNIX_EPOCH},
7};
8
9include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
10
11pub trait EnvelopedMessage: Clone + Sized + Send + 'static {
12 const NAME: &'static str;
13 fn into_envelope(
14 self,
15 id: u32,
16 responding_to: Option<u32>,
17 original_sender_id: Option<u32>,
18 ) -> Envelope;
19 fn matches_envelope(envelope: &Envelope) -> bool;
20 fn from_envelope(envelope: Envelope) -> Option<Self>;
21}
22
23pub trait RequestMessage: EnvelopedMessage {
24 type Response: EnvelopedMessage;
25}
26
27macro_rules! message {
28 ($name:ident) => {
29 impl EnvelopedMessage for $name {
30 const NAME: &'static str = std::stringify!($name);
31
32 fn into_envelope(
33 self,
34 id: u32,
35 responding_to: Option<u32>,
36 original_sender_id: Option<u32>,
37 ) -> Envelope {
38 Envelope {
39 id,
40 responding_to,
41 original_sender_id,
42 payload: Some(envelope::Payload::$name(self)),
43 }
44 }
45
46 fn matches_envelope(envelope: &Envelope) -> bool {
47 matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
48 }
49
50 fn from_envelope(envelope: Envelope) -> Option<Self> {
51 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
52 Some(msg)
53 } else {
54 None
55 }
56 }
57 }
58 };
59}
60
61macro_rules! request_message {
62 ($req:ident, $resp:ident) => {
63 message!($req);
64 message!($resp);
65 impl RequestMessage for $req {
66 type Response = $resp;
67 }
68 };
69}
70
71request_message!(Auth, AuthResponse);
72request_message!(ShareWorktree, ShareWorktreeResponse);
73request_message!(OpenWorktree, OpenWorktreeResponse);
74message!(UpdateWorktree);
75message!(CloseWorktree);
76request_message!(OpenBuffer, OpenBufferResponse);
77message!(CloseBuffer);
78message!(UpdateBuffer);
79request_message!(SaveBuffer, BufferSaved);
80message!(AddPeer);
81message!(RemovePeer);
82
83/// A stream of protobuf messages.
84pub struct MessageStream<S> {
85 stream: S,
86}
87
88impl<S> MessageStream<S> {
89 pub fn new(stream: S) -> Self {
90 Self { stream }
91 }
92
93 pub fn inner_mut(&mut self) -> &mut S {
94 &mut self.stream
95 }
96}
97
98impl<S> MessageStream<S>
99where
100 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
101{
102 /// Write a given protobuf message to the stream.
103 pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
104 let mut buffer = Vec::with_capacity(message.encoded_len());
105 message
106 .encode(&mut buffer)
107 .map_err(|err| io::Error::from(err))?;
108 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
109 Ok(())
110 }
111}
112
113impl<S> MessageStream<S>
114where
115 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
116{
117 /// Read a protobuf message of the given type from the stream.
118 pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
119 while let Some(bytes) = self.stream.next().await {
120 match bytes? {
121 WebSocketMessage::Binary(bytes) => {
122 let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
123 return Ok(envelope);
124 }
125 WebSocketMessage::Close(_) => break,
126 _ => {}
127 }
128 }
129 Err(WebSocketError::ConnectionClosed)
130 }
131}
132
133impl Into<SystemTime> for Timestamp {
134 fn into(self) -> SystemTime {
135 UNIX_EPOCH
136 .checked_add(Duration::new(self.seconds, self.nanos))
137 .unwrap()
138 }
139}
140
141impl From<SystemTime> for Timestamp {
142 fn from(time: SystemTime) -> Self {
143 let duration = time.duration_since(UNIX_EPOCH).unwrap();
144 Self {
145 seconds: duration.as_secs(),
146 nanos: duration.subsec_nanos(),
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::test;
155
156 #[test]
157 fn test_round_trip_message() {
158 smol::block_on(async {
159 let stream = test::Channel::new();
160 let message1 = Auth {
161 user_id: 5,
162 access_token: "the-access-token".into(),
163 }
164 .into_envelope(3, None, None);
165
166 let message2 = OpenBuffer {
167 worktree_id: 0,
168 path: "some/path".to_string(),
169 }
170 .into_envelope(5, None, None);
171
172 let mut message_stream = MessageStream::new(stream);
173 message_stream.write_message(&message1).await.unwrap();
174 message_stream.write_message(&message2).await.unwrap();
175 let decoded_message1 = message_stream.read_message().await.unwrap();
176 let decoded_message2 = message_stream.read_message().await.unwrap();
177 assert_eq!(decoded_message1, message1);
178 assert_eq!(decoded_message2, message2);
179 });
180 }
181}