1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::Result;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message;
6use std::any::{Any, TypeId};
7use std::{
8 io,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
13
14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
15 const NAME: &'static str;
16 const PRIORITY: MessagePriority;
17 fn into_envelope(
18 self,
19 id: u32,
20 responding_to: Option<u32>,
21 original_sender_id: Option<u32>,
22 ) -> Envelope;
23 fn from_envelope(envelope: Envelope) -> Option<Self>;
24}
25
26pub trait EntityMessage: EnvelopedMessage {
27 fn remote_entity_id(&self) -> u64;
28}
29
30pub trait RequestMessage: EnvelopedMessage {
31 type Response: EnvelopedMessage;
32}
33
34pub trait AnyTypedEnvelope: 'static + Send + Sync {
35 fn payload_type_id(&self) -> TypeId;
36 fn payload_type_name(&self) -> &'static str;
37 fn as_any(&self) -> &dyn Any;
38 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
39 fn is_background(&self) -> bool;
40}
41
42pub enum MessagePriority {
43 Foreground,
44 Background,
45}
46
47impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
48 fn payload_type_id(&self) -> TypeId {
49 TypeId::of::<T>()
50 }
51
52 fn payload_type_name(&self) -> &'static str {
53 T::NAME
54 }
55
56 fn as_any(&self) -> &dyn Any {
57 self
58 }
59
60 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
61 self
62 }
63
64 fn is_background(&self) -> bool {
65 matches!(T::PRIORITY, MessagePriority::Background)
66 }
67}
68
69macro_rules! messages {
70 ($(($name:ident, $priority:ident)),* $(,)?) => {
71 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
72 match envelope.payload {
73 $(Some(envelope::Payload::$name(payload)) => {
74 Some(Box::new(TypedEnvelope {
75 sender_id,
76 original_sender_id: envelope.original_sender_id.map(PeerId),
77 message_id: envelope.id,
78 payload,
79 }))
80 }, )*
81 _ => None
82 }
83 }
84
85 $(
86 impl EnvelopedMessage for $name {
87 const NAME: &'static str = std::stringify!($name);
88 const PRIORITY: MessagePriority = MessagePriority::$priority;
89
90 fn into_envelope(
91 self,
92 id: u32,
93 responding_to: Option<u32>,
94 original_sender_id: Option<u32>,
95 ) -> Envelope {
96 Envelope {
97 id,
98 responding_to,
99 original_sender_id,
100 payload: Some(envelope::Payload::$name(self)),
101 }
102 }
103
104 fn from_envelope(envelope: Envelope) -> Option<Self> {
105 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
106 Some(msg)
107 } else {
108 None
109 }
110 }
111 }
112 )*
113 };
114}
115
116macro_rules! request_messages {
117 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
118 $(impl RequestMessage for $request_name {
119 type Response = $response_name;
120 })*
121 };
122}
123
124macro_rules! entity_messages {
125 ($id_field:ident, $($name:ident),* $(,)?) => {
126 $(impl EntityMessage for $name {
127 fn remote_entity_id(&self) -> u64 {
128 self.$id_field
129 }
130 })*
131 };
132}
133
134messages!(
135 (Ack, Foreground),
136 (AddProjectCollaborator, Foreground),
137 (ApplyCodeAction, Foreground),
138 (ApplyCodeActionResponse, Foreground),
139 (ApplyCompletionAdditionalEdits, Foreground),
140 (ApplyCompletionAdditionalEditsResponse, Foreground),
141 (BufferReloaded, Foreground),
142 (BufferSaved, Foreground),
143 (ChannelMessageSent, Foreground),
144 (CloseBuffer, Foreground),
145 (DiskBasedDiagnosticsUpdated, Background),
146 (DiskBasedDiagnosticsUpdating, Background),
147 (Error, Foreground),
148 (FormatBuffers, Foreground),
149 (FormatBuffersResponse, Foreground),
150 (GetChannelMessages, Foreground),
151 (GetChannelMessagesResponse, Foreground),
152 (GetChannels, Foreground),
153 (GetChannelsResponse, Foreground),
154 (GetCodeActions, Background),
155 (GetCodeActionsResponse, Foreground),
156 (GetCompletions, Background),
157 (GetCompletionsResponse, Foreground),
158 (GetDefinition, Foreground),
159 (GetDefinitionResponse, Foreground),
160 (GetUsers, Foreground),
161 (GetUsersResponse, Foreground),
162 (JoinChannel, Foreground),
163 (JoinChannelResponse, Foreground),
164 (JoinProject, Foreground),
165 (JoinProjectResponse, Foreground),
166 (LeaveChannel, Foreground),
167 (LeaveProject, Foreground),
168 (OpenBuffer, Foreground),
169 (OpenBufferResponse, Foreground),
170 (PerformRename, Background),
171 (PerformRenameResponse, Background),
172 (PrepareRename, Background),
173 (PrepareRenameResponse, Background),
174 (RegisterProjectResponse, Foreground),
175 (Ping, Foreground),
176 (RegisterProject, Foreground),
177 (RegisterWorktree, Foreground),
178 (RemoveProjectCollaborator, Foreground),
179 (SaveBuffer, Foreground),
180 (SendChannelMessage, Foreground),
181 (SendChannelMessageResponse, Foreground),
182 (ShareProject, Foreground),
183 (ShareWorktree, Foreground),
184 (Test, Foreground),
185 (UnregisterProject, Foreground),
186 (UnregisterWorktree, Foreground),
187 (UnshareProject, Foreground),
188 (UpdateBuffer, Foreground),
189 (UpdateBufferFile, Foreground),
190 (UpdateContacts, Foreground),
191 (UpdateDiagnosticSummary, Foreground),
192 (UpdateWorktree, Foreground),
193);
194
195request_messages!(
196 (ApplyCodeAction, ApplyCodeActionResponse),
197 (
198 ApplyCompletionAdditionalEdits,
199 ApplyCompletionAdditionalEditsResponse
200 ),
201 (FormatBuffers, FormatBuffersResponse),
202 (GetChannelMessages, GetChannelMessagesResponse),
203 (GetChannels, GetChannelsResponse),
204 (GetCodeActions, GetCodeActionsResponse),
205 (GetCompletions, GetCompletionsResponse),
206 (GetDefinition, GetDefinitionResponse),
207 (GetUsers, GetUsersResponse),
208 (JoinChannel, JoinChannelResponse),
209 (JoinProject, JoinProjectResponse),
210 (OpenBuffer, OpenBufferResponse),
211 (Ping, Ack),
212 (PerformRename, PerformRenameResponse),
213 (PrepareRename, PrepareRenameResponse),
214 (RegisterProject, RegisterProjectResponse),
215 (RegisterWorktree, Ack),
216 (SaveBuffer, BufferSaved),
217 (SendChannelMessage, SendChannelMessageResponse),
218 (ShareProject, Ack),
219 (ShareWorktree, Ack),
220 (Test, Test),
221 (UpdateBuffer, Ack),
222 (UpdateWorktree, Ack),
223);
224
225entity_messages!(
226 project_id,
227 AddProjectCollaborator,
228 ApplyCodeAction,
229 ApplyCompletionAdditionalEdits,
230 BufferReloaded,
231 BufferSaved,
232 CloseBuffer,
233 DiskBasedDiagnosticsUpdated,
234 DiskBasedDiagnosticsUpdating,
235 FormatBuffers,
236 GetCodeActions,
237 GetCompletions,
238 GetDefinition,
239 JoinProject,
240 LeaveProject,
241 OpenBuffer,
242 PerformRename,
243 PrepareRename,
244 RemoveProjectCollaborator,
245 SaveBuffer,
246 ShareWorktree,
247 UnregisterWorktree,
248 UnshareProject,
249 UpdateBuffer,
250 UpdateBufferFile,
251 UpdateDiagnosticSummary,
252 UpdateWorktree,
253);
254
255entity_messages!(channel_id, ChannelMessageSent);
256
257/// A stream of protobuf messages.
258pub struct MessageStream<S> {
259 stream: S,
260 encoding_buffer: Vec<u8>,
261}
262
263impl<S> MessageStream<S> {
264 pub fn new(stream: S) -> Self {
265 Self {
266 stream,
267 encoding_buffer: Vec::new(),
268 }
269 }
270
271 pub fn inner_mut(&mut self) -> &mut S {
272 &mut self.stream
273 }
274}
275
276impl<S> MessageStream<S>
277where
278 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
279{
280 /// Write a given protobuf message to the stream.
281 pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
282 #[cfg(any(test, feature = "test-support"))]
283 const COMPRESSION_LEVEL: i32 = -7;
284
285 #[cfg(not(any(test, feature = "test-support")))]
286 const COMPRESSION_LEVEL: i32 = 4;
287
288 self.encoding_buffer.resize(message.encoded_len(), 0);
289 self.encoding_buffer.clear();
290 message
291 .encode(&mut self.encoding_buffer)
292 .map_err(|err| io::Error::from(err))?;
293 let buffer =
294 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL).unwrap();
295 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
296 Ok(())
297 }
298}
299
300impl<S> MessageStream<S>
301where
302 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
303{
304 /// Read a protobuf message of the given type from the stream.
305 pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
306 while let Some(bytes) = self.stream.next().await {
307 match bytes? {
308 WebSocketMessage::Binary(bytes) => {
309 self.encoding_buffer.clear();
310 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
311 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
312 .map_err(io::Error::from)?;
313 return Ok(envelope);
314 }
315 WebSocketMessage::Close(_) => break,
316 _ => {}
317 }
318 }
319 Err(WebSocketError::ConnectionClosed)
320 }
321}
322
323impl Into<SystemTime> for Timestamp {
324 fn into(self) -> SystemTime {
325 UNIX_EPOCH
326 .checked_add(Duration::new(self.seconds, self.nanos))
327 .unwrap()
328 }
329}
330
331impl From<SystemTime> for Timestamp {
332 fn from(time: SystemTime) -> Self {
333 let duration = time.duration_since(UNIX_EPOCH).unwrap();
334 Self {
335 seconds: duration.as_secs(),
336 nanos: duration.subsec_nanos(),
337 }
338 }
339}
340
341impl From<u128> for Nonce {
342 fn from(nonce: u128) -> Self {
343 let upper_half = (nonce >> 64) as u64;
344 let lower_half = nonce as u64;
345 Self {
346 upper_half,
347 lower_half,
348 }
349 }
350}
351
352impl From<Nonce> for u128 {
353 fn from(nonce: Nonce) -> Self {
354 let upper_half = (nonce.upper_half as u128) << 64;
355 let lower_half = nonce.lower_half as u128;
356 upper_half | lower_half
357 }
358}