1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::{anyhow, Result};
3use async_tungstenite::tungstenite::Message as WebSocketMessage;
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message as _;
6use serde::Serialize;
7use std::any::{Any, TypeId};
8use std::{
9 io,
10 time::{Duration, SystemTime, UNIX_EPOCH},
11};
12
13include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
14
15pub trait EnvelopedMessage: Clone + Serialize + Sized + Send + Sync + 'static {
16 const NAME: &'static str;
17 const PRIORITY: MessagePriority;
18 fn into_envelope(
19 self,
20 id: u32,
21 responding_to: Option<u32>,
22 original_sender_id: Option<u32>,
23 ) -> Envelope;
24 fn from_envelope(envelope: Envelope) -> Option<Self>;
25}
26
27pub trait EntityMessage: EnvelopedMessage {
28 fn remote_entity_id(&self) -> u64;
29}
30
31pub trait RequestMessage: EnvelopedMessage {
32 type Response: EnvelopedMessage;
33}
34
35pub trait AnyTypedEnvelope: 'static + Send + Sync {
36 fn payload_type_id(&self) -> TypeId;
37 fn payload_type_name(&self) -> &'static str;
38 fn as_any(&self) -> &dyn Any;
39 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
40 fn is_background(&self) -> bool;
41 fn original_sender_id(&self) -> Option<PeerId>;
42}
43
44pub enum MessagePriority {
45 Foreground,
46 Background,
47}
48
49impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
50 fn payload_type_id(&self) -> TypeId {
51 TypeId::of::<T>()
52 }
53
54 fn payload_type_name(&self) -> &'static str {
55 T::NAME
56 }
57
58 fn as_any(&self) -> &dyn Any {
59 self
60 }
61
62 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
63 self
64 }
65
66 fn is_background(&self) -> bool {
67 matches!(T::PRIORITY, MessagePriority::Background)
68 }
69
70 fn original_sender_id(&self) -> Option<PeerId> {
71 self.original_sender_id
72 }
73}
74
75macro_rules! messages {
76 ($(($name:ident, $priority:ident)),* $(,)?) => {
77 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
78 match envelope.payload {
79 $(Some(envelope::Payload::$name(payload)) => {
80 Some(Box::new(TypedEnvelope {
81 sender_id,
82 original_sender_id: envelope.original_sender_id.map(PeerId),
83 message_id: envelope.id,
84 payload,
85 }))
86 }, )*
87 _ => None
88 }
89 }
90
91 $(
92 impl EnvelopedMessage for $name {
93 const NAME: &'static str = std::stringify!($name);
94 const PRIORITY: MessagePriority = MessagePriority::$priority;
95
96 fn into_envelope(
97 self,
98 id: u32,
99 responding_to: Option<u32>,
100 original_sender_id: Option<u32>,
101 ) -> Envelope {
102 Envelope {
103 id,
104 responding_to,
105 original_sender_id,
106 payload: Some(envelope::Payload::$name(self)),
107 }
108 }
109
110 fn from_envelope(envelope: Envelope) -> Option<Self> {
111 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
112 Some(msg)
113 } else {
114 None
115 }
116 }
117 }
118 )*
119 };
120}
121
122macro_rules! request_messages {
123 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
124 $(impl RequestMessage for $request_name {
125 type Response = $response_name;
126 })*
127 };
128}
129
130macro_rules! entity_messages {
131 ($id_field:ident, $($name:ident),* $(,)?) => {
132 $(impl EntityMessage for $name {
133 fn remote_entity_id(&self) -> u64 {
134 self.$id_field
135 }
136 })*
137 };
138}
139
140messages!(
141 (Ack, Foreground),
142 (AddProjectCollaborator, Foreground),
143 (ApplyCodeAction, Background),
144 (ApplyCodeActionResponse, Background),
145 (ApplyCompletionAdditionalEdits, Background),
146 (ApplyCompletionAdditionalEditsResponse, Background),
147 (BufferReloaded, Foreground),
148 (BufferSaved, Foreground),
149 (ChannelMessageSent, Foreground),
150 (CreateProjectEntry, Foreground),
151 (DeleteProjectEntry, Foreground),
152 (Error, Foreground),
153 (Follow, Foreground),
154 (FollowResponse, Foreground),
155 (FormatBuffers, Foreground),
156 (FormatBuffersResponse, Foreground),
157 (GetChannelMessages, Foreground),
158 (GetChannelMessagesResponse, Foreground),
159 (GetChannels, Foreground),
160 (GetChannelsResponse, Foreground),
161 (GetCodeActions, Background),
162 (GetCodeActionsResponse, Background),
163 (GetCompletions, Background),
164 (GetCompletionsResponse, Background),
165 (GetDefinition, Background),
166 (GetDefinitionResponse, Background),
167 (GetDocumentHighlights, Background),
168 (GetDocumentHighlightsResponse, Background),
169 (GetReferences, Background),
170 (GetReferencesResponse, Background),
171 (GetProjectSymbols, Background),
172 (GetProjectSymbolsResponse, Background),
173 (GetUsers, Foreground),
174 (GetUsersResponse, Foreground),
175 (JoinChannel, Foreground),
176 (JoinChannelResponse, Foreground),
177 (JoinProject, Foreground),
178 (JoinProjectResponse, Foreground),
179 (LeaveChannel, Foreground),
180 (LeaveProject, Foreground),
181 (OpenBufferById, Background),
182 (OpenBufferByPath, Background),
183 (OpenBufferForSymbol, Background),
184 (OpenBufferForSymbolResponse, Background),
185 (OpenBufferResponse, Background),
186 (PerformRename, Background),
187 (PerformRenameResponse, Background),
188 (PrepareRename, Background),
189 (PrepareRenameResponse, Background),
190 (ProjectEntryResponse, Foreground),
191 (RegisterProjectResponse, Foreground),
192 (Ping, Foreground),
193 (RegisterProject, Foreground),
194 (RegisterWorktree, Foreground),
195 (ReloadBuffers, Foreground),
196 (ReloadBuffersResponse, Foreground),
197 (RemoveProjectCollaborator, Foreground),
198 (RenameProjectEntry, Foreground),
199 (SaveBuffer, Foreground),
200 (SearchProject, Background),
201 (SearchProjectResponse, Background),
202 (SendChannelMessage, Foreground),
203 (SendChannelMessageResponse, Foreground),
204 (ShareProject, Foreground),
205 (StartLanguageServer, Foreground),
206 (Test, Foreground),
207 (Unfollow, Foreground),
208 (UnregisterProject, Foreground),
209 (UnregisterWorktree, Foreground),
210 (UnshareProject, Foreground),
211 (UpdateBuffer, Foreground),
212 (UpdateBufferFile, Foreground),
213 (UpdateContacts, Foreground),
214 (UpdateDiagnosticSummary, Foreground),
215 (UpdateFollowers, Foreground),
216 (UpdateLanguageServer, Foreground),
217 (UpdateWorktree, Foreground),
218);
219
220request_messages!(
221 (ApplyCodeAction, ApplyCodeActionResponse),
222 (
223 ApplyCompletionAdditionalEdits,
224 ApplyCompletionAdditionalEditsResponse
225 ),
226 (CreateProjectEntry, ProjectEntryResponse),
227 (Follow, FollowResponse),
228 (FormatBuffers, FormatBuffersResponse),
229 (GetChannelMessages, GetChannelMessagesResponse),
230 (GetChannels, GetChannelsResponse),
231 (GetCodeActions, GetCodeActionsResponse),
232 (GetCompletions, GetCompletionsResponse),
233 (GetDefinition, GetDefinitionResponse),
234 (GetDocumentHighlights, GetDocumentHighlightsResponse),
235 (GetReferences, GetReferencesResponse),
236 (GetProjectSymbols, GetProjectSymbolsResponse),
237 (GetUsers, GetUsersResponse),
238 (JoinChannel, JoinChannelResponse),
239 (JoinProject, JoinProjectResponse),
240 (OpenBufferById, OpenBufferResponse),
241 (OpenBufferByPath, OpenBufferResponse),
242 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
243 (Ping, Ack),
244 (PerformRename, PerformRenameResponse),
245 (PrepareRename, PrepareRenameResponse),
246 (RegisterProject, RegisterProjectResponse),
247 (RegisterWorktree, Ack),
248 (ReloadBuffers, ReloadBuffersResponse),
249 (RenameProjectEntry, ProjectEntryResponse),
250 (SaveBuffer, BufferSaved),
251 (SearchProject, SearchProjectResponse),
252 (SendChannelMessage, SendChannelMessageResponse),
253 (ShareProject, Ack),
254 (Test, Test),
255 (UpdateBuffer, Ack),
256 (UpdateWorktree, Ack),
257);
258
259entity_messages!(
260 project_id,
261 AddProjectCollaborator,
262 ApplyCodeAction,
263 ApplyCompletionAdditionalEdits,
264 BufferReloaded,
265 BufferSaved,
266 CreateProjectEntry,
267 RenameProjectEntry,
268 DeleteProjectEntry,
269 Follow,
270 FormatBuffers,
271 GetCodeActions,
272 GetCompletions,
273 GetDefinition,
274 GetDocumentHighlights,
275 GetReferences,
276 GetProjectSymbols,
277 JoinProject,
278 LeaveProject,
279 OpenBufferById,
280 OpenBufferByPath,
281 OpenBufferForSymbol,
282 PerformRename,
283 PrepareRename,
284 ReloadBuffers,
285 RemoveProjectCollaborator,
286 SaveBuffer,
287 SearchProject,
288 StartLanguageServer,
289 Unfollow,
290 UnregisterWorktree,
291 UnshareProject,
292 UpdateBuffer,
293 UpdateBufferFile,
294 UpdateDiagnosticSummary,
295 UpdateFollowers,
296 UpdateLanguageServer,
297 RegisterWorktree,
298 UpdateWorktree,
299);
300
301entity_messages!(channel_id, ChannelMessageSent);
302
303/// A stream of protobuf messages.
304pub struct MessageStream<S> {
305 stream: S,
306 encoding_buffer: Vec<u8>,
307}
308
309#[derive(Debug)]
310pub enum Message {
311 Envelope(Envelope),
312 Ping,
313 Pong,
314}
315
316impl<S> MessageStream<S> {
317 pub fn new(stream: S) -> Self {
318 Self {
319 stream,
320 encoding_buffer: Vec::new(),
321 }
322 }
323
324 pub fn inner_mut(&mut self) -> &mut S {
325 &mut self.stream
326 }
327}
328
329impl<S> MessageStream<S>
330where
331 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
332{
333 pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
334 #[cfg(any(test, feature = "test-support"))]
335 const COMPRESSION_LEVEL: i32 = -7;
336
337 #[cfg(not(any(test, feature = "test-support")))]
338 const COMPRESSION_LEVEL: i32 = 4;
339
340 match message {
341 Message::Envelope(message) => {
342 self.encoding_buffer.resize(message.encoded_len(), 0);
343 self.encoding_buffer.clear();
344 message
345 .encode(&mut self.encoding_buffer)
346 .map_err(|err| io::Error::from(err))?;
347 let buffer =
348 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
349 .unwrap();
350 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
351 }
352 Message::Ping => {
353 self.stream
354 .send(WebSocketMessage::Ping(Default::default()))
355 .await?;
356 }
357 Message::Pong => {
358 self.stream
359 .send(WebSocketMessage::Pong(Default::default()))
360 .await?;
361 }
362 }
363
364 Ok(())
365 }
366}
367
368impl<S> MessageStream<S>
369where
370 S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
371{
372 pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
373 while let Some(bytes) = self.stream.next().await {
374 match bytes? {
375 WebSocketMessage::Binary(bytes) => {
376 self.encoding_buffer.clear();
377 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
378 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
379 .map_err(io::Error::from)?;
380 return Ok(Message::Envelope(envelope));
381 }
382 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
383 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
384 WebSocketMessage::Close(_) => break,
385 _ => {}
386 }
387 }
388 Err(anyhow!("connection closed"))
389 }
390}
391
392impl Into<SystemTime> for Timestamp {
393 fn into(self) -> SystemTime {
394 UNIX_EPOCH
395 .checked_add(Duration::new(self.seconds, self.nanos))
396 .unwrap()
397 }
398}
399
400impl From<SystemTime> for Timestamp {
401 fn from(time: SystemTime) -> Self {
402 let duration = time.duration_since(UNIX_EPOCH).unwrap();
403 Self {
404 seconds: duration.as_secs(),
405 nanos: duration.subsec_nanos(),
406 }
407 }
408}
409
410impl From<u128> for Nonce {
411 fn from(nonce: u128) -> Self {
412 let upper_half = (nonce >> 64) as u64;
413 let lower_half = nonce as u64;
414 Self {
415 upper_half,
416 lower_half,
417 }
418 }
419}
420
421impl From<Nonce> for u128 {
422 fn from(nonce: Nonce) -> Self {
423 let upper_half = (nonce.upper_half as u128) << 64;
424 let lower_half = nonce.lower_half as u128;
425 upper_half | lower_half
426 }
427}