1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
2use anyhow::{anyhow, Result};
3use async_tungstenite::tungstenite::Message as WebSocketMessage;
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message as _;
6use serde::Serialize;
7use std::any::{Any, TypeId};
8use std::{
9 cmp,
10 fmt::Debug,
11 io, iter,
12 time::{Duration, SystemTime, UNIX_EPOCH},
13};
14use std::{fmt, mem};
15
16include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
17
18pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
19 const NAME: &'static str;
20 const PRIORITY: MessagePriority;
21 fn into_envelope(
22 self,
23 id: u32,
24 responding_to: Option<u32>,
25 original_sender_id: Option<PeerId>,
26 ) -> Envelope;
27 fn from_envelope(envelope: Envelope) -> Option<Self>;
28}
29
30pub trait EntityMessage: EnvelopedMessage {
31 fn remote_entity_id(&self) -> u64;
32}
33
34pub trait RequestMessage: EnvelopedMessage {
35 type Response: EnvelopedMessage;
36}
37
38pub trait AnyTypedEnvelope: 'static + Send + Sync {
39 fn payload_type_id(&self) -> TypeId;
40 fn payload_type_name(&self) -> &'static str;
41 fn as_any(&self) -> &dyn Any;
42 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
43 fn is_background(&self) -> bool;
44 fn original_sender_id(&self) -> Option<PeerId>;
45 fn sender_id(&self) -> ConnectionId;
46 fn message_id(&self) -> u32;
47}
48
49pub enum MessagePriority {
50 Foreground,
51 Background,
52}
53
54impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
55 fn payload_type_id(&self) -> TypeId {
56 TypeId::of::<T>()
57 }
58
59 fn payload_type_name(&self) -> &'static str {
60 T::NAME
61 }
62
63 fn as_any(&self) -> &dyn Any {
64 self
65 }
66
67 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
68 self
69 }
70
71 fn is_background(&self) -> bool {
72 matches!(T::PRIORITY, MessagePriority::Background)
73 }
74
75 fn original_sender_id(&self) -> Option<PeerId> {
76 self.original_sender_id
77 }
78
79 fn sender_id(&self) -> ConnectionId {
80 self.sender_id
81 }
82
83 fn message_id(&self) -> u32 {
84 self.message_id
85 }
86}
87
88impl PeerId {
89 pub fn from_u64(peer_id: u64) -> Self {
90 let owner_id = (peer_id >> 32) as u32;
91 let id = peer_id as u32;
92 Self { owner_id, id }
93 }
94
95 pub fn as_u64(self) -> u64 {
96 ((self.owner_id as u64) << 32) | (self.id as u64)
97 }
98}
99
100impl Copy for PeerId {}
101
102impl Eq for PeerId {}
103
104impl Ord for PeerId {
105 fn cmp(&self, other: &Self) -> cmp::Ordering {
106 self.owner_id
107 .cmp(&other.owner_id)
108 .then_with(|| self.id.cmp(&other.id))
109 }
110}
111
112impl PartialOrd for PeerId {
113 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
114 Some(self.cmp(other))
115 }
116}
117
118impl std::hash::Hash for PeerId {
119 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
120 self.owner_id.hash(state);
121 self.id.hash(state);
122 }
123}
124
125impl fmt::Display for PeerId {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 write!(f, "{}/{}", self.owner_id, self.id)
128 }
129}
130
131messages!(
132 (Ack, Foreground),
133 (AddProjectCollaborator, Foreground),
134 (ApplyCodeAction, Background),
135 (ApplyCodeActionResponse, Background),
136 (ApplyCompletionAdditionalEdits, Background),
137 (ApplyCompletionAdditionalEditsResponse, Background),
138 (BufferReloaded, Foreground),
139 (BufferSaved, Foreground),
140 (Call, Foreground),
141 (CallCanceled, Foreground),
142 (CancelCall, Foreground),
143 (ChannelMessageSent, Foreground),
144 (CopyProjectEntry, Foreground),
145 (CreateBufferForPeer, Foreground),
146 (CreateProjectEntry, Foreground),
147 (CreateRoom, Foreground),
148 (CreateRoomResponse, Foreground),
149 (DeclineCall, Foreground),
150 (DeleteProjectEntry, Foreground),
151 (Error, Foreground),
152 (Follow, Foreground),
153 (FollowResponse, Foreground),
154 (FormatBuffers, Foreground),
155 (FormatBuffersResponse, Foreground),
156 (FuzzySearchUsers, Foreground),
157 (GetChannelMessages, Foreground),
158 (GetChannelMessagesResponse, Foreground),
159 (GetChannels, Foreground),
160 (GetChannelsResponse, Foreground),
161 (GetCodeActions, Background),
162 (GetCodeActionsResponse, Background),
163 (GetHover, Background),
164 (GetHoverResponse, Background),
165 (GetCompletions, Background),
166 (GetCompletionsResponse, Background),
167 (GetDefinition, Background),
168 (GetDefinitionResponse, Background),
169 (GetTypeDefinition, Background),
170 (GetTypeDefinitionResponse, Background),
171 (GetDocumentHighlights, Background),
172 (GetDocumentHighlightsResponse, Background),
173 (GetReferences, Background),
174 (GetReferencesResponse, Background),
175 (GetProjectSymbols, Background),
176 (GetProjectSymbolsResponse, Background),
177 (GetUsers, Foreground),
178 (Hello, Foreground),
179 (IncomingCall, Foreground),
180 (UsersResponse, Foreground),
181 (JoinChannel, Foreground),
182 (JoinChannelResponse, Foreground),
183 (JoinProject, Foreground),
184 (JoinProjectResponse, Foreground),
185 (JoinRoom, Foreground),
186 (JoinRoomResponse, Foreground),
187 (LeaveChannel, Foreground),
188 (LeaveProject, Foreground),
189 (LeaveRoom, Foreground),
190 (OpenBufferById, Background),
191 (OpenBufferByPath, Background),
192 (OpenBufferForSymbol, Background),
193 (OpenBufferForSymbolResponse, Background),
194 (OpenBufferResponse, Background),
195 (PerformRename, Background),
196 (PerformRenameResponse, Background),
197 (Ping, Foreground),
198 (PrepareRename, Background),
199 (PrepareRenameResponse, Background),
200 (ProjectEntryResponse, Foreground),
201 (RejoinRoom, Foreground),
202 (RejoinRoomResponse, Foreground),
203 (RemoveContact, Foreground),
204 (ReloadBuffers, Foreground),
205 (ReloadBuffersResponse, Foreground),
206 (RemoveProjectCollaborator, Foreground),
207 (RenameProjectEntry, Foreground),
208 (RequestContact, Foreground),
209 (RespondToContactRequest, Foreground),
210 (RoomUpdated, Foreground),
211 (SaveBuffer, Foreground),
212 (SearchProject, Background),
213 (SearchProjectResponse, Background),
214 (SendChannelMessage, Foreground),
215 (SendChannelMessageResponse, Foreground),
216 (ShareProject, Foreground),
217 (ShareProjectResponse, Foreground),
218 (ShowContacts, Foreground),
219 (StartLanguageServer, Foreground),
220 (SynchronizeBuffers, Foreground),
221 (SynchronizeBuffersResponse, Foreground),
222 (Test, Foreground),
223 (Unfollow, Foreground),
224 (UnshareProject, Foreground),
225 (UpdateBuffer, Foreground),
226 (UpdateBufferFile, Foreground),
227 (UpdateContacts, Foreground),
228 (UpdateDiagnosticSummary, Foreground),
229 (UpdateFollowers, Foreground),
230 (UpdateInviteInfo, Foreground),
231 (UpdateLanguageServer, Foreground),
232 (UpdateParticipantLocation, Foreground),
233 (UpdateProject, Foreground),
234 (UpdateProjectCollaborator, Foreground),
235 (UpdateWorktree, Foreground),
236 (UpdateDiffBase, Foreground),
237 (GetPrivateUserInfo, Foreground),
238 (GetPrivateUserInfoResponse, Foreground),
239);
240
241request_messages!(
242 (ApplyCodeAction, ApplyCodeActionResponse),
243 (
244 ApplyCompletionAdditionalEdits,
245 ApplyCompletionAdditionalEditsResponse
246 ),
247 (Call, Ack),
248 (CancelCall, Ack),
249 (CopyProjectEntry, ProjectEntryResponse),
250 (CreateProjectEntry, ProjectEntryResponse),
251 (CreateRoom, CreateRoomResponse),
252 (DeclineCall, Ack),
253 (DeleteProjectEntry, ProjectEntryResponse),
254 (Follow, FollowResponse),
255 (FormatBuffers, FormatBuffersResponse),
256 (GetChannelMessages, GetChannelMessagesResponse),
257 (GetChannels, GetChannelsResponse),
258 (GetCodeActions, GetCodeActionsResponse),
259 (GetHover, GetHoverResponse),
260 (GetCompletions, GetCompletionsResponse),
261 (GetDefinition, GetDefinitionResponse),
262 (GetTypeDefinition, GetTypeDefinitionResponse),
263 (GetDocumentHighlights, GetDocumentHighlightsResponse),
264 (GetReferences, GetReferencesResponse),
265 (GetPrivateUserInfo, GetPrivateUserInfoResponse),
266 (GetProjectSymbols, GetProjectSymbolsResponse),
267 (FuzzySearchUsers, UsersResponse),
268 (GetUsers, UsersResponse),
269 (JoinChannel, JoinChannelResponse),
270 (JoinProject, JoinProjectResponse),
271 (JoinRoom, JoinRoomResponse),
272 (LeaveRoom, Ack),
273 (RejoinRoom, RejoinRoomResponse),
274 (IncomingCall, Ack),
275 (OpenBufferById, OpenBufferResponse),
276 (OpenBufferByPath, OpenBufferResponse),
277 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
278 (Ping, Ack),
279 (PerformRename, PerformRenameResponse),
280 (PrepareRename, PrepareRenameResponse),
281 (ReloadBuffers, ReloadBuffersResponse),
282 (RequestContact, Ack),
283 (RemoveContact, Ack),
284 (RespondToContactRequest, Ack),
285 (RenameProjectEntry, ProjectEntryResponse),
286 (SaveBuffer, BufferSaved),
287 (SearchProject, SearchProjectResponse),
288 (SendChannelMessage, SendChannelMessageResponse),
289 (ShareProject, ShareProjectResponse),
290 (SynchronizeBuffers, SynchronizeBuffersResponse),
291 (Test, Test),
292 (UpdateBuffer, Ack),
293 (UpdateParticipantLocation, Ack),
294 (UpdateProject, Ack),
295 (UpdateWorktree, Ack),
296);
297
298entity_messages!(
299 project_id,
300 AddProjectCollaborator,
301 ApplyCodeAction,
302 ApplyCompletionAdditionalEdits,
303 BufferReloaded,
304 BufferSaved,
305 CopyProjectEntry,
306 CreateBufferForPeer,
307 CreateProjectEntry,
308 DeleteProjectEntry,
309 Follow,
310 FormatBuffers,
311 GetCodeActions,
312 GetCompletions,
313 GetDefinition,
314 GetTypeDefinition,
315 GetDocumentHighlights,
316 GetHover,
317 GetReferences,
318 GetProjectSymbols,
319 JoinProject,
320 LeaveProject,
321 OpenBufferById,
322 OpenBufferByPath,
323 OpenBufferForSymbol,
324 PerformRename,
325 PrepareRename,
326 ReloadBuffers,
327 RemoveProjectCollaborator,
328 RenameProjectEntry,
329 SaveBuffer,
330 SearchProject,
331 StartLanguageServer,
332 SynchronizeBuffers,
333 Unfollow,
334 UnshareProject,
335 UpdateBuffer,
336 UpdateBufferFile,
337 UpdateDiagnosticSummary,
338 UpdateFollowers,
339 UpdateLanguageServer,
340 UpdateProject,
341 UpdateProjectCollaborator,
342 UpdateWorktree,
343 UpdateDiffBase
344);
345
346entity_messages!(channel_id, ChannelMessageSent);
347
348const KIB: usize = 1024;
349const MIB: usize = KIB * 1024;
350const MAX_BUFFER_LEN: usize = MIB;
351
352/// A stream of protobuf messages.
353pub struct MessageStream<S> {
354 stream: S,
355 encoding_buffer: Vec<u8>,
356}
357
358#[allow(clippy::large_enum_variant)]
359#[derive(Debug)]
360pub enum Message {
361 Envelope(Envelope),
362 Ping,
363 Pong,
364}
365
366impl<S> MessageStream<S> {
367 pub fn new(stream: S) -> Self {
368 Self {
369 stream,
370 encoding_buffer: Vec::new(),
371 }
372 }
373
374 pub fn inner_mut(&mut self) -> &mut S {
375 &mut self.stream
376 }
377}
378
379impl<S> MessageStream<S>
380where
381 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
382{
383 pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
384 #[cfg(any(test, feature = "test-support"))]
385 const COMPRESSION_LEVEL: i32 = -7;
386
387 #[cfg(not(any(test, feature = "test-support")))]
388 const COMPRESSION_LEVEL: i32 = 4;
389
390 match message {
391 Message::Envelope(message) => {
392 self.encoding_buffer.reserve(message.encoded_len());
393 message
394 .encode(&mut self.encoding_buffer)
395 .map_err(io::Error::from)?;
396 let buffer =
397 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
398 .unwrap();
399
400 self.encoding_buffer.clear();
401 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
402 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
403 }
404 Message::Ping => {
405 self.stream
406 .send(WebSocketMessage::Ping(Default::default()))
407 .await?;
408 }
409 Message::Pong => {
410 self.stream
411 .send(WebSocketMessage::Pong(Default::default()))
412 .await?;
413 }
414 }
415
416 Ok(())
417 }
418}
419
420impl<S> MessageStream<S>
421where
422 S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
423{
424 pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
425 while let Some(bytes) = self.stream.next().await {
426 match bytes? {
427 WebSocketMessage::Binary(bytes) => {
428 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
429 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
430 .map_err(io::Error::from)?;
431
432 self.encoding_buffer.clear();
433 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
434 return Ok(Message::Envelope(envelope));
435 }
436 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
437 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
438 WebSocketMessage::Close(_) => break,
439 _ => {}
440 }
441 }
442 Err(anyhow!("connection closed"))
443 }
444}
445
446impl From<Timestamp> for SystemTime {
447 fn from(val: Timestamp) -> Self {
448 UNIX_EPOCH
449 .checked_add(Duration::new(val.seconds, val.nanos))
450 .unwrap()
451 }
452}
453
454impl From<SystemTime> for Timestamp {
455 fn from(time: SystemTime) -> Self {
456 let duration = time.duration_since(UNIX_EPOCH).unwrap();
457 Self {
458 seconds: duration.as_secs(),
459 nanos: duration.subsec_nanos(),
460 }
461 }
462}
463
464impl From<u128> for Nonce {
465 fn from(nonce: u128) -> Self {
466 let upper_half = (nonce >> 64) as u64;
467 let lower_half = nonce as u64;
468 Self {
469 upper_half,
470 lower_half,
471 }
472 }
473}
474
475impl From<Nonce> for u128 {
476 fn from(nonce: Nonce) -> Self {
477 let upper_half = (nonce.upper_half as u128) << 64;
478 let lower_half = nonce.lower_half as u128;
479 upper_half | lower_half
480 }
481}
482
483pub fn split_worktree_update(
484 mut message: UpdateWorktree,
485 max_chunk_size: usize,
486) -> impl Iterator<Item = UpdateWorktree> {
487 let mut done_files = false;
488 let mut done_statuses = false;
489 let mut repository_index = 0;
490 let mut root_repo_found = false;
491 iter::from_fn(move || {
492 if done_files && done_statuses {
493 return None;
494 }
495
496 let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
497 let updated_entries: Vec<_> = message
498 .updated_entries
499 .drain(..updated_entries_chunk_size)
500 .collect();
501
502 let mut updated_repositories: Vec<_> = Default::default();
503
504 if !root_repo_found {
505 for entry in updated_entries.iter() {
506 if let Some(repo) = message.updated_repositories.get(0) {
507 if repo.work_directory_id == entry.id {
508 root_repo_found = true;
509 updated_repositories.push(RepositoryEntry {
510 work_directory_id: repo.work_directory_id,
511 branch: repo.branch.clone(),
512 removed_worktree_repo_paths: Default::default(),
513 updated_worktree_statuses: Default::default(),
514 });
515 break;
516 }
517 }
518 }
519 }
520
521 let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
522 let removed_entries = message
523 .removed_entries
524 .drain(..removed_entries_chunk_size)
525 .collect();
526
527 done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
528
529 // Wait to send repositories until after we've guaranteed that their associated entries
530 // will be read
531 if done_files {
532 let mut total_statuses = 0;
533 while total_statuses < max_chunk_size
534 && repository_index < message.updated_repositories.len()
535 {
536 let updated_statuses_chunk_size = cmp::min(
537 message.updated_repositories[repository_index]
538 .updated_worktree_statuses
539 .len(),
540 max_chunk_size - total_statuses,
541 );
542
543 let updated_statuses: Vec<_> = message.updated_repositories[repository_index]
544 .updated_worktree_statuses
545 .drain(..updated_statuses_chunk_size)
546 .collect();
547
548 total_statuses += updated_statuses.len();
549
550 let done_this_repo = message.updated_repositories[repository_index]
551 .updated_worktree_statuses
552 .is_empty();
553
554 let removed_repo_paths = if done_this_repo {
555 mem::take(
556 &mut message.updated_repositories[repository_index]
557 .removed_worktree_repo_paths,
558 )
559 } else {
560 Default::default()
561 };
562
563 updated_repositories.push(RepositoryEntry {
564 work_directory_id: message.updated_repositories[repository_index]
565 .work_directory_id,
566 branch: message.updated_repositories[repository_index]
567 .branch
568 .clone(),
569 updated_worktree_statuses: updated_statuses,
570 removed_worktree_repo_paths: removed_repo_paths,
571 });
572
573 if done_this_repo {
574 repository_index += 1;
575 }
576 }
577 } else {
578 Default::default()
579 };
580
581 let removed_repositories = if done_files && done_statuses {
582 mem::take(&mut message.removed_repositories)
583 } else {
584 Default::default()
585 };
586
587 done_statuses = repository_index >= message.updated_repositories.len();
588
589 Some(UpdateWorktree {
590 project_id: message.project_id,
591 worktree_id: message.worktree_id,
592 root_name: message.root_name.clone(),
593 abs_path: message.abs_path.clone(),
594 updated_entries,
595 removed_entries,
596 scan_id: message.scan_id,
597 is_last_update: done_files && message.is_last_update,
598 updated_repositories,
599 removed_repositories,
600 })
601 })
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607
608 #[gpui::test]
609 async fn test_buffer_size() {
610 let (tx, rx) = futures::channel::mpsc::unbounded();
611 let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
612 sink.write(Message::Envelope(Envelope {
613 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
614 root_name: "abcdefg".repeat(10),
615 ..Default::default()
616 })),
617 ..Default::default()
618 }))
619 .await
620 .unwrap();
621 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
622 sink.write(Message::Envelope(Envelope {
623 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
624 root_name: "abcdefg".repeat(1000000),
625 ..Default::default()
626 })),
627 ..Default::default()
628 }))
629 .await
630 .unwrap();
631 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
632
633 let mut stream = MessageStream::new(rx.map(anyhow::Ok));
634 stream.read().await.unwrap();
635 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
636 stream.read().await.unwrap();
637 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
638 }
639
640 #[gpui::test]
641 fn test_converting_peer_id_from_and_to_u64() {
642 let peer_id = PeerId {
643 owner_id: 10,
644 id: 3,
645 };
646 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
647 let peer_id = PeerId {
648 owner_id: u32::MAX,
649 id: 3,
650 };
651 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
652 let peer_id = PeerId {
653 owner_id: 10,
654 id: u32::MAX,
655 };
656 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
657 let peer_id = PeerId {
658 owner_id: u32::MAX,
659 id: u32::MAX,
660 };
661 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
662 }
663}