1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::Result;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message as _;
6use std::any::{Any, TypeId};
7use std::{
8 io,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
13
14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
15 const NAME: &'static str;
16 const PRIORITY: MessagePriority;
17 fn into_envelope(
18 self,
19 id: u32,
20 responding_to: Option<u32>,
21 original_sender_id: Option<u32>,
22 ) -> Envelope;
23 fn from_envelope(envelope: Envelope) -> Option<Self>;
24}
25
26pub trait EntityMessage: EnvelopedMessage {
27 fn remote_entity_id(&self) -> u64;
28}
29
30pub trait RequestMessage: EnvelopedMessage {
31 type Response: EnvelopedMessage;
32}
33
34pub trait AnyTypedEnvelope: 'static + Send + Sync {
35 fn payload_type_id(&self) -> TypeId;
36 fn payload_type_name(&self) -> &'static str;
37 fn as_any(&self) -> &dyn Any;
38 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
39 fn is_background(&self) -> bool;
40 fn original_sender_id(&self) -> Option<PeerId>;
41}
42
43pub enum MessagePriority {
44 Foreground,
45 Background,
46}
47
48impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
49 fn payload_type_id(&self) -> TypeId {
50 TypeId::of::<T>()
51 }
52
53 fn payload_type_name(&self) -> &'static str {
54 T::NAME
55 }
56
57 fn as_any(&self) -> &dyn Any {
58 self
59 }
60
61 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
62 self
63 }
64
65 fn is_background(&self) -> bool {
66 matches!(T::PRIORITY, MessagePriority::Background)
67 }
68
69 fn original_sender_id(&self) -> Option<PeerId> {
70 self.original_sender_id
71 }
72}
73
74macro_rules! messages {
75 ($(($name:ident, $priority:ident)),* $(,)?) => {
76 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
77 match envelope.payload {
78 $(Some(envelope::Payload::$name(payload)) => {
79 Some(Box::new(TypedEnvelope {
80 sender_id,
81 original_sender_id: envelope.original_sender_id.map(PeerId),
82 message_id: envelope.id,
83 payload,
84 }))
85 }, )*
86 _ => None
87 }
88 }
89
90 $(
91 impl EnvelopedMessage for $name {
92 const NAME: &'static str = std::stringify!($name);
93 const PRIORITY: MessagePriority = MessagePriority::$priority;
94
95 fn into_envelope(
96 self,
97 id: u32,
98 responding_to: Option<u32>,
99 original_sender_id: Option<u32>,
100 ) -> Envelope {
101 Envelope {
102 id,
103 responding_to,
104 original_sender_id,
105 payload: Some(envelope::Payload::$name(self)),
106 }
107 }
108
109 fn from_envelope(envelope: Envelope) -> Option<Self> {
110 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
111 Some(msg)
112 } else {
113 None
114 }
115 }
116 }
117 )*
118 };
119}
120
121macro_rules! request_messages {
122 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
123 $(impl RequestMessage for $request_name {
124 type Response = $response_name;
125 })*
126 };
127}
128
129macro_rules! entity_messages {
130 ($id_field:ident, $($name:ident),* $(,)?) => {
131 $(impl EntityMessage for $name {
132 fn remote_entity_id(&self) -> u64 {
133 self.$id_field
134 }
135 })*
136 };
137}
138
139messages!(
140 (Ack, Foreground),
141 (AddProjectCollaborator, Foreground),
142 (ApplyCodeAction, Background),
143 (ApplyCodeActionResponse, Background),
144 (ApplyCompletionAdditionalEdits, Background),
145 (ApplyCompletionAdditionalEditsResponse, Background),
146 (BufferReloaded, Foreground),
147 (BufferSaved, Foreground),
148 (ChannelMessageSent, Foreground),
149 (Error, Foreground),
150 (Follow, Foreground),
151 (FollowResponse, Foreground),
152 (FormatBuffers, Foreground),
153 (FormatBuffersResponse, Foreground),
154 (GetChannelMessages, Foreground),
155 (GetChannelMessagesResponse, Foreground),
156 (GetChannels, Foreground),
157 (GetChannelsResponse, Foreground),
158 (GetCodeActions, Background),
159 (GetCodeActionsResponse, Background),
160 (GetCompletions, Background),
161 (GetCompletionsResponse, Background),
162 (GetDefinition, Background),
163 (GetDefinitionResponse, Background),
164 (GetDocumentHighlights, Background),
165 (GetDocumentHighlightsResponse, Background),
166 (GetReferences, Background),
167 (GetReferencesResponse, Background),
168 (GetProjectSymbols, Background),
169 (GetProjectSymbolsResponse, Background),
170 (GetUsers, Foreground),
171 (GetUsersResponse, Foreground),
172 (JoinChannel, Foreground),
173 (JoinChannelResponse, Foreground),
174 (JoinProject, Foreground),
175 (JoinProjectResponse, Foreground),
176 (StartLanguageServer, Foreground),
177 (UpdateLanguageServer, Foreground),
178 (LeaveChannel, Foreground),
179 (LeaveProject, Foreground),
180 (OpenBufferById, Background),
181 (OpenBufferByPath, Background),
182 (OpenBufferForSymbol, Background),
183 (OpenBufferForSymbolResponse, Background),
184 (OpenBufferResponse, Background),
185 (PerformRename, Background),
186 (PerformRenameResponse, Background),
187 (PrepareRename, Background),
188 (PrepareRenameResponse, Background),
189 (RegisterProjectResponse, Foreground),
190 (Ping, Foreground),
191 (RegisterProject, Foreground),
192 (RegisterWorktree, Foreground),
193 (RemoveProjectCollaborator, Foreground),
194 (SaveBuffer, Foreground),
195 (SearchProject, Background),
196 (SearchProjectResponse, Background),
197 (SendChannelMessage, Foreground),
198 (SendChannelMessageResponse, Foreground),
199 (ShareProject, Foreground),
200 (Test, Foreground),
201 (Unfollow, Foreground),
202 (UnregisterProject, Foreground),
203 (UnregisterWorktree, Foreground),
204 (UnshareProject, Foreground),
205 (UpdateBuffer, Foreground),
206 (UpdateBufferFile, Foreground),
207 (UpdateContacts, Foreground),
208 (UpdateDiagnosticSummary, Foreground),
209 (UpdateFollowers, Foreground),
210 (UpdateWorktree, Foreground),
211);
212
213request_messages!(
214 (ApplyCodeAction, ApplyCodeActionResponse),
215 (
216 ApplyCompletionAdditionalEdits,
217 ApplyCompletionAdditionalEditsResponse
218 ),
219 (Follow, FollowResponse),
220 (FormatBuffers, FormatBuffersResponse),
221 (GetChannelMessages, GetChannelMessagesResponse),
222 (GetChannels, GetChannelsResponse),
223 (GetCodeActions, GetCodeActionsResponse),
224 (GetCompletions, GetCompletionsResponse),
225 (GetDefinition, GetDefinitionResponse),
226 (GetDocumentHighlights, GetDocumentHighlightsResponse),
227 (GetReferences, GetReferencesResponse),
228 (GetProjectSymbols, GetProjectSymbolsResponse),
229 (GetUsers, GetUsersResponse),
230 (JoinChannel, JoinChannelResponse),
231 (JoinProject, JoinProjectResponse),
232 (OpenBufferById, OpenBufferResponse),
233 (OpenBufferByPath, OpenBufferResponse),
234 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
235 (Ping, Ack),
236 (PerformRename, PerformRenameResponse),
237 (PrepareRename, PrepareRenameResponse),
238 (RegisterProject, RegisterProjectResponse),
239 (RegisterWorktree, Ack),
240 (SaveBuffer, BufferSaved),
241 (SearchProject, SearchProjectResponse),
242 (SendChannelMessage, SendChannelMessageResponse),
243 (ShareProject, Ack),
244 (Test, Test),
245 (UpdateBuffer, Ack),
246 (UpdateWorktree, Ack),
247);
248
249entity_messages!(
250 project_id,
251 AddProjectCollaborator,
252 ApplyCodeAction,
253 ApplyCompletionAdditionalEdits,
254 BufferReloaded,
255 BufferSaved,
256 Follow,
257 FormatBuffers,
258 GetCodeActions,
259 GetCompletions,
260 GetDefinition,
261 GetDocumentHighlights,
262 GetReferences,
263 GetProjectSymbols,
264 JoinProject,
265 LeaveProject,
266 OpenBufferById,
267 OpenBufferByPath,
268 OpenBufferForSymbol,
269 PerformRename,
270 PrepareRename,
271 RemoveProjectCollaborator,
272 SaveBuffer,
273 SearchProject,
274 StartLanguageServer,
275 Unfollow,
276 UnregisterWorktree,
277 UnshareProject,
278 UpdateBuffer,
279 UpdateBufferFile,
280 UpdateDiagnosticSummary,
281 UpdateFollowers,
282 UpdateLanguageServer,
283 RegisterWorktree,
284 UpdateWorktree,
285);
286
287entity_messages!(channel_id, ChannelMessageSent);
288
289/// A stream of protobuf messages.
290pub struct MessageStream<S> {
291 stream: S,
292 encoding_buffer: Vec<u8>,
293}
294
295#[derive(Debug)]
296pub enum Message {
297 Envelope(Envelope),
298 Ping,
299 Pong,
300}
301
302impl<S> MessageStream<S> {
303 pub fn new(stream: S) -> Self {
304 Self {
305 stream,
306 encoding_buffer: Vec::new(),
307 }
308 }
309
310 pub fn inner_mut(&mut self) -> &mut S {
311 &mut self.stream
312 }
313}
314
315impl<S> MessageStream<S>
316where
317 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
318{
319 pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> {
320 #[cfg(any(test, feature = "test-support"))]
321 const COMPRESSION_LEVEL: i32 = -7;
322
323 #[cfg(not(any(test, feature = "test-support")))]
324 const COMPRESSION_LEVEL: i32 = 4;
325
326 match message {
327 Message::Envelope(message) => {
328 self.encoding_buffer.resize(message.encoded_len(), 0);
329 self.encoding_buffer.clear();
330 message
331 .encode(&mut self.encoding_buffer)
332 .map_err(|err| io::Error::from(err))?;
333 let buffer =
334 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
335 .unwrap();
336 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
337 }
338 Message::Ping => {
339 self.stream
340 .send(WebSocketMessage::Ping(Default::default()))
341 .await?;
342 }
343 Message::Pong => {
344 self.stream
345 .send(WebSocketMessage::Pong(Default::default()))
346 .await?;
347 }
348 }
349
350 Ok(())
351 }
352}
353
354impl<S> MessageStream<S>
355where
356 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
357{
358 pub async fn read(&mut self) -> Result<Message, WebSocketError> {
359 while let Some(bytes) = self.stream.next().await {
360 match bytes? {
361 WebSocketMessage::Binary(bytes) => {
362 self.encoding_buffer.clear();
363 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
364 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
365 .map_err(io::Error::from)?;
366 return Ok(Message::Envelope(envelope));
367 }
368 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
369 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
370 WebSocketMessage::Close(_) => break,
371 _ => {}
372 }
373 }
374 Err(WebSocketError::ConnectionClosed)
375 }
376}
377
378impl Into<SystemTime> for Timestamp {
379 fn into(self) -> SystemTime {
380 UNIX_EPOCH
381 .checked_add(Duration::new(self.seconds, self.nanos))
382 .unwrap()
383 }
384}
385
386impl From<SystemTime> for Timestamp {
387 fn from(time: SystemTime) -> Self {
388 let duration = time.duration_since(UNIX_EPOCH).unwrap();
389 Self {
390 seconds: duration.as_secs(),
391 nanos: duration.subsec_nanos(),
392 }
393 }
394}
395
396impl From<u128> for Nonce {
397 fn from(nonce: u128) -> Self {
398 let upper_half = (nonce >> 64) as u64;
399 let lower_half = nonce as u64;
400 Self {
401 upper_half,
402 lower_half,
403 }
404 }
405}
406
407impl From<Nonce> for u128 {
408 fn from(nonce: Nonce) -> Self {
409 let upper_half = (nonce.upper_half as u128) << 64;
410 let lower_half = nonce.lower_half as u128;
411 upper_half | lower_half
412 }
413}