1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
2use anyhow::{anyhow, Result};
3use async_tungstenite::tungstenite::Message as WebSocketMessage;
4use collections::HashMap;
5use futures::{SinkExt as _, StreamExt as _};
6use prost::Message as _;
7use serde::Serialize;
8use std::any::{Any, TypeId};
9use std::{
10 cmp,
11 fmt::Debug,
12 io, iter,
13 time::{Duration, SystemTime, UNIX_EPOCH},
14};
15use std::{fmt, mem};
16
17include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
18
19pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
20 const NAME: &'static str;
21 const PRIORITY: MessagePriority;
22 fn into_envelope(
23 self,
24 id: u32,
25 responding_to: Option<u32>,
26 original_sender_id: Option<PeerId>,
27 ) -> Envelope;
28 fn from_envelope(envelope: Envelope) -> Option<Self>;
29}
30
31pub trait EntityMessage: EnvelopedMessage {
32 fn remote_entity_id(&self) -> u64;
33}
34
35pub trait RequestMessage: EnvelopedMessage {
36 type Response: EnvelopedMessage;
37}
38
39pub trait AnyTypedEnvelope: 'static + Send + Sync {
40 fn payload_type_id(&self) -> TypeId;
41 fn payload_type_name(&self) -> &'static str;
42 fn as_any(&self) -> &dyn Any;
43 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
44 fn is_background(&self) -> bool;
45 fn original_sender_id(&self) -> Option<PeerId>;
46 fn sender_id(&self) -> ConnectionId;
47 fn message_id(&self) -> u32;
48}
49
50pub enum MessagePriority {
51 Foreground,
52 Background,
53}
54
55impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
56 fn payload_type_id(&self) -> TypeId {
57 TypeId::of::<T>()
58 }
59
60 fn payload_type_name(&self) -> &'static str {
61 T::NAME
62 }
63
64 fn as_any(&self) -> &dyn Any {
65 self
66 }
67
68 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
69 self
70 }
71
72 fn is_background(&self) -> bool {
73 matches!(T::PRIORITY, MessagePriority::Background)
74 }
75
76 fn original_sender_id(&self) -> Option<PeerId> {
77 self.original_sender_id
78 }
79
80 fn sender_id(&self) -> ConnectionId {
81 self.sender_id
82 }
83
84 fn message_id(&self) -> u32 {
85 self.message_id
86 }
87}
88
89impl PeerId {
90 pub fn from_u64(peer_id: u64) -> Self {
91 let owner_id = (peer_id >> 32) as u32;
92 let id = peer_id as u32;
93 Self { owner_id, id }
94 }
95
96 pub fn as_u64(self) -> u64 {
97 ((self.owner_id as u64) << 32) | (self.id as u64)
98 }
99}
100
101impl Copy for PeerId {}
102
103impl Eq for PeerId {}
104
105impl Ord for PeerId {
106 fn cmp(&self, other: &Self) -> cmp::Ordering {
107 self.owner_id
108 .cmp(&other.owner_id)
109 .then_with(|| self.id.cmp(&other.id))
110 }
111}
112
113impl PartialOrd for PeerId {
114 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
115 Some(self.cmp(other))
116 }
117}
118
119impl std::hash::Hash for PeerId {
120 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
121 self.owner_id.hash(state);
122 self.id.hash(state);
123 }
124}
125
126impl fmt::Display for PeerId {
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 write!(f, "{}/{}", self.owner_id, self.id)
129 }
130}
131
132messages!(
133 (Ack, Foreground),
134 (AddProjectCollaborator, Foreground),
135 (ApplyCodeAction, Background),
136 (ApplyCodeActionResponse, Background),
137 (ApplyCompletionAdditionalEdits, Background),
138 (ApplyCompletionAdditionalEditsResponse, Background),
139 (BufferReloaded, Foreground),
140 (BufferSaved, Foreground),
141 (Call, Foreground),
142 (CallCanceled, Foreground),
143 (CancelCall, Foreground),
144 (ChannelMessageSent, Foreground),
145 (CopyProjectEntry, Foreground),
146 (CreateBufferForPeer, Foreground),
147 (CreateProjectEntry, Foreground),
148 (CreateRoom, Foreground),
149 (CreateRoomResponse, Foreground),
150 (DeclineCall, Foreground),
151 (DeleteProjectEntry, Foreground),
152 (Error, Foreground),
153 (Follow, Foreground),
154 (FollowResponse, Foreground),
155 (FormatBuffers, Foreground),
156 (FormatBuffersResponse, Foreground),
157 (FuzzySearchUsers, Foreground),
158 (GetChannelMessages, Foreground),
159 (GetChannelMessagesResponse, Foreground),
160 (GetChannels, Foreground),
161 (GetChannelsResponse, Foreground),
162 (GetCodeActions, Background),
163 (GetCodeActionsResponse, Background),
164 (GetHover, Background),
165 (GetHoverResponse, Background),
166 (GetCompletions, Background),
167 (GetCompletionsResponse, Background),
168 (GetDefinition, Background),
169 (GetDefinitionResponse, Background),
170 (GetTypeDefinition, Background),
171 (GetTypeDefinitionResponse, Background),
172 (GetDocumentHighlights, Background),
173 (GetDocumentHighlightsResponse, Background),
174 (GetReferences, Background),
175 (GetReferencesResponse, Background),
176 (GetProjectSymbols, Background),
177 (GetProjectSymbolsResponse, Background),
178 (GetUsers, Foreground),
179 (Hello, Foreground),
180 (IncomingCall, Foreground),
181 (UsersResponse, Foreground),
182 (JoinChannel, Foreground),
183 (JoinChannelResponse, Foreground),
184 (JoinProject, Foreground),
185 (JoinProjectResponse, Foreground),
186 (JoinRoom, Foreground),
187 (JoinRoomResponse, Foreground),
188 (LeaveChannel, Foreground),
189 (LeaveProject, Foreground),
190 (LeaveRoom, Foreground),
191 (OpenBufferById, Background),
192 (OpenBufferByPath, Background),
193 (OpenBufferForSymbol, Background),
194 (OpenBufferForSymbolResponse, Background),
195 (OpenBufferResponse, Background),
196 (PerformRename, Background),
197 (PerformRenameResponse, Background),
198 (OnTypeFormatting, Background),
199 (OnTypeFormattingResponse, Background),
200 (Ping, Foreground),
201 (PrepareRename, Background),
202 (PrepareRenameResponse, Background),
203 (ProjectEntryResponse, Foreground),
204 (RejoinRoom, Foreground),
205 (RejoinRoomResponse, Foreground),
206 (RemoveContact, Foreground),
207 (ReloadBuffers, Foreground),
208 (ReloadBuffersResponse, Foreground),
209 (RemoveProjectCollaborator, Foreground),
210 (RenameProjectEntry, Foreground),
211 (RequestContact, Foreground),
212 (RespondToContactRequest, Foreground),
213 (RoomUpdated, Foreground),
214 (SaveBuffer, Foreground),
215 (SearchProject, Background),
216 (SearchProjectResponse, Background),
217 (SendChannelMessage, Foreground),
218 (SendChannelMessageResponse, Foreground),
219 (ShareProject, Foreground),
220 (ShareProjectResponse, Foreground),
221 (ShowContacts, Foreground),
222 (StartLanguageServer, Foreground),
223 (SynchronizeBuffers, Foreground),
224 (SynchronizeBuffersResponse, Foreground),
225 (Test, Foreground),
226 (Unfollow, Foreground),
227 (UnshareProject, Foreground),
228 (UpdateBuffer, Foreground),
229 (UpdateBufferFile, Foreground),
230 (UpdateContacts, Foreground),
231 (UpdateDiagnosticSummary, Foreground),
232 (UpdateFollowers, Foreground),
233 (UpdateInviteInfo, Foreground),
234 (UpdateLanguageServer, Foreground),
235 (UpdateParticipantLocation, Foreground),
236 (UpdateProject, Foreground),
237 (UpdateProjectCollaborator, Foreground),
238 (UpdateWorktree, Foreground),
239 (UpdateDiffBase, Foreground),
240 (GetPrivateUserInfo, Foreground),
241 (GetPrivateUserInfoResponse, Foreground),
242);
243
244request_messages!(
245 (ApplyCodeAction, ApplyCodeActionResponse),
246 (
247 ApplyCompletionAdditionalEdits,
248 ApplyCompletionAdditionalEditsResponse
249 ),
250 (Call, Ack),
251 (CancelCall, Ack),
252 (CopyProjectEntry, ProjectEntryResponse),
253 (CreateProjectEntry, ProjectEntryResponse),
254 (CreateRoom, CreateRoomResponse),
255 (DeclineCall, Ack),
256 (DeleteProjectEntry, ProjectEntryResponse),
257 (Follow, FollowResponse),
258 (FormatBuffers, FormatBuffersResponse),
259 (GetChannelMessages, GetChannelMessagesResponse),
260 (GetChannels, GetChannelsResponse),
261 (GetCodeActions, GetCodeActionsResponse),
262 (GetHover, GetHoverResponse),
263 (GetCompletions, GetCompletionsResponse),
264 (GetDefinition, GetDefinitionResponse),
265 (GetTypeDefinition, GetTypeDefinitionResponse),
266 (GetDocumentHighlights, GetDocumentHighlightsResponse),
267 (GetReferences, GetReferencesResponse),
268 (GetPrivateUserInfo, GetPrivateUserInfoResponse),
269 (GetProjectSymbols, GetProjectSymbolsResponse),
270 (FuzzySearchUsers, UsersResponse),
271 (GetUsers, UsersResponse),
272 (JoinChannel, JoinChannelResponse),
273 (JoinProject, JoinProjectResponse),
274 (JoinRoom, JoinRoomResponse),
275 (LeaveRoom, Ack),
276 (RejoinRoom, RejoinRoomResponse),
277 (IncomingCall, Ack),
278 (OpenBufferById, OpenBufferResponse),
279 (OpenBufferByPath, OpenBufferResponse),
280 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
281 (Ping, Ack),
282 (PerformRename, PerformRenameResponse),
283 (PrepareRename, PrepareRenameResponse),
284 (OnTypeFormatting, OnTypeFormattingResponse),
285 (ReloadBuffers, ReloadBuffersResponse),
286 (RequestContact, Ack),
287 (RemoveContact, Ack),
288 (RespondToContactRequest, Ack),
289 (RenameProjectEntry, ProjectEntryResponse),
290 (SaveBuffer, BufferSaved),
291 (SearchProject, SearchProjectResponse),
292 (SendChannelMessage, SendChannelMessageResponse),
293 (ShareProject, ShareProjectResponse),
294 (SynchronizeBuffers, SynchronizeBuffersResponse),
295 (Test, Test),
296 (UpdateBuffer, Ack),
297 (UpdateParticipantLocation, Ack),
298 (UpdateProject, Ack),
299 (UpdateWorktree, Ack),
300);
301
302entity_messages!(
303 project_id,
304 AddProjectCollaborator,
305 ApplyCodeAction,
306 ApplyCompletionAdditionalEdits,
307 BufferReloaded,
308 BufferSaved,
309 CopyProjectEntry,
310 CreateBufferForPeer,
311 CreateProjectEntry,
312 DeleteProjectEntry,
313 Follow,
314 FormatBuffers,
315 GetCodeActions,
316 GetCompletions,
317 GetDefinition,
318 GetTypeDefinition,
319 GetDocumentHighlights,
320 GetHover,
321 GetReferences,
322 GetProjectSymbols,
323 JoinProject,
324 LeaveProject,
325 OpenBufferById,
326 OpenBufferByPath,
327 OpenBufferForSymbol,
328 PerformRename,
329 OnTypeFormatting,
330 PrepareRename,
331 ReloadBuffers,
332 RemoveProjectCollaborator,
333 RenameProjectEntry,
334 SaveBuffer,
335 SearchProject,
336 StartLanguageServer,
337 SynchronizeBuffers,
338 Unfollow,
339 UnshareProject,
340 UpdateBuffer,
341 UpdateBufferFile,
342 UpdateDiagnosticSummary,
343 UpdateFollowers,
344 UpdateLanguageServer,
345 UpdateProject,
346 UpdateProjectCollaborator,
347 UpdateWorktree,
348 UpdateDiffBase
349);
350
351entity_messages!(channel_id, ChannelMessageSent);
352
353const KIB: usize = 1024;
354const MIB: usize = KIB * 1024;
355const MAX_BUFFER_LEN: usize = MIB;
356
357/// A stream of protobuf messages.
358pub struct MessageStream<S> {
359 stream: S,
360 encoding_buffer: Vec<u8>,
361}
362
363#[allow(clippy::large_enum_variant)]
364#[derive(Debug)]
365pub enum Message {
366 Envelope(Envelope),
367 Ping,
368 Pong,
369}
370
371impl<S> MessageStream<S> {
372 pub fn new(stream: S) -> Self {
373 Self {
374 stream,
375 encoding_buffer: Vec::new(),
376 }
377 }
378
379 pub fn inner_mut(&mut self) -> &mut S {
380 &mut self.stream
381 }
382}
383
384impl<S> MessageStream<S>
385where
386 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
387{
388 pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
389 #[cfg(any(test, feature = "test-support"))]
390 const COMPRESSION_LEVEL: i32 = -7;
391
392 #[cfg(not(any(test, feature = "test-support")))]
393 const COMPRESSION_LEVEL: i32 = 4;
394
395 match message {
396 Message::Envelope(message) => {
397 self.encoding_buffer.reserve(message.encoded_len());
398 message
399 .encode(&mut self.encoding_buffer)
400 .map_err(io::Error::from)?;
401 let buffer =
402 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
403 .unwrap();
404
405 self.encoding_buffer.clear();
406 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
407 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
408 }
409 Message::Ping => {
410 self.stream
411 .send(WebSocketMessage::Ping(Default::default()))
412 .await?;
413 }
414 Message::Pong => {
415 self.stream
416 .send(WebSocketMessage::Pong(Default::default()))
417 .await?;
418 }
419 }
420
421 Ok(())
422 }
423}
424
425impl<S> MessageStream<S>
426where
427 S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
428{
429 pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
430 while let Some(bytes) = self.stream.next().await {
431 match bytes? {
432 WebSocketMessage::Binary(bytes) => {
433 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
434 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
435 .map_err(io::Error::from)?;
436
437 self.encoding_buffer.clear();
438 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
439 return Ok(Message::Envelope(envelope));
440 }
441 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
442 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
443 WebSocketMessage::Close(_) => break,
444 _ => {}
445 }
446 }
447 Err(anyhow!("connection closed"))
448 }
449}
450
451impl From<Timestamp> for SystemTime {
452 fn from(val: Timestamp) -> Self {
453 UNIX_EPOCH
454 .checked_add(Duration::new(val.seconds, val.nanos))
455 .unwrap()
456 }
457}
458
459impl From<SystemTime> for Timestamp {
460 fn from(time: SystemTime) -> Self {
461 let duration = time.duration_since(UNIX_EPOCH).unwrap();
462 Self {
463 seconds: duration.as_secs(),
464 nanos: duration.subsec_nanos(),
465 }
466 }
467}
468
469impl From<u128> for Nonce {
470 fn from(nonce: u128) -> Self {
471 let upper_half = (nonce >> 64) as u64;
472 let lower_half = nonce as u64;
473 Self {
474 upper_half,
475 lower_half,
476 }
477 }
478}
479
480impl From<Nonce> for u128 {
481 fn from(nonce: Nonce) -> Self {
482 let upper_half = (nonce.upper_half as u128) << 64;
483 let lower_half = nonce.lower_half as u128;
484 upper_half | lower_half
485 }
486}
487
488pub fn split_worktree_update(
489 mut message: UpdateWorktree,
490 max_chunk_size: usize,
491) -> impl Iterator<Item = UpdateWorktree> {
492 let mut done_files = false;
493
494 let mut repository_map = message
495 .updated_repositories
496 .into_iter()
497 .map(|repo| (repo.work_directory_id, repo))
498 .collect::<HashMap<_, _>>();
499
500 // Maintain a list of inflight repositories
501 // Every time we send a repository's work directory, stick it in the list of in-flight repositories
502 // Every go of the loop, drain each in-flight repository's statuses
503 // Until we have no more data
504
505 iter::from_fn(move || {
506 if done_files {
507 return None;
508 }
509
510 let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
511 let updated_entries: Vec<_> = message
512 .updated_entries
513 .drain(..updated_entries_chunk_size)
514 .collect();
515
516 let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
517 let removed_entries = message
518 .removed_entries
519 .drain(..removed_entries_chunk_size)
520 .collect();
521
522 done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
523
524 let mut updated_repositories = Vec::new();
525
526 if !repository_map.is_empty() {
527 for entry in &updated_entries {
528 if let Some(repo) = repository_map.remove(&entry.id) {
529 updated_repositories.push(repo)
530 }
531 }
532 }
533
534 let removed_repositories = if done_files {
535 mem::take(&mut message.removed_repositories)
536 } else {
537 Default::default()
538 };
539
540 if done_files {
541 updated_repositories.extend(mem::take(&mut repository_map).into_values());
542 }
543
544 Some(UpdateWorktree {
545 project_id: message.project_id,
546 worktree_id: message.worktree_id,
547 root_name: message.root_name.clone(),
548 abs_path: message.abs_path.clone(),
549 updated_entries,
550 removed_entries,
551 scan_id: message.scan_id,
552 is_last_update: done_files && message.is_last_update,
553 updated_repositories,
554 removed_repositories,
555 })
556 })
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562
563 #[gpui::test]
564 async fn test_buffer_size() {
565 let (tx, rx) = futures::channel::mpsc::unbounded();
566 let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
567 sink.write(Message::Envelope(Envelope {
568 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
569 root_name: "abcdefg".repeat(10),
570 ..Default::default()
571 })),
572 ..Default::default()
573 }))
574 .await
575 .unwrap();
576 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
577 sink.write(Message::Envelope(Envelope {
578 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
579 root_name: "abcdefg".repeat(1000000),
580 ..Default::default()
581 })),
582 ..Default::default()
583 }))
584 .await
585 .unwrap();
586 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
587
588 let mut stream = MessageStream::new(rx.map(anyhow::Ok));
589 stream.read().await.unwrap();
590 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
591 stream.read().await.unwrap();
592 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
593 }
594
595 #[gpui::test]
596 fn test_converting_peer_id_from_and_to_u64() {
597 let peer_id = PeerId {
598 owner_id: 10,
599 id: 3,
600 };
601 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
602 let peer_id = PeerId {
603 owner_id: u32::MAX,
604 id: 3,
605 };
606 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
607 let peer_id = PeerId {
608 owner_id: 10,
609 id: u32::MAX,
610 };
611 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
612 let peer_id = PeerId {
613 owner_id: u32::MAX,
614 id: u32::MAX,
615 };
616 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
617 }
618}