1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::Result;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message;
6use std::any::{Any, TypeId};
7use std::{
8 io,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
13
14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
15 const NAME: &'static str;
16 fn into_envelope(
17 self,
18 id: u32,
19 responding_to: Option<u32>,
20 original_sender_id: Option<u32>,
21 ) -> Envelope;
22 fn from_envelope(envelope: Envelope) -> Option<Self>;
23}
24
25pub trait EntityMessage: EnvelopedMessage {
26 fn remote_entity_id(&self) -> u64;
27}
28
29pub trait RequestMessage: EnvelopedMessage {
30 type Response: EnvelopedMessage;
31}
32
33pub trait AnyTypedEnvelope: 'static + Send + Sync {
34 fn payload_type_id(&self) -> TypeId;
35 fn payload_type_name(&self) -> &'static str;
36 fn as_any(&self) -> &dyn Any;
37 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
38}
39
40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
41 fn payload_type_id(&self) -> TypeId {
42 TypeId::of::<T>()
43 }
44
45 fn payload_type_name(&self) -> &'static str {
46 T::NAME
47 }
48
49 fn as_any(&self) -> &dyn Any {
50 self
51 }
52
53 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
54 self
55 }
56}
57
58macro_rules! messages {
59 ($($name:ident),* $(,)?) => {
60 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
61 match envelope.payload {
62 $(Some(envelope::Payload::$name(payload)) => {
63 Some(Box::new(TypedEnvelope {
64 sender_id,
65 original_sender_id: envelope.original_sender_id.map(PeerId),
66 message_id: envelope.id,
67 payload,
68 }))
69 }, )*
70 _ => None
71 }
72 }
73
74 $(
75 impl EnvelopedMessage for $name {
76 const NAME: &'static str = std::stringify!($name);
77
78 fn into_envelope(
79 self,
80 id: u32,
81 responding_to: Option<u32>,
82 original_sender_id: Option<u32>,
83 ) -> Envelope {
84 Envelope {
85 id,
86 responding_to,
87 original_sender_id,
88 payload: Some(envelope::Payload::$name(self)),
89 }
90 }
91
92 fn from_envelope(envelope: Envelope) -> Option<Self> {
93 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
94 Some(msg)
95 } else {
96 None
97 }
98 }
99 }
100 )*
101 };
102}
103
104macro_rules! request_messages {
105 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106 $(impl RequestMessage for $request_name {
107 type Response = $response_name;
108 })*
109 };
110}
111
112macro_rules! entity_messages {
113 ($id_field:ident, $($name:ident),* $(,)?) => {
114 $(impl EntityMessage for $name {
115 fn remote_entity_id(&self) -> u64 {
116 self.$id_field
117 }
118 })*
119 };
120}
121
122messages!(
123 Ack,
124 AddProjectCollaborator,
125 BufferSaved,
126 ChannelMessageSent,
127 CloseBuffer,
128 DiskBasedDiagnosticsUpdated,
129 DiskBasedDiagnosticsUpdating,
130 Error,
131 FormatBuffer,
132 GetChannelMessages,
133 GetChannelMessagesResponse,
134 GetChannels,
135 GetChannelsResponse,
136 GetUsers,
137 GetUsersResponse,
138 JoinChannel,
139 JoinChannelResponse,
140 JoinProject,
141 JoinProjectResponse,
142 LeaveChannel,
143 LeaveProject,
144 OpenBuffer,
145 OpenBufferResponse,
146 RegisterProjectResponse,
147 Ping,
148 RegisterProject,
149 RegisterWorktree,
150 RemoveProjectCollaborator,
151 SaveBuffer,
152 SendChannelMessage,
153 SendChannelMessageResponse,
154 ShareProject,
155 ShareWorktree,
156 UnregisterProject,
157 UnregisterWorktree,
158 UnshareProject,
159 UpdateBuffer,
160 UpdateContacts,
161 UpdateDiagnosticSummary,
162 UpdateWorktree,
163);
164
165request_messages!(
166 (FormatBuffer, Ack),
167 (GetChannelMessages, GetChannelMessagesResponse),
168 (GetChannels, GetChannelsResponse),
169 (GetUsers, GetUsersResponse),
170 (JoinChannel, JoinChannelResponse),
171 (JoinProject, JoinProjectResponse),
172 (OpenBuffer, OpenBufferResponse),
173 (Ping, Ack),
174 (RegisterProject, RegisterProjectResponse),
175 (RegisterWorktree, Ack),
176 (SaveBuffer, BufferSaved),
177 (SendChannelMessage, SendChannelMessageResponse),
178 (ShareProject, Ack),
179 (ShareWorktree, Ack),
180 (UpdateBuffer, Ack),
181);
182
183entity_messages!(
184 project_id,
185 AddProjectCollaborator,
186 BufferSaved,
187 CloseBuffer,
188 DiskBasedDiagnosticsUpdated,
189 DiskBasedDiagnosticsUpdating,
190 FormatBuffer,
191 JoinProject,
192 LeaveProject,
193 OpenBuffer,
194 RemoveProjectCollaborator,
195 SaveBuffer,
196 ShareWorktree,
197 UnregisterWorktree,
198 UnshareProject,
199 UpdateBuffer,
200 UpdateDiagnosticSummary,
201 UpdateWorktree,
202);
203
204entity_messages!(channel_id, ChannelMessageSent);
205
206/// A stream of protobuf messages.
207pub struct MessageStream<S> {
208 stream: S,
209 encoding_buffer: Vec<u8>,
210}
211
212impl<S> MessageStream<S> {
213 pub fn new(stream: S) -> Self {
214 Self {
215 stream,
216 encoding_buffer: Vec::new(),
217 }
218 }
219
220 pub fn inner_mut(&mut self) -> &mut S {
221 &mut self.stream
222 }
223}
224
225impl<S> MessageStream<S>
226where
227 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
228{
229 /// Write a given protobuf message to the stream.
230 pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
231 self.encoding_buffer.resize(message.encoded_len(), 0);
232 self.encoding_buffer.clear();
233 message
234 .encode(&mut self.encoding_buffer)
235 .map_err(|err| io::Error::from(err))?;
236 let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
237 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
238 Ok(())
239 }
240}
241
242impl<S> MessageStream<S>
243where
244 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
245{
246 /// Read a protobuf message of the given type from the stream.
247 pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
248 while let Some(bytes) = self.stream.next().await {
249 match bytes? {
250 WebSocketMessage::Binary(bytes) => {
251 self.encoding_buffer.clear();
252 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
253 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
254 .map_err(io::Error::from)?;
255 return Ok(envelope);
256 }
257 WebSocketMessage::Close(_) => break,
258 _ => {}
259 }
260 }
261 Err(WebSocketError::ConnectionClosed)
262 }
263}
264
265impl Into<SystemTime> for Timestamp {
266 fn into(self) -> SystemTime {
267 UNIX_EPOCH
268 .checked_add(Duration::new(self.seconds, self.nanos))
269 .unwrap()
270 }
271}
272
273impl From<SystemTime> for Timestamp {
274 fn from(time: SystemTime) -> Self {
275 let duration = time.duration_since(UNIX_EPOCH).unwrap();
276 Self {
277 seconds: duration.as_secs(),
278 nanos: duration.subsec_nanos(),
279 }
280 }
281}
282
283impl From<u128> for Nonce {
284 fn from(nonce: u128) -> Self {
285 let upper_half = (nonce >> 64) as u64;
286 let lower_half = nonce as u64;
287 Self {
288 upper_half,
289 lower_half,
290 }
291 }
292}
293
294impl From<Nonce> for u128 {
295 fn from(nonce: Nonce) -> Self {
296 let upper_half = (nonce.upper_half as u128) << 64;
297 let lower_half = nonce.lower_half as u128;
298 upper_half | lower_half
299 }
300}