1use super::{entity_messages, messages, request_messages, ConnectionId, TypedEnvelope};
2use anyhow::{anyhow, Result};
3use async_tungstenite::tungstenite::Message as WebSocketMessage;
4use collections::HashMap;
5use futures::{SinkExt as _, StreamExt as _};
6use prost::Message as _;
7use serde::Serialize;
8use std::any::{Any, TypeId};
9use std::{
10 cmp,
11 fmt::Debug,
12 io, iter,
13 time::{Duration, SystemTime, UNIX_EPOCH},
14};
15use std::{fmt, mem};
16
17include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
18
19pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 'static {
20 const NAME: &'static str;
21 const PRIORITY: MessagePriority;
22 fn into_envelope(
23 self,
24 id: u32,
25 responding_to: Option<u32>,
26 original_sender_id: Option<PeerId>,
27 ) -> Envelope;
28 fn from_envelope(envelope: Envelope) -> Option<Self>;
29}
30
31pub trait EntityMessage: EnvelopedMessage {
32 fn remote_entity_id(&self) -> u64;
33}
34
35pub trait RequestMessage: EnvelopedMessage {
36 type Response: EnvelopedMessage;
37}
38
39pub trait AnyTypedEnvelope: 'static + Send + Sync {
40 fn payload_type_id(&self) -> TypeId;
41 fn payload_type_name(&self) -> &'static str;
42 fn as_any(&self) -> &dyn Any;
43 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
44 fn is_background(&self) -> bool;
45 fn original_sender_id(&self) -> Option<PeerId>;
46 fn sender_id(&self) -> ConnectionId;
47 fn message_id(&self) -> u32;
48}
49
50pub enum MessagePriority {
51 Foreground,
52 Background,
53}
54
55impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
56 fn payload_type_id(&self) -> TypeId {
57 TypeId::of::<T>()
58 }
59
60 fn payload_type_name(&self) -> &'static str {
61 T::NAME
62 }
63
64 fn as_any(&self) -> &dyn Any {
65 self
66 }
67
68 fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
69 self
70 }
71
72 fn is_background(&self) -> bool {
73 matches!(T::PRIORITY, MessagePriority::Background)
74 }
75
76 fn original_sender_id(&self) -> Option<PeerId> {
77 self.original_sender_id
78 }
79
80 fn sender_id(&self) -> ConnectionId {
81 self.sender_id
82 }
83
84 fn message_id(&self) -> u32 {
85 self.message_id
86 }
87}
88
89impl PeerId {
90 pub fn from_u64(peer_id: u64) -> Self {
91 let owner_id = (peer_id >> 32) as u32;
92 let id = peer_id as u32;
93 Self { owner_id, id }
94 }
95
96 pub fn as_u64(self) -> u64 {
97 ((self.owner_id as u64) << 32) | (self.id as u64)
98 }
99}
100
101impl Copy for PeerId {}
102
103impl Eq for PeerId {}
104
105impl Ord for PeerId {
106 fn cmp(&self, other: &Self) -> cmp::Ordering {
107 self.owner_id
108 .cmp(&other.owner_id)
109 .then_with(|| self.id.cmp(&other.id))
110 }
111}
112
113impl PartialOrd for PeerId {
114 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
115 Some(self.cmp(other))
116 }
117}
118
119impl std::hash::Hash for PeerId {
120 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
121 self.owner_id.hash(state);
122 self.id.hash(state);
123 }
124}
125
126impl fmt::Display for PeerId {
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 write!(f, "{}/{}", self.owner_id, self.id)
129 }
130}
131
132messages!(
133 (Ack, Foreground),
134 (AddProjectCollaborator, Foreground),
135 (ApplyCodeAction, Background),
136 (ApplyCodeActionResponse, Background),
137 (ApplyCompletionAdditionalEdits, Background),
138 (ApplyCompletionAdditionalEditsResponse, Background),
139 (BufferReloaded, Foreground),
140 (BufferSaved, Foreground),
141 (Call, Foreground),
142 (CallCanceled, Foreground),
143 (CancelCall, Foreground),
144 (ChannelMessageSent, Foreground),
145 (CopyProjectEntry, Foreground),
146 (CreateBufferForPeer, Foreground),
147 (CreateProjectEntry, Foreground),
148 (CreateRoom, Foreground),
149 (CreateRoomResponse, Foreground),
150 (DeclineCall, Foreground),
151 (DeleteProjectEntry, Foreground),
152 (Error, Foreground),
153 (Follow, Foreground),
154 (FollowResponse, Foreground),
155 (FormatBuffers, Foreground),
156 (FormatBuffersResponse, Foreground),
157 (FuzzySearchUsers, Foreground),
158 (GetChannelMessages, Foreground),
159 (GetChannelMessagesResponse, Foreground),
160 (GetChannels, Foreground),
161 (GetChannelsResponse, Foreground),
162 (GetCodeActions, Background),
163 (GetCodeActionsResponse, Background),
164 (GetHover, Background),
165 (GetHoverResponse, Background),
166 (GetCompletions, Background),
167 (GetCompletionsResponse, Background),
168 (GetDefinition, Background),
169 (GetDefinitionResponse, Background),
170 (GetTypeDefinition, Background),
171 (GetTypeDefinitionResponse, Background),
172 (GetDocumentHighlights, Background),
173 (GetDocumentHighlightsResponse, Background),
174 (GetReferences, Background),
175 (GetReferencesResponse, Background),
176 (GetProjectSymbols, Background),
177 (GetProjectSymbolsResponse, Background),
178 (GetUsers, Foreground),
179 (Hello, Foreground),
180 (IncomingCall, Foreground),
181 (UsersResponse, Foreground),
182 (JoinChannel, Foreground),
183 (JoinChannelResponse, Foreground),
184 (JoinProject, Foreground),
185 (JoinProjectResponse, Foreground),
186 (JoinRoom, Foreground),
187 (JoinRoomResponse, Foreground),
188 (LeaveChannel, Foreground),
189 (LeaveProject, Foreground),
190 (LeaveRoom, Foreground),
191 (OpenBufferById, Background),
192 (OpenBufferByPath, Background),
193 (OpenBufferForSymbol, Background),
194 (OpenBufferForSymbolResponse, Background),
195 (OpenBufferResponse, Background),
196 (PerformRename, Background),
197 (PerformRenameResponse, Background),
198 (Ping, Foreground),
199 (PrepareRename, Background),
200 (PrepareRenameResponse, Background),
201 (ProjectEntryResponse, Foreground),
202 (RejoinRoom, Foreground),
203 (RejoinRoomResponse, Foreground),
204 (RemoveContact, Foreground),
205 (ReloadBuffers, Foreground),
206 (ReloadBuffersResponse, Foreground),
207 (RemoveProjectCollaborator, Foreground),
208 (RenameProjectEntry, Foreground),
209 (RequestContact, Foreground),
210 (RespondToContactRequest, Foreground),
211 (RoomUpdated, Foreground),
212 (SaveBuffer, Foreground),
213 (SearchProject, Background),
214 (SearchProjectResponse, Background),
215 (SendChannelMessage, Foreground),
216 (SendChannelMessageResponse, Foreground),
217 (ShareProject, Foreground),
218 (ShareProjectResponse, Foreground),
219 (ShowContacts, Foreground),
220 (StartLanguageServer, Foreground),
221 (SynchronizeBuffers, Foreground),
222 (SynchronizeBuffersResponse, Foreground),
223 (Test, Foreground),
224 (Unfollow, Foreground),
225 (UnshareProject, Foreground),
226 (UpdateBuffer, Foreground),
227 (UpdateBufferFile, Foreground),
228 (UpdateContacts, Foreground),
229 (UpdateDiagnosticSummary, Foreground),
230 (UpdateFollowers, Foreground),
231 (UpdateInviteInfo, Foreground),
232 (UpdateLanguageServer, Foreground),
233 (UpdateParticipantLocation, Foreground),
234 (UpdateProject, Foreground),
235 (UpdateProjectCollaborator, Foreground),
236 (UpdateWorktree, Foreground),
237 (UpdateDiffBase, Foreground),
238 (GetPrivateUserInfo, Foreground),
239 (GetPrivateUserInfoResponse, Foreground),
240);
241
242request_messages!(
243 (ApplyCodeAction, ApplyCodeActionResponse),
244 (
245 ApplyCompletionAdditionalEdits,
246 ApplyCompletionAdditionalEditsResponse
247 ),
248 (Call, Ack),
249 (CancelCall, Ack),
250 (CopyProjectEntry, ProjectEntryResponse),
251 (CreateProjectEntry, ProjectEntryResponse),
252 (CreateRoom, CreateRoomResponse),
253 (DeclineCall, Ack),
254 (DeleteProjectEntry, ProjectEntryResponse),
255 (Follow, FollowResponse),
256 (FormatBuffers, FormatBuffersResponse),
257 (GetChannelMessages, GetChannelMessagesResponse),
258 (GetChannels, GetChannelsResponse),
259 (GetCodeActions, GetCodeActionsResponse),
260 (GetHover, GetHoverResponse),
261 (GetCompletions, GetCompletionsResponse),
262 (GetDefinition, GetDefinitionResponse),
263 (GetTypeDefinition, GetTypeDefinitionResponse),
264 (GetDocumentHighlights, GetDocumentHighlightsResponse),
265 (GetReferences, GetReferencesResponse),
266 (GetPrivateUserInfo, GetPrivateUserInfoResponse),
267 (GetProjectSymbols, GetProjectSymbolsResponse),
268 (FuzzySearchUsers, UsersResponse),
269 (GetUsers, UsersResponse),
270 (JoinChannel, JoinChannelResponse),
271 (JoinProject, JoinProjectResponse),
272 (JoinRoom, JoinRoomResponse),
273 (LeaveRoom, Ack),
274 (RejoinRoom, RejoinRoomResponse),
275 (IncomingCall, Ack),
276 (OpenBufferById, OpenBufferResponse),
277 (OpenBufferByPath, OpenBufferResponse),
278 (OpenBufferForSymbol, OpenBufferForSymbolResponse),
279 (Ping, Ack),
280 (PerformRename, PerformRenameResponse),
281 (PrepareRename, PrepareRenameResponse),
282 (ReloadBuffers, ReloadBuffersResponse),
283 (RequestContact, Ack),
284 (RemoveContact, Ack),
285 (RespondToContactRequest, Ack),
286 (RenameProjectEntry, ProjectEntryResponse),
287 (SaveBuffer, BufferSaved),
288 (SearchProject, SearchProjectResponse),
289 (SendChannelMessage, SendChannelMessageResponse),
290 (ShareProject, ShareProjectResponse),
291 (SynchronizeBuffers, SynchronizeBuffersResponse),
292 (Test, Test),
293 (UpdateBuffer, Ack),
294 (UpdateParticipantLocation, Ack),
295 (UpdateProject, Ack),
296 (UpdateWorktree, Ack),
297);
298
299entity_messages!(
300 project_id,
301 AddProjectCollaborator,
302 ApplyCodeAction,
303 ApplyCompletionAdditionalEdits,
304 BufferReloaded,
305 BufferSaved,
306 CopyProjectEntry,
307 CreateBufferForPeer,
308 CreateProjectEntry,
309 DeleteProjectEntry,
310 Follow,
311 FormatBuffers,
312 GetCodeActions,
313 GetCompletions,
314 GetDefinition,
315 GetTypeDefinition,
316 GetDocumentHighlights,
317 GetHover,
318 GetReferences,
319 GetProjectSymbols,
320 JoinProject,
321 LeaveProject,
322 OpenBufferById,
323 OpenBufferByPath,
324 OpenBufferForSymbol,
325 PerformRename,
326 PrepareRename,
327 ReloadBuffers,
328 RemoveProjectCollaborator,
329 RenameProjectEntry,
330 SaveBuffer,
331 SearchProject,
332 StartLanguageServer,
333 SynchronizeBuffers,
334 Unfollow,
335 UnshareProject,
336 UpdateBuffer,
337 UpdateBufferFile,
338 UpdateDiagnosticSummary,
339 UpdateFollowers,
340 UpdateLanguageServer,
341 UpdateProject,
342 UpdateProjectCollaborator,
343 UpdateWorktree,
344 UpdateDiffBase
345);
346
347entity_messages!(channel_id, ChannelMessageSent);
348
349const KIB: usize = 1024;
350const MIB: usize = KIB * 1024;
351const MAX_BUFFER_LEN: usize = MIB;
352
353/// A stream of protobuf messages.
354pub struct MessageStream<S> {
355 stream: S,
356 encoding_buffer: Vec<u8>,
357}
358
359#[allow(clippy::large_enum_variant)]
360#[derive(Debug)]
361pub enum Message {
362 Envelope(Envelope),
363 Ping,
364 Pong,
365}
366
367impl<S> MessageStream<S> {
368 pub fn new(stream: S) -> Self {
369 Self {
370 stream,
371 encoding_buffer: Vec::new(),
372 }
373 }
374
375 pub fn inner_mut(&mut self) -> &mut S {
376 &mut self.stream
377 }
378}
379
380impl<S> MessageStream<S>
381where
382 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
383{
384 pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
385 #[cfg(any(test, feature = "test-support"))]
386 const COMPRESSION_LEVEL: i32 = -7;
387
388 #[cfg(not(any(test, feature = "test-support")))]
389 const COMPRESSION_LEVEL: i32 = 4;
390
391 match message {
392 Message::Envelope(message) => {
393 self.encoding_buffer.reserve(message.encoded_len());
394 message
395 .encode(&mut self.encoding_buffer)
396 .map_err(io::Error::from)?;
397 let buffer =
398 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
399 .unwrap();
400
401 self.encoding_buffer.clear();
402 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
403 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
404 }
405 Message::Ping => {
406 self.stream
407 .send(WebSocketMessage::Ping(Default::default()))
408 .await?;
409 }
410 Message::Pong => {
411 self.stream
412 .send(WebSocketMessage::Pong(Default::default()))
413 .await?;
414 }
415 }
416
417 Ok(())
418 }
419}
420
421impl<S> MessageStream<S>
422where
423 S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
424{
425 pub async fn read(&mut self) -> Result<Message, anyhow::Error> {
426 while let Some(bytes) = self.stream.next().await {
427 match bytes? {
428 WebSocketMessage::Binary(bytes) => {
429 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
430 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
431 .map_err(io::Error::from)?;
432
433 self.encoding_buffer.clear();
434 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
435 return Ok(Message::Envelope(envelope));
436 }
437 WebSocketMessage::Ping(_) => return Ok(Message::Ping),
438 WebSocketMessage::Pong(_) => return Ok(Message::Pong),
439 WebSocketMessage::Close(_) => break,
440 _ => {}
441 }
442 }
443 Err(anyhow!("connection closed"))
444 }
445}
446
447impl From<Timestamp> for SystemTime {
448 fn from(val: Timestamp) -> Self {
449 UNIX_EPOCH
450 .checked_add(Duration::new(val.seconds, val.nanos))
451 .unwrap()
452 }
453}
454
455impl From<SystemTime> for Timestamp {
456 fn from(time: SystemTime) -> Self {
457 let duration = time.duration_since(UNIX_EPOCH).unwrap();
458 Self {
459 seconds: duration.as_secs(),
460 nanos: duration.subsec_nanos(),
461 }
462 }
463}
464
465impl From<u128> for Nonce {
466 fn from(nonce: u128) -> Self {
467 let upper_half = (nonce >> 64) as u64;
468 let lower_half = nonce as u64;
469 Self {
470 upper_half,
471 lower_half,
472 }
473 }
474}
475
476impl From<Nonce> for u128 {
477 fn from(nonce: Nonce) -> Self {
478 let upper_half = (nonce.upper_half as u128) << 64;
479 let lower_half = nonce.lower_half as u128;
480 upper_half | lower_half
481 }
482}
483
484pub fn split_worktree_update(
485 mut message: UpdateWorktree,
486 max_chunk_size: usize,
487) -> impl Iterator<Item = UpdateWorktree> {
488 let mut done_files = false;
489
490 let mut repository_map = message
491 .updated_repositories
492 .into_iter()
493 .map(|repo| (repo.work_directory_id, repo))
494 .collect::<HashMap<_, _>>();
495
496 iter::from_fn(move || {
497 if done_files {
498 return None;
499 }
500
501 let updated_entries_chunk_size = cmp::min(message.updated_entries.len(), max_chunk_size);
502 let updated_entries: Vec<_> = message
503 .updated_entries
504 .drain(..updated_entries_chunk_size)
505 .collect();
506
507 let removed_entries_chunk_size = cmp::min(message.removed_entries.len(), max_chunk_size);
508 let removed_entries = message
509 .removed_entries
510 .drain(..removed_entries_chunk_size)
511 .collect();
512
513 done_files = message.updated_entries.is_empty() && message.removed_entries.is_empty();
514
515 let mut updated_repositories = Vec::new();
516
517 if !repository_map.is_empty() {
518 for entry in &updated_entries {
519 if let Some(repo) = repository_map.remove(&entry.id) {
520 updated_repositories.push(repo)
521 }
522 }
523 }
524
525 let removed_repositories = if done_files {
526 mem::take(&mut message.removed_repositories)
527 } else {
528 Default::default()
529 };
530
531 if done_files {
532 updated_repositories.extend(mem::take(&mut repository_map).into_values());
533 }
534
535 Some(UpdateWorktree {
536 project_id: message.project_id,
537 worktree_id: message.worktree_id,
538 root_name: message.root_name.clone(),
539 abs_path: message.abs_path.clone(),
540 updated_entries,
541 removed_entries,
542 scan_id: message.scan_id,
543 is_last_update: done_files && message.is_last_update,
544 updated_repositories,
545 removed_repositories,
546 })
547 })
548}
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553
554 #[gpui::test]
555 async fn test_buffer_size() {
556 let (tx, rx) = futures::channel::mpsc::unbounded();
557 let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
558 sink.write(Message::Envelope(Envelope {
559 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
560 root_name: "abcdefg".repeat(10),
561 ..Default::default()
562 })),
563 ..Default::default()
564 }))
565 .await
566 .unwrap();
567 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
568 sink.write(Message::Envelope(Envelope {
569 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
570 root_name: "abcdefg".repeat(1000000),
571 ..Default::default()
572 })),
573 ..Default::default()
574 }))
575 .await
576 .unwrap();
577 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
578
579 let mut stream = MessageStream::new(rx.map(anyhow::Ok));
580 stream.read().await.unwrap();
581 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
582 stream.read().await.unwrap();
583 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
584 }
585
586 #[gpui::test]
587 fn test_converting_peer_id_from_and_to_u64() {
588 let peer_id = PeerId {
589 owner_id: 10,
590 id: 3,
591 };
592 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
593 let peer_id = PeerId {
594 owner_id: u32::MAX,
595 id: 3,
596 };
597 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
598 let peer_id = PeerId {
599 owner_id: 10,
600 id: u32::MAX,
601 };
602 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
603 let peer_id = PeerId {
604 owner_id: u32::MAX,
605 id: u32::MAX,
606 };
607 assert_eq!(PeerId::from_u64(peer_id.as_u64()), peer_id);
608 }
609}