client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod proxy;
   5pub mod telemetry;
   6pub mod user;
   7pub mod zed_urls;
   8
   9use anyhow::{Context as _, Result, anyhow};
  10use async_tungstenite::tungstenite::{
  11    client::IntoClientRequest,
  12    error::Error as WebsocketError,
  13    http::{HeaderValue, Request, StatusCode},
  14};
  15use clock::SystemClock;
  16use cloud_api_client::CloudApiClient;
  17use cloud_api_client::websocket_protocol::MessageToClient;
  18use credentials_provider::CredentialsProvider;
  19use feature_flags::FeatureFlagAppExt as _;
  20use futures::{
  21    AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
  22    channel::oneshot, future::BoxFuture,
  23};
  24use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
  25use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
  26use parking_lot::RwLock;
  27use postage::watch;
  28use proxy::connect_proxy_stream;
  29use rand::prelude::*;
  30use release_channel::{AppVersion, ReleaseChannel};
  31use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  32use serde::{Deserialize, Serialize};
  33use settings::{RegisterSetting, Settings, SettingsContent};
  34use std::{
  35    any::TypeId,
  36    convert::TryFrom,
  37    fmt::Write as _,
  38    future::Future,
  39    marker::PhantomData,
  40    path::PathBuf,
  41    sync::{
  42        Arc, LazyLock, Weak,
  43        atomic::{AtomicU64, Ordering},
  44    },
  45    time::{Duration, Instant},
  46};
  47use std::{cmp, pin::Pin};
  48use telemetry::Telemetry;
  49use thiserror::Error;
  50use tokio::net::TcpStream;
  51use url::Url;
  52use util::{ConnectionResult, ResultExt};
  53
  54pub use rpc::*;
  55pub use telemetry_events::Event;
  56pub use user::*;
  57
  58static ZED_SERVER_URL: LazyLock<Option<String>> =
  59    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  60static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  61
  62pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  63    std::env::var("ZED_IMPERSONATE")
  64        .ok()
  65        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  66});
  67
  68pub static USE_WEB_LOGIN: LazyLock<bool> = LazyLock::new(|| std::env::var("ZED_WEB_LOGIN").is_ok());
  69
  70pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  71    std::env::var("ZED_ADMIN_API_TOKEN")
  72        .ok()
  73        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  74});
  75
  76pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  77    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  78
  79pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  80    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").is_ok_and(|e| !e.is_empty()));
  81
  82pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  83pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
  84pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  85
  86actions!(
  87    client,
  88    [
  89        /// Signs in to Zed account.
  90        SignIn,
  91        /// Signs out of Zed account.
  92        SignOut,
  93        /// Reconnects to the collaboration server.
  94        Reconnect
  95    ]
  96);
  97
  98#[derive(Deserialize, RegisterSetting)]
  99pub struct ClientSettings {
 100    pub server_url: String,
 101}
 102
 103impl Settings for ClientSettings {
 104    fn from_settings(content: &settings::SettingsContent) -> Self {
 105        if let Some(server_url) = &*ZED_SERVER_URL {
 106            return Self {
 107                server_url: server_url.clone(),
 108            };
 109        }
 110        Self {
 111            server_url: content.server_url.clone().unwrap(),
 112        }
 113    }
 114}
 115
 116#[derive(Deserialize, Default, RegisterSetting)]
 117pub struct ProxySettings {
 118    pub proxy: Option<String>,
 119}
 120
 121impl ProxySettings {
 122    pub fn proxy_url(&self) -> Option<Url> {
 123        self.proxy
 124            .as_ref()
 125            .and_then(|input| {
 126                input
 127                    .parse::<Url>()
 128                    .inspect_err(|e| log::error!("Error parsing proxy settings: {}", e))
 129                    .ok()
 130            })
 131            .or_else(read_proxy_from_env)
 132    }
 133}
 134
 135impl Settings for ProxySettings {
 136    fn from_settings(content: &settings::SettingsContent) -> Self {
 137        Self {
 138            proxy: content.proxy.clone(),
 139        }
 140    }
 141}
 142
 143pub fn init(client: &Arc<Client>, cx: &mut App) {
 144    let client = Arc::downgrade(client);
 145    cx.on_action({
 146        let client = client.clone();
 147        move |_: &SignIn, cx| {
 148            if let Some(client) = client.upgrade() {
 149                cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, cx).await)
 150                    .detach_and_log_err(cx);
 151            }
 152        }
 153    })
 154    .on_action({
 155        let client = client.clone();
 156        move |_: &SignOut, cx| {
 157            if let Some(client) = client.upgrade() {
 158                cx.spawn(async move |cx| {
 159                    client.sign_out(cx).await;
 160                })
 161                .detach();
 162            }
 163        }
 164    })
 165    .on_action({
 166        let client = client;
 167        move |_: &Reconnect, cx| {
 168            if let Some(client) = client.upgrade() {
 169                cx.spawn(async move |cx| {
 170                    client.reconnect(cx);
 171                })
 172                .detach();
 173            }
 174        }
 175    });
 176}
 177
 178pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
 179
 180struct GlobalClient(Arc<Client>);
 181
 182impl Global for GlobalClient {}
 183
 184pub struct Client {
 185    id: AtomicU64,
 186    peer: Arc<Peer>,
 187    http: Arc<HttpClientWithUrl>,
 188    cloud_client: Arc<CloudApiClient>,
 189    telemetry: Arc<Telemetry>,
 190    credentials_provider: ClientCredentialsProvider,
 191    state: RwLock<ClientState>,
 192    handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
 193    message_to_client_handlers: parking_lot::Mutex<Vec<MessageToClientHandler>>,
 194
 195    #[allow(clippy::type_complexity)]
 196    #[cfg(any(test, feature = "test-support"))]
 197    authenticate:
 198        RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
 199
 200    #[allow(clippy::type_complexity)]
 201    #[cfg(any(test, feature = "test-support"))]
 202    establish_connection: RwLock<
 203        Option<
 204            Box<
 205                dyn 'static
 206                    + Send
 207                    + Sync
 208                    + Fn(
 209                        &Credentials,
 210                        &AsyncApp,
 211                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 212            >,
 213        >,
 214    >,
 215
 216    #[cfg(any(test, feature = "test-support"))]
 217    rpc_url: RwLock<Option<Url>>,
 218}
 219
 220#[derive(Error, Debug)]
 221pub enum EstablishConnectionError {
 222    #[error("upgrade required")]
 223    UpgradeRequired,
 224    #[error("unauthorized")]
 225    Unauthorized,
 226    #[error("{0}")]
 227    Other(#[from] anyhow::Error),
 228    #[error("{0}")]
 229    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 230    #[error("{0}")]
 231    Io(#[from] std::io::Error),
 232    #[error("{0}")]
 233    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 234}
 235
 236impl From<WebsocketError> for EstablishConnectionError {
 237    fn from(error: WebsocketError) -> Self {
 238        if let WebsocketError::Http(response) = &error {
 239            match response.status() {
 240                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 241                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 242                _ => {}
 243            }
 244        }
 245        EstablishConnectionError::Other(error.into())
 246    }
 247}
 248
 249impl EstablishConnectionError {
 250    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 251        Self::Other(error.into())
 252    }
 253}
 254
 255#[derive(Copy, Clone, Debug, PartialEq)]
 256pub enum Status {
 257    SignedOut,
 258    UpgradeRequired,
 259    Authenticating,
 260    Authenticated,
 261    AuthenticationError,
 262    Connecting,
 263    ConnectionError,
 264    Connected {
 265        peer_id: PeerId,
 266        connection_id: ConnectionId,
 267    },
 268    ConnectionLost,
 269    Reauthenticating,
 270    Reauthenticated,
 271    Reconnecting,
 272    ReconnectionError {
 273        next_reconnection: Instant,
 274    },
 275}
 276
 277impl Status {
 278    pub fn is_connected(&self) -> bool {
 279        matches!(self, Self::Connected { .. })
 280    }
 281
 282    pub fn was_connected(&self) -> bool {
 283        matches!(
 284            self,
 285            Self::ConnectionLost
 286                | Self::Reauthenticating
 287                | Self::Reauthenticated
 288                | Self::Reconnecting
 289        )
 290    }
 291
 292    /// Returns whether the client is currently connected or was connected at some point.
 293    pub fn is_or_was_connected(&self) -> bool {
 294        self.is_connected() || self.was_connected()
 295    }
 296
 297    pub fn is_signing_in(&self) -> bool {
 298        matches!(
 299            self,
 300            Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
 301        )
 302    }
 303
 304    pub fn is_signed_out(&self) -> bool {
 305        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 306    }
 307}
 308
 309struct ClientState {
 310    credentials: Option<Credentials>,
 311    status: (watch::Sender<Status>, watch::Receiver<Status>),
 312    _reconnect_task: Option<Task<()>>,
 313}
 314
 315#[derive(Clone, Debug, Eq, PartialEq)]
 316pub struct Credentials {
 317    pub user_id: u64,
 318    pub access_token: String,
 319}
 320
 321impl Credentials {
 322    pub fn authorization_header(&self) -> String {
 323        format!("{} {}", self.user_id, self.access_token)
 324    }
 325}
 326
 327pub struct ClientCredentialsProvider {
 328    provider: Arc<dyn CredentialsProvider>,
 329}
 330
 331impl ClientCredentialsProvider {
 332    pub fn new(cx: &App) -> Self {
 333        Self {
 334            provider: <dyn CredentialsProvider>::global(cx),
 335        }
 336    }
 337
 338    fn server_url(&self, cx: &AsyncApp) -> Result<String> {
 339        cx.update(|cx| ClientSettings::get_global(cx).server_url.clone())
 340    }
 341
 342    /// Reads the credentials from the provider.
 343    fn read_credentials<'a>(
 344        &'a self,
 345        cx: &'a AsyncApp,
 346    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
 347        async move {
 348            if IMPERSONATE_LOGIN.is_some() {
 349                return None;
 350            }
 351
 352            let server_url = self.server_url(cx).ok()?;
 353            let (user_id, access_token) = self
 354                .provider
 355                .read_credentials(&server_url, cx)
 356                .await
 357                .log_err()
 358                .flatten()?;
 359
 360            Some(Credentials {
 361                user_id: user_id.parse().ok()?,
 362                access_token: String::from_utf8(access_token).ok()?,
 363            })
 364        }
 365        .boxed_local()
 366    }
 367
 368    /// Writes the credentials to the provider.
 369    fn write_credentials<'a>(
 370        &'a self,
 371        user_id: u64,
 372        access_token: String,
 373        cx: &'a AsyncApp,
 374    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 375        async move {
 376            let server_url = self.server_url(cx)?;
 377            self.provider
 378                .write_credentials(
 379                    &server_url,
 380                    &user_id.to_string(),
 381                    access_token.as_bytes(),
 382                    cx,
 383                )
 384                .await
 385        }
 386        .boxed_local()
 387    }
 388
 389    /// Deletes the credentials from the provider.
 390    fn delete_credentials<'a>(
 391        &'a self,
 392        cx: &'a AsyncApp,
 393    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 394        async move {
 395            let server_url = self.server_url(cx)?;
 396            self.provider.delete_credentials(&server_url, cx).await
 397        }
 398        .boxed_local()
 399    }
 400}
 401
 402impl Default for ClientState {
 403    fn default() -> Self {
 404        Self {
 405            credentials: None,
 406            status: watch::channel_with(Status::SignedOut),
 407            _reconnect_task: None,
 408        }
 409    }
 410}
 411
 412pub enum Subscription {
 413    Entity {
 414        client: Weak<Client>,
 415        id: (TypeId, u64),
 416    },
 417    Message {
 418        client: Weak<Client>,
 419        id: TypeId,
 420    },
 421}
 422
 423impl Drop for Subscription {
 424    fn drop(&mut self) {
 425        match self {
 426            Subscription::Entity { client, id } => {
 427                if let Some(client) = client.upgrade() {
 428                    let mut state = client.handler_set.lock();
 429                    let _ = state.entities_by_type_and_remote_id.remove(id);
 430                }
 431            }
 432            Subscription::Message { client, id } => {
 433                if let Some(client) = client.upgrade() {
 434                    let mut state = client.handler_set.lock();
 435                    let _ = state.entity_types_by_message_type.remove(id);
 436                    let _ = state.message_handlers.remove(id);
 437                }
 438            }
 439        }
 440    }
 441}
 442
 443pub struct PendingEntitySubscription<T: 'static> {
 444    client: Arc<Client>,
 445    remote_id: u64,
 446    _entity_type: PhantomData<T>,
 447    consumed: bool,
 448}
 449
 450impl<T: 'static> PendingEntitySubscription<T> {
 451    pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
 452        self.consumed = true;
 453        let mut handlers = self.client.handler_set.lock();
 454        let id = (TypeId::of::<T>(), self.remote_id);
 455        let Some(EntityMessageSubscriber::Pending(messages)) =
 456            handlers.entities_by_type_and_remote_id.remove(&id)
 457        else {
 458            unreachable!()
 459        };
 460
 461        handlers.entities_by_type_and_remote_id.insert(
 462            id,
 463            EntityMessageSubscriber::Entity {
 464                handle: entity.downgrade().into(),
 465            },
 466        );
 467        drop(handlers);
 468        for message in messages {
 469            let client_id = self.client.id();
 470            let type_name = message.payload_type_name();
 471            let sender_id = message.original_sender_id();
 472            log::debug!(
 473                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 474                client_id,
 475                sender_id,
 476                type_name
 477            );
 478            self.client.handle_message(message, cx);
 479        }
 480        Subscription::Entity {
 481            client: Arc::downgrade(&self.client),
 482            id,
 483        }
 484    }
 485}
 486
 487impl<T: 'static> Drop for PendingEntitySubscription<T> {
 488    fn drop(&mut self) {
 489        if !self.consumed {
 490            let mut state = self.client.handler_set.lock();
 491            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 492                .entities_by_type_and_remote_id
 493                .remove(&(TypeId::of::<T>(), self.remote_id))
 494            {
 495                for message in messages {
 496                    log::info!("unhandled message {}", message.payload_type_name());
 497                }
 498            }
 499        }
 500    }
 501}
 502
 503#[derive(Copy, Clone, Deserialize, Debug, RegisterSetting)]
 504pub struct TelemetrySettings {
 505    pub diagnostics: bool,
 506    pub metrics: bool,
 507}
 508
 509impl settings::Settings for TelemetrySettings {
 510    fn from_settings(content: &SettingsContent) -> Self {
 511        Self {
 512            diagnostics: content.telemetry.as_ref().unwrap().diagnostics.unwrap(),
 513            metrics: content.telemetry.as_ref().unwrap().metrics.unwrap(),
 514        }
 515    }
 516}
 517
 518impl Client {
 519    pub fn new(
 520        clock: Arc<dyn SystemClock>,
 521        http: Arc<HttpClientWithUrl>,
 522        cx: &mut App,
 523    ) -> Arc<Self> {
 524        Arc::new(Self {
 525            id: AtomicU64::new(0),
 526            peer: Peer::new(0),
 527            telemetry: Telemetry::new(clock, http.clone(), cx),
 528            cloud_client: Arc::new(CloudApiClient::new(http.clone())),
 529            http,
 530            credentials_provider: ClientCredentialsProvider::new(cx),
 531            state: Default::default(),
 532            handler_set: Default::default(),
 533            message_to_client_handlers: parking_lot::Mutex::new(Vec::new()),
 534
 535            #[cfg(any(test, feature = "test-support"))]
 536            authenticate: Default::default(),
 537            #[cfg(any(test, feature = "test-support"))]
 538            establish_connection: Default::default(),
 539            #[cfg(any(test, feature = "test-support"))]
 540            rpc_url: RwLock::default(),
 541        })
 542    }
 543
 544    pub fn production(cx: &mut App) -> Arc<Self> {
 545        let clock = Arc::new(clock::RealSystemClock);
 546        let http = Arc::new(HttpClientWithUrl::new_url(
 547            cx.http_client(),
 548            &ClientSettings::get_global(cx).server_url,
 549            cx.http_client().proxy().cloned(),
 550        ));
 551        Self::new(clock, http, cx)
 552    }
 553
 554    pub fn id(&self) -> u64 {
 555        self.id.load(Ordering::SeqCst)
 556    }
 557
 558    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 559        self.http.clone()
 560    }
 561
 562    pub fn cloud_client(&self) -> Arc<CloudApiClient> {
 563        self.cloud_client.clone()
 564    }
 565
 566    pub fn set_id(&self, id: u64) -> &Self {
 567        self.id.store(id, Ordering::SeqCst);
 568        self
 569    }
 570
 571    #[cfg(any(test, feature = "test-support"))]
 572    pub fn teardown(&self) {
 573        let mut state = self.state.write();
 574        state._reconnect_task.take();
 575        self.handler_set.lock().clear();
 576        self.peer.teardown();
 577    }
 578
 579    #[cfg(any(test, feature = "test-support"))]
 580    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 581    where
 582        F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
 583    {
 584        *self.authenticate.write() = Some(Box::new(authenticate));
 585        self
 586    }
 587
 588    #[cfg(any(test, feature = "test-support"))]
 589    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 590    where
 591        F: 'static
 592            + Send
 593            + Sync
 594            + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
 595    {
 596        *self.establish_connection.write() = Some(Box::new(connect));
 597        self
 598    }
 599
 600    #[cfg(any(test, feature = "test-support"))]
 601    pub fn override_rpc_url(&self, url: Url) -> &Self {
 602        *self.rpc_url.write() = Some(url);
 603        self
 604    }
 605
 606    pub fn global(cx: &App) -> Arc<Self> {
 607        cx.global::<GlobalClient>().0.clone()
 608    }
 609    pub fn set_global(client: Arc<Client>, cx: &mut App) {
 610        cx.set_global(GlobalClient(client))
 611    }
 612
 613    pub fn user_id(&self) -> Option<u64> {
 614        self.state
 615            .read()
 616            .credentials
 617            .as_ref()
 618            .map(|credentials| credentials.user_id)
 619    }
 620
 621    pub fn peer_id(&self) -> Option<PeerId> {
 622        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 623            Some(*peer_id)
 624        } else {
 625            None
 626        }
 627    }
 628
 629    pub fn status(&self) -> watch::Receiver<Status> {
 630        self.state.read().status.1.clone()
 631    }
 632
 633    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
 634        log::info!("set status on client {}: {:?}", self.id(), status);
 635        let mut state = self.state.write();
 636        *state.status.0.borrow_mut() = status;
 637
 638        match status {
 639            Status::Connected { .. } => {
 640                state._reconnect_task = None;
 641            }
 642            Status::ConnectionLost => {
 643                let client = self.clone();
 644                state._reconnect_task = Some(cx.spawn(async move |cx| {
 645                    #[cfg(any(test, feature = "test-support"))]
 646                    let mut rng = StdRng::seed_from_u64(0);
 647                    #[cfg(not(any(test, feature = "test-support")))]
 648                    let mut rng = StdRng::from_os_rng();
 649
 650                    let mut delay = INITIAL_RECONNECTION_DELAY;
 651                    loop {
 652                        match client.connect(true, cx).await {
 653                            ConnectionResult::Timeout => {
 654                                log::error!("client connect attempt timed out")
 655                            }
 656                            ConnectionResult::ConnectionReset => {
 657                                log::error!("client connect attempt reset")
 658                            }
 659                            ConnectionResult::Result(r) => {
 660                                if let Err(error) = r {
 661                                    log::error!("failed to connect: {error}");
 662                                } else {
 663                                    break;
 664                                }
 665                            }
 666                        }
 667
 668                        if matches!(
 669                            *client.status().borrow(),
 670                            Status::AuthenticationError | Status::ConnectionError
 671                        ) {
 672                            client.set_status(
 673                                Status::ReconnectionError {
 674                                    next_reconnection: Instant::now() + delay,
 675                                },
 676                                cx,
 677                            );
 678                            let jitter = Duration::from_millis(
 679                                rng.random_range(0..delay.as_millis() as u64),
 680                            );
 681                            cx.background_executor().timer(delay + jitter).await;
 682                            delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
 683                        } else {
 684                            break;
 685                        }
 686                    }
 687                }));
 688            }
 689            Status::SignedOut | Status::UpgradeRequired => {
 690                self.telemetry.set_authenticated_user_info(None, false);
 691                state._reconnect_task.take();
 692            }
 693            _ => {}
 694        }
 695    }
 696
 697    pub fn subscribe_to_entity<T>(
 698        self: &Arc<Self>,
 699        remote_id: u64,
 700    ) -> Result<PendingEntitySubscription<T>>
 701    where
 702        T: 'static,
 703    {
 704        let id = (TypeId::of::<T>(), remote_id);
 705
 706        let mut state = self.handler_set.lock();
 707        anyhow::ensure!(
 708            !state.entities_by_type_and_remote_id.contains_key(&id),
 709            "already subscribed to entity"
 710        );
 711
 712        state
 713            .entities_by_type_and_remote_id
 714            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 715
 716        Ok(PendingEntitySubscription {
 717            client: self.clone(),
 718            remote_id,
 719            consumed: false,
 720            _entity_type: PhantomData,
 721        })
 722    }
 723
 724    #[track_caller]
 725    pub fn add_message_handler<M, E, H, F>(
 726        self: &Arc<Self>,
 727        entity: WeakEntity<E>,
 728        handler: H,
 729    ) -> Subscription
 730    where
 731        M: EnvelopedMessage,
 732        E: 'static,
 733        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 734        F: 'static + Future<Output = Result<()>>,
 735    {
 736        self.add_message_handler_impl(entity, move |entity, message, _, cx| {
 737            handler(entity, message, cx)
 738        })
 739    }
 740
 741    fn add_message_handler_impl<M, E, H, F>(
 742        self: &Arc<Self>,
 743        entity: WeakEntity<E>,
 744        handler: H,
 745    ) -> Subscription
 746    where
 747        M: EnvelopedMessage,
 748        E: 'static,
 749        H: 'static
 750            + Sync
 751            + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
 752            + Send
 753            + Sync,
 754        F: 'static + Future<Output = Result<()>>,
 755    {
 756        let message_type_id = TypeId::of::<M>();
 757        let mut state = self.handler_set.lock();
 758        state
 759            .entities_by_message_type
 760            .insert(message_type_id, entity.into());
 761
 762        let prev_handler = state.message_handlers.insert(
 763            message_type_id,
 764            Arc::new(move |subscriber, envelope, client, cx| {
 765                let subscriber = subscriber.downcast::<E>().unwrap();
 766                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 767                handler(subscriber, *envelope, client, cx).boxed_local()
 768            }),
 769        );
 770        if prev_handler.is_some() {
 771            let location = std::panic::Location::caller();
 772            panic!(
 773                "{}:{} registered handler for the same message {} twice",
 774                location.file(),
 775                location.line(),
 776                std::any::type_name::<M>()
 777            );
 778        }
 779
 780        Subscription::Message {
 781            client: Arc::downgrade(self),
 782            id: message_type_id,
 783        }
 784    }
 785
 786    pub fn add_request_handler<M, E, H, F>(
 787        self: &Arc<Self>,
 788        entity: WeakEntity<E>,
 789        handler: H,
 790    ) -> Subscription
 791    where
 792        M: RequestMessage,
 793        E: 'static,
 794        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 795        F: 'static + Future<Output = Result<M::Response>>,
 796    {
 797        self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
 798            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 799        })
 800    }
 801
 802    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 803        receipt: Receipt<T>,
 804        response: F,
 805        client: AnyProtoClient,
 806    ) -> Result<()> {
 807        match response.await {
 808            Ok(response) => {
 809                client.send_response(receipt.message_id, response)?;
 810                Ok(())
 811            }
 812            Err(error) => {
 813                client.send_response(receipt.message_id, error.to_proto())?;
 814                Err(error)
 815            }
 816        }
 817    }
 818
 819    pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
 820        self.credentials_provider
 821            .read_credentials(cx)
 822            .await
 823            .is_some()
 824    }
 825
 826    pub async fn sign_in(
 827        self: &Arc<Self>,
 828        try_provider: bool,
 829        cx: &AsyncApp,
 830    ) -> Result<Credentials> {
 831        let is_reauthenticating = if self.status().borrow().is_signed_out() {
 832            self.set_status(Status::Authenticating, cx);
 833            false
 834        } else {
 835            self.set_status(Status::Reauthenticating, cx);
 836            true
 837        };
 838
 839        let mut credentials = None;
 840
 841        let old_credentials = self.state.read().credentials.clone();
 842        if let Some(old_credentials) = old_credentials
 843            && self.validate_credentials(&old_credentials, cx).await?
 844        {
 845            credentials = Some(old_credentials);
 846        }
 847
 848        if credentials.is_none()
 849            && try_provider
 850            && let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await
 851        {
 852            if self.validate_credentials(&stored_credentials, cx).await? {
 853                credentials = Some(stored_credentials);
 854            } else {
 855                self.credentials_provider
 856                    .delete_credentials(cx)
 857                    .await
 858                    .log_err();
 859            }
 860        }
 861
 862        if credentials.is_none() {
 863            let mut status_rx = self.status();
 864            let _ = status_rx.next().await;
 865            futures::select_biased! {
 866                authenticate = self.authenticate(cx).fuse() => {
 867                    match authenticate {
 868                        Ok(creds) => {
 869                            if IMPERSONATE_LOGIN.is_none() {
 870                                self.credentials_provider
 871                                    .write_credentials(creds.user_id, creds.access_token.clone(), cx)
 872                                    .await
 873                                    .log_err();
 874                            }
 875
 876                            credentials = Some(creds);
 877                        },
 878                        Err(err) => {
 879                            self.set_status(Status::AuthenticationError, cx);
 880                            return Err(err);
 881                        }
 882                    }
 883                }
 884                _ = status_rx.next().fuse() => {
 885                    return Err(anyhow!("authentication canceled"));
 886                }
 887            }
 888        }
 889
 890        let credentials = credentials.unwrap();
 891        self.set_id(credentials.user_id);
 892        self.cloud_client
 893            .set_credentials(credentials.user_id as u32, credentials.access_token.clone());
 894        self.state.write().credentials = Some(credentials.clone());
 895        self.set_status(
 896            if is_reauthenticating {
 897                Status::Reauthenticated
 898            } else {
 899                Status::Authenticated
 900            },
 901            cx,
 902        );
 903
 904        Ok(credentials)
 905    }
 906
 907    async fn validate_credentials(
 908        self: &Arc<Self>,
 909        credentials: &Credentials,
 910        cx: &AsyncApp,
 911    ) -> Result<bool> {
 912        match self
 913            .cloud_client
 914            .validate_credentials(credentials.user_id as u32, &credentials.access_token)
 915            .await
 916        {
 917            Ok(valid) => Ok(valid),
 918            Err(err) => {
 919                self.set_status(Status::AuthenticationError, cx);
 920                Err(anyhow!("failed to validate credentials: {}", err))
 921            }
 922        }
 923    }
 924
 925    /// Establishes a WebSocket connection with Cloud for receiving updates from the server.
 926    async fn connect_to_cloud(self: &Arc<Self>, cx: &AsyncApp) -> Result<()> {
 927        let connect_task = cx.update({
 928            let cloud_client = self.cloud_client.clone();
 929            move |cx| cloud_client.connect(cx)
 930        })??;
 931        let connection = connect_task.await?;
 932
 933        let (mut messages, task) = cx.update(|cx| connection.spawn(cx))?;
 934        task.detach();
 935
 936        cx.spawn({
 937            let this = self.clone();
 938            async move |cx| {
 939                while let Some(message) = messages.next().await {
 940                    if let Some(message) = message.log_err() {
 941                        this.handle_message_to_client(message, cx);
 942                    }
 943                }
 944            }
 945        })
 946        .detach();
 947
 948        Ok(())
 949    }
 950
 951    /// Performs a sign-in and also (optionally) connects to Collab.
 952    ///
 953    /// Only Zed staff automatically connect to Collab.
 954    pub async fn sign_in_with_optional_connect(
 955        self: &Arc<Self>,
 956        try_provider: bool,
 957        cx: &AsyncApp,
 958    ) -> Result<()> {
 959        // Don't try to sign in again if we're already connected to Collab, as it will temporarily disconnect us.
 960        if self.status().borrow().is_connected() {
 961            return Ok(());
 962        }
 963
 964        let (is_staff_tx, is_staff_rx) = oneshot::channel::<bool>();
 965        let mut is_staff_tx = Some(is_staff_tx);
 966        cx.update(|cx| {
 967            cx.on_flags_ready(move |state, _cx| {
 968                if let Some(is_staff_tx) = is_staff_tx.take() {
 969                    is_staff_tx.send(state.is_staff).log_err();
 970                }
 971            })
 972            .detach();
 973        })
 974        .log_err();
 975
 976        let credentials = self.sign_in(try_provider, cx).await?;
 977
 978        self.connect_to_cloud(cx).await.log_err();
 979
 980        cx.update(move |cx| {
 981            cx.spawn({
 982                let client = self.clone();
 983                async move |cx| {
 984                    let is_staff = is_staff_rx.await?;
 985                    if is_staff {
 986                        match client.connect_with_credentials(credentials, cx).await {
 987                            ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
 988                            ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
 989                            ConnectionResult::Result(result) => {
 990                                result.context("client auth and connect")
 991                            }
 992                        }
 993                    } else {
 994                        Ok(())
 995                    }
 996                }
 997            })
 998            .detach_and_log_err(cx);
 999        })
1000        .log_err();
1001
1002        Ok(())
1003    }
1004
1005    pub async fn connect(
1006        self: &Arc<Self>,
1007        try_provider: bool,
1008        cx: &AsyncApp,
1009    ) -> ConnectionResult<()> {
1010        let was_disconnected = match *self.status().borrow() {
1011            Status::SignedOut | Status::Authenticated => true,
1012            Status::ConnectionError
1013            | Status::ConnectionLost
1014            | Status::Authenticating
1015            | Status::AuthenticationError
1016            | Status::Reauthenticating
1017            | Status::Reauthenticated
1018            | Status::ReconnectionError { .. } => false,
1019            Status::Connected { .. } | Status::Connecting | Status::Reconnecting => {
1020                return ConnectionResult::Result(Ok(()));
1021            }
1022            Status::UpgradeRequired => {
1023                return ConnectionResult::Result(
1024                    Err(EstablishConnectionError::UpgradeRequired)
1025                        .context("client auth and connect"),
1026                );
1027            }
1028        };
1029        let credentials = match self.sign_in(try_provider, cx).await {
1030            Ok(credentials) => credentials,
1031            Err(err) => return ConnectionResult::Result(Err(err)),
1032        };
1033
1034        if was_disconnected {
1035            self.set_status(Status::Connecting, cx);
1036        } else {
1037            self.set_status(Status::Reconnecting, cx);
1038        }
1039
1040        self.connect_with_credentials(credentials, cx).await
1041    }
1042
1043    async fn connect_with_credentials(
1044        self: &Arc<Self>,
1045        credentials: Credentials,
1046        cx: &AsyncApp,
1047    ) -> ConnectionResult<()> {
1048        let mut timeout =
1049            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
1050        futures::select_biased! {
1051            connection = self.establish_connection(&credentials, cx).fuse() => {
1052                match connection {
1053                    Ok(conn) => {
1054                        futures::select_biased! {
1055                            result = self.set_connection(conn, cx).fuse() => {
1056                                match result.context("client auth and connect") {
1057                                    Ok(()) => ConnectionResult::Result(Ok(())),
1058                                    Err(err) => {
1059                                        self.set_status(Status::ConnectionError, cx);
1060                                        ConnectionResult::Result(Err(err))
1061                                    },
1062                                }
1063                            },
1064                            _ = timeout => {
1065                                self.set_status(Status::ConnectionError, cx);
1066                                ConnectionResult::Timeout
1067                            }
1068                        }
1069                    }
1070                    Err(EstablishConnectionError::Unauthorized) => {
1071                        self.set_status(Status::ConnectionError, cx);
1072                        ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
1073                    }
1074                    Err(EstablishConnectionError::UpgradeRequired) => {
1075                        self.set_status(Status::UpgradeRequired, cx);
1076                        ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
1077                    }
1078                    Err(error) => {
1079                        self.set_status(Status::ConnectionError, cx);
1080                        ConnectionResult::Result(Err(error).context("client auth and connect"))
1081                    }
1082                }
1083            }
1084            _ = &mut timeout => {
1085                self.set_status(Status::ConnectionError, cx);
1086                ConnectionResult::Timeout
1087            }
1088        }
1089    }
1090
1091    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
1092        let executor = cx.background_executor();
1093        log::debug!("add connection to peer");
1094        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
1095            let executor = executor.clone();
1096            move |duration| executor.timer(duration)
1097        });
1098        let handle_io = executor.spawn(handle_io);
1099
1100        let peer_id = async {
1101            log::debug!("waiting for server hello");
1102            let message = incoming.next().await.context("no hello message received")?;
1103            log::debug!("got server hello");
1104            let hello_message_type_name = message.payload_type_name().to_string();
1105            let hello = message
1106                .into_any()
1107                .downcast::<TypedEnvelope<proto::Hello>>()
1108                .map_err(|_| {
1109                    anyhow!(
1110                        "invalid hello message received: {:?}",
1111                        hello_message_type_name
1112                    )
1113                })?;
1114            let peer_id = hello.payload.peer_id.context("invalid peer id")?;
1115            Ok(peer_id)
1116        };
1117
1118        let peer_id = match peer_id.await {
1119            Ok(peer_id) => peer_id,
1120            Err(error) => {
1121                self.peer.disconnect(connection_id);
1122                return Err(error);
1123            }
1124        };
1125
1126        log::debug!(
1127            "set status to connected (connection id: {:?}, peer id: {:?})",
1128            connection_id,
1129            peer_id
1130        );
1131        self.set_status(
1132            Status::Connected {
1133                peer_id,
1134                connection_id,
1135            },
1136            cx,
1137        );
1138
1139        cx.spawn({
1140            let this = self.clone();
1141            async move |cx| {
1142                while let Some(message) = incoming.next().await {
1143                    this.handle_message(message, cx);
1144                    // Don't starve the main thread when receiving lots of messages at once.
1145                    smol::future::yield_now().await;
1146                }
1147            }
1148        })
1149        .detach();
1150
1151        cx.spawn({
1152            let this = self.clone();
1153            async move |cx| match handle_io.await {
1154                Ok(()) => {
1155                    if *this.status().borrow()
1156                        == (Status::Connected {
1157                            connection_id,
1158                            peer_id,
1159                        })
1160                    {
1161                        this.set_status(Status::SignedOut, cx);
1162                    }
1163                }
1164                Err(err) => {
1165                    log::error!("connection error: {:?}", err);
1166                    this.set_status(Status::ConnectionLost, cx);
1167                }
1168            }
1169        })
1170        .detach();
1171
1172        Ok(())
1173    }
1174
1175    fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1176        #[cfg(any(test, feature = "test-support"))]
1177        if let Some(callback) = self.authenticate.read().as_ref() {
1178            return callback(cx);
1179        }
1180
1181        self.authenticate_with_browser(cx)
1182    }
1183
1184    fn establish_connection(
1185        self: &Arc<Self>,
1186        credentials: &Credentials,
1187        cx: &AsyncApp,
1188    ) -> Task<Result<Connection, EstablishConnectionError>> {
1189        #[cfg(any(test, feature = "test-support"))]
1190        if let Some(callback) = self.establish_connection.read().as_ref() {
1191            return callback(credentials, cx);
1192        }
1193
1194        self.establish_websocket_connection(credentials, cx)
1195    }
1196
1197    fn rpc_url(
1198        &self,
1199        http: Arc<HttpClientWithUrl>,
1200        release_channel: Option<ReleaseChannel>,
1201    ) -> impl Future<Output = Result<url::Url>> + use<> {
1202        #[cfg(any(test, feature = "test-support"))]
1203        let url_override = self.rpc_url.read().clone();
1204
1205        async move {
1206            #[cfg(any(test, feature = "test-support"))]
1207            if let Some(url) = url_override {
1208                return Ok(url);
1209            }
1210
1211            if let Some(url) = &*ZED_RPC_URL {
1212                return Url::parse(url).context("invalid rpc url");
1213            }
1214
1215            let mut url = http.build_url("/rpc");
1216            if let Some(preview_param) =
1217                release_channel.and_then(|channel| channel.release_query_param())
1218            {
1219                url += "?";
1220                url += preview_param;
1221            }
1222
1223            let response = http.get(&url, Default::default(), false).await?;
1224            anyhow::ensure!(
1225                response.status().is_redirection(),
1226                "unexpected /rpc response status {}",
1227                response.status()
1228            );
1229            let collab_url = response
1230                .headers()
1231                .get("Location")
1232                .context("missing location header in /rpc response")?
1233                .to_str()
1234                .map_err(EstablishConnectionError::other)?
1235                .to_string();
1236            Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
1237        }
1238    }
1239
1240    fn establish_websocket_connection(
1241        self: &Arc<Self>,
1242        credentials: &Credentials,
1243        cx: &AsyncApp,
1244    ) -> Task<Result<Connection, EstablishConnectionError>> {
1245        let release_channel = cx
1246            .update(|cx| ReleaseChannel::try_global(cx))
1247            .ok()
1248            .flatten();
1249        let app_version = cx
1250            .update(|cx| AppVersion::global(cx).to_string())
1251            .ok()
1252            .unwrap_or_default();
1253
1254        let http = self.http.clone();
1255        let proxy = http.proxy().cloned();
1256        let user_agent = http.user_agent().cloned();
1257        let credentials = credentials.clone();
1258        let rpc_url = self.rpc_url(http, release_channel);
1259        let system_id = self.telemetry.system_id();
1260        let metrics_id = self.telemetry.metrics_id();
1261        cx.spawn(async move |cx| {
1262            use HttpOrHttps::*;
1263
1264            #[derive(Debug)]
1265            enum HttpOrHttps {
1266                Http,
1267                Https,
1268            }
1269
1270            let mut rpc_url = rpc_url.await?;
1271            let url_scheme = match rpc_url.scheme() {
1272                "https" => Https,
1273                "http" => Http,
1274                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1275            };
1276
1277            let stream = gpui_tokio::Tokio::spawn_result(cx, {
1278                let rpc_url = rpc_url.clone();
1279                async move {
1280                    let rpc_host = rpc_url
1281                        .host_str()
1282                        .zip(rpc_url.port_or_known_default())
1283                        .context("missing host in rpc url")?;
1284                    Ok(match proxy {
1285                        Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
1286                        None => Box::new(TcpStream::connect(rpc_host).await?),
1287                    })
1288                }
1289            })?
1290            .await?;
1291
1292            log::info!("connected to rpc endpoint {}", rpc_url);
1293
1294            rpc_url
1295                .set_scheme(match url_scheme {
1296                    Https => "wss",
1297                    Http => "ws",
1298                })
1299                .unwrap();
1300
1301            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1302            // for us from the RPC URL.
1303            //
1304            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1305            let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1306
1307            // We then modify the request to add our desired headers.
1308            let request_headers = request.headers_mut();
1309            request_headers.insert(
1310                http::header::AUTHORIZATION,
1311                HeaderValue::from_str(&credentials.authorization_header())?,
1312            );
1313            request_headers.insert(
1314                "x-zed-protocol-version",
1315                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1316            );
1317            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1318            request_headers.insert(
1319                "x-zed-release-channel",
1320                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1321            );
1322            if let Some(user_agent) = user_agent {
1323                request_headers.insert(http::header::USER_AGENT, user_agent);
1324            }
1325            if let Some(system_id) = system_id {
1326                request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1327            }
1328            if let Some(metrics_id) = metrics_id {
1329                request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1330            }
1331
1332            let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1333                request,
1334                stream,
1335                Some(Arc::new(http_client_tls::tls_config()).into()),
1336                None,
1337            )
1338            .await?;
1339
1340            Ok(Connection::new(
1341                stream
1342                    .map_err(|error| anyhow!(error))
1343                    .sink_map_err(|error| anyhow!(error)),
1344            ))
1345        })
1346    }
1347
1348    pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1349        let http = self.http.clone();
1350        let this = self.clone();
1351        cx.spawn(async move |cx| {
1352            let background = cx.background_executor().clone();
1353
1354            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1355            cx.update(|cx| {
1356                cx.spawn(async move |cx| {
1357                    let url = open_url_rx.await?;
1358                    cx.update(|cx| cx.open_url(&url))
1359                })
1360                .detach_and_log_err(cx);
1361            })
1362            .log_err();
1363
1364            let credentials = background
1365                .clone()
1366                .spawn(async move {
1367                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1368                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1369                    // any other app running on the user's device.
1370                    let (public_key, private_key) =
1371                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1372                    let public_key_string = String::try_from(public_key)
1373                        .expect("failed to serialize public key for auth");
1374
1375                    if let Some((login, token)) =
1376                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1377                    {
1378                        if !*USE_WEB_LOGIN {
1379                            eprintln!("authenticate as admin {login}, {token}");
1380
1381                            return this
1382                                .authenticate_as_admin(http, login.clone(), token.clone())
1383                                .await;
1384                        }
1385                    }
1386
1387                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1388                    let server =
1389                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1390                    let port = server.server_addr().port();
1391
1392                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1393                    // that the user is signing in from a Zed app running on the same device.
1394                    let mut url = http.build_url(&format!(
1395                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1396                        port, public_key_string
1397                    ));
1398
1399                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1400                        log::info!("impersonating user @{}", impersonate_login);
1401                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1402                    }
1403
1404                    open_url_tx.send(url).log_err();
1405
1406                    #[derive(Deserialize)]
1407                    struct CallbackParams {
1408                        pub user_id: String,
1409                        pub access_token: String,
1410                    }
1411
1412                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1413                    // access token from the query params.
1414                    //
1415                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1416                    // custom URL scheme instead of this local HTTP server.
1417                    let (user_id, access_token) = background
1418                        .spawn(async move {
1419                            for _ in 0..100 {
1420                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1421                                    let path = req.url();
1422                                    let url = Url::parse(&format!("http://example.com{}", path))
1423                                        .context("failed to parse login notification url")?;
1424                                    let callback_params: CallbackParams =
1425                                        serde_urlencoded::from_str(url.query().unwrap_or_default())
1426                                            .context(
1427                                                "failed to parse sign-in callback query parameters",
1428                                            )?;
1429
1430                                    let post_auth_url =
1431                                        http.build_url("/native_app_signin_succeeded");
1432                                    req.respond(
1433                                        tiny_http::Response::empty(302).with_header(
1434                                            tiny_http::Header::from_bytes(
1435                                                &b"Location"[..],
1436                                                post_auth_url.as_bytes(),
1437                                            )
1438                                            .unwrap(),
1439                                        ),
1440                                    )
1441                                    .context("failed to respond to login http request")?;
1442                                    return Ok((
1443                                        callback_params.user_id,
1444                                        callback_params.access_token,
1445                                    ));
1446                                }
1447                            }
1448
1449                            anyhow::bail!("didn't receive login redirect");
1450                        })
1451                        .await?;
1452
1453                    let access_token = private_key
1454                        .decrypt_string(&access_token)
1455                        .context("failed to decrypt access token")?;
1456
1457                    Ok(Credentials {
1458                        user_id: user_id.parse()?,
1459                        access_token,
1460                    })
1461                })
1462                .await?;
1463
1464            cx.update(|cx| cx.activate(true))?;
1465            Ok(credentials)
1466        })
1467    }
1468
1469    async fn authenticate_as_admin(
1470        self: &Arc<Self>,
1471        http: Arc<HttpClientWithUrl>,
1472        login: String,
1473        api_token: String,
1474    ) -> Result<Credentials> {
1475        #[derive(Serialize)]
1476        struct ImpersonateUserBody {
1477            github_login: String,
1478        }
1479
1480        #[derive(Deserialize)]
1481        struct ImpersonateUserResponse {
1482            user_id: u64,
1483            access_token: String,
1484        }
1485
1486        let url = self
1487            .http
1488            .build_zed_cloud_url("/internal/users/impersonate")?;
1489        let request = Request::post(url.as_str())
1490            .header("Content-Type", "application/json")
1491            .header("Authorization", format!("Bearer {api_token}"))
1492            .body(
1493                serde_json::to_string(&ImpersonateUserBody {
1494                    github_login: login,
1495                })?
1496                .into(),
1497            )?;
1498
1499        let mut response = http.send(request).await?;
1500        let mut body = String::new();
1501        response.body_mut().read_to_string(&mut body).await?;
1502        anyhow::ensure!(
1503            response.status().is_success(),
1504            "admin user request failed {} - {}",
1505            response.status().as_u16(),
1506            body,
1507        );
1508        let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
1509
1510        Ok(Credentials {
1511            user_id: response.user_id,
1512            access_token: response.access_token,
1513        })
1514    }
1515
1516    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1517        self.state.write().credentials = None;
1518        self.cloud_client.clear_credentials();
1519        self.disconnect(cx);
1520
1521        if self.has_credentials(cx).await {
1522            self.credentials_provider
1523                .delete_credentials(cx)
1524                .await
1525                .log_err();
1526        }
1527    }
1528
1529    pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1530        self.peer.teardown();
1531        self.set_status(Status::SignedOut, cx);
1532    }
1533
1534    pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1535        self.peer.teardown();
1536        self.set_status(Status::ConnectionLost, cx);
1537    }
1538
1539    fn connection_id(&self) -> Result<ConnectionId> {
1540        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1541            Ok(connection_id)
1542        } else {
1543            anyhow::bail!("not connected");
1544        }
1545    }
1546
1547    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1548        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1549        self.peer.send(self.connection_id()?, message)
1550    }
1551
1552    pub fn request<T: RequestMessage>(
1553        &self,
1554        request: T,
1555    ) -> impl Future<Output = Result<T::Response>> + use<T> {
1556        self.request_envelope(request)
1557            .map_ok(|envelope| envelope.payload)
1558    }
1559
1560    pub fn request_stream<T: RequestMessage>(
1561        &self,
1562        request: T,
1563    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1564        let client_id = self.id.load(Ordering::SeqCst);
1565        log::debug!(
1566            "rpc request start. client_id:{}. name:{}",
1567            client_id,
1568            T::NAME
1569        );
1570        let response = self
1571            .connection_id()
1572            .map(|conn_id| self.peer.request_stream(conn_id, request));
1573        async move {
1574            let response = response?.await;
1575            log::debug!(
1576                "rpc request finish. client_id:{}. name:{}",
1577                client_id,
1578                T::NAME
1579            );
1580            response
1581        }
1582    }
1583
1584    pub fn request_envelope<T: RequestMessage>(
1585        &self,
1586        request: T,
1587    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1588        let client_id = self.id();
1589        log::debug!(
1590            "rpc request start. client_id:{}. name:{}",
1591            client_id,
1592            T::NAME
1593        );
1594        let response = self
1595            .connection_id()
1596            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1597        async move {
1598            let response = response?.await;
1599            log::debug!(
1600                "rpc request finish. client_id:{}. name:{}",
1601                client_id,
1602                T::NAME
1603            );
1604            response
1605        }
1606    }
1607
1608    pub fn request_dynamic(
1609        &self,
1610        envelope: proto::Envelope,
1611        request_type: &'static str,
1612    ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1613        let client_id = self.id();
1614        log::debug!(
1615            "rpc request start. client_id:{}. name:{}",
1616            client_id,
1617            request_type
1618        );
1619        let response = self
1620            .connection_id()
1621            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1622        async move {
1623            let response = response?.await;
1624            log::debug!(
1625                "rpc request finish. client_id:{}. name:{}",
1626                client_id,
1627                request_type
1628            );
1629            Ok(response?.0)
1630        }
1631    }
1632
1633    fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1634        let sender_id = message.sender_id();
1635        let request_id = message.message_id();
1636        let type_name = message.payload_type_name();
1637        let original_sender_id = message.original_sender_id();
1638
1639        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1640            &self.handler_set,
1641            message,
1642            self.clone().into(),
1643            cx.clone(),
1644        ) {
1645            let client_id = self.id();
1646            log::debug!(
1647                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1648                client_id,
1649                original_sender_id,
1650                type_name
1651            );
1652            cx.spawn(async move |_| match future.await {
1653                Ok(()) => {
1654                    log::debug!("rpc message handled. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}");
1655                }
1656                Err(error) => {
1657                    log::error!("error handling message. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}, error:{error:#}");
1658                }
1659            })
1660            .detach();
1661        } else {
1662            log::info!("unhandled message {}", type_name);
1663            self.peer
1664                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1665                .log_err();
1666        }
1667    }
1668
1669    pub fn add_message_to_client_handler(
1670        self: &Arc<Client>,
1671        handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
1672    ) {
1673        self.message_to_client_handlers
1674            .lock()
1675            .push(Box::new(handler));
1676    }
1677
1678    fn handle_message_to_client(self: &Arc<Client>, message: MessageToClient, cx: &AsyncApp) {
1679        cx.update(|cx| {
1680            for handler in self.message_to_client_handlers.lock().iter() {
1681                handler(&message, cx);
1682            }
1683        })
1684        .ok();
1685    }
1686
1687    pub fn telemetry(&self) -> &Arc<Telemetry> {
1688        &self.telemetry
1689    }
1690}
1691
1692impl ProtoClient for Client {
1693    fn request(
1694        &self,
1695        envelope: proto::Envelope,
1696        request_type: &'static str,
1697    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1698        self.request_dynamic(envelope, request_type).boxed()
1699    }
1700
1701    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1702        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1703        let connection_id = self.connection_id()?;
1704        self.peer.send_dynamic(connection_id, envelope)
1705    }
1706
1707    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1708        log::debug!(
1709            "rpc respond. client_id:{}, name:{}",
1710            self.id(),
1711            message_type
1712        );
1713        let connection_id = self.connection_id()?;
1714        self.peer.send_dynamic(connection_id, envelope)
1715    }
1716
1717    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1718        &self.handler_set
1719    }
1720
1721    fn is_via_collab(&self) -> bool {
1722        true
1723    }
1724
1725    fn has_wsl_interop(&self) -> bool {
1726        false
1727    }
1728}
1729
1730/// prefix for the zed:// url scheme
1731pub const ZED_URL_SCHEME: &str = "zed";
1732
1733/// Parses the given link into a Zed link.
1734///
1735/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1736/// Returns [`None`] otherwise.
1737pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
1738    let server_url = &ClientSettings::get_global(cx).server_url;
1739    if let Some(stripped) = link
1740        .strip_prefix(server_url)
1741        .and_then(|result| result.strip_prefix('/'))
1742    {
1743        return Some(stripped);
1744    }
1745    if let Some(stripped) = link
1746        .strip_prefix(ZED_URL_SCHEME)
1747        .and_then(|result| result.strip_prefix("://"))
1748    {
1749        return Some(stripped);
1750    }
1751
1752    None
1753}
1754
1755#[cfg(test)]
1756mod tests {
1757    use super::*;
1758    use crate::test::{FakeServer, parse_authorization_header};
1759
1760    use clock::FakeSystemClock;
1761    use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1762    use http_client::FakeHttpClient;
1763    use parking_lot::Mutex;
1764    use proto::TypedEnvelope;
1765    use settings::SettingsStore;
1766    use std::future;
1767
1768    #[gpui::test(iterations = 10)]
1769    async fn test_reconnection(cx: &mut TestAppContext) {
1770        init_test(cx);
1771        let user_id = 5;
1772        let client = cx.update(|cx| {
1773            Client::new(
1774                Arc::new(FakeSystemClock::new()),
1775                FakeHttpClient::with_404_response(),
1776                cx,
1777            )
1778        });
1779        let server = FakeServer::for_client(user_id, &client, cx).await;
1780        let mut status = client.status();
1781        assert!(matches!(
1782            status.next().await,
1783            Some(Status::Connected { .. })
1784        ));
1785        assert_eq!(server.auth_count(), 1);
1786
1787        server.forbid_connections();
1788        server.disconnect();
1789        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1790
1791        server.allow_connections();
1792        cx.executor().advance_clock(Duration::from_secs(10));
1793        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1794        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1795
1796        server.forbid_connections();
1797        server.disconnect();
1798        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1799
1800        // Clear cached credentials after authentication fails
1801        server.roll_access_token();
1802        server.allow_connections();
1803        cx.executor().run_until_parked();
1804        cx.executor().advance_clock(Duration::from_secs(10));
1805        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1806        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1807    }
1808
1809    #[gpui::test(iterations = 10)]
1810    async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) {
1811        init_test(cx);
1812        let http_client = FakeHttpClient::with_200_response();
1813        let client =
1814            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1815        let server = FakeServer::for_client(42, &client, cx).await;
1816        let mut status = client.status();
1817        assert!(matches!(
1818            status.next().await,
1819            Some(Status::Connected { .. })
1820        ));
1821        assert_eq!(server.auth_count(), 1);
1822
1823        // Simulate an auth failure during reconnection.
1824        http_client
1825            .as_fake()
1826            .replace_handler(|_, _request| async move {
1827                Ok(http_client::Response::builder()
1828                    .status(503)
1829                    .body("".into())
1830                    .unwrap())
1831            });
1832        server.disconnect();
1833        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1834
1835        // Restore the ability to authenticate.
1836        http_client
1837            .as_fake()
1838            .replace_handler(|_, _request| async move {
1839                Ok(http_client::Response::builder()
1840                    .status(200)
1841                    .body("".into())
1842                    .unwrap())
1843            });
1844        cx.executor().advance_clock(Duration::from_secs(10));
1845        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1846        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1847    }
1848
1849    #[gpui::test(iterations = 10)]
1850    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1851        init_test(cx);
1852        let user_id = 5;
1853        let client = cx.update(|cx| {
1854            Client::new(
1855                Arc::new(FakeSystemClock::new()),
1856                FakeHttpClient::with_404_response(),
1857                cx,
1858            )
1859        });
1860        let mut status = client.status();
1861
1862        // Time out when client tries to connect.
1863        client.override_authenticate(move |cx| {
1864            cx.background_spawn(async move {
1865                Ok(Credentials {
1866                    user_id,
1867                    access_token: "token".into(),
1868                })
1869            })
1870        });
1871        client.override_establish_connection(|_, cx| {
1872            cx.background_spawn(async move {
1873                future::pending::<()>().await;
1874                unreachable!()
1875            })
1876        });
1877        let auth_and_connect = cx.spawn({
1878            let client = client.clone();
1879            |cx| async move { client.connect(false, &cx).await }
1880        });
1881        executor.run_until_parked();
1882        assert!(matches!(status.next().await, Some(Status::Connecting)));
1883
1884        executor.advance_clock(CONNECTION_TIMEOUT);
1885        assert!(matches!(status.next().await, Some(Status::ConnectionError)));
1886        auth_and_connect.await.into_response().unwrap_err();
1887
1888        // Allow the connection to be established.
1889        let server = FakeServer::for_client(user_id, &client, cx).await;
1890        assert!(matches!(
1891            status.next().await,
1892            Some(Status::Connected { .. })
1893        ));
1894
1895        // Disconnect client.
1896        server.forbid_connections();
1897        server.disconnect();
1898        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1899
1900        // Time out when re-establishing the connection.
1901        server.allow_connections();
1902        client.override_establish_connection(|_, cx| {
1903            cx.background_spawn(async move {
1904                future::pending::<()>().await;
1905                unreachable!()
1906            })
1907        });
1908        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1909        assert!(matches!(status.next().await, Some(Status::Reconnecting)));
1910
1911        executor.advance_clock(CONNECTION_TIMEOUT);
1912        assert!(matches!(
1913            status.next().await,
1914            Some(Status::ReconnectionError { .. })
1915        ));
1916    }
1917
1918    #[gpui::test(iterations = 10)]
1919    async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
1920        init_test(cx);
1921        let auth_count = Arc::new(Mutex::new(0));
1922        let http_client = FakeHttpClient::create(|_request| async move {
1923            Ok(http_client::Response::builder()
1924                .status(200)
1925                .body("".into())
1926                .unwrap())
1927        });
1928        let client =
1929            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1930        client.override_authenticate({
1931            let auth_count = auth_count.clone();
1932            move |cx| {
1933                let auth_count = auth_count.clone();
1934                cx.background_spawn(async move {
1935                    *auth_count.lock() += 1;
1936                    Ok(Credentials {
1937                        user_id: 1,
1938                        access_token: auth_count.lock().to_string(),
1939                    })
1940                })
1941            }
1942        });
1943
1944        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1945        assert_eq!(*auth_count.lock(), 1);
1946        assert_eq!(credentials.access_token, "1");
1947
1948        // If credentials are still valid, signing in doesn't trigger authentication.
1949        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1950        assert_eq!(*auth_count.lock(), 1);
1951        assert_eq!(credentials.access_token, "1");
1952
1953        // If the server is unavailable, signing in doesn't trigger authentication.
1954        http_client
1955            .as_fake()
1956            .replace_handler(|_, _request| async move {
1957                Ok(http_client::Response::builder()
1958                    .status(503)
1959                    .body("".into())
1960                    .unwrap())
1961            });
1962        client.sign_in(false, &cx.to_async()).await.unwrap_err();
1963        assert_eq!(*auth_count.lock(), 1);
1964
1965        // If credentials became invalid, signing in triggers authentication.
1966        http_client
1967            .as_fake()
1968            .replace_handler(|_, request| async move {
1969                let credentials = parse_authorization_header(&request).unwrap();
1970                if credentials.access_token == "2" {
1971                    Ok(http_client::Response::builder()
1972                        .status(200)
1973                        .body("".into())
1974                        .unwrap())
1975                } else {
1976                    Ok(http_client::Response::builder()
1977                        .status(401)
1978                        .body("".into())
1979                        .unwrap())
1980                }
1981            });
1982        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1983        assert_eq!(*auth_count.lock(), 2);
1984        assert_eq!(credentials.access_token, "2");
1985    }
1986
1987    #[gpui::test(iterations = 10)]
1988    async fn test_authenticating_more_than_once(
1989        cx: &mut TestAppContext,
1990        executor: BackgroundExecutor,
1991    ) {
1992        init_test(cx);
1993        let auth_count = Arc::new(Mutex::new(0));
1994        let dropped_auth_count = Arc::new(Mutex::new(0));
1995        let client = cx.update(|cx| {
1996            Client::new(
1997                Arc::new(FakeSystemClock::new()),
1998                FakeHttpClient::with_404_response(),
1999                cx,
2000            )
2001        });
2002        client.override_authenticate({
2003            let auth_count = auth_count.clone();
2004            let dropped_auth_count = dropped_auth_count.clone();
2005            move |cx| {
2006                let auth_count = auth_count.clone();
2007                let dropped_auth_count = dropped_auth_count.clone();
2008                cx.background_spawn(async move {
2009                    *auth_count.lock() += 1;
2010                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
2011                    future::pending::<()>().await;
2012                    unreachable!()
2013                })
2014            }
2015        });
2016
2017        let _authenticate = cx.spawn({
2018            let client = client.clone();
2019            move |cx| async move { client.connect(false, &cx).await }
2020        });
2021        executor.run_until_parked();
2022        assert_eq!(*auth_count.lock(), 1);
2023        assert_eq!(*dropped_auth_count.lock(), 0);
2024
2025        let _authenticate = cx.spawn(|cx| async move { client.connect(false, &cx).await });
2026        executor.run_until_parked();
2027        assert_eq!(*auth_count.lock(), 2);
2028        assert_eq!(*dropped_auth_count.lock(), 1);
2029    }
2030
2031    #[gpui::test]
2032    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
2033        init_test(cx);
2034        let user_id = 5;
2035        let client = cx.update(|cx| {
2036            Client::new(
2037                Arc::new(FakeSystemClock::new()),
2038                FakeHttpClient::with_404_response(),
2039                cx,
2040            )
2041        });
2042        let server = FakeServer::for_client(user_id, &client, cx).await;
2043
2044        let (done_tx1, done_rx1) = smol::channel::unbounded();
2045        let (done_tx2, done_rx2) = smol::channel::unbounded();
2046        AnyProtoClient::from(client.clone()).add_entity_message_handler(
2047            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, cx| {
2048                match entity.read_with(&cx, |entity, _| entity.id).unwrap() {
2049                    1 => done_tx1.try_send(()).unwrap(),
2050                    2 => done_tx2.try_send(()).unwrap(),
2051                    _ => unreachable!(),
2052                }
2053                async { Ok(()) }
2054            },
2055        );
2056        let entity1 = cx.new(|_| TestEntity {
2057            id: 1,
2058            subscription: None,
2059        });
2060        let entity2 = cx.new(|_| TestEntity {
2061            id: 2,
2062            subscription: None,
2063        });
2064        let entity3 = cx.new(|_| TestEntity {
2065            id: 3,
2066            subscription: None,
2067        });
2068
2069        let _subscription1 = client
2070            .subscribe_to_entity(1)
2071            .unwrap()
2072            .set_entity(&entity1, &cx.to_async());
2073        let _subscription2 = client
2074            .subscribe_to_entity(2)
2075            .unwrap()
2076            .set_entity(&entity2, &cx.to_async());
2077        // Ensure dropping a subscription for the same entity type still allows receiving of
2078        // messages for other entity IDs of the same type.
2079        let subscription3 = client
2080            .subscribe_to_entity(3)
2081            .unwrap()
2082            .set_entity(&entity3, &cx.to_async());
2083        drop(subscription3);
2084
2085        server.send(proto::JoinProject {
2086            project_id: 1,
2087            committer_name: None,
2088            committer_email: None,
2089        });
2090        server.send(proto::JoinProject {
2091            project_id: 2,
2092            committer_name: None,
2093            committer_email: None,
2094        });
2095        done_rx1.recv().await.unwrap();
2096        done_rx2.recv().await.unwrap();
2097    }
2098
2099    #[gpui::test]
2100    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2101        init_test(cx);
2102        let user_id = 5;
2103        let client = cx.update(|cx| {
2104            Client::new(
2105                Arc::new(FakeSystemClock::new()),
2106                FakeHttpClient::with_404_response(),
2107                cx,
2108            )
2109        });
2110        let server = FakeServer::for_client(user_id, &client, cx).await;
2111
2112        let entity = cx.new(|_| TestEntity::default());
2113        let (done_tx1, _done_rx1) = smol::channel::unbounded();
2114        let (done_tx2, done_rx2) = smol::channel::unbounded();
2115        let subscription1 = client.add_message_handler(
2116            entity.downgrade(),
2117            move |_, _: TypedEnvelope<proto::Ping>, _| {
2118                done_tx1.try_send(()).unwrap();
2119                async { Ok(()) }
2120            },
2121        );
2122        drop(subscription1);
2123        let _subscription2 = client.add_message_handler(
2124            entity.downgrade(),
2125            move |_, _: TypedEnvelope<proto::Ping>, _| {
2126                done_tx2.try_send(()).unwrap();
2127                async { Ok(()) }
2128            },
2129        );
2130        server.send(proto::Ping {});
2131        done_rx2.recv().await.unwrap();
2132    }
2133
2134    #[gpui::test]
2135    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2136        init_test(cx);
2137        let user_id = 5;
2138        let client = cx.update(|cx| {
2139            Client::new(
2140                Arc::new(FakeSystemClock::new()),
2141                FakeHttpClient::with_404_response(),
2142                cx,
2143            )
2144        });
2145        let server = FakeServer::for_client(user_id, &client, cx).await;
2146
2147        let entity = cx.new(|_| TestEntity::default());
2148        let (done_tx, done_rx) = smol::channel::unbounded();
2149        let subscription = client.add_message_handler(
2150            entity.clone().downgrade(),
2151            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
2152                entity
2153                    .update(&mut cx, |entity, _| entity.subscription.take())
2154                    .unwrap();
2155                done_tx.try_send(()).unwrap();
2156                async { Ok(()) }
2157            },
2158        );
2159        entity.update(cx, |entity, _| {
2160            entity.subscription = Some(subscription);
2161        });
2162        server.send(proto::Ping {});
2163        done_rx.recv().await.unwrap();
2164    }
2165
2166    #[derive(Default)]
2167    struct TestEntity {
2168        id: usize,
2169        subscription: Option<Subscription>,
2170    }
2171
2172    fn init_test(cx: &mut TestAppContext) {
2173        cx.update(|cx| {
2174            let settings_store = SettingsStore::test(cx);
2175            cx.set_global(settings_store);
2176        });
2177    }
2178}