client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod proxy;
   5pub mod telemetry;
   6pub mod user;
   7pub mod zed_urls;
   8
   9use anyhow::{Context as _, Result, anyhow, bail};
  10use async_recursion::async_recursion;
  11use async_tungstenite::tungstenite::{
  12    client::IntoClientRequest,
  13    error::Error as WebsocketError,
  14    http::{HeaderValue, Request, StatusCode},
  15};
  16use chrono::{DateTime, Utc};
  17use clock::SystemClock;
  18use cloud_api_client::CloudApiClient;
  19use credentials_provider::CredentialsProvider;
  20use futures::{
  21    AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
  22    channel::oneshot, future::BoxFuture,
  23};
  24use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
  25use http_client::{AsyncBody, HttpClient, HttpClientWithUrl, http};
  26use parking_lot::RwLock;
  27use postage::watch;
  28use proxy::connect_proxy_stream;
  29use rand::prelude::*;
  30use release_channel::{AppVersion, ReleaseChannel};
  31use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  32use schemars::JsonSchema;
  33use serde::{Deserialize, Serialize};
  34use settings::{Settings, SettingsSources};
  35use std::{
  36    any::TypeId,
  37    convert::TryFrom,
  38    fmt::Write as _,
  39    future::Future,
  40    marker::PhantomData,
  41    path::PathBuf,
  42    sync::{
  43        Arc, LazyLock, Weak,
  44        atomic::{AtomicU64, Ordering},
  45    },
  46    time::{Duration, Instant},
  47};
  48use std::{cmp, pin::Pin};
  49use telemetry::Telemetry;
  50use thiserror::Error;
  51use tokio::net::TcpStream;
  52use url::Url;
  53use util::{ConnectionResult, ResultExt};
  54
  55pub use rpc::*;
  56pub use telemetry_events::Event;
  57pub use user::*;
  58
  59static ZED_SERVER_URL: LazyLock<Option<String>> =
  60    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  61static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  62
  63pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  64    std::env::var("ZED_IMPERSONATE")
  65        .ok()
  66        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  67});
  68
  69pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  70    std::env::var("ZED_ADMIN_API_TOKEN")
  71        .ok()
  72        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  73});
  74
  75pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  76    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  77
  78pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  79    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty()));
  80
  81pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  82pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
  83pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  84
  85actions!(
  86    client,
  87    [
  88        /// Signs in to Zed account.
  89        SignIn,
  90        /// Signs out of Zed account.
  91        SignOut,
  92        /// Reconnects to the collaboration server.
  93        Reconnect
  94    ]
  95);
  96
  97#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
  98pub struct ClientSettingsContent {
  99    server_url: Option<String>,
 100}
 101
 102#[derive(Deserialize)]
 103pub struct ClientSettings {
 104    pub server_url: String,
 105}
 106
 107impl Settings for ClientSettings {
 108    const KEY: Option<&'static str> = None;
 109
 110    type FileContent = ClientSettingsContent;
 111
 112    fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
 113        let mut result = sources.json_merge::<Self>()?;
 114        if let Some(server_url) = &*ZED_SERVER_URL {
 115            result.server_url.clone_from(server_url)
 116        }
 117        Ok(result)
 118    }
 119
 120    fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
 121}
 122
 123#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 124pub struct ProxySettingsContent {
 125    proxy: Option<String>,
 126}
 127
 128#[derive(Deserialize, Default)]
 129pub struct ProxySettings {
 130    pub proxy: Option<String>,
 131}
 132
 133impl Settings for ProxySettings {
 134    const KEY: Option<&'static str> = None;
 135
 136    type FileContent = ProxySettingsContent;
 137
 138    fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
 139        Ok(Self {
 140            proxy: sources
 141                .user
 142                .or(sources.server)
 143                .and_then(|value| value.proxy.clone())
 144                .or(sources.default.proxy.clone()),
 145        })
 146    }
 147
 148    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
 149        vscode.string_setting("http.proxy", &mut current.proxy);
 150    }
 151}
 152
 153pub fn init_settings(cx: &mut App) {
 154    TelemetrySettings::register(cx);
 155    DisableAiSettings::register(cx);
 156    ClientSettings::register(cx);
 157    ProxySettings::register(cx);
 158}
 159
 160pub fn init(client: &Arc<Client>, cx: &mut App) {
 161    let client = Arc::downgrade(client);
 162    cx.on_action({
 163        let client = client.clone();
 164        move |_: &SignIn, cx| {
 165            if let Some(client) = client.upgrade() {
 166                cx.spawn(
 167                    async move |cx| match client.authenticate_and_connect(true, &cx).await {
 168                        ConnectionResult::Timeout => {
 169                            log::error!("Initial authentication timed out");
 170                        }
 171                        ConnectionResult::ConnectionReset => {
 172                            log::error!("Initial authentication connection reset");
 173                        }
 174                        ConnectionResult::Result(r) => {
 175                            r.log_err();
 176                        }
 177                    },
 178                )
 179                .detach();
 180            }
 181        }
 182    });
 183
 184    cx.on_action({
 185        let client = client.clone();
 186        move |_: &SignOut, cx| {
 187            if let Some(client) = client.upgrade() {
 188                cx.spawn(async move |cx| {
 189                    client.sign_out(&cx).await;
 190                })
 191                .detach();
 192            }
 193        }
 194    });
 195
 196    cx.on_action({
 197        let client = client.clone();
 198        move |_: &Reconnect, cx| {
 199            if let Some(client) = client.upgrade() {
 200                cx.spawn(async move |cx| {
 201                    client.reconnect(&cx);
 202                })
 203                .detach();
 204            }
 205        }
 206    });
 207}
 208
 209struct GlobalClient(Arc<Client>);
 210
 211impl Global for GlobalClient {}
 212
 213pub struct Client {
 214    id: AtomicU64,
 215    peer: Arc<Peer>,
 216    http: Arc<HttpClientWithUrl>,
 217    cloud_client: Arc<CloudApiClient>,
 218    telemetry: Arc<Telemetry>,
 219    credentials_provider: ClientCredentialsProvider,
 220    state: RwLock<ClientState>,
 221    handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
 222
 223    #[allow(clippy::type_complexity)]
 224    #[cfg(any(test, feature = "test-support"))]
 225    authenticate:
 226        RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
 227
 228    #[allow(clippy::type_complexity)]
 229    #[cfg(any(test, feature = "test-support"))]
 230    establish_connection: RwLock<
 231        Option<
 232            Box<
 233                dyn 'static
 234                    + Send
 235                    + Sync
 236                    + Fn(
 237                        &Credentials,
 238                        &AsyncApp,
 239                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 240            >,
 241        >,
 242    >,
 243
 244    #[cfg(any(test, feature = "test-support"))]
 245    rpc_url: RwLock<Option<Url>>,
 246}
 247
 248#[derive(Error, Debug)]
 249pub enum EstablishConnectionError {
 250    #[error("upgrade required")]
 251    UpgradeRequired,
 252    #[error("unauthorized")]
 253    Unauthorized,
 254    #[error("{0}")]
 255    Other(#[from] anyhow::Error),
 256    #[error("{0}")]
 257    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 258    #[error("{0}")]
 259    Io(#[from] std::io::Error),
 260    #[error("{0}")]
 261    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 262}
 263
 264impl From<WebsocketError> for EstablishConnectionError {
 265    fn from(error: WebsocketError) -> Self {
 266        if let WebsocketError::Http(response) = &error {
 267            match response.status() {
 268                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 269                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 270                _ => {}
 271            }
 272        }
 273        EstablishConnectionError::Other(error.into())
 274    }
 275}
 276
 277impl EstablishConnectionError {
 278    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 279        Self::Other(error.into())
 280    }
 281}
 282
 283#[derive(Copy, Clone, Debug, PartialEq)]
 284pub enum Status {
 285    SignedOut,
 286    UpgradeRequired,
 287    Authenticating,
 288    Connecting,
 289    ConnectionError,
 290    Connected {
 291        peer_id: PeerId,
 292        connection_id: ConnectionId,
 293    },
 294    ConnectionLost,
 295    Reauthenticating,
 296    Reconnecting,
 297    ReconnectionError {
 298        next_reconnection: Instant,
 299    },
 300}
 301
 302impl Status {
 303    pub fn is_connected(&self) -> bool {
 304        matches!(self, Self::Connected { .. })
 305    }
 306
 307    pub fn is_signing_in(&self) -> bool {
 308        matches!(
 309            self,
 310            Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
 311        )
 312    }
 313
 314    pub fn is_signed_out(&self) -> bool {
 315        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 316    }
 317}
 318
 319struct ClientState {
 320    credentials: Option<Credentials>,
 321    status: (watch::Sender<Status>, watch::Receiver<Status>),
 322    _reconnect_task: Option<Task<()>>,
 323}
 324
 325#[derive(Clone, Debug, Eq, PartialEq)]
 326pub struct Credentials {
 327    pub user_id: u64,
 328    pub access_token: String,
 329}
 330
 331impl Credentials {
 332    pub fn authorization_header(&self) -> String {
 333        format!("{} {}", self.user_id, self.access_token)
 334    }
 335}
 336
 337pub struct ClientCredentialsProvider {
 338    provider: Arc<dyn CredentialsProvider>,
 339}
 340
 341impl ClientCredentialsProvider {
 342    pub fn new(cx: &App) -> Self {
 343        Self {
 344            provider: <dyn CredentialsProvider>::global(cx),
 345        }
 346    }
 347
 348    fn server_url(&self, cx: &AsyncApp) -> Result<String> {
 349        cx.update(|cx| ClientSettings::get_global(cx).server_url.clone())
 350    }
 351
 352    /// Reads the credentials from the provider.
 353    fn read_credentials<'a>(
 354        &'a self,
 355        cx: &'a AsyncApp,
 356    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
 357        async move {
 358            if IMPERSONATE_LOGIN.is_some() {
 359                return None;
 360            }
 361
 362            let server_url = self.server_url(cx).ok()?;
 363            let (user_id, access_token) = self
 364                .provider
 365                .read_credentials(&server_url, cx)
 366                .await
 367                .log_err()
 368                .flatten()?;
 369
 370            Some(Credentials {
 371                user_id: user_id.parse().ok()?,
 372                access_token: String::from_utf8(access_token).ok()?,
 373            })
 374        }
 375        .boxed_local()
 376    }
 377
 378    /// Writes the credentials to the provider.
 379    fn write_credentials<'a>(
 380        &'a self,
 381        user_id: u64,
 382        access_token: String,
 383        cx: &'a AsyncApp,
 384    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 385        async move {
 386            let server_url = self.server_url(cx)?;
 387            self.provider
 388                .write_credentials(
 389                    &server_url,
 390                    &user_id.to_string(),
 391                    access_token.as_bytes(),
 392                    cx,
 393                )
 394                .await
 395        }
 396        .boxed_local()
 397    }
 398
 399    /// Deletes the credentials from the provider.
 400    fn delete_credentials<'a>(
 401        &'a self,
 402        cx: &'a AsyncApp,
 403    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 404        async move {
 405            let server_url = self.server_url(cx)?;
 406            self.provider.delete_credentials(&server_url, cx).await
 407        }
 408        .boxed_local()
 409    }
 410}
 411
 412impl Default for ClientState {
 413    fn default() -> Self {
 414        Self {
 415            credentials: None,
 416            status: watch::channel_with(Status::SignedOut),
 417            _reconnect_task: None,
 418        }
 419    }
 420}
 421
 422pub enum Subscription {
 423    Entity {
 424        client: Weak<Client>,
 425        id: (TypeId, u64),
 426    },
 427    Message {
 428        client: Weak<Client>,
 429        id: TypeId,
 430    },
 431}
 432
 433impl Drop for Subscription {
 434    fn drop(&mut self) {
 435        match self {
 436            Subscription::Entity { client, id } => {
 437                if let Some(client) = client.upgrade() {
 438                    let mut state = client.handler_set.lock();
 439                    let _ = state.entities_by_type_and_remote_id.remove(id);
 440                }
 441            }
 442            Subscription::Message { client, id } => {
 443                if let Some(client) = client.upgrade() {
 444                    let mut state = client.handler_set.lock();
 445                    let _ = state.entity_types_by_message_type.remove(id);
 446                    let _ = state.message_handlers.remove(id);
 447                }
 448            }
 449        }
 450    }
 451}
 452
 453pub struct PendingEntitySubscription<T: 'static> {
 454    client: Arc<Client>,
 455    remote_id: u64,
 456    _entity_type: PhantomData<T>,
 457    consumed: bool,
 458}
 459
 460impl<T: 'static> PendingEntitySubscription<T> {
 461    pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
 462        self.consumed = true;
 463        let mut handlers = self.client.handler_set.lock();
 464        let id = (TypeId::of::<T>(), self.remote_id);
 465        let Some(EntityMessageSubscriber::Pending(messages)) =
 466            handlers.entities_by_type_and_remote_id.remove(&id)
 467        else {
 468            unreachable!()
 469        };
 470
 471        handlers.entities_by_type_and_remote_id.insert(
 472            id,
 473            EntityMessageSubscriber::Entity {
 474                handle: entity.downgrade().into(),
 475            },
 476        );
 477        drop(handlers);
 478        for message in messages {
 479            let client_id = self.client.id();
 480            let type_name = message.payload_type_name();
 481            let sender_id = message.original_sender_id();
 482            log::debug!(
 483                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 484                client_id,
 485                sender_id,
 486                type_name
 487            );
 488            self.client.handle_message(message, cx);
 489        }
 490        Subscription::Entity {
 491            client: Arc::downgrade(&self.client),
 492            id,
 493        }
 494    }
 495}
 496
 497impl<T: 'static> Drop for PendingEntitySubscription<T> {
 498    fn drop(&mut self) {
 499        if !self.consumed {
 500            let mut state = self.client.handler_set.lock();
 501            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 502                .entities_by_type_and_remote_id
 503                .remove(&(TypeId::of::<T>(), self.remote_id))
 504            {
 505                for message in messages {
 506                    log::info!("unhandled message {}", message.payload_type_name());
 507                }
 508            }
 509        }
 510    }
 511}
 512
 513#[derive(Copy, Clone, Deserialize, Debug)]
 514pub struct TelemetrySettings {
 515    pub diagnostics: bool,
 516    pub metrics: bool,
 517}
 518
 519/// Control what info is collected by Zed.
 520#[derive(Default, Clone, Serialize, Deserialize, JsonSchema, Debug)]
 521pub struct TelemetrySettingsContent {
 522    /// Send debug info like crash reports.
 523    ///
 524    /// Default: true
 525    pub diagnostics: Option<bool>,
 526    /// Send anonymized usage data like what languages you're using Zed with.
 527    ///
 528    /// Default: true
 529    pub metrics: Option<bool>,
 530}
 531
 532impl settings::Settings for TelemetrySettings {
 533    const KEY: Option<&'static str> = Some("telemetry");
 534
 535    type FileContent = TelemetrySettingsContent;
 536
 537    fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
 538        sources.json_merge()
 539    }
 540
 541    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
 542        vscode.enum_setting("telemetry.telemetryLevel", &mut current.metrics, |s| {
 543            Some(s == "all")
 544        });
 545        vscode.enum_setting("telemetry.telemetryLevel", &mut current.diagnostics, |s| {
 546            Some(matches!(s, "all" | "error" | "crash"))
 547        });
 548        // we could translate telemetry.telemetryLevel, but just because users didn't want
 549        // to send microsoft telemetry doesn't mean they don't want to send it to zed. their
 550        // all/error/crash/off correspond to combinations of our "diagnostics" and "metrics".
 551    }
 552}
 553
 554/// Whether to disable all AI features in Zed.
 555///
 556/// Default: false
 557#[derive(Copy, Clone, Debug)]
 558pub struct DisableAiSettings {
 559    pub disable_ai: bool,
 560}
 561
 562impl settings::Settings for DisableAiSettings {
 563    const KEY: Option<&'static str> = Some("disable_ai");
 564
 565    type FileContent = Option<bool>;
 566
 567    fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
 568        Ok(Self {
 569            disable_ai: sources
 570                .user
 571                .or(sources.server)
 572                .copied()
 573                .flatten()
 574                .unwrap_or(sources.default.ok_or_else(Self::missing_default)?),
 575        })
 576    }
 577
 578    fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
 579}
 580
 581impl Client {
 582    pub fn new(
 583        clock: Arc<dyn SystemClock>,
 584        http: Arc<HttpClientWithUrl>,
 585        cx: &mut App,
 586    ) -> Arc<Self> {
 587        Arc::new(Self {
 588            id: AtomicU64::new(0),
 589            peer: Peer::new(0),
 590            telemetry: Telemetry::new(clock, http.clone(), cx),
 591            cloud_client: Arc::new(CloudApiClient::new(http.clone())),
 592            http,
 593            credentials_provider: ClientCredentialsProvider::new(cx),
 594            state: Default::default(),
 595            handler_set: Default::default(),
 596
 597            #[cfg(any(test, feature = "test-support"))]
 598            authenticate: Default::default(),
 599            #[cfg(any(test, feature = "test-support"))]
 600            establish_connection: Default::default(),
 601            #[cfg(any(test, feature = "test-support"))]
 602            rpc_url: RwLock::default(),
 603        })
 604    }
 605
 606    pub fn production(cx: &mut App) -> Arc<Self> {
 607        let clock = Arc::new(clock::RealSystemClock);
 608        let http = Arc::new(HttpClientWithUrl::new_url(
 609            cx.http_client(),
 610            &ClientSettings::get_global(cx).server_url,
 611            cx.http_client().proxy().cloned(),
 612        ));
 613        Self::new(clock, http, cx)
 614    }
 615
 616    pub fn id(&self) -> u64 {
 617        self.id.load(Ordering::SeqCst)
 618    }
 619
 620    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 621        self.http.clone()
 622    }
 623
 624    pub fn set_id(&self, id: u64) -> &Self {
 625        self.id.store(id, Ordering::SeqCst);
 626        self
 627    }
 628
 629    #[cfg(any(test, feature = "test-support"))]
 630    pub fn teardown(&self) {
 631        let mut state = self.state.write();
 632        state._reconnect_task.take();
 633        self.handler_set.lock().clear();
 634        self.peer.teardown();
 635    }
 636
 637    #[cfg(any(test, feature = "test-support"))]
 638    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 639    where
 640        F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
 641    {
 642        *self.authenticate.write() = Some(Box::new(authenticate));
 643        self
 644    }
 645
 646    #[cfg(any(test, feature = "test-support"))]
 647    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 648    where
 649        F: 'static
 650            + Send
 651            + Sync
 652            + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
 653    {
 654        *self.establish_connection.write() = Some(Box::new(connect));
 655        self
 656    }
 657
 658    #[cfg(any(test, feature = "test-support"))]
 659    pub fn override_rpc_url(&self, url: Url) -> &Self {
 660        *self.rpc_url.write() = Some(url);
 661        self
 662    }
 663
 664    pub fn global(cx: &App) -> Arc<Self> {
 665        cx.global::<GlobalClient>().0.clone()
 666    }
 667    pub fn set_global(client: Arc<Client>, cx: &mut App) {
 668        cx.set_global(GlobalClient(client))
 669    }
 670
 671    pub fn user_id(&self) -> Option<u64> {
 672        self.state
 673            .read()
 674            .credentials
 675            .as_ref()
 676            .map(|credentials| credentials.user_id)
 677    }
 678
 679    pub fn peer_id(&self) -> Option<PeerId> {
 680        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 681            Some(*peer_id)
 682        } else {
 683            None
 684        }
 685    }
 686
 687    pub fn status(&self) -> watch::Receiver<Status> {
 688        self.state.read().status.1.clone()
 689    }
 690
 691    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
 692        log::info!("set status on client {}: {:?}", self.id(), status);
 693        let mut state = self.state.write();
 694        *state.status.0.borrow_mut() = status;
 695
 696        match status {
 697            Status::Connected { .. } => {
 698                state._reconnect_task = None;
 699            }
 700            Status::ConnectionLost => {
 701                let client = self.clone();
 702                state._reconnect_task = Some(cx.spawn(async move |cx| {
 703                    #[cfg(any(test, feature = "test-support"))]
 704                    let mut rng = StdRng::seed_from_u64(0);
 705                    #[cfg(not(any(test, feature = "test-support")))]
 706                    let mut rng = StdRng::from_entropy();
 707
 708                    let mut delay = INITIAL_RECONNECTION_DELAY;
 709                    loop {
 710                        match client.authenticate_and_connect(true, &cx).await {
 711                            ConnectionResult::Timeout => {
 712                                log::error!("client connect attempt timed out")
 713                            }
 714                            ConnectionResult::ConnectionReset => {
 715                                log::error!("client connect attempt reset")
 716                            }
 717                            ConnectionResult::Result(r) => {
 718                                if let Err(error) = r {
 719                                    log::error!("failed to connect: {error}");
 720                                } else {
 721                                    break;
 722                                }
 723                            }
 724                        }
 725
 726                        if matches!(*client.status().borrow(), Status::ConnectionError) {
 727                            client.set_status(
 728                                Status::ReconnectionError {
 729                                    next_reconnection: Instant::now() + delay,
 730                                },
 731                                &cx,
 732                            );
 733                            let jitter =
 734                                Duration::from_millis(rng.gen_range(0..delay.as_millis() as u64));
 735                            cx.background_executor().timer(delay + jitter).await;
 736                            delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
 737                        } else {
 738                            break;
 739                        }
 740                    }
 741                }));
 742            }
 743            Status::SignedOut | Status::UpgradeRequired => {
 744                self.telemetry.set_authenticated_user_info(None, false);
 745                state._reconnect_task.take();
 746            }
 747            _ => {}
 748        }
 749    }
 750
 751    pub fn subscribe_to_entity<T>(
 752        self: &Arc<Self>,
 753        remote_id: u64,
 754    ) -> Result<PendingEntitySubscription<T>>
 755    where
 756        T: 'static,
 757    {
 758        let id = (TypeId::of::<T>(), remote_id);
 759
 760        let mut state = self.handler_set.lock();
 761        anyhow::ensure!(
 762            !state.entities_by_type_and_remote_id.contains_key(&id),
 763            "already subscribed to entity"
 764        );
 765
 766        state
 767            .entities_by_type_and_remote_id
 768            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 769
 770        Ok(PendingEntitySubscription {
 771            client: self.clone(),
 772            remote_id,
 773            consumed: false,
 774            _entity_type: PhantomData,
 775        })
 776    }
 777
 778    #[track_caller]
 779    pub fn add_message_handler<M, E, H, F>(
 780        self: &Arc<Self>,
 781        entity: WeakEntity<E>,
 782        handler: H,
 783    ) -> Subscription
 784    where
 785        M: EnvelopedMessage,
 786        E: 'static,
 787        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 788        F: 'static + Future<Output = Result<()>>,
 789    {
 790        self.add_message_handler_impl(entity, move |entity, message, _, cx| {
 791            handler(entity, message, cx)
 792        })
 793    }
 794
 795    fn add_message_handler_impl<M, E, H, F>(
 796        self: &Arc<Self>,
 797        entity: WeakEntity<E>,
 798        handler: H,
 799    ) -> Subscription
 800    where
 801        M: EnvelopedMessage,
 802        E: 'static,
 803        H: 'static
 804            + Sync
 805            + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
 806            + Send
 807            + Sync,
 808        F: 'static + Future<Output = Result<()>>,
 809    {
 810        let message_type_id = TypeId::of::<M>();
 811        let mut state = self.handler_set.lock();
 812        state
 813            .entities_by_message_type
 814            .insert(message_type_id, entity.into());
 815
 816        let prev_handler = state.message_handlers.insert(
 817            message_type_id,
 818            Arc::new(move |subscriber, envelope, client, cx| {
 819                let subscriber = subscriber.downcast::<E>().unwrap();
 820                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 821                handler(subscriber, *envelope, client.clone(), cx).boxed_local()
 822            }),
 823        );
 824        if prev_handler.is_some() {
 825            let location = std::panic::Location::caller();
 826            panic!(
 827                "{}:{} registered handler for the same message {} twice",
 828                location.file(),
 829                location.line(),
 830                std::any::type_name::<M>()
 831            );
 832        }
 833
 834        Subscription::Message {
 835            client: Arc::downgrade(self),
 836            id: message_type_id,
 837        }
 838    }
 839
 840    pub fn add_request_handler<M, E, H, F>(
 841        self: &Arc<Self>,
 842        entity: WeakEntity<E>,
 843        handler: H,
 844    ) -> Subscription
 845    where
 846        M: RequestMessage,
 847        E: 'static,
 848        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 849        F: 'static + Future<Output = Result<M::Response>>,
 850    {
 851        self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
 852            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 853        })
 854    }
 855
 856    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 857        receipt: Receipt<T>,
 858        response: F,
 859        client: AnyProtoClient,
 860    ) -> Result<()> {
 861        match response.await {
 862            Ok(response) => {
 863                client.send_response(receipt.message_id, response)?;
 864                Ok(())
 865            }
 866            Err(error) => {
 867                client.send_response(receipt.message_id, error.to_proto())?;
 868                Err(error)
 869            }
 870        }
 871    }
 872
 873    pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
 874        self.credentials_provider
 875            .read_credentials(cx)
 876            .await
 877            .is_some()
 878    }
 879
 880    #[async_recursion(?Send)]
 881    pub async fn authenticate_and_connect(
 882        self: &Arc<Self>,
 883        try_provider: bool,
 884        cx: &AsyncApp,
 885    ) -> ConnectionResult<()> {
 886        let was_disconnected = match *self.status().borrow() {
 887            Status::SignedOut => true,
 888            Status::ConnectionError
 889            | Status::ConnectionLost
 890            | Status::Authenticating { .. }
 891            | Status::Reauthenticating { .. }
 892            | Status::ReconnectionError { .. } => false,
 893            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 894                return ConnectionResult::Result(Ok(()));
 895            }
 896            Status::UpgradeRequired => {
 897                return ConnectionResult::Result(
 898                    Err(EstablishConnectionError::UpgradeRequired)
 899                        .context("client auth and connect"),
 900                );
 901            }
 902        };
 903        if was_disconnected {
 904            self.set_status(Status::Authenticating, cx);
 905        } else {
 906            self.set_status(Status::Reauthenticating, cx)
 907        }
 908
 909        let mut read_from_provider = false;
 910        let mut credentials = self.state.read().credentials.clone();
 911        if credentials.is_none() && try_provider {
 912            credentials = self.credentials_provider.read_credentials(cx).await;
 913            read_from_provider = credentials.is_some();
 914        }
 915
 916        if credentials.is_none() {
 917            let mut status_rx = self.status();
 918            let _ = status_rx.next().await;
 919            futures::select_biased! {
 920                authenticate = self.authenticate(cx).fuse() => {
 921                    match authenticate {
 922                        Ok(creds) => credentials = Some(creds),
 923                        Err(err) => {
 924                            self.set_status(Status::ConnectionError, cx);
 925                            return ConnectionResult::Result(Err(err));
 926                        }
 927                    }
 928                }
 929                _ = status_rx.next().fuse() => {
 930                    return ConnectionResult::Result(Err(anyhow!("authentication canceled")));
 931                }
 932            }
 933        }
 934        let credentials = credentials.unwrap();
 935        self.set_id(credentials.user_id);
 936        self.cloud_client
 937            .set_credentials(credentials.user_id as u32, credentials.access_token.clone());
 938
 939        if was_disconnected {
 940            self.set_status(Status::Connecting, cx);
 941        } else {
 942            self.set_status(Status::Reconnecting, cx);
 943        }
 944
 945        let mut timeout =
 946            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
 947        futures::select_biased! {
 948            connection = self.establish_connection(&credentials, cx).fuse() => {
 949                match connection {
 950                    Ok(conn) => {
 951                        self.state.write().credentials = Some(credentials.clone());
 952                        if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
 953                            self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err();
 954                        }
 955
 956                        futures::select_biased! {
 957                            result = self.set_connection(conn, cx).fuse() => {
 958                                match result.context("client auth and connect") {
 959                                    Ok(()) => ConnectionResult::Result(Ok(())),
 960                                    Err(err) => {
 961                                        self.set_status(Status::ConnectionError, cx);
 962                                        ConnectionResult::Result(Err(err))
 963                                    },
 964                                }
 965                            },
 966                            _ = timeout => {
 967                                self.set_status(Status::ConnectionError, cx);
 968                                ConnectionResult::Timeout
 969                            }
 970                        }
 971                    }
 972                    Err(EstablishConnectionError::Unauthorized) => {
 973                        self.state.write().credentials.take();
 974                        if read_from_provider {
 975                            self.credentials_provider.delete_credentials(cx).await.log_err();
 976                            self.set_status(Status::SignedOut, cx);
 977                            self.authenticate_and_connect(false, cx).await
 978                        } else {
 979                            self.set_status(Status::ConnectionError, cx);
 980                            ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
 981                        }
 982                    }
 983                    Err(EstablishConnectionError::UpgradeRequired) => {
 984                        self.set_status(Status::UpgradeRequired, cx);
 985                        ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
 986                    }
 987                    Err(error) => {
 988                        self.set_status(Status::ConnectionError, cx);
 989                        ConnectionResult::Result(Err(error).context("client auth and connect"))
 990                    }
 991                }
 992            }
 993            _ = &mut timeout => {
 994                self.set_status(Status::ConnectionError, cx);
 995                ConnectionResult::Timeout
 996            }
 997        }
 998    }
 999
