1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::Result;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message as _;
6use std::any::{Any, TypeId};
7use std::{
8 io,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
13
14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
15 const NAME: &'static str;
16 const PRIORITY: MessagePriority;
17 fn into_envelope(
18 self,
19 id: u32,
20 responding_to: Option<u32>,
21 original_sender_id: Option<u32>,
22 ) -> Envelope;
23 fn from_envelope(envelope: Envelope) -> Option<Self>;
24}
25
26pub trait EntityMessage: EnvelopedMessage {
27 fn remote_entity_id(&self) -> u64;
28}
29
30pub trait RequestMessage: EnvelopedMessage {
31 type Response: EnvelopedMessage;
32}
33
34pub trait AnyTypedEnvelope: 'static + Send + Sync {
35 fn payload_type_id(&self) -> TypeId;
36 fn payload_type_name(&self) -> &'static str;
37 fn as_any(&self) -> &dyn Any;
38 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
39 fn is_background(&self) -> bool;
40 fn original_sender_id(&self) -> Option<PeerId>;
41}
42
43pub enum MessagePriority {
44 Foreground,
45 Background,
46}
47
48impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
49 fn payload_type_id(&self) -> TypeId {
50 TypeId::of::<T>()
51 }
52
53 fn payload_type_name(&self) -> &'static str {
54 T::NAME
55 }
56
57 fn as_any(&self) -> &dyn Any {
58 self
59 }
60
61 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
62 self
63 }
64
65 fn is_background(&self) -> bool {
66 matches!(T::PRIORITY, MessagePriority::Background)
67 }
68
69 fn original_sender_id(&self) -> Option<PeerId> {
70 self.original_sender_id
71 }
72}
73
74macro_rules! messages {
75 ($(($name:ident, $priority:ident)),* $(,)?) => {
76 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
77 match envelope.payload {
78 $(Some(envelope::Payload::$name(payload)) => {
79 Some(Box::new(TypedEnvelope {
80 sender_id,
81 original_sender_id: envelope.original_sender_id.map(PeerId),
82 message_id: envelope.id,
83 payload,
84 }))
85 }, )*
86 _ => None
87 }
88 }
89
90 $(
91 impl EnvelopedMessage for $name {
92 const NAME: &'static str = std::stringify!($name);
93 const PRIORITY: MessagePriority = MessagePriority::$priority;
94
95 fn into_envelope(
96 self,
97 id: u32,
98 responding_to: Option<u32>,
99 original_sender_id: Option<u32>,
100 ) -> Envelope {
101 Envelope {
102 id,
103 responding_to,
104 original_sender_id,
105 payload: Some(envelope::Payload::$name(self)),
106 }
107 }
108
109 fn from_envelope(envelope: Envelope) -> Option<Self> {
110 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
111 Some(msg)
112 } else {
113 None
114 }
115 }
116 }
117 )*
118 };
119}
120
121macro_rules! request_messages {
122 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
123 $(impl RequestMessage for $request_name {
124 type Response = $response_name;
125 })*
126 };
127}
128
129macro_rules! entity_messages {
130 ($id_field:ident, $($name:ident),* $(,)?) => {
131 $(impl EntityMessage for $name {
132 fn remote_entity_id(&self) -> u64 {
133 self.$id_field
134 }
135 })*
136 };
137}
138
139messages!(
140 (Ack, Foreground),
141 (AddProjectCollaborator, Foreground),
142 (ApplyCodeAction, Background),
143 (ApplyCodeActionResponse, Background),
144 (ApplyCompletionAdditionalEdits, Background),
145 (ApplyCompletionAdditionalEditsResponse, Background),
146 (BufferReloaded, Foreground),
147 (BufferSaved, Foreground),
148 (ChannelMessageSent, Foreground),
149 (DiskBasedDiagnosticsUpdated, Background),
150 (DiskBasedDiagnosticsUpdating, Background),
151 (Error, Foreground),
152 (FormatBuffers, Foreground),
153 (FormatBuffersResponse, Foreground),
154 (GetChannelMessages, Foreground),
155 (GetChannelMessagesResponse, Foreground),
156 (GetChannels, Foreground),
157 (GetChannelsResponse, Foreground),
158 (GetCodeActions, Background),
159 (GetCodeActionsResponse, Background),
160 (GetCompletions, Background),
161 (GetCompletionsResponse, Background),
162 (GetDefinition, Background),
163 (GetDefinitionResponse, Background),
164 (GetDocumentHighlights, Background),
165 (GetDocumentHighlightsResponse, Background),
166 (GetReferences, Background),
167 (GetReferencesResponse, Background),
168 (GetProjectSymbols, Background),
169 (GetProjectSymbolsResponse, Background),
170 (GetUsers, Foreground),
171 (GetUsersResponse, Foreground),
172 (JoinChannel, Foreground),
173 (JoinChannelResponse, Foreground),
174 (JoinProject, Foreground),
175 (JoinProjectResponse, Foreground),
176 (LeaveChannel, Foreground),
177 (LeaveProject, Foreground),
178 (OpenBuffer, Background),
179 (OpenBufferForSymbol, Background),
180 (OpenBufferForSymbolResponse, Background),
181 (OpenBufferResponse, Background),
182 (PerformRename, Background),
183 (PerformRenameResponse, Background),
184 (PrepareRename, Background),
185 (PrepareRenameResponse, Background),
186 (RegisterProjectResponse, Foreground),
187 (Ping, Foreground),
188 (RegisterProject, Foreground),
189 (RegisterWorktree, Foreground),
190 (RemoveProjectCollaborator, Foreground),
191 (SaveBuffer, Foreground),
192 (SearchProject, Background),
193 (SearchProjectResponse, Background),
194 (SendChannelMessage, Foreground),
195 (SendChannelMessageResponse, Foreground),
196 (ShareProject, Foreground),
197 (Test, Foreground),
198 (UnregisterProject, Foreground),
199 (UnregisterWorktree, Foreground),
200 (UnshareProject, Foreground),
201 (UpdateBuffer, Background),
202 (UpdateBufferFile, Foreground),
203 (UpdateContacts, Foreground),
204 (UpdateDiagnosticSummary, Foreground),
205 (UpdateWorktree, Foreground),
206);
207
208request_messages!(
209 (ApplyCodeAction, ApplyCodeActionResponse),
210 (
211 ApplyCompletionAdditionalEdits,
212 ApplyCompletionAdditionalEditsResponse
213 ),
214 (FormatBuffers, FormatBuffersResponse),
215 (GetChannelMessages, GetChannelMessagesResponse),
216 (GetChannels, GetChannelsResponse),
217 (GetCodeActions, GetCodeActionsResponse),
218 (GetCompletions, GetCompletionsResponse),
219 (GetDefinition, GetDefinitionResponse),
220 (GetDocumentHighlights, GetDocumentHighlightsResponse),
221 (GetReferences, GetReferencesResponse),
222 (GetProjectSymbols, GetProjectSymbolsResponse),
223 (GetUsers, GetUsersResponse),
224 (JoinChannel, JoinChannelResponse),
225 (JoinProject, JoinProjectResponse),
226 (OpenBuffer, OpenBufferResponse),
227 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
228 (Ping, Ack),
229 (PerformRename, PerformRenameResponse),
230 (PrepareRename, PrepareRenameResponse),
231 (RegisterProject, RegisterProjectResponse),
232 (RegisterWorktree, Ack),
233 (SaveBuffer, BufferSaved),
234 (SearchProject, SearchProjectResponse),
235 (SendChannelMessage, SendChannelMessageResponse),
236 (ShareProject, Ack),
237 (Test, Test),
238 (UpdateBuffer, Ack),
239 (UpdateWorktree, Ack),
240);
241
242entity_messages!(
243 project_id,
244 AddProjectCollaborator,
245 ApplyCodeAction,
246 ApplyCompletionAdditionalEdits,
247 BufferReloaded,
248 BufferSaved,
249 DiskBasedDiagnosticsUpdated,
250 DiskBasedDiagnosticsUpdating,
251 FormatBuffers,
252 GetCodeActions,
253 GetCompletions,
254 GetDefinition,
255 GetDocumentHighlights,
256 GetReferences,
257 GetProjectSymbols,
258 JoinProject,
259 LeaveProject,
260 OpenBuffer,
261 OpenBufferForSymbol,
262 PerformRename,
263 PrepareRename,
264 RemoveProjectCollaborator,
265 SaveBuffer,
266 SearchProject,
267 UnregisterWorktree,
268 UnshareProject,
269 UpdateBuffer,
270 UpdateBufferFile,
271 UpdateDiagnosticSummary,
272 RegisterWorktree,
273 UpdateWorktree,
274);
275
276entity_messages!(channel_id, ChannelMessageSent);
277
278/// A stream of protobuf messages.
279pub struct MessageStream<S> {
280 stream: S,
281 encoding_buffer: Vec<u8>,
282}
283
284#[derive(Debug)]
285pub enum Message {
286 Envelope(Envelope),
287 Ping,
288 Pong,
289}
290
291impl<S> MessageStream<S> {
292 pub fn new(stream: S) -> Self {
293 Self {
294 stream,
295 encoding_buffer: Vec::new(),
296 }
297 }
298
299 pub fn inner_mut(&mut self) -> &mut S {
300 &mut self.stream
301 }
302}
303
304impl<S> MessageStream<S>
305where
306 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
307{
308 pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> {
309 #[cfg(any(test, feature = "test-support"))]
310 const COMPRESSION_LEVEL: i32 = -7;
311
312 #[cfg(not(any(test, feature = "test-support")))]
313 const COMPRESSION_LEVEL: i32 = 4;
314
315 match message {
316 Message::Envelope(message) => {
317 self.encoding_buffer.resize(message.encoded_len(), 0);
318 self.encoding_buffer.clear();
319 message
320 .encode(&mut self.encoding_buffer)
321 .map_err(|err| io::Error::from(err))?;
322 let buffer =
323 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
324 .unwrap();
325 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
326 }
327 Message::Ping => {
328 self.stream
329 .send(WebSocketMessage::Ping(Default::default()))
330 .await?;
331 }
332 Message::Pong => {
333 self.stream
334 .send(WebSocketMessage::Pong(Default::default()))
335 .await?;
336 }
337 }
338
339 Ok(())
340 }
341}
342
343impl<S> MessageStream<S>
344where
345 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
346{
347 pub async fn read(&mut self) -> Result<Message, WebSocketError> {
348 while let Some(bytes) = self.stream.next().await {
349 match bytes? {
350 WebSocketMessage::Binary(bytes) => {
351 self.encoding_buffer.clear();
352 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
353 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
354 .map_err(io::Error::from)?;
355 return Ok(Message::Envelope(envelope));
356 }
357 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
358 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
359 WebSocketMessage::Close(_) => break,
360 _ => {}
361 }
362 }
363 Err(WebSocketError::ConnectionClosed)
364 }
365}
366
367impl Into<SystemTime> for Timestamp {
368 fn into(self) -> SystemTime {
369 UNIX_EPOCH
370 .checked_add(Duration::new(self.seconds, self.nanos))
371 .unwrap()
372 }
373}
374
375impl From<SystemTime> for Timestamp {
376 fn from(time: SystemTime) -> Self {
377 let duration = time.duration_since(UNIX_EPOCH).unwrap();
378 Self {
379 seconds: duration.as_secs(),
380 nanos: duration.subsec_nanos(),
381 }
382 }
383}
384
385impl From<u128> for Nonce {
386 fn from(nonce: u128) -> Self {
387 let upper_half = (nonce >> 64) as u64;
388 let lower_half = nonce as u64;
389 Self {
390 upper_half,
391 lower_half,
392 }
393 }
394}
395
396impl From<Nonce> for u128 {
397 fn from(nonce: Nonce) -> Self {
398 let upper_half = (nonce.upper_half as u128) << 64;
399 let lower_half = nonce.lower_half as u128;
400 upper_half | lower_half
401 }
402}