1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod amplitude_telemetry;
5pub mod channel;
6pub mod http;
7pub mod telemetry;
8pub mod user;
9
10use amplitude_telemetry::AmplitudeTelemetry;
11use anyhow::{anyhow, Context, Result};
12use async_recursion::async_recursion;
13use async_tungstenite::tungstenite::{
14 error::Error as WebsocketError,
15 http::{Request, StatusCode},
16};
17use db::Db;
18use futures::{future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryStreamExt};
19use gpui::{
20 actions,
21 serde_json::{self, Value},
22 AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AppContext,
23 AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext,
24 ViewHandle,
25};
26use http::HttpClient;
27use lazy_static::lazy_static;
28use parking_lot::RwLock;
29use postage::watch;
30use rand::prelude::*;
31use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
32use serde::Deserialize;
33use std::{
34 any::TypeId,
35 collections::HashMap,
36 convert::TryFrom,
37 fmt::Write as _,
38 future::Future,
39 path::PathBuf,
40 sync::{Arc, Weak},
41 time::{Duration, Instant},
42};
43use telemetry::Telemetry;
44use thiserror::Error;
45use url::Url;
46use util::{ResultExt, TryFutureExt};
47
48pub use channel::*;
49pub use rpc::*;
50pub use user::*;
51
52lazy_static! {
53 pub static ref ZED_SERVER_URL: String =
54 std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
55 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
56 .ok()
57 .and_then(|s| if s.is_empty() { None } else { Some(s) });
58 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
59 .ok()
60 .and_then(|s| if s.is_empty() { None } else { Some(s) });
61}
62
63pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
64pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
65pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
66
67actions!(client, [Authenticate]);
68
69pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
70 cx.add_global_action({
71 let client = client.clone();
72 move |_: &Authenticate, cx| {
73 let client = client.clone();
74 cx.spawn(
75 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
76 )
77 .detach();
78 }
79 });
80}
81
82pub struct Client {
83 id: usize,
84 peer: Arc<Peer>,
85 http: Arc<dyn HttpClient>,
86 telemetry: Arc<Telemetry>,
87 amplitude_telemetry: Arc<AmplitudeTelemetry>,
88 state: RwLock<ClientState>,
89
90 #[allow(clippy::type_complexity)]
91 #[cfg(any(test, feature = "test-support"))]
92 authenticate: RwLock<
93 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
94 >,
95
96 #[allow(clippy::type_complexity)]
97 #[cfg(any(test, feature = "test-support"))]
98 establish_connection: RwLock<
99 Option<
100 Box<
101 dyn 'static
102 + Send
103 + Sync
104 + Fn(
105 &Credentials,
106 &AsyncAppContext,
107 ) -> Task<Result<Connection, EstablishConnectionError>>,
108 >,
109 >,
110 >,
111}
112
113#[derive(Error, Debug)]
114pub enum EstablishConnectionError {
115 #[error("upgrade required")]
116 UpgradeRequired,
117 #[error("unauthorized")]
118 Unauthorized,
119 #[error("{0}")]
120 Other(#[from] anyhow::Error),
121 #[error("{0}")]
122 Http(#[from] http::Error),
123 #[error("{0}")]
124 Io(#[from] std::io::Error),
125 #[error("{0}")]
126 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
127}
128
129impl From<WebsocketError> for EstablishConnectionError {
130 fn from(error: WebsocketError) -> Self {
131 if let WebsocketError::Http(response) = &error {
132 match response.status() {
133 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
134 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
135 _ => {}
136 }
137 }
138 EstablishConnectionError::Other(error.into())
139 }
140}
141
142impl EstablishConnectionError {
143 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
144 Self::Other(error.into())
145 }
146}
147
148#[derive(Copy, Clone, Debug, Eq, PartialEq)]
149pub enum Status {
150 SignedOut,
151 UpgradeRequired,
152 Authenticating,
153 Connecting,
154 ConnectionError,
155 Connected {
156 peer_id: PeerId,
157 connection_id: ConnectionId,
158 },
159 ConnectionLost,
160 Reauthenticating,
161 Reconnecting,
162 ReconnectionError {
163 next_reconnection: Instant,
164 },
165}
166
167impl Status {
168 pub fn is_connected(&self) -> bool {
169 matches!(self, Self::Connected { .. })
170 }
171}
172
173struct ClientState {
174 credentials: Option<Credentials>,
175 status: (watch::Sender<Status>, watch::Receiver<Status>),
176 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
177 _reconnect_task: Option<Task<()>>,
178 reconnect_interval: Duration,
179 entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
180 models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
181 entity_types_by_message_type: HashMap<TypeId, TypeId>,
182 #[allow(clippy::type_complexity)]
183 message_handlers: HashMap<
184 TypeId,
185 Arc<
186 dyn Send
187 + Sync
188 + Fn(
189 AnyEntityHandle,
190 Box<dyn AnyTypedEnvelope>,
191 &Arc<Client>,
192 AsyncAppContext,
193 ) -> LocalBoxFuture<'static, Result<()>>,
194 >,
195 >,
196}
197
198enum AnyWeakEntityHandle {
199 Model(AnyWeakModelHandle),
200 View(AnyWeakViewHandle),
201}
202
203enum AnyEntityHandle {
204 Model(AnyModelHandle),
205 View(AnyViewHandle),
206}
207
208#[derive(Clone, Debug)]
209pub struct Credentials {
210 pub user_id: u64,
211 pub access_token: String,
212}
213
214impl Default for ClientState {
215 fn default() -> Self {
216 Self {
217 credentials: None,
218 status: watch::channel_with(Status::SignedOut),
219 entity_id_extractors: Default::default(),
220 _reconnect_task: None,
221 reconnect_interval: Duration::from_secs(5),
222 models_by_message_type: Default::default(),
223 entities_by_type_and_remote_id: Default::default(),
224 entity_types_by_message_type: Default::default(),
225 message_handlers: Default::default(),
226 }
227 }
228}
229
230pub enum Subscription {
231 Entity {
232 client: Weak<Client>,
233 id: (TypeId, u64),
234 },
235 Message {
236 client: Weak<Client>,
237 id: TypeId,
238 },
239}
240
241impl Drop for Subscription {
242 fn drop(&mut self) {
243 match self {
244 Subscription::Entity { client, id } => {
245 if let Some(client) = client.upgrade() {
246 let mut state = client.state.write();
247 let _ = state.entities_by_type_and_remote_id.remove(id);
248 }
249 }
250 Subscription::Message { client, id } => {
251 if let Some(client) = client.upgrade() {
252 let mut state = client.state.write();
253 let _ = state.entity_types_by_message_type.remove(id);
254 let _ = state.message_handlers.remove(id);
255 }
256 }
257 }
258 }
259}
260
261impl Client {
262 pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
263 Arc::new(Self {
264 id: 0,
265 peer: Peer::new(),
266 telemetry: Telemetry::new(http.clone(), cx),
267 amplitude_telemetry: AmplitudeTelemetry::new(http.clone(), cx),
268 http,
269 state: Default::default(),
270
271 #[cfg(any(test, feature = "test-support"))]
272 authenticate: Default::default(),
273 #[cfg(any(test, feature = "test-support"))]
274 establish_connection: Default::default(),
275 })
276 }
277
278 pub fn id(&self) -> usize {
279 self.id
280 }
281
282 pub fn http_client(&self) -> Arc<dyn HttpClient> {
283 self.http.clone()
284 }
285
286 #[cfg(any(test, feature = "test-support"))]
287 pub fn set_id(&mut self, id: usize) -> &Self {
288 self.id = id;
289 self
290 }
291
292 #[cfg(any(test, feature = "test-support"))]
293 pub fn tear_down(&self) {
294 let mut state = self.state.write();
295 state._reconnect_task.take();
296 state.message_handlers.clear();
297 state.models_by_message_type.clear();
298 state.entities_by_type_and_remote_id.clear();
299 state.entity_id_extractors.clear();
300 self.peer.reset();
301 }
302
303 #[cfg(any(test, feature = "test-support"))]
304 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
305 where
306 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
307 {
308 *self.authenticate.write() = Some(Box::new(authenticate));
309 self
310 }
311
312 #[cfg(any(test, feature = "test-support"))]
313 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
314 where
315 F: 'static
316 + Send
317 + Sync
318 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
319 {
320 *self.establish_connection.write() = Some(Box::new(connect));
321 self
322 }
323
324 pub fn user_id(&self) -> Option<u64> {
325 self.state
326 .read()
327 .credentials
328 .as_ref()
329 .map(|credentials| credentials.user_id)
330 }
331
332 pub fn peer_id(&self) -> Option<PeerId> {
333 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
334 Some(*peer_id)
335 } else {
336 None
337 }
338 }
339
340 pub fn status(&self) -> watch::Receiver<Status> {
341 self.state.read().status.1.clone()
342 }
343
344 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
345 log::info!("set status on client {}: {:?}", self.id, status);
346 let mut state = self.state.write();
347 *state.status.0.borrow_mut() = status;
348
349 match status {
350 Status::Connected { .. } => {
351 state._reconnect_task = None;
352 }
353 Status::ConnectionLost => {
354 let this = self.clone();
355 let reconnect_interval = state.reconnect_interval;
356 state._reconnect_task = Some(cx.spawn(|cx| async move {
357 let mut rng = StdRng::from_entropy();
358 let mut delay = INITIAL_RECONNECTION_DELAY;
359 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
360 log::error!("failed to connect {}", error);
361 if matches!(*this.status().borrow(), Status::ConnectionError) {
362 this.set_status(
363 Status::ReconnectionError {
364 next_reconnection: Instant::now() + delay,
365 },
366 &cx,
367 );
368 cx.background().timer(delay).await;
369 delay = delay
370 .mul_f32(rng.gen_range(1.0..=2.0))
371 .min(reconnect_interval);
372 } else {
373 break;
374 }
375 }
376 }));
377 }
378 Status::SignedOut | Status::UpgradeRequired => {
379 self.telemetry.set_authenticated_user_info(None, false);
380 self.amplitude_telemetry
381 .set_authenticated_user_info(None, false);
382 state._reconnect_task.take();
383 }
384 _ => {}
385 }
386 }
387
388 pub fn add_view_for_remote_entity<T: View>(
389 self: &Arc<Self>,
390 remote_id: u64,
391 cx: &mut ViewContext<T>,
392 ) -> Subscription {
393 let id = (TypeId::of::<T>(), remote_id);
394 self.state
395 .write()
396 .entities_by_type_and_remote_id
397 .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
398 Subscription::Entity {
399 client: Arc::downgrade(self),
400 id,
401 }
402 }
403
404 pub fn add_model_for_remote_entity<T: Entity>(
405 self: &Arc<Self>,
406 remote_id: u64,
407 cx: &mut ModelContext<T>,
408 ) -> Subscription {
409 let id = (TypeId::of::<T>(), remote_id);
410 self.state
411 .write()
412 .entities_by_type_and_remote_id
413 .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
414 Subscription::Entity {
415 client: Arc::downgrade(self),
416 id,
417 }
418 }
419
420 pub fn add_message_handler<M, E, H, F>(
421 self: &Arc<Self>,
422 model: ModelHandle<E>,
423 handler: H,
424 ) -> Subscription
425 where
426 M: EnvelopedMessage,
427 E: Entity,
428 H: 'static
429 + Send
430 + Sync
431 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
432 F: 'static + Future<Output = Result<()>>,
433 {
434 let message_type_id = TypeId::of::<M>();
435
436 let mut state = self.state.write();
437 state
438 .models_by_message_type
439 .insert(message_type_id, model.downgrade().into());
440
441 let prev_handler = state.message_handlers.insert(
442 message_type_id,
443 Arc::new(move |handle, envelope, client, cx| {
444 let handle = if let AnyEntityHandle::Model(handle) = handle {
445 handle
446 } else {
447 unreachable!();
448 };
449 let model = handle.downcast::<E>().unwrap();
450 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
451 handler(model, *envelope, client.clone(), cx).boxed_local()
452 }),
453 );
454 if prev_handler.is_some() {
455 panic!("registered handler for the same message twice");
456 }
457
458 Subscription::Message {
459 client: Arc::downgrade(self),
460 id: message_type_id,
461 }
462 }
463
464 pub fn add_request_handler<M, E, H, F>(
465 self: &Arc<Self>,
466 model: ModelHandle<E>,
467 handler: H,
468 ) -> Subscription
469 where
470 M: RequestMessage,
471 E: Entity,
472 H: 'static
473 + Send
474 + Sync
475 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
476 F: 'static + Future<Output = Result<M::Response>>,
477 {
478 self.add_message_handler(model, move |handle, envelope, this, cx| {
479 Self::respond_to_request(
480 envelope.receipt(),
481 handler(handle, envelope, this.clone(), cx),
482 this,
483 )
484 })
485 }
486
487 pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
488 where
489 M: EntityMessage,
490 E: View,
491 H: 'static
492 + Send
493 + Sync
494 + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
495 F: 'static + Future<Output = Result<()>>,
496 {
497 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
498 if let AnyEntityHandle::View(handle) = handle {
499 handler(handle.downcast::<E>().unwrap(), message, client, cx)
500 } else {
501 unreachable!();
502 }
503 })
504 }
505
506 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
507 where
508 M: EntityMessage,
509 E: Entity,
510 H: 'static
511 + Send
512 + Sync
513 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
514 F: 'static + Future<Output = Result<()>>,
515 {
516 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
517 if let AnyEntityHandle::Model(handle) = handle {
518 handler(handle.downcast::<E>().unwrap(), message, client, cx)
519 } else {
520 unreachable!();
521 }
522 })
523 }
524
525 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
526 where
527 M: EntityMessage,
528 E: Entity,
529 H: 'static
530 + Send
531 + Sync
532 + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
533 F: 'static + Future<Output = Result<()>>,
534 {
535 let model_type_id = TypeId::of::<E>();
536 let message_type_id = TypeId::of::<M>();
537
538 let mut state = self.state.write();
539 state
540 .entity_types_by_message_type
541 .insert(message_type_id, model_type_id);
542 state
543 .entity_id_extractors
544 .entry(message_type_id)
545 .or_insert_with(|| {
546 |envelope| {
547 envelope
548 .as_any()
549 .downcast_ref::<TypedEnvelope<M>>()
550 .unwrap()
551 .payload
552 .remote_entity_id()
553 }
554 });
555 let prev_handler = state.message_handlers.insert(
556 message_type_id,
557 Arc::new(move |handle, envelope, client, cx| {
558 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
559 handler(handle, *envelope, client.clone(), cx).boxed_local()
560 }),
561 );
562 if prev_handler.is_some() {
563 panic!("registered handler for the same message twice");
564 }
565 }
566
567 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
568 where
569 M: EntityMessage + RequestMessage,
570 E: Entity,
571 H: 'static
572 + Send
573 + Sync
574 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
575 F: 'static + Future<Output = Result<M::Response>>,
576 {
577 self.add_model_message_handler(move |entity, envelope, client, cx| {
578 Self::respond_to_request::<M, _>(
579 envelope.receipt(),
580 handler(entity, envelope, client.clone(), cx),
581 client,
582 )
583 })
584 }
585
586 pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
587 where
588 M: EntityMessage + RequestMessage,
589 E: View,
590 H: 'static
591 + Send
592 + Sync
593 + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
594 F: 'static + Future<Output = Result<M::Response>>,
595 {
596 self.add_view_message_handler(move |entity, envelope, client, cx| {
597 Self::respond_to_request::<M, _>(
598 envelope.receipt(),
599 handler(entity, envelope, client.clone(), cx),
600 client,
601 )
602 })
603 }
604
605 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
606 receipt: Receipt<T>,
607 response: F,
608 client: Arc<Self>,
609 ) -> Result<()> {
610 match response.await {
611 Ok(response) => {
612 client.respond(receipt, response)?;
613 Ok(())
614 }
615 Err(error) => {
616 client.respond_with_error(
617 receipt,
618 proto::Error {
619 message: format!("{:?}", error),
620 },
621 )?;
622 Err(error)
623 }
624 }
625 }
626
627 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
628 read_credentials_from_keychain(cx).is_some()
629 }
630
631 #[async_recursion(?Send)]
632 pub async fn authenticate_and_connect(
633 self: &Arc<Self>,
634 try_keychain: bool,
635 cx: &AsyncAppContext,
636 ) -> anyhow::Result<()> {
637 let was_disconnected = match *self.status().borrow() {
638 Status::SignedOut => true,
639 Status::ConnectionError
640 | Status::ConnectionLost
641 | Status::Authenticating { .. }
642 | Status::Reauthenticating { .. }
643 | Status::ReconnectionError { .. } => false,
644 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
645 return Ok(())
646 }
647 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
648 };
649
650 if was_disconnected {
651 self.set_status(Status::Authenticating, cx);
652 } else {
653 self.set_status(Status::Reauthenticating, cx)
654 }
655
656 let mut read_from_keychain = false;
657 let mut credentials = self.state.read().credentials.clone();
658 if credentials.is_none() && try_keychain {
659 credentials = read_credentials_from_keychain(cx);
660 read_from_keychain = credentials.is_some();
661 if read_from_keychain {
662 self.report_event("read credentials from keychain", Default::default());
663 }
664 }
665 if credentials.is_none() {
666 let mut status_rx = self.status();
667 let _ = status_rx.next().await;
668 futures::select_biased! {
669 authenticate = self.authenticate(cx).fuse() => {
670 match authenticate {
671 Ok(creds) => credentials = Some(creds),
672 Err(err) => {
673 self.set_status(Status::ConnectionError, cx);
674 return Err(err);
675 }
676 }
677 }
678 _ = status_rx.next().fuse() => {
679 return Err(anyhow!("authentication canceled"));
680 }
681 }
682 }
683 let credentials = credentials.unwrap();
684
685 if was_disconnected {
686 self.set_status(Status::Connecting, cx);
687 } else {
688 self.set_status(Status::Reconnecting, cx);
689 }
690
691 let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
692 futures::select_biased! {
693 connection = self.establish_connection(&credentials, cx).fuse() => {
694 match connection {
695 Ok(conn) => {
696 self.state.write().credentials = Some(credentials.clone());
697 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
698 write_credentials_to_keychain(&credentials, cx).log_err();
699 }
700
701 futures::select_biased! {
702 result = self.set_connection(conn, cx).fuse() => result,
703 _ = timeout => {
704 self.set_status(Status::ConnectionError, cx);
705 Err(anyhow!("timed out waiting on hello message from server"))
706 }
707 }
708 }
709 Err(EstablishConnectionError::Unauthorized) => {
710 self.state.write().credentials.take();
711 if read_from_keychain {
712 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
713 self.set_status(Status::SignedOut, cx);
714 self.authenticate_and_connect(false, cx).await
715 } else {
716 self.set_status(Status::ConnectionError, cx);
717 Err(EstablishConnectionError::Unauthorized)?
718 }
719 }
720 Err(EstablishConnectionError::UpgradeRequired) => {
721 self.set_status(Status::UpgradeRequired, cx);
722 Err(EstablishConnectionError::UpgradeRequired)?
723 }
724 Err(error) => {
725 self.set_status(Status::ConnectionError, cx);
726 Err(error)?
727 }
728 }
729 }
730 _ = &mut timeout => {
731 self.set_status(Status::ConnectionError, cx);
732 Err(anyhow!("timed out trying to establish connection"))
733 }
734 }
735 }
736
737 async fn set_connection(
738 self: &Arc<Self>,
739 conn: Connection,
740 cx: &AsyncAppContext,
741 ) -> Result<()> {
742 let executor = cx.background();
743 log::info!("add connection to peer");
744 let (connection_id, handle_io, mut incoming) = self
745 .peer
746 .add_connection(conn, move |duration| executor.timer(duration));
747 let handle_io = cx.background().spawn(handle_io);
748
749 let peer_id = async {
750 log::info!("waiting for server hello");
751 let message = incoming
752 .next()
753 .await
754 .ok_or_else(|| anyhow!("no hello message received"))?;
755 log::info!("got server hello");
756 let hello_message_type_name = message.payload_type_name().to_string();
757 let hello = message
758 .into_any()
759 .downcast::<TypedEnvelope<proto::Hello>>()
760 .map_err(|_| {
761 anyhow!(
762 "invalid hello message received: {:?}",
763 hello_message_type_name
764 )
765 })?;
766 Ok(PeerId(hello.payload.peer_id))
767 };
768
769 let peer_id = match peer_id.await {
770 Ok(peer_id) => peer_id,
771 Err(error) => {
772 self.peer.disconnect(connection_id);
773 return Err(error);
774 }
775 };
776
777 log::info!(
778 "set status to connected (connection id: {}, peer id: {})",
779 connection_id,
780 peer_id
781 );
782 self.set_status(
783 Status::Connected {
784 peer_id,
785 connection_id,
786 },
787 cx,
788 );
789 cx.foreground()
790 .spawn({
791 let cx = cx.clone();
792 let this = self.clone();
793 async move {
794 let mut message_id = 0_usize;
795 while let Some(message) = incoming.next().await {
796 let mut state = this.state.write();
797 message_id += 1;
798 let type_name = message.payload_type_name();
799 let payload_type_id = message.payload_type_id();
800 let sender_id = message.original_sender_id().map(|id| id.0);
801
802 let model = state
803 .models_by_message_type
804 .get(&payload_type_id)
805 .and_then(|model| model.upgrade(&cx))
806 .map(AnyEntityHandle::Model)
807 .or_else(|| {
808 let entity_type_id =
809 *state.entity_types_by_message_type.get(&payload_type_id)?;
810 let entity_id = state
811 .entity_id_extractors
812 .get(&message.payload_type_id())
813 .map(|extract_entity_id| {
814 (extract_entity_id)(message.as_ref())
815 })?;
816
817 let entity = state
818 .entities_by_type_and_remote_id
819 .get(&(entity_type_id, entity_id))?;
820 if let Some(entity) = entity.upgrade(&cx) {
821 Some(entity)
822 } else {
823 state
824 .entities_by_type_and_remote_id
825 .remove(&(entity_type_id, entity_id));
826 None
827 }
828 });
829
830 let model = if let Some(model) = model {
831 model
832 } else {
833 log::info!("unhandled message {}", type_name);
834 continue;
835 };
836
837 if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
838 {
839 drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
840 let future = handler(model, message, &this, cx.clone());
841
842 let client_id = this.id;
843 log::debug!(
844 "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
845 client_id,
846 message_id,
847 sender_id,
848 type_name
849 );
850 cx.foreground()
851 .spawn(async move {
852 match future.await {
853 Ok(()) => {
854 log::debug!(
855 "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
856 client_id,
857 message_id,
858 sender_id,
859 type_name
860 );
861 }
862 Err(error) => {
863 log::error!(
864 "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
865 client_id,
866 message_id,
867 sender_id,
868 type_name,
869 error
870 );
871 }
872 }
873 })
874 .detach();
875 } else {
876 log::info!("unhandled message {}", type_name);
877 }
878
879 // Don't starve the main thread when receiving lots of messages at once.
880 smol::future::yield_now().await;
881 }
882 }
883 })
884 .detach();
885
886 let this = self.clone();
887 let cx = cx.clone();
888 cx.foreground()
889 .spawn(async move {
890 match handle_io.await {
891 Ok(()) => {
892 if *this.status().borrow()
893 == (Status::Connected {
894 connection_id,
895 peer_id,
896 })
897 {
898 this.set_status(Status::SignedOut, &cx);
899 }
900 }
901 Err(err) => {
902 log::error!("connection error: {:?}", err);
903 this.set_status(Status::ConnectionLost, &cx);
904 }
905 }
906 })
907 .detach();
908
909 Ok(())
910 }
911
912 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
913 #[cfg(any(test, feature = "test-support"))]
914 if let Some(callback) = self.authenticate.read().as_ref() {
915 return callback(cx);
916 }
917
918 self.authenticate_with_browser(cx)
919 }
920
921 fn establish_connection(
922 self: &Arc<Self>,
923 credentials: &Credentials,
924 cx: &AsyncAppContext,
925 ) -> Task<Result<Connection, EstablishConnectionError>> {
926 #[cfg(any(test, feature = "test-support"))]
927 if let Some(callback) = self.establish_connection.read().as_ref() {
928 return callback(credentials, cx);
929 }
930
931 self.establish_websocket_connection(credentials, cx)
932 }
933
934 async fn get_rpc_url(http: Arc<dyn HttpClient>) -> Result<Url> {
935 let url = format!("{}/rpc", *ZED_SERVER_URL);
936 let response = http.get(&url, Default::default(), false).await?;
937
938 // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
939 // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
940 // which requires authorization via an HTTP header.
941 //
942 // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
943 // of a collab server. In that case, a request to the /rpc endpoint will
944 // return an 'unauthorized' response.
945 let collab_url = if response.status().is_redirection() {
946 response
947 .headers()
948 .get("Location")
949 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
950 .to_str()
951 .map_err(EstablishConnectionError::other)?
952 .to_string()
953 } else if response.status() == StatusCode::UNAUTHORIZED {
954 url
955 } else {
956 Err(anyhow!(
957 "unexpected /rpc response status {}",
958 response.status()
959 ))?
960 };
961
962 Url::parse(&collab_url).context("invalid rpc url")
963 }
964
965 fn establish_websocket_connection(
966 self: &Arc<Self>,
967 credentials: &Credentials,
968 cx: &AsyncAppContext,
969 ) -> Task<Result<Connection, EstablishConnectionError>> {
970 let request = Request::builder()
971 .header(
972 "Authorization",
973 format!("{} {}", credentials.user_id, credentials.access_token),
974 )
975 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
976
977 let http = self.http.clone();
978 cx.background().spawn(async move {
979 let mut rpc_url = Self::get_rpc_url(http).await?;
980 let rpc_host = rpc_url
981 .host_str()
982 .zip(rpc_url.port_or_known_default())
983 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
984 let stream = smol::net::TcpStream::connect(rpc_host).await?;
985
986 log::info!("connected to rpc endpoint {}", rpc_url);
987
988 match rpc_url.scheme() {
989 "https" => {
990 rpc_url.set_scheme("wss").unwrap();
991 let request = request.uri(rpc_url.as_str()).body(())?;
992 let (stream, _) =
993 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
994 Ok(Connection::new(
995 stream
996 .map_err(|error| anyhow!(error))
997 .sink_map_err(|error| anyhow!(error)),
998 ))
999 }
1000 "http" => {
1001 rpc_url.set_scheme("ws").unwrap();
1002 let request = request.uri(rpc_url.as_str()).body(())?;
1003 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1004 Ok(Connection::new(
1005 stream
1006 .map_err(|error| anyhow!(error))
1007 .sink_map_err(|error| anyhow!(error)),
1008 ))
1009 }
1010 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1011 }
1012 })
1013 }
1014
1015 pub fn authenticate_with_browser(
1016 self: &Arc<Self>,
1017 cx: &AsyncAppContext,
1018 ) -> Task<Result<Credentials>> {
1019 let platform = cx.platform();
1020 let executor = cx.background();
1021 let telemetry = self.telemetry.clone();
1022 let amplitude_telemetry = self.amplitude_telemetry.clone();
1023 let http = self.http.clone();
1024 executor.clone().spawn(async move {
1025 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1026 // zed server to encrypt the user's access token, so that it can'be intercepted by
1027 // any other app running on the user's device.
1028 let (public_key, private_key) =
1029 rpc::auth::keypair().expect("failed to generate keypair for auth");
1030 let public_key_string =
1031 String::try_from(public_key).expect("failed to serialize public key for auth");
1032
1033 if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1034 return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1035 }
1036
1037 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1038 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1039 let port = server.server_addr().port();
1040
1041 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1042 // that the user is signing in from a Zed app running on the same device.
1043 let mut url = format!(
1044 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1045 *ZED_SERVER_URL, port, public_key_string
1046 );
1047
1048 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1049 log::info!("impersonating user @{}", impersonate_login);
1050 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1051 }
1052
1053 platform.open_url(&url);
1054
1055 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1056 // access token from the query params.
1057 //
1058 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1059 // custom URL scheme instead of this local HTTP server.
1060 let (user_id, access_token) = executor
1061 .spawn(async move {
1062 for _ in 0..100 {
1063 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1064 let path = req.url();
1065 let mut user_id = None;
1066 let mut access_token = None;
1067 let url = Url::parse(&format!("http://example.com{}", path))
1068 .context("failed to parse login notification url")?;
1069 for (key, value) in url.query_pairs() {
1070 if key == "access_token" {
1071 access_token = Some(value.to_string());
1072 } else if key == "user_id" {
1073 user_id = Some(value.to_string());
1074 }
1075 }
1076
1077 let post_auth_url =
1078 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1079 req.respond(
1080 tiny_http::Response::empty(302).with_header(
1081 tiny_http::Header::from_bytes(
1082 &b"Location"[..],
1083 post_auth_url.as_bytes(),
1084 )
1085 .unwrap(),
1086 ),
1087 )
1088 .context("failed to respond to login http request")?;
1089 return Ok((
1090 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1091 access_token
1092 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1093 ));
1094 }
1095 }
1096
1097 Err(anyhow!("didn't receive login redirect"))
1098 })
1099 .await?;
1100
1101 let access_token = private_key
1102 .decrypt_string(&access_token)
1103 .context("failed to decrypt access token")?;
1104 platform.activate(true);
1105
1106 telemetry.report_event("authenticate with browser", Default::default());
1107 amplitude_telemetry.report_event("authenticate with browser", Default::default());
1108
1109 Ok(Credentials {
1110 user_id: user_id.parse()?,
1111 access_token,
1112 })
1113 })
1114 }
1115
1116 async fn authenticate_as_admin(
1117 http: Arc<dyn HttpClient>,
1118 login: String,
1119 mut api_token: String,
1120 ) -> Result<Credentials> {
1121 #[derive(Deserialize)]
1122 struct AuthenticatedUserResponse {
1123 user: User,
1124 }
1125
1126 #[derive(Deserialize)]
1127 struct User {
1128 id: u64,
1129 }
1130
1131 // Use the collab server's admin API to retrieve the id
1132 // of the impersonated user.
1133 let mut url = Self::get_rpc_url(http.clone()).await?;
1134 url.set_path("/user");
1135 url.set_query(Some(&format!("github_login={login}")));
1136 let request = Request::get(url.as_str())
1137 .header("Authorization", format!("token {api_token}"))
1138 .body("".into())?;
1139
1140 let mut response = http.send(request).await?;
1141 let mut body = String::new();
1142 response.body_mut().read_to_string(&mut body).await?;
1143 if !response.status().is_success() {
1144 Err(anyhow!(
1145 "admin user request failed {} - {}",
1146 response.status().as_u16(),
1147 body,
1148 ))?;
1149 }
1150 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1151
1152 // Use the admin API token to authenticate as the impersonated user.
1153 api_token.insert_str(0, "ADMIN_TOKEN:");
1154 Ok(Credentials {
1155 user_id: response.user.id,
1156 access_token: api_token,
1157 })
1158 }
1159
1160 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
1161 let conn_id = self.connection_id()?;
1162 self.peer.disconnect(conn_id);
1163 self.set_status(Status::SignedOut, cx);
1164 Ok(())
1165 }
1166
1167 fn connection_id(&self) -> Result<ConnectionId> {
1168 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1169 Ok(connection_id)
1170 } else {
1171 Err(anyhow!("not connected"))
1172 }
1173 }
1174
1175 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1176 log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1177 self.peer.send(self.connection_id()?, message)
1178 }
1179
1180 pub fn request<T: RequestMessage>(
1181 &self,
1182 request: T,
1183 ) -> impl Future<Output = Result<T::Response>> {
1184 let client_id = self.id;
1185 log::debug!(
1186 "rpc request start. client_id:{}. name:{}",
1187 client_id,
1188 T::NAME
1189 );
1190 let response = self
1191 .connection_id()
1192 .map(|conn_id| self.peer.request(conn_id, request));
1193 async move {
1194 let response = response?.await;
1195 log::debug!(
1196 "rpc request finish. client_id:{}. name:{}",
1197 client_id,
1198 T::NAME
1199 );
1200 response
1201 }
1202 }
1203
1204 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1205 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1206 self.peer.respond(receipt, response)
1207 }
1208
1209 fn respond_with_error<T: RequestMessage>(
1210 &self,
1211 receipt: Receipt<T>,
1212 error: proto::Error,
1213 ) -> Result<()> {
1214 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1215 self.peer.respond_with_error(receipt, error)
1216 }
1217
1218 pub fn start_telemetry(&self, db: Db) {
1219 self.telemetry.start(db.clone());
1220 self.amplitude_telemetry.start(db);
1221 }
1222
1223 pub fn report_event(&self, kind: &str, properties: Value) {
1224 self.telemetry.report_event(kind, properties.clone());
1225 self.amplitude_telemetry.report_event(kind, properties);
1226 }
1227
1228 pub fn telemetry_log_file_path(&self) -> Option<PathBuf> {
1229 self.amplitude_telemetry.log_file_path();
1230 self.telemetry.log_file_path()
1231 }
1232}
1233
1234impl AnyWeakEntityHandle {
1235 fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
1236 match self {
1237 AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
1238 AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
1239 }
1240 }
1241}
1242
1243fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1244 if IMPERSONATE_LOGIN.is_some() {
1245 return None;
1246 }
1247
1248 let (user_id, access_token) = cx
1249 .platform()
1250 .read_credentials(&ZED_SERVER_URL)
1251 .log_err()
1252 .flatten()?;
1253 Some(Credentials {
1254 user_id: user_id.parse().ok()?,
1255 access_token: String::from_utf8(access_token).ok()?,
1256 })
1257}
1258
1259fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1260 cx.platform().write_credentials(
1261 &ZED_SERVER_URL,
1262 &credentials.user_id.to_string(),
1263 credentials.access_token.as_bytes(),
1264 )
1265}
1266
1267const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1268
1269pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1270 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1271}
1272
1273pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1274 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1275 let mut parts = path.split('/');
1276 let id = parts.next()?.parse::<u64>().ok()?;
1277 let access_token = parts.next()?;
1278 if access_token.is_empty() {
1279 return None;
1280 }
1281 Some((id, access_token.to_string()))
1282}
1283
1284#[cfg(test)]
1285mod tests {
1286 use super::*;
1287 use crate::test::{FakeHttpClient, FakeServer};
1288 use gpui::{executor::Deterministic, TestAppContext};
1289 use parking_lot::Mutex;
1290 use std::future;
1291
1292 #[gpui::test(iterations = 10)]
1293 async fn test_reconnection(cx: &mut TestAppContext) {
1294 cx.foreground().forbid_parking();
1295
1296 let user_id = 5;
1297 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1298 let server = FakeServer::for_client(user_id, &client, cx).await;
1299 let mut status = client.status();
1300 assert!(matches!(
1301 status.next().await,
1302 Some(Status::Connected { .. })
1303 ));
1304 assert_eq!(server.auth_count(), 1);
1305
1306 server.forbid_connections();
1307 server.disconnect();
1308 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1309
1310 server.allow_connections();
1311 cx.foreground().advance_clock(Duration::from_secs(10));
1312 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1313 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1314
1315 server.forbid_connections();
1316 server.disconnect();
1317 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1318
1319 // Clear cached credentials after authentication fails
1320 server.roll_access_token();
1321 server.allow_connections();
1322 cx.foreground().advance_clock(Duration::from_secs(10));
1323 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1324 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1325 }
1326
1327 #[gpui::test(iterations = 10)]
1328 async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1329 deterministic.forbid_parking();
1330
1331 let user_id = 5;
1332 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1333 let mut status = client.status();
1334
1335 // Time out when client tries to connect.
1336 client.override_authenticate(move |cx| {
1337 cx.foreground().spawn(async move {
1338 Ok(Credentials {
1339 user_id,
1340 access_token: "token".into(),
1341 })
1342 })
1343 });
1344 client.override_establish_connection(|_, cx| {
1345 cx.foreground().spawn(async move {
1346 future::pending::<()>().await;
1347 unreachable!()
1348 })
1349 });
1350 let auth_and_connect = cx.spawn({
1351 let client = client.clone();
1352 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1353 });
1354 deterministic.run_until_parked();
1355 assert!(matches!(status.next().await, Some(Status::Connecting)));
1356
1357 deterministic.advance_clock(CONNECTION_TIMEOUT);
1358 assert!(matches!(
1359 status.next().await,
1360 Some(Status::ConnectionError { .. })
1361 ));
1362 auth_and_connect.await.unwrap_err();
1363
1364 // Allow the connection to be established.
1365 let server = FakeServer::for_client(user_id, &client, cx).await;
1366 assert!(matches!(
1367 status.next().await,
1368 Some(Status::Connected { .. })
1369 ));
1370
1371 // Disconnect client.
1372 server.forbid_connections();
1373 server.disconnect();
1374 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1375
1376 // Time out when re-establishing the connection.
1377 server.allow_connections();
1378 client.override_establish_connection(|_, cx| {
1379 cx.foreground().spawn(async move {
1380 future::pending::<()>().await;
1381 unreachable!()
1382 })
1383 });
1384 deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1385 assert!(matches!(
1386 status.next().await,
1387 Some(Status::Reconnecting { .. })
1388 ));
1389
1390 deterministic.advance_clock(CONNECTION_TIMEOUT);
1391 assert!(matches!(
1392 status.next().await,
1393 Some(Status::ReconnectionError { .. })
1394 ));
1395 }
1396
1397 #[gpui::test(iterations = 10)]
1398 async fn test_authenticating_more_than_once(
1399 cx: &mut TestAppContext,
1400 deterministic: Arc<Deterministic>,
1401 ) {
1402 cx.foreground().forbid_parking();
1403
1404 let auth_count = Arc::new(Mutex::new(0));
1405 let dropped_auth_count = Arc::new(Mutex::new(0));
1406 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1407 client.override_authenticate({
1408 let auth_count = auth_count.clone();
1409 let dropped_auth_count = dropped_auth_count.clone();
1410 move |cx| {
1411 let auth_count = auth_count.clone();
1412 let dropped_auth_count = dropped_auth_count.clone();
1413 cx.foreground().spawn(async move {
1414 *auth_count.lock() += 1;
1415 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1416 future::pending::<()>().await;
1417 unreachable!()
1418 })
1419 }
1420 });
1421
1422 let _authenticate = cx.spawn(|cx| {
1423 let client = client.clone();
1424 async move { client.authenticate_and_connect(false, &cx).await }
1425 });
1426 deterministic.run_until_parked();
1427 assert_eq!(*auth_count.lock(), 1);
1428 assert_eq!(*dropped_auth_count.lock(), 0);
1429
1430 let _authenticate = cx.spawn(|cx| {
1431 let client = client.clone();
1432 async move { client.authenticate_and_connect(false, &cx).await }
1433 });
1434 deterministic.run_until_parked();
1435 assert_eq!(*auth_count.lock(), 2);
1436 assert_eq!(*dropped_auth_count.lock(), 1);
1437 }
1438
1439 #[test]
1440 fn test_encode_and_decode_worktree_url() {
1441 let url = encode_worktree_url(5, "deadbeef");
1442 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1443 assert_eq!(
1444 decode_worktree_url(&format!("\n {}\t", url)),
1445 Some((5, "deadbeef".to_string()))
1446 );
1447 assert_eq!(decode_worktree_url("not://the-right-format"), None);
1448 }
1449
1450 #[gpui::test]
1451 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1452 cx.foreground().forbid_parking();
1453
1454 let user_id = 5;
1455 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1456 let server = FakeServer::for_client(user_id, &client, cx).await;
1457
1458 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1459 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1460 client.add_model_message_handler(
1461 move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1462 match model.read_with(&cx, |model, _| model.id) {
1463 1 => done_tx1.try_send(()).unwrap(),
1464 2 => done_tx2.try_send(()).unwrap(),
1465 _ => unreachable!(),
1466 }
1467 async { Ok(()) }
1468 },
1469 );
1470 let model1 = cx.add_model(|_| Model {
1471 id: 1,
1472 subscription: None,
1473 });
1474 let model2 = cx.add_model(|_| Model {
1475 id: 2,
1476 subscription: None,
1477 });
1478 let model3 = cx.add_model(|_| Model {
1479 id: 3,
1480 subscription: None,
1481 });
1482
1483 let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1484 let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1485 // Ensure dropping a subscription for the same entity type still allows receiving of
1486 // messages for other entity IDs of the same type.
1487 let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1488 drop(subscription3);
1489
1490 server.send(proto::JoinProject { project_id: 1 });
1491 server.send(proto::JoinProject { project_id: 2 });
1492 done_rx1.next().await.unwrap();
1493 done_rx2.next().await.unwrap();
1494 }
1495
1496 #[gpui::test]
1497 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1498 cx.foreground().forbid_parking();
1499
1500 let user_id = 5;
1501 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1502 let server = FakeServer::for_client(user_id, &client, cx).await;
1503
1504 let model = cx.add_model(|_| Model::default());
1505 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1506 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1507 let subscription1 = client.add_message_handler(
1508 model.clone(),
1509 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1510 done_tx1.try_send(()).unwrap();
1511 async { Ok(()) }
1512 },
1513 );
1514 drop(subscription1);
1515 let _subscription2 =
1516 client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1517 done_tx2.try_send(()).unwrap();
1518 async { Ok(()) }
1519 });
1520 server.send(proto::Ping {});
1521 done_rx2.next().await.unwrap();
1522 }
1523
1524 #[gpui::test]
1525 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1526 cx.foreground().forbid_parking();
1527
1528 let user_id = 5;
1529 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1530 let server = FakeServer::for_client(user_id, &client, cx).await;
1531
1532 let model = cx.add_model(|_| Model::default());
1533 let (done_tx, mut done_rx) = smol::channel::unbounded();
1534 let subscription = client.add_message_handler(
1535 model.clone(),
1536 move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1537 model.update(&mut cx, |model, _| model.subscription.take());
1538 done_tx.try_send(()).unwrap();
1539 async { Ok(()) }
1540 },
1541 );
1542 model.update(cx, |model, _| {
1543 model.subscription = Some(subscription);
1544 });
1545 server.send(proto::Ping {});
1546 done_rx.next().await.unwrap();
1547 }
1548
1549 #[derive(Default)]
1550 struct Model {
1551 id: usize,
1552 subscription: Option<Subscription>,
1553 }
1554
1555 impl Entity for Model {
1556 type Event = ();
1557 }
1558}