1000    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
1001        let executor = cx.background_executor();
1002        log::debug!("add connection to peer");
1003        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
1004            let executor = executor.clone();
1005            move |duration| executor.timer(duration)
1006        });
1007        let handle_io = executor.spawn(handle_io);
1008
1009        let peer_id = async {
1010            log::debug!("waiting for server hello");
1011            let message = incoming.next().await.context("no hello message received")?;
1012            log::debug!("got server hello");
1013            let hello_message_type_name = message.payload_type_name().to_string();
1014            let hello = message
1015                .into_any()
1016                .downcast::<TypedEnvelope<proto::Hello>>()
1017                .map_err(|_| {
1018                    anyhow!(
1019                        "invalid hello message received: {:?}",
1020                        hello_message_type_name
1021                    )
1022                })?;
1023            let peer_id = hello.payload.peer_id.context("invalid peer id")?;
1024            Ok(peer_id)
1025        };
1026
1027        let peer_id = match peer_id.await {
1028            Ok(peer_id) => peer_id,
1029            Err(error) => {
1030                self.peer.disconnect(connection_id);
1031                return Err(error);
1032            }
1033        };
1034
1035        log::debug!(
1036            "set status to connected (connection id: {:?}, peer id: {:?})",
1037            connection_id,
1038            peer_id
1039        );
1040        self.set_status(
1041            Status::Connected {
1042                peer_id,
1043                connection_id,
1044            },
1045            cx,
1046        );
1047
1048        cx.spawn({
1049            let this = self.clone();
1050            async move |cx| {
1051                while let Some(message) = incoming.next().await {
1052                    this.handle_message(message, &cx);
1053                    // Don't starve the main thread when receiving lots of messages at once.
1054                    smol::future::yield_now().await;
1055                }
1056            }
1057        })
1058        .detach();
1059
1060        cx.spawn({
1061            let this = self.clone();
1062            async move |cx| match handle_io.await {
1063                Ok(()) => {
1064                    if *this.status().borrow()
1065                        == (Status::Connected {
1066                            connection_id,
1067                            peer_id,
1068                        })
1069                    {
1070                        this.set_status(Status::SignedOut, &cx);
1071                    }
1072                }
1073                Err(err) => {
1074                    log::error!("connection error: {:?}", err);
1075                    this.set_status(Status::ConnectionLost, &cx);
1076                }
1077            }
1078        })
1079        .detach();
1080
1081        Ok(())
1082    }
1083
1084    fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1085        #[cfg(any(test, feature = "test-support"))]
1086        if let Some(callback) = self.authenticate.read().as_ref() {
1087            return callback(cx);
1088        }
1089
1090        self.authenticate_with_browser(cx)
1091    }
1092
1093    fn establish_connection(
1094        self: &Arc<Self>,
1095        credentials: &Credentials,
1096        cx: &AsyncApp,
1097    ) -> Task<Result<Connection, EstablishConnectionError>> {
1098        #[cfg(any(test, feature = "test-support"))]
1099        if let Some(callback) = self.establish_connection.read().as_ref() {
1100            return callback(credentials, cx);
1101        }
1102
1103        self.establish_websocket_connection(credentials, cx)
1104    }
1105
1106    fn rpc_url(
1107        &self,
1108        http: Arc<HttpClientWithUrl>,
1109        release_channel: Option<ReleaseChannel>,
1110    ) -> impl Future<Output = Result<url::Url>> + use<> {
1111        #[cfg(any(test, feature = "test-support"))]
1112        let url_override = self.rpc_url.read().clone();
1113
1114        async move {
1115            #[cfg(any(test, feature = "test-support"))]
1116            if let Some(url) = url_override {
1117                return Ok(url);
1118            }
1119
1120            if let Some(url) = &*ZED_RPC_URL {
1121                return Url::parse(url).context("invalid rpc url");
1122            }
1123
1124            let mut url = http.build_url("/rpc");
1125            if let Some(preview_param) =
1126                release_channel.and_then(|channel| channel.release_query_param())
1127            {
1128                url += "?";
1129                url += preview_param;
1130            }
1131
1132            let response = http.get(&url, Default::default(), false).await?;
1133            anyhow::ensure!(
1134                response.status().is_redirection(),
1135                "unexpected /rpc response status {}",
1136                response.status()
1137            );
1138            let collab_url = response
1139                .headers()
1140                .get("Location")
1141                .context("missing location header in /rpc response")?
1142                .to_str()
1143                .map_err(EstablishConnectionError::other)?
1144                .to_string();
1145            Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
1146        }
1147    }
1148
1149    fn establish_websocket_connection(
1150        self: &Arc<Self>,
1151        credentials: &Credentials,
1152        cx: &AsyncApp,
1153    ) -> Task<Result<Connection, EstablishConnectionError>> {
1154        let release_channel = cx
1155            .update(|cx| ReleaseChannel::try_global(cx))
1156            .ok()
1157            .flatten();
1158        let app_version = cx
1159            .update(|cx| AppVersion::global(cx).to_string())
1160            .ok()
1161            .unwrap_or_default();
1162
1163        let http = self.http.clone();
1164        let proxy = http.proxy().cloned();
1165        let user_agent = http.user_agent().cloned();
1166        let credentials = credentials.clone();
1167        let rpc_url = self.rpc_url(http, release_channel);
1168        let system_id = self.telemetry.system_id();
1169        let metrics_id = self.telemetry.metrics_id();
1170        cx.spawn(async move |cx| {
1171            use HttpOrHttps::*;
1172
1173            #[derive(Debug)]
1174            enum HttpOrHttps {
1175                Http,
1176                Https,
1177            }
1178
1179            let mut rpc_url = rpc_url.await?;
1180            let url_scheme = match rpc_url.scheme() {
1181                "https" => Https,
1182                "http" => Http,
1183                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1184            };
1185            let rpc_host = rpc_url
1186                .host_str()
1187                .zip(rpc_url.port_or_known_default())
1188                .context("missing host in rpc url")?;
1189
1190            let stream = {
1191                let handle = cx.update(|cx| gpui_tokio::Tokio::handle(cx)).ok().unwrap();
1192                let _guard = handle.enter();
1193                match proxy {
1194                    Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
1195                    None => Box::new(TcpStream::connect(rpc_host).await?),
1196                }
1197            };
1198
1199            log::info!("connected to rpc endpoint {}", rpc_url);
1200
1201            rpc_url
1202                .set_scheme(match url_scheme {
1203                    Https => "wss",
1204                    Http => "ws",
1205                })
1206                .unwrap();
1207
1208            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1209            // for us from the RPC URL.
1210            //
1211            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1212            let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1213
1214            // We then modify the request to add our desired headers.
1215            let request_headers = request.headers_mut();
1216            request_headers.insert(
1217                http::header::AUTHORIZATION,
1218                HeaderValue::from_str(&credentials.authorization_header())?,
1219            );
1220            request_headers.insert(
1221                "x-zed-protocol-version",
1222                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1223            );
1224            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1225            request_headers.insert(
1226                "x-zed-release-channel",
1227                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1228            );
1229            if let Some(user_agent) = user_agent {
1230                request_headers.insert(http::header::USER_AGENT, user_agent);
1231            }
1232            if let Some(system_id) = system_id {
1233                request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1234            }
1235            if let Some(metrics_id) = metrics_id {
1236                request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1237            }
1238
1239            let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1240                request,
1241                stream,
1242                Some(Arc::new(http_client_tls::tls_config()).into()),
1243                None,
1244            )
1245            .await?;
1246
1247            Ok(Connection::new(
1248                stream
1249                    .map_err(|error| anyhow!(error))
1250                    .sink_map_err(|error| anyhow!(error)),
1251            ))
1252        })
1253    }
1254
1255    pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1256        let http = self.http.clone();
1257        let this = self.clone();
1258        cx.spawn(async move |cx| {
1259            let background = cx.background_executor().clone();
1260
1261            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1262            cx.update(|cx| {
1263                cx.spawn(async move |cx| {
1264                    let url = open_url_rx.await?;
1265                    cx.update(|cx| cx.open_url(&url))
1266                })
1267                .detach_and_log_err(cx);
1268            })
1269            .log_err();
1270
1271            let credentials = background
1272                .clone()
1273                .spawn(async move {
1274                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1275                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1276                    // any other app running on the user's device.
1277                    let (public_key, private_key) =
1278                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1279                    let public_key_string = String::try_from(public_key)
1280                        .expect("failed to serialize public key for auth");
1281
1282                    if let Some((login, token)) =
1283                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1284                    {
1285                        eprintln!("authenticate as admin {login}, {token}");
1286
1287                        return this
1288                            .authenticate_as_admin(http, login.clone(), token.clone())
1289                            .await;
1290                    }
1291
1292                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1293                    let server =
1294                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1295                    let port = server.server_addr().port();
1296
1297                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1298                    // that the user is signing in from a Zed app running on the same device.
1299                    let mut url = http.build_url(&format!(
1300                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1301                        port, public_key_string
1302                    ));
1303
1304                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1305                        log::info!("impersonating user @{}", impersonate_login);
1306                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1307                    }
1308
1309                    open_url_tx.send(url).log_err();
1310
1311                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1312                    // access token from the query params.
1313                    //
1314                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1315                    // custom URL scheme instead of this local HTTP server.
1316                    let (user_id, access_token) = background
1317                        .spawn(async move {
1318                            for _ in 0..100 {
1319                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1320                                    let path = req.url();
1321                                    let mut user_id = None;
1322                                    let mut access_token = None;
1323                                    let url = Url::parse(&format!("http://example.com{}", path))
1324                                        .context("failed to parse login notification url")?;
1325                                    for (key, value) in url.query_pairs() {
1326                                        if key == "access_token" {
1327                                            access_token = Some(value.to_string());
1328                                        } else if key == "user_id" {
1329                                            user_id = Some(value.to_string());
1330                                        }
1331                                    }
1332
1333                                    let post_auth_url =
1334                                        http.build_url("/native_app_signin_succeeded");
1335                                    req.respond(
1336                                        tiny_http::Response::empty(302).with_header(
1337                                            tiny_http::Header::from_bytes(
1338                                                &b"Location"[..],
1339                                                post_auth_url.as_bytes(),
1340                                            )
1341                                            .unwrap(),
1342                                        ),
1343                                    )
1344                                    .context("failed to respond to login http request")?;
1345                                    return Ok((
1346                                        user_id.context("missing user_id parameter")?,
1347                                        access_token.context("missing access_token parameter")?,
1348                                    ));
1349                                }
1350                            }
1351
1352                            anyhow::bail!("didn't receive login redirect");
1353                        })
1354                        .await?;
1355
1356                    let access_token = private_key
1357                        .decrypt_string(&access_token)
1358                        .context("failed to decrypt access token")?;
1359
1360                    Ok(Credentials {
1361                        user_id: user_id.parse()?,
1362                        access_token,
1363                    })
1364                })
1365                .await?;
1366
1367            cx.update(|cx| cx.activate(true))?;
1368            Ok(credentials)
1369        })
1370    }
1371
1372    async fn authenticate_as_admin(
1373        self: &Arc<Self>,
1374        http: Arc<HttpClientWithUrl>,
1375        login: String,
1376        mut api_token: String,
1377    ) -> Result<Credentials> {
1378        #[derive(Deserialize)]
1379        struct AuthenticatedUserResponse {
1380            user: User,
1381        }
1382
1383        #[derive(Deserialize)]
1384        struct User {
1385            id: u64,
1386        }
1387
1388        let github_user = {
1389            #[derive(Deserialize)]
1390            struct GithubUser {
1391                id: i32,
1392                login: String,
1393                created_at: DateTime<Utc>,
1394            }
1395
1396            let request = {
1397                let mut request_builder =
1398                    Request::get(&format!("https://api.github.com/users/{login}"));
1399                if let Ok(github_token) = std::env::var("GITHUB_TOKEN") {
1400                    request_builder =
1401                        request_builder.header("Authorization", format!("Bearer {}", github_token));
1402                }
1403
1404                request_builder.body(AsyncBody::empty())?
1405            };
1406
1407            let mut response = http
1408                .send(request)
1409                .await
1410                .context("error fetching GitHub user")?;
1411
1412            let mut body = Vec::new();
1413            response
1414                .body_mut()
1415                .read_to_end(&mut body)
1416                .await
1417                .context("error reading GitHub user")?;
1418
1419            if !response.status().is_success() {
1420                let text = String::from_utf8_lossy(body.as_slice());
1421                bail!(
1422                    "status error {}, response: {text:?}",
1423                    response.status().as_u16()
1424                );
1425            }
1426
1427            serde_json::from_slice::<GithubUser>(body.as_slice()).map_err(|err| {
1428                log::error!("Error deserializing: {:?}", err);
1429                log::error!(
1430                    "GitHub API response text: {:?}",
1431                    String::from_utf8_lossy(body.as_slice())
1432                );
1433                anyhow!("error deserializing GitHub user")
1434            })?
1435        };
1436
1437        let query_params = [
1438            ("github_login", &github_user.login),
1439            ("github_user_id", &github_user.id.to_string()),
1440            (
1441                "github_user_created_at",
1442                &github_user.created_at.to_rfc3339(),
1443            ),
1444        ];
1445
1446        // Use the collab server's admin API to retrieve the ID
1447        // of the impersonated user.
1448        let mut url = self.rpc_url(http.clone(), None).await?;
1449        url.set_path("/user");
1450        url.set_query(Some(
1451            &query_params
1452                .iter()
1453                .map(|(key, value)| {
1454                    format!(
1455                        "{}={}",
1456                        key,
1457                        url::form_urlencoded::byte_serialize(value.as_bytes()).collect::<String>()
1458                    )
1459                })
1460                .collect::<Vec<String>>()
1461                .join("&"),
1462        ));
1463        let request: http_client::Request<AsyncBody> = Request::get(url.as_str())
1464            .header("Authorization", format!("token {api_token}"))
1465            .body("".into())?;
1466
1467        let mut response = http.send(request).await?;
1468        let mut body = String::new();
1469        response.body_mut().read_to_string(&mut body).await?;
1470        anyhow::ensure!(
1471            response.status().is_success(),
1472            "admin user request failed {} - {}",
1473            response.status().as_u16(),
1474            body,
1475        );
1476        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1477
1478        // Use the admin API token to authenticate as the impersonated user.
1479        api_token.insert_str(0, "ADMIN_TOKEN:");
1480        Ok(Credentials {
1481            user_id: response.user.id,
1482            access_token: api_token,
1483        })
1484    }
1485
1486    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1487        self.state.write().credentials = None;
1488        self.disconnect(cx);
1489
1490        if self.has_credentials(cx).await {
1491            self.credentials_provider
1492                .delete_credentials(cx)
1493                .await
1494                .log_err();
1495        }
1496    }
1497
1498    pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1499        self.peer.teardown();
1500        self.set_status(Status::SignedOut, cx);
1501    }
1502
1503    pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1504        self.peer.teardown();
1505        self.set_status(Status::ConnectionLost, cx);
1506    }
1507
1508    fn connection_id(&self) -> Result<ConnectionId> {
1509        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1510            Ok(connection_id)
1511        } else {
1512            anyhow::bail!("not connected");
1513        }
1514    }
1515
1516    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1517        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1518        self.peer.send(self.connection_id()?, message)
1519    }
1520
1521    pub fn request<T: RequestMessage>(
1522        &self,
1523        request: T,
1524    ) -> impl Future<Output = Result<T::Response>> + use<T> {
1525        self.request_envelope(request)
1526            .map_ok(|envelope| envelope.payload)
1527    }
1528
1529    pub fn request_stream<T: RequestMessage>(
1530        &self,
1531        request: T,
1532    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1533        let client_id = self.id.load(Ordering::SeqCst);
1534        log::debug!(
1535            "rpc request start. client_id:{}. name:{}",
1536            client_id,
1537            T::NAME
1538        );
1539        let response = self
1540            .connection_id()
1541            .map(|conn_id| self.peer.request_stream(conn_id, request));
1542        async move {
1543            let response = response?.await;
1544            log::debug!(
1545                "rpc request finish. client_id:{}. name:{}",
1546                client_id,
1547                T::NAME
1548            );
1549            response
1550        }
1551    }
1552
1553    pub fn request_envelope<T: RequestMessage>(
1554        &self,
1555        request: T,
1556    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1557        let client_id = self.id();
1558        log::debug!(
1559            "rpc request start. client_id:{}. name:{}",
1560            client_id,
1561            T::NAME
1562        );
1563        let response = self
1564            .connection_id()
1565            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1566        async move {
1567            let response = response?.await;
1568            log::debug!(
1569                "rpc request finish. client_id:{}. name:{}",
1570                client_id,
1571                T::NAME
1572            );
1573            response
1574        }
1575    }
1576
1577    pub fn request_dynamic(
1578        &self,
1579        envelope: proto::Envelope,
1580        request_type: &'static str,
1581    ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1582        let client_id = self.id();
1583        log::debug!(
1584            "rpc request start. client_id:{}. name:{}",
1585            client_id,
1586            request_type
1587        );
1588        let response = self
1589            .connection_id()
1590            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1591        async move {
1592            let response = response?.await;
1593            log::debug!(
1594                "rpc request finish. client_id:{}. name:{}",
1595                client_id,
1596                request_type
1597            );
1598            Ok(response?.0)
1599        }
1600    }
1601
1602    fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1603        let sender_id = message.sender_id();
1604        let request_id = message.message_id();
1605        let type_name = message.payload_type_name();
1606        let original_sender_id = message.original_sender_id();
1607
1608        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1609            &self.handler_set,
1610            message,
1611            self.clone().into(),
1612            cx.clone(),
1613        ) {
1614            let client_id = self.id();
1615            log::debug!(
1616                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1617                client_id,
1618                original_sender_id,
1619                type_name
1620            );
1621            cx.spawn(async move |_| match future.await {
1622                Ok(()) => {
1623                    log::debug!(
1624                        "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1625                        client_id,
1626                        original_sender_id,
1627                        type_name
1628                    );
1629                }
1630                Err(error) => {
1631                    log::error!(
1632                        "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1633                        client_id,
1634                        original_sender_id,
1635                        type_name,
1636                        error
1637                    );
1638                }
1639            })
1640            .detach();
1641        } else {
1642            log::info!("unhandled message {}", type_name);
1643            self.peer
1644                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1645                .log_err();
1646        }
1647    }
1648
1649    pub fn telemetry(&self) -> &Arc<Telemetry> {
1650        &self.telemetry
1651    }
1652}
1653
1654impl ProtoClient for Client {
1655    fn request(
1656        &self,
1657        envelope: proto::Envelope,
1658        request_type: &'static str,
1659    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1660        self.request_dynamic(envelope, request_type).boxed()
1661    }
1662
1663    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1664        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1665        let connection_id = self.connection_id()?;
1666        self.peer.send_dynamic(connection_id, envelope)
1667    }
1668
1669    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1670        log::debug!(
1671            "rpc respond. client_id:{}, name:{}",
1672            self.id(),
1673            message_type
1674        );
1675        let connection_id = self.connection_id()?;
1676        self.peer.send_dynamic(connection_id, envelope)
1677    }
1678
1679    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1680        &self.handler_set
1681    }
1682
1683    fn is_via_collab(&self) -> bool {
1684        true
1685    }
1686}
1687
1688/// prefix for the zed:// url scheme
1689pub const ZED_URL_SCHEME: &str = "zed";
1690
1691/// Parses the given link into a Zed link.
1692///
1693/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1694/// Returns [`None`] otherwise.
1695pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
1696    let server_url = &ClientSettings::get_global(cx).server_url;
1697    if let Some(stripped) = link
1698        .strip_prefix(server_url)
1699        .and_then(|result| result.strip_prefix('/'))
1700    {
1701        return Some(stripped);
1702    }
1703    if let Some(stripped) = link
1704        .strip_prefix(ZED_URL_SCHEME)
1705        .and_then(|result| result.strip_prefix("://"))
1706    {
1707        return Some(stripped);
1708    }
1709
1710    None
1711}
1712
1713#[cfg(test)]
1714mod tests {
1715    use super::*;
1716    use crate::test::FakeServer;
1717
1718    use clock::FakeSystemClock;
1719    use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1720    use http_client::FakeHttpClient;
1721    use parking_lot::Mutex;
1722    use proto::TypedEnvelope;
1723    use settings::SettingsStore;
1724    use std::future;
1725
1726    #[gpui::test(iterations = 10)]
1727    async fn test_reconnection(cx: &mut TestAppContext) {
1728        init_test(cx);
1729        let user_id = 5;
1730        let client = cx.update(|cx| {
1731            Client::new(
1732                Arc::new(FakeSystemClock::new()),
1733                FakeHttpClient::with_404_response(),
1734                cx,
1735            )
1736        });
1737        let server = FakeServer::for_client(user_id, &client, cx).await;
1738        let mut status = client.status();
1739        assert!(matches!(
1740            status.next().await,
1741            Some(Status::Connected { .. })
1742        ));
1743        assert_eq!(server.auth_count(), 1);
1744
1745        server.forbid_connections();
1746        server.disconnect();
1747        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1748
1749        server.allow_connections();
1750        cx.executor().advance_clock(Duration::from_secs(10));
1751        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1752        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1753
1754        server.forbid_connections();
1755        server.disconnect();
1756        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1757
1758        // Clear cached credentials after authentication fails
1759        server.roll_access_token();
1760        server.allow_connections();
1761        cx.executor().run_until_parked();
1762        cx.executor().advance_clock(Duration::from_secs(10));
1763        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1764        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1765    }
1766
1767    #[gpui::test(iterations = 10)]
1768    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1769        init_test(cx);
1770        let user_id = 5;
1771        let client = cx.update(|cx| {
1772            Client::new(
1773                Arc::new(FakeSystemClock::new()),
1774                FakeHttpClient::with_404_response(),
1775                cx,
1776            )
1777        });
1778        let mut status = client.status();
1779
1780        // Time out when client tries to connect.
1781        client.override_authenticate(move |cx| {
1782            cx.background_spawn(async move {
1783                Ok(Credentials {
1784                    user_id,
1785                    access_token: "token".into(),
1786                })
1787            })
1788        });
1789        client.override_establish_connection(|_, cx| {
1790            cx.background_spawn(async move {
1791                future::pending::<()>().await;
1792                unreachable!()
1793            })
1794        });
1795        let auth_and_connect = cx.spawn({
1796            let client = client.clone();
1797            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1798        });
1799        executor.run_until_parked();
1800        assert!(matches!(status.next().await, Some(Status::Connecting)));
1801
1802        executor.advance_clock(CONNECTION_TIMEOUT);
1803        assert!(matches!(
1804            status.next().await,
1805            Some(Status::ConnectionError { .. })
1806        ));
1807        auth_and_connect.await.into_response().unwrap_err();
1808
1809        // Allow the connection to be established.
1810        let server = FakeServer::for_client(user_id, &client, cx).await;
1811        assert!(matches!(
1812            status.next().await,
1813            Some(Status::Connected { .. })
1814        ));
1815
1816        // Disconnect client.
1817        server.forbid_connections();
1818        server.disconnect();
1819        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1820
1821        // Time out when re-establishing the connection.
1822        server.allow_connections();
1823        client.override_establish_connection(|_, cx| {
1824            cx.background_spawn(async move {
1825                future::pending::<()>().await;
1826                unreachable!()
1827            })
1828        });
1829        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1830        assert!(matches!(
1831            status.next().await,
1832            Some(Status::Reconnecting { .. })
1833        ));
1834
1835        executor.advance_clock(CONNECTION_TIMEOUT);
1836        assert!(matches!(
1837            status.next().await,
1838            Some(Status::ReconnectionError { .. })
1839        ));
1840    }
1841
1842    #[gpui::test(iterations = 10)]
1843    async fn test_authenticating_more_than_once(
1844        cx: &mut TestAppContext,
1845        executor: BackgroundExecutor,
1846    ) {
1847        init_test(cx);
1848        let auth_count = Arc::new(Mutex::new(0));
1849        let dropped_auth_count = Arc::new(Mutex::new(0));
1850        let client = cx.update(|cx| {
1851            Client::new(
1852                Arc::new(FakeSystemClock::new()),
1853                FakeHttpClient::with_404_response(),
1854                cx,
1855            )
1856        });
1857        client.override_authenticate({
1858            let auth_count = auth_count.clone();
1859            let dropped_auth_count = dropped_auth_count.clone();
1860            move |cx| {
1861                let auth_count = auth_count.clone();
1862                let dropped_auth_count = dropped_auth_count.clone();
1863                cx.background_spawn(async move {
1864                    *auth_count.lock() += 1;
1865                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1866                    future::pending::<()>().await;
1867                    unreachable!()
1868                })
1869            }
1870        });
1871
1872        let _authenticate = cx.spawn({
1873            let client = client.clone();
1874            move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1875        });
1876        executor.run_until_parked();
1877        assert_eq!(*auth_count.lock(), 1);
1878        assert_eq!(*dropped_auth_count.lock(), 0);
1879
1880        let _authenticate = cx.spawn({
1881            let client = client.clone();
1882            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1883        });
1884        executor.run_until_parked();
1885        assert_eq!(*auth_count.lock(), 2);
1886        assert_eq!(*dropped_auth_count.lock(), 1);
1887    }
1888
1889    #[gpui::test]
1890    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1891        init_test(cx);
1892        let user_id = 5;
1893        let client = cx.update(|cx| {
1894            Client::new(
1895                Arc::new(FakeSystemClock::new()),
1896                FakeHttpClient::with_404_response(),
1897                cx,
1898            )
1899        });
1900        let server = FakeServer::for_client(user_id, &client, cx).await;
1901
1902        let (done_tx1, done_rx1) = smol::channel::unbounded();
1903        let (done_tx2, done_rx2) = smol::channel::unbounded();
1904        AnyProtoClient::from(client.clone()).add_entity_message_handler(
1905            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, mut cx| {
1906                match entity.read_with(&mut cx, |entity, _| entity.id).unwrap() {
1907                    1 => done_tx1.try_send(()).unwrap(),
1908                    2 => done_tx2.try_send(()).unwrap(),
1909                    _ => unreachable!(),
1910                }
1911                async { Ok(()) }
1912            },
1913        );
1914        let entity1 = cx.new(|_| TestEntity {
1915            id: 1,
1916            subscription: None,
1917        });
1918        let entity2 = cx.new(|_| TestEntity {
1919            id: 2,
1920            subscription: None,
1921        });
1922        let entity3 = cx.new(|_| TestEntity {
1923            id: 3,
1924            subscription: None,
1925        });
1926
1927        let _subscription1 = client
1928            .subscribe_to_entity(1)
1929            .unwrap()
1930            .set_entity(&entity1, &mut cx.to_async());
1931        let _subscription2 = client
1932            .subscribe_to_entity(2)
1933            .unwrap()
1934            .set_entity(&entity2, &mut cx.to_async());
1935        // Ensure dropping a subscription for the same entity type still allows receiving of
1936        // messages for other entity IDs of the same type.
1937        let subscription3 = client
1938            .subscribe_to_entity(3)
1939            .unwrap()
1940            .set_entity(&entity3, &mut cx.to_async());
1941        drop(subscription3);
1942
1943        server.send(proto::JoinProject {
1944            project_id: 1,
1945            committer_name: None,
1946            committer_email: None,
1947        });
1948        server.send(proto::JoinProject {
1949            project_id: 2,
1950            committer_name: None,
1951            committer_email: None,
1952        });
1953        done_rx1.recv().await.unwrap();
1954        done_rx2.recv().await.unwrap();
1955    }
1956
1957    #[gpui::test]
1958    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1959        init_test(cx);
1960        let user_id = 5;
1961        let client = cx.update(|cx| {
1962            Client::new(
1963                Arc::new(FakeSystemClock::new()),
1964                FakeHttpClient::with_404_response(),
1965                cx,
1966            )
1967        });
1968        let server = FakeServer::for_client(user_id, &client, cx).await;
1969
1970        let entity = cx.new(|_| TestEntity::default());
1971        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1972        let (done_tx2, done_rx2) = smol::channel::unbounded();
1973        let subscription1 = client.add_message_handler(
1974            entity.downgrade(),
1975            move |_, _: TypedEnvelope<proto::Ping>, _| {
1976                done_tx1.try_send(()).unwrap();
1977                async { Ok(()) }
1978            },
1979        );
1980        drop(subscription1);
1981        let _subscription2 = client.add_message_handler(
1982            entity.downgrade(),
1983            move |_, _: TypedEnvelope<proto::Ping>, _| {
1984                done_tx2.try_send(()).unwrap();
1985                async { Ok(()) }
1986            },
1987        );
1988        server.send(proto::Ping {});
1989        done_rx2.recv().await.unwrap();
1990    }
1991
1992    #[gpui::test]
1993    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1994        init_test(cx);
1995        let user_id = 5;
1996        let client = cx.update(|cx| {
1997            Client::new(
1998                Arc::new(FakeSystemClock::new()),
1999                FakeHttpClient::with_404_response(),
2000                cx,
2001            )
2002        });
2003        let server = FakeServer::for_client(user_id, &client, cx).await;
2004
2005        let entity = cx.new(|_| TestEntity::default());
2006        let (done_tx, done_rx) = smol::channel::unbounded();
2007        let subscription = client.add_message_handler(
2008            entity.clone().downgrade(),
2009            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
2010                entity
2011                    .update(&mut cx, |entity, _| entity.subscription.take())
2012                    .unwrap();
2013                done_tx.try_send(()).unwrap();
2014                async { Ok(()) }
2015            },
2016        );
2017        entity.update(cx, |entity, _| {
2018            entity.subscription = Some(subscription);
2019        });
2020        server.send(proto::Ping {});
2021        done_rx.recv().await.unwrap();
2022    }
2023
2024    #[derive(Default)]
2025    struct TestEntity {
2026        id: usize,
2027        subscription: Option<Subscription>,
2028    }
2029
2030    fn init_test(cx: &mut TestAppContext) {
2031        cx.update(|cx| {
2032            let settings_store = SettingsStore::test(cx);
2033            cx.set_global(settings_store);
2034            init_settings(cx);
2035        });
2036    }
2037}