1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::Result;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message as _;
6use std::any::{Any, TypeId};
7use std::{
8 io,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
13
14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
15 const NAME: &'static str;
16 const PRIORITY: MessagePriority;
17 fn into_envelope(
18 self,
19 id: u32,
20 responding_to: Option<u32>,
21 original_sender_id: Option<u32>,
22 ) -> Envelope;
23 fn from_envelope(envelope: Envelope) -> Option<Self>;
24}
25
26pub trait EntityMessage: EnvelopedMessage {
27 fn remote_entity_id(&self) -> u64;
28}
29
30pub trait RequestMessage: EnvelopedMessage {
31 type Response: EnvelopedMessage;
32}
33
34pub trait AnyTypedEnvelope: 'static + Send + Sync {
35 fn payload_type_id(&self) -> TypeId;
36 fn payload_type_name(&self) -> &'static str;
37 fn as_any(&self) -> &dyn Any;
38 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
39 fn is_background(&self) -> bool;
40 fn original_sender_id(&self) -> Option<PeerId>;
41}
42
43pub enum MessagePriority {
44 Foreground,
45 Background,
46}
47
48impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
49 fn payload_type_id(&self) -> TypeId {
50 TypeId::of::<T>()
51 }
52
53 fn payload_type_name(&self) -> &'static str {
54 T::NAME
55 }
56
57 fn as_any(&self) -> &dyn Any {
58 self
59 }
60
61 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
62 self
63 }
64
65 fn is_background(&self) -> bool {
66 matches!(T::PRIORITY, MessagePriority::Background)
67 }
68
69 fn original_sender_id(&self) -> Option<PeerId> {
70 self.original_sender_id
71 }
72}
73
74macro_rules! messages {
75 ($(($name:ident, $priority:ident)),* $(,)?) => {
76 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
77 match envelope.payload {
78 $(Some(envelope::Payload::$name(payload)) => {
79 Some(Box::new(TypedEnvelope {
80 sender_id,
81 original_sender_id: envelope.original_sender_id.map(PeerId),
82 message_id: envelope.id,
83 payload,
84 }))
85 }, )*
86 _ => None
87 }
88 }
89
90 $(
91 impl EnvelopedMessage for $name {
92 const NAME: &'static str = std::stringify!($name);
93 const PRIORITY: MessagePriority = MessagePriority::$priority;
94
95 fn into_envelope(
96 self,
97 id: u32,
98 responding_to: Option<u32>,
99 original_sender_id: Option<u32>,
100 ) -> Envelope {
101 Envelope {
102 id,
103 responding_to,
104 original_sender_id,
105 payload: Some(envelope::Payload::$name(self)),
106 }
107 }
108
109 fn from_envelope(envelope: Envelope) -> Option<Self> {
110 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
111 Some(msg)
112 } else {
113 None
114 }
115 }
116 }
117 )*
118 };
119}
120
121macro_rules! request_messages {
122 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
123 $(impl RequestMessage for $request_name {
124 type Response = $response_name;
125 })*
126 };
127}
128
129macro_rules! entity_messages {
130 ($id_field:ident, $($name:ident),* $(,)?) => {
131 $(impl EntityMessage for $name {
132 fn remote_entity_id(&self) -> u64 {
133 self.$id_field
134 }
135 })*
136 };
137}
138
139messages!(
140 (Ack, Foreground),
141 (AddProjectCollaborator, Foreground),
142 (ApplyCodeAction, Background),
143 (ApplyCodeActionResponse, Background),
144 (ApplyCompletionAdditionalEdits, Background),
145 (ApplyCompletionAdditionalEditsResponse, Background),
146 (BufferReloaded, Foreground),
147 (BufferSaved, Foreground),
148 (ChannelMessageSent, Foreground),
149 (Error, Foreground),
150 (FormatBuffers, Foreground),
151 (FormatBuffersResponse, Foreground),
152 (GetChannelMessages, Foreground),
153 (GetChannelMessagesResponse, Foreground),
154 (GetChannels, Foreground),
155 (GetChannelsResponse, Foreground),
156 (GetCodeActions, Background),
157 (GetCodeActionsResponse, Background),
158 (GetCompletions, Background),
159 (GetCompletionsResponse, Background),
160 (GetDefinition, Background),
161 (GetDefinitionResponse, Background),
162 (GetDocumentHighlights, Background),
163 (GetDocumentHighlightsResponse, Background),
164 (GetReferences, Background),
165 (GetReferencesResponse, Background),
166 (GetProjectSymbols, Background),
167 (GetProjectSymbolsResponse, Background),
168 (GetUsers, Foreground),
169 (GetUsersResponse, Foreground),
170 (JoinChannel, Foreground),
171 (JoinChannelResponse, Foreground),
172 (JoinProject, Foreground),
173 (JoinProjectResponse, Foreground),
174 (StartLanguageServer, Foreground),
175 (UpdateLanguageServer, Foreground),
176 (LeaveChannel, Foreground),
177 (LeaveProject, Foreground),
178 (OpenBuffer, Background),
179 (OpenBufferForSymbol, Background),
180 (OpenBufferForSymbolResponse, Background),
181 (OpenBufferResponse, Background),
182 (PerformRename, Background),
183 (PerformRenameResponse, Background),
184 (PrepareRename, Background),
185 (PrepareRenameResponse, Background),
186 (RegisterProjectResponse, Foreground),
187 (Ping, Foreground),
188 (RegisterProject, Foreground),
189 (RegisterWorktree, Foreground),
190 (RemoveProjectCollaborator, Foreground),
191 (SaveBuffer, Foreground),
192 (SearchProject, Background),
193 (SearchProjectResponse, Background),
194 (SendChannelMessage, Foreground),
195 (SendChannelMessageResponse, Foreground),
196 (ShareProject, Foreground),
197 (Test, Foreground),
198 (UnregisterProject, Foreground),
199 (UnregisterWorktree, Foreground),
200 (UnshareProject, Foreground),
201 (UpdateBuffer, Background),
202 (UpdateBufferFile, Foreground),
203 (UpdateContacts, Foreground),
204 (UpdateDiagnosticSummary, Foreground),
205 (UpdateWorktree, Foreground),
206);
207
208request_messages!(
209 (ApplyCodeAction, ApplyCodeActionResponse),
210 (
211 ApplyCompletionAdditionalEdits,
212 ApplyCompletionAdditionalEditsResponse
213 ),
214 (FormatBuffers, FormatBuffersResponse),
215 (GetChannelMessages, GetChannelMessagesResponse),
216 (GetChannels, GetChannelsResponse),
217 (GetCodeActions, GetCodeActionsResponse),
218 (GetCompletions, GetCompletionsResponse),
219 (GetDefinition, GetDefinitionResponse),
220 (GetDocumentHighlights, GetDocumentHighlightsResponse),
221 (GetReferences, GetReferencesResponse),
222 (GetProjectSymbols, GetProjectSymbolsResponse),
223 (GetUsers, GetUsersResponse),
224 (JoinChannel, JoinChannelResponse),
225 (JoinProject, JoinProjectResponse),
226 (OpenBuffer, OpenBufferResponse),
227 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
228 (Ping, Ack),
229 (PerformRename, PerformRenameResponse),
230 (PrepareRename, PrepareRenameResponse),
231 (RegisterProject, RegisterProjectResponse),
232 (RegisterWorktree, Ack),
233 (SaveBuffer, BufferSaved),
234 (SearchProject, SearchProjectResponse),
235 (SendChannelMessage, SendChannelMessageResponse),
236 (ShareProject, Ack),
237 (Test, Test),
238 (UpdateBuffer, Ack),
239 (UpdateWorktree, Ack),
240);
241
242entity_messages!(
243 project_id,
244 AddProjectCollaborator,
245 ApplyCodeAction,
246 ApplyCompletionAdditionalEdits,
247 BufferReloaded,
248 BufferSaved,
249 FormatBuffers,
250 GetCodeActions,
251 GetCompletions,
252 GetDefinition,
253 GetDocumentHighlights,
254 GetReferences,
255 GetProjectSymbols,
256 JoinProject,
257 LeaveProject,
258 OpenBuffer,
259 OpenBufferForSymbol,
260 PerformRename,
261 PrepareRename,
262 RemoveProjectCollaborator,
263 SaveBuffer,
264 SearchProject,
265 StartLanguageServer,
266 UnregisterWorktree,
267 UnshareProject,
268 UpdateBuffer,
269 UpdateBufferFile,
270 UpdateDiagnosticSummary,
271 UpdateLanguageServer,
272 RegisterWorktree,
273 UpdateWorktree,
274);
275
276entity_messages!(channel_id, ChannelMessageSent);
277
278/// A stream of protobuf messages.
279pub struct MessageStream<S> {
280 stream: S,
281 encoding_buffer: Vec<u8>,
282}
283
284#[derive(Debug)]
285pub enum Message {
286 Envelope(Envelope),
287 Ping,
288 Pong,
289}
290
291impl<S> MessageStream<S> {
292 pub fn new(stream: S) -> Self {
293 Self {
294 stream,
295 encoding_buffer: Vec::new(),
296 }
297 }
298
299 pub fn inner_mut(&mut self) -> &mut S {
300 &mut self.stream
301 }
302}
303
304impl<S> MessageStream<S>
305where
306 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
307{
308 pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> {
309 #[cfg(any(test, feature = "test-support"))]
310 const COMPRESSION_LEVEL: i32 = -7;
311
312 #[cfg(not(any(test, feature = "test-support")))]
313 const COMPRESSION_LEVEL: i32 = 4;
314
315 match message {
316 Message::Envelope(message) => {
317 self.encoding_buffer.resize(message.encoded_len(), 0);
318 self.encoding_buffer.clear();
319 message
320 .encode(&mut self.encoding_buffer)
321 .map_err(|err| io::Error::from(err))?;
322 let buffer =
323 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
324 .unwrap();
325 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
326 }
327 Message::Ping => {
328 self.stream
329 .send(WebSocketMessage::Ping(Default::default()))
330 .await?;
331 }
332 Message::Pong => {
333 self.stream
334 .send(WebSocketMessage::Pong(Default::default()))
335 .await?;
336 }
337 }
338
339 Ok(())
340 }
341}
342
343impl<S> MessageStream<S>
344where
345 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
346{
347 pub async fn read(&mut self) -> Result<Message, WebSocketError> {
348 while let Some(bytes) = self.stream.next().await {
349 match bytes? {
350 WebSocketMessage::Binary(bytes) => {
351 self.encoding_buffer.clear();
352 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
353 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
354 .map_err(io::Error::from)?;
355 return Ok(Message::Envelope(envelope));
356 }
357 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
358 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
359 WebSocketMessage::Close(_) => break,
360 _ => {}
361 }
362 }
363 Err(WebSocketError::ConnectionClosed)
364 }
365}
366
367impl Into<SystemTime> for Timestamp {
368 fn into(self) -> SystemTime {
369 UNIX_EPOCH
370 .checked_add(Duration::new(self.seconds, self.nanos))
371 .unwrap()
372 }
373}
374
375impl From<SystemTime> for Timestamp {
376 fn from(time: SystemTime) -> Self {
377 let duration = time.duration_since(UNIX_EPOCH).unwrap();
378 Self {
379 seconds: duration.as_secs(),
380 nanos: duration.subsec_nanos(),
381 }
382 }
383}
384
385impl From<u128> for Nonce {
386 fn from(nonce: u128) -> Self {
387 let upper_half = (nonce >> 64) as u64;
388 let lower_half = nonce as u64;
389 Self {
390 upper_half,
391 lower_half,
392 }
393 }
394}
395
396impl From<Nonce> for u128 {
397 fn from(nonce: Nonce) -> Self {
398 let upper_half = (nonce.upper_half as u128) << 64;
399 let lower_half = nonce.lower_half as u128;
400 upper_half | lower_half
401 }
402}