1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::Result;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message;
6use std::any::{Any, TypeId};
7use std::{
8 io,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
13
14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
15 const NAME: &'static str;
16 fn into_envelope(
17 self,
18 id: u32,
19 responding_to: Option<u32>,
20 original_sender_id: Option<u32>,
21 ) -> Envelope;
22 fn from_envelope(envelope: Envelope) -> Option<Self>;
23}
24
25pub trait EntityMessage: EnvelopedMessage {
26 fn remote_entity_id(&self) -> u64;
27}
28
29pub trait RequestMessage: EnvelopedMessage {
30 type Response: EnvelopedMessage;
31}
32
33pub trait AnyTypedEnvelope: 'static + Send + Sync {
34 fn payload_type_id(&self) -> TypeId;
35 fn payload_type_name(&self) -> &'static str;
36 fn as_any(&self) -> &dyn Any;
37 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
38}
39
40impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
41 fn payload_type_id(&self) -> TypeId {
42 TypeId::of::<T>()
43 }
44
45 fn payload_type_name(&self) -> &'static str {
46 T::NAME
47 }
48
49 fn as_any(&self) -> &dyn Any {
50 self
51 }
52
53 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
54 self
55 }
56}
57
58macro_rules! messages {
59 ($($name:ident),* $(,)?) => {
60 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
61 match envelope.payload {
62 $(Some(envelope::Payload::$name(payload)) => {
63 Some(Box::new(TypedEnvelope {
64 sender_id,
65 original_sender_id: envelope.original_sender_id.map(PeerId),
66 message_id: envelope.id,
67 payload,
68 }))
69 }, )*
70 _ => None
71 }
72 }
73
74 $(
75 impl EnvelopedMessage for $name {
76 const NAME: &'static str = std::stringify!($name);
77
78 fn into_envelope(
79 self,
80 id: u32,
81 responding_to: Option<u32>,
82 original_sender_id: Option<u32>,
83 ) -> Envelope {
84 Envelope {
85 id,
86 responding_to,
87 original_sender_id,
88 payload: Some(envelope::Payload::$name(self)),
89 }
90 }
91
92 fn from_envelope(envelope: Envelope) -> Option<Self> {
93 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
94 Some(msg)
95 } else {
96 None
97 }
98 }
99 }
100 )*
101 };
102}
103
104macro_rules! request_messages {
105 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
106 $(impl RequestMessage for $request_name {
107 type Response = $response_name;
108 })*
109 };
110}
111
112macro_rules! entity_messages {
113 ($id_field:ident, $($name:ident),* $(,)?) => {
114 $(impl EntityMessage for $name {
115 fn remote_entity_id(&self) -> u64 {
116 self.$id_field
117 }
118 })*
119 };
120}
121
122messages!(
123 Ack,
124 AddProjectCollaborator,
125 ApplyCompletionAdditionalEdits,
126 ApplyCompletionAdditionalEditsResponse,
127 BufferReloaded,
128 BufferSaved,
129 ChannelMessageSent,
130 CloseBuffer,
131 DiskBasedDiagnosticsUpdated,
132 DiskBasedDiagnosticsUpdating,
133 Error,
134 FormatBuffer,
135 GetChannelMessages,
136 GetChannelMessagesResponse,
137 GetChannels,
138 GetChannelsResponse,
139 GetCompletions,
140 GetCompletionsResponse,
141 GetDefinition,
142 GetDefinitionResponse,
143 GetUsers,
144 GetUsersResponse,
145 JoinChannel,
146 JoinChannelResponse,
147 JoinProject,
148 JoinProjectResponse,
149 LeaveChannel,
150 LeaveProject,
151 OpenBuffer,
152 OpenBufferResponse,
153 RegisterProjectResponse,
154 Ping,
155 RegisterProject,
156 RegisterWorktree,
157 RemoveProjectCollaborator,
158 SaveBuffer,
159 SendChannelMessage,
160 SendChannelMessageResponse,
161 ShareProject,
162 ShareWorktree,
163 UnregisterProject,
164 UnregisterWorktree,
165 UnshareProject,
166 UpdateBuffer,
167 UpdateBufferFile,
168 UpdateContacts,
169 UpdateDiagnosticSummary,
170 UpdateWorktree,
171);
172
173request_messages!(
174 (
175 ApplyCompletionAdditionalEdits,
176 ApplyCompletionAdditionalEditsResponse
177 ),
178 (FormatBuffer, Ack),
179 (GetChannelMessages, GetChannelMessagesResponse),
180 (GetChannels, GetChannelsResponse),
181 (GetCompletions, GetCompletionsResponse),
182 (GetDefinition, GetDefinitionResponse),
183 (GetUsers, GetUsersResponse),
184 (JoinChannel, JoinChannelResponse),
185 (JoinProject, JoinProjectResponse),
186 (OpenBuffer, OpenBufferResponse),
187 (Ping, Ack),
188 (RegisterProject, RegisterProjectResponse),
189 (RegisterWorktree, Ack),
190 (SaveBuffer, BufferSaved),
191 (SendChannelMessage, SendChannelMessageResponse),
192 (ShareProject, Ack),
193 (ShareWorktree, Ack),
194 (UpdateBuffer, Ack),
195);
196
197entity_messages!(
198 project_id,
199 AddProjectCollaborator,
200 ApplyCompletionAdditionalEdits,
201 BufferReloaded,
202 BufferSaved,
203 CloseBuffer,
204 DiskBasedDiagnosticsUpdated,
205 DiskBasedDiagnosticsUpdating,
206 FormatBuffer,
207 GetCompletions,
208 GetDefinition,
209 JoinProject,
210 LeaveProject,
211 OpenBuffer,
212 RemoveProjectCollaborator,
213 SaveBuffer,
214 ShareWorktree,
215 UnregisterWorktree,
216 UnshareProject,
217 UpdateBuffer,
218 UpdateBufferFile,
219 UpdateDiagnosticSummary,
220 UpdateWorktree,
221);
222
223entity_messages!(channel_id, ChannelMessageSent);
224
225/// A stream of protobuf messages.
226pub struct MessageStream<S> {
227 stream: S,
228 encoding_buffer: Vec<u8>,
229}
230
231impl<S> MessageStream<S> {
232 pub fn new(stream: S) -> Self {
233 Self {
234 stream,
235 encoding_buffer: Vec::new(),
236 }
237 }
238
239 pub fn inner_mut(&mut self) -> &mut S {
240 &mut self.stream
241 }
242}
243
244impl<S> MessageStream<S>
245where
246 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
247{
248 /// Write a given protobuf message to the stream.
249 pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
250 self.encoding_buffer.resize(message.encoded_len(), 0);
251 self.encoding_buffer.clear();
252 message
253 .encode(&mut self.encoding_buffer)
254 .map_err(|err| io::Error::from(err))?;
255 let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
256 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
257 Ok(())
258 }
259}
260
261impl<S> MessageStream<S>
262where
263 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
264{
265 /// Read a protobuf message of the given type from the stream.
266 pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
267 while let Some(bytes) = self.stream.next().await {
268 match bytes? {
269 WebSocketMessage::Binary(bytes) => {
270 self.encoding_buffer.clear();
271 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
272 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
273 .map_err(io::Error::from)?;
274 return Ok(envelope);
275 }
276 WebSocketMessage::Close(_) => break,
277 _ => {}
278 }
279 }
280 Err(WebSocketError::ConnectionClosed)
281 }
282}
283
284impl Into<SystemTime> for Timestamp {
285 fn into(self) -> SystemTime {
286 UNIX_EPOCH
287 .checked_add(Duration::new(self.seconds, self.nanos))
288 .unwrap()
289 }
290}
291
292impl From<SystemTime> for Timestamp {
293 fn from(time: SystemTime) -> Self {
294 let duration = time.duration_since(UNIX_EPOCH).unwrap();
295 Self {
296 seconds: duration.as_secs(),
297 nanos: duration.subsec_nanos(),
298 }
299 }
300}
301
302impl From<u128> for Nonce {
303 fn from(nonce: u128) -> Self {
304 let upper_half = (nonce >> 64) as u64;
305 let lower_half = nonce as u64;
306 Self {
307 upper_half,
308 lower_half,
309 }
310 }
311}
312
313impl From<Nonce> for u128 {
314 fn from(nonce: Nonce) -> Self {
315 let upper_half = (nonce.upper_half as u128) << 64;
316 let lower_half = nonce.lower_half as u128;
317 upper_half | lower_half
318 }
319}