1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod amplitude_telemetry;
5pub mod channel;
6pub mod http;
7pub mod telemetry;
8pub mod user;
9
10use amplitude_telemetry::AmplitudeTelemetry;
11use anyhow::{anyhow, Context, Result};
12use async_recursion::async_recursion;
13use async_tungstenite::tungstenite::{
14 error::Error as WebsocketError,
15 http::{Request, StatusCode},
16};
17use db::Db;
18use futures::{future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryStreamExt};
19use gpui::{
20 actions,
21 serde_json::{self, Value},
22 AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AppContext,
23 AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext,
24 ViewHandle,
25};
26use http::HttpClient;
27use lazy_static::lazy_static;
28use parking_lot::RwLock;
29use postage::watch;
30use rand::prelude::*;
31use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
32use serde::Deserialize;
33use std::{
34 any::TypeId,
35 collections::HashMap,
36 convert::TryFrom,
37 fmt::Write as _,
38 future::Future,
39 path::PathBuf,
40 sync::{Arc, Weak},
41 time::{Duration, Instant},
42};
43use telemetry::Telemetry;
44use thiserror::Error;
45use url::Url;
46use util::{ResultExt, TryFutureExt};
47
48pub use channel::*;
49pub use rpc::*;
50pub use user::*;
51
52lazy_static! {
53 pub static ref PREVIEW_CHANNEL: bool = std::env::var("ZED_PREVIEW_CHANNEL")
54 .map_or(false, |var| !var.is_empty())
55 || option_env!("ZED_PREVIEW_CHANNEL").map_or(false, |var| !var.is_empty());
56 pub static ref ZED_SERVER_URL: String =
57 std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
58 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
59 .ok()
60 .and_then(|s| if s.is_empty() { None } else { Some(s) });
61 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
62 .ok()
63 .and_then(|s| if s.is_empty() { None } else { Some(s) });
64}
65
66pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
67pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
68pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
69
70actions!(client, [Authenticate]);
71
72pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
73 cx.add_global_action({
74 let client = client.clone();
75 move |_: &Authenticate, cx| {
76 let client = client.clone();
77 cx.spawn(
78 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
79 )
80 .detach();
81 }
82 });
83}
84
85pub struct Client {
86 id: usize,
87 peer: Arc<Peer>,
88 http: Arc<dyn HttpClient>,
89 telemetry: Arc<Telemetry>,
90 amplitude_telemetry: Arc<AmplitudeTelemetry>,
91 state: RwLock<ClientState>,
92
93 #[allow(clippy::type_complexity)]
94 #[cfg(any(test, feature = "test-support"))]
95 authenticate: RwLock<
96 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
97 >,
98
99 #[allow(clippy::type_complexity)]
100 #[cfg(any(test, feature = "test-support"))]
101 establish_connection: RwLock<
102 Option<
103 Box<
104 dyn 'static
105 + Send
106 + Sync
107 + Fn(
108 &Credentials,
109 &AsyncAppContext,
110 ) -> Task<Result<Connection, EstablishConnectionError>>,
111 >,
112 >,
113 >,
114}
115
116#[derive(Error, Debug)]
117pub enum EstablishConnectionError {
118 #[error("upgrade required")]
119 UpgradeRequired,
120 #[error("unauthorized")]
121 Unauthorized,
122 #[error("{0}")]
123 Other(#[from] anyhow::Error),
124 #[error("{0}")]
125 Http(#[from] http::Error),
126 #[error("{0}")]
127 Io(#[from] std::io::Error),
128 #[error("{0}")]
129 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
130}
131
132impl From<WebsocketError> for EstablishConnectionError {
133 fn from(error: WebsocketError) -> Self {
134 if let WebsocketError::Http(response) = &error {
135 match response.status() {
136 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
137 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
138 _ => {}
139 }
140 }
141 EstablishConnectionError::Other(error.into())
142 }
143}
144
145impl EstablishConnectionError {
146 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
147 Self::Other(error.into())
148 }
149}
150
151#[derive(Copy, Clone, Debug, Eq, PartialEq)]
152pub enum Status {
153 SignedOut,
154 UpgradeRequired,
155 Authenticating,
156 Connecting,
157 ConnectionError,
158 Connected {
159 peer_id: PeerId,
160 connection_id: ConnectionId,
161 },
162 ConnectionLost,
163 Reauthenticating,
164 Reconnecting,
165 ReconnectionError {
166 next_reconnection: Instant,
167 },
168}
169
170impl Status {
171 pub fn is_connected(&self) -> bool {
172 matches!(self, Self::Connected { .. })
173 }
174}
175
176struct ClientState {
177 credentials: Option<Credentials>,
178 status: (watch::Sender<Status>, watch::Receiver<Status>),
179 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
180 _reconnect_task: Option<Task<()>>,
181 reconnect_interval: Duration,
182 entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
183 models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
184 entity_types_by_message_type: HashMap<TypeId, TypeId>,
185 #[allow(clippy::type_complexity)]
186 message_handlers: HashMap<
187 TypeId,
188 Arc<
189 dyn Send
190 + Sync
191 + Fn(
192 AnyEntityHandle,
193 Box<dyn AnyTypedEnvelope>,
194 &Arc<Client>,
195 AsyncAppContext,
196 ) -> LocalBoxFuture<'static, Result<()>>,
197 >,
198 >,
199}
200
201enum AnyWeakEntityHandle {
202 Model(AnyWeakModelHandle),
203 View(AnyWeakViewHandle),
204}
205
206enum AnyEntityHandle {
207 Model(AnyModelHandle),
208 View(AnyViewHandle),
209}
210
211#[derive(Clone, Debug)]
212pub struct Credentials {
213 pub user_id: u64,
214 pub access_token: String,
215}
216
217impl Default for ClientState {
218 fn default() -> Self {
219 Self {
220 credentials: None,
221 status: watch::channel_with(Status::SignedOut),
222 entity_id_extractors: Default::default(),
223 _reconnect_task: None,
224 reconnect_interval: Duration::from_secs(5),
225 models_by_message_type: Default::default(),
226 entities_by_type_and_remote_id: Default::default(),
227 entity_types_by_message_type: Default::default(),
228 message_handlers: Default::default(),
229 }
230 }
231}
232
233pub enum Subscription {
234 Entity {
235 client: Weak<Client>,
236 id: (TypeId, u64),
237 },
238 Message {
239 client: Weak<Client>,
240 id: TypeId,
241 },
242}
243
244impl Drop for Subscription {
245 fn drop(&mut self) {
246 match self {
247 Subscription::Entity { client, id } => {
248 if let Some(client) = client.upgrade() {
249 let mut state = client.state.write();
250 let _ = state.entities_by_type_and_remote_id.remove(id);
251 }
252 }
253 Subscription::Message { client, id } => {
254 if let Some(client) = client.upgrade() {
255 let mut state = client.state.write();
256 let _ = state.entity_types_by_message_type.remove(id);
257 let _ = state.message_handlers.remove(id);
258 }
259 }
260 }
261 }
262}
263
264impl Client {
265 pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
266 Arc::new(Self {
267 id: 0,
268 peer: Peer::new(),
269 telemetry: Telemetry::new(http.clone(), cx),
270 amplitude_telemetry: AmplitudeTelemetry::new(http.clone(), cx),
271 http,
272 state: Default::default(),
273
274 #[cfg(any(test, feature = "test-support"))]
275 authenticate: Default::default(),
276 #[cfg(any(test, feature = "test-support"))]
277 establish_connection: Default::default(),
278 })
279 }
280
281 pub fn id(&self) -> usize {
282 self.id
283 }
284
285 pub fn http_client(&self) -> Arc<dyn HttpClient> {
286 self.http.clone()
287 }
288
289 #[cfg(any(test, feature = "test-support"))]
290 pub fn set_id(&mut self, id: usize) -> &Self {
291 self.id = id;
292 self
293 }
294
295 #[cfg(any(test, feature = "test-support"))]
296 pub fn tear_down(&self) {
297 let mut state = self.state.write();
298 state._reconnect_task.take();
299 state.message_handlers.clear();
300 state.models_by_message_type.clear();
301 state.entities_by_type_and_remote_id.clear();
302 state.entity_id_extractors.clear();
303 self.peer.reset();
304 }
305
306 #[cfg(any(test, feature = "test-support"))]
307 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
308 where
309 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
310 {
311 *self.authenticate.write() = Some(Box::new(authenticate));
312 self
313 }
314
315 #[cfg(any(test, feature = "test-support"))]
316 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
317 where
318 F: 'static
319 + Send
320 + Sync
321 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
322 {
323 *self.establish_connection.write() = Some(Box::new(connect));
324 self
325 }
326
327 pub fn user_id(&self) -> Option<u64> {
328 self.state
329 .read()
330 .credentials
331 .as_ref()
332 .map(|credentials| credentials.user_id)
333 }
334
335 pub fn peer_id(&self) -> Option<PeerId> {
336 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
337 Some(*peer_id)
338 } else {
339 None
340 }
341 }
342
343 pub fn status(&self) -> watch::Receiver<Status> {
344 self.state.read().status.1.clone()
345 }
346
347 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
348 log::info!("set status on client {}: {:?}", self.id, status);
349 let mut state = self.state.write();
350 *state.status.0.borrow_mut() = status;
351
352 match status {
353 Status::Connected { .. } => {
354 state._reconnect_task = None;
355 }
356 Status::ConnectionLost => {
357 let this = self.clone();
358 let reconnect_interval = state.reconnect_interval;
359 state._reconnect_task = Some(cx.spawn(|cx| async move {
360 let mut rng = StdRng::from_entropy();
361 let mut delay = INITIAL_RECONNECTION_DELAY;
362 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
363 log::error!("failed to connect {}", error);
364 if matches!(*this.status().borrow(), Status::ConnectionError) {
365 this.set_status(
366 Status::ReconnectionError {
367 next_reconnection: Instant::now() + delay,
368 },
369 &cx,
370 );
371 cx.background().timer(delay).await;
372 delay = delay
373 .mul_f32(rng.gen_range(1.0..=2.0))
374 .min(reconnect_interval);
375 } else {
376 break;
377 }
378 }
379 }));
380 }
381 Status::SignedOut | Status::UpgradeRequired => {
382 self.telemetry.set_authenticated_user_info(None, false);
383 self.amplitude_telemetry
384 .set_authenticated_user_info(None, false);
385 state._reconnect_task.take();
386 }
387 _ => {}
388 }
389 }
390
391 pub fn add_view_for_remote_entity<T: View>(
392 self: &Arc<Self>,
393 remote_id: u64,
394 cx: &mut ViewContext<T>,
395 ) -> Subscription {
396 let id = (TypeId::of::<T>(), remote_id);
397 self.state
398 .write()
399 .entities_by_type_and_remote_id
400 .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
401 Subscription::Entity {
402 client: Arc::downgrade(self),
403 id,
404 }
405 }
406
407 pub fn add_model_for_remote_entity<T: Entity>(
408 self: &Arc<Self>,
409 remote_id: u64,
410 cx: &mut ModelContext<T>,
411 ) -> Subscription {
412 let id = (TypeId::of::<T>(), remote_id);
413 self.state
414 .write()
415 .entities_by_type_and_remote_id
416 .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
417 Subscription::Entity {
418 client: Arc::downgrade(self),
419 id,
420 }
421 }
422
423 pub fn add_message_handler<M, E, H, F>(
424 self: &Arc<Self>,
425 model: ModelHandle<E>,
426 handler: H,
427 ) -> Subscription
428 where
429 M: EnvelopedMessage,
430 E: Entity,
431 H: 'static
432 + Send
433 + Sync
434 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
435 F: 'static + Future<Output = Result<()>>,
436 {
437 let message_type_id = TypeId::of::<M>();
438
439 let mut state = self.state.write();
440 state
441 .models_by_message_type
442 .insert(message_type_id, model.downgrade().into());
443
444 let prev_handler = state.message_handlers.insert(
445 message_type_id,
446 Arc::new(move |handle, envelope, client, cx| {
447 let handle = if let AnyEntityHandle::Model(handle) = handle {
448 handle
449 } else {
450 unreachable!();
451 };
452 let model = handle.downcast::<E>().unwrap();
453 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
454 handler(model, *envelope, client.clone(), cx).boxed_local()
455 }),
456 );
457 if prev_handler.is_some() {
458 panic!("registered handler for the same message twice");
459 }
460
461 Subscription::Message {
462 client: Arc::downgrade(self),
463 id: message_type_id,
464 }
465 }
466
467 pub fn add_request_handler<M, E, H, F>(
468 self: &Arc<Self>,
469 model: ModelHandle<E>,
470 handler: H,
471 ) -> Subscription
472 where
473 M: RequestMessage,
474 E: Entity,
475 H: 'static
476 + Send
477 + Sync
478 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
479 F: 'static + Future<Output = Result<M::Response>>,
480 {
481 self.add_message_handler(model, move |handle, envelope, this, cx| {
482 Self::respond_to_request(
483 envelope.receipt(),
484 handler(handle, envelope, this.clone(), cx),
485 this,
486 )
487 })
488 }
489
490 pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
491 where
492 M: EntityMessage,
493 E: View,
494 H: 'static
495 + Send
496 + Sync
497 + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
498 F: 'static + Future<Output = Result<()>>,
499 {
500 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
501 if let AnyEntityHandle::View(handle) = handle {
502 handler(handle.downcast::<E>().unwrap(), message, client, cx)
503 } else {
504 unreachable!();
505 }
506 })
507 }
508
509 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
510 where
511 M: EntityMessage,
512 E: Entity,
513 H: 'static
514 + Send
515 + Sync
516 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
517 F: 'static + Future<Output = Result<()>>,
518 {
519 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
520 if let AnyEntityHandle::Model(handle) = handle {
521 handler(handle.downcast::<E>().unwrap(), message, client, cx)
522 } else {
523 unreachable!();
524 }
525 })
526 }
527
528 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
529 where
530 M: EntityMessage,
531 E: Entity,
532 H: 'static
533 + Send
534 + Sync
535 + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
536 F: 'static + Future<Output = Result<()>>,
537 {
538 let model_type_id = TypeId::of::<E>();
539 let message_type_id = TypeId::of::<M>();
540
541 let mut state = self.state.write();
542 state
543 .entity_types_by_message_type
544 .insert(message_type_id, model_type_id);
545 state
546 .entity_id_extractors
547 .entry(message_type_id)
548 .or_insert_with(|| {
549 |envelope| {
550 envelope
551 .as_any()
552 .downcast_ref::<TypedEnvelope<M>>()
553 .unwrap()
554 .payload
555 .remote_entity_id()
556 }
557 });
558 let prev_handler = state.message_handlers.insert(
559 message_type_id,
560 Arc::new(move |handle, envelope, client, cx| {
561 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
562 handler(handle, *envelope, client.clone(), cx).boxed_local()
563 }),
564 );
565 if prev_handler.is_some() {
566 panic!("registered handler for the same message twice");
567 }
568 }
569
570 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
571 where
572 M: EntityMessage + RequestMessage,
573 E: Entity,
574 H: 'static
575 + Send
576 + Sync
577 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
578 F: 'static + Future<Output = Result<M::Response>>,
579 {
580 self.add_model_message_handler(move |entity, envelope, client, cx| {
581 Self::respond_to_request::<M, _>(
582 envelope.receipt(),
583 handler(entity, envelope, client.clone(), cx),
584 client,
585 )
586 })
587 }
588
589 pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
590 where
591 M: EntityMessage + RequestMessage,
592 E: View,
593 H: 'static
594 + Send
595 + Sync
596 + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
597 F: 'static + Future<Output = Result<M::Response>>,
598 {
599 self.add_view_message_handler(move |entity, envelope, client, cx| {
600 Self::respond_to_request::<M, _>(
601 envelope.receipt(),
602 handler(entity, envelope, client.clone(), cx),
603 client,
604 )
605 })
606 }
607
608 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
609 receipt: Receipt<T>,
610 response: F,
611 client: Arc<Self>,
612 ) -> Result<()> {
613 match response.await {
614 Ok(response) => {
615 client.respond(receipt, response)?;
616 Ok(())
617 }
618 Err(error) => {
619 client.respond_with_error(
620 receipt,
621 proto::Error {
622 message: format!("{:?}", error),
623 },
624 )?;
625 Err(error)
626 }
627 }
628 }
629
630 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
631 read_credentials_from_keychain(cx).is_some()
632 }
633
634 #[async_recursion(?Send)]
635 pub async fn authenticate_and_connect(
636 self: &Arc<Self>,
637 try_keychain: bool,
638 cx: &AsyncAppContext,
639 ) -> anyhow::Result<()> {
640 let was_disconnected = match *self.status().borrow() {
641 Status::SignedOut => true,
642 Status::ConnectionError
643 | Status::ConnectionLost
644 | Status::Authenticating { .. }
645 | Status::Reauthenticating { .. }
646 | Status::ReconnectionError { .. } => false,
647 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
648 return Ok(())
649 }
650 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
651 };
652
653 if was_disconnected {
654 self.set_status(Status::Authenticating, cx);
655 } else {
656 self.set_status(Status::Reauthenticating, cx)
657 }
658
659 let mut read_from_keychain = false;
660 let mut credentials = self.state.read().credentials.clone();
661 if credentials.is_none() && try_keychain {
662 credentials = read_credentials_from_keychain(cx);
663 read_from_keychain = credentials.is_some();
664 if read_from_keychain {
665 self.report_event("read credentials from keychain", Default::default());
666 }
667 }
668 if credentials.is_none() {
669 let mut status_rx = self.status();
670 let _ = status_rx.next().await;
671 futures::select_biased! {
672 authenticate = self.authenticate(cx).fuse() => {
673 match authenticate {
674 Ok(creds) => credentials = Some(creds),
675 Err(err) => {
676 self.set_status(Status::ConnectionError, cx);
677 return Err(err);
678 }
679 }
680 }
681 _ = status_rx.next().fuse() => {
682 return Err(anyhow!("authentication canceled"));
683 }
684 }
685 }
686 let credentials = credentials.unwrap();
687
688 if was_disconnected {
689 self.set_status(Status::Connecting, cx);
690 } else {
691 self.set_status(Status::Reconnecting, cx);
692 }
693
694 let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
695 futures::select_biased! {
696 connection = self.establish_connection(&credentials, cx).fuse() => {
697 match connection {
698 Ok(conn) => {
699 self.state.write().credentials = Some(credentials.clone());
700 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
701 write_credentials_to_keychain(&credentials, cx).log_err();
702 }
703
704 futures::select_biased! {
705 result = self.set_connection(conn, cx).fuse() => result,
706 _ = timeout => {
707 self.set_status(Status::ConnectionError, cx);
708 Err(anyhow!("timed out waiting on hello message from server"))
709 }
710 }
711 }
712 Err(EstablishConnectionError::Unauthorized) => {
713 self.state.write().credentials.take();
714 if read_from_keychain {
715 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
716 self.set_status(Status::SignedOut, cx);
717 self.authenticate_and_connect(false, cx).await
718 } else {
719 self.set_status(Status::ConnectionError, cx);
720 Err(EstablishConnectionError::Unauthorized)?
721 }
722 }
723 Err(EstablishConnectionError::UpgradeRequired) => {
724 self.set_status(Status::UpgradeRequired, cx);
725 Err(EstablishConnectionError::UpgradeRequired)?
726 }
727 Err(error) => {
728 self.set_status(Status::ConnectionError, cx);
729 Err(error)?
730 }
731 }
732 }
733 _ = &mut timeout => {
734 self.set_status(Status::ConnectionError, cx);
735 Err(anyhow!("timed out trying to establish connection"))
736 }
737 }
738 }
739
740 async fn set_connection(
741 self: &Arc<Self>,
742 conn: Connection,
743 cx: &AsyncAppContext,
744 ) -> Result<()> {
745 let executor = cx.background();
746 log::info!("add connection to peer");
747 let (connection_id, handle_io, mut incoming) = self
748 .peer
749 .add_connection(conn, move |duration| executor.timer(duration));
750 let handle_io = cx.background().spawn(handle_io);
751
752 let peer_id = async {
753 log::info!("waiting for server hello");
754 let message = incoming
755 .next()
756 .await
757 .ok_or_else(|| anyhow!("no hello message received"))?;
758 log::info!("got server hello");
759 let hello_message_type_name = message.payload_type_name().to_string();
760 let hello = message
761 .into_any()
762 .downcast::<TypedEnvelope<proto::Hello>>()
763 .map_err(|_| {
764 anyhow!(
765 "invalid hello message received: {:?}",
766 hello_message_type_name
767 )
768 })?;
769 Ok(PeerId(hello.payload.peer_id))
770 };
771
772 let peer_id = match peer_id.await {
773 Ok(peer_id) => peer_id,
774 Err(error) => {
775 self.peer.disconnect(connection_id);
776 return Err(error);
777 }
778 };
779
780 log::info!(
781 "set status to connected (connection id: {}, peer id: {})",
782 connection_id,
783 peer_id
784 );
785 self.set_status(
786 Status::Connected {
787 peer_id,
788 connection_id,
789 },
790 cx,
791 );
792 cx.foreground()
793 .spawn({
794 let cx = cx.clone();
795 let this = self.clone();
796 async move {
797 let mut message_id = 0_usize;
798 while let Some(message) = incoming.next().await {
799 let mut state = this.state.write();
800 message_id += 1;
801 let type_name = message.payload_type_name();
802 let payload_type_id = message.payload_type_id();
803 let sender_id = message.original_sender_id().map(|id| id.0);
804
805 let model = state
806 .models_by_message_type
807 .get(&payload_type_id)
808 .and_then(|model| model.upgrade(&cx))
809 .map(AnyEntityHandle::Model)
810 .or_else(|| {
811 let entity_type_id =
812 *state.entity_types_by_message_type.get(&payload_type_id)?;
813 let entity_id = state
814 .entity_id_extractors
815 .get(&message.payload_type_id())
816 .map(|extract_entity_id| {
817 (extract_entity_id)(message.as_ref())
818 })?;
819
820 let entity = state
821 .entities_by_type_and_remote_id
822 .get(&(entity_type_id, entity_id))?;
823 if let Some(entity) = entity.upgrade(&cx) {
824 Some(entity)
825 } else {
826 state
827 .entities_by_type_and_remote_id
828 .remove(&(entity_type_id, entity_id));
829 None
830 }
831 });
832
833 let model = if let Some(model) = model {
834 model
835 } else {
836 log::info!("unhandled message {}", type_name);
837 continue;
838 };
839
840 if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
841 {
842 drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
843 let future = handler(model, message, &this, cx.clone());
844
845 let client_id = this.id;
846 log::debug!(
847 "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
848 client_id,
849 message_id,
850 sender_id,
851 type_name
852 );
853 cx.foreground()
854 .spawn(async move {
855 match future.await {
856 Ok(()) => {
857 log::debug!(
858 "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
859 client_id,
860 message_id,
861 sender_id,
862 type_name
863 );
864 }
865 Err(error) => {
866 log::error!(
867 "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
868 client_id,
869 message_id,
870 sender_id,
871 type_name,
872 error
873 );
874 }
875 }
876 })
877 .detach();
878 } else {
879 log::info!("unhandled message {}", type_name);
880 }
881
882 // Don't starve the main thread when receiving lots of messages at once.
883 smol::future::yield_now().await;
884 }
885 }
886 })
887 .detach();
888
889 let this = self.clone();
890 let cx = cx.clone();
891 cx.foreground()
892 .spawn(async move {
893 match handle_io.await {
894 Ok(()) => {
895 if *this.status().borrow()
896 == (Status::Connected {
897 connection_id,
898 peer_id,
899 })
900 {
901 this.set_status(Status::SignedOut, &cx);
902 }
903 }
904 Err(err) => {
905 log::error!("connection error: {:?}", err);
906 this.set_status(Status::ConnectionLost, &cx);
907 }
908 }
909 })
910 .detach();
911
912 Ok(())
913 }
914
915 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
916 #[cfg(any(test, feature = "test-support"))]
917 if let Some(callback) = self.authenticate.read().as_ref() {
918 return callback(cx);
919 }
920
921 self.authenticate_with_browser(cx)
922 }
923
924 fn establish_connection(
925 self: &Arc<Self>,
926 credentials: &Credentials,
927 cx: &AsyncAppContext,
928 ) -> Task<Result<Connection, EstablishConnectionError>> {
929 #[cfg(any(test, feature = "test-support"))]
930 if let Some(callback) = self.establish_connection.read().as_ref() {
931 return callback(credentials, cx);
932 }
933
934 self.establish_websocket_connection(credentials, cx)
935 }
936
937 async fn get_rpc_url(http: Arc<dyn HttpClient>) -> Result<Url> {
938 let url = format!("{}/rpc", *ZED_SERVER_URL);
939 let response = http.get(&url, Default::default(), false).await?;
940
941 // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
942 // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
943 // which requires authorization via an HTTP header.
944 //
945 // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
946 // of a collab server. In that case, a request to the /rpc endpoint will
947 // return an 'unauthorized' response.
948 let collab_url = if response.status().is_redirection() {
949 response
950 .headers()
951 .get("Location")
952 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
953 .to_str()
954 .map_err(EstablishConnectionError::other)?
955 .to_string()
956 } else if response.status() == StatusCode::UNAUTHORIZED {
957 url
958 } else {
959 Err(anyhow!(
960 "unexpected /rpc response status {}",
961 response.status()
962 ))?
963 };
964
965 Url::parse(&collab_url).context("invalid rpc url")
966 }
967
968 fn establish_websocket_connection(
969 self: &Arc<Self>,
970 credentials: &Credentials,
971 cx: &AsyncAppContext,
972 ) -> Task<Result<Connection, EstablishConnectionError>> {
973 let request = Request::builder()
974 .header(
975 "Authorization",
976 format!("{} {}", credentials.user_id, credentials.access_token),
977 )
978 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
979
980 let http = self.http.clone();
981 cx.background().spawn(async move {
982 let mut rpc_url = Self::get_rpc_url(http).await?;
983 let rpc_host = rpc_url
984 .host_str()
985 .zip(rpc_url.port_or_known_default())
986 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
987 let stream = smol::net::TcpStream::connect(rpc_host).await?;
988
989 log::info!("connected to rpc endpoint {}", rpc_url);
990
991 match rpc_url.scheme() {
992 "https" => {
993 rpc_url.set_scheme("wss").unwrap();
994 rpc_url.set_query(if *PREVIEW_CHANNEL {
995 Some("preview=1")
996 } else {
997 None
998 });
999 let request = request.uri(rpc_url.as_str()).body(())?;
1000 let (stream, _) =
1001 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1002 Ok(Connection::new(
1003 stream
1004 .map_err(|error| anyhow!(error))
1005 .sink_map_err(|error| anyhow!(error)),
1006 ))
1007 }
1008 "http" => {
1009 rpc_url.set_scheme("ws").unwrap();
1010 let request = request.uri(rpc_url.as_str()).body(())?;
1011 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1012 Ok(Connection::new(
1013 stream
1014 .map_err(|error| anyhow!(error))
1015 .sink_map_err(|error| anyhow!(error)),
1016 ))
1017 }
1018 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1019 }
1020 })
1021 }
1022
1023 pub fn authenticate_with_browser(
1024 self: &Arc<Self>,
1025 cx: &AsyncAppContext,
1026 ) -> Task<Result<Credentials>> {
1027 let platform = cx.platform();
1028 let executor = cx.background();
1029 let telemetry = self.telemetry.clone();
1030 let amplitude_telemetry = self.amplitude_telemetry.clone();
1031 let http = self.http.clone();
1032 executor.clone().spawn(async move {
1033 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1034 // zed server to encrypt the user's access token, so that it can'be intercepted by
1035 // any other app running on the user's device.
1036 let (public_key, private_key) =
1037 rpc::auth::keypair().expect("failed to generate keypair for auth");
1038 let public_key_string =
1039 String::try_from(public_key).expect("failed to serialize public key for auth");
1040
1041 if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1042 return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1043 }
1044
1045 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1046 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1047 let port = server.server_addr().port();
1048
1049 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1050 // that the user is signing in from a Zed app running on the same device.
1051 let mut url = format!(
1052 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1053 *ZED_SERVER_URL, port, public_key_string
1054 );
1055
1056 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1057 log::info!("impersonating user @{}", impersonate_login);
1058 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1059 }
1060
1061 platform.open_url(&url);
1062
1063 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1064 // access token from the query params.
1065 //
1066 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1067 // custom URL scheme instead of this local HTTP server.
1068 let (user_id, access_token) = executor
1069 .spawn(async move {
1070 for _ in 0..100 {
1071 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1072 let path = req.url();
1073 let mut user_id = None;
1074 let mut access_token = None;
1075 let url = Url::parse(&format!("http://example.com{}", path))
1076 .context("failed to parse login notification url")?;
1077 for (key, value) in url.query_pairs() {
1078 if key == "access_token" {
1079 access_token = Some(value.to_string());
1080 } else if key == "user_id" {
1081 user_id = Some(value.to_string());
1082 }
1083 }
1084
1085 let post_auth_url =
1086 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1087 req.respond(
1088 tiny_http::Response::empty(302).with_header(
1089 tiny_http::Header::from_bytes(
1090 &b"Location"[..],
1091 post_auth_url.as_bytes(),
1092 )
1093 .unwrap(),
1094 ),
1095 )
1096 .context("failed to respond to login http request")?;
1097 return Ok((
1098 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1099 access_token
1100 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1101 ));
1102 }
1103 }
1104
1105 Err(anyhow!("didn't receive login redirect"))
1106 })
1107 .await?;
1108
1109 let access_token = private_key
1110 .decrypt_string(&access_token)
1111 .context("failed to decrypt access token")?;
1112 platform.activate(true);
1113
1114 telemetry.report_event("authenticate with browser", Default::default());
1115 amplitude_telemetry.report_event("authenticate with browser", Default::default());
1116
1117 Ok(Credentials {
1118 user_id: user_id.parse()?,
1119 access_token,
1120 })
1121 })
1122 }
1123
1124 async fn authenticate_as_admin(
1125 http: Arc<dyn HttpClient>,
1126 login: String,
1127 mut api_token: String,
1128 ) -> Result<Credentials> {
1129 #[derive(Deserialize)]
1130 struct AuthenticatedUserResponse {
1131 user: User,
1132 }
1133
1134 #[derive(Deserialize)]
1135 struct User {
1136 id: u64,
1137 }
1138
1139 // Use the collab server's admin API to retrieve the id
1140 // of the impersonated user.
1141 let mut url = Self::get_rpc_url(http.clone()).await?;
1142 url.set_path("/user");
1143 url.set_query(Some(&format!("github_login={login}")));
1144 let request = Request::get(url.as_str())
1145 .header("Authorization", format!("token {api_token}"))
1146 .body("".into())?;
1147
1148 let mut response = http.send(request).await?;
1149 let mut body = String::new();
1150 response.body_mut().read_to_string(&mut body).await?;
1151 if !response.status().is_success() {
1152 Err(anyhow!(
1153 "admin user request failed {} - {}",
1154 response.status().as_u16(),
1155 body,
1156 ))?;
1157 }
1158 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1159
1160 // Use the admin API token to authenticate as the impersonated user.
1161 api_token.insert_str(0, "ADMIN_TOKEN:");
1162 Ok(Credentials {
1163 user_id: response.user.id,
1164 access_token: api_token,
1165 })
1166 }
1167
1168 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
1169 let conn_id = self.connection_id()?;
1170 self.peer.disconnect(conn_id);
1171 self.set_status(Status::SignedOut, cx);
1172 Ok(())
1173 }
1174
1175 fn connection_id(&self) -> Result<ConnectionId> {
1176 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1177 Ok(connection_id)
1178 } else {
1179 Err(anyhow!("not connected"))
1180 }
1181 }
1182
1183 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1184 log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1185 self.peer.send(self.connection_id()?, message)
1186 }
1187
1188 pub fn request<T: RequestMessage>(
1189 &self,
1190 request: T,
1191 ) -> impl Future<Output = Result<T::Response>> {
1192 let client_id = self.id;
1193 log::debug!(
1194 "rpc request start. client_id:{}. name:{}",
1195 client_id,
1196 T::NAME
1197 );
1198 let response = self
1199 .connection_id()
1200 .map(|conn_id| self.peer.request(conn_id, request));
1201 async move {
1202 let response = response?.await;
1203 log::debug!(
1204 "rpc request finish. client_id:{}. name:{}",
1205 client_id,
1206 T::NAME
1207 );
1208 response
1209 }
1210 }
1211
1212 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1213 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1214 self.peer.respond(receipt, response)
1215 }
1216
1217 fn respond_with_error<T: RequestMessage>(
1218 &self,
1219 receipt: Receipt<T>,
1220 error: proto::Error,
1221 ) -> Result<()> {
1222 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1223 self.peer.respond_with_error(receipt, error)
1224 }
1225
1226 pub fn start_telemetry(&self, db: Db) {
1227 self.telemetry.start(db.clone());
1228 self.amplitude_telemetry.start(db);
1229 }
1230
1231 pub fn report_event(&self, kind: &str, properties: Value) {
1232 self.telemetry.report_event(kind, properties.clone());
1233 self.amplitude_telemetry.report_event(kind, properties);
1234 }
1235
1236 pub fn telemetry_log_file_path(&self) -> Option<PathBuf> {
1237 self.amplitude_telemetry.log_file_path();
1238 self.telemetry.log_file_path()
1239 }
1240}
1241
1242impl AnyWeakEntityHandle {
1243 fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
1244 match self {
1245 AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
1246 AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
1247 }
1248 }
1249}
1250
1251fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1252 if IMPERSONATE_LOGIN.is_some() {
1253 return None;
1254 }
1255
1256 let (user_id, access_token) = cx
1257 .platform()
1258 .read_credentials(&ZED_SERVER_URL)
1259 .log_err()
1260 .flatten()?;
1261 Some(Credentials {
1262 user_id: user_id.parse().ok()?,
1263 access_token: String::from_utf8(access_token).ok()?,
1264 })
1265}
1266
1267fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1268 cx.platform().write_credentials(
1269 &ZED_SERVER_URL,
1270 &credentials.user_id.to_string(),
1271 credentials.access_token.as_bytes(),
1272 )
1273}
1274
1275const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1276
1277pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1278 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1279}
1280
1281pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1282 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1283 let mut parts = path.split('/');
1284 let id = parts.next()?.parse::<u64>().ok()?;
1285 let access_token = parts.next()?;
1286 if access_token.is_empty() {
1287 return None;
1288 }
1289 Some((id, access_token.to_string()))
1290}
1291
1292#[cfg(test)]
1293mod tests {
1294 use super::*;
1295 use crate::test::{FakeHttpClient, FakeServer};
1296 use gpui::{executor::Deterministic, TestAppContext};
1297 use parking_lot::Mutex;
1298 use std::future;
1299
1300 #[gpui::test(iterations = 10)]
1301 async fn test_reconnection(cx: &mut TestAppContext) {
1302 cx.foreground().forbid_parking();
1303
1304 let user_id = 5;
1305 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1306 let server = FakeServer::for_client(user_id, &client, cx).await;
1307 let mut status = client.status();
1308 assert!(matches!(
1309 status.next().await,
1310 Some(Status::Connected { .. })
1311 ));
1312 assert_eq!(server.auth_count(), 1);
1313
1314 server.forbid_connections();
1315 server.disconnect();
1316 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1317
1318 server.allow_connections();
1319 cx.foreground().advance_clock(Duration::from_secs(10));
1320 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1321 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1322
1323 server.forbid_connections();
1324 server.disconnect();
1325 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1326
1327 // Clear cached credentials after authentication fails
1328 server.roll_access_token();
1329 server.allow_connections();
1330 cx.foreground().advance_clock(Duration::from_secs(10));
1331 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1332 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1333 }
1334
1335 #[gpui::test(iterations = 10)]
1336 async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1337 deterministic.forbid_parking();
1338
1339 let user_id = 5;
1340 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1341 let mut status = client.status();
1342
1343 // Time out when client tries to connect.
1344 client.override_authenticate(move |cx| {
1345 cx.foreground().spawn(async move {
1346 Ok(Credentials {
1347 user_id,
1348 access_token: "token".into(),
1349 })
1350 })
1351 });
1352 client.override_establish_connection(|_, cx| {
1353 cx.foreground().spawn(async move {
1354 future::pending::<()>().await;
1355 unreachable!()
1356 })
1357 });
1358 let auth_and_connect = cx.spawn({
1359 let client = client.clone();
1360 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1361 });
1362 deterministic.run_until_parked();
1363 assert!(matches!(status.next().await, Some(Status::Connecting)));
1364
1365 deterministic.advance_clock(CONNECTION_TIMEOUT);
1366 assert!(matches!(
1367 status.next().await,
1368 Some(Status::ConnectionError { .. })
1369 ));
1370 auth_and_connect.await.unwrap_err();
1371
1372 // Allow the connection to be established.
1373 let server = FakeServer::for_client(user_id, &client, cx).await;
1374 assert!(matches!(
1375 status.next().await,
1376 Some(Status::Connected { .. })
1377 ));
1378
1379 // Disconnect client.
1380 server.forbid_connections();
1381 server.disconnect();
1382 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1383
1384 // Time out when re-establishing the connection.
1385 server.allow_connections();
1386 client.override_establish_connection(|_, cx| {
1387 cx.foreground().spawn(async move {
1388 future::pending::<()>().await;
1389 unreachable!()
1390 })
1391 });
1392 deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1393 assert!(matches!(
1394 status.next().await,
1395 Some(Status::Reconnecting { .. })
1396 ));
1397
1398 deterministic.advance_clock(CONNECTION_TIMEOUT);
1399 assert!(matches!(
1400 status.next().await,
1401 Some(Status::ReconnectionError { .. })
1402 ));
1403 }
1404
1405 #[gpui::test(iterations = 10)]
1406 async fn test_authenticating_more_than_once(
1407 cx: &mut TestAppContext,
1408 deterministic: Arc<Deterministic>,
1409 ) {
1410 cx.foreground().forbid_parking();
1411
1412 let auth_count = Arc::new(Mutex::new(0));
1413 let dropped_auth_count = Arc::new(Mutex::new(0));
1414 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1415 client.override_authenticate({
1416 let auth_count = auth_count.clone();
1417 let dropped_auth_count = dropped_auth_count.clone();
1418 move |cx| {
1419 let auth_count = auth_count.clone();
1420 let dropped_auth_count = dropped_auth_count.clone();
1421 cx.foreground().spawn(async move {
1422 *auth_count.lock() += 1;
1423 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1424 future::pending::<()>().await;
1425 unreachable!()
1426 })
1427 }
1428 });
1429
1430 let _authenticate = cx.spawn(|cx| {
1431 let client = client.clone();
1432 async move { client.authenticate_and_connect(false, &cx).await }
1433 });
1434 deterministic.run_until_parked();
1435 assert_eq!(*auth_count.lock(), 1);
1436 assert_eq!(*dropped_auth_count.lock(), 0);
1437
1438 let _authenticate = cx.spawn(|cx| {
1439 let client = client.clone();
1440 async move { client.authenticate_and_connect(false, &cx).await }
1441 });
1442 deterministic.run_until_parked();
1443 assert_eq!(*auth_count.lock(), 2);
1444 assert_eq!(*dropped_auth_count.lock(), 1);
1445 }
1446
1447 #[test]
1448 fn test_encode_and_decode_worktree_url() {
1449 let url = encode_worktree_url(5, "deadbeef");
1450 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1451 assert_eq!(
1452 decode_worktree_url(&format!("\n {}\t", url)),
1453 Some((5, "deadbeef".to_string()))
1454 );
1455 assert_eq!(decode_worktree_url("not://the-right-format"), None);
1456 }
1457
1458 #[gpui::test]
1459 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1460 cx.foreground().forbid_parking();
1461
1462 let user_id = 5;
1463 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1464 let server = FakeServer::for_client(user_id, &client, cx).await;
1465
1466 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1467 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1468 client.add_model_message_handler(
1469 move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1470 match model.read_with(&cx, |model, _| model.id) {
1471 1 => done_tx1.try_send(()).unwrap(),
1472 2 => done_tx2.try_send(()).unwrap(),
1473 _ => unreachable!(),
1474 }
1475 async { Ok(()) }
1476 },
1477 );
1478 let model1 = cx.add_model(|_| Model {
1479 id: 1,
1480 subscription: None,
1481 });
1482 let model2 = cx.add_model(|_| Model {
1483 id: 2,
1484 subscription: None,
1485 });
1486 let model3 = cx.add_model(|_| Model {
1487 id: 3,
1488 subscription: None,
1489 });
1490
1491 let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1492 let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1493 // Ensure dropping a subscription for the same entity type still allows receiving of
1494 // messages for other entity IDs of the same type.
1495 let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1496 drop(subscription3);
1497
1498 server.send(proto::JoinProject { project_id: 1 });
1499 server.send(proto::JoinProject { project_id: 2 });
1500 done_rx1.next().await.unwrap();
1501 done_rx2.next().await.unwrap();
1502 }
1503
1504 #[gpui::test]
1505 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1506 cx.foreground().forbid_parking();
1507
1508 let user_id = 5;
1509 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1510 let server = FakeServer::for_client(user_id, &client, cx).await;
1511
1512 let model = cx.add_model(|_| Model::default());
1513 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1514 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1515 let subscription1 = client.add_message_handler(
1516 model.clone(),
1517 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1518 done_tx1.try_send(()).unwrap();
1519 async { Ok(()) }
1520 },
1521 );
1522 drop(subscription1);
1523 let _subscription2 =
1524 client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1525 done_tx2.try_send(()).unwrap();
1526 async { Ok(()) }
1527 });
1528 server.send(proto::Ping {});
1529 done_rx2.next().await.unwrap();
1530 }
1531
1532 #[gpui::test]
1533 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1534 cx.foreground().forbid_parking();
1535
1536 let user_id = 5;
1537 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1538 let server = FakeServer::for_client(user_id, &client, cx).await;
1539
1540 let model = cx.add_model(|_| Model::default());
1541 let (done_tx, mut done_rx) = smol::channel::unbounded();
1542 let subscription = client.add_message_handler(
1543 model.clone(),
1544 move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1545 model.update(&mut cx, |model, _| model.subscription.take());
1546 done_tx.try_send(()).unwrap();
1547 async { Ok(()) }
1548 },
1549 );
1550 model.update(cx, |model, _| {
1551 model.subscription = Some(subscription);
1552 });
1553 server.send(proto::Ping {});
1554 done_rx.next().await.unwrap();
1555 }
1556
1557 #[derive(Default)]
1558 struct Model {
1559 id: usize,
1560 subscription: Option<Subscription>,
1561 }
1562
1563 impl Entity for Model {
1564 type Event = ();
1565 }
1566}