client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod proxy;
   5pub mod telemetry;
   6pub mod user;
   7pub mod zed_urls;
   8
   9use anyhow::{Context as _, Result, anyhow};
  10use async_tungstenite::tungstenite::{
  11    client::IntoClientRequest,
  12    error::Error as WebsocketError,
  13    http::{HeaderValue, Request, StatusCode},
  14};
  15use clock::SystemClock;
  16use cloud_api_client::CloudApiClient;
  17use cloud_api_client::websocket_protocol::MessageToClient;
  18use credentials_provider::CredentialsProvider;
  19use feature_flags::FeatureFlagAppExt as _;
  20use futures::{
  21    AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
  22    channel::oneshot, future::BoxFuture,
  23};
  24use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
  25use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
  26use parking_lot::RwLock;
  27use postage::watch;
  28use proxy::connect_proxy_stream;
  29use rand::prelude::*;
  30use release_channel::{AppVersion, ReleaseChannel};
  31use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  32use serde::{Deserialize, Serialize};
  33use settings::{RegisterSetting, Settings, SettingsContent};
  34use std::{
  35    any::TypeId,
  36    convert::TryFrom,
  37    fmt::Write as _,
  38    future::Future,
  39    marker::PhantomData,
  40    path::PathBuf,
  41    sync::{
  42        Arc, LazyLock, Weak,
  43        atomic::{AtomicU64, Ordering},
  44    },
  45    time::{Duration, Instant},
  46};
  47use std::{cmp, pin::Pin};
  48use telemetry::Telemetry;
  49use thiserror::Error;
  50use tokio::net::TcpStream;
  51use url::Url;
  52use util::{ConnectionResult, ResultExt};
  53
  54pub use rpc::*;
  55pub use telemetry_events::Event;
  56pub use user::*;
  57
  58static ZED_SERVER_URL: LazyLock<Option<String>> =
  59    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  60static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  61
  62pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  63    std::env::var("ZED_IMPERSONATE")
  64        .ok()
  65        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  66});
  67
  68pub static USE_WEB_LOGIN: LazyLock<bool> = LazyLock::new(|| std::env::var("ZED_WEB_LOGIN").is_ok());
  69
  70pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  71    std::env::var("ZED_ADMIN_API_TOKEN")
  72        .ok()
  73        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  74});
  75
  76pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  77    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  78
  79pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  80    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").is_ok_and(|e| !e.is_empty()));
  81
  82pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  83pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
  84pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  85
  86actions!(
  87    client,
  88    [
  89        /// Signs in to Zed account.
  90        SignIn,
  91        /// Signs out of Zed account.
  92        SignOut,
  93        /// Reconnects to the collaboration server.
  94        Reconnect
  95    ]
  96);
  97
  98#[derive(Deserialize, RegisterSetting)]
  99pub struct ClientSettings {
 100    pub server_url: String,
 101}
 102
 103impl Settings for ClientSettings {
 104    fn from_settings(content: &settings::SettingsContent) -> Self {
 105        if let Some(server_url) = &*ZED_SERVER_URL {
 106            return Self {
 107                server_url: server_url.clone(),
 108            };
 109        }
 110        Self {
 111            server_url: content.server_url.clone().unwrap(),
 112        }
 113    }
 114}
 115
 116#[derive(Deserialize, Default, RegisterSetting)]
 117pub struct ProxySettings {
 118    pub proxy: Option<String>,
 119}
 120
 121impl ProxySettings {
 122    pub fn proxy_url(&self) -> Option<Url> {
 123        self.proxy
 124            .as_ref()
 125            .and_then(|input| {
 126                input
 127                    .parse::<Url>()
 128                    .inspect_err(|e| log::error!("Error parsing proxy settings: {}", e))
 129                    .ok()
 130            })
 131            .or_else(read_proxy_from_env)
 132    }
 133}
 134
 135impl Settings for ProxySettings {
 136    fn from_settings(content: &settings::SettingsContent) -> Self {
 137        Self {
 138            proxy: content.proxy.clone(),
 139        }
 140    }
 141}
 142
 143pub fn init(client: &Arc<Client>, cx: &mut App) {
 144    let client = Arc::downgrade(client);
 145    cx.on_action({
 146        let client = client.clone();
 147        move |_: &SignIn, cx| {
 148            if let Some(client) = client.upgrade() {
 149                cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, cx).await)
 150                    .detach_and_log_err(cx);
 151            }
 152        }
 153    });
 154
 155    cx.on_action({
 156        let client = client.clone();
 157        move |_: &SignOut, cx| {
 158            if let Some(client) = client.upgrade() {
 159                cx.spawn(async move |cx| {
 160                    client.sign_out(cx).await;
 161                })
 162                .detach();
 163            }
 164        }
 165    });
 166
 167    cx.on_action({
 168        let client = client;
 169        move |_: &Reconnect, cx| {
 170            if let Some(client) = client.upgrade() {
 171                cx.spawn(async move |cx| {
 172                    client.reconnect(cx);
 173                })
 174                .detach();
 175            }
 176        }
 177    });
 178}
 179
 180pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
 181
 182struct GlobalClient(Arc<Client>);
 183
 184impl Global for GlobalClient {}
 185
 186pub struct Client {
 187    id: AtomicU64,
 188    peer: Arc<Peer>,
 189    http: Arc<HttpClientWithUrl>,
 190    cloud_client: Arc<CloudApiClient>,
 191    telemetry: Arc<Telemetry>,
 192    credentials_provider: ClientCredentialsProvider,
 193    state: RwLock<ClientState>,
 194    handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
 195    message_to_client_handlers: parking_lot::Mutex<Vec<MessageToClientHandler>>,
 196
 197    #[allow(clippy::type_complexity)]
 198    #[cfg(any(test, feature = "test-support"))]
 199    authenticate:
 200        RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
 201
 202    #[allow(clippy::type_complexity)]
 203    #[cfg(any(test, feature = "test-support"))]
 204    establish_connection: RwLock<
 205        Option<
 206            Box<
 207                dyn 'static
 208                    + Send
 209                    + Sync
 210                    + Fn(
 211                        &Credentials,
 212                        &AsyncApp,
 213                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 214            >,
 215        >,
 216    >,
 217
 218    #[cfg(any(test, feature = "test-support"))]
 219    rpc_url: RwLock<Option<Url>>,
 220}
 221
 222#[derive(Error, Debug)]
 223pub enum EstablishConnectionError {
 224    #[error("upgrade required")]
 225    UpgradeRequired,
 226    #[error("unauthorized")]
 227    Unauthorized,
 228    #[error("{0}")]
 229    Other(#[from] anyhow::Error),
 230    #[error("{0}")]
 231    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 232    #[error("{0}")]
 233    Io(#[from] std::io::Error),
 234    #[error("{0}")]
 235    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 236}
 237
 238impl From<WebsocketError> for EstablishConnectionError {
 239    fn from(error: WebsocketError) -> Self {
 240        if let WebsocketError::Http(response) = &error {
 241            match response.status() {
 242                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 243                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 244                _ => {}
 245            }
 246        }
 247        EstablishConnectionError::Other(error.into())
 248    }
 249}
 250
 251impl EstablishConnectionError {
 252    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 253        Self::Other(error.into())
 254    }
 255}
 256
 257#[derive(Copy, Clone, Debug, PartialEq)]
 258pub enum Status {
 259    SignedOut,
 260    UpgradeRequired,
 261    Authenticating,
 262    Authenticated,
 263    AuthenticationError,
 264    Connecting,
 265    ConnectionError,
 266    Connected {
 267        peer_id: PeerId,
 268        connection_id: ConnectionId,
 269    },
 270    ConnectionLost,
 271    Reauthenticating,
 272    Reauthenticated,
 273    Reconnecting,
 274    ReconnectionError {
 275        next_reconnection: Instant,
 276    },
 277}
 278
 279impl Status {
 280    pub fn is_connected(&self) -> bool {
 281        matches!(self, Self::Connected { .. })
 282    }
 283
 284    pub fn was_connected(&self) -> bool {
 285        matches!(
 286            self,
 287            Self::ConnectionLost
 288                | Self::Reauthenticating
 289                | Self::Reauthenticated
 290                | Self::Reconnecting
 291        )
 292    }
 293
 294    /// Returns whether the client is currently connected or was connected at some point.
 295    pub fn is_or_was_connected(&self) -> bool {
 296        self.is_connected() || self.was_connected()
 297    }
 298
 299    pub fn is_signing_in(&self) -> bool {
 300        matches!(
 301            self,
 302            Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
 303        )
 304    }
 305
 306    pub fn is_signed_out(&self) -> bool {
 307        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 308    }
 309}
 310
 311struct ClientState {
 312    credentials: Option<Credentials>,
 313    status: (watch::Sender<Status>, watch::Receiver<Status>),
 314    _reconnect_task: Option<Task<()>>,
 315}
 316
 317#[derive(Clone, Debug, Eq, PartialEq)]
 318pub struct Credentials {
 319    pub user_id: u64,
 320    pub access_token: String,
 321}
 322
 323impl Credentials {
 324    pub fn authorization_header(&self) -> String {
 325        format!("{} {}", self.user_id, self.access_token)
 326    }
 327}
 328
 329pub struct ClientCredentialsProvider {
 330    provider: Arc<dyn CredentialsProvider>,
 331}
 332
 333impl ClientCredentialsProvider {
 334    pub fn new(cx: &App) -> Self {
 335        Self {
 336            provider: <dyn CredentialsProvider>::global(cx),
 337        }
 338    }
 339
 340    fn server_url(&self, cx: &AsyncApp) -> Result<String> {
 341        cx.update(|cx| ClientSettings::get_global(cx).server_url.clone())
 342    }
 343
 344    /// Reads the credentials from the provider.
 345    fn read_credentials<'a>(
 346        &'a self,
 347        cx: &'a AsyncApp,
 348    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
 349        async move {
 350            if IMPERSONATE_LOGIN.is_some() {
 351                return None;
 352            }
 353
 354            let server_url = self.server_url(cx).ok()?;
 355            let (user_id, access_token) = self
 356                .provider
 357                .read_credentials(&server_url, cx)
 358                .await
 359                .log_err()
 360                .flatten()?;
 361
 362            Some(Credentials {
 363                user_id: user_id.parse().ok()?,
 364                access_token: String::from_utf8(access_token).ok()?,
 365            })
 366        }
 367        .boxed_local()
 368    }
 369
 370    /// Writes the credentials to the provider.
 371    fn write_credentials<'a>(
 372        &'a self,
 373        user_id: u64,
 374        access_token: String,
 375        cx: &'a AsyncApp,
 376    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 377        async move {
 378            let server_url = self.server_url(cx)?;
 379            self.provider
 380                .write_credentials(
 381                    &server_url,
 382                    &user_id.to_string(),
 383                    access_token.as_bytes(),
 384                    cx,
 385                )
 386                .await
 387        }
 388        .boxed_local()
 389    }
 390
 391    /// Deletes the credentials from the provider.
 392    fn delete_credentials<'a>(
 393        &'a self,
 394        cx: &'a AsyncApp,
 395    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 396        async move {
 397            let server_url = self.server_url(cx)?;
 398            self.provider.delete_credentials(&server_url, cx).await
 399        }
 400        .boxed_local()
 401    }
 402}
 403
 404impl Default for ClientState {
 405    fn default() -> Self {
 406        Self {
 407            credentials: None,
 408            status: watch::channel_with(Status::SignedOut),
 409            _reconnect_task: None,
 410        }
 411    }
 412}
 413
 414pub enum Subscription {
 415    Entity {
 416        client: Weak<Client>,
 417        id: (TypeId, u64),
 418    },
 419    Message {
 420        client: Weak<Client>,
 421        id: TypeId,
 422    },
 423}
 424
 425impl Drop for Subscription {
 426    fn drop(&mut self) {
 427        match self {
 428            Subscription::Entity { client, id } => {
 429                if let Some(client) = client.upgrade() {
 430                    let mut state = client.handler_set.lock();
 431                    let _ = state.entities_by_type_and_remote_id.remove(id);
 432                }
 433            }
 434            Subscription::Message { client, id } => {
 435                if let Some(client) = client.upgrade() {
 436                    let mut state = client.handler_set.lock();
 437                    let _ = state.entity_types_by_message_type.remove(id);
 438                    let _ = state.message_handlers.remove(id);
 439                }
 440            }
 441        }
 442    }
 443}
 444
 445pub struct PendingEntitySubscription<T: 'static> {
 446    client: Arc<Client>,
 447    remote_id: u64,
 448    _entity_type: PhantomData<T>,
 449    consumed: bool,
 450}
 451
 452impl<T: 'static> PendingEntitySubscription<T> {
 453    pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
 454        self.consumed = true;
 455        let mut handlers = self.client.handler_set.lock();
 456        let id = (TypeId::of::<T>(), self.remote_id);
 457        let Some(EntityMessageSubscriber::Pending(messages)) =
 458            handlers.entities_by_type_and_remote_id.remove(&id)
 459        else {
 460            unreachable!()
 461        };
 462
 463        handlers.entities_by_type_and_remote_id.insert(
 464            id,
 465            EntityMessageSubscriber::Entity {
 466                handle: entity.downgrade().into(),
 467            },
 468        );
 469        drop(handlers);
 470        for message in messages {
 471            let client_id = self.client.id();
 472            let type_name = message.payload_type_name();
 473            let sender_id = message.original_sender_id();
 474            log::debug!(
 475                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 476                client_id,
 477                sender_id,
 478                type_name
 479            );
 480            self.client.handle_message(message, cx);
 481        }
 482        Subscription::Entity {
 483            client: Arc::downgrade(&self.client),
 484            id,
 485        }
 486    }
 487}
 488
 489impl<T: 'static> Drop for PendingEntitySubscription<T> {
 490    fn drop(&mut self) {
 491        if !self.consumed {
 492            let mut state = self.client.handler_set.lock();
 493            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 494                .entities_by_type_and_remote_id
 495                .remove(&(TypeId::of::<T>(), self.remote_id))
 496            {
 497                for message in messages {
 498                    log::info!("unhandled message {}", message.payload_type_name());
 499                }
 500            }
 501        }
 502    }
 503}
 504
 505#[derive(Copy, Clone, Deserialize, Debug, RegisterSetting)]
 506pub struct TelemetrySettings {
 507    pub diagnostics: bool,
 508    pub metrics: bool,
 509}
 510
 511impl settings::Settings for TelemetrySettings {
 512    fn from_settings(content: &SettingsContent) -> Self {
 513        Self {
 514            diagnostics: content.telemetry.as_ref().unwrap().diagnostics.unwrap(),
 515            metrics: content.telemetry.as_ref().unwrap().metrics.unwrap(),
 516        }
 517    }
 518}
 519
 520impl Client {
 521    pub fn new(
 522        clock: Arc<dyn SystemClock>,
 523        http: Arc<HttpClientWithUrl>,
 524        cx: &mut App,
 525    ) -> Arc<Self> {
 526        Arc::new(Self {
 527            id: AtomicU64::new(0),
 528            peer: Peer::new(0),
 529            telemetry: Telemetry::new(clock, http.clone(), cx),
 530            cloud_client: Arc::new(CloudApiClient::new(http.clone())),
 531            http,
 532            credentials_provider: ClientCredentialsProvider::new(cx),
 533            state: Default::default(),
 534            handler_set: Default::default(),
 535            message_to_client_handlers: parking_lot::Mutex::new(Vec::new()),
 536
 537            #[cfg(any(test, feature = "test-support"))]
 538            authenticate: Default::default(),
 539            #[cfg(any(test, feature = "test-support"))]
 540            establish_connection: Default::default(),
 541            #[cfg(any(test, feature = "test-support"))]
 542            rpc_url: RwLock::default(),
 543        })
 544    }
 545
 546    pub fn production(cx: &mut App) -> Arc<Self> {
 547        let clock = Arc::new(clock::RealSystemClock);
 548        let http = Arc::new(HttpClientWithUrl::new_url(
 549            cx.http_client(),
 550            &ClientSettings::get_global(cx).server_url,
 551            cx.http_client().proxy().cloned(),
 552        ));
 553        Self::new(clock, http, cx)
 554    }
 555
 556    pub fn id(&self) -> u64 {
 557        self.id.load(Ordering::SeqCst)
 558    }
 559
 560    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 561        self.http.clone()
 562    }
 563
 564    pub fn cloud_client(&self) -> Arc<CloudApiClient> {
 565        self.cloud_client.clone()
 566    }
 567
 568    pub fn set_id(&self, id: u64) -> &Self {
 569        self.id.store(id, Ordering::SeqCst);
 570        self
 571    }
 572
 573    #[cfg(any(test, feature = "test-support"))]
 574    pub fn teardown(&self) {
 575        let mut state = self.state.write();
 576        state._reconnect_task.take();
 577        self.handler_set.lock().clear();
 578        self.peer.teardown();
 579    }
 580
 581    #[cfg(any(test, feature = "test-support"))]
 582    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 583    where
 584        F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
 585    {
 586        *self.authenticate.write() = Some(Box::new(authenticate));
 587        self
 588    }
 589
 590    #[cfg(any(test, feature = "test-support"))]
 591    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 592    where
 593        F: 'static
 594            + Send
 595            + Sync
 596            + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
 597    {
 598        *self.establish_connection.write() = Some(Box::new(connect));
 599        self
 600    }
 601
 602    #[cfg(any(test, feature = "test-support"))]
 603    pub fn override_rpc_url(&self, url: Url) -> &Self {
 604        *self.rpc_url.write() = Some(url);
 605        self
 606    }
 607
 608    pub fn global(cx: &App) -> Arc<Self> {
 609        cx.global::<GlobalClient>().0.clone()
 610    }
 611    pub fn set_global(client: Arc<Client>, cx: &mut App) {
 612        cx.set_global(GlobalClient(client))
 613    }
 614
 615    pub fn user_id(&self) -> Option<u64> {
 616        self.state
 617            .read()
 618            .credentials
 619            .as_ref()
 620            .map(|credentials| credentials.user_id)
 621    }
 622
 623    pub fn peer_id(&self) -> Option<PeerId> {
 624        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 625            Some(*peer_id)
 626        } else {
 627            None
 628        }
 629    }
 630
 631    pub fn status(&self) -> watch::Receiver<Status> {
 632        self.state.read().status.1.clone()
 633    }
 634
 635    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
 636        log::info!("set status on client {}: {:?}", self.id(), status);
 637        let mut state = self.state.write();
 638        *state.status.0.borrow_mut() = status;
 639
 640        match status {
 641            Status::Connected { .. } => {
 642                state._reconnect_task = None;
 643            }
 644            Status::ConnectionLost => {
 645                let client = self.clone();
 646                state._reconnect_task = Some(cx.spawn(async move |cx| {
 647                    #[cfg(any(test, feature = "test-support"))]
 648                    let mut rng = StdRng::seed_from_u64(0);
 649                    #[cfg(not(any(test, feature = "test-support")))]
 650                    let mut rng = StdRng::from_os_rng();
 651
 652                    let mut delay = INITIAL_RECONNECTION_DELAY;
 653                    loop {
 654                        match client.connect(true, cx).await {
 655                            ConnectionResult::Timeout => {
 656                                log::error!("client connect attempt timed out")
 657                            }
 658                            ConnectionResult::ConnectionReset => {
 659                                log::error!("client connect attempt reset")
 660                            }
 661                            ConnectionResult::Result(r) => {
 662                                if let Err(error) = r {
 663                                    log::error!("failed to connect: {error}");
 664                                } else {
 665                                    break;
 666                                }
 667                            }
 668                        }
 669
 670                        if matches!(
 671                            *client.status().borrow(),
 672                            Status::AuthenticationError | Status::ConnectionError
 673                        ) {
 674                            client.set_status(
 675                                Status::ReconnectionError {
 676                                    next_reconnection: Instant::now() + delay,
 677                                },
 678                                cx,
 679                            );
 680                            let jitter = Duration::from_millis(
 681                                rng.random_range(0..delay.as_millis() as u64),
 682                            );
 683                            cx.background_executor().timer(delay + jitter).await;
 684                            delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
 685                        } else {
 686                            break;
 687                        }
 688                    }
 689                }));
 690            }
 691            Status::SignedOut | Status::UpgradeRequired => {
 692                self.telemetry.set_authenticated_user_info(None, false);
 693                state._reconnect_task.take();
 694            }
 695            _ => {}
 696        }
 697    }
 698
 699    pub fn subscribe_to_entity<T>(
 700        self: &Arc<Self>,
 701        remote_id: u64,
 702    ) -> Result<PendingEntitySubscription<T>>
 703    where
 704        T: 'static,
 705    {
 706        let id = (TypeId::of::<T>(), remote_id);
 707
 708        let mut state = self.handler_set.lock();
 709        anyhow::ensure!(
 710            !state.entities_by_type_and_remote_id.contains_key(&id),
 711            "already subscribed to entity"
 712        );
 713
 714        state
 715            .entities_by_type_and_remote_id
 716            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 717
 718        Ok(PendingEntitySubscription {
 719            client: self.clone(),
 720            remote_id,
 721            consumed: false,
 722            _entity_type: PhantomData,
 723        })
 724    }
 725
 726    #[track_caller]
 727    pub fn add_message_handler<M, E, H, F>(
 728        self: &Arc<Self>,
 729        entity: WeakEntity<E>,
 730        handler: H,
 731    ) -> Subscription
 732    where
 733        M: EnvelopedMessage,
 734        E: 'static,
 735        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 736        F: 'static + Future<Output = Result<()>>,
 737    {
 738        self.add_message_handler_impl(entity, move |entity, message, _, cx| {
 739            handler(entity, message, cx)
 740        })
 741    }
 742
 743    fn add_message_handler_impl<M, E, H, F>(
 744        self: &Arc<Self>,
 745        entity: WeakEntity<E>,
 746        handler: H,
 747    ) -> Subscription
 748    where
 749        M: EnvelopedMessage,
 750        E: 'static,
 751        H: 'static
 752            + Sync
 753            + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
 754            + Send
 755            + Sync,
 756        F: 'static + Future<Output = Result<()>>,
 757    {
 758        let message_type_id = TypeId::of::<M>();
 759        let mut state = self.handler_set.lock();
 760        state
 761            .entities_by_message_type
 762            .insert(message_type_id, entity.into());
 763
 764        let prev_handler = state.message_handlers.insert(
 765            message_type_id,
 766            Arc::new(move |subscriber, envelope, client, cx| {
 767                let subscriber = subscriber.downcast::<E>().unwrap();
 768                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 769                handler(subscriber, *envelope, client, cx).boxed_local()
 770            }),
 771        );
 772        if prev_handler.is_some() {
 773            let location = std::panic::Location::caller();
 774            panic!(
 775                "{}:{} registered handler for the same message {} twice",
 776                location.file(),
 777                location.line(),
 778                std::any::type_name::<M>()
 779            );
 780        }
 781
 782        Subscription::Message {
 783            client: Arc::downgrade(self),
 784            id: message_type_id,
 785        }
 786    }
 787
 788    pub fn add_request_handler<M, E, H, F>(
 789        self: &Arc<Self>,
 790        entity: WeakEntity<E>,
 791        handler: H,
 792    ) -> Subscription
 793    where
 794        M: RequestMessage,
 795        E: 'static,
 796        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 797        F: 'static + Future<Output = Result<M::Response>>,
 798    {
 799        self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
 800            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 801        })
 802    }
 803
 804    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 805        receipt: Receipt<T>,
 806        response: F,
 807        client: AnyProtoClient,
 808    ) -> Result<()> {
 809        match response.await {
 810            Ok(response) => {
 811                client.send_response(receipt.message_id, response)?;
 812                Ok(())
 813            }
 814            Err(error) => {
 815                client.send_response(receipt.message_id, error.to_proto())?;
 816                Err(error)
 817            }
 818        }
 819    }
 820
 821    pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
 822        self.credentials_provider
 823            .read_credentials(cx)
 824            .await
 825            .is_some()
 826    }
 827
 828    pub async fn sign_in(
 829        self: &Arc<Self>,
 830        try_provider: bool,
 831        cx: &AsyncApp,
 832    ) -> Result<Credentials> {
 833        let is_reauthenticating = if self.status().borrow().is_signed_out() {
 834            self.set_status(Status::Authenticating, cx);
 835            false
 836        } else {
 837            self.set_status(Status::Reauthenticating, cx);
 838            true
 839        };
 840
 841        let mut credentials = None;
 842
 843        let old_credentials = self.state.read().credentials.clone();
 844        if let Some(old_credentials) = old_credentials
 845            && self.validate_credentials(&old_credentials, cx).await?
 846        {
 847            credentials = Some(old_credentials);
 848        }
 849
 850        if credentials.is_none()
 851            && try_provider
 852            && let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await
 853        {
 854            if self.validate_credentials(&stored_credentials, cx).await? {
 855                credentials = Some(stored_credentials);
 856            } else {
 857                self.credentials_provider
 858                    .delete_credentials(cx)
 859                    .await
 860                    .log_err();
 861            }
 862        }
 863
 864        if credentials.is_none() {
 865            let mut status_rx = self.status();
 866            let _ = status_rx.next().await;
 867            futures::select_biased! {
 868                authenticate = self.authenticate(cx).fuse() => {
 869                    match authenticate {
 870                        Ok(creds) => {
 871                            if IMPERSONATE_LOGIN.is_none() {
 872                                self.credentials_provider
 873                                    .write_credentials(creds.user_id, creds.access_token.clone(), cx)
 874                                    .await
 875                                    .log_err();
 876                            }
 877
 878                            credentials = Some(creds);
 879                        },
 880                        Err(err) => {
 881                            self.set_status(Status::AuthenticationError, cx);
 882                            return Err(err);
 883                        }
 884                    }
 885                }
 886                _ = status_rx.next().fuse() => {
 887                    return Err(anyhow!("authentication canceled"));
 888                }
 889            }
 890        }
 891
 892        let credentials = credentials.unwrap();
 893        self.set_id(credentials.user_id);
 894        self.cloud_client
 895            .set_credentials(credentials.user_id as u32, credentials.access_token.clone());
 896        self.state.write().credentials = Some(credentials.clone());
 897        self.set_status(
 898            if is_reauthenticating {
 899                Status::Reauthenticated
 900            } else {
 901                Status::Authenticated
 902            },
 903            cx,
 904        );
 905
 906        Ok(credentials)
 907    }
 908
 909    async fn validate_credentials(
 910        self: &Arc<Self>,
 911        credentials: &Credentials,
 912        cx: &AsyncApp,
 913    ) -> Result<bool> {
 914        match self
 915            .cloud_client
 916            .validate_credentials(credentials.user_id as u32, &credentials.access_token)
 917            .await
 918        {
 919            Ok(valid) => Ok(valid),
 920            Err(err) => {
 921                self.set_status(Status::AuthenticationError, cx);
 922                Err(anyhow!("failed to validate credentials: {}", err))
 923            }
 924        }
 925    }
 926
 927    /// Establishes a WebSocket connection with Cloud for receiving updates from the server.
 928    async fn connect_to_cloud(self: &Arc<Self>, cx: &AsyncApp) -> Result<()> {
 929        let connect_task = cx.update({
 930            let cloud_client = self.cloud_client.clone();
 931            move |cx| cloud_client.connect(cx)
 932        })??;
 933        let connection = connect_task.await?;
 934
 935        let (mut messages, task) = cx.update(|cx| connection.spawn(cx))?;
 936        task.detach();
 937
 938        cx.spawn({
 939            let this = self.clone();
 940            async move |cx| {
 941                while let Some(message) = messages.next().await {
 942                    if let Some(message) = message.log_err() {
 943                        this.handle_message_to_client(message, cx);
 944                    }
 945                }
 946            }
 947        })
 948        .detach();
 949
 950        Ok(())
 951    }
 952
 953    /// Performs a sign-in and also (optionally) connects to Collab.
 954    ///
 955    /// Only Zed staff automatically connect to Collab.
 956    pub async fn sign_in_with_optional_connect(
 957        self: &Arc<Self>,
 958        try_provider: bool,
 959        cx: &AsyncApp,
 960    ) -> Result<()> {
 961        // Don't try to sign in again if we're already connected to Collab, as it will temporarily disconnect us.
 962        if self.status().borrow().is_connected() {
 963            return Ok(());
 964        }
 965
 966        let (is_staff_tx, is_staff_rx) = oneshot::channel::<bool>();
 967        let mut is_staff_tx = Some(is_staff_tx);
 968        cx.update(|cx| {
 969            cx.on_flags_ready(move |state, _cx| {
 970                if let Some(is_staff_tx) = is_staff_tx.take() {
 971                    is_staff_tx.send(state.is_staff).log_err();
 972                }
 973            })
 974            .detach();
 975        })
 976        .log_err();
 977
 978        let credentials = self.sign_in(try_provider, cx).await?;
 979
 980        self.connect_to_cloud(cx).await.log_err();
 981
 982        cx.update(move |cx| {
 983            cx.spawn({
 984                let client = self.clone();
 985                async move |cx| {
 986                    let is_staff = is_staff_rx.await?;
 987                    if is_staff {
 988                        match client.connect_with_credentials(credentials, cx).await {
 989                            ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
 990                            ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
 991                            ConnectionResult::Result(result) => {
 992                                result.context("client auth and connect")
 993                            }
 994                        }
 995                    } else {
 996                        Ok(())
 997                    }
 998                }
 999            })
1000            .detach_and_log_err(cx);
1001        })
1002        .log_err();
1003
1004        Ok(())
1005    }
1006
1007    pub async fn connect(
1008        self: &Arc<Self>,
1009        try_provider: bool,
1010        cx: &AsyncApp,
1011    ) -> ConnectionResult<()> {
1012        let was_disconnected = match *self.status().borrow() {
1013            Status::SignedOut | Status::Authenticated => true,
1014            Status::ConnectionError
1015            | Status::ConnectionLost
1016            | Status::Authenticating
1017            | Status::AuthenticationError
1018            | Status::Reauthenticating
1019            | Status::Reauthenticated
1020            | Status::ReconnectionError { .. } => false,
1021            Status::Connected { .. } | Status::Connecting | Status::Reconnecting => {
1022                return ConnectionResult::Result(Ok(()));
1023            }
1024            Status::UpgradeRequired => {
1025                return ConnectionResult::Result(
1026                    Err(EstablishConnectionError::UpgradeRequired)
1027                        .context("client auth and connect"),
1028                );
1029            }
1030        };
1031        let credentials = match self.sign_in(try_provider, cx).await {
1032            Ok(credentials) => credentials,
1033            Err(err) => return ConnectionResult::Result(Err(err)),
1034        };
1035
1036        if was_disconnected {
1037            self.set_status(Status::Connecting, cx);
1038        } else {
1039            self.set_status(Status::Reconnecting, cx);
1040        }
1041
1042        self.connect_with_credentials(credentials, cx).await
1043    }
1044
1045    async fn connect_with_credentials(
1046        self: &Arc<Self>,
1047        credentials: Credentials,
1048        cx: &AsyncApp,
1049    ) -> ConnectionResult<()> {
1050        let mut timeout =
1051            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
1052        futures::select_biased! {
1053            connection = self.establish_connection(&credentials, cx).fuse() => {
1054                match connection {
1055                    Ok(conn) => {
1056                        futures::select_biased! {
1057                            result = self.set_connection(conn, cx).fuse() => {
1058                                match result.context("client auth and connect") {
1059                                    Ok(()) => ConnectionResult::Result(Ok(())),
1060                                    Err(err) => {
1061                                        self.set_status(Status::ConnectionError, cx);
1062                                        ConnectionResult::Result(Err(err))
1063                                    },
1064                                }
1065                            },
1066                            _ = timeout => {
1067                                self.set_status(Status::ConnectionError, cx);
1068                                ConnectionResult::Timeout
1069                            }
1070                        }
1071                    }
1072                    Err(EstablishConnectionError::Unauthorized) => {
1073                        self.set_status(Status::ConnectionError, cx);
1074                        ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
1075                    }
1076                    Err(EstablishConnectionError::UpgradeRequired) => {
1077                        self.set_status(Status::UpgradeRequired, cx);
1078                        ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
1079                    }
1080                    Err(error) => {
1081                        self.set_status(Status::ConnectionError, cx);
1082                        ConnectionResult::Result(Err(error).context("client auth and connect"))
1083                    }
1084                }
1085            }
1086            _ = &mut timeout => {
1087                self.set_status(Status::ConnectionError, cx);
1088                ConnectionResult::Timeout
1089            }
1090        }
1091    }
1092
1093    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
1094        let executor = cx.background_executor();
1095        log::debug!("add connection to peer");
1096        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
1097            let executor = executor.clone();
1098            move |duration| executor.timer(duration)
1099        });
1100        let handle_io = executor.spawn(handle_io);
1101
1102        let peer_id = async {
1103            log::debug!("waiting for server hello");
1104            let message = incoming.next().await.context("no hello message received")?;
1105            log::debug!("got server hello");
1106            let hello_message_type_name = message.payload_type_name().to_string();
1107            let hello = message
1108                .into_any()
1109                .downcast::<TypedEnvelope<proto::Hello>>()
1110                .map_err(|_| {
1111                    anyhow!(
1112                        "invalid hello message received: {:?}",
1113                        hello_message_type_name
1114                    )
1115                })?;
1116            let peer_id = hello.payload.peer_id.context("invalid peer id")?;
1117            Ok(peer_id)
1118        };
1119
1120        let peer_id = match peer_id.await {
1121            Ok(peer_id) => peer_id,
1122            Err(error) => {
1123                self.peer.disconnect(connection_id);
1124                return Err(error);
1125            }
1126        };
1127
1128        log::debug!(
1129            "set status to connected (connection id: {:?}, peer id: {:?})",
1130            connection_id,
1131            peer_id
1132        );
1133        self.set_status(
1134            Status::Connected {
1135                peer_id,
1136                connection_id,
1137            },
1138            cx,
1139        );
1140
1141        cx.spawn({
1142            let this = self.clone();
1143            async move |cx| {
1144                while let Some(message) = incoming.next().await {
1145                    this.handle_message(message, cx);
1146                    // Don't starve the main thread when receiving lots of messages at once.
1147                    smol::future::yield_now().await;
1148                }
1149            }
1150        })
1151        .detach();
1152
1153        cx.spawn({
1154            let this = self.clone();
1155            async move |cx| match handle_io.await {
1156                Ok(()) => {
1157                    if *this.status().borrow()
1158                        == (Status::Connected {
1159                            connection_id,
1160                            peer_id,
1161                        })
1162                    {
1163                        this.set_status(Status::SignedOut, cx);
1164                    }
1165                }
1166                Err(err) => {
1167                    log::error!("connection error: {:?}", err);
1168                    this.set_status(Status::ConnectionLost, cx);
1169                }
1170            }
1171        })
1172        .detach();
1173
1174        Ok(())
1175    }
1176
1177    fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1178        #[cfg(any(test, feature = "test-support"))]
1179        if let Some(callback) = self.authenticate.read().as_ref() {
1180            return callback(cx);
1181        }
1182
1183        self.authenticate_with_browser(cx)
1184    }
1185
1186    fn establish_connection(
1187        self: &Arc<Self>,
1188        credentials: &Credentials,
1189        cx: &AsyncApp,
1190    ) -> Task<Result<Connection, EstablishConnectionError>> {
1191        #[cfg(any(test, feature = "test-support"))]
1192        if let Some(callback) = self.establish_connection.read().as_ref() {
1193            return callback(credentials, cx);
1194        }
1195
1196        self.establish_websocket_connection(credentials, cx)
1197    }
1198
1199    fn rpc_url(
1200        &self,
1201        http: Arc<HttpClientWithUrl>,
1202        release_channel: Option<ReleaseChannel>,
1203    ) -> impl Future<Output = Result<url::Url>> + use<> {
1204        #[cfg(any(test, feature = "test-support"))]
1205        let url_override = self.rpc_url.read().clone();
1206
1207        async move {
1208            #[cfg(any(test, feature = "test-support"))]
1209            if let Some(url) = url_override {
1210                return Ok(url);
1211            }
1212
1213            if let Some(url) = &*ZED_RPC_URL {
1214                return Url::parse(url).context("invalid rpc url");
1215            }
1216
1217            let mut url = http.build_url("/rpc");
1218            if let Some(preview_param) =
1219                release_channel.and_then(|channel| channel.release_query_param())
1220            {
1221                url += "?";
1222                url += preview_param;
1223            }
1224
1225            let response = http.get(&url, Default::default(), false).await?;
1226            anyhow::ensure!(
1227                response.status().is_redirection(),
1228                "unexpected /rpc response status {}",
1229                response.status()
1230            );
1231            let collab_url = response
1232                .headers()
1233                .get("Location")
1234                .context("missing location header in /rpc response")?
1235                .to_str()
1236                .map_err(EstablishConnectionError::other)?
1237                .to_string();
1238            Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
1239        }
1240    }
1241
1242    fn establish_websocket_connection(
1243        self: &Arc<Self>,
1244        credentials: &Credentials,
1245        cx: &AsyncApp,
1246    ) -> Task<Result<Connection, EstablishConnectionError>> {
1247        let release_channel = cx
1248            .update(|cx| ReleaseChannel::try_global(cx))
1249            .ok()
1250            .flatten();
1251        let app_version = cx
1252            .update(|cx| AppVersion::global(cx).to_string())
1253            .ok()
1254            .unwrap_or_default();
1255
1256        let http = self.http.clone();
1257        let proxy = http.proxy().cloned();
1258        let user_agent = http.user_agent().cloned();
1259        let credentials = credentials.clone();
1260        let rpc_url = self.rpc_url(http, release_channel);
1261        let system_id = self.telemetry.system_id();
1262        let metrics_id = self.telemetry.metrics_id();
1263        cx.spawn(async move |cx| {
1264            use HttpOrHttps::*;
1265
1266            #[derive(Debug)]
1267            enum HttpOrHttps {
1268                Http,
1269                Https,
1270            }
1271
1272            let mut rpc_url = rpc_url.await?;
1273            let url_scheme = match rpc_url.scheme() {
1274                "https" => Https,
1275                "http" => Http,
1276                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1277            };
1278
1279            let stream = gpui_tokio::Tokio::spawn_result(cx, {
1280                let rpc_url = rpc_url.clone();
1281                async move {
1282                    let rpc_host = rpc_url
1283                        .host_str()
1284                        .zip(rpc_url.port_or_known_default())
1285                        .context("missing host in rpc url")?;
1286                    Ok(match proxy {
1287                        Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
1288                        None => Box::new(TcpStream::connect(rpc_host).await?),
1289                    })
1290                }
1291            })?
1292            .await?;
1293
1294            log::info!("connected to rpc endpoint {}", rpc_url);
1295
1296            rpc_url
1297                .set_scheme(match url_scheme {
1298                    Https => "wss",
1299                    Http => "ws",
1300                })
1301                .unwrap();
1302
1303            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1304            // for us from the RPC URL.
1305            //
1306            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1307            let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1308
1309            // We then modify the request to add our desired headers.
1310            let request_headers = request.headers_mut();
1311            request_headers.insert(
1312                http::header::AUTHORIZATION,
1313                HeaderValue::from_str(&credentials.authorization_header())?,
1314            );
1315            request_headers.insert(
1316                "x-zed-protocol-version",
1317                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1318            );
1319            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1320            request_headers.insert(
1321                "x-zed-release-channel",
1322                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1323            );
1324            if let Some(user_agent) = user_agent {
1325                request_headers.insert(http::header::USER_AGENT, user_agent);
1326            }
1327            if let Some(system_id) = system_id {
1328                request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1329            }
1330            if let Some(metrics_id) = metrics_id {
1331                request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1332            }
1333
1334            let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1335                request,
1336                stream,
1337                Some(Arc::new(http_client_tls::tls_config()).into()),
1338                None,
1339            )
1340            .await?;
1341
1342            Ok(Connection::new(
1343                stream
1344                    .map_err(|error| anyhow!(error))
1345                    .sink_map_err(|error| anyhow!(error)),
1346            ))
1347        })
1348    }
1349
1350    pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1351        let http = self.http.clone();
1352        let this = self.clone();
1353        cx.spawn(async move |cx| {
1354            let background = cx.background_executor().clone();
1355
1356            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1357            cx.update(|cx| {
1358                cx.spawn(async move |cx| {
1359                    let url = open_url_rx.await?;
1360                    cx.update(|cx| cx.open_url(&url))
1361                })
1362                .detach_and_log_err(cx);
1363            })
1364            .log_err();
1365
1366            let credentials = background
1367                .clone()
1368                .spawn(async move {
1369                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1370                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1371                    // any other app running on the user's device.
1372                    let (public_key, private_key) =
1373                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1374                    let public_key_string = String::try_from(public_key)
1375                        .expect("failed to serialize public key for auth");
1376
1377                    if let Some((login, token)) =
1378                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1379                    {
1380                        if !*USE_WEB_LOGIN {
1381                            eprintln!("authenticate as admin {login}, {token}");
1382
1383                            return this
1384                                .authenticate_as_admin(http, login.clone(), token.clone())
1385                                .await;
1386                        }
1387                    }
1388
1389                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1390                    let server =
1391                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1392                    let port = server.server_addr().port();
1393
1394                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1395                    // that the user is signing in from a Zed app running on the same device.
1396                    let mut url = http.build_url(&format!(
1397                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1398                        port, public_key_string
1399                    ));
1400
1401                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1402                        log::info!("impersonating user @{}", impersonate_login);
1403                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1404                    }
1405
1406                    open_url_tx.send(url).log_err();
1407
1408                    #[derive(Deserialize)]
1409                    struct CallbackParams {
1410                        pub user_id: String,
1411                        pub access_token: String,
1412                    }
1413
1414                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1415                    // access token from the query params.
1416                    //
1417                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1418                    // custom URL scheme instead of this local HTTP server.
1419                    let (user_id, access_token) = background
1420                        .spawn(async move {
1421                            for _ in 0..100 {
1422                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1423                                    let path = req.url();
1424                                    let url = Url::parse(&format!("http://example.com{}", path))
1425                                        .context("failed to parse login notification url")?;
1426                                    let callback_params: CallbackParams =
1427                                        serde_urlencoded::from_str(url.query().unwrap_or_default())
1428                                            .context(
1429                                                "failed to parse sign-in callback query parameters",
1430                                            )?;
1431
1432                                    let post_auth_url =
1433                                        http.build_url("/native_app_signin_succeeded");
1434                                    req.respond(
1435                                        tiny_http::Response::empty(302).with_header(
1436                                            tiny_http::Header::from_bytes(
1437                                                &b"Location"[..],
1438                                                post_auth_url.as_bytes(),
1439                                            )
1440                                            .unwrap(),
1441                                        ),
1442                                    )
1443                                    .context("failed to respond to login http request")?;
1444                                    return Ok((
1445                                        callback_params.user_id,
1446                                        callback_params.access_token,
1447                                    ));
1448                                }
1449                            }
1450
1451                            anyhow::bail!("didn't receive login redirect");
1452                        })
1453                        .await?;
1454
1455                    let access_token = private_key
1456                        .decrypt_string(&access_token)
1457                        .context("failed to decrypt access token")?;
1458
1459                    Ok(Credentials {
1460                        user_id: user_id.parse()?,
1461                        access_token,
1462                    })
1463                })
1464                .await?;
1465
1466            cx.update(|cx| cx.activate(true))?;
1467            Ok(credentials)
1468        })
1469    }
1470
1471    async fn authenticate_as_admin(
1472        self: &Arc<Self>,
1473        http: Arc<HttpClientWithUrl>,
1474        login: String,
1475        api_token: String,
1476    ) -> Result<Credentials> {
1477        #[derive(Serialize)]
1478        struct ImpersonateUserBody {
1479            github_login: String,
1480        }
1481
1482        #[derive(Deserialize)]
1483        struct ImpersonateUserResponse {
1484            user_id: u64,
1485            access_token: String,
1486        }
1487
1488        let url = self
1489            .http
1490            .build_zed_cloud_url("/internal/users/impersonate")?;
1491        let request = Request::post(url.as_str())
1492            .header("Content-Type", "application/json")
1493            .header("Authorization", format!("Bearer {api_token}"))
1494            .body(
1495                serde_json::to_string(&ImpersonateUserBody {
1496                    github_login: login,
1497                })?
1498                .into(),
1499            )?;
1500
1501        let mut response = http.send(request).await?;
1502        let mut body = String::new();
1503        response.body_mut().read_to_string(&mut body).await?;
1504        anyhow::ensure!(
1505            response.status().is_success(),
1506            "admin user request failed {} - {}",
1507            response.status().as_u16(),
1508            body,
1509        );
1510        let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
1511
1512        Ok(Credentials {
1513            user_id: response.user_id,
1514            access_token: response.access_token,
1515        })
1516    }
1517
1518    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1519        self.state.write().credentials = None;
1520        self.cloud_client.clear_credentials();
1521        self.disconnect(cx);
1522
1523        if self.has_credentials(cx).await {
1524            self.credentials_provider
1525                .delete_credentials(cx)
1526                .await
1527                .log_err();
1528        }
1529    }
1530
1531    pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1532        self.peer.teardown();
1533        self.set_status(Status::SignedOut, cx);
1534    }
1535
1536    pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1537        self.peer.teardown();
1538        self.set_status(Status::ConnectionLost, cx);
1539    }
1540
1541    fn connection_id(&self) -> Result<ConnectionId> {
1542        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1543            Ok(connection_id)
1544        } else {
1545            anyhow::bail!("not connected");
1546        }
1547    }
1548
1549    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1550        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1551        self.peer.send(self.connection_id()?, message)
1552    }
1553
1554    pub fn request<T: RequestMessage>(
1555        &self,
1556        request: T,
1557    ) -> impl Future<Output = Result<T::Response>> + use<T> {
1558        self.request_envelope(request)
1559            .map_ok(|envelope| envelope.payload)
1560    }
1561
1562    pub fn request_stream<T: RequestMessage>(
1563        &self,
1564        request: T,
1565    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1566        let client_id = self.id.load(Ordering::SeqCst);
1567        log::debug!(
1568            "rpc request start. client_id:{}. name:{}",
1569            client_id,
1570            T::NAME
1571        );
1572        let response = self
1573            .connection_id()
1574            .map(|conn_id| self.peer.request_stream(conn_id, request));
1575        async move {
1576            let response = response?.await;
1577            log::debug!(
1578                "rpc request finish. client_id:{}. name:{}",
1579                client_id,
1580                T::NAME
1581            );
1582            response
1583        }
1584    }
1585
1586    pub fn request_envelope<T: RequestMessage>(
1587        &self,
1588        request: T,
1589    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1590        let client_id = self.id();
1591        log::debug!(
1592            "rpc request start. client_id:{}. name:{}",
1593            client_id,
1594            T::NAME
1595        );
1596        let response = self
1597            .connection_id()
1598            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1599        async move {
1600            let response = response?.await;
1601            log::debug!(
1602                "rpc request finish. client_id:{}. name:{}",
1603                client_id,
1604                T::NAME
1605            );
1606            response
1607        }
1608    }
1609
1610    pub fn request_dynamic(
1611        &self,
1612        envelope: proto::Envelope,
1613        request_type: &'static str,
1614    ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1615        let client_id = self.id();
1616        log::debug!(
1617            "rpc request start. client_id:{}. name:{}",
1618            client_id,
1619            request_type
1620        );
1621        let response = self
1622            .connection_id()
1623            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1624        async move {
1625            let response = response?.await;
1626            log::debug!(
1627                "rpc request finish. client_id:{}. name:{}",
1628                client_id,
1629                request_type
1630            );
1631            Ok(response?.0)
1632        }
1633    }
1634
1635    fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1636        let sender_id = message.sender_id();
1637        let request_id = message.message_id();
1638        let type_name = message.payload_type_name();
1639        let original_sender_id = message.original_sender_id();
1640
1641        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1642            &self.handler_set,
1643            message,
1644            self.clone().into(),
1645            cx.clone(),
1646        ) {
1647            let client_id = self.id();
1648            log::debug!(
1649                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1650                client_id,
1651                original_sender_id,
1652                type_name
1653            );
1654            cx.spawn(async move |_| match future.await {
1655                Ok(()) => {
1656                    log::debug!("rpc message handled. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}");
1657                }
1658                Err(error) => {
1659                    log::error!("error handling message. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}, error:{error:#}");
1660                }
1661            })
1662            .detach();
1663        } else {
1664            log::info!("unhandled message {}", type_name);
1665            self.peer
1666                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1667                .log_err();
1668        }
1669    }
1670
1671    pub fn add_message_to_client_handler(
1672        self: &Arc<Client>,
1673        handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
1674    ) {
1675        self.message_to_client_handlers
1676            .lock()
1677            .push(Box::new(handler));
1678    }
1679
1680    fn handle_message_to_client(self: &Arc<Client>, message: MessageToClient, cx: &AsyncApp) {
1681        cx.update(|cx| {
1682            for handler in self.message_to_client_handlers.lock().iter() {
1683                handler(&message, cx);
1684            }
1685        })
1686        .ok();
1687    }
1688
1689    pub fn telemetry(&self) -> &Arc<Telemetry> {
1690        &self.telemetry
1691    }
1692}
1693
1694impl ProtoClient for Client {
1695    fn request(
1696        &self,
1697        envelope: proto::Envelope,
1698        request_type: &'static str,
1699    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1700        self.request_dynamic(envelope, request_type).boxed()
1701    }
1702
1703    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1704        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1705        let connection_id = self.connection_id()?;
1706        self.peer.send_dynamic(connection_id, envelope)
1707    }
1708
1709    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1710        log::debug!(
1711            "rpc respond. client_id:{}, name:{}",
1712            self.id(),
1713            message_type
1714        );
1715        let connection_id = self.connection_id()?;
1716        self.peer.send_dynamic(connection_id, envelope)
1717    }
1718
1719    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1720        &self.handler_set
1721    }
1722
1723    fn is_via_collab(&self) -> bool {
1724        true
1725    }
1726}
1727
1728/// prefix for the zed:// url scheme
1729pub const ZED_URL_SCHEME: &str = "zed";
1730
1731/// Parses the given link into a Zed link.
1732///
1733/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1734/// Returns [`None`] otherwise.
1735pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
1736    let server_url = &ClientSettings::get_global(cx).server_url;
1737    if let Some(stripped) = link
1738        .strip_prefix(server_url)
1739        .and_then(|result| result.strip_prefix('/'))
1740    {
1741        return Some(stripped);
1742    }
1743    if let Some(stripped) = link
1744        .strip_prefix(ZED_URL_SCHEME)
1745        .and_then(|result| result.strip_prefix("://"))
1746    {
1747        return Some(stripped);
1748    }
1749
1750    None
1751}
1752
1753#[cfg(test)]
1754mod tests {
1755    use super::*;
1756    use crate::test::{FakeServer, parse_authorization_header};
1757
1758    use clock::FakeSystemClock;
1759    use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1760    use http_client::FakeHttpClient;
1761    use parking_lot::Mutex;
1762    use proto::TypedEnvelope;
1763    use settings::SettingsStore;
1764    use std::future;
1765
1766    #[gpui::test(iterations = 10)]
1767    async fn test_reconnection(cx: &mut TestAppContext) {
1768        init_test(cx);
1769        let user_id = 5;
1770        let client = cx.update(|cx| {
1771            Client::new(
1772                Arc::new(FakeSystemClock::new()),
1773                FakeHttpClient::with_404_response(),
1774                cx,
1775            )
1776        });
1777        let server = FakeServer::for_client(user_id, &client, cx).await;
1778        let mut status = client.status();
1779        assert!(matches!(
1780            status.next().await,
1781            Some(Status::Connected { .. })
1782        ));
1783        assert_eq!(server.auth_count(), 1);
1784
1785        server.forbid_connections();
1786        server.disconnect();
1787        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1788
1789        server.allow_connections();
1790        cx.executor().advance_clock(Duration::from_secs(10));
1791        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1792        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1793
1794        server.forbid_connections();
1795        server.disconnect();
1796        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1797
1798        // Clear cached credentials after authentication fails
1799        server.roll_access_token();
1800        server.allow_connections();
1801        cx.executor().run_until_parked();
1802        cx.executor().advance_clock(Duration::from_secs(10));
1803        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1804        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1805    }
1806
1807    #[gpui::test(iterations = 10)]
1808    async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) {
1809        init_test(cx);
1810        let http_client = FakeHttpClient::with_200_response();
1811        let client =
1812            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1813        let server = FakeServer::for_client(42, &client, cx).await;
1814        let mut status = client.status();
1815        assert!(matches!(
1816            status.next().await,
1817            Some(Status::Connected { .. })
1818        ));
1819        assert_eq!(server.auth_count(), 1);
1820
1821        // Simulate an auth failure during reconnection.
1822        http_client
1823            .as_fake()
1824            .replace_handler(|_, _request| async move {
1825                Ok(http_client::Response::builder()
1826                    .status(503)
1827                    .body("".into())
1828                    .unwrap())
1829            });
1830        server.disconnect();
1831        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1832
1833        // Restore the ability to authenticate.
1834        http_client
1835            .as_fake()
1836            .replace_handler(|_, _request| async move {
1837                Ok(http_client::Response::builder()
1838                    .status(200)
1839                    .body("".into())
1840                    .unwrap())
1841            });
1842        cx.executor().advance_clock(Duration::from_secs(10));
1843        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1844        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1845    }
1846
1847    #[gpui::test(iterations = 10)]
1848    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1849        init_test(cx);
1850        let user_id = 5;
1851        let client = cx.update(|cx| {
1852            Client::new(
1853                Arc::new(FakeSystemClock::new()),
1854                FakeHttpClient::with_404_response(),
1855                cx,
1856            )
1857        });
1858        let mut status = client.status();
1859
1860        // Time out when client tries to connect.
1861        client.override_authenticate(move |cx| {
1862            cx.background_spawn(async move {
1863                Ok(Credentials {
1864                    user_id,
1865                    access_token: "token".into(),
1866                })
1867            })
1868        });
1869        client.override_establish_connection(|_, cx| {
1870            cx.background_spawn(async move {
1871                future::pending::<()>().await;
1872                unreachable!()
1873            })
1874        });
1875        let auth_and_connect = cx.spawn({
1876            let client = client.clone();
1877            |cx| async move { client.connect(false, &cx).await }
1878        });
1879        executor.run_until_parked();
1880        assert!(matches!(status.next().await, Some(Status::Connecting)));
1881
1882        executor.advance_clock(CONNECTION_TIMEOUT);
1883        assert!(matches!(status.next().await, Some(Status::ConnectionError)));
1884        auth_and_connect.await.into_response().unwrap_err();
1885
1886        // Allow the connection to be established.
1887        let server = FakeServer::for_client(user_id, &client, cx).await;
1888        assert!(matches!(
1889            status.next().await,
1890            Some(Status::Connected { .. })
1891        ));
1892
1893        // Disconnect client.
1894        server.forbid_connections();
1895        server.disconnect();
1896        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1897
1898        // Time out when re-establishing the connection.
1899        server.allow_connections();
1900        client.override_establish_connection(|_, cx| {
1901            cx.background_spawn(async move {
1902                future::pending::<()>().await;
1903                unreachable!()
1904            })
1905        });
1906        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1907        assert!(matches!(status.next().await, Some(Status::Reconnecting)));
1908
1909        executor.advance_clock(CONNECTION_TIMEOUT);
1910        assert!(matches!(
1911            status.next().await,
1912            Some(Status::ReconnectionError { .. })
1913        ));
1914    }
1915
1916    #[gpui::test(iterations = 10)]
1917    async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
1918        init_test(cx);
1919        let auth_count = Arc::new(Mutex::new(0));
1920        let http_client = FakeHttpClient::create(|_request| async move {
1921            Ok(http_client::Response::builder()
1922                .status(200)
1923                .body("".into())
1924                .unwrap())
1925        });
1926        let client =
1927            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1928        client.override_authenticate({
1929            let auth_count = auth_count.clone();
1930            move |cx| {
1931                let auth_count = auth_count.clone();
1932                cx.background_spawn(async move {
1933                    *auth_count.lock() += 1;
1934                    Ok(Credentials {
1935                        user_id: 1,
1936                        access_token: auth_count.lock().to_string(),
1937                    })
1938                })
1939            }
1940        });
1941
1942        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1943        assert_eq!(*auth_count.lock(), 1);
1944        assert_eq!(credentials.access_token, "1");
1945
1946        // If credentials are still valid, signing in doesn't trigger authentication.
1947        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1948        assert_eq!(*auth_count.lock(), 1);
1949        assert_eq!(credentials.access_token, "1");
1950
1951        // If the server is unavailable, signing in doesn't trigger authentication.
1952        http_client
1953            .as_fake()
1954            .replace_handler(|_, _request| async move {
1955                Ok(http_client::Response::builder()
1956                    .status(503)
1957                    .body("".into())
1958                    .unwrap())
1959            });
1960        client.sign_in(false, &cx.to_async()).await.unwrap_err();
1961        assert_eq!(*auth_count.lock(), 1);
1962
1963        // If credentials became invalid, signing in triggers authentication.
1964        http_client
1965            .as_fake()
1966            .replace_handler(|_, request| async move {
1967                let credentials = parse_authorization_header(&request).unwrap();
1968                if credentials.access_token == "2" {
1969                    Ok(http_client::Response::builder()
1970                        .status(200)
1971                        .body("".into())
1972                        .unwrap())
1973                } else {
1974                    Ok(http_client::Response::builder()
1975                        .status(401)
1976                        .body("".into())
1977                        .unwrap())
1978                }
1979            });
1980        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1981        assert_eq!(*auth_count.lock(), 2);
1982        assert_eq!(credentials.access_token, "2");
1983    }
1984
1985    #[gpui::test(iterations = 10)]
1986    async fn test_authenticating_more_than_once(
1987        cx: &mut TestAppContext,
1988        executor: BackgroundExecutor,
1989    ) {
1990        init_test(cx);
1991        let auth_count = Arc::new(Mutex::new(0));
1992        let dropped_auth_count = Arc::new(Mutex::new(0));
1993        let client = cx.update(|cx| {
1994            Client::new(
1995                Arc::new(FakeSystemClock::new()),
1996                FakeHttpClient::with_404_response(),
1997                cx,
1998            )
1999        });
2000        client.override_authenticate({
2001            let auth_count = auth_count.clone();
2002            let dropped_auth_count = dropped_auth_count.clone();
2003            move |cx| {
2004                let auth_count = auth_count.clone();
2005                let dropped_auth_count = dropped_auth_count.clone();
2006                cx.background_spawn(async move {
2007                    *auth_count.lock() += 1;
2008                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
2009                    future::pending::<()>().await;
2010                    unreachable!()
2011                })
2012            }
2013        });
2014
2015        let _authenticate = cx.spawn({
2016            let client = client.clone();
2017            move |cx| async move { client.connect(false, &cx).await }
2018        });
2019        executor.run_until_parked();
2020        assert_eq!(*auth_count.lock(), 1);
2021        assert_eq!(*dropped_auth_count.lock(), 0);
2022
2023        let _authenticate = cx.spawn(|cx| async move { client.connect(false, &cx).await });
2024        executor.run_until_parked();
2025        assert_eq!(*auth_count.lock(), 2);
2026        assert_eq!(*dropped_auth_count.lock(), 1);
2027    }
2028
2029    #[gpui::test]
2030    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
2031        init_test(cx);
2032        let user_id = 5;
2033        let client = cx.update(|cx| {
2034            Client::new(
2035                Arc::new(FakeSystemClock::new()),
2036                FakeHttpClient::with_404_response(),
2037                cx,
2038            )
2039        });
2040        let server = FakeServer::for_client(user_id, &client, cx).await;
2041
2042        let (done_tx1, done_rx1) = smol::channel::unbounded();
2043        let (done_tx2, done_rx2) = smol::channel::unbounded();
2044        AnyProtoClient::from(client.clone()).add_entity_message_handler(
2045            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, cx| {
2046                match entity.read_with(&cx, |entity, _| entity.id).unwrap() {
2047                    1 => done_tx1.try_send(()).unwrap(),
2048                    2 => done_tx2.try_send(()).unwrap(),
2049                    _ => unreachable!(),
2050                }
2051                async { Ok(()) }
2052            },
2053        );
2054        let entity1 = cx.new(|_| TestEntity {
2055            id: 1,
2056            subscription: None,
2057        });
2058        let entity2 = cx.new(|_| TestEntity {
2059            id: 2,
2060            subscription: None,
2061        });
2062        let entity3 = cx.new(|_| TestEntity {
2063            id: 3,
2064            subscription: None,
2065        });
2066
2067        let _subscription1 = client
2068            .subscribe_to_entity(1)
2069            .unwrap()
2070            .set_entity(&entity1, &cx.to_async());
2071        let _subscription2 = client
2072            .subscribe_to_entity(2)
2073            .unwrap()
2074            .set_entity(&entity2, &cx.to_async());
2075        // Ensure dropping a subscription for the same entity type still allows receiving of
2076        // messages for other entity IDs of the same type.
2077        let subscription3 = client
2078            .subscribe_to_entity(3)
2079            .unwrap()
2080            .set_entity(&entity3, &cx.to_async());
2081        drop(subscription3);
2082
2083        server.send(proto::JoinProject {
2084            project_id: 1,
2085            committer_name: None,
2086            committer_email: None,
2087        });
2088        server.send(proto::JoinProject {
2089            project_id: 2,
2090            committer_name: None,
2091            committer_email: None,
2092        });
2093        done_rx1.recv().await.unwrap();
2094        done_rx2.recv().await.unwrap();
2095    }
2096
2097    #[gpui::test]
2098    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2099        init_test(cx);
2100        let user_id = 5;
2101        let client = cx.update(|cx| {
2102            Client::new(
2103                Arc::new(FakeSystemClock::new()),
2104                FakeHttpClient::with_404_response(),
2105                cx,
2106            )
2107        });
2108        let server = FakeServer::for_client(user_id, &client, cx).await;
2109
2110        let entity = cx.new(|_| TestEntity::default());
2111        let (done_tx1, _done_rx1) = smol::channel::unbounded();
2112        let (done_tx2, done_rx2) = smol::channel::unbounded();
2113        let subscription1 = client.add_message_handler(
2114            entity.downgrade(),
2115            move |_, _: TypedEnvelope<proto::Ping>, _| {
2116                done_tx1.try_send(()).unwrap();
2117                async { Ok(()) }
2118            },
2119        );
2120        drop(subscription1);
2121        let _subscription2 = client.add_message_handler(
2122            entity.downgrade(),
2123            move |_, _: TypedEnvelope<proto::Ping>, _| {
2124                done_tx2.try_send(()).unwrap();
2125                async { Ok(()) }
2126            },
2127        );
2128        server.send(proto::Ping {});
2129        done_rx2.recv().await.unwrap();
2130    }
2131
2132    #[gpui::test]
2133    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2134        init_test(cx);
2135        let user_id = 5;
2136        let client = cx.update(|cx| {
2137            Client::new(
2138                Arc::new(FakeSystemClock::new()),
2139                FakeHttpClient::with_404_response(),
2140                cx,
2141            )
2142        });
2143        let server = FakeServer::for_client(user_id, &client, cx).await;
2144
2145        let entity = cx.new(|_| TestEntity::default());
2146        let (done_tx, done_rx) = smol::channel::unbounded();
2147        let subscription = client.add_message_handler(
2148            entity.clone().downgrade(),
2149            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
2150                entity
2151                    .update(&mut cx, |entity, _| entity.subscription.take())
2152                    .unwrap();
2153                done_tx.try_send(()).unwrap();
2154                async { Ok(()) }
2155            },
2156        );
2157        entity.update(cx, |entity, _| {
2158            entity.subscription = Some(subscription);
2159        });
2160        server.send(proto::Ping {});
2161        done_rx.recv().await.unwrap();
2162    }
2163
2164    #[derive(Default)]
2165    struct TestEntity {
2166        id: usize,
2167        subscription: Option<Subscription>,
2168    }
2169
2170    fn init_test(cx: &mut TestAppContext) {
2171        cx.update(|cx| {
2172            let settings_store = SettingsStore::test(cx);
2173            cx.set_global(settings_store);
2174        });
2175    }
2176}