1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::Result;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message;
6use std::any::{Any, TypeId};
7use std::{
8 io,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
13
14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
15 const NAME: &'static str;
16 fn into_envelope(
17 self,
18 id: u32,
19 responding_to: Option<u32>,
20 original_sender_id: Option<u32>,
21 ) -> Envelope;
22 fn from_envelope(envelope: Envelope) -> Option<Self>;
23}
24
25pub trait EntityMessage: EnvelopedMessage {
26 fn remote_entity_id(&self) -> u64;
27}
28
29pub trait RequestMessage: EnvelopedMessage {
30 type Response: EnvelopedMessage;
31}
32
33pub trait AnyTypedEnvelope: 'static + Send + Sync {
34 fn payload_type_id(&self) -> TypeId;
35 fn payload_type_name(&self) -> &'static str;
36 fn as_any(&self) -> &dyn Any;
37 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
38}
39
40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
41 fn payload_type_id(&self) -> TypeId {
42 TypeId::of::<T>()
43 }
44
45 fn payload_type_name(&self) -> &'static str {
46 T::NAME
47 }
48
49 fn as_any(&self) -> &dyn Any {
50 self
51 }
52
53 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
54 self
55 }
56}
57
58macro_rules! messages {
59 ($($name:ident),* $(,)?) => {
60 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
61 match envelope.payload {
62 $(Some(envelope::Payload::$name(payload)) => {
63 Some(Box::new(TypedEnvelope {
64 sender_id,
65 original_sender_id: envelope.original_sender_id.map(PeerId),
66 message_id: envelope.id,
67 payload,
68 }))
69 }, )*
70 _ => None
71 }
72 }
73
74 $(
75 impl EnvelopedMessage for $name {
76 const NAME: &'static str = std::stringify!($name);
77
78 fn into_envelope(
79 self,
80 id: u32,
81 responding_to: Option<u32>,
82 original_sender_id: Option<u32>,
83 ) -> Envelope {
84 Envelope {
85 id,
86 responding_to,
87 original_sender_id,
88 payload: Some(envelope::Payload::$name(self)),
89 }
90 }
91
92 fn from_envelope(envelope: Envelope) -> Option<Self> {
93 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
94 Some(msg)
95 } else {
96 None
97 }
98 }
99 }
100 )*
101 };
102}
103
104macro_rules! request_messages {
105 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106 $(impl RequestMessage for $request_name {
107 type Response = $response_name;
108 })*
109 };
110}
111
112macro_rules! entity_messages {
113 ($id_field:ident, $($name:ident),* $(,)?) => {
114 $(impl EntityMessage for $name {
115 fn remote_entity_id(&self) -> u64 {
116 self.$id_field
117 }
118 })*
119 };
120}
121
122messages!(
123 AddPeer,
124 BufferSaved,
125 ChannelMessageSent,
126 CloseBuffer,
127 CloseWorktree,
128 GetChannels,
129 GetChannelsResponse,
130 GetUsers,
131 GetUsersResponse,
132 JoinChannel,
133 JoinChannelResponse,
134 LeaveChannel,
135 OpenBuffer,
136 OpenBufferResponse,
137 OpenWorktree,
138 OpenWorktreeResponse,
139 Ping,
140 Pong,
141 RemovePeer,
142 SaveBuffer,
143 SendChannelMessage,
144 SendChannelMessageResponse,
145 ShareWorktree,
146 ShareWorktreeResponse,
147 UpdateBuffer,
148 UpdateWorktree,
149);
150
151request_messages!(
152 (GetChannels, GetChannelsResponse),
153 (GetUsers, GetUsersResponse),
154 (JoinChannel, JoinChannelResponse),
155 (OpenBuffer, OpenBufferResponse),
156 (OpenWorktree, OpenWorktreeResponse),
157 (Ping, Pong),
158 (SaveBuffer, BufferSaved),
159 (ShareWorktree, ShareWorktreeResponse),
160 (SendChannelMessage, SendChannelMessageResponse),
161);
162
163entity_messages!(
164 worktree_id,
165 AddPeer,
166 BufferSaved,
167 CloseBuffer,
168 CloseWorktree,
169 OpenBuffer,
170 OpenWorktree,
171 RemovePeer,
172 SaveBuffer,
173 UpdateBuffer,
174 UpdateWorktree,
175);
176
177entity_messages!(channel_id, ChannelMessageSent);
178
179/// A stream of protobuf messages.
180pub struct MessageStream<S> {
181 stream: S,
182}
183
184impl<S> MessageStream<S> {
185 pub fn new(stream: S) -> Self {
186 Self { stream }
187 }
188
189 pub fn inner_mut(&mut self) -> &mut S {
190 &mut self.stream
191 }
192}
193
194impl<S> MessageStream<S>
195where
196 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
197{
198 /// Write a given protobuf message to the stream.
199 pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
200 let mut buffer = Vec::with_capacity(message.encoded_len());
201 message
202 .encode(&mut buffer)
203 .map_err(|err| io::Error::from(err))?;
204 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
205 Ok(())
206 }
207}
208
209impl<S> MessageStream<S>
210where
211 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
212{
213 /// Read a protobuf message of the given type from the stream.
214 pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
215 while let Some(bytes) = self.stream.next().await {
216 match bytes? {
217 WebSocketMessage::Binary(bytes) => {
218 let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
219 return Ok(envelope);
220 }
221 WebSocketMessage::Close(_) => break,
222 _ => {}
223 }
224 }
225 Err(WebSocketError::ConnectionClosed)
226 }
227}
228
229impl Into<SystemTime> for Timestamp {
230 fn into(self) -> SystemTime {
231 UNIX_EPOCH
232 .checked_add(Duration::new(self.seconds, self.nanos))
233 .unwrap()
234 }
235}
236
237impl From<SystemTime> for Timestamp {
238 fn from(time: SystemTime) -> Self {
239 let duration = time.duration_since(UNIX_EPOCH).unwrap();
240 Self {
241 seconds: duration.as_secs(),
242 nanos: duration.subsec_nanos(),
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use crate::test;
251
252 #[test]
253 fn test_round_trip_message() {
254 smol::block_on(async {
255 let stream = test::Channel::new();
256 let message1 = Ping { id: 5 }.into_envelope(3, None, None);
257 let message2 = OpenBuffer {
258 worktree_id: 0,
259 path: "some/path".to_string(),
260 }
261 .into_envelope(5, None, None);
262
263 let mut message_stream = MessageStream::new(stream);
264 message_stream.write_message(&message1).await.unwrap();
265 message_stream.write_message(&message2).await.unwrap();
266 let decoded_message1 = message_stream.read_message().await.unwrap();
267 let decoded_message2 = message_stream.read_message().await.unwrap();
268 assert_eq!(decoded_message1, message1);
269 assert_eq!(decoded_message2, message2);
270 });
271 }
272}