1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::Result;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message;
6use std::any::Any;
7use std::{
8 io,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
13
14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
15 const NAME: &'static str;
16 fn into_envelope(
17 self,
18 id: u32,
19 responding_to: Option<u32>,
20 original_sender_id: Option<u32>,
21 ) -> Envelope;
22 fn matches_envelope(envelope: &Envelope) -> bool;
23 fn from_envelope(envelope: Envelope) -> Option<Self>;
24}
25
26pub trait EntityMessage: EnvelopedMessage {
27 fn remote_entity_id(&self) -> u64;
28}
29
30pub trait RequestMessage: EnvelopedMessage {
31 type Response: EnvelopedMessage;
32}
33
34macro_rules! messages {
35 ($($name:ident),* $(,)?) => {
36 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn Any + Send + Sync>> {
37 match envelope.payload {
38 $(Some(envelope::Payload::$name(payload)) => {
39 Some(Box::new(TypedEnvelope {
40 sender_id,
41 original_sender_id: envelope.original_sender_id.map(PeerId),
42 message_id: envelope.id,
43 payload,
44 }))
45 }, )*
46 _ => None
47 }
48 }
49
50 $(
51 impl EnvelopedMessage for $name {
52 const NAME: &'static str = std::stringify!($name);
53
54 fn into_envelope(
55 self,
56 id: u32,
57 responding_to: Option<u32>,
58 original_sender_id: Option<u32>,
59 ) -> Envelope {
60 Envelope {
61 id,
62 responding_to,
63 original_sender_id,
64 payload: Some(envelope::Payload::$name(self)),
65 }
66 }
67
68 fn matches_envelope(envelope: &Envelope) -> bool {
69 matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
70 }
71
72 fn from_envelope(envelope: Envelope) -> Option<Self> {
73 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
74 Some(msg)
75 } else {
76 None
77 }
78 }
79 }
80 )*
81 };
82}
83
84macro_rules! request_messages {
85 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
86 $(impl RequestMessage for $request_name {
87 type Response = $response_name;
88 })*
89 };
90}
91
92macro_rules! entity_messages {
93 ($id_field:ident, $($name:ident),* $(,)?) => {
94 $(impl EntityMessage for $name {
95 fn remote_entity_id(&self) -> u64 {
96 self.$id_field
97 }
98 })*
99 };
100}
101
102messages!(
103 AddPeer,
104 Auth,
105 AuthResponse,
106 BufferSaved,
107 ChannelMessageSent,
108 CloseBuffer,
109 CloseWorktree,
110 GetChannels,
111 GetChannelsResponse,
112 GetUsers,
113 GetUsersResponse,
114 JoinChannel,
115 JoinChannelResponse,
116 OpenBuffer,
117 OpenBufferResponse,
118 OpenWorktree,
119 OpenWorktreeResponse,
120 Ping,
121 Pong,
122 RemovePeer,
123 SaveBuffer,
124 SendChannelMessage,
125 ShareWorktree,
126 ShareWorktreeResponse,
127 UpdateBuffer,
128 UpdateWorktree,
129);
130
131request_messages!(
132 (Auth, AuthResponse),
133 (GetChannels, GetChannelsResponse),
134 (GetUsers, GetUsersResponse),
135 (JoinChannel, JoinChannelResponse),
136 (OpenBuffer, OpenBufferResponse),
137 (OpenWorktree, OpenWorktreeResponse),
138 (Ping, Pong),
139 (SaveBuffer, BufferSaved),
140 (ShareWorktree, ShareWorktreeResponse),
141);
142
143entity_messages!(
144 worktree_id,
145 AddPeer,
146 BufferSaved,
147 CloseBuffer,
148 CloseWorktree,
149 OpenBuffer,
150 OpenWorktree,
151 RemovePeer,
152 SaveBuffer,
153 UpdateBuffer,
154 UpdateWorktree,
155);
156
157entity_messages!(channel_id, ChannelMessageSent);
158
159/// A stream of protobuf messages.
160pub struct MessageStream<S> {
161 stream: S,
162}
163
164impl<S> MessageStream<S> {
165 pub fn new(stream: S) -> Self {
166 Self { stream }
167 }
168
169 pub fn inner_mut(&mut self) -> &mut S {
170 &mut self.stream
171 }
172}
173
174impl<S> MessageStream<S>
175where
176 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
177{
178 /// Write a given protobuf message to the stream.
179 pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
180 let mut buffer = Vec::with_capacity(message.encoded_len());
181 message
182 .encode(&mut buffer)
183 .map_err(|err| io::Error::from(err))?;
184 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
185 Ok(())
186 }
187}
188
189impl<S> MessageStream<S>
190where
191 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
192{
193 /// Read a protobuf message of the given type from the stream.
194 pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
195 while let Some(bytes) = self.stream.next().await {
196 match bytes? {
197 WebSocketMessage::Binary(bytes) => {
198 let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
199 return Ok(envelope);
200 }
201 WebSocketMessage::Close(_) => break,
202 _ => {}
203 }
204 }
205 Err(WebSocketError::ConnectionClosed)
206 }
207}
208
209impl Into<SystemTime> for Timestamp {
210 fn into(self) -> SystemTime {
211 UNIX_EPOCH
212 .checked_add(Duration::new(self.seconds, self.nanos))
213 .unwrap()
214 }
215}
216
217impl From<SystemTime> for Timestamp {
218 fn from(time: SystemTime) -> Self {
219 let duration = time.duration_since(UNIX_EPOCH).unwrap();
220 Self {
221 seconds: duration.as_secs(),
222 nanos: duration.subsec_nanos(),
223 }
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use crate::test;
231
232 #[test]
233 fn test_round_trip_message() {
234 smol::block_on(async {
235 let stream = test::Channel::new();
236 let message1 = Auth {
237 user_id: 5,
238 access_token: "the-access-token".into(),
239 }
240 .into_envelope(3, None, None);
241
242 let message2 = OpenBuffer {
243 worktree_id: 0,
244 path: "some/path".to_string(),
245 }
246 .into_envelope(5, None, None);
247
248 let mut message_stream = MessageStream::new(stream);
249 message_stream.write_message(&message1).await.unwrap();
250 message_stream.write_message(&message2).await.unwrap();
251 let decoded_message1 = message_stream.read_message().await.unwrap();
252 let decoded_message2 = message_stream.read_message().await.unwrap();
253 assert_eq!(decoded_message1, message1);
254 assert_eq!(decoded_message2, message2);
255 });
256 }
257}