1use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
2use anyhow::{anyhow, Result};
3use async_tungstenite::tungstenite::Message as WebSocketMessage;
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message as _;
6use serde::Serialize;
7use std::any::{Any, TypeId};
8use std::{cmp, iter, mem};
9use std::{
10 fmt::Debug,
11 io,
12 time::{Duration, SystemTime, UNIX_EPOCH},
13};
14
15include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
16
17pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
18 const NAME: &'static str;
19 const PRIORITY: MessagePriority;
20 fn into_envelope(
21 self,
22 id: u32,
23 responding_to: Option<u32>,
24 original_sender_id: Option<u32>,
25 ) -> Envelope;
26 fn from_envelope(envelope: Envelope) -> Option<Self>;
27}
28
29pub trait EntityMessage: EnvelopedMessage {
30 fn remote_entity_id(&self) -> u64;
31}
32
33pub trait RequestMessage: EnvelopedMessage {
34 type Response: EnvelopedMessage;
35}
36
37pub trait AnyTypedEnvelope: 'static + Send + Sync {
38 fn payload_type_id(&self) -> TypeId;
39 fn payload_type_name(&self) -> &'static str;
40 fn as_any(&self) -> &dyn Any;
41 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
42 fn is_background(&self) -> bool;
43 fn original_sender_id(&self) -> Option<PeerId>;
44}
45
46pub enum MessagePriority {
47 Foreground,
48 Background,
49}
50
51impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
52 fn payload_type_id(&self) -> TypeId {
53 TypeId::of::<T>()
54 }
55
56 fn payload_type_name(&self) -> &'static str {
57 T::NAME
58 }
59
60 fn as_any(&self) -> &dyn Any {
61 self
62 }
63
64 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
65 self
66 }
67
68 fn is_background(&self) -> bool {
69 matches!(T::PRIORITY, MessagePriority::Background)
70 }
71
72 fn original_sender_id(&self) -> Option<PeerId> {
73 self.original_sender_id
74 }
75}
76
77messages!(
78 (Ack, Foreground),
79 (AddProjectCollaborator, Foreground),
80 (ApplyCodeAction, Background),
81 (ApplyCodeActionResponse, Background),
82 (ApplyCompletionAdditionalEdits, Background),
83 (ApplyCompletionAdditionalEditsResponse, Background),
84 (BufferReloaded, Foreground),
85 (BufferSaved, Foreground),
86 (RemoveContact, Foreground),
87 (ChannelMessageSent, Foreground),
88 (CopyProjectEntry, Foreground),
89 (CreateBufferForPeer, Foreground),
90 (CreateProjectEntry, Foreground),
91 (DeleteProjectEntry, Foreground),
92 (Error, Foreground),
93 (Follow, Foreground),
94 (FollowResponse, Foreground),
95 (FormatBuffers, Foreground),
96 (FormatBuffersResponse, Foreground),
97 (FuzzySearchUsers, Foreground),
98 (GetChannelMessages, Foreground),
99 (GetChannelMessagesResponse, Foreground),
100 (GetChannels, Foreground),
101 (GetChannelsResponse, Foreground),
102 (GetCodeActions, Background),
103 (GetCodeActionsResponse, Background),
104 (GetHover, Background),
105 (GetHoverResponse, Background),
106 (GetCompletions, Background),
107 (GetCompletionsResponse, Background),
108 (GetDefinition, Background),
109 (GetDefinitionResponse, Background),
110 (GetTypeDefinition, Background),
111 (GetTypeDefinitionResponse, Background),
112 (GetDocumentHighlights, Background),
113 (GetDocumentHighlightsResponse, Background),
114 (GetReferences, Background),
115 (GetReferencesResponse, Background),
116 (GetProjectSymbols, Background),
117 (GetProjectSymbolsResponse, Background),
118 (GetUsers, Foreground),
119 (UsersResponse, Foreground),
120 (JoinChannel, Foreground),
121 (JoinChannelResponse, Foreground),
122 (JoinProject, Foreground),
123 (JoinProjectResponse, Foreground),
124 (JoinProjectRequestCancelled, Foreground),
125 (LeaveChannel, Foreground),
126 (LeaveProject, Foreground),
127 (OpenBufferById, Background),
128 (OpenBufferByPath, Background),
129 (OpenBufferForSymbol, Background),
130 (OpenBufferForSymbolResponse, Background),
131 (OpenBufferResponse, Background),
132 (PerformRename, Background),
133 (PerformRenameResponse, Background),
134 (PrepareRename, Background),
135 (PrepareRenameResponse, Background),
136 (ProjectEntryResponse, Foreground),
137 (ProjectUnshared, Foreground),
138 (RegisterProjectResponse, Foreground),
139 (Ping, Foreground),
140 (RegisterProject, Foreground),
141 (RegisterProjectActivity, Foreground),
142 (ReloadBuffers, Foreground),
143 (ReloadBuffersResponse, Foreground),
144 (RemoveProjectCollaborator, Foreground),
145 (RenameProjectEntry, Foreground),
146 (RequestContact, Foreground),
147 (RequestJoinProject, Foreground),
148 (RespondToContactRequest, Foreground),
149 (RespondToJoinProjectRequest, Foreground),
150 (SaveBuffer, Foreground),
151 (SearchProject, Background),
152 (SearchProjectResponse, Background),
153 (SendChannelMessage, Foreground),
154 (SendChannelMessageResponse, Foreground),
155 (ShowContacts, Foreground),
156 (StartLanguageServer, Foreground),
157 (Test, Foreground),
158 (Unfollow, Foreground),
159 (UnregisterProject, Foreground),
160 (UpdateBuffer, Foreground),
161 (UpdateBufferFile, Foreground),
162 (UpdateContacts, Foreground),
163 (UpdateDiagnosticSummary, Foreground),
164 (UpdateFollowers, Foreground),
165 (UpdateInviteInfo, Foreground),
166 (UpdateLanguageServer, Foreground),
167 (UpdateProject, Foreground),
168 (UpdateWorktree, Foreground),
169 (UpdateWorktreeExtensions, Background),
170);
171
172request_messages!(
173 (ApplyCodeAction, ApplyCodeActionResponse),
174 (
175 ApplyCompletionAdditionalEdits,
176 ApplyCompletionAdditionalEditsResponse
177 ),
178 (CopyProjectEntry, ProjectEntryResponse),
179 (CreateProjectEntry, ProjectEntryResponse),
180 (DeleteProjectEntry, ProjectEntryResponse),
181 (Follow, FollowResponse),
182 (FormatBuffers, FormatBuffersResponse),
183 (GetChannelMessages, GetChannelMessagesResponse),
184 (GetChannels, GetChannelsResponse),
185 (GetCodeActions, GetCodeActionsResponse),
186 (GetHover, GetHoverResponse),
187 (GetCompletions, GetCompletionsResponse),
188 (GetDefinition, GetDefinitionResponse),
189 (GetTypeDefinition, GetTypeDefinitionResponse),
190 (GetDocumentHighlights, GetDocumentHighlightsResponse),
191 (GetReferences, GetReferencesResponse),
192 (GetProjectSymbols, GetProjectSymbolsResponse),
193 (FuzzySearchUsers, UsersResponse),
194 (GetUsers, UsersResponse),
195 (JoinChannel, JoinChannelResponse),
196 (JoinProject, JoinProjectResponse),
197 (OpenBufferById, OpenBufferResponse),
198 (OpenBufferByPath, OpenBufferResponse),
199 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
200 (Ping, Ack),
201 (PerformRename, PerformRenameResponse),
202 (PrepareRename, PrepareRenameResponse),
203 (RegisterProject, RegisterProjectResponse),
204 (ReloadBuffers, ReloadBuffersResponse),
205 (RequestContact, Ack),
206 (RemoveContact, Ack),
207 (RespondToContactRequest, Ack),
208 (RenameProjectEntry, ProjectEntryResponse),
209 (SaveBuffer, BufferSaved),
210 (SearchProject, SearchProjectResponse),
211 (SendChannelMessage, SendChannelMessageResponse),
212 (Test, Test),
213 (UnregisterProject, Ack),
214 (UpdateBuffer, Ack),
215 (UpdateWorktree, Ack),
216);
217
218entity_messages!(
219 project_id,
220 AddProjectCollaborator,
221 ApplyCodeAction,
222 ApplyCompletionAdditionalEdits,
223 BufferReloaded,
224 BufferSaved,
225 CopyProjectEntry,
226 CreateBufferForPeer,
227 CreateProjectEntry,
228 DeleteProjectEntry,
229 Follow,
230 FormatBuffers,
231 GetCodeActions,
232 GetCompletions,
233 GetDefinition,
234 GetTypeDefinition,
235 GetDocumentHighlights,
236 GetHover,
237 GetReferences,
238 GetProjectSymbols,
239 JoinProject,
240 JoinProjectRequestCancelled,
241 LeaveProject,
242 OpenBufferById,
243 OpenBufferByPath,
244 OpenBufferForSymbol,
245 PerformRename,
246 PrepareRename,
247 ProjectUnshared,
248 RegisterProjectActivity,
249 ReloadBuffers,
250 RemoveProjectCollaborator,
251 RenameProjectEntry,
252 RequestJoinProject,
253 SaveBuffer,
254 SearchProject,
255 StartLanguageServer,
256 Unfollow,
257 UnregisterProject,
258 UpdateBuffer,
259 UpdateBufferFile,
260 UpdateDiagnosticSummary,
261 UpdateFollowers,
262 UpdateLanguageServer,
263 UpdateProject,
264 UpdateWorktree,
265 UpdateWorktreeExtensions,
266);
267
268entity_messages!(channel_id, ChannelMessageSent);
269
270const KIB: usize = 1024;
271const MIB: usize = KIB * 1024;
272const MAX_BUFFER_LEN: usize = MIB;
273
274/// A stream of protobuf messages.
275pub struct MessageStream<S> {
276 stream: S,
277 encoding_buffer: Vec<u8>,
278}
279
280#[allow(clippy::large_enum_variant)]
281#[derive(Debug)]
282pub enum Message {
283 Envelope(Envelope),
284 Ping,
285 Pong,
286}
287
288impl<S> MessageStream<S> {
289 pub fn new(stream: S) -> Self {
290 Self {
291 stream,
292 encoding_buffer: Vec::new(),
293 }
294 }
295
296 pub fn inner_mut(&mut self) -> &mut S {
297 &mut self.stream
298 }
299}
300
301impl<S> MessageStream<S>
302where
303 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
304{
305 pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
306 #[cfg(any(test, feature = "test-support"))]
307 const COMPRESSION_LEVEL: i32 = -7;
308
309 #[cfg(not(any(test, feature = "test-support")))]
310 const COMPRESSION_LEVEL: i32 = 4;
311
312 match message {
313 Message::Envelope(message) => {
314 self.encoding_buffer.reserve(message.encoded_len());
315 message
316 .encode(&mut self.encoding_buffer)
317 .map_err(io::Error::from)?;
318 let buffer =
319 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
320 .unwrap();
321
322 self.encoding_buffer.clear();
323 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
324 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
325 }
326 Message::Ping => {
327 self.stream
328 .send(WebSocketMessage::Ping(Default::default()))
329 .await?;
330 }
331 Message::Pong => {
332 self.stream
333 .send(WebSocketMessage::Pong(Default::default()))
334 .await?;
335 }
336 }
337
338 Ok(())
339 }
340}
341
342impl<S> MessageStream<S>
343where
344 S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
345{
346 pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
347 while let Some(bytes) = self.stream.next().await {
348 match bytes? {
349 WebSocketMessage::Binary(bytes) => {
350 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
351 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
352 .map_err(io::Error::from)?;
353
354 self.encoding_buffer.clear();
355 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
356 return Ok(Message::Envelope(envelope));
357 }
358 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
359 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
360 WebSocketMessage::Close(_) => break,
361 _ => {}
362 }
363 }
364 Err(anyhow!("connection closed"))
365 }
366}
367
368impl From<Timestamp> for SystemTime {
369 fn from(val: Timestamp) -> Self {
370 UNIX_EPOCH
371 .checked_add(Duration::new(val.seconds, val.nanos))
372 .unwrap()
373 }
374}
375
376impl From<SystemTime> for Timestamp {
377 fn from(time: SystemTime) -> Self {
378 let duration = time.duration_since(UNIX_EPOCH).unwrap();
379 Self {
380 seconds: duration.as_secs(),
381 nanos: duration.subsec_nanos(),
382 }
383 }
384}
385
386impl From<u128> for Nonce {
387 fn from(nonce: u128) -> Self {
388 let upper_half = (nonce >> 64) as u64;
389 let lower_half = nonce as u64;
390 Self {
391 upper_half,
392 lower_half,
393 }
394 }
395}
396
397impl From<Nonce> for u128 {
398 fn from(nonce: Nonce) -> Self {
399 let upper_half = (nonce.upper_half as u128) << 64;
400 let lower_half = nonce.lower_half as u128;
401 upper_half | lower_half
402 }
403}
404
405pub fn split_worktree_update(
406 mut message: UpdateWorktree,
407 max_chunk_size: usize,
408) -> impl Iterator<Item = UpdateWorktree> {
409 let mut done = false;
410 iter::from_fn(move || {
411 if done {
412 return None;
413 }
414
415 let chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
416 let updated_entries = message.updated_entries.drain(..chunk_size).collect();
417 done = message.updated_entries.is_empty();
418 Some(UpdateWorktree {
419 project_id: message.project_id,
420 worktree_id: message.worktree_id,
421 root_name: message.root_name.clone(),
422 updated_entries,
423 removed_entries: mem::take(&mut message.removed_entries),
424 scan_id: message.scan_id,
425 is_last_update: done && message.is_last_update,
426 })
427 })
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[gpui::test]
435 async fn test_buffer_size() {
436 let (tx, rx) = futures::channel::mpsc::unbounded();
437 let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
438 sink.write(Message::Envelope(Envelope {
439 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
440 root_name: "abcdefg".repeat(10),
441 ..Default::default()
442 })),
443 ..Default::default()
444 }))
445 .await
446 .unwrap();
447 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
448 sink.write(Message::Envelope(Envelope {
449 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
450 root_name: "abcdefg".repeat(1000000),
451 ..Default::default()
452 })),
453 ..Default::default()
454 }))
455 .await
456 .unwrap();
457 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
458
459 let mut stream = MessageStream::new(rx.map(anyhow::Ok));
460 stream.read().await.unwrap();
461 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
462 stream.read().await.unwrap();
463 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
464 }
465}