1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::Result;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message;
6use std::any::{Any, TypeId};
7use std::{
8 io,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
13
14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
15 const NAME: &'static str;
16 fn into_envelope(
17 self,
18 id: u32,
19 responding_to: Option<u32>,
20 original_sender_id: Option<u32>,
21 ) -> Envelope;
22 fn from_envelope(envelope: Envelope) -> Option<Self>;
23}
24
25pub trait EntityMessage: EnvelopedMessage {
26 fn remote_entity_id(&self) -> u64;
27}
28
29pub trait RequestMessage: EnvelopedMessage {
30 type Response: EnvelopedMessage;
31}
32
33pub trait AnyTypedEnvelope: 'static + Send + Sync {
34 fn payload_type_id(&self) -> TypeId;
35 fn payload_type_name(&self) -> &'static str;
36 fn as_any(&self) -> &dyn Any;
37 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
38}
39
40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
41 fn payload_type_id(&self) -> TypeId {
42 TypeId::of::<T>()
43 }
44
45 fn payload_type_name(&self) -> &'static str {
46 T::NAME
47 }
48
49 fn as_any(&self) -> &dyn Any {
50 self
51 }
52
53 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
54 self
55 }
56}
57
58macro_rules! messages {
59 ($($name:ident),* $(,)?) => {
60 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
61 match envelope.payload {
62 $(Some(envelope::Payload::$name(payload)) => {
63 Some(Box::new(TypedEnvelope {
64 sender_id,
65 original_sender_id: envelope.original_sender_id.map(PeerId),
66 message_id: envelope.id,
67 payload,
68 }))
69 }, )*
70 _ => None
71 }
72 }
73
74 $(
75 impl EnvelopedMessage for $name {
76 const NAME: &'static str = std::stringify!($name);
77
78 fn into_envelope(
79 self,
80 id: u32,
81 responding_to: Option<u32>,
82 original_sender_id: Option<u32>,
83 ) -> Envelope {
84 Envelope {
85 id,
86 responding_to,
87 original_sender_id,
88 payload: Some(envelope::Payload::$name(self)),
89 }
90 }
91
92 fn from_envelope(envelope: Envelope) -> Option<Self> {
93 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
94 Some(msg)
95 } else {
96 None
97 }
98 }
99 }
100 )*
101 };
102}
103
104macro_rules! request_messages {
105 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106 $(impl RequestMessage for $request_name {
107 type Response = $response_name;
108 })*
109 };
110}
111
112macro_rules! entity_messages {
113 ($id_field:ident, $($name:ident),* $(,)?) => {
114 $(impl EntityMessage for $name {
115 fn remote_entity_id(&self) -> u64 {
116 self.$id_field
117 }
118 })*
119 };
120}
121
122messages!(
123 AddPeer,
124 BufferSaved,
125 ChannelMessageSent,
126 CloseBuffer,
127 CloseWorktree,
128 Error,
129 GetChannelMessages,
130 GetChannelMessagesResponse,
131 GetChannels,
132 GetChannelsResponse,
133 GetUsers,
134 GetUsersResponse,
135 JoinChannel,
136 JoinChannelResponse,
137 LeaveChannel,
138 OpenBuffer,
139 OpenBufferResponse,
140 OpenWorktree,
141 OpenWorktreeResponse,
142 Ping,
143 Pong,
144 RemovePeer,
145 SaveBuffer,
146 SendChannelMessage,
147 SendChannelMessageResponse,
148 ShareWorktree,
149 ShareWorktreeResponse,
150 UpdateBuffer,
151 UpdateWorktree,
152);
153
154request_messages!(
155 (GetChannels, GetChannelsResponse),
156 (GetUsers, GetUsersResponse),
157 (JoinChannel, JoinChannelResponse),
158 (OpenBuffer, OpenBufferResponse),
159 (OpenWorktree, OpenWorktreeResponse),
160 (Ping, Pong),
161 (SaveBuffer, BufferSaved),
162 (ShareWorktree, ShareWorktreeResponse),
163 (SendChannelMessage, SendChannelMessageResponse),
164 (GetChannelMessages, GetChannelMessagesResponse),
165);
166
167entity_messages!(
168 worktree_id,
169 AddPeer,
170 BufferSaved,
171 CloseBuffer,
172 CloseWorktree,
173 OpenBuffer,
174 OpenWorktree,
175 RemovePeer,
176 SaveBuffer,
177 UpdateBuffer,
178 UpdateWorktree,
179);
180
181entity_messages!(channel_id, ChannelMessageSent);
182
183/// A stream of protobuf messages.
184pub struct MessageStream<S> {
185 stream: S,
186}
187
188impl<S> MessageStream<S> {
189 pub fn new(stream: S) -> Self {
190 Self { stream }
191 }
192
193 pub fn inner_mut(&mut self) -> &mut S {
194 &mut self.stream
195 }
196}
197
198impl<S> MessageStream<S>
199where
200 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
201{
202 /// Write a given protobuf message to the stream.
203 pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
204 let mut buffer = Vec::with_capacity(message.encoded_len());
205 message
206 .encode(&mut buffer)
207 .map_err(|err| io::Error::from(err))?;
208 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
209 Ok(())
210 }
211}
212
213impl<S> MessageStream<S>
214where
215 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
216{
217 /// Read a protobuf message of the given type from the stream.
218 pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
219 while let Some(bytes) = self.stream.next().await {
220 match bytes? {
221 WebSocketMessage::Binary(bytes) => {
222 let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
223 return Ok(envelope);
224 }
225 WebSocketMessage::Close(_) => break,
226 _ => {}
227 }
228 }
229 Err(WebSocketError::ConnectionClosed)
230 }
231}
232
233impl Into<SystemTime> for Timestamp {
234 fn into(self) -> SystemTime {
235 UNIX_EPOCH
236 .checked_add(Duration::new(self.seconds, self.nanos))
237 .unwrap()
238 }
239}
240
241impl From<SystemTime> for Timestamp {
242 fn from(time: SystemTime) -> Self {
243 let duration = time.duration_since(UNIX_EPOCH).unwrap();
244 Self {
245 seconds: duration.as_secs(),
246 nanos: duration.subsec_nanos(),
247 }
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::test;
255
256 #[test]
257 fn test_round_trip_message() {
258 smol::block_on(async {
259 let stream = test::Channel::new();
260 let message1 = Ping { id: 5 }.into_envelope(3, None, None);
261 let message2 = OpenBuffer {
262 worktree_id: 0,
263 path: "some/path".to_string(),
264 }
265 .into_envelope(5, None, None);
266
267 let mut message_stream = MessageStream::new(stream);
268 message_stream.write_message(&message1).await.unwrap();
269 message_stream.write_message(&message2).await.unwrap();
270 let decoded_message1 = message_stream.read_message().await.unwrap();
271 let decoded_message2 = message_stream.read_message().await.unwrap();
272 assert_eq!(decoded_message1, message1);
273 assert_eq!(decoded_message2, message2);
274 });
275 }
276}