1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
2use anyhow::{anyhow, Result};
3use async_tungstenite::tungstenite::Message as WebSocketMessage;
4use collections::HashMap;
5use futures::{SinkExt as _, StreamExt as _};
6use prost::Message as _;
7use serde::Serialize;
8use std::any::{Any, TypeId};
9use std::{
10 cmp,
11 fmt::Debug,
12 io, iter,
13 time::{Duration, SystemTime, UNIX_EPOCH},
14};
15use std::{fmt, mem};
16
17include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
18
19pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
20 const NAME: &'static str;
21 const PRIORITY: MessagePriority;
22 fn into_envelope(
23 self,
24 id: u32,
25 responding_to: Option<u32>,
26 original_sender_id: Option<PeerId>,
27 ) -> Envelope;
28 fn from_envelope(envelope: Envelope) -> Option<Self>;
29}
30
31pub trait EntityMessage: EnvelopedMessage {
32 fn remote_entity_id(&self) -> u64;
33}
34
35pub trait RequestMessage: EnvelopedMessage {
36 type Response: EnvelopedMessage;
37}
38
39pub trait AnyTypedEnvelope: 'static + Send + Sync {
40 fn payload_type_id(&self) -> TypeId;
41 fn payload_type_name(&self) -> &'static str;
42 fn as_any(&self) -> &dyn Any;
43 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
44 fn is_background(&self) -> bool;
45 fn original_sender_id(&self) -> Option<PeerId>;
46 fn sender_id(&self) -> ConnectionId;
47 fn message_id(&self) -> u32;
48}
49
50pub enum MessagePriority {
51 Foreground,
52 Background,
53}
54
55impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
56 fn payload_type_id(&self) -> TypeId {
57 TypeId::of::<T>()
58 }
59
60 fn payload_type_name(&self) -> &'static str {
61 T::NAME
62 }
63
64 fn as_any(&self) -> &dyn Any {
65 self
66 }
67
68 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
69 self
70 }
71
72 fn is_background(&self) -> bool {
73 matches!(T::PRIORITY, MessagePriority::Background)
74 }
75
76 fn original_sender_id(&self) -> Option<PeerId> {
77 self.original_sender_id
78 }
79
80 fn sender_id(&self) -> ConnectionId {
81 self.sender_id
82 }
83
84 fn message_id(&self) -> u32 {
85 self.message_id
86 }
87}
88
89impl PeerId {
90 pub fn from_u64(peer_id: u64) -> Self {
91 let owner_id = (peer_id >> 32) as u32;
92 let id = peer_id as u32;
93 Self { owner_id, id }
94 }
95
96 pub fn as_u64(self) -> u64 {
97 ((self.owner_id as u64) << 32) | (self.id as u64)
98 }
99}
100
101impl Copy for PeerId {}
102
103impl Eq for PeerId {}
104
105impl Ord for PeerId {
106 fn cmp(&self, other: &Self) -> cmp::Ordering {
107 self.owner_id
108 .cmp(&other.owner_id)
109 .then_with(|| self.id.cmp(&other.id))
110 }
111}
112
113impl PartialOrd for PeerId {
114 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
115 Some(self.cmp(other))
116 }
117}
118
119impl std::hash::Hash for PeerId {
120 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
121 self.owner_id.hash(state);
122 self.id.hash(state);
123 }
124}
125
126impl fmt::Display for PeerId {
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 write!(f, "{}/{}", self.owner_id, self.id)
129 }
130}
131
132messages!(
133 (Ack, Foreground),
134 (AddProjectCollaborator, Foreground),
135 (ApplyCodeAction, Background),
136 (ApplyCodeActionResponse, Background),
137 (ApplyCompletionAdditionalEdits, Background),
138 (ApplyCompletionAdditionalEditsResponse, Background),
139 (BufferReloaded, Foreground),
140 (BufferSaved, Foreground),
141 (Call, Foreground),
142 (CallCanceled, Foreground),
143 (CancelCall, Foreground),
144 (ChannelMessageSent, Foreground),
145 (CopyProjectEntry, Foreground),
146 (CreateBufferForPeer, Foreground),
147 (CreateProjectEntry, Foreground),
148 (CreateRoom, Foreground),
149 (CreateRoomResponse, Foreground),
150 (DeclineCall, Foreground),
151 (DeleteProjectEntry, Foreground),
152 (Error, Foreground),
153 (Follow, Foreground),
154 (FollowResponse, Foreground),
155 (FormatBuffers, Foreground),
156 (FormatBuffersResponse, Foreground),
157 (FuzzySearchUsers, Foreground),
158 (GetChannelMessages, Foreground),
159 (GetChannelMessagesResponse, Foreground),
160 (GetChannels, Foreground),
161 (GetChannelsResponse, Foreground),
162 (GetCodeActions, Background),
163 (GetCodeActionsResponse, Background),
164 (GetHover, Background),
165 (GetHoverResponse, Background),
166 (GetCompletions, Background),
167 (GetCompletionsResponse, Background),
168 (GetDefinition, Background),
169 (GetDefinitionResponse, Background),
170 (GetTypeDefinition, Background),
171 (GetTypeDefinitionResponse, Background),
172 (GetDocumentHighlights, Background),
173 (GetDocumentHighlightsResponse, Background),
174 (GetReferences, Background),
175 (GetReferencesResponse, Background),
176 (GetProjectSymbols, Background),
177 (GetProjectSymbolsResponse, Background),
178 (GetUsers, Foreground),
179 (Hello, Foreground),
180 (IncomingCall, Foreground),
181 (UsersResponse, Foreground),
182 (JoinChannel, Foreground),
183 (JoinChannelResponse, Foreground),
184 (JoinProject, Foreground),
185 (JoinProjectResponse, Foreground),
186 (JoinRoom, Foreground),
187 (JoinRoomResponse, Foreground),
188 (LeaveChannel, Foreground),
189 (LeaveProject, Foreground),
190 (LeaveRoom, Foreground),
191 (OpenBufferById, Background),
192 (OpenBufferByPath, Background),
193 (OpenBufferForSymbol, Background),
194 (OpenBufferForSymbolResponse, Background),
195 (OpenBufferResponse, Background),
196 (PerformRename, Background),
197 (PerformRenameResponse, Background),
198 (OnTypeFormatting, Background),
199 (OnTypeFormattingResponse, Background),
200 (Ping, Foreground),
201 (PrepareRename, Background),
202 (PrepareRenameResponse, Background),
203 (ProjectEntryResponse, Foreground),
204 (RejoinRoom, Foreground),
205 (RejoinRoomResponse, Foreground),
206 (RemoveContact, Foreground),
207 (ReloadBuffers, Foreground),
208 (ReloadBuffersResponse, Foreground),
209 (RemoveProjectCollaborator, Foreground),
210 (RenameProjectEntry, Foreground),
211 (RequestContact, Foreground),
212 (RespondToContactRequest, Foreground),
213 (RoomUpdated, Foreground),
214 (SaveBuffer, Foreground),
215 (SearchProject, Background),
216 (SearchProjectResponse, Background),
217 (SendChannelMessage, Foreground),
218 (SendChannelMessageResponse, Foreground),
219 (ShareProject, Foreground),
220 (ShareProjectResponse, Foreground),
221 (ShowContacts, Foreground),
222 (StartLanguageServer, Foreground),
223 (SynchronizeBuffers, Foreground),
224 (SynchronizeBuffersResponse, Foreground),
225 (Test, Foreground),
226 (Unfollow, Foreground),
227 (UnshareProject, Foreground),
228 (UpdateBuffer, Foreground),
229 (UpdateBufferFile, Foreground),
230 (UpdateContacts, Foreground),
231 (UpdateDiagnosticSummary, Foreground),
232 (UpdateFollowers, Foreground),
233 (UpdateInviteInfo, Foreground),
234 (UpdateLanguageServer, Foreground),
235 (UpdateParticipantLocation, Foreground),
236 (UpdateProject, Foreground),
237 (UpdateProjectCollaborator, Foreground),
238 (UpdateWorktree, Foreground),
239 (UpdateDiffBase, Foreground),
240 (GetPrivateUserInfo, Foreground),
241 (GetPrivateUserInfoResponse, Foreground),
242);
243
244request_messages!(
245 (ApplyCodeAction, ApplyCodeActionResponse),
246 (
247 ApplyCompletionAdditionalEdits,
248 ApplyCompletionAdditionalEditsResponse
249 ),
250 (Call, Ack),
251 (CancelCall, Ack),
252 (CopyProjectEntry, ProjectEntryResponse),
253 (CreateProjectEntry, ProjectEntryResponse),
254 (CreateRoom, CreateRoomResponse),
255 (DeclineCall, Ack),
256 (DeleteProjectEntry, ProjectEntryResponse),
257 (Follow, FollowResponse),
258 (FormatBuffers, FormatBuffersResponse),
259 (GetChannelMessages, GetChannelMessagesResponse),
260 (GetChannels, GetChannelsResponse),
261 (GetCodeActions, GetCodeActionsResponse),
262 (GetHover, GetHoverResponse),
263 (GetCompletions, GetCompletionsResponse),
264 (GetDefinition, GetDefinitionResponse),
265 (GetTypeDefinition, GetTypeDefinitionResponse),
266 (GetDocumentHighlights, GetDocumentHighlightsResponse),
267 (GetReferences, GetReferencesResponse),
268 (GetPrivateUserInfo, GetPrivateUserInfoResponse),
269 (GetProjectSymbols, GetProjectSymbolsResponse),
270 (FuzzySearchUsers, UsersResponse),
271 (GetUsers, UsersResponse),
272 (JoinChannel, JoinChannelResponse),
273 (JoinProject, JoinProjectResponse),
274 (JoinRoom, JoinRoomResponse),
275 (LeaveRoom, Ack),
276 (RejoinRoom, RejoinRoomResponse),
277 (IncomingCall, Ack),
278 (OpenBufferById, OpenBufferResponse),
279 (OpenBufferByPath, OpenBufferResponse),
280 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
281 (Ping, Ack),
282 (PerformRename, PerformRenameResponse),
283 (PrepareRename, PrepareRenameResponse),
284 (OnTypeFormatting, OnTypeFormattingResponse),
285 (ReloadBuffers, ReloadBuffersResponse),
286 (RequestContact, Ack),
287 (RemoveContact, Ack),
288 (RespondToContactRequest, Ack),
289 (RenameProjectEntry, ProjectEntryResponse),
290 (SaveBuffer, BufferSaved),
291 (SearchProject, SearchProjectResponse),
292 (SendChannelMessage, SendChannelMessageResponse),
293 (ShareProject, ShareProjectResponse),
294 (SynchronizeBuffers, SynchronizeBuffersResponse),
295 (Test, Test),
296 (UpdateBuffer, Ack),
297 (UpdateParticipantLocation, Ack),
298 (UpdateProject, Ack),
299 (UpdateWorktree, Ack),
300);
301
302entity_messages!(
303 project_id,
304 AddProjectCollaborator,
305 ApplyCodeAction,
306 ApplyCompletionAdditionalEdits,
307 BufferReloaded,
308 BufferSaved,
309 CopyProjectEntry,
310 CreateBufferForPeer,
311 CreateProjectEntry,
312 DeleteProjectEntry,
313 Follow,
314 FormatBuffers,
315 GetCodeActions,
316 GetCompletions,
317 GetDefinition,
318 GetTypeDefinition,
319 GetDocumentHighlights,
320 GetHover,
321 GetReferences,
322 GetProjectSymbols,
323 JoinProject,
324 LeaveProject,
325 OpenBufferById,
326 OpenBufferByPath,
327 OpenBufferForSymbol,
328 PerformRename,
329 OnTypeFormatting,
330 PrepareRename,
331 ReloadBuffers,
332 RemoveProjectCollaborator,
333 RenameProjectEntry,
334 SaveBuffer,
335 SearchProject,
336 StartLanguageServer,
337 SynchronizeBuffers,
338 Unfollow,
339 UnshareProject,
340 UpdateBuffer,
341 UpdateBufferFile,
342 UpdateDiagnosticSummary,
343 UpdateFollowers,
344 UpdateLanguageServer,
345 UpdateProject,
346 UpdateProjectCollaborator,
347 UpdateWorktree,
348 UpdateDiffBase
349);
350
351entity_messages!(channel_id, ChannelMessageSent);
352
353const KIB: usize = 1024;
354const MIB: usize = KIB * 1024;
355const MAX_BUFFER_LEN: usize = MIB;
356
357/// A stream of protobuf messages.
358pub struct MessageStream<S> {
359 stream: S,
360 encoding_buffer: Vec<u8>,
361}
362
363#[allow(clippy::large_enum_variant)]
364#[derive(Debug)]
365pub enum Message {
366 Envelope(Envelope),
367 Ping,
368 Pong,
369}
370
371impl<S> MessageStream<S> {
372 pub fn new(stream: S) -> Self {
373 Self {
374 stream,
375 encoding_buffer: Vec::new(),
376 }
377 }
378
379 pub fn inner_mut(&mut self) -> &mut S {
380 &mut self.stream
381 }
382}
383
384impl<S> MessageStream<S>
385where
386 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
387{
388 pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
389 #[cfg(any(test, feature = "test-support"))]
390 const COMPRESSION_LEVEL: i32 = -7;
391
392 #[cfg(not(any(test, feature = "test-support")))]
393 const COMPRESSION_LEVEL: i32 = 4;
394
395 match message {
396 Message::Envelope(message) => {
397 self.encoding_buffer.reserve(message.encoded_len());
398 message
399 .encode(&mut self.encoding_buffer)
400 .map_err(io::Error::from)?;
401 let buffer =
402 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
403 .unwrap();
404
405 self.encoding_buffer.clear();
406 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
407 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
408 }
409 Message::Ping => {
410 self.stream
411 .send(WebSocketMessage::Ping(Default::default()))
412 .await?;
413 }
414 Message::Pong => {
415 self.stream
416 .send(WebSocketMessage::Pong(Default::default()))
417 .await?;
418 }
419 }
420
421 Ok(())
422 }
423}
424
425impl<S> MessageStream<S>
426where
427 S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
428{
429 pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
430 while let Some(bytes) = self.stream.next().await {
431 match bytes? {
432 WebSocketMessage::Binary(bytes) => {
433 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
434 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
435 .map_err(io::Error::from)?;
436
437 self.encoding_buffer.clear();
438 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
439 return Ok(Message::Envelope(envelope));
440 }
441 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
442 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
443 WebSocketMessage::Close(_) => break,
444 _ => {}
445 }
446 }
447 Err(anyhow!("connection closed"))
448 }
449}
450
451impl From<Timestamp> for SystemTime {
452 fn from(val: Timestamp) -> Self {
453 UNIX_EPOCH
454 .checked_add(Duration::new(val.seconds, val.nanos))
455 .unwrap()
456 }
457}
458
459impl From<SystemTime> for Timestamp {
460 fn from(time: SystemTime) -> Self {
461 let duration = time.duration_since(UNIX_EPOCH).unwrap();
462 Self {
463 seconds: duration.as_secs(),
464 nanos: duration.subsec_nanos(),
465 }
466 }
467}
468
469impl From<u128> for Nonce {
470 fn from(nonce: u128) -> Self {
471 let upper_half = (nonce >> 64) as u64;
472 let lower_half = nonce as u64;
473 Self {
474 upper_half,
475 lower_half,
476 }
477 }
478}
479
480impl From<Nonce> for u128 {
481 fn from(nonce: Nonce) -> Self {
482 let upper_half = (nonce.upper_half as u128) << 64;
483 let lower_half = nonce.lower_half as u128;
484 upper_half | lower_half
485 }
486}
487
488pub fn split_worktree_update(
489 mut message: UpdateWorktree,
490 max_chunk_size: usize,
491) -> impl Iterator<Item = UpdateWorktree> {
492 let mut done_files = false;
493
494 let mut repository_map = message
495 .updated_repositories
496 .into_iter()
497 .map(|repo| (repo.work_directory_id, repo))
498 .collect::<HashMap<_, _>>();
499
500 iter::from_fn(move || {
501 if done_files {
502 return None;
503 }
504
505 let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
506 let updated_entries: Vec<_> = message
507 .updated_entries
508 .drain(..updated_entries_chunk_size)
509 .collect();
510
511 let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
512 let removed_entries = message
513 .removed_entries
514 .drain(..removed_entries_chunk_size)
515 .collect();
516
517 done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
518
519 let mut updated_repositories = Vec::new();
520
521 if !repository_map.is_empty() {
522 for entry in &updated_entries {
523 if let Some(repo) = repository_map.remove(&entry.id) {
524 updated_repositories.push(repo)
525 }
526 }
527 }
528
529 let removed_repositories = if done_files {
530 mem::take(&mut message.removed_repositories)
531 } else {
532 Default::default()
533 };
534
535 if done_files {
536 updated_repositories.extend(mem::take(&mut repository_map).into_values());
537 }
538
539 Some(UpdateWorktree {
540 project_id: message.project_id,
541 worktree_id: message.worktree_id,
542 root_name: message.root_name.clone(),
543 abs_path: message.abs_path.clone(),
544 updated_entries,
545 removed_entries,
546 scan_id: message.scan_id,
547 is_last_update: done_files && message.is_last_update,
548 updated_repositories,
549 removed_repositories,
550 })
551 })
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[gpui::test]
559 async fn test_buffer_size() {
560 let (tx, rx) = futures::channel::mpsc::unbounded();
561 let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
562 sink.write(Message::Envelope(Envelope {
563 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
564 root_name: "abcdefg".repeat(10),
565 ..Default::default()
566 })),
567 ..Default::default()
568 }))
569 .await
570 .unwrap();
571 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
572 sink.write(Message::Envelope(Envelope {
573 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
574 root_name: "abcdefg".repeat(1000000),
575 ..Default::default()
576 })),
577 ..Default::default()
578 }))
579 .await
580 .unwrap();
581 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
582
583 let mut stream = MessageStream::new(rx.map(anyhow::Ok));
584 stream.read().await.unwrap();
585 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
586 stream.read().await.unwrap();
587 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
588 }
589
590 #[gpui::test]
591 fn test_converting_peer_id_from_and_to_u64() {
592 let peer_id = PeerId {
593 owner_id: 10,
594 id: 3,
595 };
596 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
597 let peer_id = PeerId {
598 owner_id: u32::MAX,
599 id: 3,
600 };
601 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
602 let peer_id = PeerId {
603 owner_id: 10,
604 id: u32::MAX,
605 };
606 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
607 let peer_id = PeerId {
608 owner_id: u32::MAX,
609 id: u32::MAX,
610 };
611 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
612 }
613}