1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::Result;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message;
6use std::any::{Any, TypeId};
7use std::{
8 io,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
13
14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
15 const NAME: &'static str;
16 fn into_envelope(
17 self,
18 id: u32,
19 responding_to: Option<u32>,
20 original_sender_id: Option<u32>,
21 ) -> Envelope;
22 fn matches_envelope(envelope: &Envelope) -> bool;
23 fn from_envelope(envelope: Envelope) -> Option<Self>;
24}
25
26pub trait EntityMessage: EnvelopedMessage {
27 fn remote_entity_id(&self) -> u64;
28}
29
30pub trait RequestMessage: EnvelopedMessage {
31 type Response: EnvelopedMessage;
32}
33
34pub trait AnyTypedEnvelope: 'static + Send + Sync {
35 fn payload_type_id(&self) -> TypeId;
36 fn payload_type_name(&self) -> &'static str;
37 fn as_any(&self) -> &dyn Any;
38 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
39}
40
41impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
42 fn payload_type_id(&self) -> TypeId {
43 TypeId::of::<T>()
44 }
45
46 fn payload_type_name(&self) -> &'static str {
47 T::NAME
48 }
49
50 fn as_any(&self) -> &dyn Any {
51 self
52 }
53
54 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
55 self
56 }
57}
58
59macro_rules! messages {
60 ($($name:ident),* $(,)?) => {
61 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
62 match envelope.payload {
63 $(Some(envelope::Payload::$name(payload)) => {
64 Some(Box::new(TypedEnvelope {
65 sender_id,
66 original_sender_id: envelope.original_sender_id.map(PeerId),
67 message_id: envelope.id,
68 payload,
69 }))
70 }, )*
71 _ => None
72 }
73 }
74
75 $(
76 impl EnvelopedMessage for $name {
77 const NAME: &'static str = std::stringify!($name);
78
79 fn into_envelope(
80 self,
81 id: u32,
82 responding_to: Option<u32>,
83 original_sender_id: Option<u32>,
84 ) -> Envelope {
85 Envelope {
86 id,
87 responding_to,
88 original_sender_id,
89 payload: Some(envelope::Payload::$name(self)),
90 }
91 }
92
93 fn matches_envelope(envelope: &Envelope) -> bool {
94 matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
95 }
96
97 fn from_envelope(envelope: Envelope) -> Option<Self> {
98 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
99 Some(msg)
100 } else {
101 None
102 }
103 }
104 }
105 )*
106 };
107}
108
109macro_rules! request_messages {
110 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
111 $(impl RequestMessage for $request_name {
112 type Response = $response_name;
113 })*
114 };
115}
116
117macro_rules! entity_messages {
118 ($id_field:ident, $($name:ident),* $(,)?) => {
119 $(impl EntityMessage for $name {
120 fn remote_entity_id(&self) -> u64 {
121 self.$id_field
122 }
123 })*
124 };
125}
126
127messages!(
128 AddPeer,
129 BufferSaved,
130 ChannelMessageSent,
131 CloseBuffer,
132 CloseWorktree,
133 GetChannels,
134 GetChannelsResponse,
135 GetUsers,
136 GetUsersResponse,
137 JoinChannel,
138 JoinChannelResponse,
139 LeaveChannel,
140 OpenBuffer,
141 OpenBufferResponse,
142 OpenWorktree,
143 OpenWorktreeResponse,
144 Ping,
145 Pong,
146 RemovePeer,
147 SaveBuffer,
148 SendChannelMessage,
149 SendChannelMessageResponse,
150 ShareWorktree,
151 ShareWorktreeResponse,
152 UpdateBuffer,
153 UpdateWorktree,
154);
155
156request_messages!(
157 (GetChannels, GetChannelsResponse),
158 (GetUsers, GetUsersResponse),
159 (JoinChannel, JoinChannelResponse),
160 (OpenBuffer, OpenBufferResponse),
161 (OpenWorktree, OpenWorktreeResponse),
162 (Ping, Pong),
163 (SaveBuffer, BufferSaved),
164 (ShareWorktree, ShareWorktreeResponse),
165 (SendChannelMessage, SendChannelMessageResponse),
166);
167
168entity_messages!(
169 worktree_id,
170 AddPeer,
171 BufferSaved,
172 CloseBuffer,
173 CloseWorktree,
174 OpenBuffer,
175 OpenWorktree,
176 RemovePeer,
177 SaveBuffer,
178 UpdateBuffer,
179 UpdateWorktree,
180);
181
182entity_messages!(channel_id, ChannelMessageSent);
183
184/// A stream of protobuf messages.
185pub struct MessageStream<S> {
186 stream: S,
187}
188
189impl<S> MessageStream<S> {
190 pub fn new(stream: S) -> Self {
191 Self { stream }
192 }
193
194 pub fn inner_mut(&mut self) -> &mut S {
195 &mut self.stream
196 }
197}
198
199impl<S> MessageStream<S>
200where
201 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
202{
203 /// Write a given protobuf message to the stream.
204 pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
205 let mut buffer = Vec::with_capacity(message.encoded_len());
206 message
207 .encode(&mut buffer)
208 .map_err(|err| io::Error::from(err))?;
209 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
210 Ok(())
211 }
212}
213
214impl<S> MessageStream<S>
215where
216 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
217{
218 /// Read a protobuf message of the given type from the stream.
219 pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
220 while let Some(bytes) = self.stream.next().await {
221 match bytes? {
222 WebSocketMessage::Binary(bytes) => {
223 let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
224 return Ok(envelope);
225 }
226 WebSocketMessage::Close(_) => break,
227 _ => {}
228 }
229 }
230 Err(WebSocketError::ConnectionClosed)
231 }
232}
233
234impl Into<SystemTime> for Timestamp {
235 fn into(self) -> SystemTime {
236 UNIX_EPOCH
237 .checked_add(Duration::new(self.seconds, self.nanos))
238 .unwrap()
239 }
240}
241
242impl From<SystemTime> for Timestamp {
243 fn from(time: SystemTime) -> Self {
244 let duration = time.duration_since(UNIX_EPOCH).unwrap();
245 Self {
246 seconds: duration.as_secs(),
247 nanos: duration.subsec_nanos(),
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use crate::test;
256
257 #[test]
258 fn test_round_trip_message() {
259 smol::block_on(async {
260 let stream = test::Channel::new();
261 let message1 = Ping { id: 5 }.into_envelope(3, None, None);
262 let message2 = OpenBuffer {
263 worktree_id: 0,
264 path: "some/path".to_string(),
265 }
266 .into_envelope(5, None, None);
267
268 let mut message_stream = MessageStream::new(stream);
269 message_stream.write_message(&message1).await.unwrap();
270 message_stream.write_message(&message2).await.unwrap();
271 let decoded_message1 = message_stream.read_message().await.unwrap();
272 let decoded_message2 = message_stream.read_message().await.unwrap();
273 assert_eq!(decoded_message1, message1);
274 assert_eq!(decoded_message2, message2);
275 });
276 }
277}