1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::{anyhow, Result};
3use async_tungstenite::tungstenite::Message as WebSocketMessage;
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message as _;
6use serde::Serialize;
7use std::any::{Any, TypeId};
8use std::{
9 io,
10 time::{Duration, SystemTime, UNIX_EPOCH},
11};
12
13include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
14
15pub trait EnvelopedMessage: Clone + Serialize + Sized + Send + Sync + 'static {
16 const NAME: &'static str;
17 const PRIORITY: MessagePriority;
18 fn into_envelope(
19 self,
20 id: u32,
21 responding_to: Option<u32>,
22 original_sender_id: Option<u32>,
23 ) -> Envelope;
24 fn from_envelope(envelope: Envelope) -> Option<Self>;
25}
26
27pub trait EntityMessage: EnvelopedMessage {
28 fn remote_entity_id(&self) -> u64;
29}
30
31pub trait RequestMessage: EnvelopedMessage {
32 type Response: EnvelopedMessage;
33}
34
35pub trait AnyTypedEnvelope: 'static + Send + Sync {
36 fn payload_type_id(&self) -> TypeId;
37 fn payload_type_name(&self) -> &'static str;
38 fn as_any(&self) -> &dyn Any;
39 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
40 fn is_background(&self) -> bool;
41 fn original_sender_id(&self) -> Option<PeerId>;
42}
43
44pub enum MessagePriority {
45 Foreground,
46 Background,
47}
48
49impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
50 fn payload_type_id(&self) -> TypeId {
51 TypeId::of::<T>()
52 }
53
54 fn payload_type_name(&self) -> &'static str {
55 T::NAME
56 }
57
58 fn as_any(&self) -> &dyn Any {
59 self
60 }
61
62 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
63 self
64 }
65
66 fn is_background(&self) -> bool {
67 matches!(T::PRIORITY, MessagePriority::Background)
68 }
69
70 fn original_sender_id(&self) -> Option<PeerId> {
71 self.original_sender_id
72 }
73}
74
75macro_rules! messages {
76 ($(($name:ident, $priority:ident)),* $(,)?) => {
77 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
78 match envelope.payload {
79 $(Some(envelope::Payload::$name(payload)) => {
80 Some(Box::new(TypedEnvelope {
81 sender_id,
82 original_sender_id: envelope.original_sender_id.map(PeerId),
83 message_id: envelope.id,
84 payload,
85 }))
86 }, )*
87 _ => None
88 }
89 }
90
91 $(
92 impl EnvelopedMessage for $name {
93 const NAME: &'static str = std::stringify!($name);
94 const PRIORITY: MessagePriority = MessagePriority::$priority;
95
96 fn into_envelope(
97 self,
98 id: u32,
99 responding_to: Option<u32>,
100 original_sender_id: Option<u32>,
101 ) -> Envelope {
102 Envelope {
103 id,
104 responding_to,
105 original_sender_id,
106 payload: Some(envelope::Payload::$name(self)),
107 }
108 }
109
110 fn from_envelope(envelope: Envelope) -> Option<Self> {
111 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
112 Some(msg)
113 } else {
114 None
115 }
116 }
117 }
118 )*
119 };
120}
121
122macro_rules! request_messages {
123 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
124 $(impl RequestMessage for $request_name {
125 type Response = $response_name;
126 })*
127 };
128}
129
130macro_rules! entity_messages {
131 ($id_field:ident, $($name:ident),* $(,)?) => {
132 $(impl EntityMessage for $name {
133 fn remote_entity_id(&self) -> u64 {
134 self.$id_field
135 }
136 })*
137 };
138}
139
140messages!(
141 (Ack, Foreground),
142 (AddProjectCollaborator, Foreground),
143 (ApplyCodeAction, Background),
144 (ApplyCodeActionResponse, Background),
145 (ApplyCompletionAdditionalEdits, Background),
146 (ApplyCompletionAdditionalEditsResponse, Background),
147 (BufferReloaded, Foreground),
148 (BufferSaved, Foreground),
149 (ChannelMessageSent, Foreground),
150 (Error, Foreground),
151 (Follow, Foreground),
152 (FollowResponse, Foreground),
153 (FormatBuffers, Foreground),
154 (FormatBuffersResponse, Foreground),
155 (GetChannelMessages, Foreground),
156 (GetChannelMessagesResponse, Foreground),
157 (GetChannels, Foreground),
158 (GetChannelsResponse, Foreground),
159 (GetCodeActions, Background),
160 (GetCodeActionsResponse, Background),
161 (GetCompletions, Background),
162 (GetCompletionsResponse, Background),
163 (GetDefinition, Background),
164 (GetDefinitionResponse, Background),
165 (GetDocumentHighlights, Background),
166 (GetDocumentHighlightsResponse, Background),
167 (GetReferences, Background),
168 (GetReferencesResponse, Background),
169 (GetProjectSymbols, Background),
170 (GetProjectSymbolsResponse, Background),
171 (GetUsers, Foreground),
172 (GetUsersResponse, Foreground),
173 (JoinChannel, Foreground),
174 (JoinChannelResponse, Foreground),
175 (JoinProject, Foreground),
176 (JoinProjectResponse, Foreground),
177 (StartLanguageServer, Foreground),
178 (UpdateLanguageServer, Foreground),
179 (LeaveChannel, Foreground),
180 (LeaveProject, Foreground),
181 (OpenBufferById, Background),
182 (OpenBufferByPath, Background),
183 (OpenBufferForSymbol, Background),
184 (OpenBufferForSymbolResponse, Background),
185 (OpenBufferResponse, Background),
186 (PerformRename, Background),
187 (PerformRenameResponse, Background),
188 (PrepareRename, Background),
189 (PrepareRenameResponse, Background),
190 (RegisterProjectResponse, Foreground),
191 (Ping, Foreground),
192 (RegisterProject, Foreground),
193 (RegisterWorktree, Foreground),
194 (ReloadBuffers, Foreground),
195 (ReloadBuffersResponse, Foreground),
196 (RemoveProjectCollaborator, Foreground),
197 (SaveBuffer, Foreground),
198 (SearchProject, Background),
199 (SearchProjectResponse, Background),
200 (SendChannelMessage, Foreground),
201 (SendChannelMessageResponse, Foreground),
202 (ShareProject, Foreground),
203 (Test, Foreground),
204 (Unfollow, Foreground),
205 (UnregisterProject, Foreground),
206 (UnregisterWorktree, Foreground),
207 (UnshareProject, Foreground),
208 (UpdateBuffer, Foreground),
209 (UpdateBufferFile, Foreground),
210 (UpdateContacts, Foreground),
211 (UpdateDiagnosticSummary, Foreground),
212 (UpdateFollowers, Foreground),
213 (UpdateWorktree, Foreground),
214);
215
216request_messages!(
217 (ApplyCodeAction, ApplyCodeActionResponse),
218 (
219 ApplyCompletionAdditionalEdits,
220 ApplyCompletionAdditionalEditsResponse
221 ),
222 (Follow, FollowResponse),
223 (FormatBuffers, FormatBuffersResponse),
224 (GetChannelMessages, GetChannelMessagesResponse),
225 (GetChannels, GetChannelsResponse),
226 (GetCodeActions, GetCodeActionsResponse),
227 (GetCompletions, GetCompletionsResponse),
228 (GetDefinition, GetDefinitionResponse),
229 (GetDocumentHighlights, GetDocumentHighlightsResponse),
230 (GetReferences, GetReferencesResponse),
231 (GetProjectSymbols, GetProjectSymbolsResponse),
232 (GetUsers, GetUsersResponse),
233 (JoinChannel, JoinChannelResponse),
234 (JoinProject, JoinProjectResponse),
235 (OpenBufferById, OpenBufferResponse),
236 (OpenBufferByPath, OpenBufferResponse),
237 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
238 (Ping, Ack),
239 (PerformRename, PerformRenameResponse),
240 (PrepareRename, PrepareRenameResponse),
241 (RegisterProject, RegisterProjectResponse),
242 (RegisterWorktree, Ack),
243 (ReloadBuffers, ReloadBuffersResponse),
244 (SaveBuffer, BufferSaved),
245 (SearchProject, SearchProjectResponse),
246 (SendChannelMessage, SendChannelMessageResponse),
247 (ShareProject, Ack),
248 (Test, Test),
249 (UpdateBuffer, Ack),
250 (UpdateWorktree, Ack),
251);
252
253entity_messages!(
254 project_id,
255 AddProjectCollaborator,
256 ApplyCodeAction,
257 ApplyCompletionAdditionalEdits,
258 BufferReloaded,
259 BufferSaved,
260 Follow,
261 FormatBuffers,
262 GetCodeActions,
263 GetCompletions,
264 GetDefinition,
265 GetDocumentHighlights,
266 GetReferences,
267 GetProjectSymbols,
268 JoinProject,
269 LeaveProject,
270 OpenBufferById,
271 OpenBufferByPath,
272 OpenBufferForSymbol,
273 PerformRename,
274 PrepareRename,
275 ReloadBuffers,
276 RemoveProjectCollaborator,
277 SaveBuffer,
278 SearchProject,
279 StartLanguageServer,
280 Unfollow,
281 UnregisterWorktree,
282 UnshareProject,
283 UpdateBuffer,
284 UpdateBufferFile,
285 UpdateDiagnosticSummary,
286 UpdateFollowers,
287 UpdateLanguageServer,
288 RegisterWorktree,
289 UpdateWorktree,
290);
291
292entity_messages!(channel_id, ChannelMessageSent);
293
294/// A stream of protobuf messages.
295pub struct MessageStream<S> {
296 stream: S,
297 encoding_buffer: Vec<u8>,
298}
299
300#[derive(Debug)]
301pub enum Message {
302 Envelope(Envelope),
303 Ping,
304 Pong,
305}
306
307impl<S> MessageStream<S> {
308 pub fn new(stream: S) -> Self {
309 Self {
310 stream,
311 encoding_buffer: Vec::new(),
312 }
313 }
314
315 pub fn inner_mut(&mut self) -> &mut S {
316 &mut self.stream
317 }
318}
319
320impl<S> MessageStream<S>
321where
322 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
323{
324 pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
325 #[cfg(any(test, feature = "test-support"))]
326 const COMPRESSION_LEVEL: i32 = -7;
327
328 #[cfg(not(any(test, feature = "test-support")))]
329 const COMPRESSION_LEVEL: i32 = 4;
330
331 match message {
332 Message::Envelope(message) => {
333 self.encoding_buffer.resize(message.encoded_len(), 0);
334 self.encoding_buffer.clear();
335 message
336 .encode(&mut self.encoding_buffer)
337 .map_err(|err| io::Error::from(err))?;
338 let buffer =
339 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
340 .unwrap();
341 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
342 }
343 Message::Ping => {
344 self.stream
345 .send(WebSocketMessage::Ping(Default::default()))
346 .await?;
347 }
348 Message::Pong => {
349 self.stream
350 .send(WebSocketMessage::Pong(Default::default()))
351 .await?;
352 }
353 }
354
355 Ok(())
356 }
357}
358
359impl<S> MessageStream<S>
360where
361 S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
362{
363 pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
364 while let Some(bytes) = self.stream.next().await {
365 match bytes? {
366 WebSocketMessage::Binary(bytes) => {
367 self.encoding_buffer.clear();
368 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
369 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
370 .map_err(io::Error::from)?;
371 return Ok(Message::Envelope(envelope));
372 }
373 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
374 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
375 WebSocketMessage::Close(_) => break,
376 _ => {}
377 }
378 }
379 Err(anyhow!("connection closed"))
380 }
381}
382
383impl Into<SystemTime> for Timestamp {
384 fn into(self) -> SystemTime {
385 UNIX_EPOCH
386 .checked_add(Duration::new(self.seconds, self.nanos))
387 .unwrap()
388 }
389}
390
391impl From<SystemTime> for Timestamp {
392 fn from(time: SystemTime) -> Self {
393 let duration = time.duration_since(UNIX_EPOCH).unwrap();
394 Self {
395 seconds: duration.as_secs(),
396 nanos: duration.subsec_nanos(),
397 }
398 }
399}
400
401impl From<u128> for Nonce {
402 fn from(nonce: u128) -> Self {
403 let upper_half = (nonce >> 64) as u64;
404 let lower_half = nonce as u64;
405 Self {
406 upper_half,
407 lower_half,
408 }
409 }
410}
411
412impl From<Nonce> for u128 {
413 fn from(nonce: Nonce) -> Self {
414 let upper_half = (nonce.upper_half as u128) << 64;
415 let lower_half = nonce.lower_half as u128;
416 upper_half | lower_half
417 }
418}