client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod cloud;
   5mod proxy;
   6pub mod telemetry;
   7pub mod user;
   8pub mod zed_urls;
   9
  10use anyhow::{Context as _, Result, anyhow};
  11use async_recursion::async_recursion;
  12use async_tungstenite::tungstenite::{
  13    client::IntoClientRequest,
  14    error::Error as WebsocketError,
  15    http::{HeaderValue, Request, StatusCode},
  16};
  17use clock::SystemClock;
  18use cloud_api_client::CloudApiClient;
  19use credentials_provider::CredentialsProvider;
  20use futures::{
  21    AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
  22    channel::oneshot, future::BoxFuture,
  23};
  24use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
  25use http_client::{HttpClient, HttpClientWithUrl, http};
  26use parking_lot::RwLock;
  27use postage::watch;
  28use proxy::connect_proxy_stream;
  29use rand::prelude::*;
  30use release_channel::{AppVersion, ReleaseChannel};
  31use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  32use schemars::JsonSchema;
  33use serde::{Deserialize, Serialize};
  34use settings::{Settings, SettingsSources};
  35use std::{
  36    any::TypeId,
  37    convert::TryFrom,
  38    fmt::Write as _,
  39    future::Future,
  40    marker::PhantomData,
  41    path::PathBuf,
  42    sync::{
  43        Arc, LazyLock, Weak,
  44        atomic::{AtomicU64, Ordering},
  45    },
  46    time::{Duration, Instant},
  47};
  48use std::{cmp, pin::Pin};
  49use telemetry::Telemetry;
  50use thiserror::Error;
  51use tokio::net::TcpStream;
  52use url::Url;
  53use util::{ConnectionResult, ResultExt};
  54
  55pub use cloud::*;
  56pub use rpc::*;
  57pub use telemetry_events::Event;
  58pub use user::*;
  59
  60static ZED_SERVER_URL: LazyLock<Option<String>> =
  61    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  62static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  63
  64pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  65    std::env::var("ZED_IMPERSONATE")
  66        .ok()
  67        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  68});
  69
  70pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  71    std::env::var("ZED_ADMIN_API_TOKEN")
  72        .ok()
  73        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  74});
  75
  76pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  77    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  78
  79pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  80    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty()));
  81
  82pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  83pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
  84pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  85
  86actions!(
  87    client,
  88    [
  89        /// Signs in to Zed account.
  90        SignIn,
  91        /// Signs out of Zed account.
  92        SignOut,
  93        /// Reconnects to the collaboration server.
  94        Reconnect
  95    ]
  96);
  97
  98#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
  99pub struct ClientSettingsContent {
 100    server_url: Option<String>,
 101}
 102
 103#[derive(Deserialize)]
 104pub struct ClientSettings {
 105    pub server_url: String,
 106}
 107
 108impl Settings for ClientSettings {
 109    const KEY: Option<&'static str> = None;
 110
 111    type FileContent = ClientSettingsContent;
 112
 113    fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
 114        let mut result = sources.json_merge::<Self>()?;
 115        if let Some(server_url) = &*ZED_SERVER_URL {
 116            result.server_url.clone_from(server_url)
 117        }
 118        Ok(result)
 119    }
 120
 121    fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
 122}
 123
 124#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 125pub struct ProxySettingsContent {
 126    proxy: Option<String>,
 127}
 128
 129#[derive(Deserialize, Default)]
 130pub struct ProxySettings {
 131    pub proxy: Option<String>,
 132}
 133
 134impl Settings for ProxySettings {
 135    const KEY: Option<&'static str> = None;
 136
 137    type FileContent = ProxySettingsContent;
 138
 139    fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
 140        Ok(Self {
 141            proxy: sources
 142                .user
 143                .or(sources.server)
 144                .and_then(|value| value.proxy.clone())
 145                .or(sources.default.proxy.clone()),
 146        })
 147    }
 148
 149    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
 150        vscode.string_setting("http.proxy", &mut current.proxy);
 151    }
 152}
 153
 154pub fn init_settings(cx: &mut App) {
 155    TelemetrySettings::register(cx);
 156    ClientSettings::register(cx);
 157    ProxySettings::register(cx);
 158}
 159
 160pub fn init(client: &Arc<Client>, cx: &mut App) {
 161    let client = Arc::downgrade(client);
 162    cx.on_action({
 163        let client = client.clone();
 164        move |_: &SignIn, cx| {
 165            if let Some(client) = client.upgrade() {
 166                cx.spawn(
 167                    async move |cx| match client.authenticate_and_connect(true, &cx).await {
 168                        ConnectionResult::Timeout => {
 169                            log::error!("Initial authentication timed out");
 170                        }
 171                        ConnectionResult::ConnectionReset => {
 172                            log::error!("Initial authentication connection reset");
 173                        }
 174                        ConnectionResult::Result(r) => {
 175                            r.log_err();
 176                        }
 177                    },
 178                )
 179                .detach();
 180            }
 181        }
 182    });
 183
 184    cx.on_action({
 185        let client = client.clone();
 186        move |_: &SignOut, cx| {
 187            if let Some(client) = client.upgrade() {
 188                cx.spawn(async move |cx| {
 189                    client.sign_out(&cx).await;
 190                })
 191                .detach();
 192            }
 193        }
 194    });
 195
 196    cx.on_action({
 197        let client = client.clone();
 198        move |_: &Reconnect, cx| {
 199            if let Some(client) = client.upgrade() {
 200                cx.spawn(async move |cx| {
 201                    client.reconnect(&cx);
 202                })
 203                .detach();
 204            }
 205        }
 206    });
 207}
 208
 209struct GlobalClient(Arc<Client>);
 210
 211impl Global for GlobalClient {}
 212
 213pub struct Client {
 214    id: AtomicU64,
 215    peer: Arc<Peer>,
 216    http: Arc<HttpClientWithUrl>,
 217    cloud_client: Arc<CloudApiClient>,
 218    telemetry: Arc<Telemetry>,
 219    credentials_provider: ClientCredentialsProvider,
 220    state: RwLock<ClientState>,
 221    handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
 222
 223    #[allow(clippy::type_complexity)]
 224    #[cfg(any(test, feature = "test-support"))]
 225    authenticate:
 226        RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
 227
 228    #[allow(clippy::type_complexity)]
 229    #[cfg(any(test, feature = "test-support"))]
 230    establish_connection: RwLock<
 231        Option<
 232            Box<
 233                dyn 'static
 234                    + Send
 235                    + Sync
 236                    + Fn(
 237                        &Credentials,
 238                        &AsyncApp,
 239                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 240            >,
 241        >,
 242    >,
 243
 244    #[cfg(any(test, feature = "test-support"))]
 245    rpc_url: RwLock<Option<Url>>,
 246}
 247
 248#[derive(Error, Debug)]
 249pub enum EstablishConnectionError {
 250    #[error("upgrade required")]
 251    UpgradeRequired,
 252    #[error("unauthorized")]
 253    Unauthorized,
 254    #[error("{0}")]
 255    Other(#[from] anyhow::Error),
 256    #[error("{0}")]
 257    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 258    #[error("{0}")]
 259    Io(#[from] std::io::Error),
 260    #[error("{0}")]
 261    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 262}
 263
 264impl From<WebsocketError> for EstablishConnectionError {
 265    fn from(error: WebsocketError) -> Self {
 266        if let WebsocketError::Http(response) = &error {
 267            match response.status() {
 268                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 269                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 270                _ => {}
 271            }
 272        }
 273        EstablishConnectionError::Other(error.into())
 274    }
 275}
 276
 277impl EstablishConnectionError {
 278    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 279        Self::Other(error.into())
 280    }
 281}
 282
 283#[derive(Copy, Clone, Debug, PartialEq)]
 284pub enum Status {
 285    SignedOut,
 286    UpgradeRequired,
 287    Authenticating,
 288    Connecting,
 289    ConnectionError,
 290    Connected {
 291        peer_id: PeerId,
 292        connection_id: ConnectionId,
 293    },
 294    ConnectionLost,
 295    Reauthenticating,
 296    Reconnecting,
 297    ReconnectionError {
 298        next_reconnection: Instant,
 299    },
 300}
 301
 302impl Status {
 303    pub fn is_connected(&self) -> bool {
 304        matches!(self, Self::Connected { .. })
 305    }
 306
 307    pub fn is_signing_in(&self) -> bool {
 308        matches!(
 309            self,
 310            Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
 311        )
 312    }
 313
 314    pub fn is_signed_out(&self) -> bool {
 315        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 316    }
 317}
 318
 319struct ClientState {
 320    credentials: Option<Credentials>,
 321    status: (watch::Sender<Status>, watch::Receiver<Status>),
 322    _reconnect_task: Option<Task<()>>,
 323}
 324
 325#[derive(Clone, Debug, Eq, PartialEq)]
 326pub struct Credentials {
 327    pub user_id: u64,
 328    pub access_token: String,
 329}
 330
 331impl Credentials {
 332    pub fn authorization_header(&self) -> String {
 333        format!("{} {}", self.user_id, self.access_token)
 334    }
 335}
 336
 337pub struct ClientCredentialsProvider {
 338    provider: Arc<dyn CredentialsProvider>,
 339}
 340
 341impl ClientCredentialsProvider {
 342    pub fn new(cx: &App) -> Self {
 343        Self {
 344            provider: <dyn CredentialsProvider>::global(cx),
 345        }
 346    }
 347
 348    fn server_url(&self, cx: &AsyncApp) -> Result<String> {
 349        cx.update(|cx| ClientSettings::get_global(cx).server_url.clone())
 350    }
 351
 352    /// Reads the credentials from the provider.
 353    fn read_credentials<'a>(
 354        &'a self,
 355        cx: &'a AsyncApp,
 356    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
 357        async move {
 358            if IMPERSONATE_LOGIN.is_some() {
 359                return None;
 360            }
 361
 362            let server_url = self.server_url(cx).ok()?;
 363            let (user_id, access_token) = self
 364                .provider
 365                .read_credentials(&server_url, cx)
 366                .await
 367                .log_err()
 368                .flatten()?;
 369
 370            Some(Credentials {
 371                user_id: user_id.parse().ok()?,
 372                access_token: String::from_utf8(access_token).ok()?,
 373            })
 374        }
 375        .boxed_local()
 376    }
 377
 378    /// Writes the credentials to the provider.
 379    fn write_credentials<'a>(
 380        &'a self,
 381        user_id: u64,
 382        access_token: String,
 383        cx: &'a AsyncApp,
 384    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 385        async move {
 386            let server_url = self.server_url(cx)?;
 387            self.provider
 388                .write_credentials(
 389                    &server_url,
 390                    &user_id.to_string(),
 391                    access_token.as_bytes(),
 392                    cx,
 393                )
 394                .await
 395        }
 396        .boxed_local()
 397    }
 398
 399    /// Deletes the credentials from the provider.
 400    fn delete_credentials<'a>(
 401        &'a self,
 402        cx: &'a AsyncApp,
 403    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
 404        async move {
 405            let server_url = self.server_url(cx)?;
 406            self.provider.delete_credentials(&server_url, cx).await
 407        }
 408        .boxed_local()
 409    }
 410}
 411
 412impl Default for ClientState {
 413    fn default() -> Self {
 414        Self {
 415            credentials: None,
 416            status: watch::channel_with(Status::SignedOut),
 417            _reconnect_task: None,
 418        }
 419    }
 420}
 421
 422pub enum Subscription {
 423    Entity {
 424        client: Weak<Client>,
 425        id: (TypeId, u64),
 426    },
 427    Message {
 428        client: Weak<Client>,
 429        id: TypeId,
 430    },
 431}
 432
 433impl Drop for Subscription {
 434    fn drop(&mut self) {
 435        match self {
 436            Subscription::Entity { client, id } => {
 437                if let Some(client) = client.upgrade() {
 438                    let mut state = client.handler_set.lock();
 439                    let _ = state.entities_by_type_and_remote_id.remove(id);
 440                }
 441            }
 442            Subscription::Message { client, id } => {
 443                if let Some(client) = client.upgrade() {
 444                    let mut state = client.handler_set.lock();
 445                    let _ = state.entity_types_by_message_type.remove(id);
 446                    let _ = state.message_handlers.remove(id);
 447                }
 448            }
 449        }
 450    }
 451}
 452
 453pub struct PendingEntitySubscription<T: 'static> {
 454    client: Arc<Client>,
 455    remote_id: u64,
 456    _entity_type: PhantomData<T>,
 457    consumed: bool,
 458}
 459
 460impl<T: 'static> PendingEntitySubscription<T> {
 461    pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
 462        self.consumed = true;
 463        let mut handlers = self.client.handler_set.lock();
 464        let id = (TypeId::of::<T>(), self.remote_id);
 465        let Some(EntityMessageSubscriber::Pending(messages)) =
 466            handlers.entities_by_type_and_remote_id.remove(&id)
 467        else {
 468            unreachable!()
 469        };
 470
 471        handlers.entities_by_type_and_remote_id.insert(
 472            id,
 473            EntityMessageSubscriber::Entity {
 474                handle: entity.downgrade().into(),
 475            },
 476        );
 477        drop(handlers);
 478        for message in messages {
 479            let client_id = self.client.id();
 480            let type_name = message.payload_type_name();
 481            let sender_id = message.original_sender_id();
 482            log::debug!(
 483                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 484                client_id,
 485                sender_id,
 486                type_name
 487            );
 488            self.client.handle_message(message, cx);
 489        }
 490        Subscription::Entity {
 491            client: Arc::downgrade(&self.client),
 492            id,
 493        }
 494    }
 495}
 496
 497impl<T: 'static> Drop for PendingEntitySubscription<T> {
 498    fn drop(&mut self) {
 499        if !self.consumed {
 500            let mut state = self.client.handler_set.lock();
 501            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 502                .entities_by_type_and_remote_id
 503                .remove(&(TypeId::of::<T>(), self.remote_id))
 504            {
 505                for message in messages {
 506                    log::info!("unhandled message {}", message.payload_type_name());
 507                }
 508            }
 509        }
 510    }
 511}
 512
 513#[derive(Copy, Clone, Deserialize, Debug)]
 514pub struct TelemetrySettings {
 515    pub diagnostics: bool,
 516    pub metrics: bool,
 517}
 518
 519/// Control what info is collected by Zed.
 520#[derive(Default, Clone, Serialize, Deserialize, JsonSchema, Debug)]
 521pub struct TelemetrySettingsContent {
 522    /// Send debug info like crash reports.
 523    ///
 524    /// Default: true
 525    pub diagnostics: Option<bool>,
 526    /// Send anonymized usage data like what languages you're using Zed with.
 527    ///
 528    /// Default: true
 529    pub metrics: Option<bool>,
 530}
 531
 532impl settings::Settings for TelemetrySettings {
 533    const KEY: Option<&'static str> = Some("telemetry");
 534
 535    type FileContent = TelemetrySettingsContent;
 536
 537    fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
 538        sources.json_merge()
 539    }
 540
 541    fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
 542        vscode.enum_setting("telemetry.telemetryLevel", &mut current.metrics, |s| {
 543            Some(s == "all")
 544        });
 545        vscode.enum_setting("telemetry.telemetryLevel", &mut current.diagnostics, |s| {
 546            Some(matches!(s, "all" | "error" | "crash"))
 547        });
 548        // we could translate telemetry.telemetryLevel, but just because users didn't want
 549        // to send microsoft telemetry doesn't mean they don't want to send it to zed. their
 550        // all/error/crash/off correspond to combinations of our "diagnostics" and "metrics".
 551    }
 552}
 553
 554impl Client {
 555    pub fn new(
 556        clock: Arc<dyn SystemClock>,
 557        http: Arc<HttpClientWithUrl>,
 558        cx: &mut App,
 559    ) -> Arc<Self> {
 560        Arc::new(Self {
 561            id: AtomicU64::new(0),
 562            peer: Peer::new(0),
 563            telemetry: Telemetry::new(clock, http.clone(), cx),
 564            cloud_client: Arc::new(CloudApiClient::new(http.clone())),
 565            http,
 566            credentials_provider: ClientCredentialsProvider::new(cx),
 567            state: Default::default(),
 568            handler_set: Default::default(),
 569
 570            #[cfg(any(test, feature = "test-support"))]
 571            authenticate: Default::default(),
 572            #[cfg(any(test, feature = "test-support"))]
 573            establish_connection: Default::default(),
 574            #[cfg(any(test, feature = "test-support"))]
 575            rpc_url: RwLock::default(),
 576        })
 577    }
 578
 579    pub fn production(cx: &mut App) -> Arc<Self> {
 580        let clock = Arc::new(clock::RealSystemClock);
 581        let http = Arc::new(HttpClientWithUrl::new_url(
 582            cx.http_client(),
 583            &ClientSettings::get_global(cx).server_url,
 584            cx.http_client().proxy().cloned(),
 585        ));
 586        Self::new(clock, http, cx)
 587    }
 588
 589    pub fn id(&self) -> u64 {
 590        self.id.load(Ordering::SeqCst)
 591    }
 592
 593    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 594        self.http.clone()
 595    }
 596
 597    pub fn cloud_client(&self) -> Arc<CloudApiClient> {
 598        self.cloud_client.clone()
 599    }
 600
 601    pub fn set_id(&self, id: u64) -> &Self {
 602        self.id.store(id, Ordering::SeqCst);
 603        self
 604    }
 605
 606    #[cfg(any(test, feature = "test-support"))]
 607    pub fn teardown(&self) {
 608        let mut state = self.state.write();
 609        state._reconnect_task.take();
 610        self.handler_set.lock().clear();
 611        self.peer.teardown();
 612    }
 613
 614    #[cfg(any(test, feature = "test-support"))]
 615    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 616    where
 617        F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
 618    {
 619        *self.authenticate.write() = Some(Box::new(authenticate));
 620        self
 621    }
 622
 623    #[cfg(any(test, feature = "test-support"))]
 624    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 625    where
 626        F: 'static
 627            + Send
 628            + Sync
 629            + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
 630    {
 631        *self.establish_connection.write() = Some(Box::new(connect));
 632        self
 633    }
 634
 635    #[cfg(any(test, feature = "test-support"))]
 636    pub fn override_rpc_url(&self, url: Url) -> &Self {
 637        *self.rpc_url.write() = Some(url);
 638        self
 639    }
 640
 641    pub fn global(cx: &App) -> Arc<Self> {
 642        cx.global::<GlobalClient>().0.clone()
 643    }
 644    pub fn set_global(client: Arc<Client>, cx: &mut App) {
 645        cx.set_global(GlobalClient(client))
 646    }
 647
 648    pub fn user_id(&self) -> Option<u64> {
 649        self.state
 650            .read()
 651            .credentials
 652            .as_ref()
 653            .map(|credentials| credentials.user_id)
 654    }
 655
 656    pub fn peer_id(&self) -> Option<PeerId> {
 657        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 658            Some(*peer_id)
 659        } else {
 660            None
 661        }
 662    }
 663
 664    pub fn status(&self) -> watch::Receiver<Status> {
 665        self.state.read().status.1.clone()
 666    }
 667
 668    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
 669        log::info!("set status on client {}: {:?}", self.id(), status);
 670        let mut state = self.state.write();
 671        *state.status.0.borrow_mut() = status;
 672
 673        match status {
 674            Status::Connected { .. } => {
 675                state._reconnect_task = None;
 676            }
 677            Status::ConnectionLost => {
 678                let client = self.clone();
 679                state._reconnect_task = Some(cx.spawn(async move |cx| {
 680                    #[cfg(any(test, feature = "test-support"))]
 681                    let mut rng = StdRng::seed_from_u64(0);
 682                    #[cfg(not(any(test, feature = "test-support")))]
 683                    let mut rng = StdRng::from_entropy();
 684
 685                    let mut delay = INITIAL_RECONNECTION_DELAY;
 686                    loop {
 687                        match client.authenticate_and_connect(true, &cx).await {
 688                            ConnectionResult::Timeout => {
 689                                log::error!("client connect attempt timed out")
 690                            }
 691                            ConnectionResult::ConnectionReset => {
 692                                log::error!("client connect attempt reset")
 693                            }
 694                            ConnectionResult::Result(r) => {
 695                                if let Err(error) = r {
 696                                    log::error!("failed to connect: {error}");
 697                                } else {
 698                                    break;
 699                                }
 700                            }
 701                        }
 702
 703                        if matches!(*client.status().borrow(), Status::ConnectionError) {
 704                            client.set_status(
 705                                Status::ReconnectionError {
 706                                    next_reconnection: Instant::now() + delay,
 707                                },
 708                                &cx,
 709                            );
 710                            let jitter =
 711                                Duration::from_millis(rng.gen_range(0..delay.as_millis() as u64));
 712                            cx.background_executor().timer(delay + jitter).await;
 713                            delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
 714                        } else {
 715                            break;
 716                        }
 717                    }
 718                }));
 719            }
 720            Status::SignedOut | Status::UpgradeRequired => {
 721                self.telemetry.set_authenticated_user_info(None, false);
 722                state._reconnect_task.take();
 723            }
 724            _ => {}
 725        }
 726    }
 727
 728    pub fn subscribe_to_entity<T>(
 729        self: &Arc<Self>,
 730        remote_id: u64,
 731    ) -> Result<PendingEntitySubscription<T>>
 732    where
 733        T: 'static,
 734    {
 735        let id = (TypeId::of::<T>(), remote_id);
 736
 737        let mut state = self.handler_set.lock();
 738        anyhow::ensure!(
 739            !state.entities_by_type_and_remote_id.contains_key(&id),
 740            "already subscribed to entity"
 741        );
 742
 743        state
 744            .entities_by_type_and_remote_id
 745            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 746
 747        Ok(PendingEntitySubscription {
 748            client: self.clone(),
 749            remote_id,
 750            consumed: false,
 751            _entity_type: PhantomData,
 752        })
 753    }
 754
 755    #[track_caller]
 756    pub fn add_message_handler<M, E, H, F>(
 757        self: &Arc<Self>,
 758        entity: WeakEntity<E>,
 759        handler: H,
 760    ) -> Subscription
 761    where
 762        M: EnvelopedMessage,
 763        E: 'static,
 764        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 765        F: 'static + Future<Output = Result<()>>,
 766    {
 767        self.add_message_handler_impl(entity, move |entity, message, _, cx| {
 768            handler(entity, message, cx)
 769        })
 770    }
 771
 772    fn add_message_handler_impl<M, E, H, F>(
 773        self: &Arc<Self>,
 774        entity: WeakEntity<E>,
 775        handler: H,
 776    ) -> Subscription
 777    where
 778        M: EnvelopedMessage,
 779        E: 'static,
 780        H: 'static
 781            + Sync
 782            + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
 783            + Send
 784            + Sync,
 785        F: 'static + Future<Output = Result<()>>,
 786    {
 787        let message_type_id = TypeId::of::<M>();
 788        let mut state = self.handler_set.lock();
 789        state
 790            .entities_by_message_type
 791            .insert(message_type_id, entity.into());
 792
 793        let prev_handler = state.message_handlers.insert(
 794            message_type_id,
 795            Arc::new(move |subscriber, envelope, client, cx| {
 796                let subscriber = subscriber.downcast::<E>().unwrap();
 797                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 798                handler(subscriber, *envelope, client.clone(), cx).boxed_local()
 799            }),
 800        );
 801        if prev_handler.is_some() {
 802            let location = std::panic::Location::caller();
 803            panic!(
 804                "{}:{} registered handler for the same message {} twice",
 805                location.file(),
 806                location.line(),
 807                std::any::type_name::<M>()
 808            );
 809        }
 810
 811        Subscription::Message {
 812            client: Arc::downgrade(self),
 813            id: message_type_id,
 814        }
 815    }
 816
 817    pub fn add_request_handler<M, E, H, F>(
 818        self: &Arc<Self>,
 819        entity: WeakEntity<E>,
 820        handler: H,
 821    ) -> Subscription
 822    where
 823        M: RequestMessage,
 824        E: 'static,
 825        H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
 826        F: 'static + Future<Output = Result<M::Response>>,
 827    {
 828        self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
 829            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 830        })
 831    }
 832
 833    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 834        receipt: Receipt<T>,
 835        response: F,
 836        client: AnyProtoClient,
 837    ) -> Result<()> {
 838        match response.await {
 839            Ok(response) => {
 840                client.send_response(receipt.message_id, response)?;
 841                Ok(())
 842            }
 843            Err(error) => {
 844                client.send_response(receipt.message_id, error.to_proto())?;
 845                Err(error)
 846            }
 847        }
 848    }
 849
 850    pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
 851        self.credentials_provider
 852            .read_credentials(cx)
 853            .await
 854            .is_some()
 855    }
 856
 857    #[async_recursion(?Send)]
 858    pub async fn authenticate_and_connect(
 859        self: &Arc<Self>,
 860        try_provider: bool,
 861        cx: &AsyncApp,
 862    ) -> ConnectionResult<()> {
 863        let was_disconnected = match *self.status().borrow() {
 864            Status::SignedOut => true,
 865            Status::ConnectionError
 866            | Status::ConnectionLost
 867            | Status::Authenticating { .. }
 868            | Status::Reauthenticating { .. }
 869            | Status::ReconnectionError { .. } => false,
 870            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 871                return ConnectionResult::Result(Ok(()));
 872            }
 873            Status::UpgradeRequired => {
 874                return ConnectionResult::Result(
 875                    Err(EstablishConnectionError::UpgradeRequired)
 876                        .context("client auth and connect"),
 877                );
 878            }
 879        };
 880        if was_disconnected {
 881            self.set_status(Status::Authenticating, cx);
 882        } else {
 883            self.set_status(Status::Reauthenticating, cx)
 884        }
 885
 886        let mut read_from_provider = false;
 887        let mut credentials = self.state.read().credentials.clone();
 888        if credentials.is_none() && try_provider {
 889            credentials = self.credentials_provider.read_credentials(cx).await;
 890            read_from_provider = credentials.is_some();
 891        }
 892
 893        if credentials.is_none() {
 894            let mut status_rx = self.status();
 895            let _ = status_rx.next().await;
 896            futures::select_biased! {
 897                authenticate = self.authenticate(cx).fuse() => {
 898                    match authenticate {
 899                        Ok(creds) => credentials = Some(creds),
 900                        Err(err) => {
 901                            self.set_status(Status::ConnectionError, cx);
 902                            return ConnectionResult::Result(Err(err));
 903                        }
 904                    }
 905                }
 906                _ = status_rx.next().fuse() => {
 907                    return ConnectionResult::Result(Err(anyhow!("authentication canceled")));
 908                }
 909            }
 910        }
 911        let credentials = credentials.unwrap();
 912        self.set_id(credentials.user_id);
 913        self.cloud_client
 914            .set_credentials(credentials.user_id as u32, credentials.access_token.clone());
 915
 916        if was_disconnected {
 917            self.set_status(Status::Connecting, cx);
 918        } else {
 919            self.set_status(Status::Reconnecting, cx);
 920        }
 921
 922        let mut timeout =
 923            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
 924        futures::select_biased! {
 925            connection = self.establish_connection(&credentials, cx).fuse() => {
 926                match connection {
 927                    Ok(conn) => {
 928                        self.state.write().credentials = Some(credentials.clone());
 929                        if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
 930                            self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err();
 931                        }
 932
 933                        futures::select_biased! {
 934                            result = self.set_connection(conn, cx).fuse() => {
 935                                match result.context("client auth and connect") {
 936                                    Ok(()) => ConnectionResult::Result(Ok(())),
 937                                    Err(err) => {
 938                                        self.set_status(Status::ConnectionError, cx);
 939                                        ConnectionResult::Result(Err(err))
 940                                    },
 941                                }
 942                            },
 943                            _ = timeout => {
 944                                self.set_status(Status::ConnectionError, cx);
 945                                ConnectionResult::Timeout
 946                            }
 947                        }
 948                    }
 949                    Err(EstablishConnectionError::Unauthorized) => {
 950                        self.state.write().credentials.take();
 951                        if read_from_provider {
 952                            self.credentials_provider.delete_credentials(cx).await.log_err();
 953                            self.set_status(Status::SignedOut, cx);
 954                            self.authenticate_and_connect(false, cx).await
 955                        } else {
 956                            self.set_status(Status::ConnectionError, cx);
 957                            ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
 958                        }
 959                    }
 960                    Err(EstablishConnectionError::UpgradeRequired) => {
 961                        self.set_status(Status::UpgradeRequired, cx);
 962                        ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
 963                    }
 964                    Err(error) => {
 965                        self.set_status(Status::ConnectionError, cx);
 966                        ConnectionResult::Result(Err(error).context("client auth and connect"))
 967                    }
 968                }
 969            }
 970            _ = &mut timeout => {
 971                self.set_status(Status::ConnectionError, cx);
 972                ConnectionResult::Timeout
 973            }
 974        }
 975    }
 976
 977    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
 978        let executor = cx.background_executor();
 979        log::debug!("add connection to peer");
 980        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
 981            let executor = executor.clone();
 982            move |duration| executor.timer(duration)
 983        });
 984        let handle_io = executor.spawn(handle_io);
 985
 986        let peer_id = async {
 987            log::debug!("waiting for server hello");
 988            let message = incoming.next().await.context("no hello message received")?;
 989            log::debug!("got server hello");
 990            let hello_message_type_name = message.payload_type_name().to_string();
 991            let hello = message
 992                .into_any()
 993                .downcast::<TypedEnvelope<proto::Hello>>()
 994                .map_err(|_| {
 995                    anyhow!(
 996                        "invalid hello message received: {:?}",
 997                        hello_message_type_name
 998                    )
 999                })?;
1000            let peer_id = hello.payload.peer_id.context("invalid peer id")?;
1001            Ok(peer_id)
1002        };
1003
1004        let peer_id = match peer_id.await {
1005            Ok(peer_id) => peer_id,
1006            Err(error) => {
1007                self.peer.disconnect(connection_id);
1008                return Err(error);
1009            }
1010        };
1011
1012        log::debug!(
1013            "set status to connected (connection id: {:?}, peer id: {:?})",
1014            connection_id,
1015            peer_id
1016        );
1017        self.set_status(
1018            Status::Connected {
1019                peer_id,
1020                connection_id,
1021            },
1022            cx,
1023        );
1024
1025        cx.spawn({
1026            let this = self.clone();
1027            async move |cx| {
1028                while let Some(message) = incoming.next().await {
1029                    this.handle_message(message, &cx);
1030                    // Don't starve the main thread when receiving lots of messages at once.
1031                    smol::future::yield_now().await;
1032                }
1033            }
1034        })
1035        .detach();
1036
1037        cx.spawn({
1038            let this = self.clone();
1039            async move |cx| match handle_io.await {
1040                Ok(()) => {
1041                    if *this.status().borrow()
1042                        == (Status::Connected {
1043                            connection_id,
1044                            peer_id,
1045                        })
1046                    {
1047                        this.set_status(Status::SignedOut, &cx);
1048                    }
1049                }
1050                Err(err) => {
1051                    log::error!("connection error: {:?}", err);
1052                    this.set_status(Status::ConnectionLost, &cx);
1053                }
1054            }
1055        })
1056        .detach();
1057
1058        Ok(())
1059    }
1060
1061    fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1062        #[cfg(any(test, feature = "test-support"))]
1063        if let Some(callback) = self.authenticate.read().as_ref() {
1064            return callback(cx);
1065        }
1066
1067        self.authenticate_with_browser(cx)
1068    }
1069
1070    fn establish_connection(
1071        self: &Arc<Self>,
1072        credentials: &Credentials,
1073        cx: &AsyncApp,
1074    ) -> Task<Result<Connection, EstablishConnectionError>> {
1075        #[cfg(any(test, feature = "test-support"))]
1076        if let Some(callback) = self.establish_connection.read().as_ref() {
1077            return callback(credentials, cx);
1078        }
1079
1080        self.establish_websocket_connection(credentials, cx)
1081    }
1082
1083    fn rpc_url(
1084        &self,
1085        http: Arc<HttpClientWithUrl>,
1086        release_channel: Option<ReleaseChannel>,
1087    ) -> impl Future<Output = Result<url::Url>> + use<> {
1088        #[cfg(any(test, feature = "test-support"))]
1089        let url_override = self.rpc_url.read().clone();
1090
1091        async move {
1092            #[cfg(any(test, feature = "test-support"))]
1093            if let Some(url) = url_override {
1094                return Ok(url);
1095            }
1096
1097            if let Some(url) = &*ZED_RPC_URL {
1098                return Url::parse(url).context("invalid rpc url");
1099            }
1100
1101            let mut url = http.build_url("/rpc");
1102            if let Some(preview_param) =
1103                release_channel.and_then(|channel| channel.release_query_param())
1104            {
1105                url += "?";
1106                url += preview_param;
1107            }
1108
1109            let response = http.get(&url, Default::default(), false).await?;
1110            anyhow::ensure!(
1111                response.status().is_redirection(),
1112                "unexpected /rpc response status {}",
1113                response.status()
1114            );
1115            let collab_url = response
1116                .headers()
1117                .get("Location")
1118                .context("missing location header in /rpc response")?
1119                .to_str()
1120                .map_err(EstablishConnectionError::other)?
1121                .to_string();
1122            Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
1123        }
1124    }
1125
1126    fn establish_websocket_connection(
1127        self: &Arc<Self>,
1128        credentials: &Credentials,
1129        cx: &AsyncApp,
1130    ) -> Task<Result<Connection, EstablishConnectionError>> {
1131        let release_channel = cx
1132            .update(|cx| ReleaseChannel::try_global(cx))
1133            .ok()
1134            .flatten();
1135        let app_version = cx
1136            .update(|cx| AppVersion::global(cx).to_string())
1137            .ok()
1138            .unwrap_or_default();
1139
1140        let http = self.http.clone();
1141        let proxy = http.proxy().cloned();
1142        let user_agent = http.user_agent().cloned();
1143        let credentials = credentials.clone();
1144        let rpc_url = self.rpc_url(http, release_channel);
1145        let system_id = self.telemetry.system_id();
1146        let metrics_id = self.telemetry.metrics_id();
1147        cx.spawn(async move |cx| {
1148            use HttpOrHttps::*;
1149
1150            #[derive(Debug)]
1151            enum HttpOrHttps {
1152                Http,
1153                Https,
1154            }
1155
1156            let mut rpc_url = rpc_url.await?;
1157            let url_scheme = match rpc_url.scheme() {
1158                "https" => Https,
1159                "http" => Http,
1160                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1161            };
1162            let rpc_host = rpc_url
1163                .host_str()
1164                .zip(rpc_url.port_or_known_default())
1165                .context("missing host in rpc url")?;
1166
1167            let stream = {
1168                let handle = cx.update(|cx| gpui_tokio::Tokio::handle(cx)).ok().unwrap();
1169                let _guard = handle.enter();
1170                match proxy {
1171                    Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
1172                    None => Box::new(TcpStream::connect(rpc_host).await?),
1173                }
1174            };
1175
1176            log::info!("connected to rpc endpoint {}", rpc_url);
1177
1178            rpc_url
1179                .set_scheme(match url_scheme {
1180                    Https => "wss",
1181                    Http => "ws",
1182                })
1183                .unwrap();
1184
1185            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1186            // for us from the RPC URL.
1187            //
1188            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1189            let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1190
1191            // We then modify the request to add our desired headers.
1192            let request_headers = request.headers_mut();
1193            request_headers.insert(
1194                http::header::AUTHORIZATION,
1195                HeaderValue::from_str(&credentials.authorization_header())?,
1196            );
1197            request_headers.insert(
1198                "x-zed-protocol-version",
1199                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1200            );
1201            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1202            request_headers.insert(
1203                "x-zed-release-channel",
1204                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1205            );
1206            if let Some(user_agent) = user_agent {
1207                request_headers.insert(http::header::USER_AGENT, user_agent);
1208            }
1209            if let Some(system_id) = system_id {
1210                request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1211            }
1212            if let Some(metrics_id) = metrics_id {
1213                request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1214            }
1215
1216            let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1217                request,
1218                stream,
1219                Some(Arc::new(http_client_tls::tls_config()).into()),
1220                None,
1221            )
1222            .await?;
1223
1224            Ok(Connection::new(
1225                stream
1226                    .map_err(|error| anyhow!(error))
1227                    .sink_map_err(|error| anyhow!(error)),
1228            ))
1229        })
1230    }
1231
1232    pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1233        let http = self.http.clone();
1234        let this = self.clone();
1235        cx.spawn(async move |cx| {
1236            let background = cx.background_executor().clone();
1237
1238            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1239            cx.update(|cx| {
1240                cx.spawn(async move |cx| {
1241                    let url = open_url_rx.await?;
1242                    cx.update(|cx| cx.open_url(&url))
1243                })
1244                .detach_and_log_err(cx);
1245            })
1246            .log_err();
1247
1248            let credentials = background
1249                .clone()
1250                .spawn(async move {
1251                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1252                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1253                    // any other app running on the user's device.
1254                    let (public_key, private_key) =
1255                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1256                    let public_key_string = String::try_from(public_key)
1257                        .expect("failed to serialize public key for auth");
1258
1259                    if let Some((login, token)) =
1260                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1261                    {
1262                        eprintln!("authenticate as admin {login}, {token}");
1263
1264                        return this
1265                            .authenticate_as_admin(http, login.clone(), token.clone())
1266                            .await;
1267                    }
1268
1269                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1270                    let server =
1271                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1272                    let port = server.server_addr().port();
1273
1274                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1275                    // that the user is signing in from a Zed app running on the same device.
1276                    let mut url = http.build_url(&format!(
1277                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1278                        port, public_key_string
1279                    ));
1280
1281                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1282                        log::info!("impersonating user @{}", impersonate_login);
1283                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1284                    }
1285
1286                    open_url_tx.send(url).log_err();
1287
1288                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1289                    // access token from the query params.
1290                    //
1291                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1292                    // custom URL scheme instead of this local HTTP server.
1293                    let (user_id, access_token) = background
1294                        .spawn(async move {
1295                            for _ in 0..100 {
1296                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1297                                    let path = req.url();
1298                                    let mut user_id = None;
1299                                    let mut access_token = None;
1300                                    let url = Url::parse(&format!("http://example.com{}", path))
1301                                        .context("failed to parse login notification url")?;
1302                                    for (key, value) in url.query_pairs() {
1303                                        if key == "access_token" {
1304                                            access_token = Some(value.to_string());
1305                                        } else if key == "user_id" {
1306                                            user_id = Some(value.to_string());
1307                                        }
1308                                    }
1309
1310                                    let post_auth_url =
1311                                        http.build_url("/native_app_signin_succeeded");
1312                                    req.respond(
1313                                        tiny_http::Response::empty(302).with_header(
1314                                            tiny_http::Header::from_bytes(
1315                                                &b"Location"[..],
1316                                                post_auth_url.as_bytes(),
1317                                            )
1318                                            .unwrap(),
1319                                        ),
1320                                    )
1321                                    .context("failed to respond to login http request")?;
1322                                    return Ok((
1323                                        user_id.context("missing user_id parameter")?,
1324                                        access_token.context("missing access_token parameter")?,
1325                                    ));
1326                                }
1327                            }
1328
1329                            anyhow::bail!("didn't receive login redirect");
1330                        })
1331                        .await?;
1332
1333                    let access_token = private_key
1334                        .decrypt_string(&access_token)
1335                        .context("failed to decrypt access token")?;
1336
1337                    Ok(Credentials {
1338                        user_id: user_id.parse()?,
1339                        access_token,
1340                    })
1341                })
1342                .await?;
1343
1344            cx.update(|cx| cx.activate(true))?;
1345            Ok(credentials)
1346        })
1347    }
1348
1349    async fn authenticate_as_admin(
1350        self: &Arc<Self>,
1351        http: Arc<HttpClientWithUrl>,
1352        login: String,
1353        api_token: String,
1354    ) -> Result<Credentials> {
1355        #[derive(Serialize)]
1356        struct ImpersonateUserBody {
1357            github_login: String,
1358        }
1359
1360        #[derive(Deserialize)]
1361        struct ImpersonateUserResponse {
1362            user_id: u64,
1363            access_token: String,
1364        }
1365
1366        let url = self
1367            .http
1368            .build_zed_cloud_url("/internal/users/impersonate", &[])?;
1369        let request = Request::post(url.as_str())
1370            .header("Content-Type", "application/json")
1371            .header("Authorization", format!("Bearer {api_token}"))
1372            .body(
1373                serde_json::to_string(&ImpersonateUserBody {
1374                    github_login: login,
1375                })?
1376                .into(),
1377            )?;
1378
1379        let mut response = http.send(request).await?;
1380        let mut body = String::new();
1381        response.body_mut().read_to_string(&mut body).await?;
1382        anyhow::ensure!(
1383            response.status().is_success(),
1384            "admin user request failed {} - {}",
1385            response.status().as_u16(),
1386            body,
1387        );
1388        let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
1389
1390        Ok(Credentials {
1391            user_id: response.user_id,
1392            access_token: response.access_token,
1393        })
1394    }
1395
1396    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1397        self.state.write().credentials = None;
1398        self.cloud_client.clear_credentials();
1399        self.disconnect(cx);
1400
1401        if self.has_credentials(cx).await {
1402            self.credentials_provider
1403                .delete_credentials(cx)
1404                .await
1405                .log_err();
1406        }
1407    }
1408
1409    pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1410        self.peer.teardown();
1411        self.set_status(Status::SignedOut, cx);
1412    }
1413
1414    pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1415        self.peer.teardown();
1416        self.set_status(Status::ConnectionLost, cx);
1417    }
1418
1419    fn connection_id(&self) -> Result<ConnectionId> {
1420        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1421            Ok(connection_id)
1422        } else {
1423            anyhow::bail!("not connected");
1424        }
1425    }
1426
1427    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1428        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1429        self.peer.send(self.connection_id()?, message)
1430    }
1431
1432    pub fn request<T: RequestMessage>(
1433        &self,
1434        request: T,
1435    ) -> impl Future<Output = Result<T::Response>> + use<T> {
1436        self.request_envelope(request)
1437            .map_ok(|envelope| envelope.payload)
1438    }
1439
1440    pub fn request_stream<T: RequestMessage>(
1441        &self,
1442        request: T,
1443    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1444        let client_id = self.id.load(Ordering::SeqCst);
1445        log::debug!(
1446            "rpc request start. client_id:{}. name:{}",
1447            client_id,
1448            T::NAME
1449        );
1450        let response = self
1451            .connection_id()
1452            .map(|conn_id| self.peer.request_stream(conn_id, request));
1453        async move {
1454            let response = response?.await;
1455            log::debug!(
1456                "rpc request finish. client_id:{}. name:{}",
1457                client_id,
1458                T::NAME
1459            );
1460            response
1461        }
1462    }
1463
1464    pub fn request_envelope<T: RequestMessage>(
1465        &self,
1466        request: T,
1467    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1468        let client_id = self.id();
1469        log::debug!(
1470            "rpc request start. client_id:{}. name:{}",
1471            client_id,
1472            T::NAME
1473        );
1474        let response = self
1475            .connection_id()
1476            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1477        async move {
1478            let response = response?.await;
1479            log::debug!(
1480                "rpc request finish. client_id:{}. name:{}",
1481                client_id,
1482                T::NAME
1483            );
1484            response
1485        }
1486    }
1487
1488    pub fn request_dynamic(
1489        &self,
1490        envelope: proto::Envelope,
1491        request_type: &'static str,
1492    ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1493        let client_id = self.id();
1494        log::debug!(
1495            "rpc request start. client_id:{}. name:{}",
1496            client_id,
1497            request_type
1498        );
1499        let response = self
1500            .connection_id()
1501            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1502        async move {
1503            let response = response?.await;
1504            log::debug!(
1505                "rpc request finish. client_id:{}. name:{}",
1506                client_id,
1507                request_type
1508            );
1509            Ok(response?.0)
1510        }
1511    }
1512
1513    fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1514        let sender_id = message.sender_id();
1515        let request_id = message.message_id();
1516        let type_name = message.payload_type_name();
1517        let original_sender_id = message.original_sender_id();
1518
1519        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1520            &self.handler_set,
1521            message,
1522            self.clone().into(),
1523            cx.clone(),
1524        ) {
1525            let client_id = self.id();
1526            log::debug!(
1527                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1528                client_id,
1529                original_sender_id,
1530                type_name
1531            );
1532            cx.spawn(async move |_| match future.await {
1533                Ok(()) => {
1534                    log::debug!(
1535                        "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1536                        client_id,
1537                        original_sender_id,
1538                        type_name
1539                    );
1540                }
1541                Err(error) => {
1542                    log::error!(
1543                        "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1544                        client_id,
1545                        original_sender_id,
1546                        type_name,
1547                        error
1548                    );
1549                }
1550            })
1551            .detach();
1552        } else {
1553            log::info!("unhandled message {}", type_name);
1554            self.peer
1555                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1556                .log_err();
1557        }
1558    }
1559
1560    pub fn telemetry(&self) -> &Arc<Telemetry> {
1561        &self.telemetry
1562    }
1563}
1564
1565impl ProtoClient for Client {
1566    fn request(
1567        &self,
1568        envelope: proto::Envelope,
1569        request_type: &'static str,
1570    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1571        self.request_dynamic(envelope, request_type).boxed()
1572    }
1573
1574    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1575        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1576        let connection_id = self.connection_id()?;
1577        self.peer.send_dynamic(connection_id, envelope)
1578    }
1579
1580    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1581        log::debug!(
1582            "rpc respond. client_id:{}, name:{}",
1583            self.id(),
1584            message_type
1585        );
1586        let connection_id = self.connection_id()?;
1587        self.peer.send_dynamic(connection_id, envelope)
1588    }
1589
1590    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1591        &self.handler_set
1592    }
1593
1594    fn is_via_collab(&self) -> bool {
1595        true
1596    }
1597}
1598
1599/// prefix for the zed:// url scheme
1600pub const ZED_URL_SCHEME: &str = "zed";
1601
1602/// Parses the given link into a Zed link.
1603///
1604/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1605/// Returns [`None`] otherwise.
1606pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
1607    let server_url = &ClientSettings::get_global(cx).server_url;
1608    if let Some(stripped) = link
1609        .strip_prefix(server_url)
1610        .and_then(|result| result.strip_prefix('/'))
1611    {
1612        return Some(stripped);
1613    }
1614    if let Some(stripped) = link
1615        .strip_prefix(ZED_URL_SCHEME)
1616        .and_then(|result| result.strip_prefix("://"))
1617    {
1618        return Some(stripped);
1619    }
1620
1621    None
1622}
1623
1624#[cfg(test)]
1625mod tests {
1626    use super::*;
1627    use crate::test::FakeServer;
1628
1629    use clock::FakeSystemClock;
1630    use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1631    use http_client::FakeHttpClient;
1632    use parking_lot::Mutex;
1633    use proto::TypedEnvelope;
1634    use settings::SettingsStore;
1635    use std::future;
1636
1637    #[gpui::test(iterations = 10)]
1638    async fn test_reconnection(cx: &mut TestAppContext) {
1639        init_test(cx);
1640        let user_id = 5;
1641        let client = cx.update(|cx| {
1642            Client::new(
1643                Arc::new(FakeSystemClock::new()),
1644                FakeHttpClient::with_404_response(),
1645                cx,
1646            )
1647        });
1648        let server = FakeServer::for_client(user_id, &client, cx).await;
1649        let mut status = client.status();
1650        assert!(matches!(
1651            status.next().await,
1652            Some(Status::Connected { .. })
1653        ));
1654        assert_eq!(server.auth_count(), 1);
1655
1656        server.forbid_connections();
1657        server.disconnect();
1658        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1659
1660        server.allow_connections();
1661        cx.executor().advance_clock(Duration::from_secs(10));
1662        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1663        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1664
1665        server.forbid_connections();
1666        server.disconnect();
1667        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1668
1669        // Clear cached credentials after authentication fails
1670        server.roll_access_token();
1671        server.allow_connections();
1672        cx.executor().run_until_parked();
1673        cx.executor().advance_clock(Duration::from_secs(10));
1674        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1675        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1676    }
1677
1678    #[gpui::test(iterations = 10)]
1679    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1680        init_test(cx);
1681        let user_id = 5;
1682        let client = cx.update(|cx| {
1683            Client::new(
1684                Arc::new(FakeSystemClock::new()),
1685                FakeHttpClient::with_404_response(),
1686                cx,
1687            )
1688        });
1689        let mut status = client.status();
1690
1691        // Time out when client tries to connect.
1692        client.override_authenticate(move |cx| {
1693            cx.background_spawn(async move {
1694                Ok(Credentials {
1695                    user_id,
1696                    access_token: "token".into(),
1697                })
1698            })
1699        });
1700        client.override_establish_connection(|_, cx| {
1701            cx.background_spawn(async move {
1702                future::pending::<()>().await;
1703                unreachable!()
1704            })
1705        });
1706        let auth_and_connect = cx.spawn({
1707            let client = client.clone();
1708            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1709        });
1710        executor.run_until_parked();
1711        assert!(matches!(status.next().await, Some(Status::Connecting)));
1712
1713        executor.advance_clock(CONNECTION_TIMEOUT);
1714        assert!(matches!(
1715            status.next().await,
1716            Some(Status::ConnectionError { .. })
1717        ));
1718        auth_and_connect.await.into_response().unwrap_err();
1719
1720        // Allow the connection to be established.
1721        let server = FakeServer::for_client(user_id, &client, cx).await;
1722        assert!(matches!(
1723            status.next().await,
1724            Some(Status::Connected { .. })
1725        ));
1726
1727        // Disconnect client.
1728        server.forbid_connections();
1729        server.disconnect();
1730        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1731
1732        // Time out when re-establishing the connection.
1733        server.allow_connections();
1734        client.override_establish_connection(|_, cx| {
1735            cx.background_spawn(async move {
1736                future::pending::<()>().await;
1737                unreachable!()
1738            })
1739        });
1740        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1741        assert!(matches!(
1742            status.next().await,
1743            Some(Status::Reconnecting { .. })
1744        ));
1745
1746        executor.advance_clock(CONNECTION_TIMEOUT);
1747        assert!(matches!(
1748            status.next().await,
1749            Some(Status::ReconnectionError { .. })
1750        ));
1751    }
1752
1753    #[gpui::test(iterations = 10)]
1754    async fn test_authenticating_more_than_once(
1755        cx: &mut TestAppContext,
1756        executor: BackgroundExecutor,
1757    ) {
1758        init_test(cx);
1759        let auth_count = Arc::new(Mutex::new(0));
1760        let dropped_auth_count = Arc::new(Mutex::new(0));
1761        let client = cx.update(|cx| {
1762            Client::new(
1763                Arc::new(FakeSystemClock::new()),
1764                FakeHttpClient::with_404_response(),
1765                cx,
1766            )
1767        });
1768        client.override_authenticate({
1769            let auth_count = auth_count.clone();
1770            let dropped_auth_count = dropped_auth_count.clone();
1771            move |cx| {
1772                let auth_count = auth_count.clone();
1773                let dropped_auth_count = dropped_auth_count.clone();
1774                cx.background_spawn(async move {
1775                    *auth_count.lock() += 1;
1776                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1777                    future::pending::<()>().await;
1778                    unreachable!()
1779                })
1780            }
1781        });
1782
1783        let _authenticate = cx.spawn({
1784            let client = client.clone();
1785            move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1786        });
1787        executor.run_until_parked();
1788        assert_eq!(*auth_count.lock(), 1);
1789        assert_eq!(*dropped_auth_count.lock(), 0);
1790
1791        let _authenticate = cx.spawn({
1792            let client = client.clone();
1793            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1794        });
1795        executor.run_until_parked();
1796        assert_eq!(*auth_count.lock(), 2);
1797        assert_eq!(*dropped_auth_count.lock(), 1);
1798    }
1799
1800    #[gpui::test]
1801    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1802        init_test(cx);
1803        let user_id = 5;
1804        let client = cx.update(|cx| {
1805            Client::new(
1806                Arc::new(FakeSystemClock::new()),
1807                FakeHttpClient::with_404_response(),
1808                cx,
1809            )
1810        });
1811        let server = FakeServer::for_client(user_id, &client, cx).await;
1812
1813        let (done_tx1, done_rx1) = smol::channel::unbounded();
1814        let (done_tx2, done_rx2) = smol::channel::unbounded();
1815        AnyProtoClient::from(client.clone()).add_entity_message_handler(
1816            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, mut cx| {
1817                match entity.read_with(&mut cx, |entity, _| entity.id).unwrap() {
1818                    1 => done_tx1.try_send(()).unwrap(),
1819                    2 => done_tx2.try_send(()).unwrap(),
1820                    _ => unreachable!(),
1821                }
1822                async { Ok(()) }
1823            },
1824        );
1825        let entity1 = cx.new(|_| TestEntity {
1826            id: 1,
1827            subscription: None,
1828        });
1829        let entity2 = cx.new(|_| TestEntity {
1830            id: 2,
1831            subscription: None,
1832        });
1833        let entity3 = cx.new(|_| TestEntity {
1834            id: 3,
1835            subscription: None,
1836        });
1837
1838        let _subscription1 = client
1839            .subscribe_to_entity(1)
1840            .unwrap()
1841            .set_entity(&entity1, &mut cx.to_async());
1842        let _subscription2 = client
1843            .subscribe_to_entity(2)
1844            .unwrap()
1845            .set_entity(&entity2, &mut cx.to_async());
1846        // Ensure dropping a subscription for the same entity type still allows receiving of
1847        // messages for other entity IDs of the same type.
1848        let subscription3 = client
1849            .subscribe_to_entity(3)
1850            .unwrap()
1851            .set_entity(&entity3, &mut cx.to_async());
1852        drop(subscription3);
1853
1854        server.send(proto::JoinProject {
1855            project_id: 1,
1856            committer_name: None,
1857            committer_email: None,
1858        });
1859        server.send(proto::JoinProject {
1860            project_id: 2,
1861            committer_name: None,
1862            committer_email: None,
1863        });
1864        done_rx1.recv().await.unwrap();
1865        done_rx2.recv().await.unwrap();
1866    }
1867
1868    #[gpui::test]
1869    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1870        init_test(cx);
1871        let user_id = 5;
1872        let client = cx.update(|cx| {
1873            Client::new(
1874                Arc::new(FakeSystemClock::new()),
1875                FakeHttpClient::with_404_response(),
1876                cx,
1877            )
1878        });
1879        let server = FakeServer::for_client(user_id, &client, cx).await;
1880
1881        let entity = cx.new(|_| TestEntity::default());
1882        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1883        let (done_tx2, done_rx2) = smol::channel::unbounded();
1884        let subscription1 = client.add_message_handler(
1885            entity.downgrade(),
1886            move |_, _: TypedEnvelope<proto::Ping>, _| {
1887                done_tx1.try_send(()).unwrap();
1888                async { Ok(()) }
1889            },
1890        );
1891        drop(subscription1);
1892        let _subscription2 = client.add_message_handler(
1893            entity.downgrade(),
1894            move |_, _: TypedEnvelope<proto::Ping>, _| {
1895                done_tx2.try_send(()).unwrap();
1896                async { Ok(()) }
1897            },
1898        );
1899        server.send(proto::Ping {});
1900        done_rx2.recv().await.unwrap();
1901    }
1902
1903    #[gpui::test]
1904    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1905        init_test(cx);
1906        let user_id = 5;
1907        let client = cx.update(|cx| {
1908            Client::new(
1909                Arc::new(FakeSystemClock::new()),
1910                FakeHttpClient::with_404_response(),
1911                cx,
1912            )
1913        });
1914        let server = FakeServer::for_client(user_id, &client, cx).await;
1915
1916        let entity = cx.new(|_| TestEntity::default());
1917        let (done_tx, done_rx) = smol::channel::unbounded();
1918        let subscription = client.add_message_handler(
1919            entity.clone().downgrade(),
1920            move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
1921                entity
1922                    .update(&mut cx, |entity, _| entity.subscription.take())
1923                    .unwrap();
1924                done_tx.try_send(()).unwrap();
1925                async { Ok(()) }
1926            },
1927        );
1928        entity.update(cx, |entity, _| {
1929            entity.subscription = Some(subscription);
1930        });
1931        server.send(proto::Ping {});
1932        done_rx.recv().await.unwrap();
1933    }
1934
1935    #[derive(Default)]
1936    struct TestEntity {
1937        id: usize,
1938        subscription: Option<Subscription>,
1939    }
1940
1941    fn init_test(cx: &mut TestAppContext) {
1942        cx.update(|cx| {
1943            let settings_store = SettingsStore::test(cx);
1944            cx.set_global(settings_store);
1945            init_settings(cx);
1946        });
1947    }
1948}