1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
2use anyhow::{anyhow, Result};
3use async_tungstenite::tungstenite::Message as WebSocketMessage;
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message as _;
6use serde::Serialize;
7use std::any::{Any, TypeId};
8use std::{
9 fmt::Debug,
10 io,
11 time::{Duration, SystemTime, UNIX_EPOCH},
12};
13
14include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
15
16pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
17 const NAME: &'static str;
18 const PRIORITY: MessagePriority;
19 fn into_envelope(
20 self,
21 id: u32,
22 responding_to: Option<u32>,
23 original_sender_id: Option<u32>,
24 ) -> Envelope;
25 fn from_envelope(envelope: Envelope) -> Option<Self>;
26}
27
28pub trait EntityMessage: EnvelopedMessage {
29 fn remote_entity_id(&self) -> u64;
30}
31
32pub trait RequestMessage: EnvelopedMessage {
33 type Response: EnvelopedMessage;
34}
35
36pub trait AnyTypedEnvelope: 'static + Send + Sync {
37 fn payload_type_id(&self) -> TypeId;
38 fn payload_type_name(&self) -> &'static str;
39 fn as_any(&self) -> &dyn Any;
40 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
41 fn is_background(&self) -> bool;
42 fn original_sender_id(&self) -> Option<PeerId>;
43}
44
45pub enum MessagePriority {
46 Foreground,
47 Background,
48}
49
50impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
51 fn payload_type_id(&self) -> TypeId {
52 TypeId::of::<T>()
53 }
54
55 fn payload_type_name(&self) -> &'static str {
56 T::NAME
57 }
58
59 fn as_any(&self) -> &dyn Any {
60 self
61 }
62
63 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
64 self
65 }
66
67 fn is_background(&self) -> bool {
68 matches!(T::PRIORITY, MessagePriority::Background)
69 }
70
71 fn original_sender_id(&self) -> Option<PeerId> {
72 self.original_sender_id
73 }
74}
75
76messages!(
77 (Ack, Foreground),
78 (AddProjectCollaborator, Foreground),
79 (ApplyCodeAction, Background),
80 (ApplyCodeActionResponse, Background),
81 (ApplyCompletionAdditionalEdits, Background),
82 (ApplyCompletionAdditionalEditsResponse, Background),
83 (BufferReloaded, Foreground),
84 (BufferSaved, Foreground),
85 (RemoveContact, Foreground),
86 (ChannelMessageSent, Foreground),
87 (CopyProjectEntry, Foreground),
88 (CreateProjectEntry, Foreground),
89 (DeleteProjectEntry, Foreground),
90 (Error, Foreground),
91 (Follow, Foreground),
92 (FollowResponse, Foreground),
93 (FormatBuffers, Foreground),
94 (FormatBuffersResponse, Foreground),
95 (FuzzySearchUsers, Foreground),
96 (GetChannelMessages, Foreground),
97 (GetChannelMessagesResponse, Foreground),
98 (GetChannels, Foreground),
99 (GetChannelsResponse, Foreground),
100 (GetCodeActions, Background),
101 (GetCodeActionsResponse, Background),
102 (GetHover, Background),
103 (GetHoverResponse, Background),
104 (GetCompletions, Background),
105 (GetCompletionsResponse, Background),
106 (GetDefinition, Background),
107 (GetDefinitionResponse, Background),
108 (GetDocumentHighlights, Background),
109 (GetDocumentHighlightsResponse, Background),
110 (GetReferences, Background),
111 (GetReferencesResponse, Background),
112 (GetProjectSymbols, Background),
113 (GetProjectSymbolsResponse, Background),
114 (GetUsers, Foreground),
115 (UsersResponse, Foreground),
116 (JoinChannel, Foreground),
117 (JoinChannelResponse, Foreground),
118 (JoinProject, Foreground),
119 (JoinProjectResponse, Foreground),
120 (JoinProjectRequestCancelled, Foreground),
121 (LeaveChannel, Foreground),
122 (LeaveProject, Foreground),
123 (OpenBufferById, Background),
124 (OpenBufferByPath, Background),
125 (OpenBufferForSymbol, Background),
126 (OpenBufferForSymbolResponse, Background),
127 (OpenBufferResponse, Background),
128 (PerformRename, Background),
129 (PerformRenameResponse, Background),
130 (PrepareRename, Background),
131 (PrepareRenameResponse, Background),
132 (ProjectEntryResponse, Foreground),
133 (RegisterProjectResponse, Foreground),
134 (Ping, Foreground),
135 (ProjectUnshared, Foreground),
136 (RegisterProject, Foreground),
137 (RegisterProjectActivity, Foreground),
138 (ReloadBuffers, Foreground),
139 (ReloadBuffersResponse, Foreground),
140 (RemoveProjectCollaborator, Foreground),
141 (RenameProjectEntry, Foreground),
142 (RequestContact, Foreground),
143 (RequestJoinProject, Foreground),
144 (RespondToContactRequest, Foreground),
145 (RespondToJoinProjectRequest, Foreground),
146 (SaveBuffer, Foreground),
147 (SearchProject, Background),
148 (SearchProjectResponse, Background),
149 (SendChannelMessage, Foreground),
150 (SendChannelMessageResponse, Foreground),
151 (ShowContacts, Foreground),
152 (StartLanguageServer, Foreground),
153 (Test, Foreground),
154 (Unfollow, Foreground),
155 (UnregisterProject, Foreground),
156 (UpdateBuffer, Foreground),
157 (UpdateBufferFile, Foreground),
158 (UpdateContacts, Foreground),
159 (UpdateDiagnosticSummary, Foreground),
160 (UpdateFollowers, Foreground),
161 (UpdateInviteInfo, Foreground),
162 (UpdateLanguageServer, Foreground),
163 (UpdateProject, Foreground),
164 (UpdateWorktree, Foreground),
165);
166
167request_messages!(
168 (ApplyCodeAction, ApplyCodeActionResponse),
169 (
170 ApplyCompletionAdditionalEdits,
171 ApplyCompletionAdditionalEditsResponse
172 ),
173 (CopyProjectEntry, ProjectEntryResponse),
174 (CreateProjectEntry, ProjectEntryResponse),
175 (DeleteProjectEntry, ProjectEntryResponse),
176 (Follow, FollowResponse),
177 (FormatBuffers, FormatBuffersResponse),
178 (GetChannelMessages, GetChannelMessagesResponse),
179 (GetChannels, GetChannelsResponse),
180 (GetCodeActions, GetCodeActionsResponse),
181 (GetHover, GetHoverResponse),
182 (GetCompletions, GetCompletionsResponse),
183 (GetDefinition, GetDefinitionResponse),
184 (GetDocumentHighlights, GetDocumentHighlightsResponse),
185 (GetReferences, GetReferencesResponse),
186 (GetProjectSymbols, GetProjectSymbolsResponse),
187 (FuzzySearchUsers, UsersResponse),
188 (GetUsers, UsersResponse),
189 (JoinChannel, JoinChannelResponse),
190 (JoinProject, JoinProjectResponse),
191 (OpenBufferById, OpenBufferResponse),
192 (OpenBufferByPath, OpenBufferResponse),
193 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
194 (Ping, Ack),
195 (PerformRename, PerformRenameResponse),
196 (PrepareRename, PrepareRenameResponse),
197 (RegisterProject, RegisterProjectResponse),
198 (ReloadBuffers, ReloadBuffersResponse),
199 (RequestContact, Ack),
200 (RemoveContact, Ack),
201 (RespondToContactRequest, Ack),
202 (RenameProjectEntry, ProjectEntryResponse),
203 (SaveBuffer, BufferSaved),
204 (SearchProject, SearchProjectResponse),
205 (SendChannelMessage, SendChannelMessageResponse),
206 (Test, Test),
207 (UnregisterProject, Ack),
208 (UpdateBuffer, Ack),
209 (UpdateWorktree, Ack),
210);
211
212entity_messages!(
213 project_id,
214 AddProjectCollaborator,
215 ApplyCodeAction,
216 ApplyCompletionAdditionalEdits,
217 BufferReloaded,
218 BufferSaved,
219 CopyProjectEntry,
220 CreateProjectEntry,
221 DeleteProjectEntry,
222 Follow,
223 FormatBuffers,
224 GetCodeActions,
225 GetCompletions,
226 GetDefinition,
227 GetDocumentHighlights,
228 GetHover,
229 GetReferences,
230 GetProjectSymbols,
231 JoinProject,
232 JoinProjectRequestCancelled,
233 LeaveProject,
234 OpenBufferById,
235 OpenBufferByPath,
236 OpenBufferForSymbol,
237 PerformRename,
238 PrepareRename,
239 ProjectUnshared,
240 RegisterProjectActivity,
241 ReloadBuffers,
242 RemoveProjectCollaborator,
243 RenameProjectEntry,
244 RequestJoinProject,
245 SaveBuffer,
246 SearchProject,
247 StartLanguageServer,
248 Unfollow,
249 UnregisterProject,
250 UpdateBuffer,
251 UpdateBufferFile,
252 UpdateDiagnosticSummary,
253 UpdateFollowers,
254 UpdateLanguageServer,
255 UpdateProject,
256 UpdateWorktree,
257);
258
259entity_messages!(channel_id, ChannelMessageSent);
260
261const MAX_BUFFER_LEN: usize = 1 * 1024 * 1024;
262
263/// A stream of protobuf messages.
264pub struct MessageStream<S> {
265 stream: S,
266 encoding_buffer: Vec<u8>,
267}
268
269#[derive(Debug)]
270pub enum Message {
271 Envelope(Envelope),
272 Ping,
273 Pong,
274}
275
276impl<S> MessageStream<S> {
277 pub fn new(stream: S) -> Self {
278 Self {
279 stream,
280 encoding_buffer: Vec::new(),
281 }
282 }
283
284 pub fn inner_mut(&mut self) -> &mut S {
285 &mut self.stream
286 }
287}
288
289impl<S> MessageStream<S>
290where
291 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
292{
293 pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
294 #[cfg(any(test, feature = "test-support"))]
295 const COMPRESSION_LEVEL: i32 = -7;
296
297 #[cfg(not(any(test, feature = "test-support")))]
298 const COMPRESSION_LEVEL: i32 = 4;
299
300 match message {
301 Message::Envelope(message) => {
302 self.encoding_buffer.reserve(message.encoded_len());
303 message
304 .encode(&mut self.encoding_buffer)
305 .map_err(|err| io::Error::from(err))?;
306 let buffer =
307 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
308 .unwrap();
309
310 self.encoding_buffer.clear();
311 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
312 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
313 }
314 Message::Ping => {
315 self.stream
316 .send(WebSocketMessage::Ping(Default::default()))
317 .await?;
318 }
319 Message::Pong => {
320 self.stream
321 .send(WebSocketMessage::Pong(Default::default()))
322 .await?;
323 }
324 }
325
326 Ok(())
327 }
328}
329
330impl<S> MessageStream<S>
331where
332 S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
333{
334 pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
335 while let Some(bytes) = self.stream.next().await {
336 match bytes? {
337 WebSocketMessage::Binary(bytes) => {
338 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
339 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
340 .map_err(io::Error::from)?;
341
342 self.encoding_buffer.clear();
343 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
344 return Ok(Message::Envelope(envelope));
345 }
346 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
347 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
348 WebSocketMessage::Close(_) => break,
349 _ => {}
350 }
351 }
352 Err(anyhow!("connection closed"))
353 }
354}
355
356impl Into<SystemTime> for Timestamp {
357 fn into(self) -> SystemTime {
358 UNIX_EPOCH
359 .checked_add(Duration::new(self.seconds, self.nanos))
360 .unwrap()
361 }
362}
363
364impl From<SystemTime> for Timestamp {
365 fn from(time: SystemTime) -> Self {
366 let duration = time.duration_since(UNIX_EPOCH).unwrap();
367 Self {
368 seconds: duration.as_secs(),
369 nanos: duration.subsec_nanos(),
370 }
371 }
372}
373
374impl From<u128> for Nonce {
375 fn from(nonce: u128) -> Self {
376 let upper_half = (nonce >> 64) as u64;
377 let lower_half = nonce as u64;
378 Self {
379 upper_half,
380 lower_half,
381 }
382 }
383}
384
385impl From<Nonce> for u128 {
386 fn from(nonce: Nonce) -> Self {
387 let upper_half = (nonce.upper_half as u128) << 64;
388 let lower_half = nonce.lower_half as u128;
389 upper_half | lower_half
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[gpui::test]
398 async fn test_buffer_size() {
399 let (tx, rx) = futures::channel::mpsc::unbounded();
400 let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
401 sink.write(Message::Envelope(Envelope {
402 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
403 root_name: "abcdefg".repeat(10),
404 ..Default::default()
405 })),
406 ..Default::default()
407 }))
408 .await
409 .unwrap();
410 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
411 sink.write(Message::Envelope(Envelope {
412 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
413 root_name: "abcdefg".repeat(1000000),
414 ..Default::default()
415 })),
416 ..Default::default()
417 }))
418 .await
419 .unwrap();
420 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
421
422 let mut stream = MessageStream::new(rx.map(|msg| anyhow::Ok(msg)));
423 stream.read().await.unwrap();
424 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
425 stream.read().await.unwrap();
426 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
427 }
428}