client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod proxy;
   5pub mod telemetry;
   6pub mod user;
   7pub mod zed_urls;
   8
   9use anyhow::{Context as _, Result, anyhow};
  10use async_tungstenite::tungstenite::{
  11    client::IntoClientRequest,
  12    error::Error as WebsocketError,
  13    http::{HeaderValue, Request, StatusCode},
  14};
  15use clock::SystemClock;
  16use cloud_api_client::CloudApiClient;
  17use cloud_api_client::websocket_protocol::MessageToClient;
  18use credentials_provider::CredentialsProvider;
  19use feature_flags::FeatureFlagAppExt as _;
  20use futures::{
  21    AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
  22    channel::oneshot, future::BoxFuture,
  23};
  24use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
  25use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
  26use parking_lot::RwLock;
  27use postage::watch;
  28use proxy::connect_proxy_stream;
  29use rand::prelude::*;
  30use release_channel::{AppVersion, ReleaseChannel};
  31use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  32use schemars::JsonSchema;
  33use serde::{Deserialize, Serialize};
  34use settings::{Settings, SettingsKey, SettingsSources, SettingsUi};
  35use std::{
  36    any::TypeId,
  37    convert::TryFrom,
  38    fmt::Write as _,
  39    future::Future,
  40    marker::PhantomData,
  41    path::PathBuf,
  42    sync::{
  43        Arc, LazyLock, Weak,
  44        atomic::{AtomicU64, Ordering},
  45    },
  46    time::{Duration, Instant},
  47};
  48use std::{cmp, pin::Pin};
  49use telemetry::Telemetry;
  50use thiserror::Error;
  51use tokio::net::TcpStream;
  52use url::Url;
  53use util::{ConnectionResult, ResultExt};
  54
  55pub use rpc::*;
  56pub use telemetry_events::Event;
  57pub use user::*;
  58
  59static ZED_SERVER_URL: LazyLock<Option<String>> =
  60    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  61static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  62
  63pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  64    std::env::var("ZED_IMPERSONATE")
  65        .ok()
  66        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  67});
  68
  69pub static USE_WEB_LOGIN: LazyLock<bool> = LazyLock::new(|| std::env::var("ZED_WEB_LOGIN").is_ok());
  70
  71pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  72    std::env::var("ZED_ADMIN_API_TOKEN")
  73        .ok()
  74        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  75});
  76
  77pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  78    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  79
  80pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  81    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").is_ok_and(|e| !e.is_empty()));
  82
  83pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  84pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
  85pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  86
  87actions!(
  88    client,
  89    [
  90        /// Signs in to Zed account.
  91        SignIn,
  92        /// Signs out of Zed account.
  93        SignOut,
  94        /// Reconnects to the collaboration server.
  95        Reconnect
  96    ]
  97);
  98
  99#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, SettingsUi, SettingsKey)]
 100#[settings_key(None)]
 101pub struct ClientSettingsContent {
 102    server_url: Option<String>,
 103}
 104
 105#[derive(Deserialize)]
 106pub struct ClientSettings {
 107    pub server_url: String,
 108}
 109
 110impl Settings for ClientSettings {
 111    type FileContent = ClientSettingsContent;
 112
 113    fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
 114        let mut result = sources.json_merge::<Self>()?;
 115        if let Some(server_url) = &*ZED_SERVER_URL {
 116            result.server_url.clone_from(server_url)
 117        }
 118        Ok(result)
 119    }
 120
 121    fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
 122}
 123
 124#[derive(Default, Clone, Serialize, Deserialize, JsonSchema, SettingsUi, SettingsKey)]
 125#[settings_key(None)]
 126pub struct ProxySettingsContent {
 127    proxy: Option<String>,
 128}
 129
 130#[derive(Deserialize, Default)]
 131pub struct ProxySettings {
 132    pub proxy: Option<String>,
 133}
 134
 135impl ProxySettings {
 136    pub fn proxy_url(&self) -> Option<Url> {
 137        self.proxy
 138            .as_ref()
 139            .and_then(|input| {
 140                input
 141                    .parse::<Url>()
 142                    .inspect_err(|e| log::error!("Error parsing proxy settings: {}", e))
 143                    .ok()
 144            })
 145            .or_else(read_proxy_from_env)
 146    }
 147}
 148
 149impl Settings for ProxySettings {
 150    type FileContent = ProxySettingsContent;
 151
 152    fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
 153        Ok(Self {
 154            proxy: sources
 155                .user
 156                .or(sources.server)
 157                .and_then(|value| value.proxy.clone())
 158                .or(sources.default.proxy.clone()),
 159        })
 160    }
 161
 162    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
 163        vscode.string_setting("http.proxy", &mut current.proxy);
 164    }
 165}
 166
 167pub fn init_settings(cx: &mut App) {
 168    TelemetrySettings::register(cx);
 169    ClientSettings::register(cx);
 170    ProxySettings::register(cx);
 171}
 172
 173pub fn init(client: &Arc<Client>, cx: &mut App) {
 174    let client = Arc::downgrade(client);
 175    cx.on_action({
 176        let client = client.clone();
 177        move |_: &SignIn, cx| {
 178            if let Some(client) = client.upgrade() {
 179                cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, cx).await)
 180                    .detach_and_log_err(cx);
 181            }
 182        }
 183    });
 184
 185    cx.on_action({
 186        let client = client.clone();
 187        move |_: &SignOut, cx| {
 188            if let Some(client) = client.upgrade() {
 189                cx.spawn(async move |cx| {
 190                    client.sign_out(cx).await;
 191                })
 192                .detach();
 193            }
 194        }
 195    });
 196
 197    cx.on_action({
 198        let client = client;
 199        move |_: &Reconnect, cx| {
 200            if let Some(client) = client.upgrade() {
 201                cx.spawn(async move |cx| {
 202                    client.reconnect(cx);
 203                })
 204                .detach();
 205            }
 206        }
 207    });
 208}
 209
 210pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
 211
 212struct GlobalClient(Arc<Client>);
 213
 214impl Global for GlobalClient {}
 215
 216pub struct Client {
 217    id: AtomicU64,
 218    peer: Arc<Peer>,
 219    http: Arc<HttpClientWithUrl>,
 220    cloud_client: Arc<CloudApiClient>,
 221    telemetry: Arc<Telemetry>,
 222    credentials_provider: ClientCredentialsProvider,
 223    state: RwLock<ClientState>,
 224    handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
 225    message_to_client_handlers: parking_lot::Mutex<Vec<MessageToClientHandler>>,
 226
 227    #[allow(clippy::type_complexity)]
 228    #[cfg(any(test, feature = "test-support"))]
 229    authenticate:
 230        RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
 231
 232    #[allow(clippy::type_complexity)]
 233    #[cfg(any(test, feature = "test-support"))]
 234    establish_connection: RwLock<
 235        Option<
 236            Box<
 237                dyn 'static
 238                    + Send
 239                    + Sync
 240                    + Fn(
 241                        &Credentials,
 242                        &AsyncApp,
 243                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 244            >,
 245        >,
 246    >,
 247
 248    #[cfg(any(test, feature = "test-support"))]
 249    rpc_url: RwLock<Option<Url>>,
 250}
 251
 252#[derive(Error, Debug)]
 253pub enum EstablishConnectionError {
 254    #[error("upgrade required")]
 255    UpgradeRequired,
 256    #[error("unauthorized")]
 257    Unauthorized,
 258    #[error("{0}")]
 259    Other(#[from] anyhow::Error),
 260    #[error("{0}")]
 261    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 262    #[error("{0}")]
 263    Io(#[from] std::io::Error),
 264    #[error("{0}")]
 265    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 266}
 267
 268impl From<WebsocketError> for EstablishConnectionError {
 269    fn from(error: WebsocketError) -> Self {
 270        if let WebsocketError::Http(response) = &error {
 271            match response.status() {
 272                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 273                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 274                _ => {}
 275            }
 276        }
 277        EstablishConnectionError::Other(error.into())
 278    }
 279}
 280
 281impl EstablishConnectionError {
 282    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 283        Self::Other(error.into())
 284    }
 285}
 286
 287#[derive(Copy, Clone, Debug, PartialEq)]
 288pub enum Status {
 289    SignedOut,
 290    UpgradeRequired,
 291    Authenticating,
 292    Authenticated,
 293    AuthenticationError,
 294    Connecting,
 295    ConnectionError,
 296    Connected {
 297        peer_id: PeerId,
 298        connection_id: ConnectionId,
 299    },
 300    ConnectionLost,
 301    Reauthenticating,
 302    Reauthenticated,
 303    Reconnecting,
 304    ReconnectionError {
 305        next_reconnection: Instant,
 306    },
 307}
 308
 309impl Status {
 310    pub fn is_connected(&self) -> bool {
 311        matches!(self, Self::Connected { .. })
 312    }
 313
 314    pub fn was_connected(&self) -> bool {
 315        matches!(
 316            self,
 317            Self::ConnectionLost
 318                | Self::Reauthenticating
 319                | Self::Reauthenticated
 320                | Self::Reconnecting
 321        )
 322    }
 323
 324    /// Returns whether the client is currently connected or was connected at some point.
 325    pub fn is_or_was_connected(&self) -> bool {
 326        self.is_connected() || self.was_connected()
 327    }
 328
 329    pub fn is_signing_in(&self) -> bool {
 330        matches!(
 331            self,
 332            Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
 333        )
 334    }
 335
 336    pub fn is_signed_out(&self) -> bool {
 337        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 338    }
 339}
 340
 341struct ClientState {
 342    credentials: Option<Credentials>,
 343    status: (watch::Sender<Status>, watch::Receiver<Status>),
 344    _reconnect_task: Option<Task<()>>,
 345}
 346
 347#[derive(Clone, Debug, Eq, PartialEq)]
 348pub struct Credentials {
 349    pub user_id: u64,
 350    pub access_token: String,
 351}
 352
 353impl Credentials {
 354    pub fn authorization_header(&self) -> String {
 355        format!("{} {}", self.user_id, self.access_token)
 356    }
 357}
 358
 359pub struct ClientCredentialsProvider {
 360    provider: Arc<dyn CredentialsProvider>,
 361}
 362
 363impl ClientCredentialsProvider {
 364    pub fn new(cx: &App) -> Self {
 365        Self {
 366            provider: <dyn CredentialsProvider>::global(cx),
 367        }
 368    }
 369
 370    fn server_url(&self, cx: &AsyncApp) -> Result<String> {
 371        cx.update(|cx| ClientSettings::get_global(cx).server_url.clone())
 372    }
 373
 374    /// Reads the credentials from the provider.
 375    fn read_credentials<'a>(
 376        &'a self,
 377        cx: &'a AsyncApp,
 378    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
 379        async move {
 380            if IMPERSONATE_LOGIN.is_some() {
 381                return None;
 382            }
 383
 384            let server_url = self.server_url(cx).ok()?;
 385            let (user_id, access_token) = self
 386                .provider
 387                .read_credentials(&server_url, cx)
 388                .await
 389                .log_err()
 390                .flatten()?;
 391
 392            Some(Credentials {
 393                user_id: user_id.parse().ok()?,
 394                access_token: String::from_utf8(access_token).ok()?,
 395            })
 396        }
 397        .boxed_local()
 398    }
 399
 400    /// Writes the credentials to the provider.
 401    fn write_credentials<'a>(
 402        &'a self,
 403        user_id: u64,
 404        access_token: String,
 405        cx: &'a AsyncApp,
 406    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 407        async move {
 408            let server_url = self.server_url(cx)?;
 409            self.provider
 410                .write_credentials(
 411                    &server_url,
 412                    &user_id.to_string(),
 413                    access_token.as_bytes(),
 414                    cx,
 415                )
 416                .await
 417        }
 418        .boxed_local()
 419    }
 420
 421    /// Deletes the credentials from the provider.
 422    fn delete_credentials<'a>(
 423        &'a self,
 424        cx: &'a AsyncApp,
 425    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 426        async move {
 427            let server_url = self.server_url(cx)?;
 428            self.provider.delete_credentials(&server_url, cx).await
 429        }
 430        .boxed_local()
 431    }
 432}
 433
 434impl Default for ClientState {
 435    fn default() -> Self {
 436        Self {
 437            credentials: None,
 438            status: watch::channel_with(Status::SignedOut),
 439            _reconnect_task: None,
 440        }
 441    }
 442}
 443
 444pub enum Subscription {
 445    Entity {
 446        client: Weak<Client>,
 447        id: (TypeId, u64),
 448    },
 449    Message {
 450        client: Weak<Client>,
 451        id: TypeId,
 452    },
 453}
 454
 455impl Drop for Subscription {
 456    fn drop(&mut self) {
 457        match self {
 458            Subscription::Entity { client, id } => {
 459                if let Some(client) = client.upgrade() {
 460                    let mut state = client.handler_set.lock();
 461                    let _ = state.entities_by_type_and_remote_id.remove(id);
 462                }
 463            }
 464            Subscription::Message { client, id } => {
 465                if let Some(client) = client.upgrade() {
 466                    let mut state = client.handler_set.lock();
 467                    let _ = state.entity_types_by_message_type.remove(id);
 468                    let _ = state.message_handlers.remove(id);
 469                }
 470            }
 471        }
 472    }
 473}
 474
 475pub struct PendingEntitySubscription<T: 'static> {
 476    client: Arc<Client>,
 477    remote_id: u64,
 478    _entity_type: PhantomData<T>,
 479    consumed: bool,
 480}
 481
 482impl<T: 'static> PendingEntitySubscription<T> {
 483    pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
 484        self.consumed = true;
 485        let mut handlers = self.client.handler_set.lock();
 486        let id = (TypeId::of::<T>(), self.remote_id);
 487        let Some(EntityMessageSubscriber::Pending(messages)) =
 488            handlers.entities_by_type_and_remote_id.remove(&id)
 489        else {
 490            unreachable!()
 491        };
 492
 493        handlers.entities_by_type_and_remote_id.insert(
 494            id,
 495            EntityMessageSubscriber::Entity {
 496                handle: entity.downgrade().into(),
 497            },
 498        );
 499        drop(handlers);
 500        for message in messages {
 501            let client_id = self.client.id();
 502            let type_name = message.payload_type_name();
 503            let sender_id = message.original_sender_id();
 504            log::debug!(
 505                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 506                client_id,
 507                sender_id,
 508                type_name
 509            );
 510            self.client.handle_message(message, cx);
 511        }
 512        Subscription::Entity {
 513            client: Arc::downgrade(&self.client),
 514            id,
 515        }
 516    }
 517}
 518
 519impl<T: 'static> Drop for PendingEntitySubscription<T> {
 520    fn drop(&mut self) {
 521        if !self.consumed {
 522            let mut state = self.client.handler_set.lock();
 523            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 524                .entities_by_type_and_remote_id
 525                .remove(&(TypeId::of::<T>(), self.remote_id))
 526            {
 527                for message in messages {
 528                    log::info!("unhandled message {}", message.payload_type_name());
 529                }
 530            }
 531        }
 532    }
 533}
 534
 535#[derive(Copy, Clone, Deserialize, Debug)]
 536pub struct TelemetrySettings {
 537    pub diagnostics: bool,
 538    pub metrics: bool,
 539}
 540
 541/// Control what info is collected by Zed.
 542#[derive(Default, Clone, Serialize, Deserialize, JsonSchema, Debug, SettingsUi, SettingsKey)]
 543#[settings_key(key = "telemetry")]
 544pub struct TelemetrySettingsContent {
 545    /// Send debug info like crash reports.
 546    ///
 547    /// Default: true
 548    pub diagnostics: Option<bool>,
 549    /// Send anonymized usage data like what languages you're using Zed with.
 550    ///
 551    /// Default: true
 552    pub metrics: Option<bool>,
 553}
 554
 555impl settings::Settings for TelemetrySettings {
 556    type FileContent = TelemetrySettingsContent;
 557
 558    fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
 559        sources.json_merge()
 560    }
 561
 562    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
 563        vscode.enum_setting("telemetry.telemetryLevel", &mut current.metrics, |s| {
 564            Some(s == "all")
 565        });
 566        vscode.enum_setting("telemetry.telemetryLevel", &mut current.diagnostics, |s| {
 567            Some(matches!(s, "all" | "error" | "crash"))
 568        });
 569        // we could translate telemetry.telemetryLevel, but just because users didn't want
 570        // to send microsoft telemetry doesn't mean they don't want to send it to zed. their
 571        // all/error/crash/off correspond to combinations of our "diagnostics" and "metrics".
 572    }
 573}
 574
 575impl Client {
 576    pub fn new(
 577        clock: Arc<dyn SystemClock>,
 578        http: Arc<HttpClientWithUrl>,
 579        cx: &mut App,
 580    ) -> Arc<Self> {
 581        Arc::new(Self {
 582            id: AtomicU64::new(0),
 583            peer: Peer::new(0),
 584            telemetry: Telemetry::new(clock, http.clone(), cx),
 585            cloud_client: Arc::new(CloudApiClient::new(http.clone())),
 586            http,
 587            credentials_provider: ClientCredentialsProvider::new(cx),
 588            state: Default::default(),
 589            handler_set: Default::default(),
 590            message_to_client_handlers: parking_lot::Mutex::new(Vec::new()),
 591
 592            #[cfg(any(test, feature = "test-support"))]
 593            authenticate: Default::default(),
 594            #[cfg(any(test, feature = "test-support"))]
 595            establish_connection: Default::default(),
 596            #[cfg(any(test, feature = "test-support"))]
 597            rpc_url: RwLock::default(),
 598        })
 599    }
 600
 601    pub fn production(cx: &mut App) -> Arc<Self> {
 602        let clock = Arc::new(clock::RealSystemClock);
 603        let http = Arc::new(HttpClientWithUrl::new_url(
 604            cx.http_client(),
 605            &ClientSettings::get_global(cx).server_url,
 606            cx.http_client().proxy().cloned(),
 607        ));
 608        Self::new(clock, http, cx)
 609    }
 610
 611    pub fn id(&self) -> u64 {
 612        self.id.load(Ordering::SeqCst)
 613    }
 614
 615    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 616        self.http.clone()
 617    }
 618
 619    pub fn cloud_client(&self) -> Arc<CloudApiClient> {
 620        self.cloud_client.clone()
 621    }
 622
 623    pub fn set_id(&self, id: u64) -> &Self {
 624        self.id.store(id, Ordering::SeqCst);
 625        self
 626    }
 627
 628    #[cfg(any(test, feature = "test-support"))]
 629    pub fn teardown(&self) {
 630        let mut state = self.state.write();
 631        state._reconnect_task.take();
 632        self.handler_set.lock().clear();
 633        self.peer.teardown();
 634    }
 635
 636    #[cfg(any(test, feature = "test-support"))]
 637    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 638    where
 639        F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
 640    {
 641        *self.authenticate.write() = Some(Box::new(authenticate));
 642        self
 643    }
 644
 645    #[cfg(any(test, feature = "test-support"))]
 646    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 647    where
 648        F: 'static
 649            + Send
 650            + Sync
 651            + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
 652    {
 653        *self.establish_connection.write() = Some(Box::new(connect));
 654        self
 655    }
 656
 657    #[cfg(any(test, feature = "test-support"))]
 658    pub fn override_rpc_url(&self, url: Url) -> &Self {
 659        *self.rpc_url.write() = Some(url);
 660        self
 661    }
 662
 663    pub fn global(cx: &App) -> Arc<Self> {
 664        cx.global::<GlobalClient>().0.clone()
 665    }
 666    pub fn set_global(client: Arc<Client>, cx: &mut App) {
 667        cx.set_global(GlobalClient(client))
 668    }
 669
 670    pub fn user_id(&self) -> Option<u64> {
 671        self.state
 672            .read()
 673            .credentials
 674            .as_ref()
 675            .map(|credentials| credentials.user_id)
 676    }
 677
 678    pub fn peer_id(&self) -> Option<PeerId> {
 679        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 680            Some(*peer_id)
 681        } else {
 682            None
 683        }
 684    }
 685
 686    pub fn status(&self) -> watch::Receiver<Status> {
 687        self.state.read().status.1.clone()
 688    }
 689
 690    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
 691        log::info!("set status on client {}: {:?}", self.id(), status);
 692        let mut state = self.state.write();
 693        *state.status.0.borrow_mut() = status;
 694
 695        match status {
 696            Status::Connected { .. } => {
 697                state._reconnect_task = None;
 698            }
 699            Status::ConnectionLost => {
 700                let client = self.clone();
 701                state._reconnect_task = Some(cx.spawn(async move |cx| {
 702                    #[cfg(any(test, feature = "test-support"))]
 703                    let mut rng = StdRng::seed_from_u64(0);
 704                    #[cfg(not(any(test, feature = "test-support")))]
 705                    let mut rng = StdRng::from_os_rng();
 706
 707                    let mut delay = INITIAL_RECONNECTION_DELAY;
 708                    loop {
 709                        match client.connect(true, cx).await {
 710                            ConnectionResult::Timeout => {
 711                                log::error!("client connect attempt timed out")
 712                            }
 713                            ConnectionResult::ConnectionReset => {
 714                                log::error!("client connect attempt reset")
 715                            }
 716                            ConnectionResult::Result(r) => {
 717                                if let Err(error) = r {
 718                                    log::error!("failed to connect: {error}");
 719                                } else {
 720                                    break;
 721                                }
 722                            }
 723                        }
 724
 725                        if matches!(
 726                            *client.status().borrow(),
 727                            Status::AuthenticationError | Status::ConnectionError
 728                        ) {
 729                            client.set_status(
 730                                Status::ReconnectionError {
 731                                    next_reconnection: Instant::now() + delay,
 732                                },
 733                                cx,
 734                            );
 735                            let jitter = Duration::from_millis(
 736                                rng.random_range(0..delay.as_millis() as u64),
 737                            );
 738                            cx.background_executor().timer(delay + jitter).await;
 739                            delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
 740                        } else {
 741                            break;
 742                        }
 743                    }
 744                }));
 745            }
 746            Status::SignedOut | Status::UpgradeRequired => {
 747                self.telemetry.set_authenticated_user_info(None, false);
 748                state._reconnect_task.take();
 749            }
 750            _ => {}
 751        }
 752    }
 753
 754    pub fn subscribe_to_entity<T>(
 755        self: &Arc<Self>,
 756        remote_id: u64,
 757    ) -> Result<PendingEntitySubscription<T>>
 758    where
 759        T: 'static,
 760    {
 761        let id = (TypeId::of::<T>(), remote_id);
 762
 763        let mut state = self.handler_set.lock();
 764        anyhow::ensure!(
 765            !state.entities_by_type_and_remote_id.contains_key(&id),
 766            "already subscribed to entity"
 767        );
 768
 769        state
 770            .entities_by_type_and_remote_id
 771            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 772
 773        Ok(PendingEntitySubscription {
 774            client: self.clone(),
 775            remote_id,
 776            consumed: false,
 777            _entity_type: PhantomData,
 778        })
 779    }
 780
 781    #[track_caller]
 782    pub fn add_message_handler<M, E, H, F>(
 783        self: &Arc<Self>,
 784        entity: WeakEntity<E>,
 785        handler: H,
 786    ) -> Subscription
 787    where
 788        M: EnvelopedMessage,
 789        E: 'static,
 790        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 791        F: 'static + Future<Output = Result<()>>,
 792    {
 793        self.add_message_handler_impl(entity, move |entity, message, _, cx| {
 794            handler(entity, message, cx)
 795        })
 796    }
 797
 798    fn add_message_handler_impl<M, E, H, F>(
 799        self: &Arc<Self>,
 800        entity: WeakEntity<E>,
 801        handler: H,
 802    ) -> Subscription
 803    where
 804        M: EnvelopedMessage,
 805        E: 'static,
 806        H: 'static
 807            + Sync
 808            + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
 809            + Send
 810            + Sync,
 811        F: 'static + Future<Output = Result<()>>,
 812    {
 813        let message_type_id = TypeId::of::<M>();
 814        let mut state = self.handler_set.lock();
 815        state
 816            .entities_by_message_type
 817            .insert(message_type_id, entity.into());
 818
 819        let prev_handler = state.message_handlers.insert(
 820            message_type_id,
 821            Arc::new(move |subscriber, envelope, client, cx| {
 822                let subscriber = subscriber.downcast::<E>().unwrap();
 823                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 824                handler(subscriber, *envelope, client, cx).boxed_local()
 825            }),
 826        );
 827        if prev_handler.is_some() {
 828            let location = std::panic::Location::caller();
 829            panic!(
 830                "{}:{} registered handler for the same message {} twice",
 831                location.file(),
 832                location.line(),
 833                std::any::type_name::<M>()
 834            );
 835        }
 836
 837        Subscription::Message {
 838            client: Arc::downgrade(self),
 839            id: message_type_id,
 840        }
 841    }
 842
 843    pub fn add_request_handler<M, E, H, F>(
 844        self: &Arc<Self>,
 845        entity: WeakEntity<E>,
 846        handler: H,
 847    ) -> Subscription
 848    where
 849        M: RequestMessage,
 850        E: 'static,
 851        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 852        F: 'static + Future<Output = Result<M::Response>>,
 853    {
 854        self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
 855            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 856        })
 857    }
 858
 859    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 860        receipt: Receipt<T>,
 861        response: F,
 862        client: AnyProtoClient,
 863    ) -> Result<()> {
 864        match response.await {
 865            Ok(response) => {
 866                client.send_response(receipt.message_id, response)?;
 867                Ok(())
 868            }
 869            Err(error) => {
 870                client.send_response(receipt.message_id, error.to_proto())?;
 871                Err(error)
 872            }
 873        }
 874    }
 875
 876    pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
 877        self.credentials_provider
 878            .read_credentials(cx)
 879            .await
 880            .is_some()
 881    }
 882
 883    pub async fn sign_in(
 884        self: &Arc<Self>,
 885        try_provider: bool,
 886        cx: &AsyncApp,
 887    ) -> Result<Credentials> {
 888        let is_reauthenticating = if self.status().borrow().is_signed_out() {
 889            self.set_status(Status::Authenticating, cx);
 890            false
 891        } else {
 892            self.set_status(Status::Reauthenticating, cx);
 893            true
 894        };
 895
 896        let mut credentials = None;
 897
 898        let old_credentials = self.state.read().credentials.clone();
 899        if let Some(old_credentials) = old_credentials
 900            && self.validate_credentials(&old_credentials, cx).await?
 901        {
 902            credentials = Some(old_credentials);
 903        }
 904
 905        if credentials.is_none()
 906            && try_provider
 907            && let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await
 908        {
 909            if self.validate_credentials(&stored_credentials, cx).await? {
 910                credentials = Some(stored_credentials);
 911            } else {
 912                self.credentials_provider
 913                    .delete_credentials(cx)
 914                    .await
 915                    .log_err();
 916            }
 917        }
 918
 919        if credentials.is_none() {
 920            let mut status_rx = self.status();
 921            let _ = status_rx.next().await;
 922            futures::select_biased! {
 923                authenticate = self.authenticate(cx).fuse() => {
 924                    match authenticate {
 925                        Ok(creds) => {
 926                            if IMPERSONATE_LOGIN.is_none() {
 927                                self.credentials_provider
 928                                    .write_credentials(creds.user_id, creds.access_token.clone(), cx)
 929                                    .await
 930                                    .log_err();
 931                            }
 932
 933                            credentials = Some(creds);
 934                        },
 935                        Err(err) => {
 936                            self.set_status(Status::AuthenticationError, cx);
 937                            return Err(err);
 938                        }
 939                    }
 940                }
 941                _ = status_rx.next().fuse() => {
 942                    return Err(anyhow!("authentication canceled"));
 943                }
 944            }
 945        }
 946
 947        let credentials = credentials.unwrap();
 948        self.set_id(credentials.user_id);
 949        self.cloud_client
 950            .set_credentials(credentials.user_id as u32, credentials.access_token.clone());
 951        self.state.write().credentials = Some(credentials.clone());
 952        self.set_status(
 953            if is_reauthenticating {
 954                Status::Reauthenticated
 955            } else {
 956                Status::Authenticated
 957            },
 958            cx,
 959        );
 960
 961        Ok(credentials)
 962    }
 963
 964    async fn validate_credentials(
 965        self: &Arc<Self>,
 966        credentials: &Credentials,
 967        cx: &AsyncApp,
 968    ) -> Result<bool> {
 969        match self
 970            .cloud_client
 971            .validate_credentials(credentials.user_id as u32, &credentials.access_token)
 972            .await
 973        {
 974            Ok(valid) => Ok(valid),
 975            Err(err) => {
 976                self.set_status(Status::AuthenticationError, cx);
 977                Err(anyhow!("failed to validate credentials: {}", err))
 978            }
 979        }
 980    }
 981
 982    /// Establishes a WebSocket connection with Cloud for receiving updates from the server.
 983    async fn connect_to_cloud(self: &Arc<Self>, cx: &AsyncApp) -> Result<()> {
 984        let connect_task = cx.update({
 985            let cloud_client = self.cloud_client.clone();
 986            move |cx| cloud_client.connect(cx)
 987        })??;
 988        let connection = connect_task.await?;
 989
 990        let (mut messages, task) = cx.update(|cx| connection.spawn(cx))?;
 991        task.detach();
 992
 993        cx.spawn({
 994            let this = self.clone();
 995            async move |cx| {
 996                while let Some(message) = messages.next().await {
 997                    if let Some(message) = message.log_err() {
 998                        this.handle_message_to_client(message, cx);
 999                    }
1000                }
1001            }
1002        })
1003        .detach();
1004
1005        Ok(())
1006    }
1007
1008    /// Performs a sign-in and also (optionally) connects to Collab.
1009    ///
1010    /// Only Zed staff automatically connect to Collab.
1011    pub async fn sign_in_with_optional_connect(
1012        self: &Arc<Self>,
1013        try_provider: bool,
1014        cx: &AsyncApp,
1015    ) -> Result<()> {
1016        // Don't try to sign in again if we're already connected to Collab, as it will temporarily disconnect us.
1017        if self.status().borrow().is_connected() {
1018            return Ok(());
1019        }
1020
1021        let (is_staff_tx, is_staff_rx) = oneshot::channel::<bool>();
1022        let mut is_staff_tx = Some(is_staff_tx);
1023        cx.update(|cx| {
1024            cx.on_flags_ready(move |state, _cx| {
1025                if let Some(is_staff_tx) = is_staff_tx.take() {
1026                    is_staff_tx.send(state.is_staff).log_err();
1027                }
1028            })
1029            .detach();
1030        })
1031        .log_err();
1032
1033        let credentials = self.sign_in(try_provider, cx).await?;
1034
1035        self.connect_to_cloud(cx).await.log_err();
1036
1037        cx.update(move |cx| {
1038            cx.spawn({
1039                let client = self.clone();
1040                async move |cx| {
1041                    let is_staff = is_staff_rx.await?;
1042                    if is_staff {
1043                        match client.connect_with_credentials(credentials, cx).await {
1044                            ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
1045                            ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
1046                            ConnectionResult::Result(result) => {
1047                                result.context("client auth and connect")
1048                            }
1049                        }
1050                    } else {
1051                        Ok(())
1052                    }
1053                }
1054            })
1055            .detach_and_log_err(cx);
1056        })
1057        .log_err();
1058
1059        Ok(())
1060    }
1061
1062    pub async fn connect(
1063        self: &Arc<Self>,
1064        try_provider: bool,
1065        cx: &AsyncApp,
1066    ) -> ConnectionResult<()> {
1067        let was_disconnected = match *self.status().borrow() {
1068            Status::SignedOut | Status::Authenticated => true,
1069            Status::ConnectionError
1070            | Status::ConnectionLost
1071            | Status::Authenticating
1072            | Status::AuthenticationError
1073            | Status::Reauthenticating
1074            | Status::Reauthenticated
1075            | Status::ReconnectionError { .. } => false,
1076            Status::Connected { .. } | Status::Connecting | Status::Reconnecting => {
1077                return ConnectionResult::Result(Ok(()));
1078            }
1079            Status::UpgradeRequired => {
1080                return ConnectionResult::Result(
1081                    Err(EstablishConnectionError::UpgradeRequired)
1082                        .context("client auth and connect"),
1083                );
1084            }
1085        };
1086        let credentials = match self.sign_in(try_provider, cx).await {
1087            Ok(credentials) => credentials,
1088            Err(err) => return ConnectionResult::Result(Err(err)),
1089        };
1090
1091        if was_disconnected {
1092            self.set_status(Status::Connecting, cx);
1093        } else {
1094            self.set_status(Status::Reconnecting, cx);
1095        }
1096
1097        self.connect_with_credentials(credentials, cx).await
1098    }
1099
1100    async fn connect_with_credentials(
1101        self: &Arc<Self>,
1102        credentials: Credentials,
1103        cx: &AsyncApp,
1104    ) -> ConnectionResult<()> {
1105        let mut timeout =
1106            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
1107        futures::select_biased! {
1108            connection = self.establish_connection(&credentials, cx).fuse() => {
1109                match connection {
1110                    Ok(conn) => {
1111                        futures::select_biased! {
1112                            result = self.set_connection(conn, cx).fuse() => {
1113                                match result.context("client auth and connect") {
1114                                    Ok(()) => ConnectionResult::Result(Ok(())),
1115                                    Err(err) => {
1116                                        self.set_status(Status::ConnectionError, cx);
1117                                        ConnectionResult::Result(Err(err))
1118                                    },
1119                                }
1120                            },
1121                            _ = timeout => {
1122                                self.set_status(Status::ConnectionError, cx);
1123                                ConnectionResult::Timeout
1124                            }
1125                        }
1126                    }
1127                    Err(EstablishConnectionError::Unauthorized) => {
1128                        self.set_status(Status::ConnectionError, cx);
1129                        ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
1130                    }
1131                    Err(EstablishConnectionError::UpgradeRequired) => {
1132                        self.set_status(Status::UpgradeRequired, cx);
1133                        ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
1134                    }
1135                    Err(error) => {
1136                        self.set_status(Status::ConnectionError, cx);
1137                        ConnectionResult::Result(Err(error).context("client auth and connect"))
1138                    }
1139                }
1140            }
1141            _ = &mut timeout => {
1142                self.set_status(Status::ConnectionError, cx);
1143                ConnectionResult::Timeout
1144            }
1145        }
1146    }
1147
1148    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
1149        let executor = cx.background_executor();
1150        log::debug!("add connection to peer");
1151        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
1152            let executor = executor.clone();
1153            move |duration| executor.timer(duration)
1154        });
1155        let handle_io = executor.spawn(handle_io);
1156
1157        let peer_id = async {
1158            log::debug!("waiting for server hello");
1159            let message = incoming.next().await.context("no hello message received")?;
1160            log::debug!("got server hello");
1161            let hello_message_type_name = message.payload_type_name().to_string();
1162            let hello = message
1163                .into_any()
1164                .downcast::<TypedEnvelope<proto::Hello>>()
1165                .map_err(|_| {
1166                    anyhow!(
1167                        "invalid hello message received: {:?}",
1168                        hello_message_type_name
1169                    )
1170                })?;
1171            let peer_id = hello.payload.peer_id.context("invalid peer id")?;
1172            Ok(peer_id)
1173        };
1174
1175        let peer_id = match peer_id.await {
1176            Ok(peer_id) => peer_id,
1177            Err(error) => {
1178                self.peer.disconnect(connection_id);
1179                return Err(error);
1180            }
1181        };
1182
1183        log::debug!(
1184            "set status to connected (connection id: {:?}, peer id: {:?})",
1185            connection_id,
1186            peer_id
1187        );
1188        self.set_status(
1189            Status::Connected {
1190                peer_id,
1191                connection_id,
1192            },
1193            cx,
1194        );
1195
1196        cx.spawn({
1197            let this = self.clone();
1198            async move |cx| {
1199                while let Some(message) = incoming.next().await {
1200                    this.handle_message(message, cx);
1201                    // Don't starve the main thread when receiving lots of messages at once.
1202                    smol::future::yield_now().await;
1203                }
1204            }
1205        })
1206        .detach();
1207
1208        cx.spawn({
1209            let this = self.clone();
1210            async move |cx| match handle_io.await {
1211                Ok(()) => {
1212                    if *this.status().borrow()
1213                        == (Status::Connected {
1214                            connection_id,
1215                            peer_id,
1216                        })
1217                    {
1218                        this.set_status(Status::SignedOut, cx);
1219                    }
1220                }
1221                Err(err) => {
1222                    log::error!("connection error: {:?}", err);
1223                    this.set_status(Status::ConnectionLost, cx);
1224                }
1225            }
1226        })
1227        .detach();
1228
1229        Ok(())
1230    }
1231
1232    fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1233        #[cfg(any(test, feature = "test-support"))]
1234        if let Some(callback) = self.authenticate.read().as_ref() {
1235            return callback(cx);
1236        }
1237
1238        self.authenticate_with_browser(cx)
1239    }
1240
1241    fn establish_connection(
1242        self: &Arc<Self>,
1243        credentials: &Credentials,
1244        cx: &AsyncApp,
1245    ) -> Task<Result<Connection, EstablishConnectionError>> {
1246        #[cfg(any(test, feature = "test-support"))]
1247        if let Some(callback) = self.establish_connection.read().as_ref() {
1248            return callback(credentials, cx);
1249        }
1250
1251        self.establish_websocket_connection(credentials, cx)
1252    }
1253
1254    fn rpc_url(
1255        &self,
1256        http: Arc<HttpClientWithUrl>,
1257        release_channel: Option<ReleaseChannel>,
1258    ) -> impl Future<Output = Result<url::Url>> + use<> {
1259        #[cfg(any(test, feature = "test-support"))]
1260        let url_override = self.rpc_url.read().clone();
1261
1262        async move {
1263            #[cfg(any(test, feature = "test-support"))]
1264            if let Some(url) = url_override {
1265                return Ok(url);
1266            }
1267
1268            if let Some(url) = &*ZED_RPC_URL {
1269                return Url::parse(url).context("invalid rpc url");
1270            }
1271
1272            let mut url = http.build_url("/rpc");
1273            if let Some(preview_param) =
1274                release_channel.and_then(|channel| channel.release_query_param())
1275            {
1276                url += "?";
1277                url += preview_param;
1278            }
1279
1280            let response = http.get(&url, Default::default(), false).await?;
1281            anyhow::ensure!(
1282                response.status().is_redirection(),
1283                "unexpected /rpc response status {}",
1284                response.status()
1285            );
1286            let collab_url = response
1287                .headers()
1288                .get("Location")
1289                .context("missing location header in /rpc response")?
1290                .to_str()
1291                .map_err(EstablishConnectionError::other)?
1292                .to_string();
1293            Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
1294        }
1295    }
1296
1297    fn establish_websocket_connection(
1298        self: &Arc<Self>,
1299        credentials: &Credentials,
1300        cx: &AsyncApp,
1301    ) -> Task<Result<Connection, EstablishConnectionError>> {
1302        let release_channel = cx
1303            .update(|cx| ReleaseChannel::try_global(cx))
1304            .ok()
1305            .flatten();
1306        let app_version = cx
1307            .update(|cx| AppVersion::global(cx).to_string())
1308            .ok()
1309            .unwrap_or_default();
1310
1311        let http = self.http.clone();
1312        let proxy = http.proxy().cloned();
1313        let user_agent = http.user_agent().cloned();
1314        let credentials = credentials.clone();
1315        let rpc_url = self.rpc_url(http, release_channel);
1316        let system_id = self.telemetry.system_id();
1317        let metrics_id = self.telemetry.metrics_id();
1318        cx.spawn(async move |cx| {
1319            use HttpOrHttps::*;
1320
1321            #[derive(Debug)]
1322            enum HttpOrHttps {
1323                Http,
1324                Https,
1325            }
1326
1327            let mut rpc_url = rpc_url.await?;
1328            let url_scheme = match rpc_url.scheme() {
1329                "https" => Https,
1330                "http" => Http,
1331                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1332            };
1333
1334            let stream = gpui_tokio::Tokio::spawn_result(cx, {
1335                let rpc_url = rpc_url.clone();
1336                async move {
1337                    let rpc_host = rpc_url
1338                        .host_str()
1339                        .zip(rpc_url.port_or_known_default())
1340                        .context("missing host in rpc url")?;
1341                    Ok(match proxy {
1342                        Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
1343                        None => Box::new(TcpStream::connect(rpc_host).await?),
1344                    })
1345                }
1346            })?
1347            .await?;
1348
1349            log::info!("connected to rpc endpoint {}", rpc_url);
1350
1351            rpc_url
1352                .set_scheme(match url_scheme {
1353                    Https => "wss",
1354                    Http => "ws",
1355                })
1356                .unwrap();
1357
1358            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1359            // for us from the RPC URL.
1360            //
1361            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1362            let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1363
1364            // We then modify the request to add our desired headers.
1365            let request_headers = request.headers_mut();
1366            request_headers.insert(
1367                http::header::AUTHORIZATION,
1368                HeaderValue::from_str(&credentials.authorization_header())?,
1369            );
1370            request_headers.insert(
1371                "x-zed-protocol-version",
1372                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1373            );
1374            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1375            request_headers.insert(
1376                "x-zed-release-channel",
1377                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1378            );
1379            if let Some(user_agent) = user_agent {
1380                request_headers.insert(http::header::USER_AGENT, user_agent);
1381            }
1382            if let Some(system_id) = system_id {
1383                request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1384            }
1385            if let Some(metrics_id) = metrics_id {
1386                request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1387            }
1388
1389            let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1390                request,
1391                stream,
1392                Some(Arc::new(http_client_tls::tls_config()).into()),
1393                None,
1394            )
1395            .await?;
1396
1397            Ok(Connection::new(
1398                stream
1399                    .map_err(|error| anyhow!(error))
1400                    .sink_map_err(|error| anyhow!(error)),
1401            ))
1402        })
1403    }
1404
1405    pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1406        let http = self.http.clone();
1407        let this = self.clone();
1408        cx.spawn(async move |cx| {
1409            let background = cx.background_executor().clone();
1410
1411            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1412            cx.update(|cx| {
1413                cx.spawn(async move |cx| {
1414                    let url = open_url_rx.await?;
1415                    cx.update(|cx| cx.open_url(&url))
1416                })
1417                .detach_and_log_err(cx);
1418            })
1419            .log_err();
1420
1421            let credentials = background
1422                .clone()
1423                .spawn(async move {
1424                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1425                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1426                    // any other app running on the user's device.
1427                    let (public_key, private_key) =
1428                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1429                    let public_key_string = String::try_from(public_key)
1430                        .expect("failed to serialize public key for auth");
1431
1432                    if let Some((login, token)) =
1433                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1434                    {
1435                        if !*USE_WEB_LOGIN {
1436                            eprintln!("authenticate as admin {login}, {token}");
1437
1438                            return this
1439                                .authenticate_as_admin(http, login.clone(), token.clone())
1440                                .await;
1441                        }
1442                    }
1443
1444                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1445                    let server =
1446                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1447                    let port = server.server_addr().port();
1448
1449                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1450                    // that the user is signing in from a Zed app running on the same device.
1451                    let mut url = http.build_url(&format!(
1452                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1453                        port, public_key_string
1454                    ));
1455
1456                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1457                        log::info!("impersonating user @{}", impersonate_login);
1458                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1459                    }
1460
1461                    open_url_tx.send(url).log_err();
1462
1463                    #[derive(Deserialize)]
1464                    struct CallbackParams {
1465                        pub user_id: String,
1466                        pub access_token: String,
1467                    }
1468
1469                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1470                    // access token from the query params.
1471                    //
1472                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1473                    // custom URL scheme instead of this local HTTP server.
1474                    let (user_id, access_token) = background
1475                        .spawn(async move {
1476                            for _ in 0..100 {
1477                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1478                                    let path = req.url();
1479                                    let url = Url::parse(&format!("http://example.com{}", path))
1480                                        .context("failed to parse login notification url")?;
1481                                    let callback_params: CallbackParams =
1482                                        serde_urlencoded::from_str(url.query().unwrap_or_default())
1483                                            .context(
1484                                                "failed to parse sign-in callback query parameters",
1485                                            )?;
1486
1487                                    let post_auth_url =
1488                                        http.build_url("/native_app_signin_succeeded");
1489                                    req.respond(
1490                                        tiny_http::Response::empty(302).with_header(
1491                                            tiny_http::Header::from_bytes(
1492                                                &b"Location"[..],
1493                                                post_auth_url.as_bytes(),
1494                                            )
1495                                            .unwrap(),
1496                                        ),
1497                                    )
1498                                    .context("failed to respond to login http request")?;
1499                                    return Ok((
1500                                        callback_params.user_id,
1501                                        callback_params.access_token,
1502                                    ));
1503                                }
1504                            }
1505
1506                            anyhow::bail!("didn't receive login redirect");
1507                        })
1508                        .await?;
1509
1510                    let access_token = private_key
1511                        .decrypt_string(&access_token)
1512                        .context("failed to decrypt access token")?;
1513
1514                    Ok(Credentials {
1515                        user_id: user_id.parse()?,
1516                        access_token,
1517                    })
1518                })
1519                .await?;
1520
1521            cx.update(|cx| cx.activate(true))?;
1522            Ok(credentials)
1523        })
1524    }
1525
1526    async fn authenticate_as_admin(
1527        self: &Arc<Self>,
1528        http: Arc<HttpClientWithUrl>,
1529        login: String,
1530        api_token: String,
1531    ) -> Result<Credentials> {
1532        #[derive(Serialize)]
1533        struct ImpersonateUserBody {
1534            github_login: String,
1535        }
1536
1537        #[derive(Deserialize)]
1538        struct ImpersonateUserResponse {
1539            user_id: u64,
1540            access_token: String,
1541        }
1542
1543        let url = self
1544            .http
1545            .build_zed_cloud_url("/internal/users/impersonate", &[])?;
1546        let request = Request::post(url.as_str())
1547            .header("Content-Type", "application/json")
1548            .header("Authorization", format!("Bearer {api_token}"))
1549            .body(
1550                serde_json::to_string(&ImpersonateUserBody {
1551                    github_login: login,
1552                })?
1553                .into(),
1554            )?;
1555
1556        let mut response = http.send(request).await?;
1557        let mut body = String::new();
1558        response.body_mut().read_to_string(&mut body).await?;
1559        anyhow::ensure!(
1560            response.status().is_success(),
1561            "admin user request failed {} - {}",
1562            response.status().as_u16(),
1563            body,
1564        );
1565        let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
1566
1567        Ok(Credentials {
1568            user_id: response.user_id,
1569            access_token: response.access_token,
1570        })
1571    }
1572
1573    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1574        self.state.write().credentials = None;
1575        self.cloud_client.clear_credentials();
1576        self.disconnect(cx);
1577
1578        if self.has_credentials(cx).await {
1579            self.credentials_provider
1580                .delete_credentials(cx)
1581                .await
1582                .log_err();
1583        }
1584    }
1585
1586    pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1587        self.peer.teardown();
1588        self.set_status(Status::SignedOut, cx);
1589    }
1590
1591    pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1592        self.peer.teardown();
1593        self.set_status(Status::ConnectionLost, cx);
1594    }
1595
1596    fn connection_id(&self) -> Result<ConnectionId> {
1597        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1598            Ok(connection_id)
1599        } else {
1600            anyhow::bail!("not connected");
1601        }
1602    }
1603
1604    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1605        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1606        self.peer.send(self.connection_id()?, message)
1607    }
1608
1609    pub fn request<T: RequestMessage>(
1610        &self,
1611        request: T,
1612    ) -> impl Future<Output = Result<T::Response>> + use<T> {
1613        self.request_envelope(request)
1614            .map_ok(|envelope| envelope.payload)
1615    }
1616
1617    pub fn request_stream<T: RequestMessage>(
1618        &self,
1619        request: T,
1620    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1621        let client_id = self.id.load(Ordering::SeqCst);
1622        log::debug!(
1623            "rpc request start. client_id:{}. name:{}",
1624            client_id,
1625            T::NAME
1626        );
1627        let response = self
1628            .connection_id()
1629            .map(|conn_id| self.peer.request_stream(conn_id, request));
1630        async move {
1631            let response = response?.await;
1632            log::debug!(
1633                "rpc request finish. client_id:{}. name:{}",
1634                client_id,
1635                T::NAME
1636            );
1637            response
1638        }
1639    }
1640
1641    pub fn request_envelope<T: RequestMessage>(
1642        &self,
1643        request: T,
1644    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1645        let client_id = self.id();
1646        log::debug!(
1647            "rpc request start. client_id:{}. name:{}",
1648            client_id,
1649            T::NAME
1650        );
1651        let response = self
1652            .connection_id()
1653            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1654        async move {
1655            let response = response?.await;
1656            log::debug!(
1657                "rpc request finish. client_id:{}. name:{}",
1658                client_id,
1659                T::NAME
1660            );
1661            response
1662        }
1663    }
1664
1665    pub fn request_dynamic(
1666        &self,
1667        envelope: proto::Envelope,
1668        request_type: &'static str,
1669    ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1670        let client_id = self.id();
1671        log::debug!(
1672            "rpc request start. client_id:{}. name:{}",
1673            client_id,
1674            request_type
1675        );
1676        let response = self
1677            .connection_id()
1678            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1679        async move {
1680            let response = response?.await;
1681            log::debug!(
1682                "rpc request finish. client_id:{}. name:{}",
1683                client_id,
1684                request_type
1685            );
1686            Ok(response?.0)
1687        }
1688    }
1689
1690    fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1691        let sender_id = message.sender_id();
1692        let request_id = message.message_id();
1693        let type_name = message.payload_type_name();
1694        let original_sender_id = message.original_sender_id();
1695
1696        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1697            &self.handler_set,
1698            message,
1699            self.clone().into(),
1700            cx.clone(),
1701        ) {
1702            let client_id = self.id();
1703            log::debug!(
1704                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1705                client_id,
1706                original_sender_id,
1707                type_name
1708            );
1709            cx.spawn(async move |_| match future.await {
1710                Ok(()) => {
1711                    log::debug!("rpc message handled. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}");
1712                }
1713                Err(error) => {
1714                    log::error!("error handling message. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}, error:{error:#}");
1715                }
1716            })
1717            .detach();
1718        } else {
1719            log::info!("unhandled message {}", type_name);
1720            self.peer
1721                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1722                .log_err();
1723        }
1724    }
1725
1726    pub fn add_message_to_client_handler(
1727        self: &Arc<Client>,
1728        handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
1729    ) {
1730        self.message_to_client_handlers
1731            .lock()
1732            .push(Box::new(handler));
1733    }
1734
1735    fn handle_message_to_client(self: &Arc<Client>, message: MessageToClient, cx: &AsyncApp) {
1736        cx.update(|cx| {
1737            for handler in self.message_to_client_handlers.lock().iter() {
1738                handler(&message, cx);
1739            }
1740        })
1741        .ok();
1742    }
1743
1744    pub fn telemetry(&self) -> &Arc<Telemetry> {
1745        &self.telemetry
1746    }
1747}
1748
1749impl ProtoClient for Client {
1750    fn request(
1751        &self,
1752        envelope: proto::Envelope,
1753        request_type: &'static str,
1754    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1755        self.request_dynamic(envelope, request_type).boxed()
1756    }
1757
1758    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1759        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1760        let connection_id = self.connection_id()?;
1761        self.peer.send_dynamic(connection_id, envelope)
1762    }
1763
1764    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1765        log::debug!(
1766            "rpc respond. client_id:{}, name:{}",
1767            self.id(),
1768            message_type
1769        );
1770        let connection_id = self.connection_id()?;
1771        self.peer.send_dynamic(connection_id, envelope)
1772    }
1773
1774    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1775        &self.handler_set
1776    }
1777
1778    fn is_via_collab(&self) -> bool {
1779        true
1780    }
1781}
1782
1783/// prefix for the zed:// url scheme
1784pub const ZED_URL_SCHEME: &str = "zed";
1785
1786/// Parses the given link into a Zed link.
1787///
1788/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1789/// Returns [`None`] otherwise.
1790pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
1791    let server_url = &ClientSettings::get_global(cx).server_url;
1792    if let Some(stripped) = link
1793        .strip_prefix(server_url)
1794        .and_then(|result| result.strip_prefix('/'))
1795    {
1796        return Some(stripped);
1797    }
1798    if let Some(stripped) = link
1799        .strip_prefix(ZED_URL_SCHEME)
1800        .and_then(|result| result.strip_prefix("://"))
1801    {
1802        return Some(stripped);
1803    }
1804
1805    None
1806}
1807
1808#[cfg(test)]
1809mod tests {
1810    use super::*;
1811    use crate::test::{FakeServer, parse_authorization_header};
1812
1813    use clock::FakeSystemClock;
1814    use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1815    use http_client::FakeHttpClient;
1816    use parking_lot::Mutex;
1817    use proto::TypedEnvelope;
1818    use settings::SettingsStore;
1819    use std::future;
1820
1821    #[gpui::test(iterations = 10)]
1822    async fn test_reconnection(cx: &mut TestAppContext) {
1823        init_test(cx);
1824        let user_id = 5;
1825        let client = cx.update(|cx| {
1826            Client::new(
1827                Arc::new(FakeSystemClock::new()),
1828                FakeHttpClient::with_404_response(),
1829                cx,
1830            )
1831        });
1832        let server = FakeServer::for_client(user_id, &client, cx).await;
1833        let mut status = client.status();
1834        assert!(matches!(
1835            status.next().await,
1836            Some(Status::Connected { .. })
1837        ));
1838        assert_eq!(server.auth_count(), 1);
1839
1840        server.forbid_connections();
1841        server.disconnect();
1842        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1843
1844        server.allow_connections();
1845        cx.executor().advance_clock(Duration::from_secs(10));
1846        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1847        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1848
1849        server.forbid_connections();
1850        server.disconnect();
1851        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1852
1853        // Clear cached credentials after authentication fails
1854        server.roll_access_token();
1855        server.allow_connections();
1856        cx.executor().run_until_parked();
1857        cx.executor().advance_clock(Duration::from_secs(10));
1858        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1859        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1860    }
1861
1862    #[gpui::test(iterations = 10)]
1863    async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) {
1864        init_test(cx);
1865        let http_client = FakeHttpClient::with_200_response();
1866        let client =
1867            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1868        let server = FakeServer::for_client(42, &client, cx).await;
1869        let mut status = client.status();
1870        assert!(matches!(
1871            status.next().await,
1872            Some(Status::Connected { .. })
1873        ));
1874        assert_eq!(server.auth_count(), 1);
1875
1876        // Simulate an auth failure during reconnection.
1877        http_client
1878            .as_fake()
1879            .replace_handler(|_, _request| async move {
1880                Ok(http_client::Response::builder()
1881                    .status(503)
1882                    .body("".into())
1883                    .unwrap())
1884            });
1885        server.disconnect();
1886        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1887
1888        // Restore the ability to authenticate.
1889        http_client
1890            .as_fake()
1891            .replace_handler(|_, _request| async move {
1892                Ok(http_client::Response::builder()
1893                    .status(200)
1894                    .body("".into())
1895                    .unwrap())
1896            });
1897        cx.executor().advance_clock(Duration::from_secs(10));
1898        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1899        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1900    }
1901
1902    #[gpui::test(iterations = 10)]
1903    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1904        init_test(cx);
1905        let user_id = 5;
1906        let client = cx.update(|cx| {
1907            Client::new(
1908                Arc::new(FakeSystemClock::new()),
1909                FakeHttpClient::with_404_response(),
1910                cx,
1911            )
1912        });
1913        let mut status = client.status();
1914
1915        // Time out when client tries to connect.
1916        client.override_authenticate(move |cx| {
1917            cx.background_spawn(async move {
1918                Ok(Credentials {
1919                    user_id,
1920                    access_token: "token".into(),
1921                })
1922            })
1923        });
1924        client.override_establish_connection(|_, cx| {
1925            cx.background_spawn(async move {
1926                future::pending::<()>().await;
1927                unreachable!()
1928            })
1929        });
1930        let auth_and_connect = cx.spawn({
1931            let client = client.clone();
1932            |cx| async move { client.connect(false, &cx).await }
1933        });
1934        executor.run_until_parked();
1935        assert!(matches!(status.next().await, Some(Status::Connecting)));
1936
1937        executor.advance_clock(CONNECTION_TIMEOUT);
1938        assert!(matches!(status.next().await, Some(Status::ConnectionError)));
1939        auth_and_connect.await.into_response().unwrap_err();
1940
1941        // Allow the connection to be established.
1942        let server = FakeServer::for_client(user_id, &client, cx).await;
1943        assert!(matches!(
1944            status.next().await,
1945            Some(Status::Connected { .. })
1946        ));
1947
1948        // Disconnect client.
1949        server.forbid_connections();
1950        server.disconnect();
1951        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1952
1953        // Time out when re-establishing the connection.
1954        server.allow_connections();
1955        client.override_establish_connection(|_, cx| {
1956            cx.background_spawn(async move {
1957                future::pending::<()>().await;
1958                unreachable!()
1959            })
1960        });
1961        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1962        assert!(matches!(status.next().await, Some(Status::Reconnecting)));
1963
1964        executor.advance_clock(CONNECTION_TIMEOUT);
1965        assert!(matches!(
1966            status.next().await,
1967            Some(Status::ReconnectionError { .. })
1968        ));
1969    }
1970
1971    #[gpui::test(iterations = 10)]
1972    async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
1973        init_test(cx);
1974        let auth_count = Arc::new(Mutex::new(0));
1975        let http_client = FakeHttpClient::create(|_request| async move {
1976            Ok(http_client::Response::builder()
1977                .status(200)
1978                .body("".into())
1979                .unwrap())
1980        });
1981        let client =
1982            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1983        client.override_authenticate({
1984            let auth_count = auth_count.clone();
1985            move |cx| {
1986                let auth_count = auth_count.clone();
1987                cx.background_spawn(async move {
1988                    *auth_count.lock() += 1;
1989                    Ok(Credentials {
1990                        user_id: 1,
1991                        access_token: auth_count.lock().to_string(),
1992                    })
1993                })
1994            }
1995        });
1996
1997        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1998        assert_eq!(*auth_count.lock(), 1);
1999        assert_eq!(credentials.access_token, "1");
2000
2001        // If credentials are still valid, signing in doesn't trigger authentication.
2002        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2003        assert_eq!(*auth_count.lock(), 1);
2004        assert_eq!(credentials.access_token, "1");
2005
2006        // If the server is unavailable, signing in doesn't trigger authentication.
2007        http_client
2008            .as_fake()
2009            .replace_handler(|_, _request| async move {
2010                Ok(http_client::Response::builder()
2011                    .status(503)
2012                    .body("".into())
2013                    .unwrap())
2014            });
2015        client.sign_in(false, &cx.to_async()).await.unwrap_err();
2016        assert_eq!(*auth_count.lock(), 1);
2017
2018        // If credentials became invalid, signing in triggers authentication.
2019        http_client
2020            .as_fake()
2021            .replace_handler(|_, request| async move {
2022                let credentials = parse_authorization_header(&request).unwrap();
2023                if credentials.access_token == "2" {
2024                    Ok(http_client::Response::builder()
2025                        .status(200)
2026                        .body("".into())
2027                        .unwrap())
2028                } else {
2029                    Ok(http_client::Response::builder()
2030                        .status(401)
2031                        .body("".into())
2032                        .unwrap())
2033                }
2034            });
2035        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2036        assert_eq!(*auth_count.lock(), 2);
2037        assert_eq!(credentials.access_token, "2");
2038    }
2039
2040    #[gpui::test(iterations = 10)]
2041    async fn test_authenticating_more_than_once(
2042        cx: &mut TestAppContext,
2043        executor: BackgroundExecutor,
2044    ) {
2045        init_test(cx);
2046        let auth_count = Arc::new(Mutex::new(0));
2047        let dropped_auth_count = Arc::new(Mutex::new(0));
2048        let client = cx.update(|cx| {
2049            Client::new(
2050                Arc::new(FakeSystemClock::new()),
2051                FakeHttpClient::with_404_response(),
2052                cx,
2053            )
2054        });
2055        client.override_authenticate({
2056            let auth_count = auth_count.clone();
2057            let dropped_auth_count = dropped_auth_count.clone();
2058            move |cx| {
2059                let auth_count = auth_count.clone();
2060                let dropped_auth_count = dropped_auth_count.clone();
2061                cx.background_spawn(async move {
2062                    *auth_count.lock() += 1;
2063                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
2064                    future::pending::<()>().await;
2065                    unreachable!()
2066                })
2067            }
2068        });
2069
2070        let _authenticate = cx.spawn({
2071            let client = client.clone();
2072            move |cx| async move { client.connect(false, &cx).await }
2073        });
2074        executor.run_until_parked();
2075        assert_eq!(*auth_count.lock(), 1);
2076        assert_eq!(*dropped_auth_count.lock(), 0);
2077
2078        let _authenticate = cx.spawn(|cx| async move { client.connect(false, &cx).await });
2079        executor.run_until_parked();
2080        assert_eq!(*auth_count.lock(), 2);
2081        assert_eq!(*dropped_auth_count.lock(), 1);
2082    }
2083
2084    #[gpui::test]
2085    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
2086        init_test(cx);
2087        let user_id = 5;
2088        let client = cx.update(|cx| {
2089            Client::new(
2090                Arc::new(FakeSystemClock::new()),
2091                FakeHttpClient::with_404_response(),
2092                cx,
2093            )
2094        });
2095        let server = FakeServer::for_client(user_id, &client, cx).await;
2096
2097        let (done_tx1, done_rx1) = smol::channel::unbounded();
2098        let (done_tx2, done_rx2) = smol::channel::unbounded();
2099        AnyProtoClient::from(client.clone()).add_entity_message_handler(
2100            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, cx| {
2101                match entity.read_with(&cx, |entity, _| entity.id).unwrap() {
2102                    1 => done_tx1.try_send(()).unwrap(),
2103                    2 => done_tx2.try_send(()).unwrap(),
2104                    _ => unreachable!(),
2105                }
2106                async { Ok(()) }
2107            },
2108        );
2109        let entity1 = cx.new(|_| TestEntity {
2110            id: 1,
2111            subscription: None,
2112        });
2113        let entity2 = cx.new(|_| TestEntity {
2114            id: 2,
2115            subscription: None,
2116        });
2117        let entity3 = cx.new(|_| TestEntity {
2118            id: 3,
2119            subscription: None,
2120        });
2121
2122        let _subscription1 = client
2123            .subscribe_to_entity(1)
2124            .unwrap()
2125            .set_entity(&entity1, &cx.to_async());
2126        let _subscription2 = client
2127            .subscribe_to_entity(2)
2128            .unwrap()
2129            .set_entity(&entity2, &cx.to_async());
2130        // Ensure dropping a subscription for the same entity type still allows receiving of
2131        // messages for other entity IDs of the same type.
2132        let subscription3 = client
2133            .subscribe_to_entity(3)
2134            .unwrap()
2135            .set_entity(&entity3, &cx.to_async());
2136        drop(subscription3);
2137
2138        server.send(proto::JoinProject {
2139            project_id: 1,
2140            committer_name: None,
2141            committer_email: None,
2142        });
2143        server.send(proto::JoinProject {
2144            project_id: 2,
2145            committer_name: None,
2146            committer_email: None,
2147        });
2148        done_rx1.recv().await.unwrap();
2149        done_rx2.recv().await.unwrap();
2150    }
2151
2152    #[gpui::test]
2153    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2154        init_test(cx);
2155        let user_id = 5;
2156        let client = cx.update(|cx| {
2157            Client::new(
2158                Arc::new(FakeSystemClock::new()),
2159                FakeHttpClient::with_404_response(),
2160                cx,
2161            )
2162        });
2163        let server = FakeServer::for_client(user_id, &client, cx).await;
2164
2165        let entity = cx.new(|_| TestEntity::default());
2166        let (done_tx1, _done_rx1) = smol::channel::unbounded();
2167        let (done_tx2, done_rx2) = smol::channel::unbounded();
2168        let subscription1 = client.add_message_handler(
2169            entity.downgrade(),
2170            move |_, _: TypedEnvelope<proto::Ping>, _| {
2171                done_tx1.try_send(()).unwrap();
2172                async { Ok(()) }
2173            },
2174        );
2175        drop(subscription1);
2176        let _subscription2 = client.add_message_handler(
2177            entity.downgrade(),
2178            move |_, _: TypedEnvelope<proto::Ping>, _| {
2179                done_tx2.try_send(()).unwrap();
2180                async { Ok(()) }
2181            },
2182        );
2183        server.send(proto::Ping {});
2184        done_rx2.recv().await.unwrap();
2185    }
2186
2187    #[gpui::test]
2188    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2189        init_test(cx);
2190        let user_id = 5;
2191        let client = cx.update(|cx| {
2192            Client::new(
2193                Arc::new(FakeSystemClock::new()),
2194                FakeHttpClient::with_404_response(),
2195                cx,
2196            )
2197        });
2198        let server = FakeServer::for_client(user_id, &client, cx).await;
2199
2200        let entity = cx.new(|_| TestEntity::default());
2201        let (done_tx, done_rx) = smol::channel::unbounded();
2202        let subscription = client.add_message_handler(
2203            entity.clone().downgrade(),
2204            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
2205                entity
2206                    .update(&mut cx, |entity, _| entity.subscription.take())
2207                    .unwrap();
2208                done_tx.try_send(()).unwrap();
2209                async { Ok(()) }
2210            },
2211        );
2212        entity.update(cx, |entity, _| {
2213            entity.subscription = Some(subscription);
2214        });
2215        server.send(proto::Ping {});
2216        done_rx.recv().await.unwrap();
2217    }
2218
2219    #[derive(Default)]
2220    struct TestEntity {
2221        id: usize,
2222        subscription: Option<Subscription>,
2223    }
2224
2225    fn init_test(cx: &mut TestAppContext) {
2226        cx.update(|cx| {
2227            let settings_store = SettingsStore::test(cx);
2228            cx.set_global(settings_store);
2229            init_settings(cx);
2230        });
2231    }
2232}