client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod proxy;
   5pub mod telemetry;
   6pub mod user;
   7pub mod zed_urls;
   8
   9use anyhow::{Context as _, Result, anyhow};
  10use async_tungstenite::tungstenite::{
  11    client::IntoClientRequest,
  12    error::Error as WebsocketError,
  13    http::{HeaderValue, Request, StatusCode},
  14};
  15use clock::SystemClock;
  16use cloud_api_client::CloudApiClient;
  17use cloud_api_client::websocket_protocol::MessageToClient;
  18use credentials_provider::CredentialsProvider;
  19use feature_flags::FeatureFlagAppExt as _;
  20use futures::{
  21    AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
  22    channel::oneshot, future::BoxFuture,
  23};
  24use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
  25use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
  26use parking_lot::RwLock;
  27use postage::watch;
  28use proxy::connect_proxy_stream;
  29use rand::prelude::*;
  30use release_channel::{AppVersion, ReleaseChannel};
  31use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  32use schemars::JsonSchema;
  33use serde::{Deserialize, Serialize};
  34use settings::{Settings, SettingsContent, SettingsKey, SettingsUi};
  35use std::{
  36    any::TypeId,
  37    convert::TryFrom,
  38    fmt::Write as _,
  39    future::Future,
  40    marker::PhantomData,
  41    path::PathBuf,
  42    sync::{
  43        Arc, LazyLock, Weak,
  44        atomic::{AtomicU64, Ordering},
  45    },
  46    time::{Duration, Instant},
  47};
  48use std::{cmp, pin::Pin};
  49use telemetry::Telemetry;
  50use thiserror::Error;
  51use tokio::net::TcpStream;
  52use url::Url;
  53use util::{ConnectionResult, MergeFrom, ResultExt};
  54
  55pub use rpc::*;
  56pub use telemetry_events::Event;
  57pub use user::*;
  58
  59static ZED_SERVER_URL: LazyLock<Option<String>> =
  60    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  61static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  62
  63pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  64    std::env::var("ZED_IMPERSONATE")
  65        .ok()
  66        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  67});
  68
  69pub static USE_WEB_LOGIN: LazyLock<bool> = LazyLock::new(|| std::env::var("ZED_WEB_LOGIN").is_ok());
  70
  71pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  72    std::env::var("ZED_ADMIN_API_TOKEN")
  73        .ok()
  74        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  75});
  76
  77pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  78    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  79
  80pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  81    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").is_ok_and(|e| !e.is_empty()));
  82
  83pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  84pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
  85pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  86
  87actions!(
  88    client,
  89    [
  90        /// Signs in to Zed account.
  91        SignIn,
  92        /// Signs out of Zed account.
  93        SignOut,
  94        /// Reconnects to the collaboration server.
  95        Reconnect
  96    ]
  97);
  98
  99#[derive(Deserialize)]
 100pub struct ClientSettings {
 101    pub server_url: String,
 102}
 103
 104impl Settings for ClientSettings {
 105    fn from_defaults(content: &settings::SettingsContent, _cx: &mut App) -> Self {
 106        if let Some(server_url) = &*ZED_SERVER_URL {
 107            return Self {
 108                server_url: server_url.clone(),
 109            };
 110        }
 111        Self {
 112            server_url: content.server_url.clone().unwrap(),
 113        }
 114    }
 115
 116    fn refine(&mut self, content: &settings::SettingsContent, _: &mut App) {
 117        if ZED_SERVER_URL.is_some() {
 118            return;
 119        }
 120        if let Some(server_url) = content.server_url.clone() {
 121            self.server_url = server_url;
 122        }
 123    }
 124
 125    fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut SettingsContent) {}
 126}
 127
 128#[derive(Default, Clone, Serialize, Deserialize, JsonSchema, SettingsUi, SettingsKey)]
 129#[settings_key(None)]
 130pub struct ProxySettingsContent {
 131    proxy: Option<String>,
 132}
 133
 134#[derive(Deserialize, Default)]
 135pub struct ProxySettings {
 136    pub proxy: Option<String>,
 137}
 138
 139impl ProxySettings {
 140    pub fn proxy_url(&self) -> Option<Url> {
 141        self.proxy
 142            .as_ref()
 143            .and_then(|input| {
 144                input
 145                    .parse::<Url>()
 146                    .inspect_err(|e| log::error!("Error parsing proxy settings: {}", e))
 147                    .ok()
 148            })
 149            .or_else(read_proxy_from_env)
 150    }
 151}
 152
 153impl Settings for ProxySettings {
 154    fn from_defaults(content: &settings::SettingsContent, _cx: &mut App) -> Self {
 155        Self {
 156            proxy: content.proxy.clone(),
 157        }
 158    }
 159
 160    fn refine(&mut self, content: &settings::SettingsContent, _: &mut App) {
 161        if let Some(proxy) = content.proxy.clone() {
 162            self.proxy = Some(proxy)
 163        }
 164    }
 165
 166    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut SettingsContent) {
 167        vscode.string_setting("http.proxy", &mut current.proxy);
 168    }
 169}
 170
 171pub fn init_settings(cx: &mut App) {
 172    TelemetrySettings::register(cx);
 173    ClientSettings::register(cx);
 174    ProxySettings::register(cx);
 175}
 176
 177pub fn init(client: &Arc<Client>, cx: &mut App) {
 178    let client = Arc::downgrade(client);
 179    cx.on_action({
 180        let client = client.clone();
 181        move |_: &SignIn, cx| {
 182            if let Some(client) = client.upgrade() {
 183                cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, cx).await)
 184                    .detach_and_log_err(cx);
 185            }
 186        }
 187    });
 188
 189    cx.on_action({
 190        let client = client.clone();
 191        move |_: &SignOut, cx| {
 192            if let Some(client) = client.upgrade() {
 193                cx.spawn(async move |cx| {
 194                    client.sign_out(cx).await;
 195                })
 196                .detach();
 197            }
 198        }
 199    });
 200
 201    cx.on_action({
 202        let client = client;
 203        move |_: &Reconnect, cx| {
 204            if let Some(client) = client.upgrade() {
 205                cx.spawn(async move |cx| {
 206                    client.reconnect(cx);
 207                })
 208                .detach();
 209            }
 210        }
 211    });
 212}
 213
 214pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
 215
 216struct GlobalClient(Arc<Client>);
 217
 218impl Global for GlobalClient {}
 219
 220pub struct Client {
 221    id: AtomicU64,
 222    peer: Arc<Peer>,
 223    http: Arc<HttpClientWithUrl>,
 224    cloud_client: Arc<CloudApiClient>,
 225    telemetry: Arc<Telemetry>,
 226    credentials_provider: ClientCredentialsProvider,
 227    state: RwLock<ClientState>,
 228    handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
 229    message_to_client_handlers: parking_lot::Mutex<Vec<MessageToClientHandler>>,
 230
 231    #[allow(clippy::type_complexity)]
 232    #[cfg(any(test, feature = "test-support"))]
 233    authenticate:
 234        RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
 235
 236    #[allow(clippy::type_complexity)]
 237    #[cfg(any(test, feature = "test-support"))]
 238    establish_connection: RwLock<
 239        Option<
 240            Box<
 241                dyn 'static
 242                    + Send
 243                    + Sync
 244                    + Fn(
 245                        &Credentials,
 246                        &AsyncApp,
 247                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 248            >,
 249        >,
 250    >,
 251
 252    #[cfg(any(test, feature = "test-support"))]
 253    rpc_url: RwLock<Option<Url>>,
 254}
 255
 256#[derive(Error, Debug)]
 257pub enum EstablishConnectionError {
 258    #[error("upgrade required")]
 259    UpgradeRequired,
 260    #[error("unauthorized")]
 261    Unauthorized,
 262    #[error("{0}")]
 263    Other(#[from] anyhow::Error),
 264    #[error("{0}")]
 265    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 266    #[error("{0}")]
 267    Io(#[from] std::io::Error),
 268    #[error("{0}")]
 269    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 270}
 271
 272impl From<WebsocketError> for EstablishConnectionError {
 273    fn from(error: WebsocketError) -> Self {
 274        if let WebsocketError::Http(response) = &error {
 275            match response.status() {
 276                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 277                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 278                _ => {}
 279            }
 280        }
 281        EstablishConnectionError::Other(error.into())
 282    }
 283}
 284
 285impl EstablishConnectionError {
 286    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 287        Self::Other(error.into())
 288    }
 289}
 290
 291#[derive(Copy, Clone, Debug, PartialEq)]
 292pub enum Status {
 293    SignedOut,
 294    UpgradeRequired,
 295    Authenticating,
 296    Authenticated,
 297    AuthenticationError,
 298    Connecting,
 299    ConnectionError,
 300    Connected {
 301        peer_id: PeerId,
 302        connection_id: ConnectionId,
 303    },
 304    ConnectionLost,
 305    Reauthenticating,
 306    Reauthenticated,
 307    Reconnecting,
 308    ReconnectionError {
 309        next_reconnection: Instant,
 310    },
 311}
 312
 313impl Status {
 314    pub fn is_connected(&self) -> bool {
 315        matches!(self, Self::Connected { .. })
 316    }
 317
 318    pub fn was_connected(&self) -> bool {
 319        matches!(
 320            self,
 321            Self::ConnectionLost
 322                | Self::Reauthenticating
 323                | Self::Reauthenticated
 324                | Self::Reconnecting
 325        )
 326    }
 327
 328    /// Returns whether the client is currently connected or was connected at some point.
 329    pub fn is_or_was_connected(&self) -> bool {
 330        self.is_connected() || self.was_connected()
 331    }
 332
 333    pub fn is_signing_in(&self) -> bool {
 334        matches!(
 335            self,
 336            Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
 337        )
 338    }
 339
 340    pub fn is_signed_out(&self) -> bool {
 341        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 342    }
 343}
 344
 345struct ClientState {
 346    credentials: Option<Credentials>,
 347    status: (watch::Sender<Status>, watch::Receiver<Status>),
 348    _reconnect_task: Option<Task<()>>,
 349}
 350
 351#[derive(Clone, Debug, Eq, PartialEq)]
 352pub struct Credentials {
 353    pub user_id: u64,
 354    pub access_token: String,
 355}
 356
 357impl Credentials {
 358    pub fn authorization_header(&self) -> String {
 359        format!("{} {}", self.user_id, self.access_token)
 360    }
 361}
 362
 363pub struct ClientCredentialsProvider {
 364    provider: Arc<dyn CredentialsProvider>,
 365}
 366
 367impl ClientCredentialsProvider {
 368    pub fn new(cx: &App) -> Self {
 369        Self {
 370            provider: <dyn CredentialsProvider>::global(cx),
 371        }
 372    }
 373
 374    fn server_url(&self, cx: &AsyncApp) -> Result<String> {
 375        cx.update(|cx| ClientSettings::get_global(cx).server_url.clone())
 376    }
 377
 378    /// Reads the credentials from the provider.
 379    fn read_credentials<'a>(
 380        &'a self,
 381        cx: &'a AsyncApp,
 382    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
 383        async move {
 384            if IMPERSONATE_LOGIN.is_some() {
 385                return None;
 386            }
 387
 388            let server_url = self.server_url(cx).ok()?;
 389            let (user_id, access_token) = self
 390                .provider
 391                .read_credentials(&server_url, cx)
 392                .await
 393                .log_err()
 394                .flatten()?;
 395
 396            Some(Credentials {
 397                user_id: user_id.parse().ok()?,
 398                access_token: String::from_utf8(access_token).ok()?,
 399            })
 400        }
 401        .boxed_local()
 402    }
 403
 404    /// Writes the credentials to the provider.
 405    fn write_credentials<'a>(
 406        &'a self,
 407        user_id: u64,
 408        access_token: String,
 409        cx: &'a AsyncApp,
 410    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 411        async move {
 412            let server_url = self.server_url(cx)?;
 413            self.provider
 414                .write_credentials(
 415                    &server_url,
 416                    &user_id.to_string(),
 417                    access_token.as_bytes(),
 418                    cx,
 419                )
 420                .await
 421        }
 422        .boxed_local()
 423    }
 424
 425    /// Deletes the credentials from the provider.
 426    fn delete_credentials<'a>(
 427        &'a self,
 428        cx: &'a AsyncApp,
 429    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 430        async move {
 431            let server_url = self.server_url(cx)?;
 432            self.provider.delete_credentials(&server_url, cx).await
 433        }
 434        .boxed_local()
 435    }
 436}
 437
 438impl Default for ClientState {
 439    fn default() -> Self {
 440        Self {
 441            credentials: None,
 442            status: watch::channel_with(Status::SignedOut),
 443            _reconnect_task: None,
 444        }
 445    }
 446}
 447
 448pub enum Subscription {
 449    Entity {
 450        client: Weak<Client>,
 451        id: (TypeId, u64),
 452    },
 453    Message {
 454        client: Weak<Client>,
 455        id: TypeId,
 456    },
 457}
 458
 459impl Drop for Subscription {
 460    fn drop(&mut self) {
 461        match self {
 462            Subscription::Entity { client, id } => {
 463                if let Some(client) = client.upgrade() {
 464                    let mut state = client.handler_set.lock();
 465                    let _ = state.entities_by_type_and_remote_id.remove(id);
 466                }
 467            }
 468            Subscription::Message { client, id } => {
 469                if let Some(client) = client.upgrade() {
 470                    let mut state = client.handler_set.lock();
 471                    let _ = state.entity_types_by_message_type.remove(id);
 472                    let _ = state.message_handlers.remove(id);
 473                }
 474            }
 475        }
 476    }
 477}
 478
 479pub struct PendingEntitySubscription<T: 'static> {
 480    client: Arc<Client>,
 481    remote_id: u64,
 482    _entity_type: PhantomData<T>,
 483    consumed: bool,
 484}
 485
 486impl<T: 'static> PendingEntitySubscription<T> {
 487    pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
 488        self.consumed = true;
 489        let mut handlers = self.client.handler_set.lock();
 490        let id = (TypeId::of::<T>(), self.remote_id);
 491        let Some(EntityMessageSubscriber::Pending(messages)) =
 492            handlers.entities_by_type_and_remote_id.remove(&id)
 493        else {
 494            unreachable!()
 495        };
 496
 497        handlers.entities_by_type_and_remote_id.insert(
 498            id,
 499            EntityMessageSubscriber::Entity {
 500                handle: entity.downgrade().into(),
 501            },
 502        );
 503        drop(handlers);
 504        for message in messages {
 505            let client_id = self.client.id();
 506            let type_name = message.payload_type_name();
 507            let sender_id = message.original_sender_id();
 508            log::debug!(
 509                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 510                client_id,
 511                sender_id,
 512                type_name
 513            );
 514            self.client.handle_message(message, cx);
 515        }
 516        Subscription::Entity {
 517            client: Arc::downgrade(&self.client),
 518            id,
 519        }
 520    }
 521}
 522
 523impl<T: 'static> Drop for PendingEntitySubscription<T> {
 524    fn drop(&mut self) {
 525        if !self.consumed {
 526            let mut state = self.client.handler_set.lock();
 527            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 528                .entities_by_type_and_remote_id
 529                .remove(&(TypeId::of::<T>(), self.remote_id))
 530            {
 531                for message in messages {
 532                    log::info!("unhandled message {}", message.payload_type_name());
 533                }
 534            }
 535        }
 536    }
 537}
 538
 539#[derive(Copy, Clone, Deserialize, Debug)]
 540pub struct TelemetrySettings {
 541    pub diagnostics: bool,
 542    pub metrics: bool,
 543}
 544
 545impl settings::Settings for TelemetrySettings {
 546    fn from_defaults(content: &SettingsContent, _cx: &mut App) -> Self {
 547        Self {
 548            diagnostics: content.telemetry.as_ref().unwrap().diagnostics.unwrap(),
 549            metrics: content.telemetry.as_ref().unwrap().metrics.unwrap(),
 550        }
 551    }
 552
 553    fn refine(&mut self, content: &SettingsContent, _cx: &mut App) {
 554        let Some(telemetry) = &content.telemetry else {
 555            return;
 556        };
 557        self.diagnostics.merge_from(&telemetry.diagnostics);
 558        self.metrics.merge_from(&telemetry.metrics);
 559    }
 560
 561    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut SettingsContent) {
 562        let mut telemetry = settings::TelemetrySettingsContent::default();
 563        vscode.enum_setting("telemetry.telemetryLevel", &mut telemetry.metrics, |s| {
 564            Some(s == "all")
 565        });
 566        vscode.enum_setting(
 567            "telemetry.telemetryLevel",
 568            &mut telemetry.diagnostics,
 569            |s| Some(matches!(s, "all" | "error" | "crash")),
 570        );
 571        // we could translate telemetry.telemetryLevel, but just because users didn't want
 572        // to send microsoft telemetry doesn't mean they don't want to send it to zed. their
 573        // all/error/crash/off correspond to combinations of our "diagnostics" and "metrics".
 574        if let Some(diagnostics) = telemetry.diagnostics {
 575            current.telemetry.get_or_insert_default().diagnostics = Some(diagnostics)
 576        }
 577        if let Some(metrics) = telemetry.metrics {
 578            current.telemetry.get_or_insert_default().metrics = Some(metrics)
 579        }
 580    }
 581}
 582
 583impl Client {
 584    pub fn new(
 585        clock: Arc<dyn SystemClock>,
 586        http: Arc<HttpClientWithUrl>,
 587        cx: &mut App,
 588    ) -> Arc<Self> {
 589        Arc::new(Self {
 590            id: AtomicU64::new(0),
 591            peer: Peer::new(0),
 592            telemetry: Telemetry::new(clock, http.clone(), cx),
 593            cloud_client: Arc::new(CloudApiClient::new(http.clone())),
 594            http,
 595            credentials_provider: ClientCredentialsProvider::new(cx),
 596            state: Default::default(),
 597            handler_set: Default::default(),
 598            message_to_client_handlers: parking_lot::Mutex::new(Vec::new()),
 599
 600            #[cfg(any(test, feature = "test-support"))]
 601            authenticate: Default::default(),
 602            #[cfg(any(test, feature = "test-support"))]
 603            establish_connection: Default::default(),
 604            #[cfg(any(test, feature = "test-support"))]
 605            rpc_url: RwLock::default(),
 606        })
 607    }
 608
 609    pub fn production(cx: &mut App) -> Arc<Self> {
 610        let clock = Arc::new(clock::RealSystemClock);
 611        let http = Arc::new(HttpClientWithUrl::new_url(
 612            cx.http_client(),
 613            &ClientSettings::get_global(cx).server_url,
 614            cx.http_client().proxy().cloned(),
 615        ));
 616        Self::new(clock, http, cx)
 617    }
 618
 619    pub fn id(&self) -> u64 {
 620        self.id.load(Ordering::SeqCst)
 621    }
 622
 623    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 624        self.http.clone()
 625    }
 626
 627    pub fn cloud_client(&self) -> Arc<CloudApiClient> {
 628        self.cloud_client.clone()
 629    }
 630
 631    pub fn set_id(&self, id: u64) -> &Self {
 632        self.id.store(id, Ordering::SeqCst);
 633        self
 634    }
 635
 636    #[cfg(any(test, feature = "test-support"))]
 637    pub fn teardown(&self) {
 638        let mut state = self.state.write();
 639        state._reconnect_task.take();
 640        self.handler_set.lock().clear();
 641        self.peer.teardown();
 642    }
 643
 644    #[cfg(any(test, feature = "test-support"))]
 645    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 646    where
 647        F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
 648    {
 649        *self.authenticate.write() = Some(Box::new(authenticate));
 650        self
 651    }
 652
 653    #[cfg(any(test, feature = "test-support"))]
 654    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 655    where
 656        F: 'static
 657            + Send
 658            + Sync
 659            + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
 660    {
 661        *self.establish_connection.write() = Some(Box::new(connect));
 662        self
 663    }
 664
 665    #[cfg(any(test, feature = "test-support"))]
 666    pub fn override_rpc_url(&self, url: Url) -> &Self {
 667        *self.rpc_url.write() = Some(url);
 668        self
 669    }
 670
 671    pub fn global(cx: &App) -> Arc<Self> {
 672        cx.global::<GlobalClient>().0.clone()
 673    }
 674    pub fn set_global(client: Arc<Client>, cx: &mut App) {
 675        cx.set_global(GlobalClient(client))
 676    }
 677
 678    pub fn user_id(&self) -> Option<u64> {
 679        self.state
 680            .read()
 681            .credentials
 682            .as_ref()
 683            .map(|credentials| credentials.user_id)
 684    }
 685
 686    pub fn peer_id(&self) -> Option<PeerId> {
 687        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 688            Some(*peer_id)
 689        } else {
 690            None
 691        }
 692    }
 693
 694    pub fn status(&self) -> watch::Receiver<Status> {
 695        self.state.read().status.1.clone()
 696    }
 697
 698    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
 699        log::info!("set status on client {}: {:?}", self.id(), status);
 700        let mut state = self.state.write();
 701        *state.status.0.borrow_mut() = status;
 702
 703        match status {
 704            Status::Connected { .. } => {
 705                state._reconnect_task = None;
 706            }
 707            Status::ConnectionLost => {
 708                let client = self.clone();
 709                state._reconnect_task = Some(cx.spawn(async move |cx| {
 710                    #[cfg(any(test, feature = "test-support"))]
 711                    let mut rng = StdRng::seed_from_u64(0);
 712                    #[cfg(not(any(test, feature = "test-support")))]
 713                    let mut rng = StdRng::from_os_rng();
 714
 715                    let mut delay = INITIAL_RECONNECTION_DELAY;
 716                    loop {
 717                        match client.connect(true, cx).await {
 718                            ConnectionResult::Timeout => {
 719                                log::error!("client connect attempt timed out")
 720                            }
 721                            ConnectionResult::ConnectionReset => {
 722                                log::error!("client connect attempt reset")
 723                            }
 724                            ConnectionResult::Result(r) => {
 725                                if let Err(error) = r {
 726                                    log::error!("failed to connect: {error}");
 727                                } else {
 728                                    break;
 729                                }
 730                            }
 731                        }
 732
 733                        if matches!(
 734                            *client.status().borrow(),
 735                            Status::AuthenticationError | Status::ConnectionError
 736                        ) {
 737                            client.set_status(
 738                                Status::ReconnectionError {
 739                                    next_reconnection: Instant::now() + delay,
 740                                },
 741                                cx,
 742                            );
 743                            let jitter = Duration::from_millis(
 744                                rng.random_range(0..delay.as_millis() as u64),
 745                            );
 746                            cx.background_executor().timer(delay + jitter).await;
 747                            delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
 748                        } else {
 749                            break;
 750                        }
 751                    }
 752                }));
 753            }
 754            Status::SignedOut | Status::UpgradeRequired => {
 755                self.telemetry.set_authenticated_user_info(None, false);
 756                state._reconnect_task.take();
 757            }
 758            _ => {}
 759        }
 760    }
 761
 762    pub fn subscribe_to_entity<T>(
 763        self: &Arc<Self>,
 764        remote_id: u64,
 765    ) -> Result<PendingEntitySubscription<T>>
 766    where
 767        T: 'static,
 768    {
 769        let id = (TypeId::of::<T>(), remote_id);
 770
 771        let mut state = self.handler_set.lock();
 772        anyhow::ensure!(
 773            !state.entities_by_type_and_remote_id.contains_key(&id),
 774            "already subscribed to entity"
 775        );
 776
 777        state
 778            .entities_by_type_and_remote_id
 779            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 780
 781        Ok(PendingEntitySubscription {
 782            client: self.clone(),
 783            remote_id,
 784            consumed: false,
 785            _entity_type: PhantomData,
 786        })
 787    }
 788
 789    #[track_caller]
 790    pub fn add_message_handler<M, E, H, F>(
 791        self: &Arc<Self>,
 792        entity: WeakEntity<E>,
 793        handler: H,
 794    ) -> Subscription
 795    where
 796        M: EnvelopedMessage,
 797        E: 'static,
 798        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 799        F: 'static + Future<Output = Result<()>>,
 800    {
 801        self.add_message_handler_impl(entity, move |entity, message, _, cx| {
 802            handler(entity, message, cx)
 803        })
 804    }
 805
 806    fn add_message_handler_impl<M, E, H, F>(
 807        self: &Arc<Self>,
 808        entity: WeakEntity<E>,
 809        handler: H,
 810    ) -> Subscription
 811    where
 812        M: EnvelopedMessage,
 813        E: 'static,
 814        H: 'static
 815            + Sync
 816            + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
 817            + Send
 818            + Sync,
 819        F: 'static + Future<Output = Result<()>>,
 820    {
 821        let message_type_id = TypeId::of::<M>();
 822        let mut state = self.handler_set.lock();
 823        state
 824            .entities_by_message_type
 825            .insert(message_type_id, entity.into());
 826
 827        let prev_handler = state.message_handlers.insert(
 828            message_type_id,
 829            Arc::new(move |subscriber, envelope, client, cx| {
 830                let subscriber = subscriber.downcast::<E>().unwrap();
 831                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 832                handler(subscriber, *envelope, client, cx).boxed_local()
 833            }),
 834        );
 835        if prev_handler.is_some() {
 836            let location = std::panic::Location::caller();
 837            panic!(
 838                "{}:{} registered handler for the same message {} twice",
 839                location.file(),
 840                location.line(),
 841                std::any::type_name::<M>()
 842            );
 843        }
 844
 845        Subscription::Message {
 846            client: Arc::downgrade(self),
 847            id: message_type_id,
 848        }
 849    }
 850
 851    pub fn add_request_handler<M, E, H, F>(
 852        self: &Arc<Self>,
 853        entity: WeakEntity<E>,
 854        handler: H,
 855    ) -> Subscription
 856    where
 857        M: RequestMessage,
 858        E: 'static,
 859        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 860        F: 'static + Future<Output = Result<M::Response>>,
 861    {
 862        self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
 863            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 864        })
 865    }
 866
 867    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 868        receipt: Receipt<T>,
 869        response: F,
 870        client: AnyProtoClient,
 871    ) -> Result<()> {
 872        match response.await {
 873            Ok(response) => {
 874                client.send_response(receipt.message_id, response)?;
 875                Ok(())
 876            }
 877            Err(error) => {
 878                client.send_response(receipt.message_id, error.to_proto())?;
 879                Err(error)
 880            }
 881        }
 882    }
 883
 884    pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
 885        self.credentials_provider
 886            .read_credentials(cx)
 887            .await
 888            .is_some()
 889    }
 890
 891    pub async fn sign_in(
 892        self: &Arc<Self>,
 893        try_provider: bool,
 894        cx: &AsyncApp,
 895    ) -> Result<Credentials> {
 896        let is_reauthenticating = if self.status().borrow().is_signed_out() {
 897            self.set_status(Status::Authenticating, cx);
 898            false
 899        } else {
 900            self.set_status(Status::Reauthenticating, cx);
 901            true
 902        };
 903
 904        let mut credentials = None;
 905
 906        let old_credentials = self.state.read().credentials.clone();
 907        if let Some(old_credentials) = old_credentials
 908            && self.validate_credentials(&old_credentials, cx).await?
 909        {
 910            credentials = Some(old_credentials);
 911        }
 912
 913        if credentials.is_none()
 914            && try_provider
 915            && let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await
 916        {
 917            if self.validate_credentials(&stored_credentials, cx).await? {
 918                credentials = Some(stored_credentials);
 919            } else {
 920                self.credentials_provider
 921                    .delete_credentials(cx)
 922                    .await
 923                    .log_err();
 924            }
 925        }
 926
 927        if credentials.is_none() {
 928            let mut status_rx = self.status();
 929            let _ = status_rx.next().await;
 930            futures::select_biased! {
 931                authenticate = self.authenticate(cx).fuse() => {
 932                    match authenticate {
 933                        Ok(creds) => {
 934                            if IMPERSONATE_LOGIN.is_none() {
 935                                self.credentials_provider
 936                                    .write_credentials(creds.user_id, creds.access_token.clone(), cx)
 937                                    .await
 938                                    .log_err();
 939                            }
 940
 941                            credentials = Some(creds);
 942                        },
 943                        Err(err) => {
 944                            self.set_status(Status::AuthenticationError, cx);
 945                            return Err(err);
 946                        }
 947                    }
 948                }
 949                _ = status_rx.next().fuse() => {
 950                    return Err(anyhow!("authentication canceled"));
 951                }
 952            }
 953        }
 954
 955        let credentials = credentials.unwrap();
 956        self.set_id(credentials.user_id);
 957        self.cloud_client
 958            .set_credentials(credentials.user_id as u32, credentials.access_token.clone());
 959        self.state.write().credentials = Some(credentials.clone());
 960        self.set_status(
 961            if is_reauthenticating {
 962                Status::Reauthenticated
 963            } else {
 964                Status::Authenticated
 965            },
 966            cx,
 967        );
 968
 969        Ok(credentials)
 970    }
 971
 972    async fn validate_credentials(
 973        self: &Arc<Self>,
 974        credentials: &Credentials,
 975        cx: &AsyncApp,
 976    ) -> Result<bool> {
 977        match self
 978            .cloud_client
 979            .validate_credentials(credentials.user_id as u32, &credentials.access_token)
 980            .await
 981        {
 982            Ok(valid) => Ok(valid),
 983            Err(err) => {
 984                self.set_status(Status::AuthenticationError, cx);
 985                Err(anyhow!("failed to validate credentials: {}", err))
 986            }
 987        }
 988    }
 989
 990    /// Establishes a WebSocket connection with Cloud for receiving updates from the server.
 991    async fn connect_to_cloud(self: &Arc<Self>, cx: &AsyncApp) -> Result<()> {
 992        let connect_task = cx.update({
 993            let cloud_client = self.cloud_client.clone();
 994            move |cx| cloud_client.connect(cx)
 995        })??;
 996        let connection = connect_task.await?;
 997
 998        let (mut messages, task) = cx.update(|cx| connection.spawn(cx))?;
 999        task.detach();
1000
1001        cx.spawn({
1002            let this = self.clone();
1003            async move |cx| {
1004                while let Some(message) = messages.next().await {
1005                    if let Some(message) = message.log_err() {
1006                        this.handle_message_to_client(message, cx);
1007                    }
1008                }
1009            }
1010        })
1011        .detach();
1012
1013        Ok(())
1014    }
1015
1016    /// Performs a sign-in and also (optionally) connects to Collab.
1017    ///
1018    /// Only Zed staff automatically connect to Collab.
1019    pub async fn sign_in_with_optional_connect(
1020        self: &Arc<Self>,
1021        try_provider: bool,
1022        cx: &AsyncApp,
1023    ) -> Result<()> {
1024        // Don't try to sign in again if we're already connected to Collab, as it will temporarily disconnect us.
1025        if self.status().borrow().is_connected() {
1026            return Ok(());
1027        }
1028
1029        let (is_staff_tx, is_staff_rx) = oneshot::channel::<bool>();
1030        let mut is_staff_tx = Some(is_staff_tx);
1031        cx.update(|cx| {
1032            cx.on_flags_ready(move |state, _cx| {
1033                if let Some(is_staff_tx) = is_staff_tx.take() {
1034                    is_staff_tx.send(state.is_staff).log_err();
1035                }
1036            })
1037            .detach();
1038        })
1039        .log_err();
1040
1041        let credentials = self.sign_in(try_provider, cx).await?;
1042
1043        self.connect_to_cloud(cx).await.log_err();
1044
1045        cx.update(move |cx| {
1046            cx.spawn({
1047                let client = self.clone();
1048                async move |cx| {
1049                    let is_staff = is_staff_rx.await?;
1050                    if is_staff {
1051                        match client.connect_with_credentials(credentials, cx).await {
1052                            ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
1053                            ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
1054                            ConnectionResult::Result(result) => {
1055                                result.context("client auth and connect")
1056                            }
1057                        }
1058                    } else {
1059                        Ok(())
1060                    }
1061                }
1062            })
1063            .detach_and_log_err(cx);
1064        })
1065        .log_err();
1066
1067        Ok(())
1068    }
1069
1070    pub async fn connect(
1071        self: &Arc<Self>,
1072        try_provider: bool,
1073        cx: &AsyncApp,
1074    ) -> ConnectionResult<()> {
1075        let was_disconnected = match *self.status().borrow() {
1076            Status::SignedOut | Status::Authenticated => true,
1077            Status::ConnectionError
1078            | Status::ConnectionLost
1079            | Status::Authenticating
1080            | Status::AuthenticationError
1081            | Status::Reauthenticating
1082            | Status::Reauthenticated
1083            | Status::ReconnectionError { .. } => false,
1084            Status::Connected { .. } | Status::Connecting | Status::Reconnecting => {
1085                return ConnectionResult::Result(Ok(()));
1086            }
1087            Status::UpgradeRequired => {
1088                return ConnectionResult::Result(
1089                    Err(EstablishConnectionError::UpgradeRequired)
1090                        .context("client auth and connect"),
1091                );
1092            }
1093        };
1094        let credentials = match self.sign_in(try_provider, cx).await {
1095            Ok(credentials) => credentials,
1096            Err(err) => return ConnectionResult::Result(Err(err)),
1097        };
1098
1099        if was_disconnected {
1100            self.set_status(Status::Connecting, cx);
1101        } else {
1102            self.set_status(Status::Reconnecting, cx);
1103        }
1104
1105        self.connect_with_credentials(credentials, cx).await
1106    }
1107
1108    async fn connect_with_credentials(
1109        self: &Arc<Self>,
1110        credentials: Credentials,
1111        cx: &AsyncApp,
1112    ) -> ConnectionResult<()> {
1113        let mut timeout =
1114            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
1115        futures::select_biased! {
1116            connection = self.establish_connection(&credentials, cx).fuse() => {
1117                match connection {
1118                    Ok(conn) => {
1119                        futures::select_biased! {
1120                            result = self.set_connection(conn, cx).fuse() => {
1121                                match result.context("client auth and connect") {
1122                                    Ok(()) => ConnectionResult::Result(Ok(())),
1123                                    Err(err) => {
1124                                        self.set_status(Status::ConnectionError, cx);
1125                                        ConnectionResult::Result(Err(err))
1126                                    },
1127                                }
1128                            },
1129                            _ = timeout => {
1130                                self.set_status(Status::ConnectionError, cx);
1131                                ConnectionResult::Timeout
1132                            }
1133                        }
1134                    }
1135                    Err(EstablishConnectionError::Unauthorized) => {
1136                        self.set_status(Status::ConnectionError, cx);
1137                        ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
1138                    }
1139                    Err(EstablishConnectionError::UpgradeRequired) => {
1140                        self.set_status(Status::UpgradeRequired, cx);
1141                        ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
1142                    }
1143                    Err(error) => {
1144                        self.set_status(Status::ConnectionError, cx);
1145                        ConnectionResult::Result(Err(error).context("client auth and connect"))
1146                    }
1147                }
1148            }
1149            _ = &mut timeout => {
1150                self.set_status(Status::ConnectionError, cx);
1151                ConnectionResult::Timeout
1152            }
1153        }
1154    }
1155
1156    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
1157        let executor = cx.background_executor();
1158        log::debug!("add connection to peer");
1159        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
1160            let executor = executor.clone();
1161            move |duration| executor.timer(duration)
1162        });
1163        let handle_io = executor.spawn(handle_io);
1164
1165        let peer_id = async {
1166            log::debug!("waiting for server hello");
1167            let message = incoming.next().await.context("no hello message received")?;
1168            log::debug!("got server hello");
1169            let hello_message_type_name = message.payload_type_name().to_string();
1170            let hello = message
1171                .into_any()
1172                .downcast::<TypedEnvelope<proto::Hello>>()
1173                .map_err(|_| {
1174                    anyhow!(
1175                        "invalid hello message received: {:?}",
1176                        hello_message_type_name
1177                    )
1178                })?;
1179            let peer_id = hello.payload.peer_id.context("invalid peer id")?;
1180            Ok(peer_id)
1181        };
1182
1183        let peer_id = match peer_id.await {
1184            Ok(peer_id) => peer_id,
1185            Err(error) => {
1186                self.peer.disconnect(connection_id);
1187                return Err(error);
1188            }
1189        };
1190
1191        log::debug!(
1192            "set status to connected (connection id: {:?}, peer id: {:?})",
1193            connection_id,
1194            peer_id
1195        );
1196        self.set_status(
1197            Status::Connected {
1198                peer_id,
1199                connection_id,
1200            },
1201            cx,
1202        );
1203
1204        cx.spawn({
1205            let this = self.clone();
1206            async move |cx| {
1207                while let Some(message) = incoming.next().await {
1208                    this.handle_message(message, cx);
1209                    // Don't starve the main thread when receiving lots of messages at once.
1210                    smol::future::yield_now().await;
1211                }
1212            }
1213        })
1214        .detach();
1215
1216        cx.spawn({
1217            let this = self.clone();
1218            async move |cx| match handle_io.await {
1219                Ok(()) => {
1220                    if *this.status().borrow()
1221                        == (Status::Connected {
1222                            connection_id,
1223                            peer_id,
1224                        })
1225                    {
1226                        this.set_status(Status::SignedOut, cx);
1227                    }
1228                }
1229                Err(err) => {
1230                    log::error!("connection error: {:?}", err);
1231                    this.set_status(Status::ConnectionLost, cx);
1232                }
1233            }
1234        })
1235        .detach();
1236
1237        Ok(())
1238    }
1239
1240    fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1241        #[cfg(any(test, feature = "test-support"))]
1242        if let Some(callback) = self.authenticate.read().as_ref() {
1243            return callback(cx);
1244        }
1245
1246        self.authenticate_with_browser(cx)
1247    }
1248
1249    fn establish_connection(
1250        self: &Arc<Self>,
1251        credentials: &Credentials,
1252        cx: &AsyncApp,
1253    ) -> Task<Result<Connection, EstablishConnectionError>> {
1254        #[cfg(any(test, feature = "test-support"))]
1255        if let Some(callback) = self.establish_connection.read().as_ref() {
1256            return callback(credentials, cx);
1257        }
1258
1259        self.establish_websocket_connection(credentials, cx)
1260    }
1261
1262    fn rpc_url(
1263        &self,
1264        http: Arc<HttpClientWithUrl>,
1265        release_channel: Option<ReleaseChannel>,
1266    ) -> impl Future<Output = Result<url::Url>> + use<> {
1267        #[cfg(any(test, feature = "test-support"))]
1268        let url_override = self.rpc_url.read().clone();
1269
1270        async move {
1271            #[cfg(any(test, feature = "test-support"))]
1272            if let Some(url) = url_override {
1273                return Ok(url);
1274            }
1275
1276            if let Some(url) = &*ZED_RPC_URL {
1277                return Url::parse(url).context("invalid rpc url");
1278            }
1279
1280            let mut url = http.build_url("/rpc");
1281            if let Some(preview_param) =
1282                release_channel.and_then(|channel| channel.release_query_param())
1283            {
1284                url += "?";
1285                url += preview_param;
1286            }
1287
1288            let response = http.get(&url, Default::default(), false).await?;
1289            anyhow::ensure!(
1290                response.status().is_redirection(),
1291                "unexpected /rpc response status {}",
1292                response.status()
1293            );
1294            let collab_url = response
1295                .headers()
1296                .get("Location")
1297                .context("missing location header in /rpc response")?
1298                .to_str()
1299                .map_err(EstablishConnectionError::other)?
1300                .to_string();
1301            Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
1302        }
1303    }
1304
1305    fn establish_websocket_connection(
1306        self: &Arc<Self>,
1307        credentials: &Credentials,
1308        cx: &AsyncApp,
1309    ) -> Task<Result<Connection, EstablishConnectionError>> {
1310        let release_channel = cx
1311            .update(|cx| ReleaseChannel::try_global(cx))
1312            .ok()
1313            .flatten();
1314        let app_version = cx
1315            .update(|cx| AppVersion::global(cx).to_string())
1316            .ok()
1317            .unwrap_or_default();
1318
1319        let http = self.http.clone();
1320        let proxy = http.proxy().cloned();
1321        let user_agent = http.user_agent().cloned();
1322        let credentials = credentials.clone();
1323        let rpc_url = self.rpc_url(http, release_channel);
1324        let system_id = self.telemetry.system_id();
1325        let metrics_id = self.telemetry.metrics_id();
1326        cx.spawn(async move |cx| {
1327            use HttpOrHttps::*;
1328
1329            #[derive(Debug)]
1330            enum HttpOrHttps {
1331                Http,
1332                Https,
1333            }
1334
1335            let mut rpc_url = rpc_url.await?;
1336            let url_scheme = match rpc_url.scheme() {
1337                "https" => Https,
1338                "http" => Http,
1339                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1340            };
1341
1342            let stream = gpui_tokio::Tokio::spawn_result(cx, {
1343                let rpc_url = rpc_url.clone();
1344                async move {
1345                    let rpc_host = rpc_url
1346                        .host_str()
1347                        .zip(rpc_url.port_or_known_default())
1348                        .context("missing host in rpc url")?;
1349                    Ok(match proxy {
1350                        Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
1351                        None => Box::new(TcpStream::connect(rpc_host).await?),
1352                    })
1353                }
1354            })?
1355            .await?;
1356
1357            log::info!("connected to rpc endpoint {}", rpc_url);
1358
1359            rpc_url
1360                .set_scheme(match url_scheme {
1361                    Https => "wss",
1362                    Http => "ws",
1363                })
1364                .unwrap();
1365
1366            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1367            // for us from the RPC URL.
1368            //
1369            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1370            let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1371
1372            // We then modify the request to add our desired headers.
1373            let request_headers = request.headers_mut();
1374            request_headers.insert(
1375                http::header::AUTHORIZATION,
1376                HeaderValue::from_str(&credentials.authorization_header())?,
1377            );
1378            request_headers.insert(
1379                "x-zed-protocol-version",
1380                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1381            );
1382            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1383            request_headers.insert(
1384                "x-zed-release-channel",
1385                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1386            );
1387            if let Some(user_agent) = user_agent {
1388                request_headers.insert(http::header::USER_AGENT, user_agent);
1389            }
1390            if let Some(system_id) = system_id {
1391                request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1392            }
1393            if let Some(metrics_id) = metrics_id {
1394                request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1395            }
1396
1397            let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1398                request,
1399                stream,
1400                Some(Arc::new(http_client_tls::tls_config()).into()),
1401                None,
1402            )
1403            .await?;
1404
1405            Ok(Connection::new(
1406                stream
1407                    .map_err(|error| anyhow!(error))
1408                    .sink_map_err(|error| anyhow!(error)),
1409            ))
1410        })
1411    }
1412
1413    pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1414        let http = self.http.clone();
1415        let this = self.clone();
1416        cx.spawn(async move |cx| {
1417            let background = cx.background_executor().clone();
1418
1419            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1420            cx.update(|cx| {
1421                cx.spawn(async move |cx| {
1422                    let url = open_url_rx.await?;
1423                    cx.update(|cx| cx.open_url(&url))
1424                })
1425                .detach_and_log_err(cx);
1426            })
1427            .log_err();
1428
1429            let credentials = background
1430                .clone()
1431                .spawn(async move {
1432                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1433                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1434                    // any other app running on the user's device.
1435                    let (public_key, private_key) =
1436                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1437                    let public_key_string = String::try_from(public_key)
1438                        .expect("failed to serialize public key for auth");
1439
1440                    if let Some((login, token)) =
1441                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1442                    {
1443                        if !*USE_WEB_LOGIN {
1444                            eprintln!("authenticate as admin {login}, {token}");
1445
1446                            return this
1447                                .authenticate_as_admin(http, login.clone(), token.clone())
1448                                .await;
1449                        }
1450                    }
1451
1452                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1453                    let server =
1454                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1455                    let port = server.server_addr().port();
1456
1457                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1458                    // that the user is signing in from a Zed app running on the same device.
1459                    let mut url = http.build_url(&format!(
1460                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1461                        port, public_key_string
1462                    ));
1463
1464                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1465                        log::info!("impersonating user @{}", impersonate_login);
1466                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1467                    }
1468
1469                    open_url_tx.send(url).log_err();
1470
1471                    #[derive(Deserialize)]
1472                    struct CallbackParams {
1473                        pub user_id: String,
1474                        pub access_token: String,
1475                    }
1476
1477                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1478                    // access token from the query params.
1479                    //
1480                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1481                    // custom URL scheme instead of this local HTTP server.
1482                    let (user_id, access_token) = background
1483                        .spawn(async move {
1484                            for _ in 0..100 {
1485                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1486                                    let path = req.url();
1487                                    let url = Url::parse(&format!("http://example.com{}", path))
1488                                        .context("failed to parse login notification url")?;
1489                                    let callback_params: CallbackParams =
1490                                        serde_urlencoded::from_str(url.query().unwrap_or_default())
1491                                            .context(
1492                                                "failed to parse sign-in callback query parameters",
1493                                            )?;
1494
1495                                    let post_auth_url =
1496                                        http.build_url("/native_app_signin_succeeded");
1497                                    req.respond(
1498                                        tiny_http::Response::empty(302).with_header(
1499                                            tiny_http::Header::from_bytes(
1500                                                &b"Location"[..],
1501                                                post_auth_url.as_bytes(),
1502                                            )
1503                                            .unwrap(),
1504                                        ),
1505                                    )
1506                                    .context("failed to respond to login http request")?;
1507                                    return Ok((
1508                                        callback_params.user_id,
1509                                        callback_params.access_token,
1510                                    ));
1511                                }
1512                            }
1513
1514                            anyhow::bail!("didn't receive login redirect");
1515                        })
1516                        .await?;
1517
1518                    let access_token = private_key
1519                        .decrypt_string(&access_token)
1520                        .context("failed to decrypt access token")?;
1521
1522                    Ok(Credentials {
1523                        user_id: user_id.parse()?,
1524                        access_token,
1525                    })
1526                })
1527                .await?;
1528
1529            cx.update(|cx| cx.activate(true))?;
1530            Ok(credentials)
1531        })
1532    }
1533
1534    async fn authenticate_as_admin(
1535        self: &Arc<Self>,
1536        http: Arc<HttpClientWithUrl>,
1537        login: String,
1538        api_token: String,
1539    ) -> Result<Credentials> {
1540        #[derive(Serialize)]
1541        struct ImpersonateUserBody {
1542            github_login: String,
1543        }
1544
1545        #[derive(Deserialize)]
1546        struct ImpersonateUserResponse {
1547            user_id: u64,
1548            access_token: String,
1549        }
1550
1551        let url = self
1552            .http
1553            .build_zed_cloud_url("/internal/users/impersonate", &[])?;
1554        let request = Request::post(url.as_str())
1555            .header("Content-Type", "application/json")
1556            .header("Authorization", format!("Bearer {api_token}"))
1557            .body(
1558                serde_json::to_string(&ImpersonateUserBody {
1559                    github_login: login,
1560                })?
1561                .into(),
1562            )?;
1563
1564        let mut response = http.send(request).await?;
1565        let mut body = String::new();
1566        response.body_mut().read_to_string(&mut body).await?;
1567        anyhow::ensure!(
1568            response.status().is_success(),
1569            "admin user request failed {} - {}",
1570            response.status().as_u16(),
1571            body,
1572        );
1573        let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
1574
1575        Ok(Credentials {
1576            user_id: response.user_id,
1577            access_token: response.access_token,
1578        })
1579    }
1580
1581    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1582        self.state.write().credentials = None;
1583        self.cloud_client.clear_credentials();
1584        self.disconnect(cx);
1585
1586        if self.has_credentials(cx).await {
1587            self.credentials_provider
1588                .delete_credentials(cx)
1589                .await
1590                .log_err();
1591        }
1592    }
1593
1594    pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1595        self.peer.teardown();
1596        self.set_status(Status::SignedOut, cx);
1597    }
1598
1599    pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1600        self.peer.teardown();
1601        self.set_status(Status::ConnectionLost, cx);
1602    }
1603
1604    fn connection_id(&self) -> Result<ConnectionId> {
1605        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1606            Ok(connection_id)
1607        } else {
1608            anyhow::bail!("not connected");
1609        }
1610    }
1611
1612    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1613        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1614        self.peer.send(self.connection_id()?, message)
1615    }
1616
1617    pub fn request<T: RequestMessage>(
1618        &self,
1619        request: T,
1620    ) -> impl Future<Output = Result<T::Response>> + use<T> {
1621        self.request_envelope(request)
1622            .map_ok(|envelope| envelope.payload)
1623    }
1624
1625    pub fn request_stream<T: RequestMessage>(
1626        &self,
1627        request: T,
1628    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1629        let client_id = self.id.load(Ordering::SeqCst);
1630        log::debug!(
1631            "rpc request start. client_id:{}. name:{}",
1632            client_id,
1633            T::NAME
1634        );
1635        let response = self
1636            .connection_id()
1637            .map(|conn_id| self.peer.request_stream(conn_id, request));
1638        async move {
1639            let response = response?.await;
1640            log::debug!(
1641                "rpc request finish. client_id:{}. name:{}",
1642                client_id,
1643                T::NAME
1644            );
1645            response
1646        }
1647    }
1648
1649    pub fn request_envelope<T: RequestMessage>(
1650        &self,
1651        request: T,
1652    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1653        let client_id = self.id();
1654        log::debug!(
1655            "rpc request start. client_id:{}. name:{}",
1656            client_id,
1657            T::NAME
1658        );
1659        let response = self
1660            .connection_id()
1661            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1662        async move {
1663            let response = response?.await;
1664            log::debug!(
1665                "rpc request finish. client_id:{}. name:{}",
1666                client_id,
1667                T::NAME
1668            );
1669            response
1670        }
1671    }
1672
1673    pub fn request_dynamic(
1674        &self,
1675        envelope: proto::Envelope,
1676        request_type: &'static str,
1677    ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1678        let client_id = self.id();
1679        log::debug!(
1680            "rpc request start. client_id:{}. name:{}",
1681            client_id,
1682            request_type
1683        );
1684        let response = self
1685            .connection_id()
1686            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1687        async move {
1688            let response = response?.await;
1689            log::debug!(
1690                "rpc request finish. client_id:{}. name:{}",
1691                client_id,
1692                request_type
1693            );
1694            Ok(response?.0)
1695        }
1696    }
1697
1698    fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1699        let sender_id = message.sender_id();
1700        let request_id = message.message_id();
1701        let type_name = message.payload_type_name();
1702        let original_sender_id = message.original_sender_id();
1703
1704        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1705            &self.handler_set,
1706            message,
1707            self.clone().into(),
1708            cx.clone(),
1709        ) {
1710            let client_id = self.id();
1711            log::debug!(
1712                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1713                client_id,
1714                original_sender_id,
1715                type_name
1716            );
1717            cx.spawn(async move |_| match future.await {
1718                Ok(()) => {
1719                    log::debug!("rpc message handled. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}");
1720                }
1721                Err(error) => {
1722                    log::error!("error handling message. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}, error:{error:#}");
1723                }
1724            })
1725            .detach();
1726        } else {
1727            log::info!("unhandled message {}", type_name);
1728            self.peer
1729                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1730                .log_err();
1731        }
1732    }
1733
1734    pub fn add_message_to_client_handler(
1735        self: &Arc<Client>,
1736        handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
1737    ) {
1738        self.message_to_client_handlers
1739            .lock()
1740            .push(Box::new(handler));
1741    }
1742
1743    fn handle_message_to_client(self: &Arc<Client>, message: MessageToClient, cx: &AsyncApp) {
1744        cx.update(|cx| {
1745            for handler in self.message_to_client_handlers.lock().iter() {
1746                handler(&message, cx);
1747            }
1748        })
1749        .ok();
1750    }
1751
1752    pub fn telemetry(&self) -> &Arc<Telemetry> {
1753        &self.telemetry
1754    }
1755}
1756
1757impl ProtoClient for Client {
1758    fn request(
1759        &self,
1760        envelope: proto::Envelope,
1761        request_type: &'static str,
1762    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1763        self.request_dynamic(envelope, request_type).boxed()
1764    }
1765
1766    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1767        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1768        let connection_id = self.connection_id()?;
1769        self.peer.send_dynamic(connection_id, envelope)
1770    }
1771
1772    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1773        log::debug!(
1774            "rpc respond. client_id:{}, name:{}",
1775            self.id(),
1776            message_type
1777        );
1778        let connection_id = self.connection_id()?;
1779        self.peer.send_dynamic(connection_id, envelope)
1780    }
1781
1782    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1783        &self.handler_set
1784    }
1785
1786    fn is_via_collab(&self) -> bool {
1787        true
1788    }
1789}
1790
1791/// prefix for the zed:// url scheme
1792pub const ZED_URL_SCHEME: &str = "zed";
1793
1794/// Parses the given link into a Zed link.
1795///
1796/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1797/// Returns [`None`] otherwise.
1798pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
1799    let server_url = &ClientSettings::get_global(cx).server_url;
1800    if let Some(stripped) = link
1801        .strip_prefix(server_url)
1802        .and_then(|result| result.strip_prefix('/'))
1803    {
1804        return Some(stripped);
1805    }
1806    if let Some(stripped) = link
1807        .strip_prefix(ZED_URL_SCHEME)
1808        .and_then(|result| result.strip_prefix("://"))
1809    {
1810        return Some(stripped);
1811    }
1812
1813    None
1814}
1815
1816#[cfg(test)]
1817mod tests {
1818    use super::*;
1819    use crate::test::{FakeServer, parse_authorization_header};
1820
1821    use clock::FakeSystemClock;
1822    use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1823    use http_client::FakeHttpClient;
1824    use parking_lot::Mutex;
1825    use proto::TypedEnvelope;
1826    use settings::SettingsStore;
1827    use std::future;
1828
1829    #[gpui::test(iterations = 10)]
1830    async fn test_reconnection(cx: &mut TestAppContext) {
1831        init_test(cx);
1832        let user_id = 5;
1833        let client = cx.update(|cx| {
1834            Client::new(
1835                Arc::new(FakeSystemClock::new()),
1836                FakeHttpClient::with_404_response(),
1837                cx,
1838            )
1839        });
1840        let server = FakeServer::for_client(user_id, &client, cx).await;
1841        let mut status = client.status();
1842        assert!(matches!(
1843            status.next().await,
1844            Some(Status::Connected { .. })
1845        ));
1846        assert_eq!(server.auth_count(), 1);
1847
1848        server.forbid_connections();
1849        server.disconnect();
1850        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1851
1852        server.allow_connections();
1853        cx.executor().advance_clock(Duration::from_secs(10));
1854        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1855        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1856
1857        server.forbid_connections();
1858        server.disconnect();
1859        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1860
1861        // Clear cached credentials after authentication fails
1862        server.roll_access_token();
1863        server.allow_connections();
1864        cx.executor().run_until_parked();
1865        cx.executor().advance_clock(Duration::from_secs(10));
1866        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1867        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1868    }
1869
1870    #[gpui::test(iterations = 10)]
1871    async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) {
1872        init_test(cx);
1873        let http_client = FakeHttpClient::with_200_response();
1874        let client =
1875            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1876        let server = FakeServer::for_client(42, &client, cx).await;
1877        let mut status = client.status();
1878        assert!(matches!(
1879            status.next().await,
1880            Some(Status::Connected { .. })
1881        ));
1882        assert_eq!(server.auth_count(), 1);
1883
1884        // Simulate an auth failure during reconnection.
1885        http_client
1886            .as_fake()
1887            .replace_handler(|_, _request| async move {
1888                Ok(http_client::Response::builder()
1889                    .status(503)
1890                    .body("".into())
1891                    .unwrap())
1892            });
1893        server.disconnect();
1894        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1895
1896        // Restore the ability to authenticate.
1897        http_client
1898            .as_fake()
1899            .replace_handler(|_, _request| async move {
1900                Ok(http_client::Response::builder()
1901                    .status(200)
1902                    .body("".into())
1903                    .unwrap())
1904            });
1905        cx.executor().advance_clock(Duration::from_secs(10));
1906        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1907        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1908    }
1909
1910    #[gpui::test(iterations = 10)]
1911    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1912        init_test(cx);
1913        let user_id = 5;
1914        let client = cx.update(|cx| {
1915            Client::new(
1916                Arc::new(FakeSystemClock::new()),
1917                FakeHttpClient::with_404_response(),
1918                cx,
1919            )
1920        });
1921        let mut status = client.status();
1922
1923        // Time out when client tries to connect.
1924        client.override_authenticate(move |cx| {
1925            cx.background_spawn(async move {
1926                Ok(Credentials {
1927                    user_id,
1928                    access_token: "token".into(),
1929                })
1930            })
1931        });
1932        client.override_establish_connection(|_, cx| {
1933            cx.background_spawn(async move {
1934                future::pending::<()>().await;
1935                unreachable!()
1936            })
1937        });
1938        let auth_and_connect = cx.spawn({
1939            let client = client.clone();
1940            |cx| async move { client.connect(false, &cx).await }
1941        });
1942        executor.run_until_parked();
1943        assert!(matches!(status.next().await, Some(Status::Connecting)));
1944
1945        executor.advance_clock(CONNECTION_TIMEOUT);
1946        assert!(matches!(status.next().await, Some(Status::ConnectionError)));
1947        auth_and_connect.await.into_response().unwrap_err();
1948
1949        // Allow the connection to be established.
1950        let server = FakeServer::for_client(user_id, &client, cx).await;
1951        assert!(matches!(
1952            status.next().await,
1953            Some(Status::Connected { .. })
1954        ));
1955
1956        // Disconnect client.
1957        server.forbid_connections();
1958        server.disconnect();
1959        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1960
1961        // Time out when re-establishing the connection.
1962        server.allow_connections();
1963        client.override_establish_connection(|_, cx| {
1964            cx.background_spawn(async move {
1965                future::pending::<()>().await;
1966                unreachable!()
1967            })
1968        });
1969        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1970        assert!(matches!(status.next().await, Some(Status::Reconnecting)));
1971
1972        executor.advance_clock(CONNECTION_TIMEOUT);
1973        assert!(matches!(
1974            status.next().await,
1975            Some(Status::ReconnectionError { .. })
1976        ));
1977    }
1978
1979    #[gpui::test(iterations = 10)]
1980    async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
1981        init_test(cx);
1982        let auth_count = Arc::new(Mutex::new(0));
1983        let http_client = FakeHttpClient::create(|_request| async move {
1984            Ok(http_client::Response::builder()
1985                .status(200)
1986                .body("".into())
1987                .unwrap())
1988        });
1989        let client =
1990            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1991        client.override_authenticate({
1992            let auth_count = auth_count.clone();
1993            move |cx| {
1994                let auth_count = auth_count.clone();
1995                cx.background_spawn(async move {
1996                    *auth_count.lock() += 1;
1997                    Ok(Credentials {
1998                        user_id: 1,
1999                        access_token: auth_count.lock().to_string(),
2000                    })
2001                })
2002            }
2003        });
2004
2005        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2006        assert_eq!(*auth_count.lock(), 1);
2007        assert_eq!(credentials.access_token, "1");
2008
2009        // If credentials are still valid, signing in doesn't trigger authentication.
2010        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2011        assert_eq!(*auth_count.lock(), 1);
2012        assert_eq!(credentials.access_token, "1");
2013
2014        // If the server is unavailable, signing in doesn't trigger authentication.
2015        http_client
2016            .as_fake()
2017            .replace_handler(|_, _request| async move {
2018                Ok(http_client::Response::builder()
2019                    .status(503)
2020                    .body("".into())
2021                    .unwrap())
2022            });
2023        client.sign_in(false, &cx.to_async()).await.unwrap_err();
2024        assert_eq!(*auth_count.lock(), 1);
2025
2026        // If credentials became invalid, signing in triggers authentication.
2027        http_client
2028            .as_fake()
2029            .replace_handler(|_, request| async move {
2030                let credentials = parse_authorization_header(&request).unwrap();
2031                if credentials.access_token == "2" {
2032                    Ok(http_client::Response::builder()
2033                        .status(200)
2034                        .body("".into())
2035                        .unwrap())
2036                } else {
2037                    Ok(http_client::Response::builder()
2038                        .status(401)
2039                        .body("".into())
2040                        .unwrap())
2041                }
2042            });
2043        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2044        assert_eq!(*auth_count.lock(), 2);
2045        assert_eq!(credentials.access_token, "2");
2046    }
2047
2048    #[gpui::test(iterations = 10)]
2049    async fn test_authenticating_more_than_once(
2050        cx: &mut TestAppContext,
2051        executor: BackgroundExecutor,
2052    ) {
2053        init_test(cx);
2054        let auth_count = Arc::new(Mutex::new(0));
2055        let dropped_auth_count = Arc::new(Mutex::new(0));
2056        let client = cx.update(|cx| {
2057            Client::new(
2058                Arc::new(FakeSystemClock::new()),
2059                FakeHttpClient::with_404_response(),
2060                cx,
2061            )
2062        });
2063        client.override_authenticate({
2064            let auth_count = auth_count.clone();
2065            let dropped_auth_count = dropped_auth_count.clone();
2066            move |cx| {
2067                let auth_count = auth_count.clone();
2068                let dropped_auth_count = dropped_auth_count.clone();
2069                cx.background_spawn(async move {
2070                    *auth_count.lock() += 1;
2071                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
2072                    future::pending::<()>().await;
2073                    unreachable!()
2074                })
2075            }
2076        });
2077
2078        let _authenticate = cx.spawn({
2079            let client = client.clone();
2080            move |cx| async move { client.connect(false, &cx).await }
2081        });
2082        executor.run_until_parked();
2083        assert_eq!(*auth_count.lock(), 1);
2084        assert_eq!(*dropped_auth_count.lock(), 0);
2085
2086        let _authenticate = cx.spawn(|cx| async move { client.connect(false, &cx).await });
2087        executor.run_until_parked();
2088        assert_eq!(*auth_count.lock(), 2);
2089        assert_eq!(*dropped_auth_count.lock(), 1);
2090    }
2091
2092    #[gpui::test]
2093    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
2094        init_test(cx);
2095        let user_id = 5;
2096        let client = cx.update(|cx| {
2097            Client::new(
2098                Arc::new(FakeSystemClock::new()),
2099                FakeHttpClient::with_404_response(),
2100                cx,
2101            )
2102        });
2103        let server = FakeServer::for_client(user_id, &client, cx).await;
2104
2105        let (done_tx1, done_rx1) = smol::channel::unbounded();
2106        let (done_tx2, done_rx2) = smol::channel::unbounded();
2107        AnyProtoClient::from(client.clone()).add_entity_message_handler(
2108            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, cx| {
2109                match entity.read_with(&cx, |entity, _| entity.id).unwrap() {
2110                    1 => done_tx1.try_send(()).unwrap(),
2111                    2 => done_tx2.try_send(()).unwrap(),
2112                    _ => unreachable!(),
2113                }
2114                async { Ok(()) }
2115            },
2116        );
2117        let entity1 = cx.new(|_| TestEntity {
2118            id: 1,
2119            subscription: None,
2120        });
2121        let entity2 = cx.new(|_| TestEntity {
2122            id: 2,
2123            subscription: None,
2124        });
2125        let entity3 = cx.new(|_| TestEntity {
2126            id: 3,
2127            subscription: None,
2128        });
2129
2130        let _subscription1 = client
2131            .subscribe_to_entity(1)
2132            .unwrap()
2133            .set_entity(&entity1, &cx.to_async());
2134        let _subscription2 = client
2135            .subscribe_to_entity(2)
2136            .unwrap()
2137            .set_entity(&entity2, &cx.to_async());
2138        // Ensure dropping a subscription for the same entity type still allows receiving of
2139        // messages for other entity IDs of the same type.
2140        let subscription3 = client
2141            .subscribe_to_entity(3)
2142            .unwrap()
2143            .set_entity(&entity3, &cx.to_async());
2144        drop(subscription3);
2145
2146        server.send(proto::JoinProject {
2147            project_id: 1,
2148            committer_name: None,
2149            committer_email: None,
2150        });
2151        server.send(proto::JoinProject {
2152            project_id: 2,
2153            committer_name: None,
2154            committer_email: None,
2155        });
2156        done_rx1.recv().await.unwrap();
2157        done_rx2.recv().await.unwrap();
2158    }
2159
2160    #[gpui::test]
2161    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2162        init_test(cx);
2163        let user_id = 5;
2164        let client = cx.update(|cx| {
2165            Client::new(
2166                Arc::new(FakeSystemClock::new()),
2167                FakeHttpClient::with_404_response(),
2168                cx,
2169            )
2170        });
2171        let server = FakeServer::for_client(user_id, &client, cx).await;
2172
2173        let entity = cx.new(|_| TestEntity::default());
2174        let (done_tx1, _done_rx1) = smol::channel::unbounded();
2175        let (done_tx2, done_rx2) = smol::channel::unbounded();
2176        let subscription1 = client.add_message_handler(
2177            entity.downgrade(),
2178            move |_, _: TypedEnvelope<proto::Ping>, _| {
2179                done_tx1.try_send(()).unwrap();
2180                async { Ok(()) }
2181            },
2182        );
2183        drop(subscription1);
2184        let _subscription2 = client.add_message_handler(
2185            entity.downgrade(),
2186            move |_, _: TypedEnvelope<proto::Ping>, _| {
2187                done_tx2.try_send(()).unwrap();
2188                async { Ok(()) }
2189            },
2190        );
2191        server.send(proto::Ping {});
2192        done_rx2.recv().await.unwrap();
2193    }
2194
2195    #[gpui::test]
2196    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2197        init_test(cx);
2198        let user_id = 5;
2199        let client = cx.update(|cx| {
2200            Client::new(
2201                Arc::new(FakeSystemClock::new()),
2202                FakeHttpClient::with_404_response(),
2203                cx,
2204            )
2205        });
2206        let server = FakeServer::for_client(user_id, &client, cx).await;
2207
2208        let entity = cx.new(|_| TestEntity::default());
2209        let (done_tx, done_rx) = smol::channel::unbounded();
2210        let subscription = client.add_message_handler(
2211            entity.clone().downgrade(),
2212            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
2213                entity
2214                    .update(&mut cx, |entity, _| entity.subscription.take())
2215                    .unwrap();
2216                done_tx.try_send(()).unwrap();
2217                async { Ok(()) }
2218            },
2219        );
2220        entity.update(cx, |entity, _| {
2221            entity.subscription = Some(subscription);
2222        });
2223        server.send(proto::Ping {});
2224        done_rx.recv().await.unwrap();
2225    }
2226
2227    #[derive(Default)]
2228    struct TestEntity {
2229        id: usize,
2230        subscription: Option<Subscription>,
2231    }
2232
2233    fn init_test(cx: &mut TestAppContext) {
2234        cx.update(|cx| {
2235            let settings_store = SettingsStore::test(cx);
2236            cx.set_global(settings_store);
2237            init_settings(cx);
2238        });
2239    }
2240}