1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::Result;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message as _;
6use std::any::{Any, TypeId};
7use std::{
8 io,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
13
14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
15 const NAME: &'static str;
16 const PRIORITY: MessagePriority;
17 fn into_envelope(
18 self,
19 id: u32,
20 responding_to: Option<u32>,
21 original_sender_id: Option<u32>,
22 ) -> Envelope;
23 fn from_envelope(envelope: Envelope) -> Option<Self>;
24}
25
26pub trait EntityMessage: EnvelopedMessage {
27 fn remote_entity_id(&self) -> u64;
28}
29
30pub trait RequestMessage: EnvelopedMessage {
31 type Response: EnvelopedMessage;
32}
33
34pub trait AnyTypedEnvelope: 'static + Send + Sync {
35 fn payload_type_id(&self) -> TypeId;
36 fn payload_type_name(&self) -> &'static str;
37 fn as_any(&self) -> &dyn Any;
38 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
39 fn is_background(&self) -> bool;
40 fn original_sender_id(&self) -> Option<PeerId>;
41}
42
43pub enum MessagePriority {
44 Foreground,
45 Background,
46}
47
48impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
49 fn payload_type_id(&self) -> TypeId {
50 TypeId::of::<T>()
51 }
52
53 fn payload_type_name(&self) -> &'static str {
54 T::NAME
55 }
56
57 fn as_any(&self) -> &dyn Any {
58 self
59 }
60
61 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
62 self
63 }
64
65 fn is_background(&self) -> bool {
66 matches!(T::PRIORITY, MessagePriority::Background)
67 }
68
69 fn original_sender_id(&self) -> Option<PeerId> {
70 self.original_sender_id
71 }
72}
73
74macro_rules! messages {
75 ($(($name:ident, $priority:ident)),* $(,)?) => {
76 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
77 match envelope.payload {
78 $(Some(envelope::Payload::$name(payload)) => {
79 Some(Box::new(TypedEnvelope {
80 sender_id,
81 original_sender_id: envelope.original_sender_id.map(PeerId),
82 message_id: envelope.id,
83 payload,
84 }))
85 }, )*
86 _ => None
87 }
88 }
89
90 $(
91 impl EnvelopedMessage for $name {
92 const NAME: &'static str = std::stringify!($name);
93 const PRIORITY: MessagePriority = MessagePriority::$priority;
94
95 fn into_envelope(
96 self,
97 id: u32,
98 responding_to: Option<u32>,
99 original_sender_id: Option<u32>,
100 ) -> Envelope {
101 Envelope {
102 id,
103 responding_to,
104 original_sender_id,
105 payload: Some(envelope::Payload::$name(self)),
106 }
107 }
108
109 fn from_envelope(envelope: Envelope) -> Option<Self> {
110 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
111 Some(msg)
112 } else {
113 None
114 }
115 }
116 }
117 )*
118 };
119}
120
121macro_rules! request_messages {
122 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
123 $(impl RequestMessage for $request_name {
124 type Response = $response_name;
125 })*
126 };
127}
128
129macro_rules! entity_messages {
130 ($id_field:ident, $($name:ident),* $(,)?) => {
131 $(impl EntityMessage for $name {
132 fn remote_entity_id(&self) -> u64 {
133 self.$id_field
134 }
135 })*
136 };
137}
138
139messages!(
140 (Ack, Foreground),
141 (AddProjectCollaborator, Foreground),
142 (ApplyCodeAction, Background),
143 (ApplyCodeActionResponse, Background),
144 (ApplyCompletionAdditionalEdits, Background),
145 (ApplyCompletionAdditionalEditsResponse, Background),
146 (BufferReloaded, Foreground),
147 (BufferSaved, Foreground),
148 (ChannelMessageSent, Foreground),
149 (Error, Foreground),
150 (Follow, Foreground),
151 (FollowResponse, Foreground),
152 (FormatBuffers, Foreground),
153 (FormatBuffersResponse, Foreground),
154 (GetChannelMessages, Foreground),
155 (GetChannelMessagesResponse, Foreground),
156 (GetChannels, Foreground),
157 (GetChannelsResponse, Foreground),
158 (GetCodeActions, Background),
159 (GetCodeActionsResponse, Background),
160 (GetCompletions, Background),
161 (GetCompletionsResponse, Background),
162 (GetDefinition, Background),
163 (GetDefinitionResponse, Background),
164 (GetDocumentHighlights, Background),
165 (GetDocumentHighlightsResponse, Background),
166 (GetReferences, Background),
167 (GetReferencesResponse, Background),
168 (GetProjectSymbols, Background),
169 (GetProjectSymbolsResponse, Background),
170 (GetUsers, Foreground),
171 (GetUsersResponse, Foreground),
172 (JoinChannel, Foreground),
173 (JoinChannelResponse, Foreground),
174 (JoinProject, Foreground),
175 (JoinProjectResponse, Foreground),
176 (StartLanguageServer, Foreground),
177 (UpdateLanguageServer, Foreground),
178 (LeaveChannel, Foreground),
179 (LeaveProject, Foreground),
180 (OpenBufferById, Background),
181 (OpenBufferByPath, Background),
182 (OpenBufferForSymbol, Background),
183 (OpenBufferForSymbolResponse, Background),
184 (OpenBufferResponse, Background),
185 (PerformRename, Background),
186 (PerformRenameResponse, Background),
187 (PrepareRename, Background),
188 (PrepareRenameResponse, Background),
189 (RegisterProjectResponse, Foreground),
190 (Ping, Foreground),
191 (RegisterProject, Foreground),
192 (RegisterWorktree, Foreground),
193 (ReloadBuffers, Foreground),
194 (ReloadBuffersResponse, Foreground),
195 (RemoveProjectCollaborator, Foreground),
196 (SaveBuffer, Foreground),
197 (SearchProject, Background),
198 (SearchProjectResponse, Background),
199 (SendChannelMessage, Foreground),
200 (SendChannelMessageResponse, Foreground),
201 (ShareProject, Foreground),
202 (Test, Foreground),
203 (Unfollow, Foreground),
204 (UnregisterProject, Foreground),
205 (UnregisterWorktree, Foreground),
206 (UnshareProject, Foreground),
207 (UpdateBuffer, Foreground),
208 (UpdateBufferFile, Foreground),
209 (UpdateContacts, Foreground),
210 (UpdateDiagnosticSummary, Foreground),
211 (UpdateFollowers, Foreground),
212 (UpdateWorktree, Foreground),
213);
214
215request_messages!(
216 (ApplyCodeAction, ApplyCodeActionResponse),
217 (
218 ApplyCompletionAdditionalEdits,
219 ApplyCompletionAdditionalEditsResponse
220 ),
221 (Follow, FollowResponse),
222 (FormatBuffers, FormatBuffersResponse),
223 (GetChannelMessages, GetChannelMessagesResponse),
224 (GetChannels, GetChannelsResponse),
225 (GetCodeActions, GetCodeActionsResponse),
226 (GetCompletions, GetCompletionsResponse),
227 (GetDefinition, GetDefinitionResponse),
228 (GetDocumentHighlights, GetDocumentHighlightsResponse),
229 (GetReferences, GetReferencesResponse),
230 (GetProjectSymbols, GetProjectSymbolsResponse),
231 (GetUsers, GetUsersResponse),
232 (JoinChannel, JoinChannelResponse),
233 (JoinProject, JoinProjectResponse),
234 (OpenBufferById, OpenBufferResponse),
235 (OpenBufferByPath, OpenBufferResponse),
236 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
237 (Ping, Ack),
238 (PerformRename, PerformRenameResponse),
239 (PrepareRename, PrepareRenameResponse),
240 (RegisterProject, RegisterProjectResponse),
241 (RegisterWorktree, Ack),
242 (ReloadBuffers, ReloadBuffersResponse),
243 (SaveBuffer, BufferSaved),
244 (SearchProject, SearchProjectResponse),
245 (SendChannelMessage, SendChannelMessageResponse),
246 (ShareProject, Ack),
247 (Test, Test),
248 (UpdateBuffer, Ack),
249 (UpdateWorktree, Ack),
250);
251
252entity_messages!(
253 project_id,
254 AddProjectCollaborator,
255 ApplyCodeAction,
256 ApplyCompletionAdditionalEdits,
257 BufferReloaded,
258 BufferSaved,
259 Follow,
260 FormatBuffers,
261 GetCodeActions,
262 GetCompletions,
263 GetDefinition,
264 GetDocumentHighlights,
265 GetReferences,
266 GetProjectSymbols,
267 JoinProject,
268 LeaveProject,
269 OpenBufferById,
270 OpenBufferByPath,
271 OpenBufferForSymbol,
272 PerformRename,
273 PrepareRename,
274 ReloadBuffers,
275 RemoveProjectCollaborator,
276 SaveBuffer,
277 SearchProject,
278 StartLanguageServer,
279 Unfollow,
280 UnregisterWorktree,
281 UnshareProject,
282 UpdateBuffer,
283 UpdateBufferFile,
284 UpdateDiagnosticSummary,
285 UpdateFollowers,
286 UpdateLanguageServer,
287 RegisterWorktree,
288 UpdateWorktree,
289);
290
291entity_messages!(channel_id, ChannelMessageSent);
292
293/// A stream of protobuf messages.
294pub struct MessageStream<S> {
295 stream: S,
296 encoding_buffer: Vec<u8>,
297}
298
299#[derive(Debug)]
300pub enum Message {
301 Envelope(Envelope),
302 Ping,
303 Pong,
304}
305
306impl<S> MessageStream<S> {
307 pub fn new(stream: S) -> Self {
308 Self {
309 stream,
310 encoding_buffer: Vec::new(),
311 }
312 }
313
314 pub fn inner_mut(&mut self) -> &mut S {
315 &mut self.stream
316 }
317}
318
319impl<S> MessageStream<S>
320where
321 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
322{
323 pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> {
324 #[cfg(any(test, feature = "test-support"))]
325 const COMPRESSION_LEVEL: i32 = -7;
326
327 #[cfg(not(any(test, feature = "test-support")))]
328 const COMPRESSION_LEVEL: i32 = 4;
329
330 match message {
331 Message::Envelope(message) => {
332 self.encoding_buffer.resize(message.encoded_len(), 0);
333 self.encoding_buffer.clear();
334 message
335 .encode(&mut self.encoding_buffer)
336 .map_err(|err| io::Error::from(err))?;
337 let buffer =
338 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
339 .unwrap();
340 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
341 }
342 Message::Ping => {
343 self.stream
344 .send(WebSocketMessage::Ping(Default::default()))
345 .await?;
346 }
347 Message::Pong => {
348 self.stream
349 .send(WebSocketMessage::Pong(Default::default()))
350 .await?;
351 }
352 }
353
354 Ok(())
355 }
356}
357
358impl<S> MessageStream<S>
359where
360 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
361{
362 pub async fn read(&mut self) -> Result<Message, WebSocketError> {
363 while let Some(bytes) = self.stream.next().await {
364 match bytes? {
365 WebSocketMessage::Binary(bytes) => {
366 self.encoding_buffer.clear();
367 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
368 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
369 .map_err(io::Error::from)?;
370 return Ok(Message::Envelope(envelope));
371 }
372 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
373 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
374 WebSocketMessage::Close(_) => break,
375 _ => {}
376 }
377 }
378 Err(WebSocketError::ConnectionClosed)
379 }
380}
381
382impl Into<SystemTime> for Timestamp {
383 fn into(self) -> SystemTime {
384 UNIX_EPOCH
385 .checked_add(Duration::new(self.seconds, self.nanos))
386 .unwrap()
387 }
388}
389
390impl From<SystemTime> for Timestamp {
391 fn from(time: SystemTime) -> Self {
392 let duration = time.duration_since(UNIX_EPOCH).unwrap();
393 Self {
394 seconds: duration.as_secs(),
395 nanos: duration.subsec_nanos(),
396 }
397 }
398}
399
400impl From<u128> for Nonce {
401 fn from(nonce: u128) -> Self {
402 let upper_half = (nonce >> 64) as u64;
403 let lower_half = nonce as u64;
404 Self {
405 upper_half,
406 lower_half,
407 }
408 }
409}
410
411impl From<Nonce> for u128 {
412 fn from(nonce: Nonce) -> Self {
413 let upper_half = (nonce.upper_half as u128) << 64;
414 let lower_half = nonce.lower_half as u128;
415 upper_half | lower_half
416 }
417}