client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod proxy;
   5pub mod telemetry;
   6pub mod user;
   7pub mod zed_urls;
   8
   9use anyhow::{Context as _, Result, anyhow};
  10use async_tungstenite::tungstenite::{
  11    client::IntoClientRequest,
  12    error::Error as WebsocketError,
  13    http::{HeaderValue, Request, StatusCode},
  14};
  15use clock::SystemClock;
  16use cloud_api_client::CloudApiClient;
  17use cloud_api_client::websocket_protocol::MessageToClient;
  18use credentials_provider::CredentialsProvider;
  19use feature_flags::FeatureFlagAppExt as _;
  20use futures::{
  21    AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
  22    channel::oneshot, future::BoxFuture,
  23};
  24use gpui::{App, AsyncApp, Entity, FutureExt as _, Global, Task, Timeout, WeakEntity, actions};
  25use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
  26use parking_lot::RwLock;
  27use postage::watch;
  28use proxy::connect_proxy_stream;
  29use rand::prelude::*;
  30use release_channel::{AppVersion, ReleaseChannel};
  31use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  32use serde::{Deserialize, Serialize};
  33use settings::{Settings, SettingsContent};
  34use std::{
  35    any::TypeId,
  36    convert::TryFrom,
  37    fmt::Write as _,
  38    future::Future,
  39    marker::PhantomData,
  40    path::PathBuf,
  41    sync::{
  42        Arc, LazyLock, Weak,
  43        atomic::{AtomicU64, Ordering},
  44    },
  45    time::{Duration, Instant},
  46};
  47use std::{cmp, pin::Pin};
  48use telemetry::Telemetry;
  49use thiserror::Error;
  50use tokio::net::TcpStream;
  51use url::Url;
  52use util::{ConnectionResult, ResultExt};
  53
  54pub use rpc::*;
  55pub use telemetry_events::Event;
  56pub use user::*;
  57
  58static ZED_SERVER_URL: LazyLock<Option<String>> =
  59    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  60static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  61
  62pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  63    std::env::var("ZED_IMPERSONATE")
  64        .ok()
  65        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  66});
  67
  68pub static USE_WEB_LOGIN: LazyLock<bool> = LazyLock::new(|| std::env::var("ZED_WEB_LOGIN").is_ok());
  69
  70pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  71    std::env::var("ZED_ADMIN_API_TOKEN")
  72        .ok()
  73        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  74});
  75
  76pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  77    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  78
  79pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  80    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").is_ok_and(|e| !e.is_empty()));
  81
  82pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  83pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
  84pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  85
  86actions!(
  87    client,
  88    [
  89        /// Signs in to Zed account.
  90        SignIn,
  91        /// Signs out of Zed account.
  92        SignOut,
  93        /// Reconnects to the collaboration server.
  94        Reconnect
  95    ]
  96);
  97
  98#[derive(Deserialize)]
  99pub struct ClientSettings {
 100    pub server_url: String,
 101}
 102
 103impl Settings for ClientSettings {
 104    fn from_settings(content: &settings::SettingsContent) -> Self {
 105        if let Some(server_url) = &*ZED_SERVER_URL {
 106            return Self {
 107                server_url: server_url.clone(),
 108            };
 109        }
 110        Self {
 111            server_url: content.server_url.clone().unwrap(),
 112        }
 113    }
 114}
 115
 116#[derive(Deserialize, Default)]
 117pub struct ProxySettings {
 118    pub proxy: Option<String>,
 119}
 120
 121impl ProxySettings {
 122    pub fn proxy_url(&self) -> Option<Url> {
 123        self.proxy
 124            .as_ref()
 125            .and_then(|input| {
 126                input
 127                    .parse::<Url>()
 128                    .inspect_err(|e| log::error!("Error parsing proxy settings: {}", e))
 129                    .ok()
 130            })
 131            .or_else(read_proxy_from_env)
 132    }
 133}
 134
 135impl Settings for ProxySettings {
 136    fn from_settings(content: &settings::SettingsContent) -> Self {
 137        Self {
 138            proxy: content.proxy.clone(),
 139        }
 140    }
 141
 142    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut SettingsContent) {
 143        vscode.string_setting("http.proxy", &mut current.proxy);
 144    }
 145}
 146
 147pub fn init_settings(cx: &mut App) {
 148    TelemetrySettings::register(cx);
 149    ClientSettings::register(cx);
 150    ProxySettings::register(cx);
 151}
 152
 153pub fn init(client: &Arc<Client>, cx: &mut App) {
 154    let client = Arc::downgrade(client);
 155    cx.on_action({
 156        let client = client.clone();
 157        move |_: &SignIn, cx| {
 158            if let Some(client) = client.upgrade() {
 159                cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, cx).await)
 160                    .detach_and_log_err(cx);
 161            }
 162        }
 163    });
 164
 165    cx.on_action({
 166        let client = client.clone();
 167        move |_: &SignOut, cx| {
 168            if let Some(client) = client.upgrade() {
 169                cx.spawn(async move |cx| {
 170                    client.sign_out(cx).await;
 171                })
 172                .detach();
 173            }
 174        }
 175    });
 176
 177    cx.on_action({
 178        let client = client;
 179        move |_: &Reconnect, cx| {
 180            if let Some(client) = client.upgrade() {
 181                cx.spawn(async move |cx| {
 182                    client.reconnect(cx);
 183                })
 184                .detach();
 185            }
 186        }
 187    });
 188}
 189
 190pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
 191
 192struct GlobalClient(Arc<Client>);
 193
 194impl Global for GlobalClient {}
 195
 196pub struct Client {
 197    id: AtomicU64,
 198    peer: Arc<Peer>,
 199    http: Arc<HttpClientWithUrl>,
 200    cloud_client: Arc<CloudApiClient>,
 201    telemetry: Arc<Telemetry>,
 202    credentials_provider: ClientCredentialsProvider,
 203    state: RwLock<ClientState>,
 204    handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
 205    message_to_client_handlers: parking_lot::Mutex<Vec<MessageToClientHandler>>,
 206
 207    #[allow(clippy::type_complexity)]
 208    #[cfg(any(test, feature = "test-support"))]
 209    authenticate:
 210        RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
 211
 212    #[allow(clippy::type_complexity)]
 213    #[cfg(any(test, feature = "test-support"))]
 214    establish_connection: RwLock<
 215        Option<
 216            Box<
 217                dyn 'static
 218                    + Send
 219                    + Sync
 220                    + Fn(
 221                        &Credentials,
 222                        &AsyncApp,
 223                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 224            >,
 225        >,
 226    >,
 227
 228    #[cfg(any(test, feature = "test-support"))]
 229    rpc_url: RwLock<Option<Url>>,
 230}
 231
 232#[derive(Error, Debug)]
 233pub enum EstablishConnectionError {
 234    #[error("upgrade required")]
 235    UpgradeRequired,
 236    #[error("unauthorized")]
 237    Unauthorized,
 238    #[error("{0}")]
 239    Other(#[from] anyhow::Error),
 240    #[error("{0}")]
 241    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 242    #[error("{0}")]
 243    Io(#[from] std::io::Error),
 244    #[error("{0}")]
 245    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 246}
 247
 248impl From<WebsocketError> for EstablishConnectionError {
 249    fn from(error: WebsocketError) -> Self {
 250        if let WebsocketError::Http(response) = &error {
 251            match response.status() {
 252                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 253                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 254                _ => {}
 255            }
 256        }
 257        EstablishConnectionError::Other(error.into())
 258    }
 259}
 260
 261impl EstablishConnectionError {
 262    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 263        Self::Other(error.into())
 264    }
 265}
 266
 267#[derive(Copy, Clone, Debug, PartialEq)]
 268pub enum Status {
 269    SignedOut,
 270    UpgradeRequired,
 271    Authenticating,
 272    Authenticated,
 273    AuthenticationError,
 274    Connecting,
 275    ConnectionError,
 276    Connected {
 277        peer_id: PeerId,
 278        connection_id: ConnectionId,
 279    },
 280    ConnectionLost,
 281    Reauthenticating,
 282    Reauthenticated,
 283    Reconnecting,
 284    ReconnectionError {
 285        next_reconnection: Instant,
 286    },
 287}
 288
 289impl Status {
 290    pub fn is_connected(&self) -> bool {
 291        matches!(self, Self::Connected { .. })
 292    }
 293
 294    pub fn was_connected(&self) -> bool {
 295        matches!(
 296            self,
 297            Self::ConnectionLost
 298                | Self::Reauthenticating
 299                | Self::Reauthenticated
 300                | Self::Reconnecting
 301        )
 302    }
 303
 304    /// Returns whether the client is currently connected or was connected at some point.
 305    pub fn is_or_was_connected(&self) -> bool {
 306        self.is_connected() || self.was_connected()
 307    }
 308
 309    pub fn is_signing_in(&self) -> bool {
 310        matches!(
 311            self,
 312            Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
 313        )
 314    }
 315
 316    pub fn is_signed_out(&self) -> bool {
 317        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 318    }
 319}
 320
 321struct ClientState {
 322    credentials: Option<Credentials>,
 323    status: (watch::Sender<Status>, watch::Receiver<Status>),
 324    _reconnect_task: Option<Task<()>>,
 325}
 326
 327#[derive(Clone, Debug, Eq, PartialEq)]
 328pub struct Credentials {
 329    pub user_id: u64,
 330    pub access_token: String,
 331}
 332
 333impl Credentials {
 334    pub fn authorization_header(&self) -> String {
 335        format!("{} {}", self.user_id, self.access_token)
 336    }
 337}
 338
 339pub const CREDENTIAL_READ_TIMEOUT: Duration = Duration::from_secs(10);
 340pub struct ClientCredentialsProvider {
 341    provider: Arc<dyn CredentialsProvider>,
 342}
 343
 344impl ClientCredentialsProvider {
 345    pub fn new(cx: &App) -> Self {
 346        Self {
 347            provider: <dyn CredentialsProvider>::global(cx),
 348        }
 349    }
 350
 351    fn server_url(&self, cx: &AsyncApp) -> Result<String> {
 352        cx.update(|cx| ClientSettings::get_global(cx).server_url.clone())
 353    }
 354
 355    /// Reads the credentials from the provider.
 356    ///
 357    /// Times out after [`CREDENTIAL_READ_TIMEOUT`] seconds returning None and
 358    /// logging the timeout as error
 359    fn read_credentials<'a>(
 360        &'a self,
 361        cx: &'a AsyncApp,
 362    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
 363        async move {
 364            let executor = cx.background_executor();
 365            if IMPERSONATE_LOGIN.is_some() {
 366                return None;
 367            }
 368
 369            let server_url = self.server_url(cx).ok()?;
 370            let (user_id, access_token) = self
 371                .provider
 372                .read_credentials(&server_url, cx)
 373                .with_timeout(CREDENTIAL_READ_TIMEOUT, executor)
 374                .await
 375                .map_err(|_| anyhow!("Timed out waiting for credentials from OS"))
 376                .flatten()
 377                .log_err()
 378                .flatten()?;
 379
 380            Some(Credentials {
 381                user_id: user_id.parse().ok()?,
 382                access_token: String::from_utf8(access_token).ok()?,
 383            })
 384        }
 385        .boxed_local()
 386    }
 387
 388    /// Writes the credentials to the provider.
 389    fn write_credentials<'a>(
 390        &'a self,
 391        user_id: u64,
 392        access_token: String,
 393        cx: &'a AsyncApp,
 394    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 395        async move {
 396            let server_url = self.server_url(cx)?;
 397            self.provider
 398                .write_credentials(
 399                    &server_url,
 400                    &user_id.to_string(),
 401                    access_token.as_bytes(),
 402                    cx,
 403                )
 404                .await
 405        }
 406        .boxed_local()
 407    }
 408
 409    /// Deletes the credentials from the provider.
 410    fn delete_credentials<'a>(
 411        &'a self,
 412        cx: &'a AsyncApp,
 413    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 414        async move {
 415            let server_url = self.server_url(cx)?;
 416            self.provider.delete_credentials(&server_url, cx).await
 417        }
 418        .boxed_local()
 419    }
 420}
 421
 422impl Default for ClientState {
 423    fn default() -> Self {
 424        Self {
 425            credentials: None,
 426            status: watch::channel_with(Status::SignedOut),
 427            _reconnect_task: None,
 428        }
 429    }
 430}
 431
 432pub enum Subscription {
 433    Entity {
 434        client: Weak<Client>,
 435        id: (TypeId, u64),
 436    },
 437    Message {
 438        client: Weak<Client>,
 439        id: TypeId,
 440    },
 441}
 442
 443impl Drop for Subscription {
 444    fn drop(&mut self) {
 445        match self {
 446            Subscription::Entity { client, id } => {
 447                if let Some(client) = client.upgrade() {
 448                    let mut state = client.handler_set.lock();
 449                    let _ = state.entities_by_type_and_remote_id.remove(id);
 450                }
 451            }
 452            Subscription::Message { client, id } => {
 453                if let Some(client) = client.upgrade() {
 454                    let mut state = client.handler_set.lock();
 455                    let _ = state.entity_types_by_message_type.remove(id);
 456                    let _ = state.message_handlers.remove(id);
 457                }
 458            }
 459        }
 460    }
 461}
 462
 463pub struct PendingEntitySubscription<T: 'static> {
 464    client: Arc<Client>,
 465    remote_id: u64,
 466    _entity_type: PhantomData<T>,
 467    consumed: bool,
 468}
 469
 470impl<T: 'static> PendingEntitySubscription<T> {
 471    pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
 472        self.consumed = true;
 473        let mut handlers = self.client.handler_set.lock();
 474        let id = (TypeId::of::<T>(), self.remote_id);
 475        let Some(EntityMessageSubscriber::Pending(messages)) =
 476            handlers.entities_by_type_and_remote_id.remove(&id)
 477        else {
 478            unreachable!()
 479        };
 480
 481        handlers.entities_by_type_and_remote_id.insert(
 482            id,
 483            EntityMessageSubscriber::Entity {
 484                handle: entity.downgrade().into(),
 485            },
 486        );
 487        drop(handlers);
 488        for message in messages {
 489            let client_id = self.client.id();
 490            let type_name = message.payload_type_name();
 491            let sender_id = message.original_sender_id();
 492            log::debug!(
 493                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 494                client_id,
 495                sender_id,
 496                type_name
 497            );
 498            self.client.handle_message(message, cx);
 499        }
 500        Subscription::Entity {
 501            client: Arc::downgrade(&self.client),
 502            id,
 503        }
 504    }
 505}
 506
 507impl<T: 'static> Drop for PendingEntitySubscription<T> {
 508    fn drop(&mut self) {
 509        if !self.consumed {
 510            let mut state = self.client.handler_set.lock();
 511            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 512                .entities_by_type_and_remote_id
 513                .remove(&(TypeId::of::<T>(), self.remote_id))
 514            {
 515                for message in messages {
 516                    log::info!("unhandled message {}", message.payload_type_name());
 517                }
 518            }
 519        }
 520    }
 521}
 522
 523#[derive(Copy, Clone, Deserialize, Debug)]
 524pub struct TelemetrySettings {
 525    pub diagnostics: bool,
 526    pub metrics: bool,
 527}
 528
 529impl settings::Settings for TelemetrySettings {
 530    fn from_settings(content: &SettingsContent) -> Self {
 531        Self {
 532            diagnostics: content.telemetry.as_ref().unwrap().diagnostics.unwrap(),
 533            metrics: content.telemetry.as_ref().unwrap().metrics.unwrap(),
 534        }
 535    }
 536
 537    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut SettingsContent) {
 538        let mut telemetry = settings::TelemetrySettingsContent::default();
 539        vscode.enum_setting("telemetry.telemetryLevel", &mut telemetry.metrics, |s| {
 540            Some(s == "all")
 541        });
 542        vscode.enum_setting(
 543            "telemetry.telemetryLevel",
 544            &mut telemetry.diagnostics,
 545            |s| Some(matches!(s, "all" | "error" | "crash")),
 546        );
 547        // we could translate telemetry.telemetryLevel, but just because users didn't want
 548        // to send microsoft telemetry doesn't mean they don't want to send it to zed. their
 549        // all/error/crash/off correspond to combinations of our "diagnostics" and "metrics".
 550        if let Some(diagnostics) = telemetry.diagnostics {
 551            current.telemetry.get_or_insert_default().diagnostics = Some(diagnostics)
 552        }
 553        if let Some(metrics) = telemetry.metrics {
 554            current.telemetry.get_or_insert_default().metrics = Some(metrics)
 555        }
 556    }
 557}
 558
 559impl Client {
 560    pub fn new(
 561        clock: Arc<dyn SystemClock>,
 562        http: Arc<HttpClientWithUrl>,
 563        cx: &mut App,
 564    ) -> Arc<Self> {
 565        Arc::new(Self {
 566            id: AtomicU64::new(0),
 567            peer: Peer::new(0),
 568            telemetry: Telemetry::new(clock, http.clone(), cx),
 569            cloud_client: Arc::new(CloudApiClient::new(http.clone())),
 570            http,
 571            credentials_provider: ClientCredentialsProvider::new(cx),
 572            state: Default::default(),
 573            handler_set: Default::default(),
 574            message_to_client_handlers: parking_lot::Mutex::new(Vec::new()),
 575
 576            #[cfg(any(test, feature = "test-support"))]
 577            authenticate: Default::default(),
 578            #[cfg(any(test, feature = "test-support"))]
 579            establish_connection: Default::default(),
 580            #[cfg(any(test, feature = "test-support"))]
 581            rpc_url: RwLock::default(),
 582        })
 583    }
 584
 585    pub fn production(cx: &mut App) -> Arc<Self> {
 586        let clock = Arc::new(clock::RealSystemClock);
 587        let http = Arc::new(HttpClientWithUrl::new_url(
 588            cx.http_client(),
 589            &ClientSettings::get_global(cx).server_url,
 590            cx.http_client().proxy().cloned(),
 591        ));
 592        Self::new(clock, http, cx)
 593    }
 594
 595    pub fn id(&self) -> u64 {
 596        self.id.load(Ordering::SeqCst)
 597    }
 598
 599    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 600        self.http.clone()
 601    }
 602
 603    pub fn cloud_client(&self) -> Arc<CloudApiClient> {
 604        self.cloud_client.clone()
 605    }
 606
 607    pub fn set_id(&self, id: u64) -> &Self {
 608        self.id.store(id, Ordering::SeqCst);
 609        self
 610    }
 611
 612    #[cfg(any(test, feature = "test-support"))]
 613    pub fn teardown(&self) {
 614        let mut state = self.state.write();
 615        state._reconnect_task.take();
 616        self.handler_set.lock().clear();
 617        self.peer.teardown();
 618    }
 619
 620    #[cfg(any(test, feature = "test-support"))]
 621    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 622    where
 623        F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
 624    {
 625        *self.authenticate.write() = Some(Box::new(authenticate));
 626        self
 627    }
 628
 629    #[cfg(any(test, feature = "test-support"))]
 630    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 631    where
 632        F: 'static
 633            + Send
 634            + Sync
 635            + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
 636    {
 637        *self.establish_connection.write() = Some(Box::new(connect));
 638        self
 639    }
 640
 641    #[cfg(any(test, feature = "test-support"))]
 642    pub fn override_rpc_url(&self, url: Url) -> &Self {
 643        *self.rpc_url.write() = Some(url);
 644        self
 645    }
 646
 647    pub fn global(cx: &App) -> Arc<Self> {
 648        cx.global::<GlobalClient>().0.clone()
 649    }
 650    pub fn set_global(client: Arc<Client>, cx: &mut App) {
 651        cx.set_global(GlobalClient(client))
 652    }
 653
 654    pub fn user_id(&self) -> Option<u64> {
 655        self.state
 656            .read()
 657            .credentials
 658            .as_ref()
 659            .map(|credentials| credentials.user_id)
 660    }
 661
 662    pub fn peer_id(&self) -> Option<PeerId> {
 663        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 664            Some(*peer_id)
 665        } else {
 666            None
 667        }
 668    }
 669
 670    pub fn status(&self) -> watch::Receiver<Status> {
 671        self.state.read().status.1.clone()
 672    }
 673
 674    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
 675        log::info!("set status on client {}: {:?}", self.id(), status);
 676        let mut state = self.state.write();
 677        *state.status.0.borrow_mut() = status;
 678
 679        match status {
 680            Status::Connected { .. } => {
 681                state._reconnect_task = None;
 682            }
 683            Status::ConnectionLost => {
 684                let client = self.clone();
 685                state._reconnect_task = Some(cx.spawn(async move |cx| {
 686                    #[cfg(any(test, feature = "test-support"))]
 687                    let mut rng = StdRng::seed_from_u64(0);
 688                    #[cfg(not(any(test, feature = "test-support")))]
 689                    let mut rng = StdRng::from_os_rng();
 690
 691                    let mut delay = INITIAL_RECONNECTION_DELAY;
 692                    loop {
 693                        match client.connect(true, cx).await {
 694                            ConnectionResult::Timeout => {
 695                                log::error!("client connect attempt timed out")
 696                            }
 697                            ConnectionResult::ConnectionReset => {
 698                                log::error!("client connect attempt reset")
 699                            }
 700                            ConnectionResult::Result(r) => {
 701                                if let Err(error) = r {
 702                                    log::error!("failed to connect: {error}");
 703                                } else {
 704                                    break;
 705                                }
 706                            }
 707                        }
 708
 709                        if matches!(
 710                            *client.status().borrow(),
 711                            Status::AuthenticationError | Status::ConnectionError
 712                        ) {
 713                            client.set_status(
 714                                Status::ReconnectionError {
 715                                    next_reconnection: Instant::now() + delay,
 716                                },
 717                                cx,
 718                            );
 719                            let jitter = Duration::from_millis(
 720                                rng.random_range(0..delay.as_millis() as u64),
 721                            );
 722                            cx.background_executor().timer(delay + jitter).await;
 723                            delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
 724                        } else {
 725                            break;
 726                        }
 727                    }
 728                }));
 729            }
 730            Status::SignedOut | Status::UpgradeRequired => {
 731                self.telemetry.set_authenticated_user_info(None, false);
 732                state._reconnect_task.take();
 733            }
 734            _ => {}
 735        }
 736    }
 737
 738    pub fn subscribe_to_entity<T>(
 739        self: &Arc<Self>,
 740        remote_id: u64,
 741    ) -> Result<PendingEntitySubscription<T>>
 742    where
 743        T: 'static,
 744    {
 745        let id = (TypeId::of::<T>(), remote_id);
 746
 747        let mut state = self.handler_set.lock();
 748        anyhow::ensure!(
 749            !state.entities_by_type_and_remote_id.contains_key(&id),
 750            "already subscribed to entity"
 751        );
 752
 753        state
 754            .entities_by_type_and_remote_id
 755            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 756
 757        Ok(PendingEntitySubscription {
 758            client: self.clone(),
 759            remote_id,
 760            consumed: false,
 761            _entity_type: PhantomData,
 762        })
 763    }
 764
 765    #[track_caller]
 766    pub fn add_message_handler<M, E, H, F>(
 767        self: &Arc<Self>,
 768        entity: WeakEntity<E>,
 769        handler: H,
 770    ) -> Subscription
 771    where
 772        M: EnvelopedMessage,
 773        E: 'static,
 774        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 775        F: 'static + Future<Output = Result<()>>,
 776    {
 777        self.add_message_handler_impl(entity, move |entity, message, _, cx| {
 778            handler(entity, message, cx)
 779        })
 780    }
 781
 782    fn add_message_handler_impl<M, E, H, F>(
 783        self: &Arc<Self>,
 784        entity: WeakEntity<E>,
 785        handler: H,
 786    ) -> Subscription
 787    where
 788        M: EnvelopedMessage,
 789        E: 'static,
 790        H: 'static
 791            + Sync
 792            + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
 793            + Send
 794            + Sync,
 795        F: 'static + Future<Output = Result<()>>,
 796    {
 797        let message_type_id = TypeId::of::<M>();
 798        let mut state = self.handler_set.lock();
 799        state
 800            .entities_by_message_type
 801            .insert(message_type_id, entity.into());
 802
 803        let prev_handler = state.message_handlers.insert(
 804            message_type_id,
 805            Arc::new(move |subscriber, envelope, client, cx| {
 806                let subscriber = subscriber.downcast::<E>().unwrap();
 807                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 808                handler(subscriber, *envelope, client, cx).boxed_local()
 809            }),
 810        );
 811        if prev_handler.is_some() {
 812            let location = std::panic::Location::caller();
 813            panic!(
 814                "{}:{} registered handler for the same message {} twice",
 815                location.file(),
 816                location.line(),
 817                std::any::type_name::<M>()
 818            );
 819        }
 820
 821        Subscription::Message {
 822            client: Arc::downgrade(self),
 823            id: message_type_id,
 824        }
 825    }
 826
 827    pub fn add_request_handler<M, E, H, F>(
 828        self: &Arc<Self>,
 829        entity: WeakEntity<E>,
 830        handler: H,
 831    ) -> Subscription
 832    where
 833        M: RequestMessage,
 834        E: 'static,
 835        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 836        F: 'static + Future<Output = Result<M::Response>>,
 837    {
 838        self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
 839            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 840        })
 841    }
 842
 843    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 844        receipt: Receipt<T>,
 845        response: F,
 846        client: AnyProtoClient,
 847    ) -> Result<()> {
 848        match response.await {
 849            Ok(response) => {
 850                client.send_response(receipt.message_id, response)?;
 851                Ok(())
 852            }
 853            Err(error) => {
 854                client.send_response(receipt.message_id, error.to_proto())?;
 855                Err(error)
 856            }
 857        }
 858    }
 859
 860    pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
 861        self.credentials_provider
 862            .read_credentials(cx)
 863            .await
 864            .is_some()
 865    }
 866
 867    pub async fn sign_in(
 868        self: &Arc<Self>,
 869        try_provider: bool,
 870        cx: &AsyncApp,
 871    ) -> Result<Credentials> {
 872        let is_reauthenticating = if self.status().borrow().is_signed_out() {
 873            self.set_status(Status::Authenticating, cx);
 874            false
 875        } else {
 876            self.set_status(Status::Reauthenticating, cx);
 877            true
 878        };
 879
 880        let mut credentials = None;
 881
 882        let old_credentials = self.state.read().credentials.clone();
 883        if let Some(old_credentials) = old_credentials
 884            && self.validate_credentials(&old_credentials, cx).await?
 885        {
 886            credentials = Some(old_credentials);
 887        }
 888
 889        let executor = cx.background_executor();
 890        if credentials.is_none() && try_provider {
 891            match self
 892                .credentials_provider
 893                .read_credentials(cx)
 894                .with_timeout(
 895                    CREDENTIAL_READ_TIMEOUT - Duration::from_millis(200),
 896                    executor,
 897                )
 898                .await
 899            {
 900                Ok(None) => (),
 901                Ok(Some(stored_credentials)) => {
 902                    if self.validate_credentials(&stored_credentials, cx).await? {
 903                        credentials = Some(stored_credentials);
 904                    } else {
 905                        self.credentials_provider
 906                            .delete_credentials(cx)
 907                            .await
 908                            .log_err();
 909                    }
 910                }
 911                Err(Timeout) => {
 912                    // Todo do not time out but show toast after 10s. Hide toast
 913                    // when keyring unlocks.
 914                    return Err(anyhow!(
 915                        "Timed out waiting for credentials from OS, did you unlock the keyring?"
 916                    ));
 917                }
 918            }
 919        }
 920
 921        if credentials.is_none() {
 922            let mut status_rx = self.status();
 923            let _ = status_rx.next().await;
 924            futures::select_biased! {
 925                authenticate = self.authenticate(cx).fuse() => {
 926                    match authenticate {
 927                        Ok(creds) => {
 928                            if IMPERSONATE_LOGIN.is_none() {
 929                                self.credentials_provider
 930                                    .write_credentials(creds.user_id, creds.access_token.clone(), cx)
 931                                    .await
 932                                    .log_err();
 933                            }
 934
 935                            credentials = Some(creds);
 936                        },
 937                        Err(err) => {
 938                            self.set_status(Status::AuthenticationError, cx);
 939                            return Err(err);
 940                        }
 941                    }
 942                }
 943                _ = status_rx.next().fuse() => {
 944                    return Err(anyhow!("authentication canceled"));
 945                }
 946            }
 947        }
 948
 949        let credentials = credentials.unwrap();
 950        self.set_id(credentials.user_id);
 951        self.cloud_client
 952            .set_credentials(credentials.user_id as u32, credentials.access_token.clone());
 953        self.state.write().credentials = Some(credentials.clone());
 954        self.set_status(
 955            if is_reauthenticating {
 956                Status::Reauthenticated
 957            } else {
 958                Status::Authenticated
 959            },
 960            cx,
 961        );
 962
 963        Ok(credentials)
 964    }
 965
 966    async fn validate_credentials(
 967        self: &Arc<Self>,
 968        credentials: &Credentials,
 969        cx: &AsyncApp,
 970    ) -> Result<bool> {
 971        match self
 972            .cloud_client
 973            .validate_credentials(credentials.user_id as u32, &credentials.access_token)
 974            .await
 975        {
 976            Ok(valid) => Ok(valid),
 977            Err(err) => {
 978                self.set_status(Status::AuthenticationError, cx);
 979                Err(anyhow!("failed to validate credentials: {}", err))
 980            }
 981        }
 982    }
 983
 984    /// Establishes a WebSocket connection with Cloud for receiving updates from the server.
 985    async fn connect_to_cloud(self: &Arc<Self>, cx: &AsyncApp) -> Result<()> {
 986        let connect_task = cx.update({
 987            let cloud_client = self.cloud_client.clone();
 988            move |cx| cloud_client.connect(cx)
 989        })??;
 990        let connection = connect_task.await?;
 991
 992        let (mut messages, task) = cx.update(|cx| connection.spawn(cx))?;
 993        task.detach();
 994
 995        cx.spawn({
 996            let this = self.clone();
 997            async move |cx| {
 998                while let Some(message) = messages.next().await {
 999                    if let Some(message) = message.log_err() {
1000                        this.handle_message_to_client(message, cx);
1001                    }
1002                }
1003            }
1004        })
1005        .detach();
1006
1007        Ok(())
1008    }
1009
1010    /// Performs a sign-in and also (optionally) connects to Collab.
1011    ///
1012    /// Only Zed staff automatically connect to Collab.
1013    pub async fn sign_in_with_optional_connect(
1014        self: &Arc<Self>,
1015        try_provider: bool,
1016        cx: &AsyncApp,
1017    ) -> Result<()> {
1018        // Don't try to sign in again if we're already connected to Collab, as it will temporarily disconnect us.
1019        if self.status().borrow().is_connected() {
1020            return Ok(());
1021        }
1022
1023        let (is_staff_tx, is_staff_rx) = oneshot::channel::<bool>();
1024        let mut is_staff_tx = Some(is_staff_tx);
1025        cx.update(|cx| {
1026            cx.on_flags_ready(move |state, _cx| {
1027                if let Some(is_staff_tx) = is_staff_tx.take() {
1028                    is_staff_tx.send(state.is_staff).log_err();
1029                }
1030            })
1031            .detach();
1032        })
1033        .log_err();
1034
1035        let credentials = self.sign_in(try_provider, cx).await?;
1036
1037        self.connect_to_cloud(cx).await.log_err();
1038
1039        cx.update(move |cx| {
1040            cx.spawn({
1041                let client = self.clone();
1042                async move |cx| {
1043                    let is_staff = is_staff_rx.await?;
1044                    if is_staff {
1045                        match client.connect_with_credentials(credentials, cx).await {
1046                            ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
1047                            ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
1048                            ConnectionResult::Result(result) => {
1049                                result.context("client auth and connect")
1050                            }
1051                        }
1052                    } else {
1053                        Ok(())
1054                    }
1055                }
1056            })
1057            .detach_and_log_err(cx);
1058        })
1059        .log_err();
1060
1061        Ok(())
1062    }
1063
1064    pub async fn connect(
1065        self: &Arc<Self>,
1066        try_provider: bool,
1067        cx: &AsyncApp,
1068    ) -> ConnectionResult<()> {
1069        let was_disconnected = match *self.status().borrow() {
1070            Status::SignedOut | Status::Authenticated => true,
1071            Status::ConnectionError
1072            | Status::ConnectionLost
1073            | Status::Authenticating
1074            | Status::AuthenticationError
1075            | Status::Reauthenticating
1076            | Status::Reauthenticated
1077            | Status::ReconnectionError { .. } => false,
1078            Status::Connected { .. } | Status::Connecting | Status::Reconnecting => {
1079                return ConnectionResult::Result(Ok(()));
1080            }
1081            Status::UpgradeRequired => {
1082                return ConnectionResult::Result(
1083                    Err(EstablishConnectionError::UpgradeRequired)
1084                        .context("client auth and connect"),
1085                );
1086            }
1087        };
1088        let credentials = match self.sign_in(try_provider, cx).await {
1089            Ok(credentials) => credentials,
1090            Err(err) => return ConnectionResult::Result(Err(err)),
1091        };
1092
1093        if was_disconnected {
1094            self.set_status(Status::Connecting, cx);
1095        } else {
1096            self.set_status(Status::Reconnecting, cx);
1097        }
1098
1099        self.connect_with_credentials(credentials, cx).await
1100    }
1101
1102    async fn connect_with_credentials(
1103        self: &Arc<Self>,
1104        credentials: Credentials,
1105        cx: &AsyncApp,
1106    ) -> ConnectionResult<()> {
1107        let mut timeout =
1108            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
1109        futures::select_biased! {
1110            connection = self.establish_connection(&credentials, cx).fuse() => {
1111                match connection {
1112                    Ok(conn) => {
1113                        futures::select_biased! {
1114                            result = self.set_connection(conn, cx).fuse() => {
1115                                match result.context("client auth and connect") {
1116                                    Ok(()) => ConnectionResult::Result(Ok(())),
1117                                    Err(err) => {
1118                                        self.set_status(Status::ConnectionError, cx);
1119                                        ConnectionResult::Result(Err(err))
1120                                    },
1121                                }
1122                            },
1123                            _ = timeout => {
1124                                self.set_status(Status::ConnectionError, cx);
1125                                ConnectionResult::Timeout
1126                            }
1127                        }
1128                    }
1129                    Err(EstablishConnectionError::Unauthorized) => {
1130                        self.set_status(Status::ConnectionError, cx);
1131                        ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
1132                    }
1133                    Err(EstablishConnectionError::UpgradeRequired) => {
1134                        self.set_status(Status::UpgradeRequired, cx);
1135                        ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
1136                    }
1137                    Err(error) => {
1138                        self.set_status(Status::ConnectionError, cx);
1139                        ConnectionResult::Result(Err(error).context("client auth and connect"))
1140                    }
1141                }
1142            }
1143            _ = &mut timeout => {
1144                self.set_status(Status::ConnectionError, cx);
1145                ConnectionResult::Timeout
1146            }
1147        }
1148    }
1149
1150    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
1151        let executor = cx.background_executor();
1152        log::debug!("add connection to peer");
1153        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
1154            let executor = executor.clone();
1155            move |duration| executor.timer(duration)
1156        });
1157        let handle_io = executor.spawn(handle_io);
1158
1159        let peer_id = async {
1160            log::debug!("waiting for server hello");
1161            let message = incoming.next().await.context("no hello message received")?;
1162            log::debug!("got server hello");
1163            let hello_message_type_name = message.payload_type_name().to_string();
1164            let hello = message
1165                .into_any()
1166                .downcast::<TypedEnvelope<proto::Hello>>()
1167                .map_err(|_| {
1168                    anyhow!(
1169                        "invalid hello message received: {:?}",
1170                        hello_message_type_name
1171                    )
1172                })?;
1173            let peer_id = hello.payload.peer_id.context("invalid peer id")?;
1174            Ok(peer_id)
1175        };
1176
1177        let peer_id = match peer_id.await {
1178            Ok(peer_id) => peer_id,
1179            Err(error) => {
1180                self.peer.disconnect(connection_id);
1181                return Err(error);
1182            }
1183        };
1184
1185        log::debug!(
1186            "set status to connected (connection id: {:?}, peer id: {:?})",
1187            connection_id,
1188            peer_id
1189        );
1190        self.set_status(
1191            Status::Connected {
1192                peer_id,
1193                connection_id,
1194            },
1195            cx,
1196        );
1197
1198        cx.spawn({
1199            let this = self.clone();
1200            async move |cx| {
1201                while let Some(message) = incoming.next().await {
1202                    this.handle_message(message, cx);
1203                    // Don't starve the main thread when receiving lots of messages at once.
1204                    smol::future::yield_now().await;
1205                }
1206            }
1207        })
1208        .detach();
1209
1210        cx.spawn({
1211            let this = self.clone();
1212            async move |cx| match handle_io.await {
1213                Ok(()) => {
1214                    if *this.status().borrow()
1215                        == (Status::Connected {
1216                            connection_id,
1217                            peer_id,
1218                        })
1219                    {
1220                        this.set_status(Status::SignedOut, cx);
1221                    }
1222                }
1223                Err(err) => {
1224                    log::error!("connection error: {:?}", err);
1225                    this.set_status(Status::ConnectionLost, cx);
1226                }
1227            }
1228        })
1229        .detach();
1230
1231        Ok(())
1232    }
1233
1234    fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1235        #[cfg(any(test, feature = "test-support"))]
1236        if let Some(callback) = self.authenticate.read().as_ref() {
1237            return callback(cx);
1238        }
1239
1240        self.authenticate_with_browser(cx)
1241    }
1242
1243    fn establish_connection(
1244        self: &Arc<Self>,
1245        credentials: &Credentials,
1246        cx: &AsyncApp,
1247    ) -> Task<Result<Connection, EstablishConnectionError>> {
1248        #[cfg(any(test, feature = "test-support"))]
1249        if let Some(callback) = self.establish_connection.read().as_ref() {
1250            return callback(credentials, cx);
1251        }
1252
1253        self.establish_websocket_connection(credentials, cx)
1254    }
1255
1256    fn rpc_url(
1257        &self,
1258        http: Arc<HttpClientWithUrl>,
1259        release_channel: Option<ReleaseChannel>,
1260    ) -> impl Future<Output = Result<url::Url>> + use<> {
1261        #[cfg(any(test, feature = "test-support"))]
1262        let url_override = self.rpc_url.read().clone();
1263
1264        async move {
1265            #[cfg(any(test, feature = "test-support"))]
1266            if let Some(url) = url_override {
1267                return Ok(url);
1268            }
1269
1270            if let Some(url) = &*ZED_RPC_URL {
1271                return Url::parse(url).context("invalid rpc url");
1272            }
1273
1274            let mut url = http.build_url("/rpc");
1275            if let Some(preview_param) =
1276                release_channel.and_then(|channel| channel.release_query_param())
1277            {
1278                url += "?";
1279                url += preview_param;
1280            }
1281
1282            let response = http.get(&url, Default::default(), false).await?;
1283            anyhow::ensure!(
1284                response.status().is_redirection(),
1285                "unexpected /rpc response status {}",
1286                response.status()
1287            );
1288            let collab_url = response
1289                .headers()
1290                .get("Location")
1291                .context("missing location header in /rpc response")?
1292                .to_str()
1293                .map_err(EstablishConnectionError::other)?
1294                .to_string();
1295            Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
1296        }
1297    }
1298
1299    fn establish_websocket_connection(
1300        self: &Arc<Self>,
1301        credentials: &Credentials,
1302        cx: &AsyncApp,
1303    ) -> Task<Result<Connection, EstablishConnectionError>> {
1304        let release_channel = cx
1305            .update(|cx| ReleaseChannel::try_global(cx))
1306            .ok()
1307            .flatten();
1308        let app_version = cx
1309            .update(|cx| AppVersion::global(cx).to_string())
1310            .ok()
1311            .unwrap_or_default();
1312
1313        let http = self.http.clone();
1314        let proxy = http.proxy().cloned();
1315        let user_agent = http.user_agent().cloned();
1316        let credentials = credentials.clone();
1317        let rpc_url = self.rpc_url(http, release_channel);
1318        let system_id = self.telemetry.system_id();
1319        let metrics_id = self.telemetry.metrics_id();
1320        cx.spawn(async move |cx| {
1321            use HttpOrHttps::*;
1322
1323            #[derive(Debug)]
1324            enum HttpOrHttps {
1325                Http,
1326                Https,
1327            }
1328
1329            let mut rpc_url = rpc_url.await?;
1330            let url_scheme = match rpc_url.scheme() {
1331                "https" => Https,
1332                "http" => Http,
1333                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1334            };
1335
1336            let stream = gpui_tokio::Tokio::spawn_result(cx, {
1337                let rpc_url = rpc_url.clone();
1338                async move {
1339                    let rpc_host = rpc_url
1340                        .host_str()
1341                        .zip(rpc_url.port_or_known_default())
1342                        .context("missing host in rpc url")?;
1343                    Ok(match proxy {
1344                        Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
1345                        None => Box::new(TcpStream::connect(rpc_host).await?),
1346                    })
1347                }
1348            })?
1349            .await?;
1350
1351            log::info!("connected to rpc endpoint {}", rpc_url);
1352
1353            rpc_url
1354                .set_scheme(match url_scheme {
1355                    Https => "wss",
1356                    Http => "ws",
1357                })
1358                .unwrap();
1359
1360            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1361            // for us from the RPC URL.
1362            //
1363            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1364            let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1365
1366            // We then modify the request to add our desired headers.
1367            let request_headers = request.headers_mut();
1368            request_headers.insert(
1369                http::header::AUTHORIZATION,
1370                HeaderValue::from_str(&credentials.authorization_header())?,
1371            );
1372            request_headers.insert(
1373                "x-zed-protocol-version",
1374                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1375            );
1376            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1377            request_headers.insert(
1378                "x-zed-release-channel",
1379                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1380            );
1381            if let Some(user_agent) = user_agent {
1382                request_headers.insert(http::header::USER_AGENT, user_agent);
1383            }
1384            if let Some(system_id) = system_id {
1385                request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1386            }
1387            if let Some(metrics_id) = metrics_id {
1388                request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1389            }
1390
1391            let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1392                request,
1393                stream,
1394                Some(Arc::new(http_client_tls::tls_config()).into()),
1395                None,
1396            )
1397            .await?;
1398
1399            Ok(Connection::new(
1400                stream
1401                    .map_err(|error| anyhow!(error))
1402                    .sink_map_err(|error| anyhow!(error)),
1403            ))
1404        })
1405    }
1406
1407    pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1408        let http = self.http.clone();
1409        let this = self.clone();
1410        cx.spawn(async move |cx| {
1411            let background = cx.background_executor().clone();
1412
1413            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1414            cx.update(|cx| {
1415                cx.spawn(async move |cx| {
1416                    let url = open_url_rx.await?;
1417                    cx.update(|cx| cx.open_url(&url))
1418                })
1419                .detach_and_log_err(cx);
1420            })
1421            .log_err();
1422
1423            let credentials = background
1424                .clone()
1425                .spawn(async move {
1426                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1427                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1428                    // any other app running on the user's device.
1429                    let (public_key, private_key) =
1430                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1431                    let public_key_string = String::try_from(public_key)
1432                        .expect("failed to serialize public key for auth");
1433
1434                    if let Some((login, token)) =
1435                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1436                    {
1437                        if !*USE_WEB_LOGIN {
1438                            eprintln!("authenticate as admin {login}, {token}");
1439
1440                            return this
1441                                .authenticate_as_admin(http, login.clone(), token.clone())
1442                                .await;
1443                        }
1444                    }
1445
1446                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1447                    let server =
1448                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1449                    let port = server.server_addr().port();
1450
1451                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1452                    // that the user is signing in from a Zed app running on the same device.
1453                    let mut url = http.build_url(&format!(
1454                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1455                        port, public_key_string
1456                    ));
1457
1458                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1459                        log::info!("impersonating user @{}", impersonate_login);
1460                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1461                    }
1462
1463                    open_url_tx.send(url).log_err();
1464
1465                    #[derive(Deserialize)]
1466                    struct CallbackParams {
1467                        pub user_id: String,
1468                        pub access_token: String,
1469                    }
1470
1471                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1472                    // access token from the query params.
1473                    //
1474                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1475                    // custom URL scheme instead of this local HTTP server.
1476                    let (user_id, access_token) = background
1477                        .spawn(async move {
1478                            for _ in 0..100 {
1479                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1480                                    let path = req.url();
1481                                    let url = Url::parse(&format!("http://example.com{}", path))
1482                                        .context("failed to parse login notification url")?;
1483                                    let callback_params: CallbackParams =
1484                                        serde_urlencoded::from_str(url.query().unwrap_or_default())
1485                                            .context(
1486                                                "failed to parse sign-in callback query parameters",
1487                                            )?;
1488
1489                                    let post_auth_url =
1490                                        http.build_url("/native_app_signin_succeeded");
1491                                    req.respond(
1492                                        tiny_http::Response::empty(302).with_header(
1493                                            tiny_http::Header::from_bytes(
1494                                                &b"Location"[..],
1495                                                post_auth_url.as_bytes(),
1496                                            )
1497                                            .unwrap(),
1498                                        ),
1499                                    )
1500                                    .context("failed to respond to login http request")?;
1501                                    return Ok((
1502                                        callback_params.user_id,
1503                                        callback_params.access_token,
1504                                    ));
1505                                }
1506                            }
1507
1508                            anyhow::bail!("didn't receive login redirect");
1509                        })
1510                        .await?;
1511
1512                    let access_token = private_key
1513                        .decrypt_string(&access_token)
1514                        .context("failed to decrypt access token")?;
1515
1516                    Ok(Credentials {
1517                        user_id: user_id.parse()?,
1518                        access_token,
1519                    })
1520                })
1521                .await?;
1522
1523            cx.update(|cx| cx.activate(true))?;
1524            Ok(credentials)
1525        })
1526    }
1527
1528    async fn authenticate_as_admin(
1529        self: &Arc<Self>,
1530        http: Arc<HttpClientWithUrl>,
1531        login: String,
1532        api_token: String,
1533    ) -> Result<Credentials> {
1534        #[derive(Serialize)]
1535        struct ImpersonateUserBody {
1536            github_login: String,
1537        }
1538
1539        #[derive(Deserialize)]
1540        struct ImpersonateUserResponse {
1541            user_id: u64,
1542            access_token: String,
1543        }
1544
1545        let url = self
1546            .http
1547            .build_zed_cloud_url("/internal/users/impersonate", &[])?;
1548        let request = Request::post(url.as_str())
1549            .header("Content-Type", "application/json")
1550            .header("Authorization", format!("Bearer {api_token}"))
1551            .body(
1552                serde_json::to_string(&ImpersonateUserBody {
1553                    github_login: login,
1554                })?
1555                .into(),
1556            )?;
1557
1558        let mut response = http.send(request).await?;
1559        let mut body = String::new();
1560        response.body_mut().read_to_string(&mut body).await?;
1561        anyhow::ensure!(
1562            response.status().is_success(),
1563            "admin user request failed {} - {}",
1564            response.status().as_u16(),
1565            body,
1566        );
1567        let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
1568
1569        Ok(Credentials {
1570            user_id: response.user_id,
1571            access_token: response.access_token,
1572        })
1573    }
1574
1575    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1576        self.state.write().credentials = None;
1577        self.cloud_client.clear_credentials();
1578        self.disconnect(cx);
1579
1580        if self.has_credentials(cx).await {
1581            self.credentials_provider
1582                .delete_credentials(cx)
1583                .await
1584                .log_err();
1585        }
1586    }
1587
1588    pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1589        self.peer.teardown();
1590        self.set_status(Status::SignedOut, cx);
1591    }
1592
1593    pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1594        self.peer.teardown();
1595        self.set_status(Status::ConnectionLost, cx);
1596    }
1597
1598    fn connection_id(&self) -> Result<ConnectionId> {
1599        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1600            Ok(connection_id)
1601        } else {
1602            anyhow::bail!("not connected");
1603        }
1604    }
1605
1606    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1607        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1608        self.peer.send(self.connection_id()?, message)
1609    }
1610
1611    pub fn request<T: RequestMessage>(
1612        &self,
1613        request: T,
1614    ) -> impl Future<Output = Result<T::Response>> + use<T> {
1615        self.request_envelope(request)
1616            .map_ok(|envelope| envelope.payload)
1617    }
1618
1619    pub fn request_stream<T: RequestMessage>(
1620        &self,
1621        request: T,
1622    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1623        let client_id = self.id.load(Ordering::SeqCst);
1624        log::debug!(
1625            "rpc request start. client_id:{}. name:{}",
1626            client_id,
1627            T::NAME
1628        );
1629        let response = self
1630            .connection_id()
1631            .map(|conn_id| self.peer.request_stream(conn_id, request));
1632        async move {
1633            let response = response?.await;
1634            log::debug!(
1635                "rpc request finish. client_id:{}. name:{}",
1636                client_id,
1637                T::NAME
1638            );
1639            response
1640        }
1641    }
1642
1643    pub fn request_envelope<T: RequestMessage>(
1644        &self,
1645        request: T,
1646    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1647        let client_id = self.id();
1648        log::debug!(
1649            "rpc request start. client_id:{}. name:{}",
1650            client_id,
1651            T::NAME
1652        );
1653        let response = self
1654            .connection_id()
1655            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1656        async move {
1657            let response = response?.await;
1658            log::debug!(
1659                "rpc request finish. client_id:{}. name:{}",
1660                client_id,
1661                T::NAME
1662            );
1663            response
1664        }
1665    }
1666
1667    pub fn request_dynamic(
1668        &self,
1669        envelope: proto::Envelope,
1670        request_type: &'static str,
1671    ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1672        let client_id = self.id();
1673        log::debug!(
1674            "rpc request start. client_id:{}. name:{}",
1675            client_id,
1676            request_type
1677        );
1678        let response = self
1679            .connection_id()
1680            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1681        async move {
1682            let response = response?.await;
1683            log::debug!(
1684                "rpc request finish. client_id:{}. name:{}",
1685                client_id,
1686                request_type
1687            );
1688            Ok(response?.0)
1689        }
1690    }
1691
1692    fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1693        let sender_id = message.sender_id();
1694        let request_id = message.message_id();
1695        let type_name = message.payload_type_name();
1696        let original_sender_id = message.original_sender_id();
1697
1698        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1699            &self.handler_set,
1700            message,
1701            self.clone().into(),
1702            cx.clone(),
1703        ) {
1704            let client_id = self.id();
1705            log::debug!(
1706                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1707                client_id,
1708                original_sender_id,
1709                type_name
1710            );
1711            cx.spawn(async move |_| match future.await {
1712                Ok(()) => {
1713                    log::debug!("rpc message handled. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}");
1714                }
1715                Err(error) => {
1716                    log::error!("error handling message. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}, error:{error:#}");
1717                }
1718            })
1719            .detach();
1720        } else {
1721            log::info!("unhandled message {}", type_name);
1722            self.peer
1723                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1724                .log_err();
1725        }
1726    }
1727
1728    pub fn add_message_to_client_handler(
1729        self: &Arc<Client>,
1730        handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
1731    ) {
1732        self.message_to_client_handlers
1733            .lock()
1734            .push(Box::new(handler));
1735    }
1736
1737    fn handle_message_to_client(self: &Arc<Client>, message: MessageToClient, cx: &AsyncApp) {
1738        cx.update(|cx| {
1739            for handler in self.message_to_client_handlers.lock().iter() {
1740                handler(&message, cx);
1741            }
1742        })
1743        .ok();
1744    }
1745
1746    pub fn telemetry(&self) -> &Arc<Telemetry> {
1747        &self.telemetry
1748    }
1749}
1750
1751impl ProtoClient for Client {
1752    fn request(
1753        &self,
1754        envelope: proto::Envelope,
1755        request_type: &'static str,
1756    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1757        self.request_dynamic(envelope, request_type).boxed()
1758    }
1759
1760    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1761        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1762        let connection_id = self.connection_id()?;
1763        self.peer.send_dynamic(connection_id, envelope)
1764    }
1765
1766    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1767        log::debug!(
1768            "rpc respond. client_id:{}, name:{}",
1769            self.id(),
1770            message_type
1771        );
1772        let connection_id = self.connection_id()?;
1773        self.peer.send_dynamic(connection_id, envelope)
1774    }
1775
1776    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1777        &self.handler_set
1778    }
1779
1780    fn is_via_collab(&self) -> bool {
1781        true
1782    }
1783}
1784
1785/// prefix for the zed:// url scheme
1786pub const ZED_URL_SCHEME: &str = "zed";
1787
1788/// Parses the given link into a Zed link.
1789///
1790/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1791/// Returns [`None`] otherwise.
1792pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
1793    let server_url = &ClientSettings::get_global(cx).server_url;
1794    if let Some(stripped) = link
1795        .strip_prefix(server_url)
1796        .and_then(|result| result.strip_prefix('/'))
1797    {
1798        return Some(stripped);
1799    }
1800    if let Some(stripped) = link
1801        .strip_prefix(ZED_URL_SCHEME)
1802        .and_then(|result| result.strip_prefix("://"))
1803    {
1804        return Some(stripped);
1805    }
1806
1807    None
1808}
1809
1810#[cfg(test)]
1811mod tests {
1812    use super::*;
1813    use crate::test::{FakeServer, parse_authorization_header};
1814
1815    use clock::FakeSystemClock;
1816    use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1817    use http_client::FakeHttpClient;
1818    use parking_lot::Mutex;
1819    use proto::TypedEnvelope;
1820    use settings::SettingsStore;
1821    use std::future;
1822
1823    #[gpui::test(iterations = 10)]
1824    async fn test_reconnection(cx: &mut TestAppContext) {
1825        init_test(cx);
1826        let user_id = 5;
1827        let client = cx.update(|cx| {
1828            Client::new(
1829                Arc::new(FakeSystemClock::new()),
1830                FakeHttpClient::with_404_response(),
1831                cx,
1832            )
1833        });
1834        let server = FakeServer::for_client(user_id, &client, cx).await;
1835        let mut status = client.status();
1836        assert!(matches!(
1837            status.next().await,
1838            Some(Status::Connected { .. })
1839        ));
1840        assert_eq!(server.auth_count(), 1);
1841
1842        server.forbid_connections();
1843        server.disconnect();
1844        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1845
1846        server.allow_connections();
1847        cx.executor().advance_clock(Duration::from_secs(10));
1848        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1849        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1850
1851        server.forbid_connections();
1852        server.disconnect();
1853        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1854
1855        // Clear cached credentials after authentication fails
1856        server.roll_access_token();
1857        server.allow_connections();
1858        cx.executor().run_until_parked();
1859        cx.executor().advance_clock(Duration::from_secs(10));
1860        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1861        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1862    }
1863
1864    #[gpui::test(iterations = 10)]
1865    async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) {
1866        init_test(cx);
1867        let http_client = FakeHttpClient::with_200_response();
1868        let client =
1869            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1870        let server = FakeServer::for_client(42, &client, cx).await;
1871        let mut status = client.status();
1872        assert!(matches!(
1873            status.next().await,
1874            Some(Status::Connected { .. })
1875        ));
1876        assert_eq!(server.auth_count(), 1);
1877
1878        // Simulate an auth failure during reconnection.
1879        http_client
1880            .as_fake()
1881            .replace_handler(|_, _request| async move {
1882                Ok(http_client::Response::builder()
1883                    .status(503)
1884                    .body("".into())
1885                    .unwrap())
1886            });
1887        server.disconnect();
1888        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1889
1890        // Restore the ability to authenticate.
1891        http_client
1892            .as_fake()
1893            .replace_handler(|_, _request| async move {
1894                Ok(http_client::Response::builder()
1895                    .status(200)
1896                    .body("".into())
1897                    .unwrap())
1898            });
1899        cx.executor().advance_clock(Duration::from_secs(10));
1900        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1901        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1902    }
1903
1904    #[gpui::test(iterations = 10)]
1905    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1906        init_test(cx);
1907        let user_id = 5;
1908        let client = cx.update(|cx| {
1909            Client::new(
1910                Arc::new(FakeSystemClock::new()),
1911                FakeHttpClient::with_404_response(),
1912                cx,
1913            )
1914        });
1915        let mut status = client.status();
1916
1917        // Time out when client tries to connect.
1918        client.override_authenticate(move |cx| {
1919            cx.background_spawn(async move {
1920                Ok(Credentials {
1921                    user_id,
1922                    access_token: "token".into(),
1923                })
1924            })
1925        });
1926        client.override_establish_connection(|_, cx| {
1927            cx.background_spawn(async move {
1928                future::pending::<()>().await;
1929                unreachable!()
1930            })
1931        });
1932        let auth_and_connect = cx.spawn({
1933            let client = client.clone();
1934            |cx| async move { client.connect(false, &cx).await }
1935        });
1936        executor.run_until_parked();
1937        assert!(matches!(status.next().await, Some(Status::Connecting)));
1938
1939        executor.advance_clock(CONNECTION_TIMEOUT);
1940        assert!(matches!(status.next().await, Some(Status::ConnectionError)));
1941        auth_and_connect.await.into_response().unwrap_err();
1942
1943        // Allow the connection to be established.
1944        let server = FakeServer::for_client(user_id, &client, cx).await;
1945        assert!(matches!(
1946            status.next().await,
1947            Some(Status::Connected { .. })
1948        ));
1949
1950        // Disconnect client.
1951        server.forbid_connections();
1952        server.disconnect();
1953        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1954
1955        // Time out when re-establishing the connection.
1956        server.allow_connections();
1957        client.override_establish_connection(|_, cx| {
1958            cx.background_spawn(async move {
1959                future::pending::<()>().await;
1960                unreachable!()
1961            })
1962        });
1963        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1964        assert!(matches!(status.next().await, Some(Status::Reconnecting)));
1965
1966        executor.advance_clock(CONNECTION_TIMEOUT);
1967        assert!(matches!(
1968            status.next().await,
1969            Some(Status::ReconnectionError { .. })
1970        ));
1971    }
1972
1973    #[gpui::test(iterations = 10)]
1974    async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
1975        init_test(cx);
1976        let auth_count = Arc::new(Mutex::new(0));
1977        let http_client = FakeHttpClient::create(|_request| async move {
1978            Ok(http_client::Response::builder()
1979                .status(200)
1980                .body("".into())
1981                .unwrap())
1982        });
1983        let client =
1984            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1985        client.override_authenticate({
1986            let auth_count = auth_count.clone();
1987            move |cx| {
1988                let auth_count = auth_count.clone();
1989                cx.background_spawn(async move {
1990                    *auth_count.lock() += 1;
1991                    Ok(Credentials {
1992                        user_id: 1,
1993                        access_token: auth_count.lock().to_string(),
1994                    })
1995                })
1996            }
1997        });
1998
1999        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2000        assert_eq!(*auth_count.lock(), 1);
2001        assert_eq!(credentials.access_token, "1");
2002
2003        // If credentials are still valid, signing in doesn't trigger authentication.
2004        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2005        assert_eq!(*auth_count.lock(), 1);
2006        assert_eq!(credentials.access_token, "1");
2007
2008        // If the server is unavailable, signing in doesn't trigger authentication.
2009        http_client
2010            .as_fake()
2011            .replace_handler(|_, _request| async move {
2012                Ok(http_client::Response::builder()
2013                    .status(503)
2014                    .body("".into())
2015                    .unwrap())
2016            });
2017        client.sign_in(false, &cx.to_async()).await.unwrap_err();
2018        assert_eq!(*auth_count.lock(), 1);
2019
2020        // If credentials became invalid, signing in triggers authentication.
2021        http_client
2022            .as_fake()
2023            .replace_handler(|_, request| async move {
2024                let credentials = parse_authorization_header(&request).unwrap();
2025                if credentials.access_token == "2" {
2026                    Ok(http_client::Response::builder()
2027                        .status(200)
2028                        .body("".into())
2029                        .unwrap())
2030                } else {
2031                    Ok(http_client::Response::builder()
2032                        .status(401)
2033                        .body("".into())
2034                        .unwrap())
2035                }
2036            });
2037        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2038        assert_eq!(*auth_count.lock(), 2);
2039        assert_eq!(credentials.access_token, "2");
2040    }
2041
2042    #[gpui::test(iterations = 10)]
2043    async fn test_authenticating_more_than_once(
2044        cx: &mut TestAppContext,
2045        executor: BackgroundExecutor,
2046    ) {
2047        init_test(cx);
2048        let auth_count = Arc::new(Mutex::new(0));
2049        let dropped_auth_count = Arc::new(Mutex::new(0));
2050        let client = cx.update(|cx| {
2051            Client::new(
2052                Arc::new(FakeSystemClock::new()),
2053                FakeHttpClient::with_404_response(),
2054                cx,
2055            )
2056        });
2057        client.override_authenticate({
2058            let auth_count = auth_count.clone();
2059            let dropped_auth_count = dropped_auth_count.clone();
2060            move |cx| {
2061                let auth_count = auth_count.clone();
2062                let dropped_auth_count = dropped_auth_count.clone();
2063                cx.background_spawn(async move {
2064                    *auth_count.lock() += 1;
2065                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
2066                    future::pending::<()>().await;
2067                    unreachable!()
2068                })
2069            }
2070        });
2071
2072        let _authenticate = cx.spawn({
2073            let client = client.clone();
2074            move |cx| async move { client.connect(false, &cx).await }
2075        });
2076        executor.run_until_parked();
2077        assert_eq!(*auth_count.lock(), 1);
2078        assert_eq!(*dropped_auth_count.lock(), 0);
2079
2080        let _authenticate = cx.spawn(|cx| async move { client.connect(false, &cx).await });
2081        executor.run_until_parked();
2082        assert_eq!(*auth_count.lock(), 2);
2083        assert_eq!(*dropped_auth_count.lock(), 1);
2084    }
2085
2086    #[gpui::test]
2087    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
2088        init_test(cx);
2089        let user_id = 5;
2090        let client = cx.update(|cx| {
2091            Client::new(
2092                Arc::new(FakeSystemClock::new()),
2093                FakeHttpClient::with_404_response(),
2094                cx,
2095            )
2096        });
2097        let server = FakeServer::for_client(user_id, &client, cx).await;
2098
2099        let (done_tx1, done_rx1) = smol::channel::unbounded();
2100        let (done_tx2, done_rx2) = smol::channel::unbounded();
2101        AnyProtoClient::from(client.clone()).add_entity_message_handler(
2102            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, cx| {
2103                match entity.read_with(&cx, |entity, _| entity.id).unwrap() {
2104                    1 => done_tx1.try_send(()).unwrap(),
2105                    2 => done_tx2.try_send(()).unwrap(),
2106                    _ => unreachable!(),
2107                }
2108                async { Ok(()) }
2109            },
2110        );
2111        let entity1 = cx.new(|_| TestEntity {
2112            id: 1,
2113            subscription: None,
2114        });
2115        let entity2 = cx.new(|_| TestEntity {
2116            id: 2,
2117            subscription: None,
2118        });
2119        let entity3 = cx.new(|_| TestEntity {
2120            id: 3,
2121            subscription: None,
2122        });
2123
2124        let _subscription1 = client
2125            .subscribe_to_entity(1)
2126            .unwrap()
2127            .set_entity(&entity1, &cx.to_async());
2128        let _subscription2 = client
2129            .subscribe_to_entity(2)
2130            .unwrap()
2131            .set_entity(&entity2, &cx.to_async());
2132        // Ensure dropping a subscription for the same entity type still allows receiving of
2133        // messages for other entity IDs of the same type.
2134        let subscription3 = client
2135            .subscribe_to_entity(3)
2136            .unwrap()
2137            .set_entity(&entity3, &cx.to_async());
2138        drop(subscription3);
2139
2140        server.send(proto::JoinProject {
2141            project_id: 1,
2142            committer_name: None,
2143            committer_email: None,
2144        });
2145        server.send(proto::JoinProject {
2146            project_id: 2,
2147            committer_name: None,
2148            committer_email: None,
2149        });
2150        done_rx1.recv().await.unwrap();
2151        done_rx2.recv().await.unwrap();
2152    }
2153
2154    #[gpui::test]
2155    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2156        init_test(cx);
2157        let user_id = 5;
2158        let client = cx.update(|cx| {
2159            Client::new(
2160                Arc::new(FakeSystemClock::new()),
2161                FakeHttpClient::with_404_response(),
2162                cx,
2163            )
2164        });
2165        let server = FakeServer::for_client(user_id, &client, cx).await;
2166
2167        let entity = cx.new(|_| TestEntity::default());
2168        let (done_tx1, _done_rx1) = smol::channel::unbounded();
2169        let (done_tx2, done_rx2) = smol::channel::unbounded();
2170        let subscription1 = client.add_message_handler(
2171            entity.downgrade(),
2172            move |_, _: TypedEnvelope<proto::Ping>, _| {
2173                done_tx1.try_send(()).unwrap();
2174                async { Ok(()) }
2175            },
2176        );
2177        drop(subscription1);
2178        let _subscription2 = client.add_message_handler(
2179            entity.downgrade(),
2180            move |_, _: TypedEnvelope<proto::Ping>, _| {
2181                done_tx2.try_send(()).unwrap();
2182                async { Ok(()) }
2183            },
2184        );
2185        server.send(proto::Ping {});
2186        done_rx2.recv().await.unwrap();
2187    }
2188
2189    #[gpui::test]
2190    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2191        init_test(cx);
2192        let user_id = 5;
2193        let client = cx.update(|cx| {
2194            Client::new(
2195                Arc::new(FakeSystemClock::new()),
2196                FakeHttpClient::with_404_response(),
2197                cx,
2198            )
2199        });
2200        let server = FakeServer::for_client(user_id, &client, cx).await;
2201
2202        let entity = cx.new(|_| TestEntity::default());
2203        let (done_tx, done_rx) = smol::channel::unbounded();
2204        let subscription = client.add_message_handler(
2205            entity.clone().downgrade(),
2206            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
2207                entity
2208                    .update(&mut cx, |entity, _| entity.subscription.take())
2209                    .unwrap();
2210                done_tx.try_send(()).unwrap();
2211                async { Ok(()) }
2212            },
2213        );
2214        entity.update(cx, |entity, _| {
2215            entity.subscription = Some(subscription);
2216        });
2217        server.send(proto::Ping {});
2218        done_rx.recv().await.unwrap();
2219    }
2220
2221    #[derive(Default)]
2222    struct TestEntity {
2223        id: usize,
2224        subscription: Option<Subscription>,
2225    }
2226
2227    fn init_test(cx: &mut TestAppContext) {
2228        cx.update(|cx| {
2229            let settings_store = SettingsStore::test(cx);
2230            cx.set_global(settings_store);
2231            init_settings(cx);
2232        });
2233    }
2234}