1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4mod llm_token;
5mod proxy;
6pub mod telemetry;
7pub mod user;
8pub mod zed_urls;
9
10use anyhow::{Context as _, Result, anyhow};
11use async_tungstenite::tungstenite::{
12 client::IntoClientRequest,
13 error::Error as WebsocketError,
14 http::{HeaderValue, Request, StatusCode},
15};
16use clock::SystemClock;
17use cloud_api_client::LlmApiToken;
18use cloud_api_client::websocket_protocol::MessageToClient;
19use cloud_api_client::{ClientApiError, CloudApiClient};
20use cloud_api_types::OrganizationId;
21use credentials_provider::CredentialsProvider;
22use feature_flags::FeatureFlagAppExt as _;
23use futures::{
24 AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
25 channel::{mpsc, oneshot},
26 future::BoxFuture,
27};
28use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
29use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
30use parking_lot::{Mutex, RwLock};
31use postage::watch;
32use proxy::connect_proxy_stream;
33use rand::prelude::*;
34use release_channel::{AppVersion, ReleaseChannel};
35use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
36use serde::{Deserialize, Serialize};
37use settings::{RegisterSetting, Settings, SettingsContent};
38use std::{
39 any::TypeId,
40 convert::TryFrom,
41 future::Future,
42 marker::PhantomData,
43 path::PathBuf,
44 sync::{
45 Arc, LazyLock, Weak,
46 atomic::{AtomicU64, Ordering},
47 },
48 time::{Duration, Instant},
49};
50use std::{cmp, pin::Pin};
51use telemetry::Telemetry;
52use thiserror::Error;
53use tokio::net::TcpStream;
54use url::Url;
55use util::{ConnectionResult, ResultExt};
56
57pub use llm_token::*;
58pub use rpc::*;
59pub use telemetry_events::Event;
60pub use user::*;
61
62static ZED_SERVER_URL: LazyLock<Option<String>> =
63 LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
64static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
65
66pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
67 std::env::var("ZED_IMPERSONATE")
68 .ok()
69 .and_then(|s| if s.is_empty() { None } else { Some(s) })
70});
71
72pub static USE_WEB_LOGIN: LazyLock<bool> = LazyLock::new(|| std::env::var("ZED_WEB_LOGIN").is_ok());
73
74pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
75 std::env::var("ZED_ADMIN_API_TOKEN")
76 .ok()
77 .and_then(|s| if s.is_empty() { None } else { Some(s) })
78});
79
80pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
81 LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
82
83pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
84 LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").is_ok_and(|e| !e.is_empty()));
85
86pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
87pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
88pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
89
90actions!(
91 client,
92 [
93 /// Signs in to Zed account.
94 SignIn,
95 /// Signs out of Zed account.
96 SignOut,
97 /// Reconnects to the collaboration server.
98 Reconnect
99 ]
100);
101
102#[derive(Deserialize, RegisterSetting)]
103pub struct ClientSettings {
104 pub server_url: String,
105 /// Overrides the key used to store credentials in the system keychain.
106 /// Defaults to `server_url` when unset.
107 ///
108 /// Useful when running multiple Zed instances side by side without them
109 /// overwriting each other's keychain entries.
110 ///
111 /// Note: changing this after signing in will require signing in again, as
112 /// existing credentials are stored under the old key.
113 pub credentials_url: Option<String>,
114}
115
116impl Settings for ClientSettings {
117 fn from_settings(content: &settings::SettingsContent) -> Self {
118 if let Some(server_url) = &*ZED_SERVER_URL {
119 return Self {
120 server_url: server_url.clone(),
121 credentials_url: content.credentials_url.clone(),
122 };
123 }
124 Self {
125 server_url: content.server_url.clone().unwrap(),
126 credentials_url: content.credentials_url.clone(),
127 }
128 }
129}
130
131#[derive(Deserialize, Default, RegisterSetting)]
132pub struct ProxySettings {
133 pub proxy: Option<String>,
134}
135
136impl ProxySettings {
137 pub fn proxy_url(&self) -> Option<Url> {
138 self.proxy
139 .as_deref()
140 .map(str::trim)
141 .filter(|input| !input.is_empty())
142 .and_then(|input| {
143 input
144 .parse::<Url>()
145 .inspect_err(|e| log::error!("Error parsing proxy settings: {}", e))
146 .ok()
147 })
148 .or_else(read_proxy_from_env)
149 }
150}
151
152impl Settings for ProxySettings {
153 fn from_settings(content: &settings::SettingsContent) -> Self {
154 Self {
155 proxy: content
156 .proxy
157 .as_deref()
158 .map(str::trim)
159 .filter(|proxy| !proxy.is_empty())
160 .map(ToOwned::to_owned),
161 }
162 }
163}
164
165pub fn init(client: &Arc<Client>, cx: &mut App) {
166 let client = Arc::downgrade(client);
167 cx.on_action({
168 let client = client.clone();
169 move |_: &SignIn, cx| {
170 if let Some(client) = client.upgrade() {
171 cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, cx).await)
172 .detach_and_log_err(cx);
173 }
174 }
175 })
176 .on_action({
177 let client = client.clone();
178 move |_: &SignOut, cx| {
179 if let Some(client) = client.upgrade() {
180 cx.spawn(async move |cx| {
181 client.sign_out(cx).await;
182 })
183 .detach();
184 }
185 }
186 })
187 .on_action({
188 let client = client;
189 move |_: &Reconnect, cx| {
190 if let Some(client) = client.upgrade() {
191 cx.spawn(async move |cx| {
192 client.reconnect(cx);
193 })
194 .detach();
195 }
196 }
197 });
198}
199
200pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
201
202struct GlobalClient(Arc<Client>);
203
204impl Global for GlobalClient {}
205
206pub struct Client {
207 id: AtomicU64,
208 peer: Arc<Peer>,
209 http: Arc<HttpClientWithUrl>,
210 cloud_client: Arc<CloudApiClient>,
211 telemetry: Arc<Telemetry>,
212 credentials_provider: ClientCredentialsProvider,
213 state: RwLock<ClientState>,
214 handler_set: Mutex<ProtoMessageHandlerSet>,
215 message_to_client_handlers: Mutex<Vec<MessageToClientHandler>>,
216 sign_out_tx: Mutex<Option<mpsc::UnboundedSender<()>>>,
217
218 #[allow(clippy::type_complexity)]
219 #[cfg(any(test, feature = "test-support"))]
220 authenticate:
221 RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
222
223 #[allow(clippy::type_complexity)]
224 #[cfg(any(test, feature = "test-support"))]
225 establish_connection: RwLock<
226 Option<
227 Box<
228 dyn 'static
229 + Send
230 + Sync
231 + Fn(
232 &Credentials,
233 &AsyncApp,
234 ) -> Task<Result<Connection, EstablishConnectionError>>,
235 >,
236 >,
237 >,
238
239 #[cfg(any(test, feature = "test-support"))]
240 rpc_url: RwLock<Option<Url>>,
241}
242
243#[derive(Error, Debug)]
244pub enum EstablishConnectionError {
245 #[error("upgrade required")]
246 UpgradeRequired,
247 #[error("unauthorized")]
248 Unauthorized,
249 #[error("{0}")]
250 Other(#[from] anyhow::Error),
251 #[error("{0}")]
252 InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
253 #[error("{0}")]
254 Io(#[from] std::io::Error),
255 #[error("{0}")]
256 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
257}
258
259impl From<WebsocketError> for EstablishConnectionError {
260 fn from(error: WebsocketError) -> Self {
261 if let WebsocketError::Http(response) = &error {
262 match response.status() {
263 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
264 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
265 _ => {}
266 }
267 }
268 EstablishConnectionError::Other(error.into())
269 }
270}
271
272impl EstablishConnectionError {
273 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
274 Self::Other(error.into())
275 }
276}
277
278#[derive(Copy, Clone, Debug, PartialEq)]
279pub enum Status {
280 SignedOut,
281 UpgradeRequired,
282 Authenticating,
283 Authenticated,
284 AuthenticationError,
285 Connecting,
286 ConnectionError,
287 Connected {
288 peer_id: PeerId,
289 connection_id: ConnectionId,
290 },
291 ConnectionLost,
292 Reauthenticating,
293 Reauthenticated,
294 Reconnecting,
295 ReconnectionError {
296 next_reconnection: Instant,
297 },
298}
299
300impl Status {
301 pub fn is_connected(&self) -> bool {
302 matches!(self, Self::Connected { .. })
303 }
304
305 pub fn was_connected(&self) -> bool {
306 matches!(
307 self,
308 Self::ConnectionLost
309 | Self::Reauthenticating
310 | Self::Reauthenticated
311 | Self::Reconnecting
312 )
313 }
314
315 /// Returns whether the client is currently connected or was connected at some point.
316 pub fn is_or_was_connected(&self) -> bool {
317 self.is_connected() || self.was_connected()
318 }
319
320 pub fn is_signing_in(&self) -> bool {
321 matches!(
322 self,
323 Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
324 )
325 }
326
327 pub fn is_signed_out(&self) -> bool {
328 matches!(self, Self::SignedOut | Self::UpgradeRequired)
329 }
330}
331
332struct ClientState {
333 credentials: Option<Credentials>,
334 status: (watch::Sender<Status>, watch::Receiver<Status>),
335 _reconnect_task: Option<Task<()>>,
336}
337
338#[derive(Clone, Debug, Eq, PartialEq)]
339pub struct Credentials {
340 pub user_id: u64,
341 pub access_token: String,
342}
343
344impl Credentials {
345 pub fn authorization_header(&self) -> String {
346 format!("{} {}", self.user_id, self.access_token)
347 }
348}
349
350pub struct ClientCredentialsProvider {
351 provider: Arc<dyn CredentialsProvider>,
352}
353
354impl ClientCredentialsProvider {
355 pub fn new(cx: &App) -> Self {
356 Self {
357 provider: zed_credentials_provider::global(cx),
358 }
359 }
360
361 fn server_url(&self, cx: &AsyncApp) -> Result<String> {
362 Ok(cx.update(|cx| ClientSettings::get_global(cx).server_url.clone()))
363 }
364
365 /// Returns the key used for credential storage in the system keychain.
366 fn credentials_url(&self, cx: &AsyncApp) -> Result<String> {
367 let from_settings = cx.update(|cx| ClientSettings::get_global(cx).credentials_url.clone());
368 Ok(from_settings.unwrap_or(self.server_url(cx)?))
369 }
370
371 /// Reads the credentials from the provider.
372 fn read_credentials<'a>(
373 &'a self,
374 cx: &'a AsyncApp,
375 ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
376 async move {
377 if IMPERSONATE_LOGIN.is_some() {
378 return None;
379 }
380
381 let credentials_url = self.credentials_url(cx).ok()?;
382 let (user_id, access_token) = self
383 .provider
384 .read_credentials(&credentials_url, cx)
385 .await
386 .log_err()
387 .flatten()?;
388
389 Some(Credentials {
390 user_id: user_id.parse().ok()?,
391 access_token: String::from_utf8(access_token).ok()?,
392 })
393 }
394 .boxed_local()
395 }
396
397 /// Writes the credentials to the provider.
398 fn write_credentials<'a>(
399 &'a self,
400 user_id: u64,
401 access_token: String,
402 cx: &'a AsyncApp,
403 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
404 async move {
405 let credentials_url = self.credentials_url(cx)?;
406 self.provider
407 .write_credentials(
408 &credentials_url,
409 &user_id.to_string(),
410 access_token.as_bytes(),
411 cx,
412 )
413 .await
414 }
415 .boxed_local()
416 }
417
418 /// Deletes the credentials from the provider.
419 fn delete_credentials<'a>(
420 &'a self,
421 cx: &'a AsyncApp,
422 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
423 async move {
424 let credentials_url = self.credentials_url(cx)?;
425 self.provider.delete_credentials(&credentials_url, cx).await
426 }
427 .boxed_local()
428 }
429}
430
431impl Default for ClientState {
432 fn default() -> Self {
433 Self {
434 credentials: None,
435 status: watch::channel_with(Status::SignedOut),
436 _reconnect_task: None,
437 }
438 }
439}
440
441pub enum Subscription {
442 Entity {
443 client: Weak<Client>,
444 id: (TypeId, u64),
445 },
446 Message {
447 client: Weak<Client>,
448 id: TypeId,
449 },
450}
451
452impl Drop for Subscription {
453 fn drop(&mut self) {
454 match self {
455 Subscription::Entity { client, id } => {
456 if let Some(client) = client.upgrade() {
457 let mut state = client.handler_set.lock();
458 let _ = state.entities_by_type_and_remote_id.remove(id);
459 }
460 }
461 Subscription::Message { client, id } => {
462 if let Some(client) = client.upgrade() {
463 let mut state = client.handler_set.lock();
464 let _ = state.entity_types_by_message_type.remove(id);
465 let _ = state.message_handlers.remove(id);
466 }
467 }
468 }
469 }
470}
471
472pub struct PendingEntitySubscription<T: 'static> {
473 client: Arc<Client>,
474 remote_id: u64,
475 _entity_type: PhantomData<T>,
476 consumed: bool,
477}
478
479impl<T: 'static> PendingEntitySubscription<T> {
480 pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
481 self.consumed = true;
482 let mut handlers = self.client.handler_set.lock();
483 let id = (TypeId::of::<T>(), self.remote_id);
484 let Some(EntityMessageSubscriber::Pending(messages)) =
485 handlers.entities_by_type_and_remote_id.remove(&id)
486 else {
487 unreachable!()
488 };
489
490 handlers.entities_by_type_and_remote_id.insert(
491 id,
492 EntityMessageSubscriber::Entity {
493 handle: entity.downgrade().into(),
494 },
495 );
496 drop(handlers);
497 for message in messages {
498 let client_id = self.client.id();
499 let type_name = message.payload_type_name();
500 let sender_id = message.original_sender_id();
501 log::debug!(
502 "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
503 client_id,
504 sender_id,
505 type_name
506 );
507 self.client.handle_message(message, cx);
508 }
509 Subscription::Entity {
510 client: Arc::downgrade(&self.client),
511 id,
512 }
513 }
514}
515
516impl<T: 'static> Drop for PendingEntitySubscription<T> {
517 fn drop(&mut self) {
518 if !self.consumed {
519 let mut state = self.client.handler_set.lock();
520 if let Some(EntityMessageSubscriber::Pending(messages)) = state
521 .entities_by_type_and_remote_id
522 .remove(&(TypeId::of::<T>(), self.remote_id))
523 {
524 for message in messages {
525 log::info!("unhandled message {}", message.payload_type_name());
526 }
527 }
528 }
529 }
530}
531
532#[derive(Copy, Clone, Deserialize, Debug, RegisterSetting)]
533pub struct TelemetrySettings {
534 pub diagnostics: bool,
535 pub metrics: bool,
536}
537
538impl settings::Settings for TelemetrySettings {
539 fn from_settings(content: &SettingsContent) -> Self {
540 Self {
541 diagnostics: content.telemetry.as_ref().unwrap().diagnostics.unwrap(),
542 metrics: content.telemetry.as_ref().unwrap().metrics.unwrap(),
543 }
544 }
545}
546
547impl Client {
548 pub fn new(
549 clock: Arc<dyn SystemClock>,
550 http: Arc<HttpClientWithUrl>,
551 cx: &mut App,
552 ) -> Arc<Self> {
553 Arc::new(Self {
554 id: AtomicU64::new(0),
555 peer: Peer::new(0),
556 telemetry: Telemetry::new(clock, http.clone(), cx),
557 cloud_client: Arc::new(CloudApiClient::new(http.clone())),
558 http,
559 credentials_provider: ClientCredentialsProvider::new(cx),
560 state: Default::default(),
561 handler_set: Default::default(),
562 message_to_client_handlers: Mutex::new(Vec::new()),
563 sign_out_tx: Mutex::new(None),
564
565 #[cfg(any(test, feature = "test-support"))]
566 authenticate: Default::default(),
567 #[cfg(any(test, feature = "test-support"))]
568 establish_connection: Default::default(),
569 #[cfg(any(test, feature = "test-support"))]
570 rpc_url: RwLock::default(),
571 })
572 }
573
574 pub fn production(cx: &mut App) -> Arc<Self> {
575 let clock = Arc::new(clock::RealSystemClock);
576 let http = Arc::new(HttpClientWithUrl::new_url(
577 cx.http_client(),
578 &ClientSettings::get_global(cx).server_url,
579 cx.http_client().proxy().cloned(),
580 ));
581 Self::new(clock, http, cx)
582 }
583
584 pub fn id(&self) -> u64 {
585 self.id.load(Ordering::SeqCst)
586 }
587
588 pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
589 self.http.clone()
590 }
591
592 pub fn credentials_provider(&self) -> Arc<dyn CredentialsProvider> {
593 self.credentials_provider.provider.clone()
594 }
595
596 pub fn cloud_client(&self) -> Arc<CloudApiClient> {
597 self.cloud_client.clone()
598 }
599
600 pub fn set_id(&self, id: u64) -> &Self {
601 self.id.store(id, Ordering::SeqCst);
602 self
603 }
604
605 #[cfg(any(test, feature = "test-support"))]
606 pub fn teardown(&self) {
607 let mut state = self.state.write();
608 state._reconnect_task.take();
609 self.handler_set.lock().clear();
610 self.peer.teardown();
611 }
612
613 #[cfg(any(test, feature = "test-support"))]
614 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
615 where
616 F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
617 {
618 *self.authenticate.write() = Some(Box::new(authenticate));
619 self
620 }
621
622 #[cfg(any(test, feature = "test-support"))]
623 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
624 where
625 F: 'static
626 + Send
627 + Sync
628 + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
629 {
630 *self.establish_connection.write() = Some(Box::new(connect));
631 self
632 }
633
634 #[cfg(any(test, feature = "test-support"))]
635 pub fn override_rpc_url(&self, url: Url) -> &Self {
636 *self.rpc_url.write() = Some(url);
637 self
638 }
639
640 pub fn global(cx: &App) -> Arc<Self> {
641 cx.global::<GlobalClient>().0.clone()
642 }
643 pub fn set_global(client: Arc<Client>, cx: &mut App) {
644 cx.set_global(GlobalClient(client))
645 }
646
647 pub fn user_id(&self) -> Option<u64> {
648 self.state
649 .read()
650 .credentials
651 .as_ref()
652 .map(|credentials| credentials.user_id)
653 }
654
655 pub fn peer_id(&self) -> Option<PeerId> {
656 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
657 Some(*peer_id)
658 } else {
659 None
660 }
661 }
662
663 pub fn status(&self) -> watch::Receiver<Status> {
664 self.state.read().status.1.clone()
665 }
666
667 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
668 log::info!("set status on client {}: {:?}", self.id(), status);
669 let mut state = self.state.write();
670 *state.status.0.borrow_mut() = status;
671
672 match status {
673 Status::Connected { .. } => {
674 state._reconnect_task = None;
675 }
676 Status::ConnectionLost => {
677 let client = self.clone();
678 state._reconnect_task = Some(cx.spawn(async move |cx| {
679 #[cfg(any(test, feature = "test-support"))]
680 let mut rng = StdRng::seed_from_u64(0);
681 #[cfg(not(any(test, feature = "test-support")))]
682 let mut rng = StdRng::from_os_rng();
683
684 let mut delay = INITIAL_RECONNECTION_DELAY;
685 loop {
686 match client.connect(true, cx).await {
687 ConnectionResult::Timeout => {
688 log::error!("client connect attempt timed out")
689 }
690 ConnectionResult::ConnectionReset => {
691 log::error!("client connect attempt reset")
692 }
693 ConnectionResult::Result(r) => {
694 if let Err(error) = r {
695 log::error!("failed to connect: {error}");
696 } else {
697 break;
698 }
699 }
700 }
701
702 if matches!(
703 *client.status().borrow(),
704 Status::AuthenticationError | Status::ConnectionError
705 ) {
706 client.set_status(
707 Status::ReconnectionError {
708 next_reconnection: Instant::now() + delay,
709 },
710 cx,
711 );
712 let jitter = Duration::from_millis(
713 rng.random_range(0..delay.as_millis() as u64),
714 );
715 cx.background_executor().timer(delay + jitter).await;
716 delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
717 } else {
718 break;
719 }
720 }
721 }));
722 }
723 Status::SignedOut | Status::UpgradeRequired => {
724 self.telemetry.set_authenticated_user_info(None, false);
725 state._reconnect_task.take();
726 }
727 _ => {}
728 }
729 }
730
731 pub fn subscribe_to_entity<T>(
732 self: &Arc<Self>,
733 remote_id: u64,
734 ) -> Result<PendingEntitySubscription<T>>
735 where
736 T: 'static,
737 {
738 let id = (TypeId::of::<T>(), remote_id);
739
740 let mut state = self.handler_set.lock();
741 anyhow::ensure!(
742 !state.entities_by_type_and_remote_id.contains_key(&id),
743 "already subscribed to entity"
744 );
745
746 state
747 .entities_by_type_and_remote_id
748 .insert(id, EntityMessageSubscriber::Pending(Default::default()));
749
750 Ok(PendingEntitySubscription {
751 client: self.clone(),
752 remote_id,
753 consumed: false,
754 _entity_type: PhantomData,
755 })
756 }
757
758 #[track_caller]
759 pub fn add_message_handler<M, E, H, F>(
760 self: &Arc<Self>,
761 entity: WeakEntity<E>,
762 handler: H,
763 ) -> Subscription
764 where
765 M: EnvelopedMessage,
766 E: 'static,
767 H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
768 F: 'static + Future<Output = Result<()>>,
769 {
770 self.add_message_handler_impl(entity, move |entity, message, _, cx| {
771 handler(entity, message, cx)
772 })
773 }
774
775 fn add_message_handler_impl<M, E, H, F>(
776 self: &Arc<Self>,
777 entity: WeakEntity<E>,
778 handler: H,
779 ) -> Subscription
780 where
781 M: EnvelopedMessage,
782 E: 'static,
783 H: 'static
784 + Sync
785 + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
786 + Send
787 + Sync,
788 F: 'static + Future<Output = Result<()>>,
789 {
790 let message_type_id = TypeId::of::<M>();
791 let mut state = self.handler_set.lock();
792 state
793 .entities_by_message_type
794 .insert(message_type_id, entity.into());
795
796 let prev_handler = state.message_handlers.insert(
797 message_type_id,
798 Arc::new(move |subscriber, envelope, client, cx| {
799 let subscriber = subscriber.downcast::<E>().unwrap();
800 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
801 handler(subscriber, *envelope, client, cx).boxed_local()
802 }),
803 );
804 if prev_handler.is_some() {
805 let location = std::panic::Location::caller();
806 panic!(
807 "{}:{} registered handler for the same message {} twice",
808 location.file(),
809 location.line(),
810 std::any::type_name::<M>()
811 );
812 }
813
814 Subscription::Message {
815 client: Arc::downgrade(self),
816 id: message_type_id,
817 }
818 }
819
820 pub fn add_request_handler<M, E, H, F>(
821 self: &Arc<Self>,
822 entity: WeakEntity<E>,
823 handler: H,
824 ) -> Subscription
825 where
826 M: RequestMessage,
827 E: 'static,
828 H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
829 F: 'static + Future<Output = Result<M::Response>>,
830 {
831 self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
832 Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
833 })
834 }
835
836 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
837 receipt: Receipt<T>,
838 response: F,
839 client: AnyProtoClient,
840 ) -> Result<()> {
841 match response.await {
842 Ok(response) => {
843 client.send_response(receipt.message_id, response)?;
844 Ok(())
845 }
846 Err(error) => {
847 client.send_response(receipt.message_id, error.to_proto())?;
848 Err(error)
849 }
850 }
851 }
852
853 pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
854 self.credentials_provider
855 .read_credentials(cx)
856 .await
857 .is_some()
858 }
859
860 pub async fn sign_in(
861 self: &Arc<Self>,
862 try_provider: bool,
863 cx: &AsyncApp,
864 ) -> Result<Credentials> {
865 let is_reauthenticating = if self.status().borrow().is_signed_out() {
866 self.set_status(Status::Authenticating, cx);
867 false
868 } else {
869 self.set_status(Status::Reauthenticating, cx);
870 true
871 };
872
873 let mut credentials = None;
874
875 let old_credentials = self.state.read().credentials.clone();
876 if let Some(old_credentials) = old_credentials
877 && self.validate_credentials(&old_credentials, cx).await?
878 {
879 credentials = Some(old_credentials);
880 }
881
882 if credentials.is_none()
883 && try_provider
884 && let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await
885 {
886 if self.validate_credentials(&stored_credentials, cx).await? {
887 credentials = Some(stored_credentials);
888 } else {
889 self.credentials_provider
890 .delete_credentials(cx)
891 .await
892 .log_err();
893 }
894 }
895
896 if credentials.is_none() {
897 let mut status_rx = self.status();
898 let _ = status_rx.next().await;
899 futures::select_biased! {
900 authenticate = self.authenticate(cx).fuse() => {
901 match authenticate {
902 Ok(creds) => {
903 if IMPERSONATE_LOGIN.is_none() {
904 self.credentials_provider
905 .write_credentials(creds.user_id, creds.access_token.clone(), cx)
906 .await
907 .log_err();
908 }
909
910 credentials = Some(creds);
911 },
912 Err(err) => {
913 self.set_status(Status::AuthenticationError, cx);
914 return Err(err);
915 }
916 }
917 }
918 _ = status_rx.next().fuse() => {
919 return Err(anyhow!("authentication canceled"));
920 }
921 }
922 }
923
924 let credentials = credentials.unwrap();
925 self.set_id(credentials.user_id);
926 self.cloud_client
927 .set_credentials(credentials.user_id as u32, credentials.access_token.clone());
928 self.state.write().credentials = Some(credentials.clone());
929 self.set_status(
930 if is_reauthenticating {
931 Status::Reauthenticated
932 } else {
933 Status::Authenticated
934 },
935 cx,
936 );
937
938 Ok(credentials)
939 }
940
941 async fn validate_credentials(
942 self: &Arc<Self>,
943 credentials: &Credentials,
944 cx: &AsyncApp,
945 ) -> Result<bool> {
946 match self
947 .cloud_client
948 .validate_credentials(credentials.user_id as u32, &credentials.access_token)
949 .await
950 {
951 Ok(valid) => Ok(valid),
952 Err(err) => {
953 self.set_status(Status::AuthenticationError, cx);
954 Err(anyhow!("failed to validate credentials: {}", err))
955 }
956 }
957 }
958
959 /// Establishes a WebSocket connection with Cloud for receiving updates from the server.
960 async fn connect_to_cloud(self: &Arc<Self>, cx: &AsyncApp) -> Result<()> {
961 let connect_task = cx.update({
962 let cloud_client = self.cloud_client.clone();
963 move |cx| cloud_client.connect(cx)
964 })?;
965 let connection = connect_task.await?;
966
967 let (mut messages, task) = cx.update(|cx| connection.spawn(cx));
968 task.detach();
969
970 cx.spawn({
971 let this = self.clone();
972 async move |cx| {
973 while let Some(message) = messages.next().await {
974 if let Some(message) = message.log_err() {
975 this.handle_message_to_client(message, cx);
976 }
977 }
978 }
979 })
980 .detach();
981
982 Ok(())
983 }
984
985 /// Performs a sign-in and also (optionally) connects to Collab.
986 ///
987 /// Only Zed staff automatically connect to Collab.
988 pub async fn sign_in_with_optional_connect(
989 self: &Arc<Self>,
990 try_provider: bool,
991 cx: &AsyncApp,
992 ) -> Result<()> {
993 // Don't try to sign in again if we're already connected to Collab, as it will temporarily disconnect us.
994 if self.status().borrow().is_connected() {
995 return Ok(());
996 }
997
998 let (is_staff_tx, is_staff_rx) = oneshot::channel::<bool>();
999 let mut is_staff_tx = Some(is_staff_tx);
1000 cx.update(|cx| {
1001 cx.on_flags_ready(move |state, _cx| {
1002 if let Some(is_staff_tx) = is_staff_tx.take() {
1003 is_staff_tx.send(state.is_staff).log_err();
1004 }
1005 })
1006 .detach();
1007 });
1008
1009 let credentials = self.sign_in(try_provider, cx).await?;
1010
1011 self.connect_to_cloud(cx).await.log_err();
1012
1013 cx.update(move |cx| {
1014 cx.spawn({
1015 let client = self.clone();
1016 async move |cx| {
1017 let is_staff = is_staff_rx.await?;
1018 if is_staff {
1019 match client.connect_with_credentials(credentials, cx).await {
1020 ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
1021 ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
1022 ConnectionResult::Result(result) => {
1023 result.context("client auth and connect")
1024 }
1025 }
1026 } else {
1027 Ok(())
1028 }
1029 }
1030 })
1031 .detach_and_log_err(cx);
1032 });
1033
1034 Ok(())
1035 }
1036
1037 pub async fn connect(
1038 self: &Arc<Self>,
1039 try_provider: bool,
1040 cx: &AsyncApp,
1041 ) -> ConnectionResult<()> {
1042 let was_disconnected = match *self.status().borrow() {
1043 Status::SignedOut | Status::Authenticated => true,
1044 Status::ConnectionError
1045 | Status::ConnectionLost
1046 | Status::Authenticating
1047 | Status::AuthenticationError
1048 | Status::Reauthenticating
1049 | Status::Reauthenticated
1050 | Status::ReconnectionError { .. } => false,
1051 Status::Connected { .. } | Status::Connecting | Status::Reconnecting => {
1052 return ConnectionResult::Result(Ok(()));
1053 }
1054 Status::UpgradeRequired => {
1055 return ConnectionResult::Result(
1056 Err(EstablishConnectionError::UpgradeRequired)
1057 .context("client auth and connect"),
1058 );
1059 }
1060 };
1061 let credentials = match self.sign_in(try_provider, cx).await {
1062 Ok(credentials) => credentials,
1063 Err(err) => return ConnectionResult::Result(Err(err)),
1064 };
1065
1066 if was_disconnected {
1067 self.set_status(Status::Connecting, cx);
1068 } else {
1069 self.set_status(Status::Reconnecting, cx);
1070 }
1071
1072 self.connect_with_credentials(credentials, cx).await
1073 }
1074
1075 async fn connect_with_credentials(
1076 self: &Arc<Self>,
1077 credentials: Credentials,
1078 cx: &AsyncApp,
1079 ) -> ConnectionResult<()> {
1080 let mut timeout =
1081 futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
1082 futures::select_biased! {
1083 connection = self.establish_connection(&credentials, cx).fuse() => {
1084 match connection {
1085 Ok(conn) => {
1086 futures::select_biased! {
1087 result = self.set_connection(conn, cx).fuse() => {
1088 match result.context("client auth and connect") {
1089 Ok(()) => ConnectionResult::Result(Ok(())),
1090 Err(err) => {
1091 self.set_status(Status::ConnectionError, cx);
1092 ConnectionResult::Result(Err(err))
1093 },
1094 }
1095 },
1096 _ = timeout => {
1097 self.set_status(Status::ConnectionError, cx);
1098 ConnectionResult::Timeout
1099 }
1100 }
1101 }
1102 Err(EstablishConnectionError::Unauthorized) => {
1103 self.set_status(Status::ConnectionError, cx);
1104 ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
1105 }
1106 Err(EstablishConnectionError::UpgradeRequired) => {
1107 self.set_status(Status::UpgradeRequired, cx);
1108 ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
1109 }
1110 Err(error) => {
1111 self.set_status(Status::ConnectionError, cx);
1112 ConnectionResult::Result(Err(error).context("client auth and connect"))
1113 }
1114 }
1115 }
1116 _ = &mut timeout => {
1117 self.set_status(Status::ConnectionError, cx);
1118 ConnectionResult::Timeout
1119 }
1120 }
1121 }
1122
1123 async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
1124 let executor = cx.background_executor();
1125 log::debug!("add connection to peer");
1126 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
1127 let executor = executor.clone();
1128 move |duration| executor.timer(duration)
1129 });
1130 let handle_io = executor.spawn(handle_io);
1131
1132 let peer_id = async {
1133 log::debug!("waiting for server hello");
1134 let message = incoming.next().await.context("no hello message received")?;
1135 log::debug!("got server hello");
1136 let hello_message_type_name = message.payload_type_name().to_string();
1137 let hello = message
1138 .into_any()
1139 .downcast::<TypedEnvelope<proto::Hello>>()
1140 .map_err(|_| {
1141 anyhow!(
1142 "invalid hello message received: {:?}",
1143 hello_message_type_name
1144 )
1145 })?;
1146 let peer_id = hello.payload.peer_id.context("invalid peer id")?;
1147 Ok(peer_id)
1148 };
1149
1150 let peer_id = match peer_id.await {
1151 Ok(peer_id) => peer_id,
1152 Err(error) => {
1153 self.peer.disconnect(connection_id);
1154 return Err(error);
1155 }
1156 };
1157
1158 log::debug!(
1159 "set status to connected (connection id: {:?}, peer id: {:?})",
1160 connection_id,
1161 peer_id
1162 );
1163 self.set_status(
1164 Status::Connected {
1165 peer_id,
1166 connection_id,
1167 },
1168 cx,
1169 );
1170
1171 cx.spawn({
1172 let this = self.clone();
1173 async move |cx| {
1174 while let Some(message) = incoming.next().await {
1175 this.handle_message(message, cx);
1176 // Don't starve the main thread when receiving lots of messages at once.
1177 smol::future::yield_now().await;
1178 }
1179 }
1180 })
1181 .detach();
1182
1183 cx.spawn({
1184 let this = self.clone();
1185 async move |cx| match handle_io.await {
1186 Ok(()) => {
1187 if *this.status().borrow()
1188 == (Status::Connected {
1189 connection_id,
1190 peer_id,
1191 })
1192 {
1193 this.set_status(Status::SignedOut, cx);
1194 }
1195 }
1196 Err(err) => {
1197 log::error!("connection error: {:?}", err);
1198 this.set_status(Status::ConnectionLost, cx);
1199 }
1200 }
1201 })
1202 .detach();
1203
1204 Ok(())
1205 }
1206
1207 fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1208 #[cfg(any(test, feature = "test-support"))]
1209 if let Some(callback) = self.authenticate.read().as_ref() {
1210 return callback(cx);
1211 }
1212
1213 self.authenticate_with_browser(cx)
1214 }
1215
1216 fn establish_connection(
1217 self: &Arc<Self>,
1218 credentials: &Credentials,
1219 cx: &AsyncApp,
1220 ) -> Task<Result<Connection, EstablishConnectionError>> {
1221 #[cfg(any(test, feature = "test-support"))]
1222 if let Some(callback) = self.establish_connection.read().as_ref() {
1223 return callback(credentials, cx);
1224 }
1225
1226 self.establish_websocket_connection(credentials, cx)
1227 }
1228
1229 fn rpc_url(
1230 &self,
1231 http: Arc<HttpClientWithUrl>,
1232 release_channel: Option<ReleaseChannel>,
1233 ) -> impl Future<Output = Result<url::Url>> + use<> {
1234 #[cfg(any(test, feature = "test-support"))]
1235 let url_override = self.rpc_url.read().clone();
1236
1237 async move {
1238 #[cfg(any(test, feature = "test-support"))]
1239 if let Some(url) = url_override {
1240 return Ok(url);
1241 }
1242
1243 if let Some(url) = &*ZED_RPC_URL {
1244 return Url::parse(url).context("invalid rpc url");
1245 }
1246
1247 let mut url = http.build_url("/rpc");
1248 if let Some(preview_param) =
1249 release_channel.and_then(|channel| channel.release_query_param())
1250 {
1251 url += "?";
1252 url += preview_param;
1253 }
1254
1255 let response = http.get(&url, Default::default(), false).await?;
1256 anyhow::ensure!(
1257 response.status().is_redirection(),
1258 "unexpected /rpc response status {}",
1259 response.status()
1260 );
1261 let collab_url = response
1262 .headers()
1263 .get("Location")
1264 .context("missing location header in /rpc response")?
1265 .to_str()
1266 .map_err(EstablishConnectionError::other)?
1267 .to_string();
1268 Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
1269 }
1270 }
1271
1272 fn establish_websocket_connection(
1273 self: &Arc<Self>,
1274 credentials: &Credentials,
1275 cx: &AsyncApp,
1276 ) -> Task<Result<Connection, EstablishConnectionError>> {
1277 let release_channel = cx.update(|cx| ReleaseChannel::try_global(cx));
1278 let app_version = cx.update(|cx| AppVersion::global(cx).to_string());
1279
1280 let http = self.http.clone();
1281 let proxy = http.proxy().cloned();
1282 let user_agent = http.user_agent().cloned();
1283 let credentials = credentials.clone();
1284 let rpc_url = self.rpc_url(http, release_channel);
1285 let system_id = self.telemetry.system_id();
1286 let metrics_id = self.telemetry.metrics_id();
1287 cx.spawn(async move |cx| {
1288 use HttpOrHttps::*;
1289
1290 #[derive(Debug)]
1291 enum HttpOrHttps {
1292 Http,
1293 Https,
1294 }
1295
1296 let mut rpc_url = rpc_url.await?;
1297 let url_scheme = match rpc_url.scheme() {
1298 "https" => Https,
1299 "http" => Http,
1300 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1301 };
1302
1303 let stream = gpui_tokio::Tokio::spawn_result(cx, {
1304 let rpc_url = rpc_url.clone();
1305 async move {
1306 let rpc_host = rpc_url
1307 .host_str()
1308 .zip(rpc_url.port_or_known_default())
1309 .context("missing host in rpc url")?;
1310 Ok(match proxy {
1311 Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
1312 None => Box::new(TcpStream::connect(rpc_host).await?),
1313 })
1314 }
1315 })
1316 .await?;
1317
1318 log::info!("connected to rpc endpoint {}", rpc_url);
1319
1320 rpc_url
1321 .set_scheme(match url_scheme {
1322 Https => "wss",
1323 Http => "ws",
1324 })
1325 .unwrap();
1326
1327 // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1328 // for us from the RPC URL.
1329 //
1330 // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1331 let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1332
1333 // We then modify the request to add our desired headers.
1334 let request_headers = request.headers_mut();
1335 request_headers.insert(
1336 http::header::AUTHORIZATION,
1337 HeaderValue::from_str(&credentials.authorization_header())?,
1338 );
1339 request_headers.insert(
1340 "x-zed-protocol-version",
1341 HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1342 );
1343 request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1344 request_headers.insert(
1345 "x-zed-release-channel",
1346 HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1347 );
1348 if let Some(user_agent) = user_agent {
1349 request_headers.insert(http::header::USER_AGENT, user_agent);
1350 }
1351 if let Some(system_id) = system_id {
1352 request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1353 }
1354 if let Some(metrics_id) = metrics_id {
1355 request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1356 }
1357
1358 let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1359 request,
1360 stream,
1361 Some(Arc::new(http_client_tls::tls_config()).into()),
1362 None,
1363 )
1364 .await?;
1365
1366 Ok(Connection::new(
1367 stream
1368 .map_err(|error| anyhow!(error))
1369 .sink_map_err(|error| anyhow!(error)),
1370 ))
1371 })
1372 }
1373
1374 pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1375 let http = self.http.clone();
1376 let this = self.clone();
1377 cx.spawn(async move |cx| {
1378 let background = cx.background_executor().clone();
1379
1380 let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1381 cx.update(|cx| {
1382 cx.spawn(async move |cx| {
1383 if let Ok(url) = open_url_rx.await {
1384 cx.update(|cx| cx.open_url(&url));
1385 }
1386 })
1387 .detach();
1388 });
1389
1390 let credentials = background
1391 .clone()
1392 .spawn(async move {
1393 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1394 // zed server to encrypt the user's access token, so that it can'be intercepted by
1395 // any other app running on the user's device.
1396 let (public_key, private_key) =
1397 rpc::auth::keypair().context("failed to generate keypair for auth")?;
1398 let public_key_string = String::try_from(public_key)
1399 .context("failed to serialize public key for auth")?;
1400
1401 if let Some((login, token)) =
1402 IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1403 {
1404 if !*USE_WEB_LOGIN {
1405 eprintln!("authenticate as admin {login}, {token}");
1406
1407 return this
1408 .authenticate_as_admin(http, login.clone(), token.clone())
1409 .await;
1410 }
1411 }
1412
1413 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1414 let server = tiny_http::Server::http("127.0.0.1:0")
1415 .map_err(|e| anyhow!(e).context("failed to bind callback port"))?;
1416 let port = server
1417 .server_addr()
1418 .to_ip()
1419 .context("server not bound to a TCP address")?
1420 .port();
1421
1422 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1423 // that the user is signing in from a Zed app running on the same device.
1424 let url = http.build_url(&format!(
1425 "/native_app_signin?native_app_port={}&native_app_public_key={}",
1426 port, public_key_string
1427 ));
1428
1429 open_url_tx.send(url).log_err();
1430
1431 #[derive(Deserialize)]
1432 struct CallbackParams {
1433 pub user_id: String,
1434 pub access_token: String,
1435 }
1436
1437 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1438 // access token from the query params.
1439 //
1440 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1441 // custom URL scheme instead of this local HTTP server.
1442 let (user_id, access_token) = background
1443 .spawn(async move {
1444 for _ in 0..100 {
1445 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1446 let path = req.url();
1447 let url = Url::parse(&format!("http://example.com{}", path))
1448 .context("failed to parse login notification url")?;
1449 let callback_params: CallbackParams =
1450 serde_urlencoded::from_str(url.query().unwrap_or_default())
1451 .context(
1452 "failed to parse sign-in callback query parameters",
1453 )?;
1454
1455 let post_auth_url =
1456 http.build_url("/native_app_signin_succeeded");
1457 req.respond(
1458 tiny_http::Response::empty(302).with_header(
1459 tiny_http::Header::from_bytes(
1460 &b"Location"[..],
1461 post_auth_url.as_bytes(),
1462 )
1463 .unwrap(),
1464 ),
1465 )
1466 .context("failed to respond to login http request")?;
1467 return Ok((
1468 callback_params.user_id,
1469 callback_params.access_token,
1470 ));
1471 }
1472 }
1473
1474 anyhow::bail!("didn't receive login redirect");
1475 })
1476 .await?;
1477
1478 let access_token = private_key
1479 .decrypt_string(&access_token)
1480 .context("failed to decrypt access token")?;
1481
1482 Ok(Credentials {
1483 user_id: user_id.parse()?,
1484 access_token,
1485 })
1486 })
1487 .await?;
1488
1489 cx.update(|cx| cx.activate(true));
1490 Ok(credentials)
1491 })
1492 }
1493
1494 async fn authenticate_as_admin(
1495 self: &Arc<Self>,
1496 http: Arc<HttpClientWithUrl>,
1497 login: String,
1498 api_token: String,
1499 ) -> Result<Credentials> {
1500 #[derive(Serialize)]
1501 struct ImpersonateUserBody {
1502 github_login: String,
1503 }
1504
1505 #[derive(Deserialize)]
1506 struct ImpersonateUserResponse {
1507 user_id: u64,
1508 access_token: String,
1509 }
1510
1511 let url = self
1512 .http
1513 .build_zed_cloud_url("/internal/users/impersonate")?;
1514 let request = Request::post(url.as_str())
1515 .header("Content-Type", "application/json")
1516 .header("Authorization", format!("Bearer {api_token}"))
1517 .body(
1518 serde_json::to_string(&ImpersonateUserBody {
1519 github_login: login,
1520 })?
1521 .into(),
1522 )?;
1523
1524 let mut response = http.send(request).await?;
1525 let mut body = String::new();
1526 response.body_mut().read_to_string(&mut body).await?;
1527 anyhow::ensure!(
1528 response.status().is_success(),
1529 "admin user request failed {} - {}",
1530 response.status().as_u16(),
1531 body,
1532 );
1533 let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
1534
1535 Ok(Credentials {
1536 user_id: response.user_id,
1537 access_token: response.access_token,
1538 })
1539 }
1540
1541 pub async fn acquire_llm_token(
1542 &self,
1543 llm_token: &LlmApiToken,
1544 organization_id: Option<OrganizationId>,
1545 ) -> Result<String> {
1546 let system_id = self.telemetry().system_id().map(|x| x.to_string());
1547 let cloud_client = self.cloud_client();
1548 match llm_token
1549 .acquire(&cloud_client, system_id, organization_id)
1550 .await
1551 {
1552 Ok(token) => Ok(token),
1553 Err(ClientApiError::Unauthorized) => {
1554 self.request_sign_out();
1555 Err(ClientApiError::Unauthorized).context("Failed to create LLM token")
1556 }
1557 Err(err) => Err(anyhow::Error::from(err)),
1558 }
1559 }
1560
1561 pub async fn refresh_llm_token(
1562 &self,
1563 llm_token: &LlmApiToken,
1564 organization_id: Option<OrganizationId>,
1565 ) -> Result<String> {
1566 let system_id = self.telemetry().system_id().map(|x| x.to_string());
1567 let cloud_client = self.cloud_client();
1568 match llm_token
1569 .refresh(&cloud_client, system_id, organization_id)
1570 .await
1571 {
1572 Ok(token) => Ok(token),
1573 Err(ClientApiError::Unauthorized) => {
1574 self.request_sign_out();
1575 return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
1576 }
1577 Err(err) => return Err(anyhow::Error::from(err)),
1578 }
1579 }
1580
1581 pub async fn clear_and_refresh_llm_token(
1582 &self,
1583 llm_token: &LlmApiToken,
1584 organization_id: Option<OrganizationId>,
1585 ) -> Result<String> {
1586 let system_id = self.telemetry().system_id().map(|x| x.to_string());
1587 let cloud_client = self.cloud_client();
1588 match llm_token
1589 .clear_and_refresh(&cloud_client, system_id, organization_id)
1590 .await
1591 {
1592 Ok(token) => Ok(token),
1593 Err(ClientApiError::Unauthorized) => {
1594 self.request_sign_out();
1595 return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
1596 }
1597 Err(err) => return Err(anyhow::Error::from(err)),
1598 }
1599 }
1600
1601 pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1602 self.state.write().credentials = None;
1603 self.cloud_client.clear_credentials();
1604 self.disconnect(cx);
1605
1606 if self.has_credentials(cx).await {
1607 self.credentials_provider
1608 .delete_credentials(cx)
1609 .await
1610 .log_err();
1611 }
1612 }
1613
1614 /// Requests a sign out to be performed asynchronously.
1615 pub fn request_sign_out(&self) {
1616 if let Some(sign_out_tx) = self.sign_out_tx.lock().clone() {
1617 sign_out_tx.unbounded_send(()).ok();
1618 }
1619 }
1620
1621 pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1622 self.peer.teardown();
1623 self.set_status(Status::SignedOut, cx);
1624 }
1625
1626 pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1627 self.peer.teardown();
1628 self.set_status(Status::ConnectionLost, cx);
1629 }
1630
1631 fn connection_id(&self) -> Result<ConnectionId> {
1632 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1633 Ok(connection_id)
1634 } else {
1635 anyhow::bail!("not connected");
1636 }
1637 }
1638
1639 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1640 log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1641 self.peer.send(self.connection_id()?, message)
1642 }
1643
1644 pub fn request<T: RequestMessage>(
1645 &self,
1646 request: T,
1647 ) -> impl Future<Output = Result<T::Response>> + use<T> {
1648 self.request_envelope(request)
1649 .map_ok(|envelope| envelope.payload)
1650 }
1651
1652 pub fn request_stream<T: RequestMessage>(
1653 &self,
1654 request: T,
1655 ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1656 let client_id = self.id.load(Ordering::SeqCst);
1657 log::debug!(
1658 "rpc request start. client_id:{}. name:{}",
1659 client_id,
1660 T::NAME
1661 );
1662 let response = self
1663 .connection_id()
1664 .map(|conn_id| self.peer.request_stream(conn_id, request));
1665 async move {
1666 let response = response?.await;
1667 log::debug!(
1668 "rpc request finish. client_id:{}. name:{}",
1669 client_id,
1670 T::NAME
1671 );
1672 response
1673 }
1674 }
1675
1676 pub fn request_envelope<T: RequestMessage>(
1677 &self,
1678 request: T,
1679 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1680 let client_id = self.id();
1681 log::debug!(
1682 "rpc request start. client_id:{}. name:{}",
1683 client_id,
1684 T::NAME
1685 );
1686 let response = self
1687 .connection_id()
1688 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1689 async move {
1690 let response = response?.await;
1691 log::debug!(
1692 "rpc request finish. client_id:{}. name:{}",
1693 client_id,
1694 T::NAME
1695 );
1696 response
1697 }
1698 }
1699
1700 pub fn request_dynamic(
1701 &self,
1702 envelope: proto::Envelope,
1703 request_type: &'static str,
1704 ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1705 let client_id = self.id();
1706 log::debug!(
1707 "rpc request start. client_id:{}. name:{}",
1708 client_id,
1709 request_type
1710 );
1711 let response = self
1712 .connection_id()
1713 .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1714 async move {
1715 let response = response?.await;
1716 log::debug!(
1717 "rpc request finish. client_id:{}. name:{}",
1718 client_id,
1719 request_type
1720 );
1721 Ok(response?.0)
1722 }
1723 }
1724
1725 fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1726 let sender_id = message.sender_id();
1727 let request_id = message.message_id();
1728 let type_name = message.payload_type_name();
1729 let original_sender_id = message.original_sender_id();
1730
1731 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1732 &self.handler_set,
1733 message,
1734 self.clone().into(),
1735 cx.clone(),
1736 ) {
1737 let client_id = self.id();
1738 log::debug!(
1739 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1740 client_id,
1741 original_sender_id,
1742 type_name
1743 );
1744 cx.spawn(async move |_| match future.await {
1745 Ok(()) => {
1746 log::debug!("rpc message handled. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}");
1747 }
1748 Err(error) => {
1749 log::error!("error handling message. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}, error:{error:#}");
1750 }
1751 })
1752 .detach();
1753 } else {
1754 log::info!("unhandled message {}", type_name);
1755 self.peer
1756 .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1757 .log_err();
1758 }
1759 }
1760
1761 pub fn add_message_to_client_handler(
1762 self: &Arc<Client>,
1763 handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
1764 ) {
1765 self.message_to_client_handlers
1766 .lock()
1767 .push(Box::new(handler));
1768 }
1769
1770 fn handle_message_to_client(self: &Arc<Client>, message: MessageToClient, cx: &AsyncApp) {
1771 cx.update(|cx| {
1772 for handler in self.message_to_client_handlers.lock().iter() {
1773 handler(&message, cx);
1774 }
1775 });
1776 }
1777
1778 pub fn telemetry(&self) -> &Arc<Telemetry> {
1779 &self.telemetry
1780 }
1781}
1782
1783impl ProtoClient for Client {
1784 fn request(
1785 &self,
1786 envelope: proto::Envelope,
1787 request_type: &'static str,
1788 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1789 self.request_dynamic(envelope, request_type).boxed()
1790 }
1791
1792 fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1793 log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1794 let connection_id = self.connection_id()?;
1795 self.peer.send_dynamic(connection_id, envelope)
1796 }
1797
1798 fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1799 log::debug!(
1800 "rpc respond. client_id:{}, name:{}",
1801 self.id(),
1802 message_type
1803 );
1804 let connection_id = self.connection_id()?;
1805 self.peer.send_dynamic(connection_id, envelope)
1806 }
1807
1808 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1809 &self.handler_set
1810 }
1811
1812 fn is_via_collab(&self) -> bool {
1813 true
1814 }
1815
1816 fn has_wsl_interop(&self) -> bool {
1817 false
1818 }
1819}
1820
1821/// prefix for the zed:// url scheme
1822pub const ZED_URL_SCHEME: &str = "zed";
1823
1824/// A parsed Zed link that can be handled internally by the application.
1825#[derive(Debug, Clone, PartialEq, Eq)]
1826pub enum ZedLink {
1827 /// Join a channel: `zed.dev/channel/channel-name-123` or `zed://channel/channel-name-123`
1828 Channel { channel_id: u64 },
1829 /// Open channel notes: `zed.dev/channel/channel-name-123/notes` or with heading `notes#heading`
1830 ChannelNotes {
1831 channel_id: u64,
1832 heading: Option<String>,
1833 },
1834}
1835
1836/// Parses the given link into a Zed link.
1837///
1838/// Returns a [`Some`] containing the parsed link if the link is a recognized Zed link
1839/// that should be handled internally by the application.
1840/// Returns [`None`] for links that should be opened in the browser.
1841pub fn parse_zed_link(link: &str, cx: &App) -> Option<ZedLink> {
1842 let server_url = &ClientSettings::get_global(cx).server_url;
1843 let path = link
1844 .strip_prefix(server_url)
1845 .and_then(|result| result.strip_prefix('/'))
1846 .or_else(|| {
1847 link.strip_prefix(ZED_URL_SCHEME)
1848 .and_then(|result| result.strip_prefix("://"))
1849 })?;
1850
1851 let mut parts = path.split('/');
1852
1853 if parts.next() != Some("channel") {
1854 return None;
1855 }
1856
1857 let slug = parts.next()?;
1858 let id_str = slug.split('-').next_back()?;
1859 let channel_id = id_str.parse::<u64>().ok()?;
1860
1861 let Some(next) = parts.next() else {
1862 return Some(ZedLink::Channel { channel_id });
1863 };
1864
1865 if let Some(heading) = next.strip_prefix("notes#") {
1866 return Some(ZedLink::ChannelNotes {
1867 channel_id,
1868 heading: Some(heading.to_string()),
1869 });
1870 }
1871
1872 if next == "notes" {
1873 return Some(ZedLink::ChannelNotes {
1874 channel_id,
1875 heading: None,
1876 });
1877 }
1878
1879 None
1880}
1881
1882#[cfg(test)]
1883mod tests {
1884 use super::*;
1885 use crate::test::{FakeServer, parse_authorization_header};
1886
1887 use clock::FakeSystemClock;
1888 use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1889 use http_client::FakeHttpClient;
1890 use parking_lot::Mutex;
1891 use proto::TypedEnvelope;
1892 use settings::SettingsStore;
1893 use std::future;
1894
1895 #[test]
1896 fn test_proxy_settings_trims_and_ignores_empty_proxy() {
1897 let mut content = SettingsContent::default();
1898 content.proxy = Some(" ".to_owned());
1899 assert_eq!(ProxySettings::from_settings(&content).proxy, None);
1900
1901 content.proxy = Some("http://127.0.0.1:10809".to_owned());
1902 assert_eq!(
1903 ProxySettings::from_settings(&content).proxy.as_deref(),
1904 Some("http://127.0.0.1:10809")
1905 );
1906 }
1907
1908 #[gpui::test(iterations = 10)]
1909 async fn test_reconnection(cx: &mut TestAppContext) {
1910 init_test(cx);
1911 let user_id = 5;
1912 let client = cx.update(|cx| {
1913 Client::new(
1914 Arc::new(FakeSystemClock::new()),
1915 FakeHttpClient::with_404_response(),
1916 cx,
1917 )
1918 });
1919 let server = FakeServer::for_client(user_id, &client, cx).await;
1920 let mut status = client.status();
1921 assert!(matches!(
1922 status.next().await,
1923 Some(Status::Connected { .. })
1924 ));
1925 assert_eq!(server.auth_count(), 1);
1926
1927 server.forbid_connections();
1928 server.disconnect();
1929 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1930
1931 server.allow_connections();
1932 cx.executor().advance_clock(Duration::from_secs(10));
1933 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1934 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1935
1936 server.forbid_connections();
1937 server.disconnect();
1938 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1939
1940 // Clear cached credentials after authentication fails
1941 server.roll_access_token();
1942 server.allow_connections();
1943 cx.executor().run_until_parked();
1944 cx.executor().advance_clock(Duration::from_secs(10));
1945 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1946 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1947 }
1948
1949 #[gpui::test(iterations = 10)]
1950 async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) {
1951 init_test(cx);
1952 let http_client = FakeHttpClient::with_200_response();
1953 let client =
1954 cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1955 let server = FakeServer::for_client(42, &client, cx).await;
1956 let mut status = client.status();
1957 assert!(matches!(
1958 status.next().await,
1959 Some(Status::Connected { .. })
1960 ));
1961 assert_eq!(server.auth_count(), 1);
1962
1963 // Simulate an auth failure during reconnection.
1964 http_client
1965 .as_fake()
1966 .replace_handler(|_, _request| async move {
1967 Ok(http_client::Response::builder()
1968 .status(503)
1969 .body("".into())
1970 .unwrap())
1971 });
1972 server.disconnect();
1973 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1974
1975 // Restore the ability to authenticate.
1976 http_client
1977 .as_fake()
1978 .replace_handler(|_, _request| async move {
1979 Ok(http_client::Response::builder()
1980 .status(200)
1981 .body("".into())
1982 .unwrap())
1983 });
1984 cx.executor().advance_clock(Duration::from_secs(10));
1985 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1986 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1987 }
1988
1989 #[gpui::test(iterations = 10)]
1990 async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1991 init_test(cx);
1992 let user_id = 5;
1993 let client = cx.update(|cx| {
1994 Client::new(
1995 Arc::new(FakeSystemClock::new()),
1996 FakeHttpClient::with_404_response(),
1997 cx,
1998 )
1999 });
2000 let mut status = client.status();
2001
2002 // Time out when client tries to connect.
2003 client.override_authenticate(move |cx| {
2004 cx.background_spawn(async move {
2005 Ok(Credentials {
2006 user_id,
2007 access_token: "token".into(),
2008 })
2009 })
2010 });
2011 client.override_establish_connection(|_, cx| {
2012 cx.background_spawn(async move {
2013 future::pending::<()>().await;
2014 unreachable!()
2015 })
2016 });
2017 let auth_and_connect = cx.spawn({
2018 let client = client.clone();
2019 |cx| async move { client.connect(false, &cx).await }
2020 });
2021 executor.run_until_parked();
2022 assert!(matches!(status.next().await, Some(Status::Connecting)));
2023
2024 executor.advance_clock(CONNECTION_TIMEOUT);
2025 assert!(matches!(status.next().await, Some(Status::ConnectionError)));
2026 auth_and_connect.await.into_response().unwrap_err();
2027
2028 // Allow the connection to be established.
2029 let server = FakeServer::for_client(user_id, &client, cx).await;
2030 assert!(matches!(
2031 status.next().await,
2032 Some(Status::Connected { .. })
2033 ));
2034
2035 // Disconnect client.
2036 server.forbid_connections();
2037 server.disconnect();
2038 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
2039
2040 // Time out when re-establishing the connection.
2041 server.allow_connections();
2042 client.override_establish_connection(|_, cx| {
2043 cx.background_spawn(async move {
2044 future::pending::<()>().await;
2045 unreachable!()
2046 })
2047 });
2048 executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
2049 assert!(matches!(status.next().await, Some(Status::Reconnecting)));
2050
2051 executor.advance_clock(CONNECTION_TIMEOUT);
2052 assert!(matches!(
2053 status.next().await,
2054 Some(Status::ReconnectionError { .. })
2055 ));
2056 }
2057
2058 #[gpui::test(iterations = 10)]
2059 async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
2060 init_test(cx);
2061 let auth_count = Arc::new(Mutex::new(0));
2062 let http_client = FakeHttpClient::create(|_request| async move {
2063 Ok(http_client::Response::builder()
2064 .status(200)
2065 .body("".into())
2066 .unwrap())
2067 });
2068 let client =
2069 cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
2070 client.override_authenticate({
2071 let auth_count = auth_count.clone();
2072 move |cx| {
2073 let auth_count = auth_count.clone();
2074 cx.background_spawn(async move {
2075 *auth_count.lock() += 1;
2076 Ok(Credentials {
2077 user_id: 1,
2078 access_token: auth_count.lock().to_string(),
2079 })
2080 })
2081 }
2082 });
2083
2084 let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2085 assert_eq!(*auth_count.lock(), 1);
2086 assert_eq!(credentials.access_token, "1");
2087
2088 // If credentials are still valid, signing in doesn't trigger authentication.
2089 let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2090 assert_eq!(*auth_count.lock(), 1);
2091 assert_eq!(credentials.access_token, "1");
2092
2093 // If the server is unavailable, signing in doesn't trigger authentication.
2094 http_client
2095 .as_fake()
2096 .replace_handler(|_, _request| async move {
2097 Ok(http_client::Response::builder()
2098 .status(503)
2099 .body("".into())
2100 .unwrap())
2101 });
2102 client.sign_in(false, &cx.to_async()).await.unwrap_err();
2103 assert_eq!(*auth_count.lock(), 1);
2104
2105 // If credentials became invalid, signing in triggers authentication.
2106 http_client
2107 .as_fake()
2108 .replace_handler(|_, request| async move {
2109 let credentials = parse_authorization_header(&request).unwrap();
2110 if credentials.access_token == "2" {
2111 Ok(http_client::Response::builder()
2112 .status(200)
2113 .body("".into())
2114 .unwrap())
2115 } else {
2116 Ok(http_client::Response::builder()
2117 .status(401)
2118 .body("".into())
2119 .unwrap())
2120 }
2121 });
2122 let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2123 assert_eq!(*auth_count.lock(), 2);
2124 assert_eq!(credentials.access_token, "2");
2125 }
2126
2127 #[gpui::test(iterations = 10)]
2128 async fn test_authenticating_more_than_once(
2129 cx: &mut TestAppContext,
2130 executor: BackgroundExecutor,
2131 ) {
2132 init_test(cx);
2133 let auth_count = Arc::new(Mutex::new(0));
2134 let dropped_auth_count = Arc::new(Mutex::new(0));
2135 let client = cx.update(|cx| {
2136 Client::new(
2137 Arc::new(FakeSystemClock::new()),
2138 FakeHttpClient::with_404_response(),
2139 cx,
2140 )
2141 });
2142 client.override_authenticate({
2143 let auth_count = auth_count.clone();
2144 let dropped_auth_count = dropped_auth_count.clone();
2145 move |cx| {
2146 let auth_count = auth_count.clone();
2147 let dropped_auth_count = dropped_auth_count.clone();
2148 cx.background_spawn(async move {
2149 *auth_count.lock() += 1;
2150 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
2151 future::pending::<()>().await;
2152 unreachable!()
2153 })
2154 }
2155 });
2156
2157 let _authenticate = cx.spawn({
2158 let client = client.clone();
2159 move |cx| async move { client.connect(false, &cx).await }
2160 });
2161 executor.run_until_parked();
2162 assert_eq!(*auth_count.lock(), 1);
2163 assert_eq!(*dropped_auth_count.lock(), 0);
2164
2165 let _authenticate = cx.spawn(|cx| async move { client.connect(false, &cx).await });
2166 executor.run_until_parked();
2167 assert_eq!(*auth_count.lock(), 2);
2168 assert_eq!(*dropped_auth_count.lock(), 1);
2169 }
2170
2171 #[gpui::test]
2172 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
2173 init_test(cx);
2174 let user_id = 5;
2175 let client = cx.update(|cx| {
2176 Client::new(
2177 Arc::new(FakeSystemClock::new()),
2178 FakeHttpClient::with_404_response(),
2179 cx,
2180 )
2181 });
2182 let server = FakeServer::for_client(user_id, &client, cx).await;
2183
2184 let (done_tx1, done_rx1) = smol::channel::unbounded();
2185 let (done_tx2, done_rx2) = smol::channel::unbounded();
2186 AnyProtoClient::from(client.clone()).add_entity_message_handler(
2187 move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, cx| {
2188 match entity.read_with(&cx, |entity, _| entity.id) {
2189 1 => done_tx1.try_send(()).unwrap(),
2190 2 => done_tx2.try_send(()).unwrap(),
2191 _ => unreachable!(),
2192 }
2193 async { Ok(()) }
2194 },
2195 );
2196 let entity1 = cx.new(|_| TestEntity {
2197 id: 1,
2198 subscription: None,
2199 });
2200 let entity2 = cx.new(|_| TestEntity {
2201 id: 2,
2202 subscription: None,
2203 });
2204 let entity3 = cx.new(|_| TestEntity {
2205 id: 3,
2206 subscription: None,
2207 });
2208
2209 let _subscription1 = client
2210 .subscribe_to_entity(1)
2211 .unwrap()
2212 .set_entity(&entity1, &cx.to_async());
2213 let _subscription2 = client
2214 .subscribe_to_entity(2)
2215 .unwrap()
2216 .set_entity(&entity2, &cx.to_async());
2217 // Ensure dropping a subscription for the same entity type still allows receiving of
2218 // messages for other entity IDs of the same type.
2219 let subscription3 = client
2220 .subscribe_to_entity(3)
2221 .unwrap()
2222 .set_entity(&entity3, &cx.to_async());
2223 drop(subscription3);
2224
2225 server.send(proto::JoinProject {
2226 project_id: 1,
2227 committer_name: None,
2228 committer_email: None,
2229 features: Vec::new(),
2230 });
2231 server.send(proto::JoinProject {
2232 project_id: 2,
2233 committer_name: None,
2234 committer_email: None,
2235 features: Vec::new(),
2236 });
2237 done_rx1.recv().await.unwrap();
2238 done_rx2.recv().await.unwrap();
2239 }
2240
2241 #[gpui::test]
2242 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2243 init_test(cx);
2244 let user_id = 5;
2245 let client = cx.update(|cx| {
2246 Client::new(
2247 Arc::new(FakeSystemClock::new()),
2248 FakeHttpClient::with_404_response(),
2249 cx,
2250 )
2251 });
2252 let server = FakeServer::for_client(user_id, &client, cx).await;
2253
2254 let entity = cx.new(|_| TestEntity::default());
2255 let (done_tx1, _done_rx1) = smol::channel::unbounded();
2256 let (done_tx2, done_rx2) = smol::channel::unbounded();
2257 let subscription1 = client.add_message_handler(
2258 entity.downgrade(),
2259 move |_, _: TypedEnvelope<proto::Ping>, _| {
2260 done_tx1.try_send(()).unwrap();
2261 async { Ok(()) }
2262 },
2263 );
2264 drop(subscription1);
2265 let _subscription2 = client.add_message_handler(
2266 entity.downgrade(),
2267 move |_, _: TypedEnvelope<proto::Ping>, _| {
2268 done_tx2.try_send(()).unwrap();
2269 async { Ok(()) }
2270 },
2271 );
2272 server.send(proto::Ping {});
2273 done_rx2.recv().await.unwrap();
2274 }
2275
2276 #[gpui::test]
2277 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2278 init_test(cx);
2279 let user_id = 5;
2280 let client = cx.update(|cx| {
2281 Client::new(
2282 Arc::new(FakeSystemClock::new()),
2283 FakeHttpClient::with_404_response(),
2284 cx,
2285 )
2286 });
2287 let server = FakeServer::for_client(user_id, &client, cx).await;
2288
2289 let entity = cx.new(|_| TestEntity::default());
2290 let (done_tx, done_rx) = smol::channel::unbounded();
2291 let subscription = client.add_message_handler(
2292 entity.clone().downgrade(),
2293 move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
2294 entity
2295 .update(&mut cx, |entity, _| entity.subscription.take())
2296 .unwrap();
2297 done_tx.try_send(()).unwrap();
2298 async { Ok(()) }
2299 },
2300 );
2301 entity.update(cx, |entity, _| {
2302 entity.subscription = Some(subscription);
2303 });
2304 server.send(proto::Ping {});
2305 done_rx.recv().await.unwrap();
2306 }
2307
2308 #[derive(Default)]
2309 struct TestEntity {
2310 id: usize,
2311 subscription: Option<Subscription>,
2312 }
2313
2314 fn init_test(cx: &mut TestAppContext) {
2315 cx.update(|cx| {
2316 let settings_store = SettingsStore::test(cx);
2317 cx.set_global(settings_store);
2318 });
2319 }
2320}