client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod llm_token;
   5mod proxy;
   6pub mod telemetry;
   7pub mod user;
   8pub mod zed_urls;
   9
  10use anyhow::{Context as _, Result, anyhow};
  11use async_tungstenite::tungstenite::{
  12    client::IntoClientRequest,
  13    error::Error as WebsocketError,
  14    http::{HeaderValue, Request, StatusCode},
  15};
  16use clock::SystemClock;
  17use cloud_api_client::websocket_protocol::MessageToClient;
  18use cloud_api_client::{ClientApiError, CloudApiClient};
  19use cloud_api_types::OrganizationId;
  20use credentials_provider::CredentialsProvider;
  21use feature_flags::FeatureFlagAppExt as _;
  22use futures::{
  23    AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
  24    channel::{mpsc, oneshot},
  25    future::BoxFuture,
  26};
  27use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
  28use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
  29use language_model::LlmApiToken;
  30use parking_lot::{Mutex, RwLock};
  31use postage::watch;
  32use proxy::connect_proxy_stream;
  33use rand::prelude::*;
  34use release_channel::{AppVersion, ReleaseChannel};
  35use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  36use serde::{Deserialize, Serialize};
  37use settings::{RegisterSetting, Settings, SettingsContent};
  38use std::{
  39    any::TypeId,
  40    convert::TryFrom,
  41    future::Future,
  42    marker::PhantomData,
  43    path::PathBuf,
  44    sync::{
  45        Arc, LazyLock, Weak,
  46        atomic::{AtomicU64, Ordering},
  47    },
  48    time::{Duration, Instant},
  49};
  50use std::{cmp, pin::Pin};
  51use telemetry::Telemetry;
  52use thiserror::Error;
  53use tokio::net::TcpStream;
  54use url::Url;
  55use util::{ConnectionResult, ResultExt};
  56
  57pub use llm_token::*;
  58pub use rpc::*;
  59pub use telemetry_events::Event;
  60pub use user::*;
  61
  62static ZED_SERVER_URL: LazyLock<Option<String>> =
  63    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  64static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  65
  66pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  67    std::env::var("ZED_IMPERSONATE")
  68        .ok()
  69        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  70});
  71
  72pub static USE_WEB_LOGIN: LazyLock<bool> = LazyLock::new(|| std::env::var("ZED_WEB_LOGIN").is_ok());
  73
  74pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  75    std::env::var("ZED_ADMIN_API_TOKEN")
  76        .ok()
  77        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  78});
  79
  80pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  81    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  82
  83pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  84    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").is_ok_and(|e| !e.is_empty()));
  85
  86pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  87pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
  88pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  89
  90actions!(
  91    client,
  92    [
  93        /// Signs in to Zed account.
  94        SignIn,
  95        /// Signs out of Zed account.
  96        SignOut,
  97        /// Reconnects to the collaboration server.
  98        Reconnect
  99    ]
 100);
 101
 102#[derive(Deserialize, RegisterSetting)]
 103pub struct ClientSettings {
 104    pub server_url: String,
 105}
 106
 107impl Settings for ClientSettings {
 108    fn from_settings(content: &settings::SettingsContent) -> Self {
 109        if let Some(server_url) = &*ZED_SERVER_URL {
 110            return Self {
 111                server_url: server_url.clone(),
 112            };
 113        }
 114        Self {
 115            server_url: content.server_url.clone().unwrap(),
 116        }
 117    }
 118}
 119
 120#[derive(Deserialize, Default, RegisterSetting)]
 121pub struct ProxySettings {
 122    pub proxy: Option<String>,
 123}
 124
 125impl ProxySettings {
 126    pub fn proxy_url(&self) -> Option<Url> {
 127        self.proxy
 128            .as_deref()
 129            .map(str::trim)
 130            .filter(|input| !input.is_empty())
 131            .and_then(|input| {
 132                input
 133                    .parse::<Url>()
 134                    .inspect_err(|e| log::error!("Error parsing proxy settings: {}", e))
 135                    .ok()
 136            })
 137            .or_else(read_proxy_from_env)
 138    }
 139}
 140
 141impl Settings for ProxySettings {
 142    fn from_settings(content: &settings::SettingsContent) -> Self {
 143        Self {
 144            proxy: content
 145                .proxy
 146                .as_deref()
 147                .map(str::trim)
 148                .filter(|proxy| !proxy.is_empty())
 149                .map(ToOwned::to_owned),
 150        }
 151    }
 152}
 153
 154pub fn init(client: &Arc<Client>, cx: &mut App) {
 155    let client = Arc::downgrade(client);
 156    cx.on_action({
 157        let client = client.clone();
 158        move |_: &SignIn, cx| {
 159            if let Some(client) = client.upgrade() {
 160                cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, cx).await)
 161                    .detach_and_log_err(cx);
 162            }
 163        }
 164    })
 165    .on_action({
 166        let client = client.clone();
 167        move |_: &SignOut, cx| {
 168            if let Some(client) = client.upgrade() {
 169                cx.spawn(async move |cx| {
 170                    client.sign_out(cx).await;
 171                })
 172                .detach();
 173            }
 174        }
 175    })
 176    .on_action({
 177        let client = client;
 178        move |_: &Reconnect, cx| {
 179            if let Some(client) = client.upgrade() {
 180                cx.spawn(async move |cx| {
 181                    client.reconnect(cx);
 182                })
 183                .detach();
 184            }
 185        }
 186    });
 187}
 188
 189pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
 190
 191struct GlobalClient(Arc<Client>);
 192
 193impl Global for GlobalClient {}
 194
 195pub struct Client {
 196    id: AtomicU64,
 197    peer: Arc<Peer>,
 198    http: Arc<HttpClientWithUrl>,
 199    cloud_client: Arc<CloudApiClient>,
 200    telemetry: Arc<Telemetry>,
 201    credentials_provider: ClientCredentialsProvider,
 202    state: RwLock<ClientState>,
 203    handler_set: Mutex<ProtoMessageHandlerSet>,
 204    message_to_client_handlers: Mutex<Vec<MessageToClientHandler>>,
 205    sign_out_tx: Mutex<Option<mpsc::UnboundedSender<()>>>,
 206
 207    #[allow(clippy::type_complexity)]
 208    #[cfg(any(test, feature = "test-support"))]
 209    authenticate:
 210        RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
 211
 212    #[allow(clippy::type_complexity)]
 213    #[cfg(any(test, feature = "test-support"))]
 214    establish_connection: RwLock<
 215        Option<
 216            Box<
 217                dyn 'static
 218                    + Send
 219                    + Sync
 220                    + Fn(
 221                        &Credentials,
 222                        &AsyncApp,
 223                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 224            >,
 225        >,
 226    >,
 227
 228    #[cfg(any(test, feature = "test-support"))]
 229    rpc_url: RwLock<Option<Url>>,
 230}
 231
 232#[derive(Error, Debug)]
 233pub enum EstablishConnectionError {
 234    #[error("upgrade required")]
 235    UpgradeRequired,
 236    #[error("unauthorized")]
 237    Unauthorized,
 238    #[error("{0}")]
 239    Other(#[from] anyhow::Error),
 240    #[error("{0}")]
 241    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 242    #[error("{0}")]
 243    Io(#[from] std::io::Error),
 244    #[error("{0}")]
 245    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 246}
 247
 248impl From<WebsocketError> for EstablishConnectionError {
 249    fn from(error: WebsocketError) -> Self {
 250        if let WebsocketError::Http(response) = &error {
 251            match response.status() {
 252                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 253                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 254                _ => {}
 255            }
 256        }
 257        EstablishConnectionError::Other(error.into())
 258    }
 259}
 260
 261impl EstablishConnectionError {
 262    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 263        Self::Other(error.into())
 264    }
 265}
 266
 267#[derive(Copy, Clone, Debug, PartialEq)]
 268pub enum Status {
 269    SignedOut,
 270    UpgradeRequired,
 271    Authenticating,
 272    Authenticated,
 273    AuthenticationError,
 274    Connecting,
 275    ConnectionError,
 276    Connected {
 277        peer_id: PeerId,
 278        connection_id: ConnectionId,
 279    },
 280    ConnectionLost,
 281    Reauthenticating,
 282    Reauthenticated,
 283    Reconnecting,
 284    ReconnectionError {
 285        next_reconnection: Instant,
 286    },
 287}
 288
 289impl Status {
 290    pub fn is_connected(&self) -> bool {
 291        matches!(self, Self::Connected { .. })
 292    }
 293
 294    pub fn was_connected(&self) -> bool {
 295        matches!(
 296            self,
 297            Self::ConnectionLost
 298                | Self::Reauthenticating
 299                | Self::Reauthenticated
 300                | Self::Reconnecting
 301        )
 302    }
 303
 304    /// Returns whether the client is currently connected or was connected at some point.
 305    pub fn is_or_was_connected(&self) -> bool {
 306        self.is_connected() || self.was_connected()
 307    }
 308
 309    pub fn is_signing_in(&self) -> bool {
 310        matches!(
 311            self,
 312            Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
 313        )
 314    }
 315
 316    pub fn is_signed_out(&self) -> bool {
 317        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 318    }
 319}
 320
 321struct ClientState {
 322    credentials: Option<Credentials>,
 323    status: (watch::Sender<Status>, watch::Receiver<Status>),
 324    _reconnect_task: Option<Task<()>>,
 325}
 326
 327#[derive(Clone, Debug, Eq, PartialEq)]
 328pub struct Credentials {
 329    pub user_id: u64,
 330    pub access_token: String,
 331}
 332
 333impl Credentials {
 334    pub fn authorization_header(&self) -> String {
 335        format!("{} {}", self.user_id, self.access_token)
 336    }
 337}
 338
 339pub struct ClientCredentialsProvider {
 340    provider: Arc<dyn CredentialsProvider>,
 341}
 342
 343impl ClientCredentialsProvider {
 344    pub fn new(cx: &App) -> Self {
 345        Self {
 346            provider: zed_credentials_provider::global(cx),
 347        }
 348    }
 349
 350    fn server_url(&self, cx: &AsyncApp) -> Result<String> {
 351        Ok(cx.update(|cx| ClientSettings::get_global(cx).server_url.clone()))
 352    }
 353
 354    /// Reads the credentials from the provider.
 355    fn read_credentials<'a>(
 356        &'a self,
 357        cx: &'a AsyncApp,
 358    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
 359        async move {
 360            if IMPERSONATE_LOGIN.is_some() {
 361                return None;
 362            }
 363
 364            let server_url = self.server_url(cx).ok()?;
 365            let (user_id, access_token) = self
 366                .provider
 367                .read_credentials(&server_url, cx)
 368                .await
 369                .log_err()
 370                .flatten()?;
 371
 372            Some(Credentials {
 373                user_id: user_id.parse().ok()?,
 374                access_token: String::from_utf8(access_token).ok()?,
 375            })
 376        }
 377        .boxed_local()
 378    }
 379
 380    /// Writes the credentials to the provider.
 381    fn write_credentials<'a>(
 382        &'a self,
 383        user_id: u64,
 384        access_token: String,
 385        cx: &'a AsyncApp,
 386    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 387        async move {
 388            let server_url = self.server_url(cx)?;
 389            self.provider
 390                .write_credentials(
 391                    &server_url,
 392                    &user_id.to_string(),
 393                    access_token.as_bytes(),
 394                    cx,
 395                )
 396                .await
 397        }
 398        .boxed_local()
 399    }
 400
 401    /// Deletes the credentials from the provider.
 402    fn delete_credentials<'a>(
 403        &'a self,
 404        cx: &'a AsyncApp,
 405    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 406        async move {
 407            let server_url = self.server_url(cx)?;
 408            self.provider.delete_credentials(&server_url, cx).await
 409        }
 410        .boxed_local()
 411    }
 412}
 413
 414impl Default for ClientState {
 415    fn default() -> Self {
 416        Self {
 417            credentials: None,
 418            status: watch::channel_with(Status::SignedOut),
 419            _reconnect_task: None,
 420        }
 421    }
 422}
 423
 424pub enum Subscription {
 425    Entity {
 426        client: Weak<Client>,
 427        id: (TypeId, u64),
 428    },
 429    Message {
 430        client: Weak<Client>,
 431        id: TypeId,
 432    },
 433}
 434
 435impl Drop for Subscription {
 436    fn drop(&mut self) {
 437        match self {
 438            Subscription::Entity { client, id } => {
 439                if let Some(client) = client.upgrade() {
 440                    let mut state = client.handler_set.lock();
 441                    let _ = state.entities_by_type_and_remote_id.remove(id);
 442                }
 443            }
 444            Subscription::Message { client, id } => {
 445                if let Some(client) = client.upgrade() {
 446                    let mut state = client.handler_set.lock();
 447                    let _ = state.entity_types_by_message_type.remove(id);
 448                    let _ = state.message_handlers.remove(id);
 449                }
 450            }
 451        }
 452    }
 453}
 454
 455pub struct PendingEntitySubscription<T: 'static> {
 456    client: Arc<Client>,
 457    remote_id: u64,
 458    _entity_type: PhantomData<T>,
 459    consumed: bool,
 460}
 461
 462impl<T: 'static> PendingEntitySubscription<T> {
 463    pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
 464        self.consumed = true;
 465        let mut handlers = self.client.handler_set.lock();
 466        let id = (TypeId::of::<T>(), self.remote_id);
 467        let Some(EntityMessageSubscriber::Pending(messages)) =
 468            handlers.entities_by_type_and_remote_id.remove(&id)
 469        else {
 470            unreachable!()
 471        };
 472
 473        handlers.entities_by_type_and_remote_id.insert(
 474            id,
 475            EntityMessageSubscriber::Entity {
 476                handle: entity.downgrade().into(),
 477            },
 478        );
 479        drop(handlers);
 480        for message in messages {
 481            let client_id = self.client.id();
 482            let type_name = message.payload_type_name();
 483            let sender_id = message.original_sender_id();
 484            log::debug!(
 485                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 486                client_id,
 487                sender_id,
 488                type_name
 489            );
 490            self.client.handle_message(message, cx);
 491        }
 492        Subscription::Entity {
 493            client: Arc::downgrade(&self.client),
 494            id,
 495        }
 496    }
 497}
 498
 499impl<T: 'static> Drop for PendingEntitySubscription<T> {
 500    fn drop(&mut self) {
 501        if !self.consumed {
 502            let mut state = self.client.handler_set.lock();
 503            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 504                .entities_by_type_and_remote_id
 505                .remove(&(TypeId::of::<T>(), self.remote_id))
 506            {
 507                for message in messages {
 508                    log::info!("unhandled message {}", message.payload_type_name());
 509                }
 510            }
 511        }
 512    }
 513}
 514
 515#[derive(Copy, Clone, Deserialize, Debug, RegisterSetting)]
 516pub struct TelemetrySettings {
 517    pub diagnostics: bool,
 518    pub metrics: bool,
 519}
 520
 521impl settings::Settings for TelemetrySettings {
 522    fn from_settings(content: &SettingsContent) -> Self {
 523        Self {
 524            diagnostics: content.telemetry.as_ref().unwrap().diagnostics.unwrap(),
 525            metrics: content.telemetry.as_ref().unwrap().metrics.unwrap(),
 526        }
 527    }
 528}
 529
 530impl Client {
 531    pub fn new(
 532        clock: Arc<dyn SystemClock>,
 533        http: Arc<HttpClientWithUrl>,
 534        cx: &mut App,
 535    ) -> Arc<Self> {
 536        Arc::new(Self {
 537            id: AtomicU64::new(0),
 538            peer: Peer::new(0),
 539            telemetry: Telemetry::new(clock, http.clone(), cx),
 540            cloud_client: Arc::new(CloudApiClient::new(http.clone())),
 541            http,
 542            credentials_provider: ClientCredentialsProvider::new(cx),
 543            state: Default::default(),
 544            handler_set: Default::default(),
 545            message_to_client_handlers: Mutex::new(Vec::new()),
 546            sign_out_tx: Mutex::new(None),
 547
 548            #[cfg(any(test, feature = "test-support"))]
 549            authenticate: Default::default(),
 550            #[cfg(any(test, feature = "test-support"))]
 551            establish_connection: Default::default(),
 552            #[cfg(any(test, feature = "test-support"))]
 553            rpc_url: RwLock::default(),
 554        })
 555    }
 556
 557    pub fn production(cx: &mut App) -> Arc<Self> {
 558        let clock = Arc::new(clock::RealSystemClock);
 559        let http = Arc::new(HttpClientWithUrl::new_url(
 560            cx.http_client(),
 561            &ClientSettings::get_global(cx).server_url,
 562            cx.http_client().proxy().cloned(),
 563        ));
 564        Self::new(clock, http, cx)
 565    }
 566
 567    pub fn id(&self) -> u64 {
 568        self.id.load(Ordering::SeqCst)
 569    }
 570
 571    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 572        self.http.clone()
 573    }
 574
 575    pub fn credentials_provider(&self) -> Arc<dyn CredentialsProvider> {
 576        self.credentials_provider.provider.clone()
 577    }
 578
 579    pub fn cloud_client(&self) -> Arc<CloudApiClient> {
 580        self.cloud_client.clone()
 581    }
 582
 583    pub fn set_id(&self, id: u64) -> &Self {
 584        self.id.store(id, Ordering::SeqCst);
 585        self
 586    }
 587
 588    #[cfg(any(test, feature = "test-support"))]
 589    pub fn teardown(&self) {
 590        let mut state = self.state.write();
 591        state._reconnect_task.take();
 592        self.handler_set.lock().clear();
 593        self.peer.teardown();
 594    }
 595
 596    #[cfg(any(test, feature = "test-support"))]
 597    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 598    where
 599        F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
 600    {
 601        *self.authenticate.write() = Some(Box::new(authenticate));
 602        self
 603    }
 604
 605    #[cfg(any(test, feature = "test-support"))]
 606    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 607    where
 608        F: 'static
 609            + Send
 610            + Sync
 611            + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
 612    {
 613        *self.establish_connection.write() = Some(Box::new(connect));
 614        self
 615    }
 616
 617    #[cfg(any(test, feature = "test-support"))]
 618    pub fn override_rpc_url(&self, url: Url) -> &Self {
 619        *self.rpc_url.write() = Some(url);
 620        self
 621    }
 622
 623    pub fn global(cx: &App) -> Arc<Self> {
 624        cx.global::<GlobalClient>().0.clone()
 625    }
 626    pub fn set_global(client: Arc<Client>, cx: &mut App) {
 627        cx.set_global(GlobalClient(client))
 628    }
 629
 630    pub fn user_id(&self) -> Option<u64> {
 631        self.state
 632            .read()
 633            .credentials
 634            .as_ref()
 635            .map(|credentials| credentials.user_id)
 636    }
 637
 638    pub fn peer_id(&self) -> Option<PeerId> {
 639        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 640            Some(*peer_id)
 641        } else {
 642            None
 643        }
 644    }
 645
 646    pub fn status(&self) -> watch::Receiver<Status> {
 647        self.state.read().status.1.clone()
 648    }
 649
 650    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
 651        log::info!("set status on client {}: {:?}", self.id(), status);
 652        let mut state = self.state.write();
 653        *state.status.0.borrow_mut() = status;
 654
 655        match status {
 656            Status::Connected { .. } => {
 657                state._reconnect_task = None;
 658            }
 659            Status::ConnectionLost => {
 660                let client = self.clone();
 661                state._reconnect_task = Some(cx.spawn(async move |cx| {
 662                    #[cfg(any(test, feature = "test-support"))]
 663                    let mut rng = StdRng::seed_from_u64(0);
 664                    #[cfg(not(any(test, feature = "test-support")))]
 665                    let mut rng = StdRng::from_os_rng();
 666
 667                    let mut delay = INITIAL_RECONNECTION_DELAY;
 668                    loop {
 669                        match client.connect(true, cx).await {
 670                            ConnectionResult::Timeout => {
 671                                log::error!("client connect attempt timed out")
 672                            }
 673                            ConnectionResult::ConnectionReset => {
 674                                log::error!("client connect attempt reset")
 675                            }
 676                            ConnectionResult::Result(r) => {
 677                                if let Err(error) = r {
 678                                    log::error!("failed to connect: {error}");
 679                                } else {
 680                                    break;
 681                                }
 682                            }
 683                        }
 684
 685                        if matches!(
 686                            *client.status().borrow(),
 687                            Status::AuthenticationError | Status::ConnectionError
 688                        ) {
 689                            client.set_status(
 690                                Status::ReconnectionError {
 691                                    next_reconnection: Instant::now() + delay,
 692                                },
 693                                cx,
 694                            );
 695                            let jitter = Duration::from_millis(
 696                                rng.random_range(0..delay.as_millis() as u64),
 697                            );
 698                            cx.background_executor().timer(delay + jitter).await;
 699                            delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
 700                        } else {
 701                            break;
 702                        }
 703                    }
 704                }));
 705            }
 706            Status::SignedOut | Status::UpgradeRequired => {
 707                self.telemetry.set_authenticated_user_info(None, false);
 708                state._reconnect_task.take();
 709            }
 710            _ => {}
 711        }
 712    }
 713
 714    pub fn subscribe_to_entity<T>(
 715        self: &Arc<Self>,
 716        remote_id: u64,
 717    ) -> Result<PendingEntitySubscription<T>>
 718    where
 719        T: 'static,
 720    {
 721        let id = (TypeId::of::<T>(), remote_id);
 722
 723        let mut state = self.handler_set.lock();
 724        anyhow::ensure!(
 725            !state.entities_by_type_and_remote_id.contains_key(&id),
 726            "already subscribed to entity"
 727        );
 728
 729        state
 730            .entities_by_type_and_remote_id
 731            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 732
 733        Ok(PendingEntitySubscription {
 734            client: self.clone(),
 735            remote_id,
 736            consumed: false,
 737            _entity_type: PhantomData,
 738        })
 739    }
 740
 741    #[track_caller]
 742    pub fn add_message_handler<M, E, H, F>(
 743        self: &Arc<Self>,
 744        entity: WeakEntity<E>,
 745        handler: H,
 746    ) -> Subscription
 747    where
 748        M: EnvelopedMessage,
 749        E: 'static,
 750        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 751        F: 'static + Future<Output = Result<()>>,
 752    {
 753        self.add_message_handler_impl(entity, move |entity, message, _, cx| {
 754            handler(entity, message, cx)
 755        })
 756    }
 757
 758    fn add_message_handler_impl<M, E, H, F>(
 759        self: &Arc<Self>,
 760        entity: WeakEntity<E>,
 761        handler: H,
 762    ) -> Subscription
 763    where
 764        M: EnvelopedMessage,
 765        E: 'static,
 766        H: 'static
 767            + Sync
 768            + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
 769            + Send
 770            + Sync,
 771        F: 'static + Future<Output = Result<()>>,
 772    {
 773        let message_type_id = TypeId::of::<M>();
 774        let mut state = self.handler_set.lock();
 775        state
 776            .entities_by_message_type
 777            .insert(message_type_id, entity.into());
 778
 779        let prev_handler = state.message_handlers.insert(
 780            message_type_id,
 781            Arc::new(move |subscriber, envelope, client, cx| {
 782                let subscriber = subscriber.downcast::<E>().unwrap();
 783                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 784                handler(subscriber, *envelope, client, cx).boxed_local()
 785            }),
 786        );
 787        if prev_handler.is_some() {
 788            let location = std::panic::Location::caller();
 789            panic!(
 790                "{}:{} registered handler for the same message {} twice",
 791                location.file(),
 792                location.line(),
 793                std::any::type_name::<M>()
 794            );
 795        }
 796
 797        Subscription::Message {
 798            client: Arc::downgrade(self),
 799            id: message_type_id,
 800        }
 801    }
 802
 803    pub fn add_request_handler<M, E, H, F>(
 804        self: &Arc<Self>,
 805        entity: WeakEntity<E>,
 806        handler: H,
 807    ) -> Subscription
 808    where
 809        M: RequestMessage,
 810        E: 'static,
 811        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 812        F: 'static + Future<Output = Result<M::Response>>,
 813    {
 814        self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
 815            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 816        })
 817    }
 818
 819    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 820        receipt: Receipt<T>,
 821        response: F,
 822        client: AnyProtoClient,
 823    ) -> Result<()> {
 824        match response.await {
 825            Ok(response) => {
 826                client.send_response(receipt.message_id, response)?;
 827                Ok(())
 828            }
 829            Err(error) => {
 830                client.send_response(receipt.message_id, error.to_proto())?;
 831                Err(error)
 832            }
 833        }
 834    }
 835
 836    pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
 837        self.credentials_provider
 838            .read_credentials(cx)
 839            .await
 840            .is_some()
 841    }
 842
 843    pub async fn sign_in(
 844        self: &Arc<Self>,
 845        try_provider: bool,
 846        cx: &AsyncApp,
 847    ) -> Result<Credentials> {
 848        let is_reauthenticating = if self.status().borrow().is_signed_out() {
 849            self.set_status(Status::Authenticating, cx);
 850            false
 851        } else {
 852            self.set_status(Status::Reauthenticating, cx);
 853            true
 854        };
 855
 856        let mut credentials = None;
 857
 858        let old_credentials = self.state.read().credentials.clone();
 859        if let Some(old_credentials) = old_credentials
 860            && self.validate_credentials(&old_credentials, cx).await?
 861        {
 862            credentials = Some(old_credentials);
 863        }
 864
 865        if credentials.is_none()
 866            && try_provider
 867            && let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await
 868        {
 869            if self.validate_credentials(&stored_credentials, cx).await? {
 870                credentials = Some(stored_credentials);
 871            } else {
 872                self.credentials_provider
 873                    .delete_credentials(cx)
 874                    .await
 875                    .log_err();
 876            }
 877        }
 878
 879        if credentials.is_none() {
 880            let mut status_rx = self.status();
 881            let _ = status_rx.next().await;
 882            futures::select_biased! {
 883                authenticate = self.authenticate(cx).fuse() => {
 884                    match authenticate {
 885                        Ok(creds) => {
 886                            if IMPERSONATE_LOGIN.is_none() {
 887                                self.credentials_provider
 888                                    .write_credentials(creds.user_id, creds.access_token.clone(), cx)
 889                                    .await
 890                                    .log_err();
 891                            }
 892
 893                            credentials = Some(creds);
 894                        },
 895                        Err(err) => {
 896                            self.set_status(Status::AuthenticationError, cx);
 897                            return Err(err);
 898                        }
 899                    }
 900                }
 901                _ = status_rx.next().fuse() => {
 902                    return Err(anyhow!("authentication canceled"));
 903                }
 904            }
 905        }
 906
 907        let credentials = credentials.unwrap();
 908        self.set_id(credentials.user_id);
 909        self.cloud_client
 910            .set_credentials(credentials.user_id as u32, credentials.access_token.clone());
 911        self.state.write().credentials = Some(credentials.clone());
 912        self.set_status(
 913            if is_reauthenticating {
 914                Status::Reauthenticated
 915            } else {
 916                Status::Authenticated
 917            },
 918            cx,
 919        );
 920
 921        Ok(credentials)
 922    }
 923
 924    async fn validate_credentials(
 925        self: &Arc<Self>,
 926        credentials: &Credentials,
 927        cx: &AsyncApp,
 928    ) -> Result<bool> {
 929        match self
 930            .cloud_client
 931            .validate_credentials(credentials.user_id as u32, &credentials.access_token)
 932            .await
 933        {
 934            Ok(valid) => Ok(valid),
 935            Err(err) => {
 936                self.set_status(Status::AuthenticationError, cx);
 937                Err(anyhow!("failed to validate credentials: {}", err))
 938            }
 939        }
 940    }
 941
 942    /// Establishes a WebSocket connection with Cloud for receiving updates from the server.
 943    async fn connect_to_cloud(self: &Arc<Self>, cx: &AsyncApp) -> Result<()> {
 944        let connect_task = cx.update({
 945            let cloud_client = self.cloud_client.clone();
 946            move |cx| cloud_client.connect(cx)
 947        })?;
 948        let connection = connect_task.await?;
 949
 950        let (mut messages, task) = cx.update(|cx| connection.spawn(cx));
 951        task.detach();
 952
 953        cx.spawn({
 954            let this = self.clone();
 955            async move |cx| {
 956                while let Some(message) = messages.next().await {
 957                    if let Some(message) = message.log_err() {
 958                        this.handle_message_to_client(message, cx);
 959                    }
 960                }
 961            }
 962        })
 963        .detach();
 964
 965        Ok(())
 966    }
 967
 968    /// Performs a sign-in and also (optionally) connects to Collab.
 969    ///
 970    /// Only Zed staff automatically connect to Collab.
 971    pub async fn sign_in_with_optional_connect(
 972        self: &Arc<Self>,
 973        try_provider: bool,
 974        cx: &AsyncApp,
 975    ) -> Result<()> {
 976        // Don't try to sign in again if we're already connected to Collab, as it will temporarily disconnect us.
 977        if self.status().borrow().is_connected() {
 978            return Ok(());
 979        }
 980
 981        let (is_staff_tx, is_staff_rx) = oneshot::channel::<bool>();
 982        let mut is_staff_tx = Some(is_staff_tx);
 983        cx.update(|cx| {
 984            cx.on_flags_ready(move |state, _cx| {
 985                if let Some(is_staff_tx) = is_staff_tx.take() {
 986                    is_staff_tx.send(state.is_staff).log_err();
 987                }
 988            })
 989            .detach();
 990        });
 991
 992        let credentials = self.sign_in(try_provider, cx).await?;
 993
 994        self.connect_to_cloud(cx).await.log_err();
 995
 996        cx.update(move |cx| {
 997            cx.spawn({
 998                let client = self.clone();
 999                async move |cx| {
1000                    let is_staff = is_staff_rx.await?;
1001                    if is_staff {
1002                        match client.connect_with_credentials(credentials, cx).await {
1003                            ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
1004                            ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
1005                            ConnectionResult::Result(result) => {
1006                                result.context("client auth and connect")
1007                            }
1008                        }
1009                    } else {
1010                        Ok(())
1011                    }
1012                }
1013            })
1014            .detach_and_log_err(cx);
1015        });
1016
1017        Ok(())
1018    }
1019
1020    pub async fn connect(
1021        self: &Arc<Self>,
1022        try_provider: bool,
1023        cx: &AsyncApp,
1024    ) -> ConnectionResult<()> {
1025        let was_disconnected = match *self.status().borrow() {
1026            Status::SignedOut | Status::Authenticated => true,
1027            Status::ConnectionError
1028            | Status::ConnectionLost
1029            | Status::Authenticating
1030            | Status::AuthenticationError
1031            | Status::Reauthenticating
1032            | Status::Reauthenticated
1033            | Status::ReconnectionError { .. } => false,
1034            Status::Connected { .. } | Status::Connecting | Status::Reconnecting => {
1035                return ConnectionResult::Result(Ok(()));
1036            }
1037            Status::UpgradeRequired => {
1038                return ConnectionResult::Result(
1039                    Err(EstablishConnectionError::UpgradeRequired)
1040                        .context("client auth and connect"),
1041                );
1042            }
1043        };
1044        let credentials = match self.sign_in(try_provider, cx).await {
1045            Ok(credentials) => credentials,
1046            Err(err) => return ConnectionResult::Result(Err(err)),
1047        };
1048
1049        if was_disconnected {
1050            self.set_status(Status::Connecting, cx);
1051        } else {
1052            self.set_status(Status::Reconnecting, cx);
1053        }
1054
1055        self.connect_with_credentials(credentials, cx).await
1056    }
1057
1058    async fn connect_with_credentials(
1059        self: &Arc<Self>,
1060        credentials: Credentials,
1061        cx: &AsyncApp,
1062    ) -> ConnectionResult<()> {
1063        let mut timeout =
1064            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
1065        futures::select_biased! {
1066            connection = self.establish_connection(&credentials, cx).fuse() => {
1067                match connection {
1068                    Ok(conn) => {
1069                        futures::select_biased! {
1070                            result = self.set_connection(conn, cx).fuse() => {
1071                                match result.context("client auth and connect") {
1072                                    Ok(()) => ConnectionResult::Result(Ok(())),
1073                                    Err(err) => {
1074                                        self.set_status(Status::ConnectionError, cx);
1075                                        ConnectionResult::Result(Err(err))
1076                                    },
1077                                }
1078                            },
1079                            _ = timeout => {
1080                                self.set_status(Status::ConnectionError, cx);
1081                                ConnectionResult::Timeout
1082                            }
1083                        }
1084                    }
1085                    Err(EstablishConnectionError::Unauthorized) => {
1086                        self.set_status(Status::ConnectionError, cx);
1087                        ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
1088                    }
1089                    Err(EstablishConnectionError::UpgradeRequired) => {
1090                        self.set_status(Status::UpgradeRequired, cx);
1091                        ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
1092                    }
1093                    Err(error) => {
1094                        self.set_status(Status::ConnectionError, cx);
1095                        ConnectionResult::Result(Err(error).context("client auth and connect"))
1096                    }
1097                }
1098            }
1099            _ = &mut timeout => {
1100                self.set_status(Status::ConnectionError, cx);
1101                ConnectionResult::Timeout
1102            }
1103        }
1104    }
1105
1106    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
1107        let executor = cx.background_executor();
1108        log::debug!("add connection to peer");
1109        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
1110            let executor = executor.clone();
1111            move |duration| executor.timer(duration)
1112        });
1113        let handle_io = executor.spawn(handle_io);
1114
1115        let peer_id = async {
1116            log::debug!("waiting for server hello");
1117            let message = incoming.next().await.context("no hello message received")?;
1118            log::debug!("got server hello");
1119            let hello_message_type_name = message.payload_type_name().to_string();
1120            let hello = message
1121                .into_any()
1122                .downcast::<TypedEnvelope<proto::Hello>>()
1123                .map_err(|_| {
1124                    anyhow!(
1125                        "invalid hello message received: {:?}",
1126                        hello_message_type_name
1127                    )
1128                })?;
1129            let peer_id = hello.payload.peer_id.context("invalid peer id")?;
1130            Ok(peer_id)
1131        };
1132
1133        let peer_id = match peer_id.await {
1134            Ok(peer_id) => peer_id,
1135            Err(error) => {
1136                self.peer.disconnect(connection_id);
1137                return Err(error);
1138            }
1139        };
1140
1141        log::debug!(
1142            "set status to connected (connection id: {:?}, peer id: {:?})",
1143            connection_id,
1144            peer_id
1145        );
1146        self.set_status(
1147            Status::Connected {
1148                peer_id,
1149                connection_id,
1150            },
1151            cx,
1152        );
1153
1154        cx.spawn({
1155            let this = self.clone();
1156            async move |cx| {
1157                while let Some(message) = incoming.next().await {
1158                    this.handle_message(message, cx);
1159                    // Don't starve the main thread when receiving lots of messages at once.
1160                    smol::future::yield_now().await;
1161                }
1162            }
1163        })
1164        .detach();
1165
1166        cx.spawn({
1167            let this = self.clone();
1168            async move |cx| match handle_io.await {
1169                Ok(()) => {
1170                    if *this.status().borrow()
1171                        == (Status::Connected {
1172                            connection_id,
1173                            peer_id,
1174                        })
1175                    {
1176                        this.set_status(Status::SignedOut, cx);
1177                    }
1178                }
1179                Err(err) => {
1180                    log::error!("connection error: {:?}", err);
1181                    this.set_status(Status::ConnectionLost, cx);
1182                }
1183            }
1184        })
1185        .detach();
1186
1187        Ok(())
1188    }
1189
1190    fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1191        #[cfg(any(test, feature = "test-support"))]
1192        if let Some(callback) = self.authenticate.read().as_ref() {
1193            return callback(cx);
1194        }
1195
1196        self.authenticate_with_browser(cx)
1197    }
1198
1199    fn establish_connection(
1200        self: &Arc<Self>,
1201        credentials: &Credentials,
1202        cx: &AsyncApp,
1203    ) -> Task<Result<Connection, EstablishConnectionError>> {
1204        #[cfg(any(test, feature = "test-support"))]
1205        if let Some(callback) = self.establish_connection.read().as_ref() {
1206            return callback(credentials, cx);
1207        }
1208
1209        self.establish_websocket_connection(credentials, cx)
1210    }
1211
1212    fn rpc_url(
1213        &self,
1214        http: Arc<HttpClientWithUrl>,
1215        release_channel: Option<ReleaseChannel>,
1216    ) -> impl Future<Output = Result<url::Url>> + use<> {
1217        #[cfg(any(test, feature = "test-support"))]
1218        let url_override = self.rpc_url.read().clone();
1219
1220        async move {
1221            #[cfg(any(test, feature = "test-support"))]
1222            if let Some(url) = url_override {
1223                return Ok(url);
1224            }
1225
1226            if let Some(url) = &*ZED_RPC_URL {
1227                return Url::parse(url).context("invalid rpc url");
1228            }
1229
1230            let mut url = http.build_url("/rpc");
1231            if let Some(preview_param) =
1232                release_channel.and_then(|channel| channel.release_query_param())
1233            {
1234                url += "?";
1235                url += preview_param;
1236            }
1237
1238            let response = http.get(&url, Default::default(), false).await?;
1239            anyhow::ensure!(
1240                response.status().is_redirection(),
1241                "unexpected /rpc response status {}",
1242                response.status()
1243            );
1244            let collab_url = response
1245                .headers()
1246                .get("Location")
1247                .context("missing location header in /rpc response")?
1248                .to_str()
1249                .map_err(EstablishConnectionError::other)?
1250                .to_string();
1251            Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
1252        }
1253    }
1254
1255    fn establish_websocket_connection(
1256        self: &Arc<Self>,
1257        credentials: &Credentials,
1258        cx: &AsyncApp,
1259    ) -> Task<Result<Connection, EstablishConnectionError>> {
1260        let release_channel = cx.update(|cx| ReleaseChannel::try_global(cx));
1261        let app_version = cx.update(|cx| AppVersion::global(cx).to_string());
1262
1263        let http = self.http.clone();
1264        let proxy = http.proxy().cloned();
1265        let user_agent = http.user_agent().cloned();
1266        let credentials = credentials.clone();
1267        let rpc_url = self.rpc_url(http, release_channel);
1268        let system_id = self.telemetry.system_id();
1269        let metrics_id = self.telemetry.metrics_id();
1270        cx.spawn(async move |cx| {
1271            use HttpOrHttps::*;
1272
1273            #[derive(Debug)]
1274            enum HttpOrHttps {
1275                Http,
1276                Https,
1277            }
1278
1279            let mut rpc_url = rpc_url.await?;
1280            let url_scheme = match rpc_url.scheme() {
1281                "https" => Https,
1282                "http" => Http,
1283                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1284            };
1285
1286            let stream = gpui_tokio::Tokio::spawn_result(cx, {
1287                let rpc_url = rpc_url.clone();
1288                async move {
1289                    let rpc_host = rpc_url
1290                        .host_str()
1291                        .zip(rpc_url.port_or_known_default())
1292                        .context("missing host in rpc url")?;
1293                    Ok(match proxy {
1294                        Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
1295                        None => Box::new(TcpStream::connect(rpc_host).await?),
1296                    })
1297                }
1298            })
1299            .await?;
1300
1301            log::info!("connected to rpc endpoint {}", rpc_url);
1302
1303            rpc_url
1304                .set_scheme(match url_scheme {
1305                    Https => "wss",
1306                    Http => "ws",
1307                })
1308                .unwrap();
1309
1310            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1311            // for us from the RPC URL.
1312            //
1313            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1314            let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1315
1316            // We then modify the request to add our desired headers.
1317            let request_headers = request.headers_mut();
1318            request_headers.insert(
1319                http::header::AUTHORIZATION,
1320                HeaderValue::from_str(&credentials.authorization_header())?,
1321            );
1322            request_headers.insert(
1323                "x-zed-protocol-version",
1324                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1325            );
1326            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1327            request_headers.insert(
1328                "x-zed-release-channel",
1329                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1330            );
1331            if let Some(user_agent) = user_agent {
1332                request_headers.insert(http::header::USER_AGENT, user_agent);
1333            }
1334            if let Some(system_id) = system_id {
1335                request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1336            }
1337            if let Some(metrics_id) = metrics_id {
1338                request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1339            }
1340
1341            let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1342                request,
1343                stream,
1344                Some(Arc::new(http_client_tls::tls_config()).into()),
1345                None,
1346            )
1347            .await?;
1348
1349            Ok(Connection::new(
1350                stream
1351                    .map_err(|error| anyhow!(error))
1352                    .sink_map_err(|error| anyhow!(error)),
1353            ))
1354        })
1355    }
1356
1357    pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1358        let http = self.http.clone();
1359        let this = self.clone();
1360        cx.spawn(async move |cx| {
1361            let background = cx.background_executor().clone();
1362
1363            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1364            cx.update(|cx| {
1365                cx.spawn(async move |cx| {
1366                    if let Ok(url) = open_url_rx.await {
1367                        cx.update(|cx| cx.open_url(&url));
1368                    }
1369                })
1370                .detach();
1371            });
1372
1373            let credentials = background
1374                .clone()
1375                .spawn(async move {
1376                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1377                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1378                    // any other app running on the user's device.
1379                    let (public_key, private_key) =
1380                        rpc::auth::keypair().context("failed to generate keypair for auth")?;
1381                    let public_key_string = String::try_from(public_key)
1382                        .context("failed to serialize public key for auth")?;
1383
1384                    if let Some((login, token)) =
1385                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1386                    {
1387                        if !*USE_WEB_LOGIN {
1388                            eprintln!("authenticate as admin {login}, {token}");
1389
1390                            return this
1391                                .authenticate_as_admin(http, login.clone(), token.clone())
1392                                .await;
1393                        }
1394                    }
1395
1396                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1397                    let server = tiny_http::Server::http("127.0.0.1:0")
1398                        .map_err(|e| anyhow!(e).context("failed to bind callback port"))?;
1399                    let port = server
1400                        .server_addr()
1401                        .to_ip()
1402                        .context("server not bound to a TCP address")?
1403                        .port();
1404
1405                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1406                    // that the user is signing in from a Zed app running on the same device.
1407                    let url = http.build_url(&format!(
1408                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1409                        port, public_key_string
1410                    ));
1411
1412                    open_url_tx.send(url).log_err();
1413
1414                    #[derive(Deserialize)]
1415                    struct CallbackParams {
1416                        pub user_id: String,
1417                        pub access_token: String,
1418                    }
1419
1420                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1421                    // access token from the query params.
1422                    //
1423                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1424                    // custom URL scheme instead of this local HTTP server.
1425                    let (user_id, access_token) = background
1426                        .spawn(async move {
1427                            for _ in 0..100 {
1428                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1429                                    let path = req.url();
1430                                    let url = Url::parse(&format!("http://example.com{}", path))
1431                                        .context("failed to parse login notification url")?;
1432                                    let callback_params: CallbackParams =
1433                                        serde_urlencoded::from_str(url.query().unwrap_or_default())
1434                                            .context(
1435                                                "failed to parse sign-in callback query parameters",
1436                                            )?;
1437
1438                                    let post_auth_url =
1439                                        http.build_url("/native_app_signin_succeeded");
1440                                    req.respond(
1441                                        tiny_http::Response::empty(302).with_header(
1442                                            tiny_http::Header::from_bytes(
1443                                                &b"Location"[..],
1444                                                post_auth_url.as_bytes(),
1445                                            )
1446                                            .unwrap(),
1447                                        ),
1448                                    )
1449                                    .context("failed to respond to login http request")?;
1450                                    return Ok((
1451                                        callback_params.user_id,
1452                                        callback_params.access_token,
1453                                    ));
1454                                }
1455                            }
1456
1457                            anyhow::bail!("didn't receive login redirect");
1458                        })
1459                        .await?;
1460
1461                    let access_token = private_key
1462                        .decrypt_string(&access_token)
1463                        .context("failed to decrypt access token")?;
1464
1465                    Ok(Credentials {
1466                        user_id: user_id.parse()?,
1467                        access_token,
1468                    })
1469                })
1470                .await?;
1471
1472            cx.update(|cx| cx.activate(true));
1473            Ok(credentials)
1474        })
1475    }
1476
1477    async fn authenticate_as_admin(
1478        self: &Arc<Self>,
1479        http: Arc<HttpClientWithUrl>,
1480        login: String,
1481        api_token: String,
1482    ) -> Result<Credentials> {
1483        #[derive(Serialize)]
1484        struct ImpersonateUserBody {
1485            github_login: String,
1486        }
1487
1488        #[derive(Deserialize)]
1489        struct ImpersonateUserResponse {
1490            user_id: u64,
1491            access_token: String,
1492        }
1493
1494        let url = self
1495            .http
1496            .build_zed_cloud_url("/internal/users/impersonate")?;
1497        let request = Request::post(url.as_str())
1498            .header("Content-Type", "application/json")
1499            .header("Authorization", format!("Bearer {api_token}"))
1500            .body(
1501                serde_json::to_string(&ImpersonateUserBody {
1502                    github_login: login,
1503                })?
1504                .into(),
1505            )?;
1506
1507        let mut response = http.send(request).await?;
1508        let mut body = String::new();
1509        response.body_mut().read_to_string(&mut body).await?;
1510        anyhow::ensure!(
1511            response.status().is_success(),
1512            "admin user request failed {} - {}",
1513            response.status().as_u16(),
1514            body,
1515        );
1516        let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
1517
1518        Ok(Credentials {
1519            user_id: response.user_id,
1520            access_token: response.access_token,
1521        })
1522    }
1523
1524    pub async fn acquire_llm_token(
1525        &self,
1526        llm_token: &LlmApiToken,
1527        organization_id: Option<OrganizationId>,
1528    ) -> Result<String> {
1529        let system_id = self.telemetry().system_id().map(|x| x.to_string());
1530        let cloud_client = self.cloud_client();
1531        match llm_token
1532            .acquire(&cloud_client, system_id, organization_id)
1533            .await
1534        {
1535            Ok(token) => Ok(token),
1536            Err(ClientApiError::Unauthorized) => {
1537                self.request_sign_out();
1538                Err(ClientApiError::Unauthorized).context("Failed to create LLM token")
1539            }
1540            Err(err) => Err(anyhow::Error::from(err)),
1541        }
1542    }
1543
1544    pub async fn refresh_llm_token(
1545        &self,
1546        llm_token: &LlmApiToken,
1547        organization_id: Option<OrganizationId>,
1548    ) -> Result<String> {
1549        let system_id = self.telemetry().system_id().map(|x| x.to_string());
1550        let cloud_client = self.cloud_client();
1551        match llm_token
1552            .refresh(&cloud_client, system_id, organization_id)
1553            .await
1554        {
1555            Ok(token) => Ok(token),
1556            Err(ClientApiError::Unauthorized) => {
1557                self.request_sign_out();
1558                return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
1559            }
1560            Err(err) => return Err(anyhow::Error::from(err)),
1561        }
1562    }
1563
1564    pub async fn clear_and_refresh_llm_token(
1565        &self,
1566        llm_token: &LlmApiToken,
1567        organization_id: Option<OrganizationId>,
1568    ) -> Result<String> {
1569        let system_id = self.telemetry().system_id().map(|x| x.to_string());
1570        let cloud_client = self.cloud_client();
1571        match llm_token
1572            .clear_and_refresh(&cloud_client, system_id, organization_id)
1573            .await
1574        {
1575            Ok(token) => Ok(token),
1576            Err(ClientApiError::Unauthorized) => {
1577                self.request_sign_out();
1578                return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
1579            }
1580            Err(err) => return Err(anyhow::Error::from(err)),
1581        }
1582    }
1583
1584    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1585        self.state.write().credentials = None;
1586        self.cloud_client.clear_credentials();
1587        self.disconnect(cx);
1588
1589        if self.has_credentials(cx).await {
1590            self.credentials_provider
1591                .delete_credentials(cx)
1592                .await
1593                .log_err();
1594        }
1595    }
1596
1597    /// Requests a sign out to be performed asynchronously.
1598    pub fn request_sign_out(&self) {
1599        if let Some(sign_out_tx) = self.sign_out_tx.lock().clone() {
1600            sign_out_tx.unbounded_send(()).ok();
1601        }
1602    }
1603
1604    pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1605        self.peer.teardown();
1606        self.set_status(Status::SignedOut, cx);
1607    }
1608
1609    pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1610        self.peer.teardown();
1611        self.set_status(Status::ConnectionLost, cx);
1612    }
1613
1614    fn connection_id(&self) -> Result<ConnectionId> {
1615        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1616            Ok(connection_id)
1617        } else {
1618            anyhow::bail!("not connected");
1619        }
1620    }
1621
1622    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1623        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1624        self.peer.send(self.connection_id()?, message)
1625    }
1626
1627    pub fn request<T: RequestMessage>(
1628        &self,
1629        request: T,
1630    ) -> impl Future<Output = Result<T::Response>> + use<T> {
1631        self.request_envelope(request)
1632            .map_ok(|envelope| envelope.payload)
1633    }
1634
1635    pub fn request_stream<T: RequestMessage>(
1636        &self,
1637        request: T,
1638    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1639        let client_id = self.id.load(Ordering::SeqCst);
1640        log::debug!(
1641            "rpc request start. client_id:{}. name:{}",
1642            client_id,
1643            T::NAME
1644        );
1645        let response = self
1646            .connection_id()
1647            .map(|conn_id| self.peer.request_stream(conn_id, request));
1648        async move {
1649            let response = response?.await;
1650            log::debug!(
1651                "rpc request finish. client_id:{}. name:{}",
1652                client_id,
1653                T::NAME
1654            );
1655            response
1656        }
1657    }
1658
1659    pub fn request_envelope<T: RequestMessage>(
1660        &self,
1661        request: T,
1662    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1663        let client_id = self.id();
1664        log::debug!(
1665            "rpc request start. client_id:{}. name:{}",
1666            client_id,
1667            T::NAME
1668        );
1669        let response = self
1670            .connection_id()
1671            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1672        async move {
1673            let response = response?.await;
1674            log::debug!(
1675                "rpc request finish. client_id:{}. name:{}",
1676                client_id,
1677                T::NAME
1678            );
1679            response
1680        }
1681    }
1682
1683    pub fn request_dynamic(
1684        &self,
1685        envelope: proto::Envelope,
1686        request_type: &'static str,
1687    ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1688        let client_id = self.id();
1689        log::debug!(
1690            "rpc request start. client_id:{}. name:{}",
1691            client_id,
1692            request_type
1693        );
1694        let response = self
1695            .connection_id()
1696            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1697        async move {
1698            let response = response?.await;
1699            log::debug!(
1700                "rpc request finish. client_id:{}. name:{}",
1701                client_id,
1702                request_type
1703            );
1704            Ok(response?.0)
1705        }
1706    }
1707
1708    fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1709        let sender_id = message.sender_id();
1710        let request_id = message.message_id();
1711        let type_name = message.payload_type_name();
1712        let original_sender_id = message.original_sender_id();
1713
1714        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1715            &self.handler_set,
1716            message,
1717            self.clone().into(),
1718            cx.clone(),
1719        ) {
1720            let client_id = self.id();
1721            log::debug!(
1722                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1723                client_id,
1724                original_sender_id,
1725                type_name
1726            );
1727            cx.spawn(async move |_| match future.await {
1728                Ok(()) => {
1729                    log::debug!("rpc message handled. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}");
1730                }
1731                Err(error) => {
1732                    log::error!("error handling message. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}, error:{error:#}");
1733                }
1734            })
1735            .detach();
1736        } else {
1737            log::info!("unhandled message {}", type_name);
1738            self.peer
1739                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1740                .log_err();
1741        }
1742    }
1743
1744    pub fn add_message_to_client_handler(
1745        self: &Arc<Client>,
1746        handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
1747    ) {
1748        self.message_to_client_handlers
1749            .lock()
1750            .push(Box::new(handler));
1751    }
1752
1753    fn handle_message_to_client(self: &Arc<Client>, message: MessageToClient, cx: &AsyncApp) {
1754        cx.update(|cx| {
1755            for handler in self.message_to_client_handlers.lock().iter() {
1756                handler(&message, cx);
1757            }
1758        });
1759    }
1760
1761    pub fn telemetry(&self) -> &Arc<Telemetry> {
1762        &self.telemetry
1763    }
1764}
1765
1766impl ProtoClient for Client {
1767    fn request(
1768        &self,
1769        envelope: proto::Envelope,
1770        request_type: &'static str,
1771    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1772        self.request_dynamic(envelope, request_type).boxed()
1773    }
1774
1775    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1776        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1777        let connection_id = self.connection_id()?;
1778        self.peer.send_dynamic(connection_id, envelope)
1779    }
1780
1781    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1782        log::debug!(
1783            "rpc respond. client_id:{}, name:{}",
1784            self.id(),
1785            message_type
1786        );
1787        let connection_id = self.connection_id()?;
1788        self.peer.send_dynamic(connection_id, envelope)
1789    }
1790
1791    fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1792        &self.handler_set
1793    }
1794
1795    fn is_via_collab(&self) -> bool {
1796        true
1797    }
1798
1799    fn has_wsl_interop(&self) -> bool {
1800        false
1801    }
1802}
1803
1804/// prefix for the zed:// url scheme
1805pub const ZED_URL_SCHEME: &str = "zed";
1806
1807/// A parsed Zed link that can be handled internally by the application.
1808#[derive(Debug, Clone, PartialEq, Eq)]
1809pub enum ZedLink {
1810    /// Join a channel: `zed.dev/channel/channel-name-123` or `zed://channel/channel-name-123`
1811    Channel { channel_id: u64 },
1812    /// Open channel notes: `zed.dev/channel/channel-name-123/notes` or with heading `notes#heading`
1813    ChannelNotes {
1814        channel_id: u64,
1815        heading: Option<String>,
1816    },
1817}
1818
1819/// Parses the given link into a Zed link.
1820///
1821/// Returns a [`Some`] containing the parsed link if the link is a recognized Zed link
1822/// that should be handled internally by the application.
1823/// Returns [`None`] for links that should be opened in the browser.
1824pub fn parse_zed_link(link: &str, cx: &App) -> Option<ZedLink> {
1825    let server_url = &ClientSettings::get_global(cx).server_url;
1826    let path = link
1827        .strip_prefix(server_url)
1828        .and_then(|result| result.strip_prefix('/'))
1829        .or_else(|| {
1830            link.strip_prefix(ZED_URL_SCHEME)
1831                .and_then(|result| result.strip_prefix("://"))
1832        })?;
1833
1834    let mut parts = path.split('/');
1835
1836    if parts.next() != Some("channel") {
1837        return None;
1838    }
1839
1840    let slug = parts.next()?;
1841    let id_str = slug.split('-').next_back()?;
1842    let channel_id = id_str.parse::<u64>().ok()?;
1843
1844    let Some(next) = parts.next() else {
1845        return Some(ZedLink::Channel { channel_id });
1846    };
1847
1848    if let Some(heading) = next.strip_prefix("notes#") {
1849        return Some(ZedLink::ChannelNotes {
1850            channel_id,
1851            heading: Some(heading.to_string()),
1852        });
1853    }
1854
1855    if next == "notes" {
1856        return Some(ZedLink::ChannelNotes {
1857            channel_id,
1858            heading: None,
1859        });
1860    }
1861
1862    None
1863}
1864
1865#[cfg(test)]
1866mod tests {
1867    use super::*;
1868    use crate::test::{FakeServer, parse_authorization_header};
1869
1870    use clock::FakeSystemClock;
1871    use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1872    use http_client::FakeHttpClient;
1873    use parking_lot::Mutex;
1874    use proto::TypedEnvelope;
1875    use settings::SettingsStore;
1876    use std::future;
1877
1878    #[test]
1879    fn test_proxy_settings_trims_and_ignores_empty_proxy() {
1880        let mut content = SettingsContent::default();
1881        content.proxy = Some("   ".to_owned());
1882        assert_eq!(ProxySettings::from_settings(&content).proxy, None);
1883
1884        content.proxy = Some("http://127.0.0.1:10809".to_owned());
1885        assert_eq!(
1886            ProxySettings::from_settings(&content).proxy.as_deref(),
1887            Some("http://127.0.0.1:10809")
1888        );
1889    }
1890
1891    #[gpui::test(iterations = 10)]
1892    async fn test_reconnection(cx: &mut TestAppContext) {
1893        init_test(cx);
1894        let user_id = 5;
1895        let client = cx.update(|cx| {
1896            Client::new(
1897                Arc::new(FakeSystemClock::new()),
1898                FakeHttpClient::with_404_response(),
1899                cx,
1900            )
1901        });
1902        let server = FakeServer::for_client(user_id, &client, cx).await;
1903        let mut status = client.status();
1904        assert!(matches!(
1905            status.next().await,
1906            Some(Status::Connected { .. })
1907        ));
1908        assert_eq!(server.auth_count(), 1);
1909
1910        server.forbid_connections();
1911        server.disconnect();
1912        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1913
1914        server.allow_connections();
1915        cx.executor().advance_clock(Duration::from_secs(10));
1916        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1917        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1918
1919        server.forbid_connections();
1920        server.disconnect();
1921        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1922
1923        // Clear cached credentials after authentication fails
1924        server.roll_access_token();
1925        server.allow_connections();
1926        cx.executor().run_until_parked();
1927        cx.executor().advance_clock(Duration::from_secs(10));
1928        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1929        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1930    }
1931
1932    #[gpui::test(iterations = 10)]
1933    async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) {
1934        init_test(cx);
1935        let http_client = FakeHttpClient::with_200_response();
1936        let client =
1937            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1938        let server = FakeServer::for_client(42, &client, cx).await;
1939        let mut status = client.status();
1940        assert!(matches!(
1941            status.next().await,
1942            Some(Status::Connected { .. })
1943        ));
1944        assert_eq!(server.auth_count(), 1);
1945
1946        // Simulate an auth failure during reconnection.
1947        http_client
1948            .as_fake()
1949            .replace_handler(|_, _request| async move {
1950                Ok(http_client::Response::builder()
1951                    .status(503)
1952                    .body("".into())
1953                    .unwrap())
1954            });
1955        server.disconnect();
1956        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1957
1958        // Restore the ability to authenticate.
1959        http_client
1960            .as_fake()
1961            .replace_handler(|_, _request| async move {
1962                Ok(http_client::Response::builder()
1963                    .status(200)
1964                    .body("".into())
1965                    .unwrap())
1966            });
1967        cx.executor().advance_clock(Duration::from_secs(10));
1968        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1969        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1970    }
1971
1972    #[gpui::test(iterations = 10)]
1973    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1974        init_test(cx);
1975        let user_id = 5;
1976        let client = cx.update(|cx| {
1977            Client::new(
1978                Arc::new(FakeSystemClock::new()),
1979                FakeHttpClient::with_404_response(),
1980                cx,
1981            )
1982        });
1983        let mut status = client.status();
1984
1985        // Time out when client tries to connect.
1986        client.override_authenticate(move |cx| {
1987            cx.background_spawn(async move {
1988                Ok(Credentials {
1989                    user_id,
1990                    access_token: "token".into(),
1991                })
1992            })
1993        });
1994        client.override_establish_connection(|_, cx| {
1995            cx.background_spawn(async move {
1996                future::pending::<()>().await;
1997                unreachable!()
1998            })
1999        });
2000        let auth_and_connect = cx.spawn({
2001            let client = client.clone();
2002            |cx| async move { client.connect(false, &cx).await }
2003        });
2004        executor.run_until_parked();
2005        assert!(matches!(status.next().await, Some(Status::Connecting)));
2006
2007        executor.advance_clock(CONNECTION_TIMEOUT);
2008        assert!(matches!(status.next().await, Some(Status::ConnectionError)));
2009        auth_and_connect.await.into_response().unwrap_err();
2010
2011        // Allow the connection to be established.
2012        let server = FakeServer::for_client(user_id, &client, cx).await;
2013        assert!(matches!(
2014            status.next().await,
2015            Some(Status::Connected { .. })
2016        ));
2017
2018        // Disconnect client.
2019        server.forbid_connections();
2020        server.disconnect();
2021        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
2022
2023        // Time out when re-establishing the connection.
2024        server.allow_connections();
2025        client.override_establish_connection(|_, cx| {
2026            cx.background_spawn(async move {
2027                future::pending::<()>().await;
2028                unreachable!()
2029            })
2030        });
2031        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
2032        assert!(matches!(status.next().await, Some(Status::Reconnecting)));
2033
2034        executor.advance_clock(CONNECTION_TIMEOUT);
2035        assert!(matches!(
2036            status.next().await,
2037            Some(Status::ReconnectionError { .. })
2038        ));
2039    }
2040
2041    #[gpui::test(iterations = 10)]
2042    async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
2043        init_test(cx);
2044        let auth_count = Arc::new(Mutex::new(0));
2045        let http_client = FakeHttpClient::create(|_request| async move {
2046            Ok(http_client::Response::builder()
2047                .status(200)
2048                .body("".into())
2049                .unwrap())
2050        });
2051        let client =
2052            cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
2053        client.override_authenticate({
2054            let auth_count = auth_count.clone();
2055            move |cx| {
2056                let auth_count = auth_count.clone();
2057                cx.background_spawn(async move {
2058                    *auth_count.lock() += 1;
2059                    Ok(Credentials {
2060                        user_id: 1,
2061                        access_token: auth_count.lock().to_string(),
2062                    })
2063                })
2064            }
2065        });
2066
2067        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2068        assert_eq!(*auth_count.lock(), 1);
2069        assert_eq!(credentials.access_token, "1");
2070
2071        // If credentials are still valid, signing in doesn't trigger authentication.
2072        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2073        assert_eq!(*auth_count.lock(), 1);
2074        assert_eq!(credentials.access_token, "1");
2075
2076        // If the server is unavailable, signing in doesn't trigger authentication.
2077        http_client
2078            .as_fake()
2079            .replace_handler(|_, _request| async move {
2080                Ok(http_client::Response::builder()
2081                    .status(503)
2082                    .body("".into())
2083                    .unwrap())
2084            });
2085        client.sign_in(false, &cx.to_async()).await.unwrap_err();
2086        assert_eq!(*auth_count.lock(), 1);
2087
2088        // If credentials became invalid, signing in triggers authentication.
2089        http_client
2090            .as_fake()
2091            .replace_handler(|_, request| async move {
2092                let credentials = parse_authorization_header(&request).unwrap();
2093                if credentials.access_token == "2" {
2094                    Ok(http_client::Response::builder()
2095                        .status(200)
2096                        .body("".into())
2097                        .unwrap())
2098                } else {
2099                    Ok(http_client::Response::builder()
2100                        .status(401)
2101                        .body("".into())
2102                        .unwrap())
2103                }
2104            });
2105        let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2106        assert_eq!(*auth_count.lock(), 2);
2107        assert_eq!(credentials.access_token, "2");
2108    }
2109
2110    #[gpui::test(iterations = 10)]
2111    async fn test_authenticating_more_than_once(
2112        cx: &mut TestAppContext,
2113        executor: BackgroundExecutor,
2114    ) {
2115        init_test(cx);
2116        let auth_count = Arc::new(Mutex::new(0));
2117        let dropped_auth_count = Arc::new(Mutex::new(0));
2118        let client = cx.update(|cx| {
2119            Client::new(
2120                Arc::new(FakeSystemClock::new()),
2121                FakeHttpClient::with_404_response(),
2122                cx,
2123            )
2124        });
2125        client.override_authenticate({
2126            let auth_count = auth_count.clone();
2127            let dropped_auth_count = dropped_auth_count.clone();
2128            move |cx| {
2129                let auth_count = auth_count.clone();
2130                let dropped_auth_count = dropped_auth_count.clone();
2131                cx.background_spawn(async move {
2132                    *auth_count.lock() += 1;
2133                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
2134                    future::pending::<()>().await;
2135                    unreachable!()
2136                })
2137            }
2138        });
2139
2140        let _authenticate = cx.spawn({
2141            let client = client.clone();
2142            move |cx| async move { client.connect(false, &cx).await }
2143        });
2144        executor.run_until_parked();
2145        assert_eq!(*auth_count.lock(), 1);
2146        assert_eq!(*dropped_auth_count.lock(), 0);
2147
2148        let _authenticate = cx.spawn(|cx| async move { client.connect(false, &cx).await });
2149        executor.run_until_parked();
2150        assert_eq!(*auth_count.lock(), 2);
2151        assert_eq!(*dropped_auth_count.lock(), 1);
2152    }
2153
2154    #[gpui::test]
2155    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
2156        init_test(cx);
2157        let user_id = 5;
2158        let client = cx.update(|cx| {
2159            Client::new(
2160                Arc::new(FakeSystemClock::new()),
2161                FakeHttpClient::with_404_response(),
2162                cx,
2163            )
2164        });
2165        let server = FakeServer::for_client(user_id, &client, cx).await;
2166
2167        let (done_tx1, done_rx1) = smol::channel::unbounded();
2168        let (done_tx2, done_rx2) = smol::channel::unbounded();
2169        AnyProtoClient::from(client.clone()).add_entity_message_handler(
2170            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, cx| {
2171                match entity.read_with(&cx, |entity, _| entity.id) {
2172                    1 => done_tx1.try_send(()).unwrap(),
2173                    2 => done_tx2.try_send(()).unwrap(),
2174                    _ => unreachable!(),
2175                }
2176                async { Ok(()) }
2177            },
2178        );
2179        let entity1 = cx.new(|_| TestEntity {
2180            id: 1,
2181            subscription: None,
2182        });
2183        let entity2 = cx.new(|_| TestEntity {
2184            id: 2,
2185            subscription: None,
2186        });
2187        let entity3 = cx.new(|_| TestEntity {
2188            id: 3,
2189            subscription: None,
2190        });
2191
2192        let _subscription1 = client
2193            .subscribe_to_entity(1)
2194            .unwrap()
2195            .set_entity(&entity1, &cx.to_async());
2196        let _subscription2 = client
2197            .subscribe_to_entity(2)
2198            .unwrap()
2199            .set_entity(&entity2, &cx.to_async());
2200        // Ensure dropping a subscription for the same entity type still allows receiving of
2201        // messages for other entity IDs of the same type.
2202        let subscription3 = client
2203            .subscribe_to_entity(3)
2204            .unwrap()
2205            .set_entity(&entity3, &cx.to_async());
2206        drop(subscription3);
2207
2208        server.send(proto::JoinProject {
2209            project_id: 1,
2210            committer_name: None,
2211            committer_email: None,
2212            features: Vec::new(),
2213        });
2214        server.send(proto::JoinProject {
2215            project_id: 2,
2216            committer_name: None,
2217            committer_email: None,
2218            features: Vec::new(),
2219        });
2220        done_rx1.recv().await.unwrap();
2221        done_rx2.recv().await.unwrap();
2222    }
2223
2224    #[gpui::test]
2225    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2226        init_test(cx);
2227        let user_id = 5;
2228        let client = cx.update(|cx| {
2229            Client::new(
2230                Arc::new(FakeSystemClock::new()),
2231                FakeHttpClient::with_404_response(),
2232                cx,
2233            )
2234        });
2235        let server = FakeServer::for_client(user_id, &client, cx).await;
2236
2237        let entity = cx.new(|_| TestEntity::default());
2238        let (done_tx1, _done_rx1) = smol::channel::unbounded();
2239        let (done_tx2, done_rx2) = smol::channel::unbounded();
2240        let subscription1 = client.add_message_handler(
2241            entity.downgrade(),
2242            move |_, _: TypedEnvelope<proto::Ping>, _| {
2243                done_tx1.try_send(()).unwrap();
2244                async { Ok(()) }
2245            },
2246        );
2247        drop(subscription1);
2248        let _subscription2 = client.add_message_handler(
2249            entity.downgrade(),
2250            move |_, _: TypedEnvelope<proto::Ping>, _| {
2251                done_tx2.try_send(()).unwrap();
2252                async { Ok(()) }
2253            },
2254        );
2255        server.send(proto::Ping {});
2256        done_rx2.recv().await.unwrap();
2257    }
2258
2259    #[gpui::test]
2260    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2261        init_test(cx);
2262        let user_id = 5;
2263        let client = cx.update(|cx| {
2264            Client::new(
2265                Arc::new(FakeSystemClock::new()),
2266                FakeHttpClient::with_404_response(),
2267                cx,
2268            )
2269        });
2270        let server = FakeServer::for_client(user_id, &client, cx).await;
2271
2272        let entity = cx.new(|_| TestEntity::default());
2273        let (done_tx, done_rx) = smol::channel::unbounded();
2274        let subscription = client.add_message_handler(
2275            entity.clone().downgrade(),
2276            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
2277                entity
2278                    .update(&mut cx, |entity, _| entity.subscription.take())
2279                    .unwrap();
2280                done_tx.try_send(()).unwrap();
2281                async { Ok(()) }
2282            },
2283        );
2284        entity.update(cx, |entity, _| {
2285            entity.subscription = Some(subscription);
2286        });
2287        server.send(proto::Ping {});
2288        done_rx.recv().await.unwrap();
2289    }
2290
2291    #[derive(Default)]
2292    struct TestEntity {
2293        id: usize,
2294        subscription: Option<Subscription>,
2295    }
2296
2297    fn init_test(cx: &mut TestAppContext) {
2298        cx.update(|cx| {
2299            let settings_store = SettingsStore::test(cx);
2300            cx.set_global(settings_store);
2301        });
2302    }
2303}