1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
2use anyhow::{anyhow, Result};
3use async_tungstenite::tungstenite::Message as WebSocketMessage;
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message as _;
6use serde::Serialize;
7use std::any::{Any, TypeId};
8use std::{cmp, iter, mem};
9use std::{
10 fmt::Debug,
11 io,
12 time::{Duration, SystemTime, UNIX_EPOCH},
13};
14
15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
16
17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
18 const NAME: &'static str;
19 const PRIORITY: MessagePriority;
20 fn into_envelope(
21 self,
22 id: u32,
23 responding_to: Option<u32>,
24 original_sender_id: Option<u32>,
25 ) -> Envelope;
26 fn from_envelope(envelope: Envelope) -> Option<Self>;
27}
28
29pub trait EntityMessage: EnvelopedMessage {
30 fn remote_entity_id(&self) -> u64;
31}
32
33pub trait RequestMessage: EnvelopedMessage {
34 type Response: EnvelopedMessage;
35}
36
37pub trait AnyTypedEnvelope: 'static + Send + Sync {
38 fn payload_type_id(&self) -> TypeId;
39 fn payload_type_name(&self) -> &'static str;
40 fn as_any(&self) -> &dyn Any;
41 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
42 fn is_background(&self) -> bool;
43 fn original_sender_id(&self) -> Option<PeerId>;
44}
45
46pub enum MessagePriority {
47 Foreground,
48 Background,
49}
50
51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
52 fn payload_type_id(&self) -> TypeId {
53 TypeId::of::<T>()
54 }
55
56 fn payload_type_name(&self) -> &'static str {
57 T::NAME
58 }
59
60 fn as_any(&self) -> &dyn Any {
61 self
62 }
63
64 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
65 self
66 }
67
68 fn is_background(&self) -> bool {
69 matches!(T::PRIORITY, MessagePriority::Background)
70 }
71
72 fn original_sender_id(&self) -> Option<PeerId> {
73 self.original_sender_id
74 }
75}
76
77messages!(
78 (Ack, Foreground),
79 (AddProjectCollaborator, Foreground),
80 (ApplyCodeAction, Background),
81 (ApplyCodeActionResponse, Background),
82 (ApplyCompletionAdditionalEdits, Background),
83 (ApplyCompletionAdditionalEditsResponse, Background),
84 (BufferReloaded, Foreground),
85 (BufferSaved, Foreground),
86 (RemoveContact, Foreground),
87 (ChannelMessageSent, Foreground),
88 (CopyProjectEntry, Foreground),
89 (CreateProjectEntry, Foreground),
90 (DeleteProjectEntry, Foreground),
91 (Error, Foreground),
92 (Follow, Foreground),
93 (FollowResponse, Foreground),
94 (FormatBuffers, Foreground),
95 (FormatBuffersResponse, Foreground),
96 (FuzzySearchUsers, Foreground),
97 (GetChannelMessages, Foreground),
98 (GetChannelMessagesResponse, Foreground),
99 (GetChannels, Foreground),
100 (GetChannelsResponse, Foreground),
101 (GetCodeActions, Background),
102 (GetCodeActionsResponse, Background),
103 (GetHover, Background),
104 (GetHoverResponse, Background),
105 (GetCompletions, Background),
106 (GetCompletionsResponse, Background),
107 (GetDefinition, Background),
108 (GetDefinitionResponse, Background),
109 (GetTypeDefinition, Background),
110 (GetTypeDefinitionResponse, Background),
111 (GetDocumentHighlights, Background),
112 (GetDocumentHighlightsResponse, Background),
113 (GetReferences, Background),
114 (GetReferencesResponse, Background),
115 (GetProjectSymbols, Background),
116 (GetProjectSymbolsResponse, Background),
117 (GetUsers, Foreground),
118 (UsersResponse, Foreground),
119 (JoinChannel, Foreground),
120 (JoinChannelResponse, Foreground),
121 (JoinProject, Foreground),
122 (JoinProjectResponse, Foreground),
123 (JoinProjectRequestCancelled, Foreground),
124 (LeaveChannel, Foreground),
125 (LeaveProject, Foreground),
126 (OpenBufferById, Background),
127 (OpenBufferByPath, Background),
128 (OpenBufferForSymbol, Background),
129 (OpenBufferForSymbolResponse, Background),
130 (OpenBufferResponse, Background),
131 (PerformRename, Background),
132 (PerformRenameResponse, Background),
133 (PrepareRename, Background),
134 (PrepareRenameResponse, Background),
135 (ProjectEntryResponse, Foreground),
136 (ProjectUnshared, Foreground),
137 (RegisterProjectResponse, Foreground),
138 (Ping, Foreground),
139 (RegisterProject, Foreground),
140 (RegisterProjectActivity, Foreground),
141 (ReloadBuffers, Foreground),
142 (ReloadBuffersResponse, Foreground),
143 (RemoveProjectCollaborator, Foreground),
144 (RenameProjectEntry, Foreground),
145 (RequestContact, Foreground),
146 (RequestJoinProject, Foreground),
147 (RespondToContactRequest, Foreground),
148 (RespondToJoinProjectRequest, Foreground),
149 (SaveBuffer, Foreground),
150 (SearchProject, Background),
151 (SearchProjectResponse, Background),
152 (SendChannelMessage, Foreground),
153 (SendChannelMessageResponse, Foreground),
154 (ShowContacts, Foreground),
155 (StartLanguageServer, Foreground),
156 (Test, Foreground),
157 (Unfollow, Foreground),
158 (UnregisterProject, Foreground),
159 (UpdateBuffer, Foreground),
160 (UpdateBufferFile, Foreground),
161 (UpdateContacts, Foreground),
162 (UpdateDiagnosticSummary, Foreground),
163 (UpdateFollowers, Foreground),
164 (UpdateInviteInfo, Foreground),
165 (UpdateLanguageServer, Foreground),
166 (UpdateProject, Foreground),
167 (UpdateWorktree, Foreground),
168 (UpdateWorktreeExtensions, Background),
169);
170
171request_messages!(
172 (ApplyCodeAction, ApplyCodeActionResponse),
173 (
174 ApplyCompletionAdditionalEdits,
175 ApplyCompletionAdditionalEditsResponse
176 ),
177 (CopyProjectEntry, ProjectEntryResponse),
178 (CreateProjectEntry, ProjectEntryResponse),
179 (DeleteProjectEntry, ProjectEntryResponse),
180 (Follow, FollowResponse),
181 (FormatBuffers, FormatBuffersResponse),
182 (GetChannelMessages, GetChannelMessagesResponse),
183 (GetChannels, GetChannelsResponse),
184 (GetCodeActions, GetCodeActionsResponse),
185 (GetHover, GetHoverResponse),
186 (GetCompletions, GetCompletionsResponse),
187 (GetDefinition, GetDefinitionResponse),
188 (GetTypeDefinition, GetTypeDefinitionResponse),
189 (GetDocumentHighlights, GetDocumentHighlightsResponse),
190 (GetReferences, GetReferencesResponse),
191 (GetProjectSymbols, GetProjectSymbolsResponse),
192 (FuzzySearchUsers, UsersResponse),
193 (GetUsers, UsersResponse),
194 (JoinChannel, JoinChannelResponse),
195 (JoinProject, JoinProjectResponse),
196 (OpenBufferById, OpenBufferResponse),
197 (OpenBufferByPath, OpenBufferResponse),
198 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
199 (Ping, Ack),
200 (PerformRename, PerformRenameResponse),
201 (PrepareRename, PrepareRenameResponse),
202 (RegisterProject, RegisterProjectResponse),
203 (ReloadBuffers, ReloadBuffersResponse),
204 (RequestContact, Ack),
205 (RemoveContact, Ack),
206 (RespondToContactRequest, Ack),
207 (RenameProjectEntry, ProjectEntryResponse),
208 (SaveBuffer, BufferSaved),
209 (SearchProject, SearchProjectResponse),
210 (SendChannelMessage, SendChannelMessageResponse),
211 (Test, Test),
212 (UnregisterProject, Ack),
213 (UpdateBuffer, Ack),
214 (UpdateWorktree, Ack),
215);
216
217entity_messages!(
218 project_id,
219 AddProjectCollaborator,
220 ApplyCodeAction,
221 ApplyCompletionAdditionalEdits,
222 BufferReloaded,
223 BufferSaved,
224 CopyProjectEntry,
225 CreateProjectEntry,
226 DeleteProjectEntry,
227 Follow,
228 FormatBuffers,
229 GetCodeActions,
230 GetCompletions,
231 GetDefinition,
232 GetTypeDefinition,
233 GetDocumentHighlights,
234 GetHover,
235 GetReferences,
236 GetProjectSymbols,
237 JoinProject,
238 JoinProjectRequestCancelled,
239 LeaveProject,
240 OpenBufferById,
241 OpenBufferByPath,
242 OpenBufferForSymbol,
243 PerformRename,
244 PrepareRename,
245 ProjectUnshared,
246 RegisterProjectActivity,
247 ReloadBuffers,
248 RemoveProjectCollaborator,
249 RenameProjectEntry,
250 RequestJoinProject,
251 SaveBuffer,
252 SearchProject,
253 StartLanguageServer,
254 Unfollow,
255 UnregisterProject,
256 UpdateBuffer,
257 UpdateBufferFile,
258 UpdateDiagnosticSummary,
259 UpdateFollowers,
260 UpdateLanguageServer,
261 UpdateProject,
262 UpdateWorktree,
263 UpdateWorktreeExtensions,
264);
265
266entity_messages!(channel_id, ChannelMessageSent);
267
268const KIB: usize = 1024;
269const MIB: usize = KIB * 1024;
270const MAX_BUFFER_LEN: usize = MIB;
271
272/// A stream of protobuf messages.
273pub struct MessageStream<S> {
274 stream: S,
275 encoding_buffer: Vec<u8>,
276}
277
278#[allow(clippy::large_enum_variant)]
279#[derive(Debug)]
280pub enum Message {
281 Envelope(Envelope),
282 Ping,
283 Pong,
284}
285
286impl<S> MessageStream<S> {
287 pub fn new(stream: S) -> Self {
288 Self {
289 stream,
290 encoding_buffer: Vec::new(),
291 }
292 }
293
294 pub fn inner_mut(&mut self) -> &mut S {
295 &mut self.stream
296 }
297}
298
299impl<S> MessageStream<S>
300where
301 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
302{
303 pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
304 #[cfg(any(test, feature = "test-support"))]
305 const COMPRESSION_LEVEL: i32 = -7;
306
307 #[cfg(not(any(test, feature = "test-support")))]
308 const COMPRESSION_LEVEL: i32 = 4;
309
310 match message {
311 Message::Envelope(message) => {
312 self.encoding_buffer.reserve(message.encoded_len());
313 message
314 .encode(&mut self.encoding_buffer)
315 .map_err(io::Error::from)?;
316 let buffer =
317 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
318 .unwrap();
319
320 self.encoding_buffer.clear();
321 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
322 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
323 }
324 Message::Ping => {
325 self.stream
326 .send(WebSocketMessage::Ping(Default::default()))
327 .await?;
328 }
329 Message::Pong => {
330 self.stream
331 .send(WebSocketMessage::Pong(Default::default()))
332 .await?;
333 }
334 }
335
336 Ok(())
337 }
338}
339
340impl<S> MessageStream<S>
341where
342 S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
343{
344 pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
345 while let Some(bytes) = self.stream.next().await {
346 match bytes? {
347 WebSocketMessage::Binary(bytes) => {
348 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
349 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
350 .map_err(io::Error::from)?;
351
352 self.encoding_buffer.clear();
353 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
354 return Ok(Message::Envelope(envelope));
355 }
356 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
357 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
358 WebSocketMessage::Close(_) => break,
359 _ => {}
360 }
361 }
362 Err(anyhow!("connection closed"))
363 }
364}
365
366impl From<Timestamp> for SystemTime {
367 fn from(val: Timestamp) -> Self {
368 UNIX_EPOCH
369 .checked_add(Duration::new(val.seconds, val.nanos))
370 .unwrap()
371 }
372}
373
374impl From<SystemTime> for Timestamp {
375 fn from(time: SystemTime) -> Self {
376 let duration = time.duration_since(UNIX_EPOCH).unwrap();
377 Self {
378 seconds: duration.as_secs(),
379 nanos: duration.subsec_nanos(),
380 }
381 }
382}
383
384impl From<u128> for Nonce {
385 fn from(nonce: u128) -> Self {
386 let upper_half = (nonce >> 64) as u64;
387 let lower_half = nonce as u64;
388 Self {
389 upper_half,
390 lower_half,
391 }
392 }
393}
394
395impl From<Nonce> for u128 {
396 fn from(nonce: Nonce) -> Self {
397 let upper_half = (nonce.upper_half as u128) << 64;
398 let lower_half = nonce.lower_half as u128;
399 upper_half | lower_half
400 }
401}
402
403pub fn split_worktree_update(
404 mut message: UpdateWorktree,
405 max_chunk_size: usize,
406) -> impl Iterator<Item = UpdateWorktree> {
407 let mut done = false;
408 iter::from_fn(move || {
409 if done {
410 return None;
411 }
412
413 let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
414 let updated_entries = message.updated_entries.drain(..chunk_size).collect();
415 done = message.updated_entries.is_empty();
416 Some(UpdateWorktree {
417 project_id: message.project_id,
418 worktree_id: message.worktree_id,
419 root_name: message.root_name.clone(),
420 updated_entries,
421 removed_entries: mem::take(&mut message.removed_entries),
422 scan_id: message.scan_id,
423 is_last_update: done && message.is_last_update,
424 })
425 })
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[gpui::test]
433 async fn test_buffer_size() {
434 let (tx, rx) = futures::channel::mpsc::unbounded();
435 let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
436 sink.write(Message::Envelope(Envelope {
437 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
438 root_name: "abcdefg".repeat(10),
439 ..Default::default()
440 })),
441 ..Default::default()
442 }))
443 .await
444 .unwrap();
445 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
446 sink.write(Message::Envelope(Envelope {
447 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
448 root_name: "abcdefg".repeat(1000000),
449 ..Default::default()
450 })),
451 ..Default::default()
452 }))
453 .await
454 .unwrap();
455 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
456
457 let mut stream = MessageStream::new(rx.map(anyhow::Ok));
458 stream.read().await.unwrap();
459 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
460 stream.read().await.unwrap();
461 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
462 }
463}