client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod proxy;
   5pub mod telemetry;
   6pub mod user;
   7pub mod zed_urls;
   8
   9use anyhow::{Context as _, Result, anyhow};
  10use async_tungstenite::tungstenite::{
  11    client::IntoClientRequest,
  12    error::Error as WebsocketError,
  13    http::{HeaderValue, Request, StatusCode},
  14};
  15use clock::SystemClock;
  16use cloud_api_client::CloudApiClient;
  17use cloud_api_client::websocket_protocol::MessageToClient;
  18use credentials_provider::CredentialsProvider;
  19use feature_flags::FeatureFlagAppExt as _;
  20use futures::{
  21    AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
  22    channel::oneshot, future::BoxFuture,
  23};
  24use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
  25use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
  26use parking_lot::RwLock;
  27use postage::watch;
  28use proxy::connect_proxy_stream;
  29use rand::prelude::*;
  30use release_channel::{AppVersion, ReleaseChannel};
  31use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  32use serde::{Deserialize, Serialize};
  33use settings::{Settings, SettingsContent};
  34use std::{
  35    any::TypeId,
  36    convert::TryFrom,
  37    fmt::Write as _,
  38    future::Future,
  39    marker::PhantomData,
  40    path::PathBuf,
  41    sync::{
  42        Arc, LazyLock, Weak,
  43        atomic::{AtomicU64, Ordering},
  44    },
  45    time::{Duration, Instant},
  46};
  47use std::{cmp, pin::Pin};
  48use telemetry::Telemetry;
  49use thiserror::Error;
  50use tokio::net::TcpStream;
  51use url::Url;
  52use util::{ConnectionResult, ResultExt};
  53
  54pub use rpc::*;
  55pub use telemetry_events::Event;
  56pub use user::*;
  57
  58static ZED_SERVER_URL: LazyLock<Option<String>> =
  59    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  60static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  61
  62pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  63    std::env::var("ZED_IMPERSONATE")
  64        .ok()
  65        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  66});
  67
  68pub static USE_WEB_LOGIN: LazyLock<bool> = LazyLock::new(|| std::env::var("ZED_WEB_LOGIN").is_ok());
  69
  70pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  71    std::env::var("ZED_ADMIN_API_TOKEN")
  72        .ok()
  73        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  74});
  75
  76pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  77    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  78
  79pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  80    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").is_ok_and(|e| !e.is_empty()));
  81
  82pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  83pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
  84pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  85
  86actions!(
  87    client,
  88    [
  89        /// Signs in to Zed account.
  90        SignIn,
  91        /// Signs out of Zed account.
  92        SignOut,
  93        /// Reconnects to the collaboration server.
  94        Reconnect
  95    ]
  96);
  97
  98#[derive(Deserialize)]
  99pub struct ClientSettings {
 100    pub server_url: String,
 101}
 102
 103impl Settings for ClientSettings {
 104    fn from_settings(content: &settings::SettingsContent) -> Self {
 105        if let Some(server_url) = &*ZED_SERVER_URL {
 106            return Self {
 107                server_url: server_url.clone(),
 108            };
 109        }
 110        Self {
 111            server_url: content.server_url.clone().unwrap(),
 112        }
 113    }
 114}
 115
 116#[derive(Deserialize, Default)]
 117pub struct ProxySettings {
 118    pub proxy: Option<String>,
 119}
 120
 121impl ProxySettings {
 122    pub fn proxy_url(&self) -> Option<Url> {
 123        self.proxy
 124            .as_ref()
 125            .and_then(|input| {
 126                input
 127                    .parse::<Url>()
 128                    .inspect_err(|e| log::error!("Error parsing proxy settings: {}", e))
 129                    .ok()
 130            })
 131            .or_else(read_proxy_from_env)
 132    }
 133}
 134
 135impl Settings for ProxySettings {
 136    fn from_settings(content: &settings::SettingsContent) -> Self {
 137        Self {
 138            proxy: content.proxy.clone(),
 139        }
 140    }
 141
 142    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut SettingsContent) {
 143        vscode.string_setting("http.proxy", &mut current.proxy);
 144    }
 145}
 146
 147pub fn init_settings(cx: &mut App) {
 148    TelemetrySettings::register(cx);
 149    ClientSettings::register(cx);
 150    ProxySettings::register(cx);
 151}
 152
 153pub fn init(client: &Arc<Client>, cx: &mut App) {
 154    let client = Arc::downgrade(client);
 155    cx.on_action({
 156        let client = client.clone();
 157        move |_: &SignIn, cx| {
 158            if let Some(client) = client.upgrade() {
 159                cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, cx).await)
 160                    .detach_and_log_err(cx);
 161            }
 162        }
 163    });
 164
 165    cx.on_action({
 166        let client = client.clone();
 167        move |_: &SignOut, cx| {
 168            if let Some(client) = client.upgrade() {
 169                cx.spawn(async move |cx| {
 170                    client.sign_out(cx).await;
 171                })
 172                .detach();
 173            }
 174        }
 175    });
 176
 177    cx.on_action({
 178        let client = client;
 179        move |_: &Reconnect, cx| {
 180            if let Some(client) = client.upgrade() {
 181                cx.spawn(async move |cx| {
 182                    client.reconnect(cx);
 183                })
 184                .detach();
 185            }
 186        }
 187    });
 188}
 189
 190pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
 191
 192struct GlobalClient(Arc<Client>);
 193
 194impl Global for GlobalClient {}
 195
 196pub struct Client {
 197    id: AtomicU64,
 198    peer: Arc<Peer>,
 199    http: Arc<HttpClientWithUrl>,
 200    cloud_client: Arc<CloudApiClient>,
 201    telemetry: Arc<Telemetry>,
 202    credentials_provider: ClientCredentialsProvider,
 203    state: RwLock<ClientState>,
 204    handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
 205    message_to_client_handlers: parking_lot::Mutex<Vec<MessageToClientHandler>>,
 206
 207    #[allow(clippy::type_complexity)]
 208    #[cfg(any(test, feature = "test-support"))]
 209    authenticate:
 210        RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
 211
 212    #[allow(clippy::type_complexity)]
 213    #[cfg(any(test, feature = "test-support"))]
 214    establish_connection: RwLock<
 215        Option<
 216            Box<
 217                dyn 'static
 218                    + Send
 219                    + Sync
 220                    + Fn(
 221                        &Credentials,
 222                        &AsyncApp,
 223                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 224            >,
 225        >,
 226    >,
 227
 228    #[cfg(any(test, feature = "test-support"))]
 229    rpc_url: RwLock<Option<Url>>,
 230}
 231
 232#[derive(Error, Debug)]
 233pub enum EstablishConnectionError {
 234    #[error("upgrade required")]
 235    UpgradeRequired,
 236    #[error("unauthorized")]
 237    Unauthorized,
 238    #[error("{0}")]
 239    Other(#[from] anyhow::Error),
 240    #[error("{0}")]
 241    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 242    #[error("{0}")]
 243    Io(#[from] std::io::Error),
 244    #[error("{0}")]
 245    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 246}
 247
 248impl From<WebsocketError> for EstablishConnectionError {
 249    fn from(error: WebsocketError) -> Self {
 250        if let WebsocketError::Http(response) = &error {
 251            match response.status() {
 252                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 253                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 254                _ => {}
 255            }
 256        }
 257        EstablishConnectionError::Other(error.into())
 258    }
 259}
 260
 261impl EstablishConnectionError {
 262    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 263        Self::Other(error.into())
 264    }
 265}
 266
 267#[derive(Copy, Clone, Debug, PartialEq)]
 268pub enum Status {
 269    SignedOut,
 270    UpgradeRequired,
 271    Authenticating,
 272    Authenticated,
 273    AuthenticationError,
 274    Connecting,
 275    ConnectionError,
 276    Connected {
 277        peer_id: PeerId,
 278        connection_id: ConnectionId,
 279    },
 280    ConnectionLost,
 281    Reauthenticating,
 282    Reauthenticated,
 283    Reconnecting,
 284    ReconnectionError {
 285        next_reconnection: Instant,
 286    },
 287}
 288
 289impl Status {
 290    pub fn is_connected(&self) -> bool {
 291        matches!(self, Self::Connected { .. })
 292    }
 293
 294    pub fn was_connected(&self) -> bool {
 295        matches!(
 296            self,
 297            Self::ConnectionLost
 298                | Self::Reauthenticating
 299                | Self::Reauthenticated
 300                | Self::Reconnecting
 301        )
 302    }
 303
 304    /// Returns whether the client is currently connected or was connected at some point.
 305    pub fn is_or_was_connected(&self) -> bool {
 306        self.is_connected() || self.was_connected()
 307    }
 308
 309    pub fn is_signing_in(&self) -> bool {
 310        matches!(
 311            self,
 312            Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
 313        )
 314    }
 315
 316    pub fn is_signed_out(&self) -> bool {
 317        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 318    }
 319}
 320
 321struct ClientState {
 322    credentials: Option<Credentials>,
 323    status: (watch::Sender<Status>, watch::Receiver<Status>),
 324    _reconnect_task: Option<Task<()>>,
 325}
 326
 327#[derive(Clone, Debug, Eq, PartialEq)]
 328pub struct Credentials {
 329    pub user_id: u64,
 330    pub access_token: String,
 331}
 332
 333impl Credentials {
 334    pub fn authorization_header(&self) -> String {
 335        format!("{} {}", self.user_id, self.access_token)
 336    }
 337}
 338
 339pub struct ClientCredentialsProvider {
 340    provider: Arc<dyn CredentialsProvider>,
 341}
 342
 343impl ClientCredentialsProvider {
 344    pub fn new(cx: &App) -> Self {
 345        Self {
 346            provider: <dyn CredentialsProvider>::global(cx),
 347        }
 348    }
 349
 350    fn server_url(&self, cx: &AsyncApp) -> Result<String> {
 351        cx.update(|cx| ClientSettings::get_global(cx).server_url.clone())
 352    }
 353
 354    /// Reads the credentials from the provider.
 355    fn read_credentials<'a>(
 356        &'a self,
 357        cx: &'a AsyncApp,
 358    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
 359        async move {
 360            if IMPERSONATE_LOGIN.is_some() {
 361                return None;
 362            }
 363
 364            let server_url = self.server_url(cx).ok()?;
 365            let (user_id, access_token) = self
 366                .provider
 367                .read_credentials(&server_url, cx)
 368                .await
 369                .log_err()
 370                .flatten()?;
 371
 372            Some(Credentials {
 373                user_id: user_id.parse().ok()?,
 374                access_token: String::from_utf8(access_token).ok()?,
 375            })
 376        }
 377        .boxed_local()
 378    }
 379
 380    /// Writes the credentials to the provider.
 381    fn write_credentials<'a>(
 382        &'a self,
 383        user_id: u64,
 384        access_token: String,
 385        cx: &'a AsyncApp,
 386    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 387        async move {
 388            let server_url = self.server_url(cx)?;
 389            self.provider
 390                .write_credentials(
 391                    &server_url,
 392                    &user_id.to_string(),
 393                    access_token.as_bytes(),
 394                    cx,
 395                )
 396                .await
 397        }
 398        .boxed_local()
 399    }
 400
 401    /// Deletes the credentials from the provider.
 402    fn delete_credentials<'a>(
 403        &'a self,
 404        cx: &'a AsyncApp,
 405    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 406        async move {
 407            let server_url = self.server_url(cx)?;
 408            self.provider.delete_credentials(&server_url, cx).await
 409        }
 410        .boxed_local()
 411    }
 412}
 413
 414impl Default for ClientState {
 415    fn default() -> Self {
 416        Self {
 417            credentials: None,
 418            status: watch::channel_with(Status::SignedOut),
 419            _reconnect_task: None,
 420        }
 421    }
 422}
 423
 424pub enum Subscription {
 425    Entity {
 426        client: Weak<Client>,
 427        id: (TypeId, u64),
 428    },
 429    Message {
 430        client: Weak<Client>,
 431        id: TypeId,
 432    },
 433}
 434
 435impl Drop for Subscription {
 436    fn drop(&mut self) {
 437        match self {
 438            Subscription::Entity { client, id } => {
 439                if let Some(client) = client.upgrade() {
 440                    let mut state = client.handler_set.lock();
 441                    let _ = state.entities_by_type_and_remote_id.remove(id);
 442                }
 443            }
 444            Subscription::Message { client, id } => {
 445                if let Some(client) = client.upgrade() {
 446                    let mut state = client.handler_set.lock();
 447                    let _ = state.entity_types_by_message_type.remove(id);
 448                    let _ = state.message_handlers.remove(id);
 449                }
 450            }
 451        }
 452    }
 453}
 454
 455pub struct PendingEntitySubscription<T: 'static> {
 456    client: Arc<Client>,
 457    remote_id: u64,
 458    _entity_type: PhantomData<T>,
 459    consumed: bool,
 460}
 461
 462impl<T: 'static> PendingEntitySubscription<T> {
 463    pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
 464        self.consumed = true;
 465        let mut handlers = self.client.handler_set.lock();
 466        let id = (TypeId::of::<T>(), self.remote_id);
 467        let Some(EntityMessageSubscriber::Pending(messages)) =
 468            handlers.entities_by_type_and_remote_id.remove(&id)
 469        else {
 470            unreachable!()
 471        };
 472
 473        handlers.entities_by_type_and_remote_id.insert(
 474            id,
 475            EntityMessageSubscriber::Entity {
 476                handle: entity.downgrade().into(),
 477            },
 478        );
 479        drop(handlers);
 480        for message in messages {
 481            let client_id = self.client.id();
 482            let type_name = message.payload_type_name();
 483            let sender_id = message.original_sender_id();
 484            log::debug!(
 485                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 486                client_id,
 487                sender_id,
 488                type_name
 489            );
 490            self.client.handle_message(message, cx);
 491        }
 492        Subscription::Entity {
 493            client: Arc::downgrade(&self.client),
 494            id,
 495        }
 496    }
 497}
 498
 499impl<T: 'static> Drop for PendingEntitySubscription<T> {
 500    fn drop(&mut self) {
 501        if !self.consumed {
 502            let mut state = self.client.handler_set.lock();
 503            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 504                .entities_by_type_and_remote_id
 505                .remove(&(TypeId::of::<T>(), self.remote_id))
 506            {
 507                for message in messages {
 508                    log::info!("unhandled message {}", message.payload_type_name());
 509                }
 510            }
 511        }
 512    }
 513}
 514
 515#[derive(Copy, Clone, Deserialize, Debug)]
 516pub struct TelemetrySettings {
 517    pub diagnostics: bool,
 518    pub metrics: bool,
 519}
 520
 521impl settings::Settings for TelemetrySettings {
 522    fn from_settings(content: &SettingsContent) -> Self {
 523        Self {
 524            diagnostics: content.telemetry.as_ref().unwrap().diagnostics.unwrap(),
 525            metrics: content.telemetry.as_ref().unwrap().metrics.unwrap(),
 526        }
 527    }
 528
 529    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut SettingsContent) {
 530        let mut telemetry = settings::TelemetrySettingsContent::default();
 531        vscode.enum_setting("telemetry.telemetryLevel", &mut telemetry.metrics, |s| {
 532            Some(s == "all")
 533        });
 534        vscode.enum_setting(
 535            "telemetry.telemetryLevel",
 536            &mut telemetry.diagnostics,
 537            |s| Some(matches!(s, "all" | "error" | "crash")),
 538        );
 539        // we could translate telemetry.telemetryLevel, but just because users didn't want
 540        // to send microsoft telemetry doesn't mean they don't want to send it to zed. their
 541        // all/error/crash/off correspond to combinations of our "diagnostics" and "metrics".
 542        if let Some(diagnostics) = telemetry.diagnostics {
 543            current.telemetry.get_or_insert_default().diagnostics = Some(diagnostics)
 544        }
 545        if let Some(metrics) = telemetry.metrics {
 546            current.telemetry.get_or_insert_default().metrics = Some(metrics)
 547        }
 548    }
 549}
 550
 551impl Client {
 552    pub fn new(
 553        clock: Arc<dyn SystemClock>,
 554        http: Arc<HttpClientWithUrl>,
 555        cx: &mut App,
 556    ) -> Arc<Self> {
 557        Arc::new(Self {
 558            id: AtomicU64::new(0),
 559            peer: Peer::new(0),
 560            telemetry: Telemetry::new(clock, http.clone(), cx),
 561            cloud_client: Arc::new(CloudApiClient::new(http.clone())),
 562            http,
 563            credentials_provider: ClientCredentialsProvider::new(cx),
 564            state: Default::default(),
 565            handler_set: Default::default(),
 566            message_to_client_handlers: parking_lot::Mutex::new(Vec::new()),
 567
 568            #[cfg(any(test, feature = "test-support"))]
 569            authenticate: Default::default(),
 570            #[cfg(any(test, feature = "test-support"))]
 571            establish_connection: Default::default(),
 572            #[cfg(any(test, feature = "test-support"))]
 573            rpc_url: RwLock::default(),
 574        })
 575    }
 576
 577    pub fn production(cx: &mut App) -> Arc<Self> {
 578        let clock = Arc::new(clock::RealSystemClock);
 579        let http = Arc::new(HttpClientWithUrl::new_url(
 580            cx.http_client(),
 581            &ClientSettings::get_global(cx).server_url,
 582            cx.http_client().proxy().cloned(),
 583        ));
 584        Self::new(clock, http, cx)
 585    }
 586
 587    pub fn id(&self) -> u64 {
 588        self.id.load(Ordering::SeqCst)
 589    }
 590
 591    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 592        self.http.clone()
 593    }
 594
 595    pub fn cloud_client(&self) -> Arc<CloudApiClient> {
 596        self.cloud_client.clone()
 597    }
 598
 599    pub fn set_id(&self, id: u64) -> &Self {
 600        self.id.store(id, Ordering::SeqCst);
 601        self
 602    }
 603
 604    #[cfg(any(test, feature = "test-support"))]
 605    pub fn teardown(&self) {
 606        let mut state = self.state.write();
 607        state._reconnect_task.take();
 608        self.handler_set.lock().clear();
 609        self.peer.teardown();
 610    }
 611
 612    #[cfg(any(test, feature = "test-support"))]
 613    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 614    where
 615        F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
 616    {
 617        *self.authenticate.write() = Some(Box::new(authenticate));
 618        self
 619    }
 620
 621    #[cfg(any(test, feature = "test-support"))]
 622    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 623    where
 624        F: 'static
 625            + Send
 626            + Sync
 627            + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
 628    {
 629        *self.establish_connection.write() = Some(Box::new(connect));
 630        self
 631    }
 632
 633    #[cfg(any(test, feature = "test-support"))]
 634    pub fn override_rpc_url(&self, url: Url) -> &Self {
 635        *self.rpc_url.write() = Some(url);
 636        self
 637    }
 638
 639    pub fn global(cx: &App) -> Arc<Self> {
 640        cx.global::<GlobalClient>().0.clone()
 641    }
 642    pub fn set_global(client: Arc<Client>, cx: &mut App) {
 643        cx.set_global(GlobalClient(client))
 644    }
 645
 646    pub fn user_id(&self) -> Option<u64> {
 647        self.state
 648            .read()
 649            .credentials
 650            .as_ref()
 651            .map(|credentials| credentials.user_id)
 652    }
 653
 654    pub fn peer_id(&self) -> Option<PeerId> {
 655        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 656            Some(*peer_id)
 657        } else {
 658            None
 659        }
 660    }
 661
 662    pub fn status(&self) -> watch::Receiver<Status> {
 663        self.state.read().status.1.clone()
 664    }
 665
 666    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
 667        log::info!("set status on client {}: {:?}", self.id(), status);
 668        let mut state = self.state.write();
 669        *state.status.0.borrow_mut() = status;
 670
 671        match status {
 672            Status::Connected { .. } => {
 673                state._reconnect_task = None;
 674            }
 675            Status::ConnectionLost => {
 676                let client = self.clone();
 677                state._reconnect_task = Some(cx.spawn(async move |cx| {
 678                    #[cfg(any(test, feature = "test-support"))]
 679                    let mut rng = StdRng::seed_from_u64(0);
 680                    #[cfg(not(any(test, feature = "test-support")))]
 681                    let mut rng = StdRng::from_os_rng();
 682
 683                    let mut delay = INITIAL_RECONNECTION_DELAY;
 684                    loop {
 685                        match client.connect(true, cx).await {
 686                            ConnectionResult::Timeout => {
 687                                log::error!("client connect attempt timed out")
 688                            }
 689                            ConnectionResult::ConnectionReset => {
 690                                log::error!("client connect attempt reset")
 691                            }
 692                            ConnectionResult::Result(r) => {
 693                                if let Err(error) = r {
 694                                    log::error!("failed to connect: {error}");
 695                                } else {
 696                                    break;
 697                                }
 698                            }
 699                        }
 700
 701                        if matches!(
 702                            *client.status().borrow(),
 703                            Status::AuthenticationError | Status::ConnectionError
 704                        ) {
 705                            client.set_status(
 706                                Status::ReconnectionError {
 707                                    next_reconnection: Instant::now() + delay,
 708                                },
 709                                cx,
 710                            );
 711                            let jitter = Duration::from_millis(
 712                                rng.random_range(0..delay.as_millis() as u64),
 713                            );
 714                            cx.background_executor().timer(delay + jitter).await;
 715                            delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
 716                        } else {
 717                            break;
 718                        }
 719                    }
 720                }));
 721            }
 722            Status::SignedOut | Status::UpgradeRequired => {
 723                self.telemetry.set_authenticated_user_info(None, false);
 724                state._reconnect_task.take();
 725            }
 726            _ => {}
 727        }
 728    }
 729
 730    pub fn subscribe_to_entity<T>(
 731        self: &Arc<Self>,
 732        remote_id: u64,
 733    ) -> Result<PendingEntitySubscription<T>>
 734    where
 735        T: 'static,
 736    {
 737        let id = (TypeId::of::<T>(), remote_id);
 738
 739        let mut state = self.handler_set.lock();
 740        anyhow::ensure!(
 741            !state.entities_by_type_and_remote_id.contains_key(&id),
 742            "already subscribed to entity"
 743        );
 744
 745        state
 746            .entities_by_type_and_remote_id
 747            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 748
 749        Ok(PendingEntitySubscription {
 750            client: self.clone(),
 751            remote_id,
 752            consumed: false,
 753            _entity_type: PhantomData,
 754        })
 755    }
 756
 757    #[track_caller]
 758    pub fn add_message_handler<M, E, H, F>(
 759        self: &Arc<Self>,
 760        entity: WeakEntity<E>,
 761        handler: H,
 762    ) -> Subscription
 763    where
 764        M: EnvelopedMessage,
 765        E: 'static,
 766        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 767        F: 'static + Future<Output = Result<()>>,
 768    {
 769        self.add_message_handler_impl(entity, move |entity, message, _, cx| {
 770            handler(entity, message, cx)
 771        })
 772    }
 773
 774    fn add_message_handler_impl<M, E, H, F>(
 775        self: &Arc<Self>,
 776        entity: WeakEntity<E>,
 777        handler: H,
 778    ) -> Subscription
 779    where
 780        M: EnvelopedMessage,
 781        E: 'static,
 782        H: 'static
 783            + Sync
 784            + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
 785            + Send
 786            + Sync,
 787        F: 'static + Future<Output = Result<()>>,
 788    {
 789        let message_type_id = TypeId::of::<M>();
 790        let mut state = self.handler_set.lock();
 791        state
 792            .entities_by_message_type
 793            .insert(message_type_id, entity.into());
 794
 795        let prev_handler = state.message_handlers.insert(
 796            message_type_id,
 797            Arc::new(move |subscriber, envelope, client, cx| {
 798                let subscriber = subscriber.downcast::<E>().unwrap();
 799                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 800                handler(subscriber, *envelope, client, cx).boxed_local()
 801            }),
 802        );
 803        if prev_handler.is_some() {
 804            let location = std::panic::Location::caller();
 805            panic!(
 806                "{}:{} registered handler for the same message {} twice",
 807                location.file(),
 808                location.line(),
 809                std::any::type_name::<M>()
 810            );
 811        }
 812
 813        Subscription::Message {
 814            client: Arc::downgrade(self),
 815            id: message_type_id,
 816        }
 817    }
 818
 819    pub fn add_request_handler<M, E, H, F>(
 820        self: &Arc<Self>,
 821        entity: WeakEntity<E>,
 822        handler: H,
 823    ) -> Subscription
 824    where
 825        M: RequestMessage,
 826        E: 'static,
 827        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 828        F: 'static + Future<Output = Result<M::Response>>,
 829    {
 830        self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
 831            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 832        })
 833    }
 834
 835    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 836        receipt: Receipt<T>,
 837        response: F,
 838        client: AnyProtoClient,
 839    ) -> Result<()> {
 840        match response.await {
 841            Ok(response) => {
 842                client.send_response(receipt.message_id, response)?;
 843                Ok(())
 844            }
 845            Err(error) => {
 846                client.send_response(receipt.message_id, error.to_proto())?;
 847                Err(error)
 848            }
 849        }
 850    }
 851
 852    pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
 853        self.credentials_provider
 854            .read_credentials(cx)
 855            .await
 856            .is_some()
 857    }
 858
 859    pub async fn sign_in(
 860        self: &Arc<Self>,
 861        try_provider: bool,
 862        cx: &AsyncApp,
 863    ) -> Result<Credentials> {
 864        let is_reauthenticating = if self.status().borrow().is_signed_out() {
 865            self.set_status(Status::Authenticating, cx);
 866            false
 867        } else {
 868            self.set_status(Status::Reauthenticating, cx);
 869            true
 870        };
 871
 872        let mut credentials = None;
 873
 874        let old_credentials = self.state.read().credentials.clone();
 875        if let Some(old_credentials) = old_credentials
 876            && self.validate_credentials(&old_credentials, cx).await?
 877        {
 878            credentials = Some(old_credentials);
 879        }
 880
 881        if credentials.is_none()
 882            && try_provider
 883            && let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await
 884        {
 885            if self.validate_credentials(&stored_credentials, cx).await? {
 886                credentials = Some(stored_credentials);
 887            } else {
 888                self.credentials_provider
 889                    .delete_credentials(cx)
 890                    .await
 891                    .log_err();
 892            }
 893        }
 894
 895        if credentials.is_none() {
 896            let mut status_rx = self.status();
 897            let _ = status_rx.next().await;
 898            futures::select_biased! {
 899                authenticate = self.authenticate(cx).fuse() => {
 900                    match authenticate {
 901                        Ok(creds) => {
 902                            if IMPERSONATE_LOGIN.is_none() {
 903                                self.credentials_provider
 904                                    .write_credentials(creds.user_id, creds.access_token.clone(), cx)
 905                                    .await
 906                                    .log_err();
 907                            }
 908
 909                            credentials = Some(creds);
 910                        },
 911                        Err(err) => {
 912                            self.set_status(Status::AuthenticationError, cx);
 913                            return Err(err);
 914                        }
 915                    }
 916                }
 917                _ = status_rx.next().fuse() => {
 918                    return Err(anyhow!("authentication canceled"));
 919                }
 920            }
 921        }
 922
 923        let credentials = credentials.unwrap();
 924        self.set_id(credentials.user_id);
 925        self.cloud_client
 926            .set_credentials(credentials.user_id as u32, credentials.access_token.clone());
 927        self.state.write().credentials = Some(credentials.clone());
 928        self.set_status(
 929            if is_reauthenticating {
 930                Status::Reauthenticated
 931            } else {
 932                Status::Authenticated
 933            },
 934            cx,
 935        );
 936
 937        Ok(credentials)
 938    }
 939
 940    async fn validate_credentials(
 941        self: &Arc<Self>,
 942        credentials: &Credentials,
 943        cx: &AsyncApp,
 944    ) -> Result<bool> {
 945        match self
 946            .cloud_client
 947            .validate_credentials(credentials.user_id as u32, &credentials.access_token)
 948            .await
 949        {
 950            Ok(valid) => Ok(valid),
 951            Err(err) => {
 952                self.set_status(Status::AuthenticationError, cx);
 953                Err(anyhow!("failed to validate credentials: {}", err))
 954            }
 955        }
 956    }
 957
 958    /// Establishes a WebSocket connection with Cloud for receiving updates from the server.
 959    async fn connect_to_cloud(self: &Arc<Self>, cx: &AsyncApp) -> Result<()> {
 960        let connect_task = cx.update({
 961            let cloud_client = self.cloud_client.clone();
 962            move |cx| cloud_client.connect(cx)
 963        })??;
 964        let connection = connect_task.await?;
 965
 966        let (mut messages, task) = cx.update(|cx| connection.spawn(cx))?;
 967        task.detach();
 968
 969        cx.spawn({
 970            let this = self.clone();
 971            async move |cx| {
 972                while let Some(message) = messages.next().await {
 973                    if let Some(message) = message.log_err() {
 974                        this.handle_message_to_client(message, cx);
 975                    }
 976                }
 977            }
 978        })
 979        .detach();
 980
 981        Ok(())
 982    }
 983
 984    /// Performs a sign-in and also (optionally) connects to Collab.
 985    ///
 986    /// Only Zed staff automatically connect to Collab.
 987    pub async fn sign_in_with_optional_connect(
 988        self: &Arc<Self>,
 989        try_provider: bool,
 990        cx: &AsyncApp,
 991    ) -> Result<()> {
 992        // Don't try to sign in again if we're already connected to Collab, as it will temporarily disconnect us.
 993        if self.status().borrow().is_connected() {
 994            return Ok(());
 995        }
 996
 997        let (is_staff_tx, is_staff_rx) = oneshot::channel::<bool>();
 998        let mut is_staff_tx = Some(is_staff_tx);
 999        cx.update(|cx| {
1000            cx.on_flags_ready(move |state, _cx| {
1001                if let Some(is_staff_tx) = is_staff_tx.take() {
1002                    is_staff_tx.send(state.is_staff).log_err();
1003                }
1004            })
1005            .detach();
1006        })
1007        .log_err();
1008
1009        let credentials = self.sign_in(try_provider, cx).await?;
1010
1011        self.connect_to_cloud(cx).await.log_err();
1012
1013        cx.update(move |cx| {
1014            cx.spawn({
1015                let client = self.clone();
1016                async move |cx| {
1017                    let is_staff = is_staff_rx.await?;
1018                    if is_staff {
1019                        match client.connect_with_credentials(credentials, cx).await {
1020                            ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
1021                            ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
1022                            ConnectionResult::Result(result) => {
1023                                result.context("client auth and connect")
1024                            }
1025                        }
1026                    } else {
1027                        Ok(())
1028                    }
1029                }
1030            })
1031            .detach_and_log_err(cx);
1032        })
1033        .log_err();
1034
1035        Ok(())
1036    }
1037
1038    pub async fn connect(
1039        self: &Arc<Self>,
1040        try_provider: bool,
1041        cx: &AsyncApp,
1042    ) -> ConnectionResult<()> {
1043        let was_disconnected = match *self.status().borrow() {
1044            Status::SignedOut | Status::Authenticated => true,
1045            Status::ConnectionError
1046            | Status::ConnectionLost
1047            | Status::Authenticating
1048            | Status::AuthenticationError
1049            | Status::Reauthenticating
1050            | Status::Reauthenticated
1051            | Status::ReconnectionError { .. } => false,
1052            Status::Connected { .. } | Status::Connecting | Status::Reconnecting => {
1053                return ConnectionResult::Result(Ok(()));
1054            }
1055            Status::UpgradeRequired => {
1056                return ConnectionResult::Result(
1057                    Err(EstablishConnectionError::UpgradeRequired)
1058                        .context("client auth and connect"),
1059                );
1060            }
1061        };
1062        let credentials = match self.sign_in(try_provider, cx).await {
1063            Ok(credentials) => credentials,
1064            Err(err) => return ConnectionResult::Result(Err(err)),
1065        };
1066
1067        if was_disconnected {
1068            self.set_status(Status::Connecting, cx);
1069        } else {
1070            self.set_status(Status::Reconnecting, cx);
1071        }
1072
1073        self.connect_with_credentials(credentials, cx).await
1074    }
1075
1076    async fn connect_with_credentials(
1077        self: &Arc<Self>,
1078        credentials: Credentials,
1079        cx: &AsyncApp,
1080    ) -> ConnectionResult<()> {
1081        let mut timeout =
1082            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
1083        futures::select_biased! {
1084            connection = self.establish_connection(&credentials, cx).fuse() => {
1085                match connection {
1086                    Ok(conn) => {
1087                        futures::select_biased! {
1088                            result = self.set_connection(conn, cx).fuse() => {
1089                                match result.context("client auth and connect") {
1090                                    Ok(()) => ConnectionResult::Result(Ok(())),
1091                                    Err(err) => {
1092                                        self.set_status(Status::ConnectionError, cx);
1093                                        ConnectionResult::Result(Err(err))
1094                                    },
1095                                }
1096                            },
1097                            _ = timeout => {
1098                                self.set_status(Status::ConnectionError, cx);
1099                                ConnectionResult::Timeout
1100                            }
1101                        }
1102                    }
1103                    Err(EstablishConnectionError::Unauthorized) => {
1104                        self.set_status(Status::ConnectionError, cx);
1105                        ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
1106                    }
1107                    Err(EstablishConnectionError::UpgradeRequired) => {
1108                        self.set_status(Status::UpgradeRequired, cx);
1109                        ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
1110                    }
1111                    Err(error) => {
1112                        self.set_status(Status::ConnectionError, cx);
1113                        ConnectionResult::Result(Err(error).context("client auth and connect"))
1114                    }
1115                }
1116            }
1117            _ = &mut timeout => {
1118                self.set_status(Status::ConnectionError, cx);
1119                ConnectionResult::Timeout
1120            }
1121        }
1122    }
1123
1124    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
1125        let executor = cx.background_executor();
1126        log::debug!("add connection to peer");
1127        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
1128            let executor = executor.clone();
1129            move |duration| executor.timer(duration)
1130        });
1131        let handle_io = executor.spawn(handle_io);
1132
1133        let peer_id = async {
1134            log::debug!("waiting for server hello");
1135            let message = incoming.next().await.context("no hello message received")?;
1136            log::debug!("got server hello");
1137            let hello_message_type_name = message.payload_type_name().to_string();
1138            let hello = message
1139                .into_any()
1140                .downcast::<TypedEnvelope<proto::Hello>>()
1141                .map_err(|_| {
1142                    anyhow!(
1143                        "invalid hello message received: {:?}",
1144                        hello_message_type_name
1145                    )
1146                })?;
1147            let peer_id = hello.payload.peer_id.context("invalid peer id")?;
1148            Ok(peer_id)
1149        };
1150
1151        let peer_id = match peer_id.await {
1152            Ok(peer_id) => peer_id,
1153            Err(error) => {
1154                self.peer.disconnect(connection_id);
1155                return Err(error);
1156            }
1157        };
1158
1159        log::debug!(
1160            "set status to connected (connection id: {:?}, peer id: {:?})",
1161            connection_id,
1162            peer_id
1163        );
1164        self.set_status(
1165            Status::Connected {
1166                peer_id,
1167                connection_id,
1168            },
1169            cx,
1170        );
1171
1172        cx.spawn({
1173            let this = self.clone();
1174            async move |cx| {
1175                while let Some(message) = incoming.next().await {
1176                    this.handle_message(message, cx);
1177                    // Don't starve the main thread when receiving lots of messages at once.
1178                    smol::future::yield_now().await;
1179                }
1180            }
1181        })
1182        .detach();
1183
1184        cx.spawn({
1185            let this = self.clone();
1186            async move |cx| match handle_io.await {
1187                Ok(()) => {
1188                    if *this.status().borrow()
1189                        == (Status::Connected {
1190                            connection_id,
1191                            peer_id,
1192                        })
1193                    {
1194                        this.set_status(Status::SignedOut, cx);
1195                    }
1196                }
1197                Err(err) => {
1198                    log::error!("connection error: {:?}", err);
1199                    this.set_status(Status::ConnectionLost, cx);
1200                }
1201            }
1202        })
1203        .detach();
1204
1205        Ok(())
1206    }
1207
1208    fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1209        #[cfg(any(test, feature = "test-support"))]
1210        if let Some(callback) = self.authenticate.read().as_ref() {
1211            return callback(cx);
1212        }
1213
1214        self.authenticate_with_browser(cx)
1215    }
1216
1217    fn establish_connection(
1218        self: &Arc<Self>,
1219        credentials: &Credentials,
1220        cx: &AsyncApp,
1221    ) -> Task<Result<Connection, EstablishConnectionError>> {
1222        #[cfg(any(test, feature = "test-support"))]
1223        if let Some(callback) = self.establish_connection.read().as_ref() {
1224            return callback(credentials, cx);
1225        }
1226
1227        self.establish_websocket_connection(credentials, cx)
1228    }
1229
1230    fn rpc_url(
1231        &self,
1232        http: Arc<HttpClientWithUrl>,
1233        release_channel: Option<ReleaseChannel>,
1234    ) -> impl Future<Output = Result<url::Url>> + use<> {
1235        #[cfg(any(test, feature = "test-support"))]
1236        let url_override = self.rpc_url.read().clone();
1237
1238        async move {
1239            #[cfg(any(test, feature = "test-support"))]
1240            if let Some(url) = url_override {
1241                return Ok(url);
1242            }
1243
1244            if let Some(url) = &*ZED_RPC_URL {
1245                return Url::parse(url).context("invalid rpc url");
1246            }
1247
1248            let mut url = http.build_url("/rpc");
1249            if let Some(preview_param) =
1250                release_channel.and_then(|channel| channel.release_query_param())
1251            {
1252                url += "?";
1253                url += preview_param;
1254            }
1255
1256            let response = http.get(&url, Default::default(), false).await?;
1257            anyhow::ensure!(
1258                response.status().is_redirection(),
1259                "unexpected /rpc response status {}",
1260                response.status()
1261            );
1262            let collab_url = response
1263                .headers()
1264                .get("Location")
1265                .context("missing location header in /rpc response")?
1266                .to_str()
1267                .map_err(EstablishConnectionError::other)?
1268                .to_string();
1269            Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
1270        }
1271    }
1272
1273    fn establish_websocket_connection(
1274        self: &Arc<Self>,
1275        credentials: &Credentials,
1276        cx: &AsyncApp,
1277    ) -> Task<Result<Connection, EstablishConnectionError>> {
1278        let release_channel = cx
1279            .update(|cx| ReleaseChannel::try_global(cx))
1280            .ok()
1281            .flatten();
1282        let app_version = cx
1283            .update(|cx| AppVersion::global(cx).to_string())
1284            .ok()
1285            .unwrap_or_default();
1286
1287        let http = self.http.clone();
1288        let proxy = http.proxy().cloned();
1289        let user_agent = http.user_agent().cloned();
1290        let credentials = credentials.clone();
1291        let rpc_url = self.rpc_url(http, release_channel);
1292        let system_id = self.telemetry.system_id();
1293        let metrics_id = self.telemetry.metrics_id();
1294        cx.spawn(async move |cx| {
1295            use HttpOrHttps::*;
1296
1297            #[derive(Debug)]
1298            enum HttpOrHttps {
1299                Http,
1300                Https,
1301            }
1302
1303            let mut rpc_url = rpc_url.await?;
1304            let url_scheme = match rpc_url.scheme() {
1305                "https" => Https,
1306                "http" => Http,
1307                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1308            };
1309
1310            let stream = gpui_tokio::Tokio::spawn_result(cx, {
1311                let rpc_url = rpc_url.clone();
1312                async move {
1313                    let rpc_host = rpc_url
1314                        .host_str()
1315                        .zip(rpc_url.port_or_known_default())
1316                        .context("missing host in rpc url")?;
1317                    Ok(match proxy {
1318                        Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
1319                        None => Box::new(TcpStream::connect(rpc_host).await?),
1320                    })
1321                }
1322            })?
1323            .await?;
1324
1325            log::info!("connected to rpc endpoint {}", rpc_url);
1326
1327            rpc_url
1328                .set_scheme(match url_scheme {
1329                    Https => "wss",
1330                    Http => "ws",
1331                })
1332                .unwrap();
1333
1334            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1335            // for us from the RPC URL.
1336            //
1337            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1338            let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1339
1340            // We then modify the request to add our desired headers.
1341            let request_headers = request.headers_mut();
1342            request_headers.insert(
1343                http::header::AUTHORIZATION,
1344                HeaderValue::from_str(&credentials.authorization_header())?,
1345            );
1346            request_headers.insert(
1347                "x-zed-protocol-version",
1348                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1349            );
1350            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1351            request_headers.insert(
1352                "x-zed-release-channel",
1353                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1354            );
1355            if let Some(user_agent) = user_agent {
1356                request_headers.insert(http::header::USER_AGENT, user_agent);
1357            }
1358            if let Some(system_id) = system_id {
1359                request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1360            }
1361            if let Some(metrics_id) = metrics_id {
1362                request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1363            }
1364
1365            let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1366                request,
1367                stream,
1368                Some(Arc::new(http_client_tls::tls_config()).into()),
1369                None,
1370            )
1371            .await?;
1372
1373            Ok(Connection::new(
1374                stream
1375                    .map_err(|error| anyhow!(error))
1376                    .sink_map_err(|error| anyhow!(error)),
1377            ))
1378        })
1379    }
1380
1381    pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1382        let http = self.http.clone();
1383        let this = self.clone();
1384        cx.spawn(async move |cx| {
1385            let background = cx.background_executor().clone();
1386
1387            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1388            cx.update(|cx| {
1389                cx.spawn(async move |cx| {
1390                    let url = open_url_rx.await?;
1391                    cx.update(|cx| cx.open_url(&url))
1392                })
1393                .detach_and_log_err(cx);
1394            })
1395            .log_err();
1396
1397            let credentials = background
1398                .clone()
1399                .spawn(async move {
1400                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1401                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1402                    // any other app running on the user's device.
1403                    let (public_key, private_key) =
1404                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1405                    let public_key_string = String::try_from(public_key)
1406                        .expect("failed to serialize public key for auth");
1407
1408                    if let Some((login, token)) =
1409                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1410                    {
1411                        if !*USE_WEB_LOGIN {
1412                            eprintln!("authenticate as admin {login}, {token}");
1413
1414                            return this
1415                                .authenticate_as_admin(http, login.clone(), token.clone())
1416                                .await;
1417                        }
1418                    }
1419
1420                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1421                    let server =
1422                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1423                    let port = server.server_addr().port();
1424
1425                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1426                    // that the user is signing in from a Zed app running on the same device.
1427                    let mut url = http.build_url(&format!(
1428                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1429                        port, public_key_string
1430                    ));
1431
1432                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1433                        log::info!("impersonating user @{}", impersonate_login);
1434                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1435                    }
1436
1437                    open_url_tx.send(url).log_err();
1438
1439                    #[derive(Deserialize)]
1440                    struct CallbackParams {
1441                        pub user_id: String,
1442                        pub access_token: String,
1443                    }
1444
1445                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1446                    // access token from the query params.
1447                    //
1448                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1449                    // custom URL scheme instead of this local HTTP server.
1450                    let (user_id, access_token) = background
1451                        .spawn(async move {
1452                            for _ in 0..100 {
1453                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1454                                    let path = req.url();
1455                                    let url = Url::parse(&format!("http://example.com{}", path))
1456                                        .context("failed to parse login notification url")?;
1457                                    let callback_params: CallbackParams =
1458                                        serde_urlencoded::from_str(url.query().unwrap_or_default())
1459                                            .context(
1460                                                "failed to parse sign-in callback query parameters",
1461                                            )?;
1462
1463                                    let post_auth_url =
1464                                        http.build_url("/native_app_signin_succeeded");
1465                                    req.respond(
1466                                        tiny_http::Response::empty(302).with_header(
1467                                            tiny_http::Header::from_bytes(
1468                                                &b"Location"[..],
1469                                                post_auth_url.as_bytes(),
1470                                            )
1471                                            .unwrap(),
1472                                        ),
1473                                    )
1474                                    .context("failed to respond to login http request")?;
1475                                    return Ok((
1476                                        callback_params.user_id,
1477                                        callback_params.access_token,
1478                                    ));
1479                                }
1480                            }
1481
1482                            anyhow::bail!("didn't receive login redirect");
1483                        })
1484                        .await?;
1485
1486                    let access_token = private_key
1487                        .decrypt_string(&access_token)
1488                        .context("failed to decrypt access token")?;
1489
1490                    Ok(Credentials {
1491                        user_id: user_id.parse()?,
1492                        access_token,
1493                    })
1494                })
1495                .await?;
1496
1497            cx.update(|cx| cx.activate(true))?;
1498            Ok(credentials)
1499        })
1500    }
1501
1502    async fn authenticate_as_admin(
1503        self: &Arc<Self>,
1504        http: Arc<HttpClientWithUrl>,
1505        login: String,
1506        api_token: String,
1507    ) -> Result<Credentials> {
1508        #[derive(Serialize)]
1509        struct ImpersonateUserBody {
1510            github_login: String,
1511        }
1512
1513        #[derive(Deserialize)]
1514        struct ImpersonateUserResponse {
1515            user_id: u64,
1516            access_token: String,
1517        }
1518
1519        let url = self
1520            .http
1521            .build_zed_cloud_url("/internal/users/impersonate", &[])?;
1522        let request = Request::post(url.as_str())
1523            .header("Content-Type", "application/json")
1524            .header("Authorization", format!("Bearer {api_token}"))
1525            .body(
1526                serde_json::to_string(&ImpersonateUserBody {
1527                    github_login: login,
1528                })?
1529                .into(),
1530            )?;
1531
1532        let mut response = http.send(request).await?;
1533        let mut body = String::new();
1534        response.body_mut().read_to_string(&mut body).await?;
1535        anyhow::ensure!(
1536            response.status().is_success(),
1537            "admin user request failed {} - {}",
1538            response.status().as_u16(),
1539            body,
1540        );
1541        let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
1542
1543        Ok(Credentials {
1544            user_id: response.user_id,
1545            access_token: response.access_token,
1546        })
1547    }
1548
1549    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1550        self.state.write().credentials = None;
1551        self.cloud_client.clear_credentials();
1552        self.disconnect(cx);
1553
1554        if self.has_credentials(cx).await {
1555            self.credentials_provider
1556                .delete_credentials(cx)
1557                .await
1558                .log_err();
1559        }
1560    }
1561
1562    pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1563        self.peer.teardown();
1564        self.set_status(Status::SignedOut, cx);
1565    }
1566
1567    pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1568        self.peer.teardown();
1569        self.set_status(Status::ConnectionLost, cx);
1570    }
1571
1572    fn connection_id(&self) -> Result<ConnectionId> {
1573        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1574            Ok(connection_id)
1575        } else {
1576            anyhow::bail!("not connected");
1577        }
1578    }
1579
1580    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1581        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1582        self.peer.send(self.connection_id()?, message)
1583    }
1584
1585    pub fn request<T: RequestMessage>(
1586        &self,
1587        request: T,
1588    ) -> impl Future<Output = Result<T::Response>> + use<T> {
1589        self.request_envelope(request)
1590            .map_ok(|envelope| envelope.payload)
1591    }
1592
1593    pub fn request_stream<T: RequestMessage>(
1594        &self,
1595        request: T,
1596    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1597        let client_id = self.id.load(Ordering::SeqCst);
1598        log::debug!(
1599            "rpc request start. client_id:{}. name:{}",
1600            client_id,
1601            T::NAME
1602        );
1603        let response = self
1604            .connection_id()
1605            .map(|conn_id| self.peer.request_stream(conn_id, request));
1606        async move {
1607            let response = response?.await;
1608            log::debug!(
1609                "rpc request finish. client_id:{}. name:{}",
1610                client_id,
1611                T::NAME
1612            );
1613            response
1614        }
1615    }
1616
1617    pub fn request_envelope<T: RequestMessage>(
1618        &self,
1619        request: T,
1620    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1621        let client_id = self.id();
1622        log::debug!(
1623            "rpc request start. client_id:{}. name:{}",
1624            client_id,
1625            T::NAME
1626        );
1627        let response = self
1628            .connection_id()
1629            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1630        async move {
1631            let response = response?.await;
1632            log::debug!(
1633                "rpc request finish. client_id:{}. name:{}",
1634                client_id,
1635                T::NAME
1636            );
1637            response
1638        }
1639    }
1640
1641    pub fn request_dynamic(
1642        &self,
1643        envelope: proto::Envelope,
1644        request_type: &'static str,
1645    ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1646        let client_id = self.id();
1647        log::debug!(
1648            "rpc request start. client_id:{}. name:{}",
1649            client_id,
1650            request_type
1651        );
1652        let response = self
1653            .connection_id()
1654            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1655        async move {
1656            let response = response?.await;
1657            log::debug!(
1658                "rpc request finish. client_id:{}. name:{}",
1659                client_id,
1660                request_type
1661            );
1662            Ok(response?.0)
1663        }
1664    }
1665
1666    fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1667        let sender_id = message.sender_id();
1668        let request_id = message.message_id();
1669        let type_name = message.payload_type_name();
1670        let original_sender_id = message.original_sender_id();
1671
1672        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1673            &self.handler_set,
1674            message,
1675            self.clone().into(),
1676            cx.clone(),
1677        ) {
1678            let client_id = self.id();
1679            log::debug!(
1680                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1681                client_id,
1682                original_sender_id,
1683                type_name
1684            );
1685            cx.spawn(async move |_| match future.await {
1686                Ok(()) => {
1687                    log::debug!("rpc message handled. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}");
1688                }
1689                Err(error) => {
1690                    log::error!("error handling message. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}, error:{error:#}");
1691                }
1692            })
1693            .detach();
1694        } else {
1695            log::info!("unhandled message {}", type_name);
1696            self.peer
1697                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1698                .log_err();
1699        }
1700    }
1701
1702    pub fn add_message_to_client_handler(
1703        self: &Arc<Client>,
1704        handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
1705    ) {
1706        self.message_to_client_handlers
1707            .lock()
1708            .push(Box::new(handler));
1709    }
1710
1711    fn handle_message_to_client(self: &Arc<Client>, message: MessageToClient, cx: &AsyncApp) {
1712        cx.update(|cx| {
1713            for handler in self.message_to_client_handlers.lock().iter() {
1714                handler(&message, cx);
1715            }
1716        })
1717        .ok();
1718    }
1719
1720    pub fn telemetry(&self) -> &Arc<Telemetry> {
1721        &self.telemetry
1722    }
1723}
1724
1725impl ProtoClient for Client {
1726    fn request(
1727        &self,
1728        envelope: proto::Envelope,
1729        request_type: &'static str,
1730    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1731        self.request_dynamic(envelope, request_type).boxed()
1732    }
1733
1734    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1735        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1736        let connection_id = self.connection_id()?;
1737        self.peer.send_dynamic(connection_id, envelope)
1738    }
1739
1740    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1741        log::debug!(
1742            "rpc respond. client_id:{}, name:{}",
1743            self.id(),
1744            message_type
1745        );
1746        let connection_id = self.connection_id()?;
1747        self.peer.send_dynamic(connection_id, envelope)
1748    }
1749
1750    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1751        &self.handler_set
1752    }
1753
1754    fn is_via_collab(&self) -> bool {
1755        true
1756    }
1757}
1758
1759/// prefix for the zed:// url scheme
1760pub const ZED_URL_SCHEME: &str = "zed";
1761
1762/// Parses the given link into a Zed link.
1763///
1764/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1765/// Returns [`None`] otherwise.
1766pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
1767    let server_url = &ClientSettings::get_global(cx).server_url;
1768    if let Some(stripped) = link
1769        .strip_prefix(server_url)
1770        .and_then(|result| result.strip_prefix('/'))
1771    {
1772        return Some(stripped);
1773    }
1774    if let Some(stripped) = link
1775        .strip_prefix(ZED_URL_SCHEME)
1776        .and_then(|result| result.strip_prefix("://"))
1777    {
1778        return Some(stripped);
1779    }
1780
1781    None
1782}
1783
1784#[cfg(test)]
1785mod tests {
1786    use super::*;
1787    use crate::test::{FakeServer, parse_authorization_header};
1788
1789    use clock::FakeSystemClock;
1790    use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1791    use http_client::FakeHttpClient;
1792    use parking_lot::Mutex;
1793    use proto::TypedEnvelope;
1794    use settings::SettingsStore;
1795    use std::future;
1796
1797    #[gpui::test(iterations = 10)]
1798    async fn test_reconnection(cx: &mut TestAppContext) {
1799        init_test(cx);
1800        let user_id = 5;
1801        let client = cx.update(|cx| {
1802            Client::new(
1803                Arc::new(FakeSystemClock::new()),
1804                FakeHttpClient::with_404_response(),
1805                cx,
1806            )
1807        });
1808        let server = FakeServer::for_client(user_id, &client, cx).await;
1809        let mut status = client.status();
1810        assert!(matches!(
1811            status.next().await,
1812            Some(Status::Connected { .. })
1813        ));
1814        assert_eq!(server.auth_count(), 1);
1815
1816        server.forbid_connections();
1817        server.disconnect();
1818        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1819
1820        server.allow_connections();
1821        cx.executor().advance_clock(Duration::from_secs(10));
1822        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1823        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1824
1825        server.forbid_connections();
1826        server.disconnect();
1827        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1828
1829        // Clear cached credentials after authentication fails
1830        server.roll_access_token();
1831        server.allow_connections();
1832        cx.executor().run_until_parked();
1833        cx.executor().advance_clock(Duration::from_secs(10));
1834        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1835        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1836    }
1837
1838    #[gpui::test(iterations = 10)]
1839    async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) {
1840        init_test(cx);
1841        let http_client = FakeHttpClient::with_200_response();
1842        let client =
1843            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1844        let server = FakeServer::for_client(42, &client, cx).await;
1845        let mut status = client.status();
1846        assert!(matches!(
1847            status.next().await,
1848            Some(Status::Connected { .. })
1849        ));
1850        assert_eq!(server.auth_count(), 1);
1851
1852        // Simulate an auth failure during reconnection.
1853        http_client
1854            .as_fake()
1855            .replace_handler(|_, _request| async move {
1856                Ok(http_client::Response::builder()
1857                    .status(503)
1858                    .body("".into())
1859                    .unwrap())
1860            });
1861        server.disconnect();
1862        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1863
1864        // Restore the ability to authenticate.
1865        http_client
1866            .as_fake()
1867            .replace_handler(|_, _request| async move {
1868                Ok(http_client::Response::builder()
1869                    .status(200)
1870                    .body("".into())
1871                    .unwrap())
1872            });
1873        cx.executor().advance_clock(Duration::from_secs(10));
1874        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1875        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1876    }
1877
1878    #[gpui::test(iterations = 10)]
1879    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1880        init_test(cx);
1881        let user_id = 5;
1882        let client = cx.update(|cx| {
1883            Client::new(
1884                Arc::new(FakeSystemClock::new()),
1885                FakeHttpClient::with_404_response(),
1886                cx,
1887            )
1888        });
1889        let mut status = client.status();
1890
1891        // Time out when client tries to connect.
1892        client.override_authenticate(move |cx| {
1893            cx.background_spawn(async move {
1894                Ok(Credentials {
1895                    user_id,
1896                    access_token: "token".into(),
1897                })
1898            })
1899        });
1900        client.override_establish_connection(|_, cx| {
1901            cx.background_spawn(async move {
1902                future::pending::<()>().await;
1903                unreachable!()
1904            })
1905        });
1906        let auth_and_connect = cx.spawn({
1907            let client = client.clone();
1908            |cx| async move { client.connect(false, &cx).await }
1909        });
1910        executor.run_until_parked();
1911        assert!(matches!(status.next().await, Some(Status::Connecting)));
1912
1913        executor.advance_clock(CONNECTION_TIMEOUT);
1914        assert!(matches!(status.next().await, Some(Status::ConnectionError)));
1915        auth_and_connect.await.into_response().unwrap_err();
1916
1917        // Allow the connection to be established.
1918        let server = FakeServer::for_client(user_id, &client, cx).await;
1919        assert!(matches!(
1920            status.next().await,
1921            Some(Status::Connected { .. })
1922        ));
1923
1924        // Disconnect client.
1925        server.forbid_connections();
1926        server.disconnect();
1927        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1928
1929        // Time out when re-establishing the connection.
1930        server.allow_connections();
1931        client.override_establish_connection(|_, cx| {
1932            cx.background_spawn(async move {
1933                future::pending::<()>().await;
1934                unreachable!()
1935            })
1936        });
1937        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1938        assert!(matches!(status.next().await, Some(Status::Reconnecting)));
1939
1940        executor.advance_clock(CONNECTION_TIMEOUT);
1941        assert!(matches!(
1942            status.next().await,
1943            Some(Status::ReconnectionError { .. })
1944        ));
1945    }
1946
1947    #[gpui::test(iterations = 10)]
1948    async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
1949        init_test(cx);
1950        let auth_count = Arc::new(Mutex::new(0));
1951        let http_client = FakeHttpClient::create(|_request| async move {
1952            Ok(http_client::Response::builder()
1953                .status(200)
1954                .body("".into())
1955                .unwrap())
1956        });
1957        let client =
1958            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1959        client.override_authenticate({
1960            let auth_count = auth_count.clone();
1961            move |cx| {
1962                let auth_count = auth_count.clone();
1963                cx.background_spawn(async move {
1964                    *auth_count.lock() += 1;
1965                    Ok(Credentials {
1966                        user_id: 1,
1967                        access_token: auth_count.lock().to_string(),
1968                    })
1969                })
1970            }
1971        });
1972
1973        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1974        assert_eq!(*auth_count.lock(), 1);
1975        assert_eq!(credentials.access_token, "1");
1976
1977        // If credentials are still valid, signing in doesn't trigger authentication.
1978        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1979        assert_eq!(*auth_count.lock(), 1);
1980        assert_eq!(credentials.access_token, "1");
1981
1982        // If the server is unavailable, signing in doesn't trigger authentication.
1983        http_client
1984            .as_fake()
1985            .replace_handler(|_, _request| async move {
1986                Ok(http_client::Response::builder()
1987                    .status(503)
1988                    .body("".into())
1989                    .unwrap())
1990            });
1991        client.sign_in(false, &cx.to_async()).await.unwrap_err();
1992        assert_eq!(*auth_count.lock(), 1);
1993
1994        // If credentials became invalid, signing in triggers authentication.
1995        http_client
1996            .as_fake()
1997            .replace_handler(|_, request| async move {
1998                let credentials = parse_authorization_header(&request).unwrap();
1999                if credentials.access_token == "2" {
2000                    Ok(http_client::Response::builder()
2001                        .status(200)
2002                        .body("".into())
2003                        .unwrap())
2004                } else {
2005                    Ok(http_client::Response::builder()
2006                        .status(401)
2007                        .body("".into())
2008                        .unwrap())
2009                }
2010            });
2011        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2012        assert_eq!(*auth_count.lock(), 2);
2013        assert_eq!(credentials.access_token, "2");
2014    }
2015
2016    #[gpui::test(iterations = 10)]
2017    async fn test_authenticating_more_than_once(
2018        cx: &mut TestAppContext,
2019        executor: BackgroundExecutor,
2020    ) {
2021        init_test(cx);
2022        let auth_count = Arc::new(Mutex::new(0));
2023        let dropped_auth_count = Arc::new(Mutex::new(0));
2024        let client = cx.update(|cx| {
2025            Client::new(
2026                Arc::new(FakeSystemClock::new()),
2027                FakeHttpClient::with_404_response(),
2028                cx,
2029            )
2030        });
2031        client.override_authenticate({
2032            let auth_count = auth_count.clone();
2033            let dropped_auth_count = dropped_auth_count.clone();
2034            move |cx| {
2035                let auth_count = auth_count.clone();
2036                let dropped_auth_count = dropped_auth_count.clone();
2037                cx.background_spawn(async move {
2038                    *auth_count.lock() += 1;
2039                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
2040                    future::pending::<()>().await;
2041                    unreachable!()
2042                })
2043            }
2044        });
2045
2046        let _authenticate = cx.spawn({
2047            let client = client.clone();
2048            move |cx| async move { client.connect(false, &cx).await }
2049        });
2050        executor.run_until_parked();
2051        assert_eq!(*auth_count.lock(), 1);
2052        assert_eq!(*dropped_auth_count.lock(), 0);
2053
2054        let _authenticate = cx.spawn(|cx| async move { client.connect(false, &cx).await });
2055        executor.run_until_parked();
2056        assert_eq!(*auth_count.lock(), 2);
2057        assert_eq!(*dropped_auth_count.lock(), 1);
2058    }
2059
2060    #[gpui::test]
2061    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
2062        init_test(cx);
2063        let user_id = 5;
2064        let client = cx.update(|cx| {
2065            Client::new(
2066                Arc::new(FakeSystemClock::new()),
2067                FakeHttpClient::with_404_response(),
2068                cx,
2069            )
2070        });
2071        let server = FakeServer::for_client(user_id, &client, cx).await;
2072
2073        let (done_tx1, done_rx1) = smol::channel::unbounded();
2074        let (done_tx2, done_rx2) = smol::channel::unbounded();
2075        AnyProtoClient::from(client.clone()).add_entity_message_handler(
2076            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, cx| {
2077                match entity.read_with(&cx, |entity, _| entity.id).unwrap() {
2078                    1 => done_tx1.try_send(()).unwrap(),
2079                    2 => done_tx2.try_send(()).unwrap(),
2080                    _ => unreachable!(),
2081                }
2082                async { Ok(()) }
2083            },
2084        );
2085        let entity1 = cx.new(|_| TestEntity {
2086            id: 1,
2087            subscription: None,
2088        });
2089        let entity2 = cx.new(|_| TestEntity {
2090            id: 2,
2091            subscription: None,
2092        });
2093        let entity3 = cx.new(|_| TestEntity {
2094            id: 3,
2095            subscription: None,
2096        });
2097
2098        let _subscription1 = client
2099            .subscribe_to_entity(1)
2100            .unwrap()
2101            .set_entity(&entity1, &cx.to_async());
2102        let _subscription2 = client
2103            .subscribe_to_entity(2)
2104            .unwrap()
2105            .set_entity(&entity2, &cx.to_async());
2106        // Ensure dropping a subscription for the same entity type still allows receiving of
2107        // messages for other entity IDs of the same type.
2108        let subscription3 = client
2109            .subscribe_to_entity(3)
2110            .unwrap()
2111            .set_entity(&entity3, &cx.to_async());
2112        drop(subscription3);
2113
2114        server.send(proto::JoinProject {
2115            project_id: 1,
2116            committer_name: None,
2117            committer_email: None,
2118        });
2119        server.send(proto::JoinProject {
2120            project_id: 2,
2121            committer_name: None,
2122            committer_email: None,
2123        });
2124        done_rx1.recv().await.unwrap();
2125        done_rx2.recv().await.unwrap();
2126    }
2127
2128    #[gpui::test]
2129    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2130        init_test(cx);
2131        let user_id = 5;
2132        let client = cx.update(|cx| {
2133            Client::new(
2134                Arc::new(FakeSystemClock::new()),
2135                FakeHttpClient::with_404_response(),
2136                cx,
2137            )
2138        });
2139        let server = FakeServer::for_client(user_id, &client, cx).await;
2140
2141        let entity = cx.new(|_| TestEntity::default());
2142        let (done_tx1, _done_rx1) = smol::channel::unbounded();
2143        let (done_tx2, done_rx2) = smol::channel::unbounded();
2144        let subscription1 = client.add_message_handler(
2145            entity.downgrade(),
2146            move |_, _: TypedEnvelope<proto::Ping>, _| {
2147                done_tx1.try_send(()).unwrap();
2148                async { Ok(()) }
2149            },
2150        );
2151        drop(subscription1);
2152        let _subscription2 = client.add_message_handler(
2153            entity.downgrade(),
2154            move |_, _: TypedEnvelope<proto::Ping>, _| {
2155                done_tx2.try_send(()).unwrap();
2156                async { Ok(()) }
2157            },
2158        );
2159        server.send(proto::Ping {});
2160        done_rx2.recv().await.unwrap();
2161    }
2162
2163    #[gpui::test]
2164    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2165        init_test(cx);
2166        let user_id = 5;
2167        let client = cx.update(|cx| {
2168            Client::new(
2169                Arc::new(FakeSystemClock::new()),
2170                FakeHttpClient::with_404_response(),
2171                cx,
2172            )
2173        });
2174        let server = FakeServer::for_client(user_id, &client, cx).await;
2175
2176        let entity = cx.new(|_| TestEntity::default());
2177        let (done_tx, done_rx) = smol::channel::unbounded();
2178        let subscription = client.add_message_handler(
2179            entity.clone().downgrade(),
2180            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
2181                entity
2182                    .update(&mut cx, |entity, _| entity.subscription.take())
2183                    .unwrap();
2184                done_tx.try_send(()).unwrap();
2185                async { Ok(()) }
2186            },
2187        );
2188        entity.update(cx, |entity, _| {
2189            entity.subscription = Some(subscription);
2190        });
2191        server.send(proto::Ping {});
2192        done_rx.recv().await.unwrap();
2193    }
2194
2195    #[derive(Default)]
2196    struct TestEntity {
2197        id: usize,
2198        subscription: Option<Subscription>,
2199    }
2200
2201    fn init_test(cx: &mut TestAppContext) {
2202        cx.update(|cx| {
2203            let settings_store = SettingsStore::test(cx);
2204            cx.set_global(settings_store);
2205            init_settings(cx);
2206        });
2207    }
2208}