1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::Result;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message as _;
6use std::any::{Any, TypeId};
7use std::{
8 io,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
13
14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
15 const NAME: &'static str;
16 const PRIORITY: MessagePriority;
17 fn into_envelope(
18 self,
19 id: u32,
20 responding_to: Option<u32>,
21 original_sender_id: Option<u32>,
22 ) -> Envelope;
23 fn from_envelope(envelope: Envelope) -> Option<Self>;
24}
25
26pub trait EntityMessage: EnvelopedMessage {
27 fn remote_entity_id(&self) -> u64;
28}
29
30pub trait RequestMessage: EnvelopedMessage {
31 type Response: EnvelopedMessage;
32}
33
34pub trait AnyTypedEnvelope: 'static + Send + Sync {
35 fn payload_type_id(&self) -> TypeId;
36 fn payload_type_name(&self) -> &'static str;
37 fn as_any(&self) -> &dyn Any;
38 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
39 fn is_background(&self) -> bool;
40 fn original_sender_id(&self) -> Option<PeerId>;
41}
42
43pub enum MessagePriority {
44 Foreground,
45 Background,
46}
47
48impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
49 fn payload_type_id(&self) -> TypeId {
50 TypeId::of::<T>()
51 }
52
53 fn payload_type_name(&self) -> &'static str {
54 T::NAME
55 }
56
57 fn as_any(&self) -> &dyn Any {
58 self
59 }
60
61 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
62 self
63 }
64
65 fn is_background(&self) -> bool {
66 matches!(T::PRIORITY, MessagePriority::Background)
67 }
68
69 fn original_sender_id(&self) -> Option<PeerId> {
70 self.original_sender_id
71 }
72}
73
74macro_rules! messages {
75 ($(($name:ident, $priority:ident)),* $(,)?) => {
76 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
77 match envelope.payload {
78 $(Some(envelope::Payload::$name(payload)) => {
79 Some(Box::new(TypedEnvelope {
80 sender_id,
81 original_sender_id: envelope.original_sender_id.map(PeerId),
82 message_id: envelope.id,
83 payload,
84 }))
85 }, )*
86 _ => None
87 }
88 }
89
90 $(
91 impl EnvelopedMessage for $name {
92 const NAME: &'static str = std::stringify!($name);
93 const PRIORITY: MessagePriority = MessagePriority::$priority;
94
95 fn into_envelope(
96 self,
97 id: u32,
98 responding_to: Option<u32>,
99 original_sender_id: Option<u32>,
100 ) -> Envelope {
101 Envelope {
102 id,
103 responding_to,
104 original_sender_id,
105 payload: Some(envelope::Payload::$name(self)),
106 }
107 }
108
109 fn from_envelope(envelope: Envelope) -> Option<Self> {
110 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
111 Some(msg)
112 } else {
113 None
114 }
115 }
116 }
117 )*
118 };
119}
120
121macro_rules! request_messages {
122 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
123 $(impl RequestMessage for $request_name {
124 type Response = $response_name;
125 })*
126 };
127}
128
129macro_rules! entity_messages {
130 ($id_field:ident, $($name:ident),* $(,)?) => {
131 $(impl EntityMessage for $name {
132 fn remote_entity_id(&self) -> u64 {
133 self.$id_field
134 }
135 })*
136 };
137}
138
139messages!(
140 (Ack, Foreground),
141 (AddProjectCollaborator, Foreground),
142 (ApplyCodeAction, Background),
143 (ApplyCodeActionResponse, Background),
144 (ApplyCompletionAdditionalEdits, Background),
145 (ApplyCompletionAdditionalEditsResponse, Background),
146 (BufferReloaded, Foreground),
147 (BufferSaved, Foreground),
148 (ChannelMessageSent, Foreground),
149 (CloseBuffer, Foreground),
150 (DiskBasedDiagnosticsUpdated, Background),
151 (DiskBasedDiagnosticsUpdating, Background),
152 (Error, Foreground),
153 (FormatBuffers, Foreground),
154 (FormatBuffersResponse, Foreground),
155 (GetChannelMessages, Foreground),
156 (GetChannelMessagesResponse, Foreground),
157 (GetChannels, Foreground),
158 (GetChannelsResponse, Foreground),
159 (GetCodeActions, Background),
160 (GetCodeActionsResponse, Background),
161 (GetCompletions, Background),
162 (GetCompletionsResponse, Background),
163 (GetDefinition, Background),
164 (GetDefinitionResponse, Background),
165 (GetDocumentHighlights, Background),
166 (GetDocumentHighlightsResponse, Background),
167 (GetReferences, Background),
168 (GetReferencesResponse, Background),
169 (GetProjectSymbols, Background),
170 (GetProjectSymbolsResponse, Background),
171 (GetUsers, Foreground),
172 (GetUsersResponse, Foreground),
173 (JoinChannel, Foreground),
174 (JoinChannelResponse, Foreground),
175 (JoinProject, Foreground),
176 (JoinProjectResponse, Foreground),
177 (LeaveChannel, Foreground),
178 (LeaveProject, Foreground),
179 (OpenBuffer, Background),
180 (OpenBufferForSymbol, Background),
181 (OpenBufferForSymbolResponse, Background),
182 (OpenBufferResponse, Background),
183 (PerformRename, Background),
184 (PerformRenameResponse, Background),
185 (PrepareRename, Background),
186 (PrepareRenameResponse, Background),
187 (RegisterProjectResponse, Foreground),
188 (Ping, Foreground),
189 (RegisterProject, Foreground),
190 (RegisterWorktree, Foreground),
191 (RemoveProjectCollaborator, Foreground),
192 (SaveBuffer, Foreground),
193 (SearchProject, Background),
194 (SearchProjectResponse, Background),
195 (SendChannelMessage, Foreground),
196 (SendChannelMessageResponse, Foreground),
197 (ShareProject, Foreground),
198 (Test, Foreground),
199 (UnregisterProject, Foreground),
200 (UnregisterWorktree, Foreground),
201 (UnshareProject, Foreground),
202 (UpdateBuffer, Background),
203 (UpdateBufferFile, Foreground),
204 (UpdateContacts, Foreground),
205 (UpdateDiagnosticSummary, Foreground),
206 (UpdateWorktree, Foreground),
207);
208
209request_messages!(
210 (ApplyCodeAction, ApplyCodeActionResponse),
211 (
212 ApplyCompletionAdditionalEdits,
213 ApplyCompletionAdditionalEditsResponse
214 ),
215 (FormatBuffers, FormatBuffersResponse),
216 (GetChannelMessages, GetChannelMessagesResponse),
217 (GetChannels, GetChannelsResponse),
218 (GetCodeActions, GetCodeActionsResponse),
219 (GetCompletions, GetCompletionsResponse),
220 (GetDefinition, GetDefinitionResponse),
221 (GetDocumentHighlights, GetDocumentHighlightsResponse),
222 (GetReferences, GetReferencesResponse),
223 (GetProjectSymbols, GetProjectSymbolsResponse),
224 (GetUsers, GetUsersResponse),
225 (JoinChannel, JoinChannelResponse),
226 (JoinProject, JoinProjectResponse),
227 (OpenBuffer, OpenBufferResponse),
228 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
229 (Ping, Ack),
230 (PerformRename, PerformRenameResponse),
231 (PrepareRename, PrepareRenameResponse),
232 (RegisterProject, RegisterProjectResponse),
233 (RegisterWorktree, Ack),
234 (SaveBuffer, BufferSaved),
235 (SearchProject, SearchProjectResponse),
236 (SendChannelMessage, SendChannelMessageResponse),
237 (ShareProject, Ack),
238 (Test, Test),
239 (UpdateBuffer, Ack),
240 (UpdateWorktree, Ack),
241);
242
243entity_messages!(
244 project_id,
245 AddProjectCollaborator,
246 ApplyCodeAction,
247 ApplyCompletionAdditionalEdits,
248 BufferReloaded,
249 BufferSaved,
250 CloseBuffer,
251 DiskBasedDiagnosticsUpdated,
252 DiskBasedDiagnosticsUpdating,
253 FormatBuffers,
254 GetCodeActions,
255 GetCompletions,
256 GetDefinition,
257 GetDocumentHighlights,
258 GetReferences,
259 GetProjectSymbols,
260 JoinProject,
261 LeaveProject,
262 OpenBuffer,
263 OpenBufferForSymbol,
264 PerformRename,
265 PrepareRename,
266 RemoveProjectCollaborator,
267 SaveBuffer,
268 SearchProject,
269 UnregisterWorktree,
270 UnshareProject,
271 UpdateBuffer,
272 UpdateBufferFile,
273 UpdateDiagnosticSummary,
274 RegisterWorktree,
275 UpdateWorktree,
276);
277
278entity_messages!(channel_id, ChannelMessageSent);
279
280/// A stream of protobuf messages.
281pub struct MessageStream<S> {
282 stream: S,
283 encoding_buffer: Vec<u8>,
284}
285
286#[derive(Debug)]
287pub enum Message {
288 Envelope(Envelope),
289 Ping,
290 Pong,
291}
292
293impl<S> MessageStream<S> {
294 pub fn new(stream: S) -> Self {
295 Self {
296 stream,
297 encoding_buffer: Vec::new(),
298 }
299 }
300
301 pub fn inner_mut(&mut self) -> &mut S {
302 &mut self.stream
303 }
304}
305
306impl<S> MessageStream<S>
307where
308 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
309{
310 pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> {
311 #[cfg(any(test, feature = "test-support"))]
312 const COMPRESSION_LEVEL: i32 = -7;
313
314 #[cfg(not(any(test, feature = "test-support")))]
315 const COMPRESSION_LEVEL: i32 = 4;
316
317 match message {
318 Message::Envelope(message) => {
319 self.encoding_buffer.resize(message.encoded_len(), 0);
320 self.encoding_buffer.clear();
321 message
322 .encode(&mut self.encoding_buffer)
323 .map_err(|err| io::Error::from(err))?;
324 let buffer =
325 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
326 .unwrap();
327 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
328 }
329 Message::Ping => {
330 self.stream
331 .send(WebSocketMessage::Ping(Default::default()))
332 .await?;
333 }
334 Message::Pong => {
335 self.stream
336 .send(WebSocketMessage::Pong(Default::default()))
337 .await?;
338 }
339 }
340
341 Ok(())
342 }
343}
344
345impl<S> MessageStream<S>
346where
347 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
348{
349 pub async fn read(&mut self) -> Result<Message, WebSocketError> {
350 while let Some(bytes) = self.stream.next().await {
351 match bytes? {
352 WebSocketMessage::Binary(bytes) => {
353 self.encoding_buffer.clear();
354 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
355 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
356 .map_err(io::Error::from)?;
357 return Ok(Message::Envelope(envelope));
358 }
359 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
360 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
361 WebSocketMessage::Close(_) => break,
362 _ => {}
363 }
364 }
365 Err(WebSocketError::ConnectionClosed)
366 }
367}
368
369impl Into<SystemTime> for Timestamp {
370 fn into(self) -> SystemTime {
371 UNIX_EPOCH
372 .checked_add(Duration::new(self.seconds, self.nanos))
373 .unwrap()
374 }
375}
376
377impl From<SystemTime> for Timestamp {
378 fn from(time: SystemTime) -> Self {
379 let duration = time.duration_since(UNIX_EPOCH).unwrap();
380 Self {
381 seconds: duration.as_secs(),
382 nanos: duration.subsec_nanos(),
383 }
384 }
385}
386
387impl From<u128> for Nonce {
388 fn from(nonce: u128) -> Self {
389 let upper_half = (nonce >> 64) as u64;
390 let lower_half = nonce as u64;
391 Self {
392 upper_half,
393 lower_half,
394 }
395 }
396}
397
398impl From<Nonce> for u128 {
399 fn from(nonce: Nonce) -> Self {
400 let upper_half = (nonce.upper_half as u128) << 64;
401 let lower_half = nonce.lower_half as u128;
402 upper_half | lower_half
403 }
404}