client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod proxy;
   5pub mod telemetry;
   6pub mod user;
   7pub mod zed_urls;
   8
   9use anyhow::{Context as _, Result, anyhow};
  10use async_tungstenite::tungstenite::{
  11    client::IntoClientRequest,
  12    error::Error as WebsocketError,
  13    http::{HeaderValue, Request, StatusCode},
  14};
  15use clock::SystemClock;
  16use cloud_api_client::CloudApiClient;
  17use cloud_api_client::websocket_protocol::MessageToClient;
  18use credentials_provider::CredentialsProvider;
  19use feature_flags::FeatureFlagAppExt as _;
  20use futures::{
  21    AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
  22    channel::oneshot, future::BoxFuture,
  23};
  24use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
  25use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
  26use parking_lot::RwLock;
  27use postage::watch;
  28use proxy::connect_proxy_stream;
  29use rand::prelude::*;
  30use release_channel::{AppVersion, ReleaseChannel};
  31use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  32use serde::{Deserialize, Serialize};
  33use settings::{Settings, SettingsContent};
  34use std::{
  35    any::TypeId,
  36    convert::TryFrom,
  37    fmt::Write as _,
  38    future::Future,
  39    marker::PhantomData,
  40    path::PathBuf,
  41    sync::{
  42        Arc, LazyLock, Weak,
  43        atomic::{AtomicU64, Ordering},
  44    },
  45    time::{Duration, Instant},
  46};
  47use std::{cmp, pin::Pin};
  48use telemetry::Telemetry;
  49use thiserror::Error;
  50use tokio::net::TcpStream;
  51use url::Url;
  52use util::{ConnectionResult, ResultExt};
  53
  54pub use rpc::*;
  55pub use telemetry_events::Event;
  56pub use user::*;
  57
  58static ZED_SERVER_URL: LazyLock<Option<String>> =
  59    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  60static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  61
  62pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  63    std::env::var("ZED_IMPERSONATE")
  64        .ok()
  65        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  66});
  67
  68pub static USE_WEB_LOGIN: LazyLock<bool> = LazyLock::new(|| std::env::var("ZED_WEB_LOGIN").is_ok());
  69
  70pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  71    std::env::var("ZED_ADMIN_API_TOKEN")
  72        .ok()
  73        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  74});
  75
  76pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  77    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  78
  79pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  80    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").is_ok_and(|e| !e.is_empty()));
  81
  82pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  83pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
  84pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  85
  86actions!(
  87    client,
  88    [
  89        /// Signs in to Zed account.
  90        SignIn,
  91        /// Signs out of Zed account.
  92        SignOut,
  93        /// Reconnects to the collaboration server.
  94        Reconnect
  95    ]
  96);
  97
  98#[derive(Deserialize)]
  99pub struct ClientSettings {
 100    pub server_url: String,
 101}
 102
 103impl Settings for ClientSettings {
 104    fn from_settings(content: &settings::SettingsContent) -> Self {
 105        if let Some(server_url) = &*ZED_SERVER_URL {
 106            return Self {
 107                server_url: server_url.clone(),
 108            };
 109        }
 110        Self {
 111            server_url: content.server_url.clone().unwrap(),
 112        }
 113    }
 114}
 115
 116#[derive(Deserialize, Default)]
 117pub struct ProxySettings {
 118    pub proxy: Option<String>,
 119}
 120
 121impl ProxySettings {
 122    pub fn proxy_url(&self) -> Option<Url> {
 123        self.proxy
 124            .as_ref()
 125            .and_then(|input| {
 126                input
 127                    .parse::<Url>()
 128                    .inspect_err(|e| log::error!("Error parsing proxy settings: {}", e))
 129                    .ok()
 130            })
 131            .or_else(read_proxy_from_env)
 132    }
 133}
 134
 135impl Settings for ProxySettings {
 136    fn from_settings(content: &settings::SettingsContent) -> Self {
 137        Self {
 138            proxy: content.proxy.clone(),
 139        }
 140    }
 141}
 142
 143pub fn init_settings(cx: &mut App) {
 144    TelemetrySettings::register(cx);
 145    ClientSettings::register(cx);
 146    ProxySettings::register(cx);
 147}
 148
 149pub fn init(client: &Arc<Client>, cx: &mut App) {
 150    let client = Arc::downgrade(client);
 151    cx.on_action({
 152        let client = client.clone();
 153        move |_: &SignIn, cx| {
 154            if let Some(client) = client.upgrade() {
 155                cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, cx).await)
 156                    .detach_and_log_err(cx);
 157            }
 158        }
 159    });
 160
 161    cx.on_action({
 162        let client = client.clone();
 163        move |_: &SignOut, cx| {
 164            if let Some(client) = client.upgrade() {
 165                cx.spawn(async move |cx| {
 166                    client.sign_out(cx).await;
 167                })
 168                .detach();
 169            }
 170        }
 171    });
 172
 173    cx.on_action({
 174        let client = client;
 175        move |_: &Reconnect, cx| {
 176            if let Some(client) = client.upgrade() {
 177                cx.spawn(async move |cx| {
 178                    client.reconnect(cx);
 179                })
 180                .detach();
 181            }
 182        }
 183    });
 184}
 185
 186pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
 187
 188struct GlobalClient(Arc<Client>);
 189
 190impl Global for GlobalClient {}
 191
 192pub struct Client {
 193    id: AtomicU64,
 194    peer: Arc<Peer>,
 195    http: Arc<HttpClientWithUrl>,
 196    cloud_client: Arc<CloudApiClient>,
 197    telemetry: Arc<Telemetry>,
 198    credentials_provider: ClientCredentialsProvider,
 199    state: RwLock<ClientState>,
 200    handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
 201    message_to_client_handlers: parking_lot::Mutex<Vec<MessageToClientHandler>>,
 202
 203    #[allow(clippy::type_complexity)]
 204    #[cfg(any(test, feature = "test-support"))]
 205    authenticate:
 206        RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
 207
 208    #[allow(clippy::type_complexity)]
 209    #[cfg(any(test, feature = "test-support"))]
 210    establish_connection: RwLock<
 211        Option<
 212            Box<
 213                dyn 'static
 214                    + Send
 215                    + Sync
 216                    + Fn(
 217                        &Credentials,
 218                        &AsyncApp,
 219                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 220            >,
 221        >,
 222    >,
 223
 224    #[cfg(any(test, feature = "test-support"))]
 225    rpc_url: RwLock<Option<Url>>,
 226}
 227
 228#[derive(Error, Debug)]
 229pub enum EstablishConnectionError {
 230    #[error("upgrade required")]
 231    UpgradeRequired,
 232    #[error("unauthorized")]
 233    Unauthorized,
 234    #[error("{0}")]
 235    Other(#[from] anyhow::Error),
 236    #[error("{0}")]
 237    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 238    #[error("{0}")]
 239    Io(#[from] std::io::Error),
 240    #[error("{0}")]
 241    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 242}
 243
 244impl From<WebsocketError> for EstablishConnectionError {
 245    fn from(error: WebsocketError) -> Self {
 246        if let WebsocketError::Http(response) = &error {
 247            match response.status() {
 248                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 249                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 250                _ => {}
 251            }
 252        }
 253        EstablishConnectionError::Other(error.into())
 254    }
 255}
 256
 257impl EstablishConnectionError {
 258    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 259        Self::Other(error.into())
 260    }
 261}
 262
 263#[derive(Copy, Clone, Debug, PartialEq)]
 264pub enum Status {
 265    SignedOut,
 266    UpgradeRequired,
 267    Authenticating,
 268    Authenticated,
 269    AuthenticationError,
 270    Connecting,
 271    ConnectionError,
 272    Connected {
 273        peer_id: PeerId,
 274        connection_id: ConnectionId,
 275    },
 276    ConnectionLost,
 277    Reauthenticating,
 278    Reauthenticated,
 279    Reconnecting,
 280    ReconnectionError {
 281        next_reconnection: Instant,
 282    },
 283}
 284
 285impl Status {
 286    pub fn is_connected(&self) -> bool {
 287        matches!(self, Self::Connected { .. })
 288    }
 289
 290    pub fn was_connected(&self) -> bool {
 291        matches!(
 292            self,
 293            Self::ConnectionLost
 294                | Self::Reauthenticating
 295                | Self::Reauthenticated
 296                | Self::Reconnecting
 297        )
 298    }
 299
 300    /// Returns whether the client is currently connected or was connected at some point.
 301    pub fn is_or_was_connected(&self) -> bool {
 302        self.is_connected() || self.was_connected()
 303    }
 304
 305    pub fn is_signing_in(&self) -> bool {
 306        matches!(
 307            self,
 308            Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
 309        )
 310    }
 311
 312    pub fn is_signed_out(&self) -> bool {
 313        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 314    }
 315}
 316
 317struct ClientState {
 318    credentials: Option<Credentials>,
 319    status: (watch::Sender<Status>, watch::Receiver<Status>),
 320    _reconnect_task: Option<Task<()>>,
 321}
 322
 323#[derive(Clone, Debug, Eq, PartialEq)]
 324pub struct Credentials {
 325    pub user_id: u64,
 326    pub access_token: String,
 327}
 328
 329impl Credentials {
 330    pub fn authorization_header(&self) -> String {
 331        format!("{} {}", self.user_id, self.access_token)
 332    }
 333}
 334
 335pub struct ClientCredentialsProvider {
 336    provider: Arc<dyn CredentialsProvider>,
 337}
 338
 339impl ClientCredentialsProvider {
 340    pub fn new(cx: &App) -> Self {
 341        Self {
 342            provider: <dyn CredentialsProvider>::global(cx),
 343        }
 344    }
 345
 346    fn server_url(&self, cx: &AsyncApp) -> Result<String> {
 347        cx.update(|cx| ClientSettings::get_global(cx).server_url.clone())
 348    }
 349
 350    /// Reads the credentials from the provider.
 351    fn read_credentials<'a>(
 352        &'a self,
 353        cx: &'a AsyncApp,
 354    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
 355        async move {
 356            if IMPERSONATE_LOGIN.is_some() {
 357                return None;
 358            }
 359
 360            let server_url = self.server_url(cx).ok()?;
 361            let (user_id, access_token) = self
 362                .provider
 363                .read_credentials(&server_url, cx)
 364                .await
 365                .log_err()
 366                .flatten()?;
 367
 368            Some(Credentials {
 369                user_id: user_id.parse().ok()?,
 370                access_token: String::from_utf8(access_token).ok()?,
 371            })
 372        }
 373        .boxed_local()
 374    }
 375
 376    /// Writes the credentials to the provider.
 377    fn write_credentials<'a>(
 378        &'a self,
 379        user_id: u64,
 380        access_token: String,
 381        cx: &'a AsyncApp,
 382    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 383        async move {
 384            let server_url = self.server_url(cx)?;
 385            self.provider
 386                .write_credentials(
 387                    &server_url,
 388                    &user_id.to_string(),
 389                    access_token.as_bytes(),
 390                    cx,
 391                )
 392                .await
 393        }
 394        .boxed_local()
 395    }
 396
 397    /// Deletes the credentials from the provider.
 398    fn delete_credentials<'a>(
 399        &'a self,
 400        cx: &'a AsyncApp,
 401    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 402        async move {
 403            let server_url = self.server_url(cx)?;
 404            self.provider.delete_credentials(&server_url, cx).await
 405        }
 406        .boxed_local()
 407    }
 408}
 409
 410impl Default for ClientState {
 411    fn default() -> Self {
 412        Self {
 413            credentials: None,
 414            status: watch::channel_with(Status::SignedOut),
 415            _reconnect_task: None,
 416        }
 417    }
 418}
 419
 420pub enum Subscription {
 421    Entity {
 422        client: Weak<Client>,
 423        id: (TypeId, u64),
 424    },
 425    Message {
 426        client: Weak<Client>,
 427        id: TypeId,
 428    },
 429}
 430
 431impl Drop for Subscription {
 432    fn drop(&mut self) {
 433        match self {
 434            Subscription::Entity { client, id } => {
 435                if let Some(client) = client.upgrade() {
 436                    let mut state = client.handler_set.lock();
 437                    let _ = state.entities_by_type_and_remote_id.remove(id);
 438                }
 439            }
 440            Subscription::Message { client, id } => {
 441                if let Some(client) = client.upgrade() {
 442                    let mut state = client.handler_set.lock();
 443                    let _ = state.entity_types_by_message_type.remove(id);
 444                    let _ = state.message_handlers.remove(id);
 445                }
 446            }
 447        }
 448    }
 449}
 450
 451pub struct PendingEntitySubscription<T: 'static> {
 452    client: Arc<Client>,
 453    remote_id: u64,
 454    _entity_type: PhantomData<T>,
 455    consumed: bool,
 456}
 457
 458impl<T: 'static> PendingEntitySubscription<T> {
 459    pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
 460        self.consumed = true;
 461        let mut handlers = self.client.handler_set.lock();
 462        let id = (TypeId::of::<T>(), self.remote_id);
 463        let Some(EntityMessageSubscriber::Pending(messages)) =
 464            handlers.entities_by_type_and_remote_id.remove(&id)
 465        else {
 466            unreachable!()
 467        };
 468
 469        handlers.entities_by_type_and_remote_id.insert(
 470            id,
 471            EntityMessageSubscriber::Entity {
 472                handle: entity.downgrade().into(),
 473            },
 474        );
 475        drop(handlers);
 476        for message in messages {
 477            let client_id = self.client.id();
 478            let type_name = message.payload_type_name();
 479            let sender_id = message.original_sender_id();
 480            log::debug!(
 481                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 482                client_id,
 483                sender_id,
 484                type_name
 485            );
 486            self.client.handle_message(message, cx);
 487        }
 488        Subscription::Entity {
 489            client: Arc::downgrade(&self.client),
 490            id,
 491        }
 492    }
 493}
 494
 495impl<T: 'static> Drop for PendingEntitySubscription<T> {
 496    fn drop(&mut self) {
 497        if !self.consumed {
 498            let mut state = self.client.handler_set.lock();
 499            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 500                .entities_by_type_and_remote_id
 501                .remove(&(TypeId::of::<T>(), self.remote_id))
 502            {
 503                for message in messages {
 504                    log::info!("unhandled message {}", message.payload_type_name());
 505                }
 506            }
 507        }
 508    }
 509}
 510
 511#[derive(Copy, Clone, Deserialize, Debug)]
 512pub struct TelemetrySettings {
 513    pub diagnostics: bool,
 514    pub metrics: bool,
 515}
 516
 517impl settings::Settings for TelemetrySettings {
 518    fn from_settings(content: &SettingsContent) -> Self {
 519        Self {
 520            diagnostics: content.telemetry.as_ref().unwrap().diagnostics.unwrap(),
 521            metrics: content.telemetry.as_ref().unwrap().metrics.unwrap(),
 522        }
 523    }
 524}
 525
 526impl Client {
 527    pub fn new(
 528        clock: Arc<dyn SystemClock>,
 529        http: Arc<HttpClientWithUrl>,
 530        cx: &mut App,
 531    ) -> Arc<Self> {
 532        Arc::new(Self {
 533            id: AtomicU64::new(0),
 534            peer: Peer::new(0),
 535            telemetry: Telemetry::new(clock, http.clone(), cx),
 536            cloud_client: Arc::new(CloudApiClient::new(http.clone())),
 537            http,
 538            credentials_provider: ClientCredentialsProvider::new(cx),
 539            state: Default::default(),
 540            handler_set: Default::default(),
 541            message_to_client_handlers: parking_lot::Mutex::new(Vec::new()),
 542
 543            #[cfg(any(test, feature = "test-support"))]
 544            authenticate: Default::default(),
 545            #[cfg(any(test, feature = "test-support"))]
 546            establish_connection: Default::default(),
 547            #[cfg(any(test, feature = "test-support"))]
 548            rpc_url: RwLock::default(),
 549        })
 550    }
 551
 552    pub fn production(cx: &mut App) -> Arc<Self> {
 553        let clock = Arc::new(clock::RealSystemClock);
 554        let http = Arc::new(HttpClientWithUrl::new_url(
 555            cx.http_client(),
 556            &ClientSettings::get_global(cx).server_url,
 557            cx.http_client().proxy().cloned(),
 558        ));
 559        Self::new(clock, http, cx)
 560    }
 561
 562    pub fn id(&self) -> u64 {
 563        self.id.load(Ordering::SeqCst)
 564    }
 565
 566    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 567        self.http.clone()
 568    }
 569
 570    pub fn cloud_client(&self) -> Arc<CloudApiClient> {
 571        self.cloud_client.clone()
 572    }
 573
 574    pub fn set_id(&self, id: u64) -> &Self {
 575        self.id.store(id, Ordering::SeqCst);
 576        self
 577    }
 578
 579    #[cfg(any(test, feature = "test-support"))]
 580    pub fn teardown(&self) {
 581        let mut state = self.state.write();
 582        state._reconnect_task.take();
 583        self.handler_set.lock().clear();
 584        self.peer.teardown();
 585    }
 586
 587    #[cfg(any(test, feature = "test-support"))]
 588    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 589    where
 590        F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
 591    {
 592        *self.authenticate.write() = Some(Box::new(authenticate));
 593        self
 594    }
 595
 596    #[cfg(any(test, feature = "test-support"))]
 597    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 598    where
 599        F: 'static
 600            + Send
 601            + Sync
 602            + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
 603    {
 604        *self.establish_connection.write() = Some(Box::new(connect));
 605        self
 606    }
 607
 608    #[cfg(any(test, feature = "test-support"))]
 609    pub fn override_rpc_url(&self, url: Url) -> &Self {
 610        *self.rpc_url.write() = Some(url);
 611        self
 612    }
 613
 614    pub fn global(cx: &App) -> Arc<Self> {
 615        cx.global::<GlobalClient>().0.clone()
 616    }
 617    pub fn set_global(client: Arc<Client>, cx: &mut App) {
 618        cx.set_global(GlobalClient(client))
 619    }
 620
 621    pub fn user_id(&self) -> Option<u64> {
 622        self.state
 623            .read()
 624            .credentials
 625            .as_ref()
 626            .map(|credentials| credentials.user_id)
 627    }
 628
 629    pub fn peer_id(&self) -> Option<PeerId> {
 630        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 631            Some(*peer_id)
 632        } else {
 633            None
 634        }
 635    }
 636
 637    pub fn status(&self) -> watch::Receiver<Status> {
 638        self.state.read().status.1.clone()
 639    }
 640
 641    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
 642        log::info!("set status on client {}: {:?}", self.id(), status);
 643        let mut state = self.state.write();
 644        *state.status.0.borrow_mut() = status;
 645
 646        match status {
 647            Status::Connected { .. } => {
 648                state._reconnect_task = None;
 649            }
 650            Status::ConnectionLost => {
 651                let client = self.clone();
 652                state._reconnect_task = Some(cx.spawn(async move |cx| {
 653                    #[cfg(any(test, feature = "test-support"))]
 654                    let mut rng = StdRng::seed_from_u64(0);
 655                    #[cfg(not(any(test, feature = "test-support")))]
 656                    let mut rng = StdRng::from_os_rng();
 657
 658                    let mut delay = INITIAL_RECONNECTION_DELAY;
 659                    loop {
 660                        match client.connect(true, cx).await {
 661                            ConnectionResult::Timeout => {
 662                                log::error!("client connect attempt timed out")
 663                            }
 664                            ConnectionResult::ConnectionReset => {
 665                                log::error!("client connect attempt reset")
 666                            }
 667                            ConnectionResult::Result(r) => {
 668                                if let Err(error) = r {
 669                                    log::error!("failed to connect: {error}");
 670                                } else {
 671                                    break;
 672                                }
 673                            }
 674                        }
 675
 676                        if matches!(
 677                            *client.status().borrow(),
 678                            Status::AuthenticationError | Status::ConnectionError
 679                        ) {
 680                            client.set_status(
 681                                Status::ReconnectionError {
 682                                    next_reconnection: Instant::now() + delay,
 683                                },
 684                                cx,
 685                            );
 686                            let jitter = Duration::from_millis(
 687                                rng.random_range(0..delay.as_millis() as u64),
 688                            );
 689                            cx.background_executor().timer(delay + jitter).await;
 690                            delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
 691                        } else {
 692                            break;
 693                        }
 694                    }
 695                }));
 696            }
 697            Status::SignedOut | Status::UpgradeRequired => {
 698                self.telemetry.set_authenticated_user_info(None, false);
 699                state._reconnect_task.take();
 700            }
 701            _ => {}
 702        }
 703    }
 704
 705    pub fn subscribe_to_entity<T>(
 706        self: &Arc<Self>,
 707        remote_id: u64,
 708    ) -> Result<PendingEntitySubscription<T>>
 709    where
 710        T: 'static,
 711    {
 712        let id = (TypeId::of::<T>(), remote_id);
 713
 714        let mut state = self.handler_set.lock();
 715        anyhow::ensure!(
 716            !state.entities_by_type_and_remote_id.contains_key(&id),
 717            "already subscribed to entity"
 718        );
 719
 720        state
 721            .entities_by_type_and_remote_id
 722            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 723
 724        Ok(PendingEntitySubscription {
 725            client: self.clone(),
 726            remote_id,
 727            consumed: false,
 728            _entity_type: PhantomData,
 729        })
 730    }
 731
 732    #[track_caller]
 733    pub fn add_message_handler<M, E, H, F>(
 734        self: &Arc<Self>,
 735        entity: WeakEntity<E>,
 736        handler: H,
 737    ) -> Subscription
 738    where
 739        M: EnvelopedMessage,
 740        E: 'static,
 741        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 742        F: 'static + Future<Output = Result<()>>,
 743    {
 744        self.add_message_handler_impl(entity, move |entity, message, _, cx| {
 745            handler(entity, message, cx)
 746        })
 747    }
 748
 749    fn add_message_handler_impl<M, E, H, F>(
 750        self: &Arc<Self>,
 751        entity: WeakEntity<E>,
 752        handler: H,
 753    ) -> Subscription
 754    where
 755        M: EnvelopedMessage,
 756        E: 'static,
 757        H: 'static
 758            + Sync
 759            + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
 760            + Send
 761            + Sync,
 762        F: 'static + Future<Output = Result<()>>,
 763    {
 764        let message_type_id = TypeId::of::<M>();
 765        let mut state = self.handler_set.lock();
 766        state
 767            .entities_by_message_type
 768            .insert(message_type_id, entity.into());
 769
 770        let prev_handler = state.message_handlers.insert(
 771            message_type_id,
 772            Arc::new(move |subscriber, envelope, client, cx| {
 773                let subscriber = subscriber.downcast::<E>().unwrap();
 774                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 775                handler(subscriber, *envelope, client, cx).boxed_local()
 776            }),
 777        );
 778        if prev_handler.is_some() {
 779            let location = std::panic::Location::caller();
 780            panic!(
 781                "{}:{} registered handler for the same message {} twice",
 782                location.file(),
 783                location.line(),
 784                std::any::type_name::<M>()
 785            );
 786        }
 787
 788        Subscription::Message {
 789            client: Arc::downgrade(self),
 790            id: message_type_id,
 791        }
 792    }
 793
 794    pub fn add_request_handler<M, E, H, F>(
 795        self: &Arc<Self>,
 796        entity: WeakEntity<E>,
 797        handler: H,
 798    ) -> Subscription
 799    where
 800        M: RequestMessage,
 801        E: 'static,
 802        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 803        F: 'static + Future<Output = Result<M::Response>>,
 804    {
 805        self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
 806            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 807        })
 808    }
 809
 810    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 811        receipt: Receipt<T>,
 812        response: F,
 813        client: AnyProtoClient,
 814    ) -> Result<()> {
 815        match response.await {
 816            Ok(response) => {
 817                client.send_response(receipt.message_id, response)?;
 818                Ok(())
 819            }
 820            Err(error) => {
 821                client.send_response(receipt.message_id, error.to_proto())?;
 822                Err(error)
 823            }
 824        }
 825    }
 826
 827    pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
 828        self.credentials_provider
 829            .read_credentials(cx)
 830            .await
 831            .is_some()
 832    }
 833
 834    pub async fn sign_in(
 835        self: &Arc<Self>,
 836        try_provider: bool,
 837        cx: &AsyncApp,
 838    ) -> Result<Credentials> {
 839        let is_reauthenticating = if self.status().borrow().is_signed_out() {
 840            self.set_status(Status::Authenticating, cx);
 841            false
 842        } else {
 843            self.set_status(Status::Reauthenticating, cx);
 844            true
 845        };
 846
 847        let mut credentials = None;
 848
 849        let old_credentials = self.state.read().credentials.clone();
 850        if let Some(old_credentials) = old_credentials
 851            && self.validate_credentials(&old_credentials, cx).await?
 852        {
 853            credentials = Some(old_credentials);
 854        }
 855
 856        if credentials.is_none()
 857            && try_provider
 858            && let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await
 859        {
 860            if self.validate_credentials(&stored_credentials, cx).await? {
 861                credentials = Some(stored_credentials);
 862            } else {
 863                self.credentials_provider
 864                    .delete_credentials(cx)
 865                    .await
 866                    .log_err();
 867            }
 868        }
 869
 870        if credentials.is_none() {
 871            let mut status_rx = self.status();
 872            let _ = status_rx.next().await;
 873            futures::select_biased! {
 874                authenticate = self.authenticate(cx).fuse() => {
 875                    match authenticate {
 876                        Ok(creds) => {
 877                            if IMPERSONATE_LOGIN.is_none() {
 878                                self.credentials_provider
 879                                    .write_credentials(creds.user_id, creds.access_token.clone(), cx)
 880                                    .await
 881                                    .log_err();
 882                            }
 883
 884                            credentials = Some(creds);
 885                        },
 886                        Err(err) => {
 887                            self.set_status(Status::AuthenticationError, cx);
 888                            return Err(err);
 889                        }
 890                    }
 891                }
 892                _ = status_rx.next().fuse() => {
 893                    return Err(anyhow!("authentication canceled"));
 894                }
 895            }
 896        }
 897
 898        let credentials = credentials.unwrap();
 899        self.set_id(credentials.user_id);
 900        self.cloud_client
 901            .set_credentials(credentials.user_id as u32, credentials.access_token.clone());
 902        self.state.write().credentials = Some(credentials.clone());
 903        self.set_status(
 904            if is_reauthenticating {
 905                Status::Reauthenticated
 906            } else {
 907                Status::Authenticated
 908            },
 909            cx,
 910        );
 911
 912        Ok(credentials)
 913    }
 914
 915    async fn validate_credentials(
 916        self: &Arc<Self>,
 917        credentials: &Credentials,
 918        cx: &AsyncApp,
 919    ) -> Result<bool> {
 920        match self
 921            .cloud_client
 922            .validate_credentials(credentials.user_id as u32, &credentials.access_token)
 923            .await
 924        {
 925            Ok(valid) => Ok(valid),
 926            Err(err) => {
 927                self.set_status(Status::AuthenticationError, cx);
 928                Err(anyhow!("failed to validate credentials: {}", err))
 929            }
 930        }
 931    }
 932
 933    /// Establishes a WebSocket connection with Cloud for receiving updates from the server.
 934    async fn connect_to_cloud(self: &Arc<Self>, cx: &AsyncApp) -> Result<()> {
 935        let connect_task = cx.update({
 936            let cloud_client = self.cloud_client.clone();
 937            move |cx| cloud_client.connect(cx)
 938        })??;
 939        let connection = connect_task.await?;
 940
 941        let (mut messages, task) = cx.update(|cx| connection.spawn(cx))?;
 942        task.detach();
 943
 944        cx.spawn({
 945            let this = self.clone();
 946            async move |cx| {
 947                while let Some(message) = messages.next().await {
 948                    if let Some(message) = message.log_err() {
 949                        this.handle_message_to_client(message, cx);
 950                    }
 951                }
 952            }
 953        })
 954        .detach();
 955
 956        Ok(())
 957    }
 958
 959    /// Performs a sign-in and also (optionally) connects to Collab.
 960    ///
 961    /// Only Zed staff automatically connect to Collab.
 962    pub async fn sign_in_with_optional_connect(
 963        self: &Arc<Self>,
 964        try_provider: bool,
 965        cx: &AsyncApp,
 966    ) -> Result<()> {
 967        // Don't try to sign in again if we're already connected to Collab, as it will temporarily disconnect us.
 968        if self.status().borrow().is_connected() {
 969            return Ok(());
 970        }
 971
 972        let (is_staff_tx, is_staff_rx) = oneshot::channel::<bool>();
 973        let mut is_staff_tx = Some(is_staff_tx);
 974        cx.update(|cx| {
 975            cx.on_flags_ready(move |state, _cx| {
 976                if let Some(is_staff_tx) = is_staff_tx.take() {
 977                    is_staff_tx.send(state.is_staff).log_err();
 978                }
 979            })
 980            .detach();
 981        })
 982        .log_err();
 983
 984        let credentials = self.sign_in(try_provider, cx).await?;
 985
 986        self.connect_to_cloud(cx).await.log_err();
 987
 988        cx.update(move |cx| {
 989            cx.spawn({
 990                let client = self.clone();
 991                async move |cx| {
 992                    let is_staff = is_staff_rx.await?;
 993                    if is_staff {
 994                        match client.connect_with_credentials(credentials, cx).await {
 995                            ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
 996                            ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
 997                            ConnectionResult::Result(result) => {
 998                                result.context("client auth and connect")
 999                            }
1000                        }
1001                    } else {
1002                        Ok(())
1003                    }
1004                }
1005            })
1006            .detach_and_log_err(cx);
1007        })
1008        .log_err();
1009
1010        Ok(())
1011    }
1012
1013    pub async fn connect(
1014        self: &Arc<Self>,
1015        try_provider: bool,
1016        cx: &AsyncApp,
1017    ) -> ConnectionResult<()> {
1018        let was_disconnected = match *self.status().borrow() {
1019            Status::SignedOut | Status::Authenticated => true,
1020            Status::ConnectionError
1021            | Status::ConnectionLost
1022            | Status::Authenticating
1023            | Status::AuthenticationError
1024            | Status::Reauthenticating
1025            | Status::Reauthenticated
1026            | Status::ReconnectionError { .. } => false,
1027            Status::Connected { .. } | Status::Connecting | Status::Reconnecting => {
1028                return ConnectionResult::Result(Ok(()));
1029            }
1030            Status::UpgradeRequired => {
1031                return ConnectionResult::Result(
1032                    Err(EstablishConnectionError::UpgradeRequired)
1033                        .context("client auth and connect"),
1034                );
1035            }
1036        };
1037        let credentials = match self.sign_in(try_provider, cx).await {
1038            Ok(credentials) => credentials,
1039            Err(err) => return ConnectionResult::Result(Err(err)),
1040        };
1041
1042        if was_disconnected {
1043            self.set_status(Status::Connecting, cx);
1044        } else {
1045            self.set_status(Status::Reconnecting, cx);
1046        }
1047
1048        self.connect_with_credentials(credentials, cx).await
1049    }
1050
1051    async fn connect_with_credentials(
1052        self: &Arc<Self>,
1053        credentials: Credentials,
1054        cx: &AsyncApp,
1055    ) -> ConnectionResult<()> {
1056        let mut timeout =
1057            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
1058        futures::select_biased! {
1059            connection = self.establish_connection(&credentials, cx).fuse() => {
1060                match connection {
1061                    Ok(conn) => {
1062                        futures::select_biased! {
1063                            result = self.set_connection(conn, cx).fuse() => {
1064                                match result.context("client auth and connect") {
1065                                    Ok(()) => ConnectionResult::Result(Ok(())),
1066                                    Err(err) => {
1067                                        self.set_status(Status::ConnectionError, cx);
1068                                        ConnectionResult::Result(Err(err))
1069                                    },
1070                                }
1071                            },
1072                            _ = timeout => {
1073                                self.set_status(Status::ConnectionError, cx);
1074                                ConnectionResult::Timeout
1075                            }
1076                        }
1077                    }
1078                    Err(EstablishConnectionError::Unauthorized) => {
1079                        self.set_status(Status::ConnectionError, cx);
1080                        ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
1081                    }
1082                    Err(EstablishConnectionError::UpgradeRequired) => {
1083                        self.set_status(Status::UpgradeRequired, cx);
1084                        ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
1085                    }
1086                    Err(error) => {
1087                        self.set_status(Status::ConnectionError, cx);
1088                        ConnectionResult::Result(Err(error).context("client auth and connect"))
1089                    }
1090                }
1091            }
1092            _ = &mut timeout => {
1093                self.set_status(Status::ConnectionError, cx);
1094                ConnectionResult::Timeout
1095            }
1096        }
1097    }
1098
1099    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
1100        let executor = cx.background_executor();
1101        log::debug!("add connection to peer");
1102        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
1103            let executor = executor.clone();
1104            move |duration| executor.timer(duration)
1105        });
1106        let handle_io = executor.spawn(handle_io);
1107
1108        let peer_id = async {
1109            log::debug!("waiting for server hello");
1110            let message = incoming.next().await.context("no hello message received")?;
1111            log::debug!("got server hello");
1112            let hello_message_type_name = message.payload_type_name().to_string();
1113            let hello = message
1114                .into_any()
1115                .downcast::<TypedEnvelope<proto::Hello>>()
1116                .map_err(|_| {
1117                    anyhow!(
1118                        "invalid hello message received: {:?}",
1119                        hello_message_type_name
1120                    )
1121                })?;
1122            let peer_id = hello.payload.peer_id.context("invalid peer id")?;
1123            Ok(peer_id)
1124        };
1125
1126        let peer_id = match peer_id.await {
1127            Ok(peer_id) => peer_id,
1128            Err(error) => {
1129                self.peer.disconnect(connection_id);
1130                return Err(error);
1131            }
1132        };
1133
1134        log::debug!(
1135            "set status to connected (connection id: {:?}, peer id: {:?})",
1136            connection_id,
1137            peer_id
1138        );
1139        self.set_status(
1140            Status::Connected {
1141                peer_id,
1142                connection_id,
1143            },
1144            cx,
1145        );
1146
1147        cx.spawn({
1148            let this = self.clone();
1149            async move |cx| {
1150                while let Some(message) = incoming.next().await {
1151                    this.handle_message(message, cx);
1152                    // Don't starve the main thread when receiving lots of messages at once.
1153                    smol::future::yield_now().await;
1154                }
1155            }
1156        })
1157        .detach();
1158
1159        cx.spawn({
1160            let this = self.clone();
1161            async move |cx| match handle_io.await {
1162                Ok(()) => {
1163                    if *this.status().borrow()
1164                        == (Status::Connected {
1165                            connection_id,
1166                            peer_id,
1167                        })
1168                    {
1169                        this.set_status(Status::SignedOut, cx);
1170                    }
1171                }
1172                Err(err) => {
1173                    log::error!("connection error: {:?}", err);
1174                    this.set_status(Status::ConnectionLost, cx);
1175                }
1176            }
1177        })
1178        .detach();
1179
1180        Ok(())
1181    }
1182
1183    fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1184        #[cfg(any(test, feature = "test-support"))]
1185        if let Some(callback) = self.authenticate.read().as_ref() {
1186            return callback(cx);
1187        }
1188
1189        self.authenticate_with_browser(cx)
1190    }
1191
1192    fn establish_connection(
1193        self: &Arc<Self>,
1194        credentials: &Credentials,
1195        cx: &AsyncApp,
1196    ) -> Task<Result<Connection, EstablishConnectionError>> {
1197        #[cfg(any(test, feature = "test-support"))]
1198        if let Some(callback) = self.establish_connection.read().as_ref() {
1199            return callback(credentials, cx);
1200        }
1201
1202        self.establish_websocket_connection(credentials, cx)
1203    }
1204
1205    fn rpc_url(
1206        &self,
1207        http: Arc<HttpClientWithUrl>,
1208        release_channel: Option<ReleaseChannel>,
1209    ) -> impl Future<Output = Result<url::Url>> + use<> {
1210        #[cfg(any(test, feature = "test-support"))]
1211        let url_override = self.rpc_url.read().clone();
1212
1213        async move {
1214            #[cfg(any(test, feature = "test-support"))]
1215            if let Some(url) = url_override {
1216                return Ok(url);
1217            }
1218
1219            if let Some(url) = &*ZED_RPC_URL {
1220                return Url::parse(url).context("invalid rpc url");
1221            }
1222
1223            let mut url = http.build_url("/rpc");
1224            if let Some(preview_param) =
1225                release_channel.and_then(|channel| channel.release_query_param())
1226            {
1227                url += "?";
1228                url += preview_param;
1229            }
1230
1231            let response = http.get(&url, Default::default(), false).await?;
1232            anyhow::ensure!(
1233                response.status().is_redirection(),
1234                "unexpected /rpc response status {}",
1235                response.status()
1236            );
1237            let collab_url = response
1238                .headers()
1239                .get("Location")
1240                .context("missing location header in /rpc response")?
1241                .to_str()
1242                .map_err(EstablishConnectionError::other)?
1243                .to_string();
1244            Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
1245        }
1246    }
1247
1248    fn establish_websocket_connection(
1249        self: &Arc<Self>,
1250        credentials: &Credentials,
1251        cx: &AsyncApp,
1252    ) -> Task<Result<Connection, EstablishConnectionError>> {
1253        let release_channel = cx
1254            .update(|cx| ReleaseChannel::try_global(cx))
1255            .ok()
1256            .flatten();
1257        let app_version = cx
1258            .update(|cx| AppVersion::global(cx).to_string())
1259            .ok()
1260            .unwrap_or_default();
1261
1262        let http = self.http.clone();
1263        let proxy = http.proxy().cloned();
1264        let user_agent = http.user_agent().cloned();
1265        let credentials = credentials.clone();
1266        let rpc_url = self.rpc_url(http, release_channel);
1267        let system_id = self.telemetry.system_id();
1268        let metrics_id = self.telemetry.metrics_id();
1269        cx.spawn(async move |cx| {
1270            use HttpOrHttps::*;
1271
1272            #[derive(Debug)]
1273            enum HttpOrHttps {
1274                Http,
1275                Https,
1276            }
1277
1278            let mut rpc_url = rpc_url.await?;
1279            let url_scheme = match rpc_url.scheme() {
1280                "https" => Https,
1281                "http" => Http,
1282                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1283            };
1284
1285            let stream = gpui_tokio::Tokio::spawn_result(cx, {
1286                let rpc_url = rpc_url.clone();
1287                async move {
1288                    let rpc_host = rpc_url
1289                        .host_str()
1290                        .zip(rpc_url.port_or_known_default())
1291                        .context("missing host in rpc url")?;
1292                    Ok(match proxy {
1293                        Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
1294                        None => Box::new(TcpStream::connect(rpc_host).await?),
1295                    })
1296                }
1297            })?
1298            .await?;
1299
1300            log::info!("connected to rpc endpoint {}", rpc_url);
1301
1302            rpc_url
1303                .set_scheme(match url_scheme {
1304                    Https => "wss",
1305                    Http => "ws",
1306                })
1307                .unwrap();
1308
1309            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1310            // for us from the RPC URL.
1311            //
1312            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1313            let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1314
1315            // We then modify the request to add our desired headers.
1316            let request_headers = request.headers_mut();
1317            request_headers.insert(
1318                http::header::AUTHORIZATION,
1319                HeaderValue::from_str(&credentials.authorization_header())?,
1320            );
1321            request_headers.insert(
1322                "x-zed-protocol-version",
1323                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1324            );
1325            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1326            request_headers.insert(
1327                "x-zed-release-channel",
1328                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1329            );
1330            if let Some(user_agent) = user_agent {
1331                request_headers.insert(http::header::USER_AGENT, user_agent);
1332            }
1333            if let Some(system_id) = system_id {
1334                request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1335            }
1336            if let Some(metrics_id) = metrics_id {
1337                request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1338            }
1339
1340            let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1341                request,
1342                stream,
1343                Some(Arc::new(http_client_tls::tls_config()).into()),
1344                None,
1345            )
1346            .await?;
1347
1348            Ok(Connection::new(
1349                stream
1350                    .map_err(|error| anyhow!(error))
1351                    .sink_map_err(|error| anyhow!(error)),
1352            ))
1353        })
1354    }
1355
1356    pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1357        let http = self.http.clone();
1358        let this = self.clone();
1359        cx.spawn(async move |cx| {
1360            let background = cx.background_executor().clone();
1361
1362            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1363            cx.update(|cx| {
1364                cx.spawn(async move |cx| {
1365                    let url = open_url_rx.await?;
1366                    cx.update(|cx| cx.open_url(&url))
1367                })
1368                .detach_and_log_err(cx);
1369            })
1370            .log_err();
1371
1372            let credentials = background
1373                .clone()
1374                .spawn(async move {
1375                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1376                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1377                    // any other app running on the user's device.
1378                    let (public_key, private_key) =
1379                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1380                    let public_key_string = String::try_from(public_key)
1381                        .expect("failed to serialize public key for auth");
1382
1383                    if let Some((login, token)) =
1384                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1385                    {
1386                        if !*USE_WEB_LOGIN {
1387                            eprintln!("authenticate as admin {login}, {token}");
1388
1389                            return this
1390                                .authenticate_as_admin(http, login.clone(), token.clone())
1391                                .await;
1392                        }
1393                    }
1394
1395                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1396                    let server =
1397                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1398                    let port = server.server_addr().port();
1399
1400                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1401                    // that the user is signing in from a Zed app running on the same device.
1402                    let mut url = http.build_url(&format!(
1403                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1404                        port, public_key_string
1405                    ));
1406
1407                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1408                        log::info!("impersonating user @{}", impersonate_login);
1409                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1410                    }
1411
1412                    open_url_tx.send(url).log_err();
1413
1414                    #[derive(Deserialize)]
1415                    struct CallbackParams {
1416                        pub user_id: String,
1417                        pub access_token: String,
1418                    }
1419
1420                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1421                    // access token from the query params.
1422                    //
1423                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1424                    // custom URL scheme instead of this local HTTP server.
1425                    let (user_id, access_token) = background
1426                        .spawn(async move {
1427                            for _ in 0..100 {
1428                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1429                                    let path = req.url();
1430                                    let url = Url::parse(&format!("http://example.com{}", path))
1431                                        .context("failed to parse login notification url")?;
1432                                    let callback_params: CallbackParams =
1433                                        serde_urlencoded::from_str(url.query().unwrap_or_default())
1434                                            .context(
1435                                                "failed to parse sign-in callback query parameters",
1436                                            )?;
1437
1438                                    let post_auth_url =
1439                                        http.build_url("/native_app_signin_succeeded");
1440                                    req.respond(
1441                                        tiny_http::Response::empty(302).with_header(
1442                                            tiny_http::Header::from_bytes(
1443                                                &b"Location"[..],
1444                                                post_auth_url.as_bytes(),
1445                                            )
1446                                            .unwrap(),
1447                                        ),
1448                                    )
1449                                    .context("failed to respond to login http request")?;
1450                                    return Ok((
1451                                        callback_params.user_id,
1452                                        callback_params.access_token,
1453                                    ));
1454                                }
1455                            }
1456
1457                            anyhow::bail!("didn't receive login redirect");
1458                        })
1459                        .await?;
1460
1461                    let access_token = private_key
1462                        .decrypt_string(&access_token)
1463                        .context("failed to decrypt access token")?;
1464
1465                    Ok(Credentials {
1466                        user_id: user_id.parse()?,
1467                        access_token,
1468                    })
1469                })
1470                .await?;
1471
1472            cx.update(|cx| cx.activate(true))?;
1473            Ok(credentials)
1474        })
1475    }
1476
1477    async fn authenticate_as_admin(
1478        self: &Arc<Self>,
1479        http: Arc<HttpClientWithUrl>,
1480        login: String,
1481        api_token: String,
1482    ) -> Result<Credentials> {
1483        #[derive(Serialize)]
1484        struct ImpersonateUserBody {
1485            github_login: String,
1486        }
1487
1488        #[derive(Deserialize)]
1489        struct ImpersonateUserResponse {
1490            user_id: u64,
1491            access_token: String,
1492        }
1493
1494        let url = self
1495            .http
1496            .build_zed_cloud_url("/internal/users/impersonate", &[])?;
1497        let request = Request::post(url.as_str())
1498            .header("Content-Type", "application/json")
1499            .header("Authorization", format!("Bearer {api_token}"))
1500            .body(
1501                serde_json::to_string(&ImpersonateUserBody {
1502                    github_login: login,
1503                })?
1504                .into(),
1505            )?;
1506
1507        let mut response = http.send(request).await?;
1508        let mut body = String::new();
1509        response.body_mut().read_to_string(&mut body).await?;
1510        anyhow::ensure!(
1511            response.status().is_success(),
1512            "admin user request failed {} - {}",
1513            response.status().as_u16(),
1514            body,
1515        );
1516        let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
1517
1518        Ok(Credentials {
1519            user_id: response.user_id,
1520            access_token: response.access_token,
1521        })
1522    }
1523
1524    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1525        self.state.write().credentials = None;
1526        self.cloud_client.clear_credentials();
1527        self.disconnect(cx);
1528
1529        if self.has_credentials(cx).await {
1530            self.credentials_provider
1531                .delete_credentials(cx)
1532                .await
1533                .log_err();
1534        }
1535    }
1536
1537    pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1538        self.peer.teardown();
1539        self.set_status(Status::SignedOut, cx);
1540    }
1541
1542    pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1543        self.peer.teardown();
1544        self.set_status(Status::ConnectionLost, cx);
1545    }
1546
1547    fn connection_id(&self) -> Result<ConnectionId> {
1548        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1549            Ok(connection_id)
1550        } else {
1551            anyhow::bail!("not connected");
1552        }
1553    }
1554
1555    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1556        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1557        self.peer.send(self.connection_id()?, message)
1558    }
1559
1560    pub fn request<T: RequestMessage>(
1561        &self,
1562        request: T,
1563    ) -> impl Future<Output = Result<T::Response>> + use<T> {
1564        self.request_envelope(request)
1565            .map_ok(|envelope| envelope.payload)
1566    }
1567
1568    pub fn request_stream<T: RequestMessage>(
1569        &self,
1570        request: T,
1571    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1572        let client_id = self.id.load(Ordering::SeqCst);
1573        log::debug!(
1574            "rpc request start. client_id:{}. name:{}",
1575            client_id,
1576            T::NAME
1577        );
1578        let response = self
1579            .connection_id()
1580            .map(|conn_id| self.peer.request_stream(conn_id, request));
1581        async move {
1582            let response = response?.await;
1583            log::debug!(
1584                "rpc request finish. client_id:{}. name:{}",
1585                client_id,
1586                T::NAME
1587            );
1588            response
1589        }
1590    }
1591
1592    pub fn request_envelope<T: RequestMessage>(
1593        &self,
1594        request: T,
1595    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1596        let client_id = self.id();
1597        log::debug!(
1598            "rpc request start. client_id:{}. name:{}",
1599            client_id,
1600            T::NAME
1601        );
1602        let response = self
1603            .connection_id()
1604            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1605        async move {
1606            let response = response?.await;
1607            log::debug!(
1608                "rpc request finish. client_id:{}. name:{}",
1609                client_id,
1610                T::NAME
1611            );
1612            response
1613        }
1614    }
1615
1616    pub fn request_dynamic(
1617        &self,
1618        envelope: proto::Envelope,
1619        request_type: &'static str,
1620    ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1621        let client_id = self.id();
1622        log::debug!(
1623            "rpc request start. client_id:{}. name:{}",
1624            client_id,
1625            request_type
1626        );
1627        let response = self
1628            .connection_id()
1629            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1630        async move {
1631            let response = response?.await;
1632            log::debug!(
1633                "rpc request finish. client_id:{}. name:{}",
1634                client_id,
1635                request_type
1636            );
1637            Ok(response?.0)
1638        }
1639    }
1640
1641    fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1642        let sender_id = message.sender_id();
1643        let request_id = message.message_id();
1644        let type_name = message.payload_type_name();
1645        let original_sender_id = message.original_sender_id();
1646
1647        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1648            &self.handler_set,
1649            message,
1650            self.clone().into(),
1651            cx.clone(),
1652        ) {
1653            let client_id = self.id();
1654            log::debug!(
1655                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1656                client_id,
1657                original_sender_id,
1658                type_name
1659            );
1660            cx.spawn(async move |_| match future.await {
1661                Ok(()) => {
1662                    log::debug!("rpc message handled. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}");
1663                }
1664                Err(error) => {
1665                    log::error!("error handling message. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}, error:{error:#}");
1666                }
1667            })
1668            .detach();
1669        } else {
1670            log::info!("unhandled message {}", type_name);
1671            self.peer
1672                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1673                .log_err();
1674        }
1675    }
1676
1677    pub fn add_message_to_client_handler(
1678        self: &Arc<Client>,
1679        handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
1680    ) {
1681        self.message_to_client_handlers
1682            .lock()
1683            .push(Box::new(handler));
1684    }
1685
1686    fn handle_message_to_client(self: &Arc<Client>, message: MessageToClient, cx: &AsyncApp) {
1687        cx.update(|cx| {
1688            for handler in self.message_to_client_handlers.lock().iter() {
1689                handler(&message, cx);
1690            }
1691        })
1692        .ok();
1693    }
1694
1695    pub fn telemetry(&self) -> &Arc<Telemetry> {
1696        &self.telemetry
1697    }
1698}
1699
1700impl ProtoClient for Client {
1701    fn request(
1702        &self,
1703        envelope: proto::Envelope,
1704        request_type: &'static str,
1705    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1706        self.request_dynamic(envelope, request_type).boxed()
1707    }
1708
1709    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1710        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1711        let connection_id = self.connection_id()?;
1712        self.peer.send_dynamic(connection_id, envelope)
1713    }
1714
1715    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1716        log::debug!(
1717            "rpc respond. client_id:{}, name:{}",
1718            self.id(),
1719            message_type
1720        );
1721        let connection_id = self.connection_id()?;
1722        self.peer.send_dynamic(connection_id, envelope)
1723    }
1724
1725    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1726        &self.handler_set
1727    }
1728
1729    fn is_via_collab(&self) -> bool {
1730        true
1731    }
1732}
1733
1734/// prefix for the zed:// url scheme
1735pub const ZED_URL_SCHEME: &str = "zed";
1736
1737/// Parses the given link into a Zed link.
1738///
1739/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1740/// Returns [`None`] otherwise.
1741pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
1742    let server_url = &ClientSettings::get_global(cx).server_url;
1743    if let Some(stripped) = link
1744        .strip_prefix(server_url)
1745        .and_then(|result| result.strip_prefix('/'))
1746    {
1747        return Some(stripped);
1748    }
1749    if let Some(stripped) = link
1750        .strip_prefix(ZED_URL_SCHEME)
1751        .and_then(|result| result.strip_prefix("://"))
1752    {
1753        return Some(stripped);
1754    }
1755
1756    None
1757}
1758
1759#[cfg(test)]
1760mod tests {
1761    use super::*;
1762    use crate::test::{FakeServer, parse_authorization_header};
1763
1764    use clock::FakeSystemClock;
1765    use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1766    use http_client::FakeHttpClient;
1767    use parking_lot::Mutex;
1768    use proto::TypedEnvelope;
1769    use settings::SettingsStore;
1770    use std::future;
1771
1772    #[gpui::test(iterations = 10)]
1773    async fn test_reconnection(cx: &mut TestAppContext) {
1774        init_test(cx);
1775        let user_id = 5;
1776        let client = cx.update(|cx| {
1777            Client::new(
1778                Arc::new(FakeSystemClock::new()),
1779                FakeHttpClient::with_404_response(),
1780                cx,
1781            )
1782        });
1783        let server = FakeServer::for_client(user_id, &client, cx).await;
1784        let mut status = client.status();
1785        assert!(matches!(
1786            status.next().await,
1787            Some(Status::Connected { .. })
1788        ));
1789        assert_eq!(server.auth_count(), 1);
1790
1791        server.forbid_connections();
1792        server.disconnect();
1793        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1794
1795        server.allow_connections();
1796        cx.executor().advance_clock(Duration::from_secs(10));
1797        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1798        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1799
1800        server.forbid_connections();
1801        server.disconnect();
1802        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1803
1804        // Clear cached credentials after authentication fails
1805        server.roll_access_token();
1806        server.allow_connections();
1807        cx.executor().run_until_parked();
1808        cx.executor().advance_clock(Duration::from_secs(10));
1809        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1810        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1811    }
1812
1813    #[gpui::test(iterations = 10)]
1814    async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) {
1815        init_test(cx);
1816        let http_client = FakeHttpClient::with_200_response();
1817        let client =
1818            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1819        let server = FakeServer::for_client(42, &client, cx).await;
1820        let mut status = client.status();
1821        assert!(matches!(
1822            status.next().await,
1823            Some(Status::Connected { .. })
1824        ));
1825        assert_eq!(server.auth_count(), 1);
1826
1827        // Simulate an auth failure during reconnection.
1828        http_client
1829            .as_fake()
1830            .replace_handler(|_, _request| async move {
1831                Ok(http_client::Response::builder()
1832                    .status(503)
1833                    .body("".into())
1834                    .unwrap())
1835            });
1836        server.disconnect();
1837        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1838
1839        // Restore the ability to authenticate.
1840        http_client
1841            .as_fake()
1842            .replace_handler(|_, _request| async move {
1843                Ok(http_client::Response::builder()
1844                    .status(200)
1845                    .body("".into())
1846                    .unwrap())
1847            });
1848        cx.executor().advance_clock(Duration::from_secs(10));
1849        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1850        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1851    }
1852
1853    #[gpui::test(iterations = 10)]
1854    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1855        init_test(cx);
1856        let user_id = 5;
1857        let client = cx.update(|cx| {
1858            Client::new(
1859                Arc::new(FakeSystemClock::new()),
1860                FakeHttpClient::with_404_response(),
1861                cx,
1862            )
1863        });
1864        let mut status = client.status();
1865
1866        // Time out when client tries to connect.
1867        client.override_authenticate(move |cx| {
1868            cx.background_spawn(async move {
1869                Ok(Credentials {
1870                    user_id,
1871                    access_token: "token".into(),
1872                })
1873            })
1874        });
1875        client.override_establish_connection(|_, cx| {
1876            cx.background_spawn(async move {
1877                future::pending::<()>().await;
1878                unreachable!()
1879            })
1880        });
1881        let auth_and_connect = cx.spawn({
1882            let client = client.clone();
1883            |cx| async move { client.connect(false, &cx).await }
1884        });
1885        executor.run_until_parked();
1886        assert!(matches!(status.next().await, Some(Status::Connecting)));
1887
1888        executor.advance_clock(CONNECTION_TIMEOUT);
1889        assert!(matches!(status.next().await, Some(Status::ConnectionError)));
1890        auth_and_connect.await.into_response().unwrap_err();
1891
1892        // Allow the connection to be established.
1893        let server = FakeServer::for_client(user_id, &client, cx).await;
1894        assert!(matches!(
1895            status.next().await,
1896            Some(Status::Connected { .. })
1897        ));
1898
1899        // Disconnect client.
1900        server.forbid_connections();
1901        server.disconnect();
1902        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1903
1904        // Time out when re-establishing the connection.
1905        server.allow_connections();
1906        client.override_establish_connection(|_, cx| {
1907            cx.background_spawn(async move {
1908                future::pending::<()>().await;
1909                unreachable!()
1910            })
1911        });
1912        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1913        assert!(matches!(status.next().await, Some(Status::Reconnecting)));
1914
1915        executor.advance_clock(CONNECTION_TIMEOUT);
1916        assert!(matches!(
1917            status.next().await,
1918            Some(Status::ReconnectionError { .. })
1919        ));
1920    }
1921
1922    #[gpui::test(iterations = 10)]
1923    async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
1924        init_test(cx);
1925        let auth_count = Arc::new(Mutex::new(0));
1926        let http_client = FakeHttpClient::create(|_request| async move {
1927            Ok(http_client::Response::builder()
1928                .status(200)
1929                .body("".into())
1930                .unwrap())
1931        });
1932        let client =
1933            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1934        client.override_authenticate({
1935            let auth_count = auth_count.clone();
1936            move |cx| {
1937                let auth_count = auth_count.clone();
1938                cx.background_spawn(async move {
1939                    *auth_count.lock() += 1;
1940                    Ok(Credentials {
1941                        user_id: 1,
1942                        access_token: auth_count.lock().to_string(),
1943                    })
1944                })
1945            }
1946        });
1947
1948        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1949        assert_eq!(*auth_count.lock(), 1);
1950        assert_eq!(credentials.access_token, "1");
1951
1952        // If credentials are still valid, signing in doesn't trigger authentication.
1953        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1954        assert_eq!(*auth_count.lock(), 1);
1955        assert_eq!(credentials.access_token, "1");
1956
1957        // If the server is unavailable, signing in doesn't trigger authentication.
1958        http_client
1959            .as_fake()
1960            .replace_handler(|_, _request| async move {
1961                Ok(http_client::Response::builder()
1962                    .status(503)
1963                    .body("".into())
1964                    .unwrap())
1965            });
1966        client.sign_in(false, &cx.to_async()).await.unwrap_err();
1967        assert_eq!(*auth_count.lock(), 1);
1968
1969        // If credentials became invalid, signing in triggers authentication.
1970        http_client
1971            .as_fake()
1972            .replace_handler(|_, request| async move {
1973                let credentials = parse_authorization_header(&request).unwrap();
1974                if credentials.access_token == "2" {
1975                    Ok(http_client::Response::builder()
1976                        .status(200)
1977                        .body("".into())
1978                        .unwrap())
1979                } else {
1980                    Ok(http_client::Response::builder()
1981                        .status(401)
1982                        .body("".into())
1983                        .unwrap())
1984                }
1985            });
1986        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1987        assert_eq!(*auth_count.lock(), 2);
1988        assert_eq!(credentials.access_token, "2");
1989    }
1990
1991    #[gpui::test(iterations = 10)]
1992    async fn test_authenticating_more_than_once(
1993        cx: &mut TestAppContext,
1994        executor: BackgroundExecutor,
1995    ) {
1996        init_test(cx);
1997        let auth_count = Arc::new(Mutex::new(0));
1998        let dropped_auth_count = Arc::new(Mutex::new(0));
1999        let client = cx.update(|cx| {
2000            Client::new(
2001                Arc::new(FakeSystemClock::new()),
2002                FakeHttpClient::with_404_response(),
2003                cx,
2004            )
2005        });
2006        client.override_authenticate({
2007            let auth_count = auth_count.clone();
2008            let dropped_auth_count = dropped_auth_count.clone();
2009            move |cx| {
2010                let auth_count = auth_count.clone();
2011                let dropped_auth_count = dropped_auth_count.clone();
2012                cx.background_spawn(async move {
2013                    *auth_count.lock() += 1;
2014                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
2015                    future::pending::<()>().await;
2016                    unreachable!()
2017                })
2018            }
2019        });
2020
2021        let _authenticate = cx.spawn({
2022            let client = client.clone();
2023            move |cx| async move { client.connect(false, &cx).await }
2024        });
2025        executor.run_until_parked();
2026        assert_eq!(*auth_count.lock(), 1);
2027        assert_eq!(*dropped_auth_count.lock(), 0);
2028
2029        let _authenticate = cx.spawn(|cx| async move { client.connect(false, &cx).await });
2030        executor.run_until_parked();
2031        assert_eq!(*auth_count.lock(), 2);
2032        assert_eq!(*dropped_auth_count.lock(), 1);
2033    }
2034
2035    #[gpui::test]
2036    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
2037        init_test(cx);
2038        let user_id = 5;
2039        let client = cx.update(|cx| {
2040            Client::new(
2041                Arc::new(FakeSystemClock::new()),
2042                FakeHttpClient::with_404_response(),
2043                cx,
2044            )
2045        });
2046        let server = FakeServer::for_client(user_id, &client, cx).await;
2047
2048        let (done_tx1, done_rx1) = smol::channel::unbounded();
2049        let (done_tx2, done_rx2) = smol::channel::unbounded();
2050        AnyProtoClient::from(client.clone()).add_entity_message_handler(
2051            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, cx| {
2052                match entity.read_with(&cx, |entity, _| entity.id).unwrap() {
2053                    1 => done_tx1.try_send(()).unwrap(),
2054                    2 => done_tx2.try_send(()).unwrap(),
2055                    _ => unreachable!(),
2056                }
2057                async { Ok(()) }
2058            },
2059        );
2060        let entity1 = cx.new(|_| TestEntity {
2061            id: 1,
2062            subscription: None,
2063        });
2064        let entity2 = cx.new(|_| TestEntity {
2065            id: 2,
2066            subscription: None,
2067        });
2068        let entity3 = cx.new(|_| TestEntity {
2069            id: 3,
2070            subscription: None,
2071        });
2072
2073        let _subscription1 = client
2074            .subscribe_to_entity(1)
2075            .unwrap()
2076            .set_entity(&entity1, &cx.to_async());
2077        let _subscription2 = client
2078            .subscribe_to_entity(2)
2079            .unwrap()
2080            .set_entity(&entity2, &cx.to_async());
2081        // Ensure dropping a subscription for the same entity type still allows receiving of
2082        // messages for other entity IDs of the same type.
2083        let subscription3 = client
2084            .subscribe_to_entity(3)
2085            .unwrap()
2086            .set_entity(&entity3, &cx.to_async());
2087        drop(subscription3);
2088
2089        server.send(proto::JoinProject {
2090            project_id: 1,
2091            committer_name: None,
2092            committer_email: None,
2093        });
2094        server.send(proto::JoinProject {
2095            project_id: 2,
2096            committer_name: None,
2097            committer_email: None,
2098        });
2099        done_rx1.recv().await.unwrap();
2100        done_rx2.recv().await.unwrap();
2101    }
2102
2103    #[gpui::test]
2104    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2105        init_test(cx);
2106        let user_id = 5;
2107        let client = cx.update(|cx| {
2108            Client::new(
2109                Arc::new(FakeSystemClock::new()),
2110                FakeHttpClient::with_404_response(),
2111                cx,
2112            )
2113        });
2114        let server = FakeServer::for_client(user_id, &client, cx).await;
2115
2116        let entity = cx.new(|_| TestEntity::default());
2117        let (done_tx1, _done_rx1) = smol::channel::unbounded();
2118        let (done_tx2, done_rx2) = smol::channel::unbounded();
2119        let subscription1 = client.add_message_handler(
2120            entity.downgrade(),
2121            move |_, _: TypedEnvelope<proto::Ping>, _| {
2122                done_tx1.try_send(()).unwrap();
2123                async { Ok(()) }
2124            },
2125        );
2126        drop(subscription1);
2127        let _subscription2 = client.add_message_handler(
2128            entity.downgrade(),
2129            move |_, _: TypedEnvelope<proto::Ping>, _| {
2130                done_tx2.try_send(()).unwrap();
2131                async { Ok(()) }
2132            },
2133        );
2134        server.send(proto::Ping {});
2135        done_rx2.recv().await.unwrap();
2136    }
2137
2138    #[gpui::test]
2139    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2140        init_test(cx);
2141        let user_id = 5;
2142        let client = cx.update(|cx| {
2143            Client::new(
2144                Arc::new(FakeSystemClock::new()),
2145                FakeHttpClient::with_404_response(),
2146                cx,
2147            )
2148        });
2149        let server = FakeServer::for_client(user_id, &client, cx).await;
2150
2151        let entity = cx.new(|_| TestEntity::default());
2152        let (done_tx, done_rx) = smol::channel::unbounded();
2153        let subscription = client.add_message_handler(
2154            entity.clone().downgrade(),
2155            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
2156                entity
2157                    .update(&mut cx, |entity, _| entity.subscription.take())
2158                    .unwrap();
2159                done_tx.try_send(()).unwrap();
2160                async { Ok(()) }
2161            },
2162        );
2163        entity.update(cx, |entity, _| {
2164            entity.subscription = Some(subscription);
2165        });
2166        server.send(proto::Ping {});
2167        done_rx.recv().await.unwrap();
2168    }
2169
2170    #[derive(Default)]
2171    struct TestEntity {
2172        id: usize,
2173        subscription: Option<Subscription>,
2174    }
2175
2176    fn init_test(cx: &mut TestAppContext) {
2177        cx.update(|cx| {
2178            let settings_store = SettingsStore::test(cx);
2179            cx.set_global(settings_store);
2180            init_settings(cx);
2181        });
2182    }
2183}