1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::Result;
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message;
6use std::any::{Any, TypeId};
7use std::{
8 io,
9 time::{Duration, SystemTime, UNIX_EPOCH},
10};
11
12include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
13
14pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
15 const NAME: &'static str;
16 const PRIORITY: MessagePriority;
17 fn into_envelope(
18 self,
19 id: u32,
20 responding_to: Option<u32>,
21 original_sender_id: Option<u32>,
22 ) -> Envelope;
23 fn from_envelope(envelope: Envelope) -> Option<Self>;
24}
25
26pub trait EntityMessage: EnvelopedMessage {
27 fn remote_entity_id(&self) -> u64;
28}
29
30pub trait RequestMessage: EnvelopedMessage {
31 type Response: EnvelopedMessage;
32}
33
34pub trait AnyTypedEnvelope: 'static + Send + Sync {
35 fn payload_type_id(&self) -> TypeId;
36 fn payload_type_name(&self) -> &'static str;
37 fn as_any(&self) -> &dyn Any;
38 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
39 fn is_background(&self) -> bool;
40}
41
42pub enum MessagePriority {
43 Foreground,
44 Background,
45}
46
47impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
48 fn payload_type_id(&self) -> TypeId {
49 TypeId::of::<T>()
50 }
51
52 fn payload_type_name(&self) -> &'static str {
53 T::NAME
54 }
55
56 fn as_any(&self) -> &dyn Any {
57 self
58 }
59
60 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
61 self
62 }
63
64 fn is_background(&self) -> bool {
65 matches!(T::PRIORITY, MessagePriority::Background)
66 }
67}
68
69macro_rules! messages {
70 ($(($name:ident, $priority:ident)),* $(,)?) => {
71 pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
72 match envelope.payload {
73 $(Some(envelope::Payload::$name(payload)) => {
74 Some(Box::new(TypedEnvelope {
75 sender_id,
76 original_sender_id: envelope.original_sender_id.map(PeerId),
77 message_id: envelope.id,
78 payload,
79 }))
80 }, )*
81 _ => None
82 }
83 }
84
85 $(
86 impl EnvelopedMessage for $name {
87 const NAME: &'static str = std::stringify!($name);
88 const PRIORITY: MessagePriority = MessagePriority::$priority;
89
90 fn into_envelope(
91 self,
92 id: u32,
93 responding_to: Option<u32>,
94 original_sender_id: Option<u32>,
95 ) -> Envelope {
96 Envelope {
97 id,
98 responding_to,
99 original_sender_id,
100 payload: Some(envelope::Payload::$name(self)),
101 }
102 }
103
104 fn from_envelope(envelope: Envelope) -> Option<Self> {
105 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
106 Some(msg)
107 } else {
108 None
109 }
110 }
111 }
112 )*
113 };
114}
115
116macro_rules! request_messages {
117 ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
118 $(impl RequestMessage for $request_name {
119 type Response = $response_name;
120 })*
121 };
122}
123
124macro_rules! entity_messages {
125 ($id_field:ident, $($name:ident),* $(,)?) => {
126 $(impl EntityMessage for $name {
127 fn remote_entity_id(&self) -> u64 {
128 self.$id_field
129 }
130 })*
131 };
132}
133
134messages!(
135 (Ack, Foreground),
136 (AddProjectCollaborator, Foreground),
137 (ApplyCodeAction, Foreground),
138 (ApplyCodeActionResponse, Foreground),
139 (ApplyCompletionAdditionalEdits, Foreground),
140 (ApplyCompletionAdditionalEditsResponse, Foreground),
141 (BufferReloaded, Foreground),
142 (BufferSaved, Foreground),
143 (ChannelMessageSent, Foreground),
144 (CloseBuffer, Foreground),
145 (DiskBasedDiagnosticsUpdated, Background),
146 (DiskBasedDiagnosticsUpdating, Background),
147 (Error, Foreground),
148 (FormatBuffers, Foreground),
149 (FormatBuffersResponse, Foreground),
150 (GetChannelMessages, Foreground),
151 (GetChannelMessagesResponse, Foreground),
152 (GetChannels, Foreground),
153 (GetChannelsResponse, Foreground),
154 (GetCodeActions, Background),
155 (GetCodeActionsResponse, Foreground),
156 (GetCompletions, Background),
157 (GetCompletionsResponse, Foreground),
158 (GetDefinition, Foreground),
159 (GetDefinitionResponse, Foreground),
160 (GetUsers, Foreground),
161 (GetUsersResponse, Foreground),
162 (JoinChannel, Foreground),
163 (JoinChannelResponse, Foreground),
164 (JoinProject, Foreground),
165 (JoinProjectResponse, Foreground),
166 (LeaveChannel, Foreground),
167 (LeaveProject, Foreground),
168 (OpenBuffer, Foreground),
169 (OpenBufferResponse, Foreground),
170 (RegisterProjectResponse, Foreground),
171 (Ping, Foreground),
172 (RegisterProject, Foreground),
173 (RegisterWorktree, Foreground),
174 (RemoveProjectCollaborator, Foreground),
175 (SaveBuffer, Foreground),
176 (SendChannelMessage, Foreground),
177 (SendChannelMessageResponse, Foreground),
178 (ShareProject, Foreground),
179 (ShareWorktree, Foreground),
180 (Test, Foreground),
181 (UnregisterProject, Foreground),
182 (UnregisterWorktree, Foreground),
183 (UnshareProject, Foreground),
184 (UpdateBuffer, Foreground),
185 (UpdateBufferFile, Foreground),
186 (UpdateContacts, Foreground),
187 (UpdateDiagnosticSummary, Foreground),
188 (UpdateWorktree, Foreground),
189);
190
191request_messages!(
192 (ApplyCodeAction, ApplyCodeActionResponse),
193 (
194 ApplyCompletionAdditionalEdits,
195 ApplyCompletionAdditionalEditsResponse
196 ),
197 (FormatBuffers, FormatBuffersResponse),
198 (GetChannelMessages, GetChannelMessagesResponse),
199 (GetChannels, GetChannelsResponse),
200 (GetCodeActions, GetCodeActionsResponse),
201 (GetCompletions, GetCompletionsResponse),
202 (GetDefinition, GetDefinitionResponse),
203 (GetUsers, GetUsersResponse),
204 (JoinChannel, JoinChannelResponse),
205 (JoinProject, JoinProjectResponse),
206 (OpenBuffer, OpenBufferResponse),
207 (Ping, Ack),
208 (RegisterProject, RegisterProjectResponse),
209 (RegisterWorktree, Ack),
210 (SaveBuffer, BufferSaved),
211 (SendChannelMessage, SendChannelMessageResponse),
212 (ShareProject, Ack),
213 (ShareWorktree, Ack),
214 (Test, Test),
215 (UpdateBuffer, Ack),
216 (UpdateWorktree, Ack),
217);
218
219entity_messages!(
220 project_id,
221 AddProjectCollaborator,
222 ApplyCodeAction,
223 ApplyCompletionAdditionalEdits,
224 BufferReloaded,
225 BufferSaved,
226 CloseBuffer,
227 DiskBasedDiagnosticsUpdated,
228 DiskBasedDiagnosticsUpdating,
229 FormatBuffers,
230 GetCodeActions,
231 GetCompletions,
232 GetDefinition,
233 JoinProject,
234 LeaveProject,
235 OpenBuffer,
236 RemoveProjectCollaborator,
237 SaveBuffer,
238 ShareWorktree,
239 UnregisterWorktree,
240 UnshareProject,
241 UpdateBuffer,
242 UpdateBufferFile,
243 UpdateDiagnosticSummary,
244 UpdateWorktree,
245);
246
247entity_messages!(channel_id, ChannelMessageSent);
248
249/// A stream of protobuf messages.
250pub struct MessageStream<S> {
251 stream: S,
252 encoding_buffer: Vec<u8>,
253}
254
255impl<S> MessageStream<S> {
256 pub fn new(stream: S) -> Self {
257 Self {
258 stream,
259 encoding_buffer: Vec::new(),
260 }
261 }
262
263 pub fn inner_mut(&mut self) -> &mut S {
264 &mut self.stream
265 }
266}
267
268impl<S> MessageStream<S>
269where
270 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
271{
272 /// Write a given protobuf message to the stream.
273 pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
274 #[cfg(any(test, feature = "test-support"))]
275 const COMPRESSION_LEVEL: i32 = -7;
276
277 #[cfg(not(any(test, feature = "test-support")))]
278 const COMPRESSION_LEVEL: i32 = 4;
279
280 self.encoding_buffer.resize(message.encoded_len(), 0);
281 self.encoding_buffer.clear();
282 message
283 .encode(&mut self.encoding_buffer)
284 .map_err(|err| io::Error::from(err))?;
285 let buffer =
286 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL).unwrap();
287 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
288 Ok(())
289 }
290}
291
292impl<S> MessageStream<S>
293where
294 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
295{
296 /// Read a protobuf message of the given type from the stream.
297 pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
298 while let Some(bytes) = self.stream.next().await {
299 match bytes? {
300 WebSocketMessage::Binary(bytes) => {
301 self.encoding_buffer.clear();
302 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
303 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
304 .map_err(io::Error::from)?;
305 return Ok(envelope);
306 }
307 WebSocketMessage::Close(_) => break,
308 _ => {}
309 }
310 }
311 Err(WebSocketError::ConnectionClosed)
312 }
313}
314
315impl Into<SystemTime> for Timestamp {
316 fn into(self) -> SystemTime {
317 UNIX_EPOCH
318 .checked_add(Duration::new(self.seconds, self.nanos))
319 .unwrap()
320 }
321}
322
323impl From<SystemTime> for Timestamp {
324 fn from(time: SystemTime) -> Self {
325 let duration = time.duration_since(UNIX_EPOCH).unwrap();
326 Self {
327 seconds: duration.as_secs(),
328 nanos: duration.subsec_nanos(),
329 }
330 }
331}
332
333impl From<u128> for Nonce {
334 fn from(nonce: u128) -> Self {
335 let upper_half = (nonce >> 64) as u64;
336 let lower_half = nonce as u64;
337 Self {
338 upper_half,
339 lower_half,
340 }
341 }
342}
343
344impl From<Nonce> for u128 {
345 fn from(nonce: Nonce) -> Self {
346 let upper_half = (nonce.upper_half as u128) << 64;
347 let lower_half = nonce.lower_half as u128;
348 upper_half | lower_half
349 }
350}