1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
2use anyhow::{anyhow, Result};
3use async_tungstenite::tungstenite::Message as WebSocketMessage;
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message as _;
6use serde::Serialize;
7use std::any::{Any, TypeId};
8use std::{
9 cmp,
10 fmt::Debug,
11 io, iter,
12 time::{Duration, SystemTime, UNIX_EPOCH},
13};
14use std::{fmt, mem};
15
16include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
17
18pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
19 const NAME: &'static str;
20 const PRIORITY: MessagePriority;
21 fn into_envelope(
22 self,
23 id: u32,
24 responding_to: Option<u32>,
25 original_sender_id: Option<PeerId>,
26 ) -> Envelope;
27 fn from_envelope(envelope: Envelope) -> Option<Self>;
28}
29
30pub trait EntityMessage: EnvelopedMessage {
31 fn remote_entity_id(&self) -> u64;
32}
33
34pub trait RequestMessage: EnvelopedMessage {
35 type Response: EnvelopedMessage;
36}
37
38pub trait AnyTypedEnvelope: 'static + Send + Sync {
39 fn payload_type_id(&self) -> TypeId;
40 fn payload_type_name(&self) -> &'static str;
41 fn as_any(&self) -> &dyn Any;
42 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
43 fn is_background(&self) -> bool;
44 fn original_sender_id(&self) -> Option<PeerId>;
45 fn sender_id(&self) -> ConnectionId;
46 fn message_id(&self) -> u32;
47}
48
49pub enum MessagePriority {
50 Foreground,
51 Background,
52}
53
54impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
55 fn payload_type_id(&self) -> TypeId {
56 TypeId::of::<T>()
57 }
58
59 fn payload_type_name(&self) -> &'static str {
60 T::NAME
61 }
62
63 fn as_any(&self) -> &dyn Any {
64 self
65 }
66
67 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
68 self
69 }
70
71 fn is_background(&self) -> bool {
72 matches!(T::PRIORITY, MessagePriority::Background)
73 }
74
75 fn original_sender_id(&self) -> Option<PeerId> {
76 self.original_sender_id
77 }
78
79 fn sender_id(&self) -> ConnectionId {
80 self.sender_id
81 }
82
83 fn message_id(&self) -> u32 {
84 self.message_id
85 }
86}
87
88impl PeerId {
89 pub fn from_u64(peer_id: u64) -> Self {
90 let owner_id = (peer_id >> 32) as u32;
91 let id = peer_id as u32;
92 Self { owner_id, id }
93 }
94
95 pub fn as_u64(self) -> u64 {
96 ((self.owner_id as u64) << 32) | (self.id as u64)
97 }
98}
99
100impl Copy for PeerId {}
101
102impl Eq for PeerId {}
103
104impl Ord for PeerId {
105 fn cmp(&self, other: &Self) -> cmp::Ordering {
106 self.owner_id
107 .cmp(&other.owner_id)
108 .then_with(|| self.id.cmp(&other.id))
109 }
110}
111
112impl PartialOrd for PeerId {
113 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
114 Some(self.cmp(other))
115 }
116}
117
118impl std::hash::Hash for PeerId {
119 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
120 self.owner_id.hash(state);
121 self.id.hash(state);
122 }
123}
124
125impl fmt::Display for PeerId {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 write!(f, "{}/{}", self.owner_id, self.id)
128 }
129}
130
131messages!(
132 (Ack, Foreground),
133 (AddProjectCollaborator, Foreground),
134 (ApplyCodeAction, Background),
135 (ApplyCodeActionResponse, Background),
136 (ApplyCompletionAdditionalEdits, Background),
137 (ApplyCompletionAdditionalEditsResponse, Background),
138 (BufferReloaded, Foreground),
139 (BufferSaved, Foreground),
140 (Call, Foreground),
141 (CallCanceled, Foreground),
142 (CancelCall, Foreground),
143 (ChannelMessageSent, Foreground),
144 (CopyProjectEntry, Foreground),
145 (CreateBufferForPeer, Foreground),
146 (CreateProjectEntry, Foreground),
147 (CreateRoom, Foreground),
148 (CreateRoomResponse, Foreground),
149 (DeclineCall, Foreground),
150 (DeleteProjectEntry, Foreground),
151 (Error, Foreground),
152 (Follow, Foreground),
153 (FollowResponse, Foreground),
154 (FormatBuffers, Foreground),
155 (FormatBuffersResponse, Foreground),
156 (FuzzySearchUsers, Foreground),
157 (GetChannelMessages, Foreground),
158 (GetChannelMessagesResponse, Foreground),
159 (GetChannels, Foreground),
160 (GetChannelsResponse, Foreground),
161 (GetCodeActions, Background),
162 (GetCodeActionsResponse, Background),
163 (GetHover, Background),
164 (GetHoverResponse, Background),
165 (GetCompletions, Background),
166 (GetCompletionsResponse, Background),
167 (GetDefinition, Background),
168 (GetDefinitionResponse, Background),
169 (GetTypeDefinition, Background),
170 (GetTypeDefinitionResponse, Background),
171 (GetDocumentHighlights, Background),
172 (GetDocumentHighlightsResponse, Background),
173 (GetReferences, Background),
174 (GetReferencesResponse, Background),
175 (GetProjectSymbols, Background),
176 (GetProjectSymbolsResponse, Background),
177 (GetUsers, Foreground),
178 (Hello, Foreground),
179 (IncomingCall, Foreground),
180 (UsersResponse, Foreground),
181 (JoinChannel, Foreground),
182 (JoinChannelResponse, Foreground),
183 (JoinProject, Foreground),
184 (JoinProjectResponse, Foreground),
185 (JoinRoom, Foreground),
186 (JoinRoomResponse, Foreground),
187 (LeaveChannel, Foreground),
188 (LeaveProject, Foreground),
189 (LeaveRoom, Foreground),
190 (OpenBufferById, Background),
191 (OpenBufferByPath, Background),
192 (OpenBufferForSymbol, Background),
193 (OpenBufferForSymbolResponse, Background),
194 (OpenBufferResponse, Background),
195 (PerformRename, Background),
196 (PerformRenameResponse, Background),
197 (Ping, Foreground),
198 (PrepareRename, Background),
199 (PrepareRenameResponse, Background),
200 (ProjectEntryResponse, Foreground),
201 (RejoinRoom, Foreground),
202 (RejoinRoomResponse, Foreground),
203 (RemoveContact, Foreground),
204 (ReloadBuffers, Foreground),
205 (ReloadBuffersResponse, Foreground),
206 (RemoveProjectCollaborator, Foreground),
207 (RenameProjectEntry, Foreground),
208 (RequestContact, Foreground),
209 (RespondToContactRequest, Foreground),
210 (RoomUpdated, Foreground),
211 (SaveBuffer, Foreground),
212 (SearchProject, Background),
213 (SearchProjectResponse, Background),
214 (SendChannelMessage, Foreground),
215 (SendChannelMessageResponse, Foreground),
216 (ShareProject, Foreground),
217 (ShareProjectResponse, Foreground),
218 (ShowContacts, Foreground),
219 (StartLanguageServer, Foreground),
220 (SynchronizeBuffers, Foreground),
221 (SynchronizeBuffersResponse, Foreground),
222 (Test, Foreground),
223 (Unfollow, Foreground),
224 (UnshareProject, Foreground),
225 (UpdateBuffer, Foreground),
226 (UpdateBufferFile, Foreground),
227 (UpdateContacts, Foreground),
228 (UpdateDiagnosticSummary, Foreground),
229 (UpdateFollowers, Foreground),
230 (UpdateInviteInfo, Foreground),
231 (UpdateLanguageServer, Foreground),
232 (UpdateParticipantLocation, Foreground),
233 (UpdateProject, Foreground),
234 (UpdateProjectCollaborator, Foreground),
235 (UpdateWorktree, Foreground),
236 (UpdateDiffBase, Foreground),
237 (GetPrivateUserInfo, Foreground),
238 (GetPrivateUserInfoResponse, Foreground),
239);
240
241request_messages!(
242 (ApplyCodeAction, ApplyCodeActionResponse),
243 (
244 ApplyCompletionAdditionalEdits,
245 ApplyCompletionAdditionalEditsResponse
246 ),
247 (Call, Ack),
248 (CancelCall, Ack),
249 (CopyProjectEntry, ProjectEntryResponse),
250 (CreateProjectEntry, ProjectEntryResponse),
251 (CreateRoom, CreateRoomResponse),
252 (DeclineCall, Ack),
253 (DeleteProjectEntry, ProjectEntryResponse),
254 (Follow, FollowResponse),
255 (FormatBuffers, FormatBuffersResponse),
256 (GetChannelMessages, GetChannelMessagesResponse),
257 (GetChannels, GetChannelsResponse),
258 (GetCodeActions, GetCodeActionsResponse),
259 (GetHover, GetHoverResponse),
260 (GetCompletions, GetCompletionsResponse),
261 (GetDefinition, GetDefinitionResponse),
262 (GetTypeDefinition, GetTypeDefinitionResponse),
263 (GetDocumentHighlights, GetDocumentHighlightsResponse),
264 (GetReferences, GetReferencesResponse),
265 (GetPrivateUserInfo, GetPrivateUserInfoResponse),
266 (GetProjectSymbols, GetProjectSymbolsResponse),
267 (FuzzySearchUsers, UsersResponse),
268 (GetUsers, UsersResponse),
269 (JoinChannel, JoinChannelResponse),
270 (JoinProject, JoinProjectResponse),
271 (JoinRoom, JoinRoomResponse),
272 (LeaveRoom, Ack),
273 (RejoinRoom, RejoinRoomResponse),
274 (IncomingCall, Ack),
275 (OpenBufferById, OpenBufferResponse),
276 (OpenBufferByPath, OpenBufferResponse),
277 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
278 (Ping, Ack),
279 (PerformRename, PerformRenameResponse),
280 (PrepareRename, PrepareRenameResponse),
281 (ReloadBuffers, ReloadBuffersResponse),
282 (RequestContact, Ack),
283 (RemoveContact, Ack),
284 (RespondToContactRequest, Ack),
285 (RenameProjectEntry, ProjectEntryResponse),
286 (SaveBuffer, BufferSaved),
287 (SearchProject, SearchProjectResponse),
288 (SendChannelMessage, SendChannelMessageResponse),
289 (ShareProject, ShareProjectResponse),
290 (SynchronizeBuffers, SynchronizeBuffersResponse),
291 (Test, Test),
292 (UpdateBuffer, Ack),
293 (UpdateParticipantLocation, Ack),
294 (UpdateProject, Ack),
295 (UpdateWorktree, Ack),
296);
297
298entity_messages!(
299 project_id,
300 AddProjectCollaborator,
301 ApplyCodeAction,
302 ApplyCompletionAdditionalEdits,
303 BufferReloaded,
304 BufferSaved,
305 CopyProjectEntry,
306 CreateBufferForPeer,
307 CreateProjectEntry,
308 DeleteProjectEntry,
309 Follow,
310 FormatBuffers,
311 GetCodeActions,
312 GetCompletions,
313 GetDefinition,
314 GetTypeDefinition,
315 GetDocumentHighlights,
316 GetHover,
317 GetReferences,
318 GetProjectSymbols,
319 JoinProject,
320 LeaveProject,
321 OpenBufferById,
322 OpenBufferByPath,
323 OpenBufferForSymbol,
324 PerformRename,
325 PrepareRename,
326 ReloadBuffers,
327 RemoveProjectCollaborator,
328 RenameProjectEntry,
329 SaveBuffer,
330 SearchProject,
331 StartLanguageServer,
332 SynchronizeBuffers,
333 Unfollow,
334 UnshareProject,
335 UpdateBuffer,
336 UpdateBufferFile,
337 UpdateDiagnosticSummary,
338 UpdateFollowers,
339 UpdateLanguageServer,
340 UpdateProject,
341 UpdateProjectCollaborator,
342 UpdateWorktree,
343 UpdateDiffBase
344);
345
346entity_messages!(channel_id, ChannelMessageSent);
347
348const KIB: usize = 1024;
349const MIB: usize = KIB * 1024;
350const MAX_BUFFER_LEN: usize = MIB;
351
352/// A stream of protobuf messages.
353pub struct MessageStream<S> {
354 stream: S,
355 encoding_buffer: Vec<u8>,
356}
357
358#[allow(clippy::large_enum_variant)]
359#[derive(Debug)]
360pub enum Message {
361 Envelope(Envelope),
362 Ping,
363 Pong,
364}
365
366impl<S> MessageStream<S> {
367 pub fn new(stream: S) -> Self {
368 Self {
369 stream,
370 encoding_buffer: Vec::new(),
371 }
372 }
373
374 pub fn inner_mut(&mut self) -> &mut S {
375 &mut self.stream
376 }
377}
378
379impl<S> MessageStream<S>
380where
381 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
382{
383 pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
384 #[cfg(any(test, feature = "test-support"))]
385 const COMPRESSION_LEVEL: i32 = -7;
386
387 #[cfg(not(any(test, feature = "test-support")))]
388 const COMPRESSION_LEVEL: i32 = 4;
389
390 match message {
391 Message::Envelope(message) => {
392 self.encoding_buffer.reserve(message.encoded_len());
393 message
394 .encode(&mut self.encoding_buffer)
395 .map_err(io::Error::from)?;
396 let buffer =
397 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
398 .unwrap();
399
400 self.encoding_buffer.clear();
401 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
402 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
403 }
404 Message::Ping => {
405 self.stream
406 .send(WebSocketMessage::Ping(Default::default()))
407 .await?;
408 }
409 Message::Pong => {
410 self.stream
411 .send(WebSocketMessage::Pong(Default::default()))
412 .await?;
413 }
414 }
415
416 Ok(())
417 }
418}
419
420impl<S> MessageStream<S>
421where
422 S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
423{
424 pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
425 while let Some(bytes) = self.stream.next().await {
426 match bytes? {
427 WebSocketMessage::Binary(bytes) => {
428 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
429 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
430 .map_err(io::Error::from)?;
431
432 self.encoding_buffer.clear();
433 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
434 return Ok(Message::Envelope(envelope));
435 }
436 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
437 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
438 WebSocketMessage::Close(_) => break,
439 _ => {}
440 }
441 }
442 Err(anyhow!("connection closed"))
443 }
444}
445
446impl From<Timestamp> for SystemTime {
447 fn from(val: Timestamp) -> Self {
448 UNIX_EPOCH
449 .checked_add(Duration::new(val.seconds, val.nanos))
450 .unwrap()
451 }
452}
453
454impl From<SystemTime> for Timestamp {
455 fn from(time: SystemTime) -> Self {
456 let duration = time.duration_since(UNIX_EPOCH).unwrap();
457 Self {
458 seconds: duration.as_secs(),
459 nanos: duration.subsec_nanos(),
460 }
461 }
462}
463
464impl From<u128> for Nonce {
465 fn from(nonce: u128) -> Self {
466 let upper_half = (nonce >> 64) as u64;
467 let lower_half = nonce as u64;
468 Self {
469 upper_half,
470 lower_half,
471 }
472 }
473}
474
475impl From<Nonce> for u128 {
476 fn from(nonce: Nonce) -> Self {
477 let upper_half = (nonce.upper_half as u128) << 64;
478 let lower_half = nonce.lower_half as u128;
479 upper_half | lower_half
480 }
481}
482
483pub fn split_worktree_update(
484 mut message: UpdateWorktree,
485 max_chunk_size: usize,
486) -> impl Iterator<Item = UpdateWorktree> {
487 let mut done_files = false;
488 let mut done_statuses = false;
489 let mut repository_index = 0;
490 iter::from_fn(move || {
491 if done_files && done_statuses {
492 return None;
493 }
494
495 let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
496 let updated_entries = message
497 .updated_entries
498 .drain(..updated_entries_chunk_size)
499 .collect();
500
501 let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
502 let removed_entries = message
503 .removed_entries
504 .drain(..removed_entries_chunk_size)
505 .collect();
506
507 done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
508
509 // Wait to send repositories until after we've guaranteed that their associated entries
510 // will be read
511 let updated_repositories = if done_files {
512 let mut total_statuses = 0;
513 let mut updated_repositories = Vec::new();
514 while total_statuses < max_chunk_size
515 && repository_index < message.updated_repositories.len()
516 {
517 let updated_statuses_chunk_size = cmp::min(
518 message.updated_repositories[repository_index]
519 .updated_worktree_statuses
520 .len(),
521 max_chunk_size - total_statuses,
522 );
523
524 let updated_statuses: Vec<_> = message.updated_repositories[repository_index]
525 .updated_worktree_statuses
526 .drain(..updated_statuses_chunk_size)
527 .collect();
528
529 total_statuses += updated_statuses.len();
530
531 let done_this_repo = message.updated_repositories[repository_index]
532 .updated_worktree_statuses
533 .is_empty();
534
535 let removed_repo_paths = if done_this_repo {
536 mem::take(
537 &mut message.updated_repositories[repository_index]
538 .removed_worktree_repo_paths,
539 )
540 } else {
541 Default::default()
542 };
543
544 updated_repositories.push(RepositoryEntry {
545 work_directory_id: message.updated_repositories[repository_index]
546 .work_directory_id,
547 branch: message.updated_repositories[repository_index]
548 .branch
549 .clone(),
550 updated_worktree_statuses: updated_statuses,
551 removed_worktree_repo_paths: removed_repo_paths,
552 });
553
554 if done_this_repo {
555 repository_index += 1;
556 }
557 }
558
559 updated_repositories
560 } else {
561 Default::default()
562 };
563
564 let removed_repositories = if done_files && done_statuses {
565 mem::take(&mut message.removed_repositories)
566 } else {
567 Default::default()
568 };
569
570 done_statuses = repository_index >= message.updated_repositories.len();
571
572 Some(UpdateWorktree {
573 project_id: message.project_id,
574 worktree_id: message.worktree_id,
575 root_name: message.root_name.clone(),
576 abs_path: message.abs_path.clone(),
577 updated_entries,
578 removed_entries,
579 scan_id: message.scan_id,
580 is_last_update: done_files && message.is_last_update,
581 updated_repositories,
582 removed_repositories,
583 })
584 })
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[gpui::test]
592 async fn test_buffer_size() {
593 let (tx, rx) = futures::channel::mpsc::unbounded();
594 let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
595 sink.write(Message::Envelope(Envelope {
596 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
597 root_name: "abcdefg".repeat(10),
598 ..Default::default()
599 })),
600 ..Default::default()
601 }))
602 .await
603 .unwrap();
604 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
605 sink.write(Message::Envelope(Envelope {
606 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
607 root_name: "abcdefg".repeat(1000000),
608 ..Default::default()
609 })),
610 ..Default::default()
611 }))
612 .await
613 .unwrap();
614 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
615
616 let mut stream = MessageStream::new(rx.map(anyhow::Ok));
617 stream.read().await.unwrap();
618 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
619 stream.read().await.unwrap();
620 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
621 }
622
623 #[gpui::test]
624 fn test_converting_peer_id_from_and_to_u64() {
625 let peer_id = PeerId {
626 owner_id: 10,
627 id: 3,
628 };
629 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
630 let peer_id = PeerId {
631 owner_id: u32::MAX,
632 id: 3,
633 };
634 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
635 let peer_id = PeerId {
636 owner_id: 10,
637 id: u32::MAX,
638 };
639 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
640 let peer_id = PeerId {
641 owner_id: u32::MAX,
642 id: u32::MAX,
643 };
644 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
645 }
646}