client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod socks;
   5pub mod telemetry;
   6pub mod user;
   7pub mod zed_urls;
   8
   9use anyhow::{Context as _, Result, anyhow, bail};
  10use async_recursion::async_recursion;
  11use async_tungstenite::tungstenite::{
  12    client::IntoClientRequest,
  13    error::Error as WebsocketError,
  14    http::{HeaderValue, Request, StatusCode},
  15};
  16use chrono::{DateTime, Utc};
  17use clock::SystemClock;
  18use credentials_provider::CredentialsProvider;
  19use futures::{
  20    AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
  21    channel::oneshot, future::BoxFuture,
  22};
  23use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
  24use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
  25use parking_lot::RwLock;
  26use postage::watch;
  27use rand::prelude::*;
  28use release_channel::{AppVersion, ReleaseChannel};
  29use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  30use schemars::JsonSchema;
  31use serde::{Deserialize, Serialize};
  32use settings::{Settings, SettingsSources};
  33use socks::connect_socks_proxy_stream;
  34use std::pin::Pin;
  35use std::{
  36    any::TypeId,
  37    convert::TryFrom,
  38    fmt::Write as _,
  39    future::Future,
  40    marker::PhantomData,
  41    path::PathBuf,
  42    sync::{
  43        Arc, LazyLock, Weak,
  44        atomic::{AtomicU64, Ordering},
  45    },
  46    time::{Duration, Instant},
  47};
  48use telemetry::Telemetry;
  49use thiserror::Error;
  50use url::Url;
  51use util::{ResultExt, TryFutureExt};
  52
  53pub use rpc::*;
  54pub use telemetry_events::Event;
  55pub use user::*;
  56
  57static ZED_SERVER_URL: LazyLock<Option<String>> =
  58    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  59static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  60
  61pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  62    std::env::var("ZED_IMPERSONATE")
  63        .ok()
  64        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  65});
  66
  67pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  68    std::env::var("ZED_ADMIN_API_TOKEN")
  69        .ok()
  70        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  71});
  72
  73pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  74    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  75
  76pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  77    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty()));
  78
  79pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  80pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
  81pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  82
  83actions!(client, [SignIn, SignOut, Reconnect]);
  84
  85#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
  86pub struct ClientSettingsContent {
  87    server_url: Option<String>,
  88}
  89
  90#[derive(Deserialize)]
  91pub struct ClientSettings {
  92    pub server_url: String,
  93}
  94
  95impl Settings for ClientSettings {
  96    const KEY: Option<&'static str> = None;
  97
  98    type FileContent = ClientSettingsContent;
  99
 100    fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
 101        let mut result = sources.json_merge::<Self>()?;
 102        if let Some(server_url) = &*ZED_SERVER_URL {
 103            result.server_url.clone_from(server_url)
 104        }
 105        Ok(result)
 106    }
 107
 108    fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
 109}
 110
 111#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 112pub struct ProxySettingsContent {
 113    proxy: Option<String>,
 114}
 115
 116#[derive(Deserialize, Default)]
 117pub struct ProxySettings {
 118    pub proxy: Option<String>,
 119}
 120
 121impl Settings for ProxySettings {
 122    const KEY: Option<&'static str> = None;
 123
 124    type FileContent = ProxySettingsContent;
 125
 126    fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
 127        Ok(Self {
 128            proxy: sources
 129                .user
 130                .or(sources.server)
 131                .and_then(|value| value.proxy.clone())
 132                .or(sources.default.proxy.clone()),
 133        })
 134    }
 135
 136    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
 137        vscode.string_setting("http.proxy", &mut current.proxy);
 138    }
 139}
 140
 141pub fn init_settings(cx: &mut App) {
 142    TelemetrySettings::register(cx);
 143    ClientSettings::register(cx);
 144    ProxySettings::register(cx);
 145}
 146
 147pub fn init(client: &Arc<Client>, cx: &mut App) {
 148    let client = Arc::downgrade(client);
 149    cx.on_action({
 150        let client = client.clone();
 151        move |_: &SignIn, cx| {
 152            if let Some(client) = client.upgrade() {
 153                cx.spawn(async move |cx| {
 154                    client.authenticate_and_connect(true, &cx).log_err().await
 155                })
 156                .detach();
 157            }
 158        }
 159    });
 160
 161    cx.on_action({
 162        let client = client.clone();
 163        move |_: &SignOut, cx| {
 164            if let Some(client) = client.upgrade() {
 165                cx.spawn(async move |cx| {
 166                    client.sign_out(&cx).await;
 167                })
 168                .detach();
 169            }
 170        }
 171    });
 172
 173    cx.on_action({
 174        let client = client.clone();
 175        move |_: &Reconnect, cx| {
 176            if let Some(client) = client.upgrade() {
 177                cx.spawn(async move |cx| {
 178                    client.reconnect(&cx);
 179                })
 180                .detach();
 181            }
 182        }
 183    });
 184}
 185
 186struct GlobalClient(Arc<Client>);
 187
 188impl Global for GlobalClient {}
 189
 190pub struct Client {
 191    id: AtomicU64,
 192    peer: Arc<Peer>,
 193    http: Arc<HttpClientWithUrl>,
 194    telemetry: Arc<Telemetry>,
 195    credentials_provider: ClientCredentialsProvider,
 196    state: RwLock<ClientState>,
 197    handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
 198
 199    #[allow(clippy::type_complexity)]
 200    #[cfg(any(test, feature = "test-support"))]
 201    authenticate:
 202        RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
 203
 204    #[allow(clippy::type_complexity)]
 205    #[cfg(any(test, feature = "test-support"))]
 206    establish_connection: RwLock<
 207        Option<
 208            Box<
 209                dyn 'static
 210                    + Send
 211                    + Sync
 212                    + Fn(
 213                        &Credentials,
 214                        &AsyncApp,
 215                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 216            >,
 217        >,
 218    >,
 219
 220    #[cfg(any(test, feature = "test-support"))]
 221    rpc_url: RwLock<Option<Url>>,
 222}
 223
 224#[derive(Error, Debug)]
 225pub enum EstablishConnectionError {
 226    #[error("upgrade required")]
 227    UpgradeRequired,
 228    #[error("unauthorized")]
 229    Unauthorized,
 230    #[error("{0}")]
 231    Other(#[from] anyhow::Error),
 232    #[error("{0}")]
 233    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 234    #[error("{0}")]
 235    Io(#[from] std::io::Error),
 236    #[error("{0}")]
 237    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 238}
 239
 240impl From<WebsocketError> for EstablishConnectionError {
 241    fn from(error: WebsocketError) -> Self {
 242        if let WebsocketError::Http(response) = &error {
 243            match response.status() {
 244                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 245                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 246                _ => {}
 247            }
 248        }
 249        EstablishConnectionError::Other(error.into())
 250    }
 251}
 252
 253impl EstablishConnectionError {
 254    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 255        Self::Other(error.into())
 256    }
 257}
 258
 259#[derive(Copy, Clone, Debug, PartialEq)]
 260pub enum Status {
 261    SignedOut,
 262    UpgradeRequired,
 263    Authenticating,
 264    Connecting,
 265    ConnectionError,
 266    Connected {
 267        peer_id: PeerId,
 268        connection_id: ConnectionId,
 269    },
 270    ConnectionLost,
 271    Reauthenticating,
 272    Reconnecting,
 273    ReconnectionError {
 274        next_reconnection: Instant,
 275    },
 276}
 277
 278impl Status {
 279    pub fn is_connected(&self) -> bool {
 280        matches!(self, Self::Connected { .. })
 281    }
 282
 283    pub fn is_signed_out(&self) -> bool {
 284        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 285    }
 286}
 287
 288struct ClientState {
 289    credentials: Option<Credentials>,
 290    status: (watch::Sender<Status>, watch::Receiver<Status>),
 291    _reconnect_task: Option<Task<()>>,
 292}
 293
 294#[derive(Clone, Debug, Eq, PartialEq)]
 295pub struct Credentials {
 296    pub user_id: u64,
 297    pub access_token: String,
 298}
 299
 300impl Credentials {
 301    pub fn authorization_header(&self) -> String {
 302        format!("{} {}", self.user_id, self.access_token)
 303    }
 304}
 305
 306pub struct ClientCredentialsProvider {
 307    provider: Arc<dyn CredentialsProvider>,
 308}
 309
 310impl ClientCredentialsProvider {
 311    pub fn new(cx: &App) -> Self {
 312        Self {
 313            provider: <dyn CredentialsProvider>::global(cx),
 314        }
 315    }
 316
 317    fn server_url(&self, cx: &AsyncApp) -> Result<String> {
 318        cx.update(|cx| ClientSettings::get_global(cx).server_url.clone())
 319    }
 320
 321    /// Reads the credentials from the provider.
 322    fn read_credentials<'a>(
 323        &'a self,
 324        cx: &'a AsyncApp,
 325    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
 326        async move {
 327            if IMPERSONATE_LOGIN.is_some() {
 328                return None;
 329            }
 330
 331            let server_url = self.server_url(cx).ok()?;
 332            let (user_id, access_token) = self
 333                .provider
 334                .read_credentials(&server_url, cx)
 335                .await
 336                .log_err()
 337                .flatten()?;
 338
 339            Some(Credentials {
 340                user_id: user_id.parse().ok()?,
 341                access_token: String::from_utf8(access_token).ok()?,
 342            })
 343        }
 344        .boxed_local()
 345    }
 346
 347    /// Writes the credentials to the provider.
 348    fn write_credentials<'a>(
 349        &'a self,
 350        user_id: u64,
 351        access_token: String,
 352        cx: &'a AsyncApp,
 353    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 354        async move {
 355            let server_url = self.server_url(cx)?;
 356            self.provider
 357                .write_credentials(
 358                    &server_url,
 359                    &user_id.to_string(),
 360                    access_token.as_bytes(),
 361                    cx,
 362                )
 363                .await
 364        }
 365        .boxed_local()
 366    }
 367
 368    /// Deletes the credentials from the provider.
 369    fn delete_credentials<'a>(
 370        &'a self,
 371        cx: &'a AsyncApp,
 372    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 373        async move {
 374            let server_url = self.server_url(cx)?;
 375            self.provider.delete_credentials(&server_url, cx).await
 376        }
 377        .boxed_local()
 378    }
 379}
 380
 381impl Default for ClientState {
 382    fn default() -> Self {
 383        Self {
 384            credentials: None,
 385            status: watch::channel_with(Status::SignedOut),
 386            _reconnect_task: None,
 387        }
 388    }
 389}
 390
 391pub enum Subscription {
 392    Entity {
 393        client: Weak<Client>,
 394        id: (TypeId, u64),
 395    },
 396    Message {
 397        client: Weak<Client>,
 398        id: TypeId,
 399    },
 400}
 401
 402impl Drop for Subscription {
 403    fn drop(&mut self) {
 404        match self {
 405            Subscription::Entity { client, id } => {
 406                if let Some(client) = client.upgrade() {
 407                    let mut state = client.handler_set.lock();
 408                    let _ = state.entities_by_type_and_remote_id.remove(id);
 409                }
 410            }
 411            Subscription::Message { client, id } => {
 412                if let Some(client) = client.upgrade() {
 413                    let mut state = client.handler_set.lock();
 414                    let _ = state.entity_types_by_message_type.remove(id);
 415                    let _ = state.message_handlers.remove(id);
 416                }
 417            }
 418        }
 419    }
 420}
 421
 422pub struct PendingEntitySubscription<T: 'static> {
 423    client: Arc<Client>,
 424    remote_id: u64,
 425    _entity_type: PhantomData<T>,
 426    consumed: bool,
 427}
 428
 429impl<T: 'static> PendingEntitySubscription<T> {
 430    pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
 431        self.consumed = true;
 432        let mut handlers = self.client.handler_set.lock();
 433        let id = (TypeId::of::<T>(), self.remote_id);
 434        let Some(EntityMessageSubscriber::Pending(messages)) =
 435            handlers.entities_by_type_and_remote_id.remove(&id)
 436        else {
 437            unreachable!()
 438        };
 439
 440        handlers.entities_by_type_and_remote_id.insert(
 441            id,
 442            EntityMessageSubscriber::Entity {
 443                handle: entity.downgrade().into(),
 444            },
 445        );
 446        drop(handlers);
 447        for message in messages {
 448            let client_id = self.client.id();
 449            let type_name = message.payload_type_name();
 450            let sender_id = message.original_sender_id();
 451            log::debug!(
 452                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 453                client_id,
 454                sender_id,
 455                type_name
 456            );
 457            self.client.handle_message(message, cx);
 458        }
 459        Subscription::Entity {
 460            client: Arc::downgrade(&self.client),
 461            id,
 462        }
 463    }
 464}
 465
 466impl<T: 'static> Drop for PendingEntitySubscription<T> {
 467    fn drop(&mut self) {
 468        if !self.consumed {
 469            let mut state = self.client.handler_set.lock();
 470            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 471                .entities_by_type_and_remote_id
 472                .remove(&(TypeId::of::<T>(), self.remote_id))
 473            {
 474                for message in messages {
 475                    log::info!("unhandled message {}", message.payload_type_name());
 476                }
 477            }
 478        }
 479    }
 480}
 481
 482#[derive(Copy, Clone)]
 483pub struct TelemetrySettings {
 484    pub diagnostics: bool,
 485    pub metrics: bool,
 486}
 487
 488/// Control what info is collected by Zed.
 489#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 490pub struct TelemetrySettingsContent {
 491    /// Send debug info like crash reports.
 492    ///
 493    /// Default: true
 494    pub diagnostics: Option<bool>,
 495    /// Send anonymized usage data like what languages you're using Zed with.
 496    ///
 497    /// Default: true
 498    pub metrics: Option<bool>,
 499}
 500
 501impl settings::Settings for TelemetrySettings {
 502    const KEY: Option<&'static str> = Some("telemetry");
 503
 504    type FileContent = TelemetrySettingsContent;
 505
 506    fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
 507        Ok(Self {
 508            diagnostics: sources
 509                .user
 510                .as_ref()
 511                .or(sources.server.as_ref())
 512                .and_then(|v| v.diagnostics)
 513                .unwrap_or(
 514                    sources
 515                        .default
 516                        .diagnostics
 517                        .ok_or_else(Self::missing_default)?,
 518                ),
 519            metrics: sources
 520                .user
 521                .as_ref()
 522                .or(sources.server.as_ref())
 523                .and_then(|v| v.metrics)
 524                .unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
 525        })
 526    }
 527
 528    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
 529        vscode.enum_setting("telemetry.telemetryLevel", &mut current.metrics, |s| {
 530            Some(s == "all")
 531        });
 532        vscode.enum_setting("telemetry.telemetryLevel", &mut current.diagnostics, |s| {
 533            Some(matches!(s, "all" | "error" | "crash"))
 534        });
 535        // we could translate telemetry.telemetryLevel, but just because users didn't want
 536        // to send microsoft telemetry doesn't mean they don't want to send it to zed. their
 537        // all/error/crash/off correspond to combinations of our "diagnostics" and "metrics".
 538    }
 539}
 540
 541impl Client {
 542    pub fn new(
 543        clock: Arc<dyn SystemClock>,
 544        http: Arc<HttpClientWithUrl>,
 545        cx: &mut App,
 546    ) -> Arc<Self> {
 547        Arc::new(Self {
 548            id: AtomicU64::new(0),
 549            peer: Peer::new(0),
 550            telemetry: Telemetry::new(clock, http.clone(), cx),
 551            http,
 552            credentials_provider: ClientCredentialsProvider::new(cx),
 553            state: Default::default(),
 554            handler_set: Default::default(),
 555
 556            #[cfg(any(test, feature = "test-support"))]
 557            authenticate: Default::default(),
 558            #[cfg(any(test, feature = "test-support"))]
 559            establish_connection: Default::default(),
 560            #[cfg(any(test, feature = "test-support"))]
 561            rpc_url: RwLock::default(),
 562        })
 563    }
 564
 565    pub fn production(cx: &mut App) -> Arc<Self> {
 566        let clock = Arc::new(clock::RealSystemClock);
 567        let http = Arc::new(HttpClientWithUrl::new_url(
 568            cx.http_client(),
 569            &ClientSettings::get_global(cx).server_url,
 570            cx.http_client().proxy().cloned(),
 571        ));
 572        Self::new(clock, http, cx)
 573    }
 574
 575    pub fn id(&self) -> u64 {
 576        self.id.load(Ordering::SeqCst)
 577    }
 578
 579    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 580        self.http.clone()
 581    }
 582
 583    pub fn set_id(&self, id: u64) -> &Self {
 584        self.id.store(id, Ordering::SeqCst);
 585        self
 586    }
 587
 588    #[cfg(any(test, feature = "test-support"))]
 589    pub fn teardown(&self) {
 590        let mut state = self.state.write();
 591        state._reconnect_task.take();
 592        self.handler_set.lock().clear();
 593        self.peer.teardown();
 594    }
 595
 596    #[cfg(any(test, feature = "test-support"))]
 597    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 598    where
 599        F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
 600    {
 601        *self.authenticate.write() = Some(Box::new(authenticate));
 602        self
 603    }
 604
 605    #[cfg(any(test, feature = "test-support"))]
 606    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 607    where
 608        F: 'static
 609            + Send
 610            + Sync
 611            + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
 612    {
 613        *self.establish_connection.write() = Some(Box::new(connect));
 614        self
 615    }
 616
 617    #[cfg(any(test, feature = "test-support"))]
 618    pub fn override_rpc_url(&self, url: Url) -> &Self {
 619        *self.rpc_url.write() = Some(url);
 620        self
 621    }
 622
 623    pub fn global(cx: &App) -> Arc<Self> {
 624        cx.global::<GlobalClient>().0.clone()
 625    }
 626    pub fn set_global(client: Arc<Client>, cx: &mut App) {
 627        cx.set_global(GlobalClient(client))
 628    }
 629
 630    pub fn user_id(&self) -> Option<u64> {
 631        self.state
 632            .read()
 633            .credentials
 634            .as_ref()
 635            .map(|credentials| credentials.user_id)
 636    }
 637
 638    pub fn peer_id(&self) -> Option<PeerId> {
 639        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 640            Some(*peer_id)
 641        } else {
 642            None
 643        }
 644    }
 645
 646    pub fn status(&self) -> watch::Receiver<Status> {
 647        self.state.read().status.1.clone()
 648    }
 649
 650    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
 651        log::info!("set status on client {}: {:?}", self.id(), status);
 652        let mut state = self.state.write();
 653        *state.status.0.borrow_mut() = status;
 654
 655        match status {
 656            Status::Connected { .. } => {
 657                state._reconnect_task = None;
 658            }
 659            Status::ConnectionLost => {
 660                let this = self.clone();
 661                state._reconnect_task = Some(cx.spawn(async move |cx| {
 662                    #[cfg(any(test, feature = "test-support"))]
 663                    let mut rng = StdRng::seed_from_u64(0);
 664                    #[cfg(not(any(test, feature = "test-support")))]
 665                    let mut rng = StdRng::from_entropy();
 666
 667                    let mut delay = INITIAL_RECONNECTION_DELAY;
 668                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 669                        log::error!("failed to connect {}", error);
 670                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 671                            this.set_status(
 672                                Status::ReconnectionError {
 673                                    next_reconnection: Instant::now() + delay,
 674                                },
 675                                &cx,
 676                            );
 677                            cx.background_executor().timer(delay).await;
 678                            delay = delay
 679                                .mul_f32(rng.gen_range(0.5..=2.5))
 680                                .max(INITIAL_RECONNECTION_DELAY)
 681                                .min(MAX_RECONNECTION_DELAY);
 682                        } else {
 683                            break;
 684                        }
 685                    }
 686                }));
 687            }
 688            Status::SignedOut | Status::UpgradeRequired => {
 689                self.telemetry.set_authenticated_user_info(None, false);
 690                state._reconnect_task.take();
 691            }
 692            _ => {}
 693        }
 694    }
 695
 696    pub fn subscribe_to_entity<T>(
 697        self: &Arc<Self>,
 698        remote_id: u64,
 699    ) -> Result<PendingEntitySubscription<T>>
 700    where
 701        T: 'static,
 702    {
 703        let id = (TypeId::of::<T>(), remote_id);
 704
 705        let mut state = self.handler_set.lock();
 706        if state.entities_by_type_and_remote_id.contains_key(&id) {
 707            return Err(anyhow!("already subscribed to entity"));
 708        }
 709
 710        state
 711            .entities_by_type_and_remote_id
 712            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 713
 714        Ok(PendingEntitySubscription {
 715            client: self.clone(),
 716            remote_id,
 717            consumed: false,
 718            _entity_type: PhantomData,
 719        })
 720    }
 721
 722    #[track_caller]
 723    pub fn add_message_handler<M, E, H, F>(
 724        self: &Arc<Self>,
 725        entity: WeakEntity<E>,
 726        handler: H,
 727    ) -> Subscription
 728    where
 729        M: EnvelopedMessage,
 730        E: 'static,
 731        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 732        F: 'static + Future<Output = Result<()>>,
 733    {
 734        self.add_message_handler_impl(entity, move |entity, message, _, cx| {
 735            handler(entity, message, cx)
 736        })
 737    }
 738
 739    fn add_message_handler_impl<M, E, H, F>(
 740        self: &Arc<Self>,
 741        entity: WeakEntity<E>,
 742        handler: H,
 743    ) -> Subscription
 744    where
 745        M: EnvelopedMessage,
 746        E: 'static,
 747        H: 'static
 748            + Sync
 749            + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
 750            + Send
 751            + Sync,
 752        F: 'static + Future<Output = Result<()>>,
 753    {
 754        let message_type_id = TypeId::of::<M>();
 755        let mut state = self.handler_set.lock();
 756        state
 757            .entities_by_message_type
 758            .insert(message_type_id, entity.into());
 759
 760        let prev_handler = state.message_handlers.insert(
 761            message_type_id,
 762            Arc::new(move |subscriber, envelope, client, cx| {
 763                let subscriber = subscriber.downcast::<E>().unwrap();
 764                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 765                handler(subscriber, *envelope, client.clone(), cx).boxed_local()
 766            }),
 767        );
 768        if prev_handler.is_some() {
 769            let location = std::panic::Location::caller();
 770            panic!(
 771                "{}:{} registered handler for the same message {} twice",
 772                location.file(),
 773                location.line(),
 774                std::any::type_name::<M>()
 775            );
 776        }
 777
 778        Subscription::Message {
 779            client: Arc::downgrade(self),
 780            id: message_type_id,
 781        }
 782    }
 783
 784    pub fn add_request_handler<M, E, H, F>(
 785        self: &Arc<Self>,
 786        entity: WeakEntity<E>,
 787        handler: H,
 788    ) -> Subscription
 789    where
 790        M: RequestMessage,
 791        E: 'static,
 792        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 793        F: 'static + Future<Output = Result<M::Response>>,
 794    {
 795        self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
 796            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 797        })
 798    }
 799
 800    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 801        receipt: Receipt<T>,
 802        response: F,
 803        client: AnyProtoClient,
 804    ) -> Result<()> {
 805        match response.await {
 806            Ok(response) => {
 807                client.send_response(receipt.message_id, response)?;
 808                Ok(())
 809            }
 810            Err(error) => {
 811                client.send_response(receipt.message_id, error.to_proto())?;
 812                Err(error)
 813            }
 814        }
 815    }
 816
 817    pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
 818        self.credentials_provider
 819            .read_credentials(cx)
 820            .await
 821            .is_some()
 822    }
 823
 824    #[async_recursion(?Send)]
 825    pub async fn authenticate_and_connect(
 826        self: &Arc<Self>,
 827        try_provider: bool,
 828        cx: &AsyncApp,
 829    ) -> anyhow::Result<()> {
 830        let was_disconnected = match *self.status().borrow() {
 831            Status::SignedOut => true,
 832            Status::ConnectionError
 833            | Status::ConnectionLost
 834            | Status::Authenticating { .. }
 835            | Status::Reauthenticating { .. }
 836            | Status::ReconnectionError { .. } => false,
 837            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 838                return Ok(());
 839            }
 840            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 841        };
 842        if was_disconnected {
 843            self.set_status(Status::Authenticating, cx);
 844        } else {
 845            self.set_status(Status::Reauthenticating, cx)
 846        }
 847
 848        let mut read_from_provider = false;
 849        let mut credentials = self.state.read().credentials.clone();
 850        if credentials.is_none() && try_provider {
 851            credentials = self.credentials_provider.read_credentials(cx).await;
 852            read_from_provider = credentials.is_some();
 853        }
 854
 855        if credentials.is_none() {
 856            let mut status_rx = self.status();
 857            let _ = status_rx.next().await;
 858            futures::select_biased! {
 859                authenticate = self.authenticate(cx).fuse() => {
 860                    match authenticate {
 861                        Ok(creds) => credentials = Some(creds),
 862                        Err(err) => {
 863                            self.set_status(Status::ConnectionError, cx);
 864                            return Err(err);
 865                        }
 866                    }
 867                }
 868                _ = status_rx.next().fuse() => {
 869                    return Err(anyhow!("authentication canceled"));
 870                }
 871            }
 872        }
 873        let credentials = credentials.unwrap();
 874        self.set_id(credentials.user_id);
 875
 876        if was_disconnected {
 877            self.set_status(Status::Connecting, cx);
 878        } else {
 879            self.set_status(Status::Reconnecting, cx);
 880        }
 881
 882        let mut timeout =
 883            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
 884        futures::select_biased! {
 885            connection = self.establish_connection(&credentials, cx).fuse() => {
 886                match connection {
 887                    Ok(conn) => {
 888                        self.state.write().credentials = Some(credentials.clone());
 889                        if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
 890                            self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err();
 891                        }
 892
 893                        futures::select_biased! {
 894                            result = self.set_connection(conn, cx).fuse() => result,
 895                            _ = timeout => {
 896                                self.set_status(Status::ConnectionError, cx);
 897                                Err(anyhow!("timed out waiting on hello message from server"))
 898                            }
 899                        }
 900                    }
 901                    Err(EstablishConnectionError::Unauthorized) => {
 902                        self.state.write().credentials.take();
 903                        if read_from_provider {
 904                            self.credentials_provider.delete_credentials(cx).await.log_err();
 905                            self.set_status(Status::SignedOut, cx);
 906                            self.authenticate_and_connect(false, cx).await
 907                        } else {
 908                            self.set_status(Status::ConnectionError, cx);
 909                            Err(EstablishConnectionError::Unauthorized)?
 910                        }
 911                    }
 912                    Err(EstablishConnectionError::UpgradeRequired) => {
 913                        self.set_status(Status::UpgradeRequired, cx);
 914                        Err(EstablishConnectionError::UpgradeRequired)?
 915                    }
 916                    Err(error) => {
 917                        self.set_status(Status::ConnectionError, cx);
 918                        Err(error)?
 919                    }
 920                }
 921            }
 922            _ = &mut timeout => {
 923                self.set_status(Status::ConnectionError, cx);
 924                Err(anyhow!("timed out trying to establish connection"))
 925            }
 926        }
 927    }
 928
 929    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
 930        let executor = cx.background_executor();
 931        log::debug!("add connection to peer");
 932        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
 933            let executor = executor.clone();
 934            move |duration| executor.timer(duration)
 935        });
 936        let handle_io = executor.spawn(handle_io);
 937
 938        let peer_id = async {
 939            log::debug!("waiting for server hello");
 940            let message = incoming
 941                .next()
 942                .await
 943                .ok_or_else(|| anyhow!("no hello message received"))?;
 944            log::debug!("got server hello");
 945            let hello_message_type_name = message.payload_type_name().to_string();
 946            let hello = message
 947                .into_any()
 948                .downcast::<TypedEnvelope<proto::Hello>>()
 949                .map_err(|_| {
 950                    anyhow!(
 951                        "invalid hello message received: {:?}",
 952                        hello_message_type_name
 953                    )
 954                })?;
 955            let peer_id = hello
 956                .payload
 957                .peer_id
 958                .ok_or_else(|| anyhow!("invalid peer id"))?;
 959            Ok(peer_id)
 960        };
 961
 962        let peer_id = match peer_id.await {
 963            Ok(peer_id) => peer_id,
 964            Err(error) => {
 965                self.peer.disconnect(connection_id);
 966                return Err(error);
 967            }
 968        };
 969
 970        log::debug!(
 971            "set status to connected (connection id: {:?}, peer id: {:?})",
 972            connection_id,
 973            peer_id
 974        );
 975        self.set_status(
 976            Status::Connected {
 977                peer_id,
 978                connection_id,
 979            },
 980            cx,
 981        );
 982
 983        cx.spawn({
 984            let this = self.clone();
 985            async move |cx| {
 986                while let Some(message) = incoming.next().await {
 987                    this.handle_message(message, &cx);
 988                    // Don't starve the main thread when receiving lots of messages at once.
 989                    smol::future::yield_now().await;
 990                }
 991            }
 992        })
 993        .detach();
 994
 995        cx.spawn({
 996            let this = self.clone();
 997            async move |cx| match handle_io.await {
 998                Ok(()) => {
 999                    if *this.status().borrow()
1000                        == (Status::Connected {
1001                            connection_id,
1002                            peer_id,
1003                        })
1004                    {
1005                        this.set_status(Status::SignedOut, &cx);
1006                    }
1007                }
1008                Err(err) => {
1009                    log::error!("connection error: {:?}", err);
1010                    this.set_status(Status::ConnectionLost, &cx);
1011                }
1012            }
1013        })
1014        .detach();
1015
1016        Ok(())
1017    }
1018
1019    fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1020        #[cfg(any(test, feature = "test-support"))]
1021        if let Some(callback) = self.authenticate.read().as_ref() {
1022            return callback(cx);
1023        }
1024
1025        self.authenticate_with_browser(cx)
1026    }
1027
1028    fn establish_connection(
1029        self: &Arc<Self>,
1030        credentials: &Credentials,
1031        cx: &AsyncApp,
1032    ) -> Task<Result<Connection, EstablishConnectionError>> {
1033        #[cfg(any(test, feature = "test-support"))]
1034        if let Some(callback) = self.establish_connection.read().as_ref() {
1035            return callback(credentials, cx);
1036        }
1037
1038        self.establish_websocket_connection(credentials, cx)
1039    }
1040
1041    fn rpc_url(
1042        &self,
1043        http: Arc<HttpClientWithUrl>,
1044        release_channel: Option<ReleaseChannel>,
1045    ) -> impl Future<Output = Result<url::Url>> + use<> {
1046        #[cfg(any(test, feature = "test-support"))]
1047        let url_override = self.rpc_url.read().clone();
1048
1049        async move {
1050            #[cfg(any(test, feature = "test-support"))]
1051            if let Some(url) = url_override {
1052                return Ok(url);
1053            }
1054
1055            if let Some(url) = &*ZED_RPC_URL {
1056                return Url::parse(url).context("invalid rpc url");
1057            }
1058
1059            let mut url = http.build_url("/rpc");
1060            if let Some(preview_param) =
1061                release_channel.and_then(|channel| channel.release_query_param())
1062            {
1063                url += "?";
1064                url += preview_param;
1065            }
1066
1067            let response = http.get(&url, Default::default(), false).await?;
1068            let collab_url = if response.status().is_redirection() {
1069                response
1070                    .headers()
1071                    .get("Location")
1072                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1073                    .to_str()
1074                    .map_err(EstablishConnectionError::other)?
1075                    .to_string()
1076            } else {
1077                Err(anyhow!(
1078                    "unexpected /rpc response status {}",
1079                    response.status()
1080                ))?
1081            };
1082
1083            Url::parse(&collab_url).context("invalid rpc url")
1084        }
1085    }
1086
1087    fn establish_websocket_connection(
1088        self: &Arc<Self>,
1089        credentials: &Credentials,
1090        cx: &AsyncApp,
1091    ) -> Task<Result<Connection, EstablishConnectionError>> {
1092        let release_channel = cx
1093            .update(|cx| ReleaseChannel::try_global(cx))
1094            .ok()
1095            .flatten();
1096        let app_version = cx
1097            .update(|cx| AppVersion::global(cx).to_string())
1098            .ok()
1099            .unwrap_or_default();
1100
1101        let http = self.http.clone();
1102        let proxy = http.proxy().cloned();
1103        let credentials = credentials.clone();
1104        let rpc_url = self.rpc_url(http, release_channel);
1105        let system_id = self.telemetry.system_id();
1106        let metrics_id = self.telemetry.metrics_id();
1107        cx.spawn(async move |cx| {
1108            use HttpOrHttps::*;
1109
1110            #[derive(Debug)]
1111            enum HttpOrHttps {
1112                Http,
1113                Https,
1114            }
1115
1116            let mut rpc_url = rpc_url.await?;
1117            let url_scheme = match rpc_url.scheme() {
1118                "https" => Https,
1119                "http" => Http,
1120                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1121            };
1122            let rpc_host = rpc_url
1123                .host_str()
1124                .zip(rpc_url.port_or_known_default())
1125                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1126
1127            let stream = {
1128                let handle = cx.update(|cx| gpui_tokio::Tokio::handle(cx)).ok().unwrap();
1129                let _guard = handle.enter();
1130                connect_socks_proxy_stream(proxy.as_ref(), rpc_host).await?
1131            };
1132
1133            log::info!("connected to rpc endpoint {}", rpc_url);
1134
1135            rpc_url
1136                .set_scheme(match url_scheme {
1137                    Https => "wss",
1138                    Http => "ws",
1139                })
1140                .unwrap();
1141
1142            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1143            // for us from the RPC URL.
1144            //
1145            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1146            let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1147
1148            // We then modify the request to add our desired headers.
1149            let request_headers = request.headers_mut();
1150            request_headers.insert(
1151                "Authorization",
1152                HeaderValue::from_str(&credentials.authorization_header())?,
1153            );
1154            request_headers.insert(
1155                "x-zed-protocol-version",
1156                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1157            );
1158            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1159            request_headers.insert(
1160                "x-zed-release-channel",
1161                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1162            );
1163            if let Some(system_id) = system_id {
1164                request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1165            }
1166            if let Some(metrics_id) = metrics_id {
1167                request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1168            }
1169
1170            let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1171                request,
1172                stream,
1173                Some(Arc::new(http_client_tls::tls_config()).into()),
1174                None,
1175            )
1176            .await?;
1177
1178            Ok(Connection::new(
1179                stream
1180                    .map_err(|error| anyhow!(error))
1181                    .sink_map_err(|error| anyhow!(error)),
1182            ))
1183        })
1184    }
1185
1186    pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1187        let http = self.http.clone();
1188        let this = self.clone();
1189        cx.spawn(async move |cx| {
1190            let background = cx.background_executor().clone();
1191
1192            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1193            cx.update(|cx| {
1194                cx.spawn(async move |cx| {
1195                    let url = open_url_rx.await?;
1196                    cx.update(|cx| cx.open_url(&url))
1197                })
1198                .detach_and_log_err(cx);
1199            })
1200            .log_err();
1201
1202            let credentials = background
1203                .clone()
1204                .spawn(async move {
1205                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1206                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1207                    // any other app running on the user's device.
1208                    let (public_key, private_key) =
1209                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1210                    let public_key_string = String::try_from(public_key)
1211                        .expect("failed to serialize public key for auth");
1212
1213                    if let Some((login, token)) =
1214                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1215                    {
1216                        eprintln!("authenticate as admin {login}, {token}");
1217
1218                        return this
1219                            .authenticate_as_admin(http, login.clone(), token.clone())
1220                            .await;
1221                    }
1222
1223                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1224                    let server =
1225                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1226                    let port = server.server_addr().port();
1227
1228                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1229                    // that the user is signing in from a Zed app running on the same device.
1230                    let mut url = http.build_url(&format!(
1231                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1232                        port, public_key_string
1233                    ));
1234
1235                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1236                        log::info!("impersonating user @{}", impersonate_login);
1237                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1238                    }
1239
1240                    open_url_tx.send(url).log_err();
1241
1242                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1243                    // access token from the query params.
1244                    //
1245                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1246                    // custom URL scheme instead of this local HTTP server.
1247                    let (user_id, access_token) = background
1248                        .spawn(async move {
1249                            for _ in 0..100 {
1250                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1251                                    let path = req.url();
1252                                    let mut user_id = None;
1253                                    let mut access_token = None;
1254                                    let url = Url::parse(&format!("http://example.com{}", path))
1255                                        .context("failed to parse login notification url")?;
1256                                    for (key, value) in url.query_pairs() {
1257                                        if key == "access_token" {
1258                                            access_token = Some(value.to_string());
1259                                        } else if key == "user_id" {
1260                                            user_id = Some(value.to_string());
1261                                        }
1262                                    }
1263
1264                                    let post_auth_url =
1265                                        http.build_url("/native_app_signin_succeeded");
1266                                    req.respond(
1267                                        tiny_http::Response::empty(302).with_header(
1268                                            tiny_http::Header::from_bytes(
1269                                                &b"Location"[..],
1270                                                post_auth_url.as_bytes(),
1271                                            )
1272                                            .unwrap(),
1273                                        ),
1274                                    )
1275                                    .context("failed to respond to login http request")?;
1276                                    return Ok((
1277                                        user_id
1278                                            .ok_or_else(|| anyhow!("missing user_id parameter"))?,
1279                                        access_token.ok_or_else(|| {
1280                                            anyhow!("missing access_token parameter")
1281                                        })?,
1282                                    ));
1283                                }
1284                            }
1285
1286                            Err(anyhow!("didn't receive login redirect"))
1287                        })
1288                        .await?;
1289
1290                    let access_token = private_key
1291                        .decrypt_string(&access_token)
1292                        .context("failed to decrypt access token")?;
1293
1294                    Ok(Credentials {
1295                        user_id: user_id.parse()?,
1296                        access_token,
1297                    })
1298                })
1299                .await?;
1300
1301            cx.update(|cx| cx.activate(true))?;
1302            Ok(credentials)
1303        })
1304    }
1305
1306    async fn authenticate_as_admin(
1307        self: &Arc<Self>,
1308        http: Arc<HttpClientWithUrl>,
1309        login: String,
1310        mut api_token: String,
1311    ) -> Result<Credentials> {
1312        #[derive(Deserialize)]
1313        struct AuthenticatedUserResponse {
1314            user: User,
1315        }
1316
1317        #[derive(Deserialize)]
1318        struct User {
1319            id: u64,
1320        }
1321
1322        let github_user = {
1323            #[derive(Deserialize)]
1324            struct GithubUser {
1325                id: i32,
1326                login: String,
1327                created_at: DateTime<Utc>,
1328            }
1329
1330            let request = {
1331                let mut request_builder =
1332                    Request::get(&format!("https://api.github.com/users/{login}"));
1333                if let Ok(github_token) = std::env::var("GITHUB_TOKEN") {
1334                    request_builder =
1335                        request_builder.header("Authorization", format!("Bearer {}", github_token));
1336                }
1337
1338                request_builder.body(AsyncBody::empty())?
1339            };
1340
1341            let mut response = http
1342                .send(request)
1343                .await
1344                .context("error fetching GitHub user")?;
1345
1346            let mut body = Vec::new();
1347            response
1348                .body_mut()
1349                .read_to_end(&mut body)
1350                .await
1351                .context("error reading GitHub user")?;
1352
1353            if !response.status().is_success() {
1354                let text = String::from_utf8_lossy(body.as_slice());
1355                bail!(
1356                    "status error {}, response: {text:?}",
1357                    response.status().as_u16()
1358                );
1359            }
1360
1361            serde_json::from_slice::<GithubUser>(body.as_slice()).map_err(|err| {
1362                log::error!("Error deserializing: {:?}", err);
1363                log::error!(
1364                    "GitHub API response text: {:?}",
1365                    String::from_utf8_lossy(body.as_slice())
1366                );
1367                anyhow!("error deserializing GitHub user")
1368            })?
1369        };
1370
1371        let query_params = [
1372            ("github_login", &github_user.login),
1373            ("github_user_id", &github_user.id.to_string()),
1374            (
1375                "github_user_created_at",
1376                &github_user.created_at.to_rfc3339(),
1377            ),
1378        ];
1379
1380        // Use the collab server's admin API to retrieve the ID
1381        // of the impersonated user.
1382        let mut url = self.rpc_url(http.clone(), None).await?;
1383        url.set_path("/user");
1384        url.set_query(Some(
1385            &query_params
1386                .iter()
1387                .map(|(key, value)| {
1388                    format!(
1389                        "{}={}",
1390                        key,
1391                        url::form_urlencoded::byte_serialize(value.as_bytes()).collect::<String>()
1392                    )
1393                })
1394                .collect::<Vec<String>>()
1395                .join("&"),
1396        ));
1397        let request: http_client::Request<AsyncBody> = Request::get(url.as_str())
1398            .header("Authorization", format!("token {api_token}"))
1399            .body("".into())?;
1400
1401        let mut response = http.send(request).await?;
1402        let mut body = String::new();
1403        response.body_mut().read_to_string(&mut body).await?;
1404        if !response.status().is_success() {
1405            Err(anyhow!(
1406                "admin user request failed {} - {}",
1407                response.status().as_u16(),
1408                body,
1409            ))?;
1410        }
1411        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1412
1413        // Use the admin API token to authenticate as the impersonated user.
1414        api_token.insert_str(0, "ADMIN_TOKEN:");
1415        Ok(Credentials {
1416            user_id: response.user.id,
1417            access_token: api_token,
1418        })
1419    }
1420
1421    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1422        self.state.write().credentials = None;
1423        self.disconnect(cx);
1424
1425        if self.has_credentials(cx).await {
1426            self.credentials_provider
1427                .delete_credentials(cx)
1428                .await
1429                .log_err();
1430        }
1431    }
1432
1433    pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1434        self.peer.teardown();
1435        self.set_status(Status::SignedOut, cx);
1436    }
1437
1438    pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1439        self.peer.teardown();
1440        self.set_status(Status::ConnectionLost, cx);
1441    }
1442
1443    fn connection_id(&self) -> Result<ConnectionId> {
1444        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1445            Ok(connection_id)
1446        } else {
1447            Err(anyhow!("not connected"))
1448        }
1449    }
1450
1451    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1452        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1453        self.peer.send(self.connection_id()?, message)
1454    }
1455
1456    pub fn request<T: RequestMessage>(
1457        &self,
1458        request: T,
1459    ) -> impl Future<Output = Result<T::Response>> + use<T> {
1460        self.request_envelope(request)
1461            .map_ok(|envelope| envelope.payload)
1462    }
1463
1464    pub fn request_stream<T: RequestMessage>(
1465        &self,
1466        request: T,
1467    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1468        let client_id = self.id.load(Ordering::SeqCst);
1469        log::debug!(
1470            "rpc request start. client_id:{}. name:{}",
1471            client_id,
1472            T::NAME
1473        );
1474        let response = self
1475            .connection_id()
1476            .map(|conn_id| self.peer.request_stream(conn_id, request));
1477        async move {
1478            let response = response?.await;
1479            log::debug!(
1480                "rpc request finish. client_id:{}. name:{}",
1481                client_id,
1482                T::NAME
1483            );
1484            response
1485        }
1486    }
1487
1488    pub fn request_envelope<T: RequestMessage>(
1489        &self,
1490        request: T,
1491    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1492        let client_id = self.id();
1493        log::debug!(
1494            "rpc request start. client_id:{}. name:{}",
1495            client_id,
1496            T::NAME
1497        );
1498        let response = self
1499            .connection_id()
1500            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1501        async move {
1502            let response = response?.await;
1503            log::debug!(
1504                "rpc request finish. client_id:{}. name:{}",
1505                client_id,
1506                T::NAME
1507            );
1508            response
1509        }
1510    }
1511
1512    pub fn request_dynamic(
1513        &self,
1514        envelope: proto::Envelope,
1515        request_type: &'static str,
1516    ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1517        let client_id = self.id();
1518        log::debug!(
1519            "rpc request start. client_id:{}. name:{}",
1520            client_id,
1521            request_type
1522        );
1523        let response = self
1524            .connection_id()
1525            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1526        async move {
1527            let response = response?.await;
1528            log::debug!(
1529                "rpc request finish. client_id:{}. name:{}",
1530                client_id,
1531                request_type
1532            );
1533            Ok(response?.0)
1534        }
1535    }
1536
1537    fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1538        let sender_id = message.sender_id();
1539        let request_id = message.message_id();
1540        let type_name = message.payload_type_name();
1541        let original_sender_id = message.original_sender_id();
1542
1543        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1544            &self.handler_set,
1545            message,
1546            self.clone().into(),
1547            cx.clone(),
1548        ) {
1549            let client_id = self.id();
1550            log::debug!(
1551                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1552                client_id,
1553                original_sender_id,
1554                type_name
1555            );
1556            cx.spawn(async move |_| match future.await {
1557                Ok(()) => {
1558                    log::debug!(
1559                        "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1560                        client_id,
1561                        original_sender_id,
1562                        type_name
1563                    );
1564                }
1565                Err(error) => {
1566                    log::error!(
1567                        "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1568                        client_id,
1569                        original_sender_id,
1570                        type_name,
1571                        error
1572                    );
1573                }
1574            })
1575            .detach();
1576        } else {
1577            log::info!("unhandled message {}", type_name);
1578            self.peer
1579                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1580                .log_err();
1581        }
1582    }
1583
1584    pub fn telemetry(&self) -> &Arc<Telemetry> {
1585        &self.telemetry
1586    }
1587}
1588
1589impl ProtoClient for Client {
1590    fn request(
1591        &self,
1592        envelope: proto::Envelope,
1593        request_type: &'static str,
1594    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1595        self.request_dynamic(envelope, request_type).boxed()
1596    }
1597
1598    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1599        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1600        let connection_id = self.connection_id()?;
1601        self.peer.send_dynamic(connection_id, envelope)
1602    }
1603
1604    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1605        log::debug!(
1606            "rpc respond. client_id:{}, name:{}",
1607            self.id(),
1608            message_type
1609        );
1610        let connection_id = self.connection_id()?;
1611        self.peer.send_dynamic(connection_id, envelope)
1612    }
1613
1614    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1615        &self.handler_set
1616    }
1617
1618    fn is_via_collab(&self) -> bool {
1619        true
1620    }
1621}
1622
1623/// prefix for the zed:// url scheme
1624pub const ZED_URL_SCHEME: &str = "zed";
1625
1626/// Parses the given link into a Zed link.
1627///
1628/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1629/// Returns [`None`] otherwise.
1630pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
1631    let server_url = &ClientSettings::get_global(cx).server_url;
1632    if let Some(stripped) = link
1633        .strip_prefix(server_url)
1634        .and_then(|result| result.strip_prefix('/'))
1635    {
1636        return Some(stripped);
1637    }
1638    if let Some(stripped) = link
1639        .strip_prefix(ZED_URL_SCHEME)
1640        .and_then(|result| result.strip_prefix("://"))
1641    {
1642        return Some(stripped);
1643    }
1644
1645    None
1646}
1647
1648#[cfg(test)]
1649mod tests {
1650    use super::*;
1651    use crate::test::FakeServer;
1652
1653    use clock::FakeSystemClock;
1654    use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1655    use http_client::FakeHttpClient;
1656    use parking_lot::Mutex;
1657    use proto::TypedEnvelope;
1658    use settings::SettingsStore;
1659    use std::future;
1660
1661    #[gpui::test(iterations = 10)]
1662    async fn test_reconnection(cx: &mut TestAppContext) {
1663        init_test(cx);
1664        let user_id = 5;
1665        let client = cx.update(|cx| {
1666            Client::new(
1667                Arc::new(FakeSystemClock::new()),
1668                FakeHttpClient::with_404_response(),
1669                cx,
1670            )
1671        });
1672        let server = FakeServer::for_client(user_id, &client, cx).await;
1673        let mut status = client.status();
1674        assert!(matches!(
1675            status.next().await,
1676            Some(Status::Connected { .. })
1677        ));
1678        assert_eq!(server.auth_count(), 1);
1679
1680        server.forbid_connections();
1681        server.disconnect();
1682        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1683
1684        server.allow_connections();
1685        cx.executor().advance_clock(Duration::from_secs(10));
1686        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1687        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1688
1689        server.forbid_connections();
1690        server.disconnect();
1691        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1692
1693        // Clear cached credentials after authentication fails
1694        server.roll_access_token();
1695        server.allow_connections();
1696        cx.executor().run_until_parked();
1697        cx.executor().advance_clock(Duration::from_secs(10));
1698        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1699        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1700    }
1701
1702    #[gpui::test(iterations = 10)]
1703    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1704        init_test(cx);
1705        let user_id = 5;
1706        let client = cx.update(|cx| {
1707            Client::new(
1708                Arc::new(FakeSystemClock::new()),
1709                FakeHttpClient::with_404_response(),
1710                cx,
1711            )
1712        });
1713        let mut status = client.status();
1714
1715        // Time out when client tries to connect.
1716        client.override_authenticate(move |cx| {
1717            cx.background_spawn(async move {
1718                Ok(Credentials {
1719                    user_id,
1720                    access_token: "token".into(),
1721                })
1722            })
1723        });
1724        client.override_establish_connection(|_, cx| {
1725            cx.background_spawn(async move {
1726                future::pending::<()>().await;
1727                unreachable!()
1728            })
1729        });
1730        let auth_and_connect = cx.spawn({
1731            let client = client.clone();
1732            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1733        });
1734        executor.run_until_parked();
1735        assert!(matches!(status.next().await, Some(Status::Connecting)));
1736
1737        executor.advance_clock(CONNECTION_TIMEOUT);
1738        assert!(matches!(
1739            status.next().await,
1740            Some(Status::ConnectionError { .. })
1741        ));
1742        auth_and_connect.await.unwrap_err();
1743
1744        // Allow the connection to be established.
1745        let server = FakeServer::for_client(user_id, &client, cx).await;
1746        assert!(matches!(
1747            status.next().await,
1748            Some(Status::Connected { .. })
1749        ));
1750
1751        // Disconnect client.
1752        server.forbid_connections();
1753        server.disconnect();
1754        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1755
1756        // Time out when re-establishing the connection.
1757        server.allow_connections();
1758        client.override_establish_connection(|_, cx| {
1759            cx.background_spawn(async move {
1760                future::pending::<()>().await;
1761                unreachable!()
1762            })
1763        });
1764        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1765        assert!(matches!(
1766            status.next().await,
1767            Some(Status::Reconnecting { .. })
1768        ));
1769
1770        executor.advance_clock(CONNECTION_TIMEOUT);
1771        assert!(matches!(
1772            status.next().await,
1773            Some(Status::ReconnectionError { .. })
1774        ));
1775    }
1776
1777    #[gpui::test(iterations = 10)]
1778    async fn test_authenticating_more_than_once(
1779        cx: &mut TestAppContext,
1780        executor: BackgroundExecutor,
1781    ) {
1782        init_test(cx);
1783        let auth_count = Arc::new(Mutex::new(0));
1784        let dropped_auth_count = Arc::new(Mutex::new(0));
1785        let client = cx.update(|cx| {
1786            Client::new(
1787                Arc::new(FakeSystemClock::new()),
1788                FakeHttpClient::with_404_response(),
1789                cx,
1790            )
1791        });
1792        client.override_authenticate({
1793            let auth_count = auth_count.clone();
1794            let dropped_auth_count = dropped_auth_count.clone();
1795            move |cx| {
1796                let auth_count = auth_count.clone();
1797                let dropped_auth_count = dropped_auth_count.clone();
1798                cx.background_spawn(async move {
1799                    *auth_count.lock() += 1;
1800                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1801                    future::pending::<()>().await;
1802                    unreachable!()
1803                })
1804            }
1805        });
1806
1807        let _authenticate = cx.spawn({
1808            let client = client.clone();
1809            move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1810        });
1811        executor.run_until_parked();
1812        assert_eq!(*auth_count.lock(), 1);
1813        assert_eq!(*dropped_auth_count.lock(), 0);
1814
1815        let _authenticate = cx.spawn({
1816            let client = client.clone();
1817            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1818        });
1819        executor.run_until_parked();
1820        assert_eq!(*auth_count.lock(), 2);
1821        assert_eq!(*dropped_auth_count.lock(), 1);
1822    }
1823
1824    #[gpui::test]
1825    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1826        init_test(cx);
1827        let user_id = 5;
1828        let client = cx.update(|cx| {
1829            Client::new(
1830                Arc::new(FakeSystemClock::new()),
1831                FakeHttpClient::with_404_response(),
1832                cx,
1833            )
1834        });
1835        let server = FakeServer::for_client(user_id, &client, cx).await;
1836
1837        let (done_tx1, done_rx1) = smol::channel::unbounded();
1838        let (done_tx2, done_rx2) = smol::channel::unbounded();
1839        AnyProtoClient::from(client.clone()).add_entity_message_handler(
1840            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, mut cx| {
1841                match entity.update(&mut cx, |entity, _| entity.id).unwrap() {
1842                    1 => done_tx1.try_send(()).unwrap(),
1843                    2 => done_tx2.try_send(()).unwrap(),
1844                    _ => unreachable!(),
1845                }
1846                async { Ok(()) }
1847            },
1848        );
1849        let entity1 = cx.new(|_| TestEntity {
1850            id: 1,
1851            subscription: None,
1852        });
1853        let entity2 = cx.new(|_| TestEntity {
1854            id: 2,
1855            subscription: None,
1856        });
1857        let entity3 = cx.new(|_| TestEntity {
1858            id: 3,
1859            subscription: None,
1860        });
1861
1862        let _subscription1 = client
1863            .subscribe_to_entity(1)
1864            .unwrap()
1865            .set_entity(&entity1, &mut cx.to_async());
1866        let _subscription2 = client
1867            .subscribe_to_entity(2)
1868            .unwrap()
1869            .set_entity(&entity2, &mut cx.to_async());
1870        // Ensure dropping a subscription for the same entity type still allows receiving of
1871        // messages for other entity IDs of the same type.
1872        let subscription3 = client
1873            .subscribe_to_entity(3)
1874            .unwrap()
1875            .set_entity(&entity3, &mut cx.to_async());
1876        drop(subscription3);
1877
1878        server.send(proto::JoinProject { project_id: 1 });
1879        server.send(proto::JoinProject { project_id: 2 });
1880        done_rx1.recv().await.unwrap();
1881        done_rx2.recv().await.unwrap();
1882    }
1883
1884    #[gpui::test]
1885    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1886        init_test(cx);
1887        let user_id = 5;
1888        let client = cx.update(|cx| {
1889            Client::new(
1890                Arc::new(FakeSystemClock::new()),
1891                FakeHttpClient::with_404_response(),
1892                cx,
1893            )
1894        });
1895        let server = FakeServer::for_client(user_id, &client, cx).await;
1896
1897        let entity = cx.new(|_| TestEntity::default());
1898        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1899        let (done_tx2, done_rx2) = smol::channel::unbounded();
1900        let subscription1 = client.add_message_handler(
1901            entity.downgrade(),
1902            move |_, _: TypedEnvelope<proto::Ping>, _| {
1903                done_tx1.try_send(()).unwrap();
1904                async { Ok(()) }
1905            },
1906        );
1907        drop(subscription1);
1908        let _subscription2 = client.add_message_handler(
1909            entity.downgrade(),
1910            move |_, _: TypedEnvelope<proto::Ping>, _| {
1911                done_tx2.try_send(()).unwrap();
1912                async { Ok(()) }
1913            },
1914        );
1915        server.send(proto::Ping {});
1916        done_rx2.recv().await.unwrap();
1917    }
1918
1919    #[gpui::test]
1920    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1921        init_test(cx);
1922        let user_id = 5;
1923        let client = cx.update(|cx| {
1924            Client::new(
1925                Arc::new(FakeSystemClock::new()),
1926                FakeHttpClient::with_404_response(),
1927                cx,
1928            )
1929        });
1930        let server = FakeServer::for_client(user_id, &client, cx).await;
1931
1932        let entity = cx.new(|_| TestEntity::default());
1933        let (done_tx, done_rx) = smol::channel::unbounded();
1934        let subscription = client.add_message_handler(
1935            entity.clone().downgrade(),
1936            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
1937                entity
1938                    .update(&mut cx, |entity, _| entity.subscription.take())
1939                    .unwrap();
1940                done_tx.try_send(()).unwrap();
1941                async { Ok(()) }
1942            },
1943        );
1944        entity.update(cx, |entity, _| {
1945            entity.subscription = Some(subscription);
1946        });
1947        server.send(proto::Ping {});
1948        done_rx.recv().await.unwrap();
1949    }
1950
1951    #[derive(Default)]
1952    struct TestEntity {
1953        id: usize,
1954        subscription: Option<Subscription>,
1955    }
1956
1957    fn init_test(cx: &mut TestAppContext) {
1958        cx.update(|cx| {
1959            let settings_store = SettingsStore::test(cx);
1960            cx.set_global(settings_store);
1961            init_settings(cx);
1962        });
1963    }
1964}