client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod socks;
   5pub mod telemetry;
   6pub mod user;
   7
   8use anyhow::{anyhow, bail, Context as _, Result};
   9use async_recursion::async_recursion;
  10use async_tungstenite::tungstenite::{
  11    client::IntoClientRequest,
  12    error::Error as WebsocketError,
  13    http::{HeaderValue, Request, StatusCode},
  14};
  15use chrono::{DateTime, Utc};
  16use clock::SystemClock;
  17use futures::{
  18    channel::oneshot, future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
  19    TryFutureExt as _, TryStreamExt,
  20};
  21use gpui::{actions, AppContext, AsyncAppContext, Global, Model, Task, WeakModel};
  22use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
  23use parking_lot::RwLock;
  24use postage::watch;
  25use proto::{AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet};
  26use rand::prelude::*;
  27use release_channel::{AppVersion, ReleaseChannel};
  28use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::{Settings, SettingsSources};
  32use socks::connect_socks_proxy_stream;
  33use std::fmt;
  34use std::pin::Pin;
  35use std::{
  36    any::TypeId,
  37    convert::TryFrom,
  38    fmt::Write as _,
  39    future::Future,
  40    marker::PhantomData,
  41    path::PathBuf,
  42    sync::{
  43        atomic::{AtomicU64, Ordering},
  44        Arc, LazyLock, Weak,
  45    },
  46    time::{Duration, Instant},
  47};
  48use telemetry::Telemetry;
  49use thiserror::Error;
  50use url::Url;
  51use util::{ResultExt, TryFutureExt};
  52
  53pub use rpc::*;
  54pub use telemetry_events::Event;
  55pub use user::*;
  56
  57#[derive(Debug, Clone, Eq, PartialEq)]
  58pub struct DevServerToken(pub String);
  59
  60impl fmt::Display for DevServerToken {
  61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  62        write!(f, "{}", self.0)
  63    }
  64}
  65
  66static ZED_SERVER_URL: LazyLock<Option<String>> =
  67    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  68static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  69
  70/// An environment variable whose presence indicates that the development auth
  71/// provider should be used.
  72///
  73/// Only works in development. Setting this environment variable in other release
  74/// channels is a no-op.
  75pub static ZED_DEVELOPMENT_AUTH: LazyLock<bool> = LazyLock::new(|| {
  76    std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty())
  77});
  78pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  79    std::env::var("ZED_IMPERSONATE")
  80        .ok()
  81        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  82});
  83
  84pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  85    std::env::var("ZED_ADMIN_API_TOKEN")
  86        .ok()
  87        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  88});
  89
  90pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  91    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  92
  93pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  94    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty()));
  95
  96pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  97pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
  98pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  99
 100actions!(client, [SignIn, SignOut, Reconnect]);
 101
 102#[derive(Clone, Serialize, Deserialize, JsonSchema)]
 103#[serde(default)]
 104pub struct ClientSettings {
 105    /// The server to connect to. If the environment variable
 106    /// ZED_SERVER_URL is set, it will override this setting.
 107    pub server_url: String,
 108}
 109
 110impl Default for ClientSettings {
 111    fn default() -> Self {
 112        Self {
 113            server_url: "https://zed.dev".to_owned(),
 114        }
 115    }
 116}
 117
 118impl Settings for ClientSettings {
 119    const KEY: Option<&'static str> = None;
 120
 121    type FileContent = Self;
 122
 123    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 124        let mut result = sources.json_merge::<Self>()?;
 125        if let Some(server_url) = &*ZED_SERVER_URL {
 126            result.server_url.clone_from(server_url)
 127        }
 128        Ok(result)
 129    }
 130}
 131
 132#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 133#[serde(default)]
 134pub struct ProxySettings {
 135    /// Set a proxy to use. The proxy protocol is specified by the URI scheme.
 136    ///
 137    /// Supported URI scheme: `http`, `https`, `socks4`, `socks4a`, `socks5`,
 138    /// `socks5h`. `http` will be used when no scheme is specified.
 139    ///
 140    /// By default no proxy will be used, or Zed will try get proxy settings from
 141    /// environment variables.
 142    ///
 143    /// Examples:
 144    ///   - "proxy": "socks5://localhost:10808"
 145    ///   - "proxy": "http://127.0.0.1:10809"
 146    #[schemars(example = "Self::example_1")]
 147    #[schemars(example = "Self::example_2")]
 148    pub proxy: Option<String>,
 149}
 150
 151impl ProxySettings {
 152    fn example_1() -> String {
 153        "http://127.0.0.1:10809".to_owned()
 154    }
 155    fn example_2() -> String {
 156        "socks5://localhost:10808".to_owned()
 157    }
 158}
 159
 160impl Settings for ProxySettings {
 161    const KEY: Option<&'static str> = None;
 162
 163    type FileContent = Self;
 164
 165    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 166        Ok(Self {
 167            proxy: sources
 168                .user
 169                .and_then(|value| value.proxy.clone())
 170                .or(sources.default.proxy.clone()),
 171        })
 172    }
 173}
 174
 175pub fn init_settings(cx: &mut AppContext) {
 176    TelemetrySettings::register(cx);
 177    ClientSettings::register(cx);
 178    ProxySettings::register(cx);
 179}
 180
 181pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
 182    let client = Arc::downgrade(client);
 183    cx.on_action({
 184        let client = client.clone();
 185        move |_: &SignIn, cx| {
 186            if let Some(client) = client.upgrade() {
 187                cx.spawn(
 188                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
 189                )
 190                .detach();
 191            }
 192        }
 193    });
 194
 195    cx.on_action({
 196        let client = client.clone();
 197        move |_: &SignOut, cx| {
 198            if let Some(client) = client.upgrade() {
 199                cx.spawn(|cx| async move {
 200                    client.sign_out(&cx).await;
 201                })
 202                .detach();
 203            }
 204        }
 205    });
 206
 207    cx.on_action({
 208        let client = client.clone();
 209        move |_: &Reconnect, cx| {
 210            if let Some(client) = client.upgrade() {
 211                cx.spawn(|cx| async move {
 212                    client.reconnect(&cx);
 213                })
 214                .detach();
 215            }
 216        }
 217    });
 218}
 219
 220struct GlobalClient(Arc<Client>);
 221
 222impl Global for GlobalClient {}
 223
 224pub struct Client {
 225    id: AtomicU64,
 226    peer: Arc<Peer>,
 227    http: Arc<HttpClientWithUrl>,
 228    telemetry: Arc<Telemetry>,
 229    credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
 230    state: RwLock<ClientState>,
 231    handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
 232
 233    #[allow(clippy::type_complexity)]
 234    #[cfg(any(test, feature = "test-support"))]
 235    authenticate: RwLock<
 236        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 237    >,
 238
 239    #[allow(clippy::type_complexity)]
 240    #[cfg(any(test, feature = "test-support"))]
 241    establish_connection: RwLock<
 242        Option<
 243            Box<
 244                dyn 'static
 245                    + Send
 246                    + Sync
 247                    + Fn(
 248                        &Credentials,
 249                        &AsyncAppContext,
 250                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 251            >,
 252        >,
 253    >,
 254
 255    #[cfg(any(test, feature = "test-support"))]
 256    rpc_url: RwLock<Option<Url>>,
 257}
 258
 259#[derive(Error, Debug)]
 260pub enum EstablishConnectionError {
 261    #[error("upgrade required")]
 262    UpgradeRequired,
 263    #[error("unauthorized")]
 264    Unauthorized,
 265    #[error("{0}")]
 266    Other(#[from] anyhow::Error),
 267    #[error("{0}")]
 268    Http(#[from] http_client::Error),
 269    #[error("{0}")]
 270    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 271    #[error("{0}")]
 272    Io(#[from] std::io::Error),
 273    #[error("{0}")]
 274    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 275}
 276
 277impl From<WebsocketError> for EstablishConnectionError {
 278    fn from(error: WebsocketError) -> Self {
 279        if let WebsocketError::Http(response) = &error {
 280            match response.status() {
 281                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 282                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 283                _ => {}
 284            }
 285        }
 286        EstablishConnectionError::Other(error.into())
 287    }
 288}
 289
 290impl EstablishConnectionError {
 291    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 292        Self::Other(error.into())
 293    }
 294}
 295
 296#[derive(Copy, Clone, Debug, PartialEq)]
 297pub enum Status {
 298    SignedOut,
 299    UpgradeRequired,
 300    Authenticating,
 301    Connecting,
 302    ConnectionError,
 303    Connected {
 304        peer_id: PeerId,
 305        connection_id: ConnectionId,
 306    },
 307    ConnectionLost,
 308    Reauthenticating,
 309    Reconnecting,
 310    ReconnectionError {
 311        next_reconnection: Instant,
 312    },
 313}
 314
 315impl Status {
 316    pub fn is_connected(&self) -> bool {
 317        matches!(self, Self::Connected { .. })
 318    }
 319
 320    pub fn is_signed_out(&self) -> bool {
 321        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 322    }
 323}
 324
 325struct ClientState {
 326    credentials: Option<Credentials>,
 327    status: (watch::Sender<Status>, watch::Receiver<Status>),
 328    _reconnect_task: Option<Task<()>>,
 329}
 330
 331#[derive(Clone, Debug, Eq, PartialEq)]
 332pub enum Credentials {
 333    DevServer { token: DevServerToken },
 334    User { user_id: u64, access_token: String },
 335}
 336
 337impl Credentials {
 338    pub fn authorization_header(&self) -> String {
 339        match self {
 340            Credentials::DevServer { token } => format!("dev-server-token {}", token),
 341            Credentials::User {
 342                user_id,
 343                access_token,
 344            } => format!("{} {}", user_id, access_token),
 345        }
 346    }
 347}
 348
 349/// A provider for [`Credentials`].
 350///
 351/// Used to abstract over reading and writing credentials to some form of
 352/// persistence (like the system keychain).
 353trait CredentialsProvider {
 354    /// Reads the credentials from the provider.
 355    fn read_credentials<'a>(
 356        &'a self,
 357        cx: &'a AsyncAppContext,
 358    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>>;
 359
 360    /// Writes the credentials to the provider.
 361    fn write_credentials<'a>(
 362        &'a self,
 363        user_id: u64,
 364        access_token: String,
 365        cx: &'a AsyncAppContext,
 366    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 367
 368    /// Deletes the credentials from the provider.
 369    fn delete_credentials<'a>(
 370        &'a self,
 371        cx: &'a AsyncAppContext,
 372    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 373}
 374
 375impl Default for ClientState {
 376    fn default() -> Self {
 377        Self {
 378            credentials: None,
 379            status: watch::channel_with(Status::SignedOut),
 380            _reconnect_task: None,
 381        }
 382    }
 383}
 384
 385pub enum Subscription {
 386    Entity {
 387        client: Weak<Client>,
 388        id: (TypeId, u64),
 389    },
 390    Message {
 391        client: Weak<Client>,
 392        id: TypeId,
 393    },
 394}
 395
 396impl Drop for Subscription {
 397    fn drop(&mut self) {
 398        match self {
 399            Subscription::Entity { client, id } => {
 400                if let Some(client) = client.upgrade() {
 401                    let mut state = client.handler_set.lock();
 402                    let _ = state.entities_by_type_and_remote_id.remove(id);
 403                }
 404            }
 405            Subscription::Message { client, id } => {
 406                if let Some(client) = client.upgrade() {
 407                    let mut state = client.handler_set.lock();
 408                    let _ = state.entity_types_by_message_type.remove(id);
 409                    let _ = state.message_handlers.remove(id);
 410                }
 411            }
 412        }
 413    }
 414}
 415
 416pub struct PendingEntitySubscription<T: 'static> {
 417    client: Arc<Client>,
 418    remote_id: u64,
 419    _entity_type: PhantomData<T>,
 420    consumed: bool,
 421}
 422
 423impl<T: 'static> PendingEntitySubscription<T> {
 424    pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
 425        self.consumed = true;
 426        let mut handlers = self.client.handler_set.lock();
 427        let id = (TypeId::of::<T>(), self.remote_id);
 428        let Some(EntityMessageSubscriber::Pending(messages)) =
 429            handlers.entities_by_type_and_remote_id.remove(&id)
 430        else {
 431            unreachable!()
 432        };
 433
 434        handlers.entities_by_type_and_remote_id.insert(
 435            id,
 436            EntityMessageSubscriber::Entity {
 437                handle: model.downgrade().into(),
 438            },
 439        );
 440        drop(handlers);
 441        for message in messages {
 442            let client_id = self.client.id();
 443            let type_name = message.payload_type_name();
 444            let sender_id = message.original_sender_id();
 445            log::debug!(
 446                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 447                client_id,
 448                sender_id,
 449                type_name
 450            );
 451            self.client.handle_message(message, cx);
 452        }
 453        Subscription::Entity {
 454            client: Arc::downgrade(&self.client),
 455            id,
 456        }
 457    }
 458}
 459
 460impl<T: 'static> Drop for PendingEntitySubscription<T> {
 461    fn drop(&mut self) {
 462        if !self.consumed {
 463            let mut state = self.client.handler_set.lock();
 464            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 465                .entities_by_type_and_remote_id
 466                .remove(&(TypeId::of::<T>(), self.remote_id))
 467            {
 468                for message in messages {
 469                    log::info!("unhandled message {}", message.payload_type_name());
 470                }
 471            }
 472        }
 473    }
 474}
 475
 476#[derive(Copy, Clone)]
 477pub struct TelemetrySettings {
 478    pub diagnostics: bool,
 479    pub metrics: bool,
 480}
 481
 482/// Control what info is collected by Zed.
 483#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 484pub struct TelemetrySettingsContent {
 485    /// Send debug info like crash reports.
 486    ///
 487    /// Default: true
 488    pub diagnostics: Option<bool>,
 489    /// Send anonymized usage data like what languages you're using Zed with.
 490    ///
 491    /// Default: true
 492    pub metrics: Option<bool>,
 493}
 494
 495impl settings::Settings for TelemetrySettings {
 496    const KEY: Option<&'static str> = Some("telemetry");
 497
 498    type FileContent = TelemetrySettingsContent;
 499
 500    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 501        Ok(Self {
 502            diagnostics: sources.user.as_ref().and_then(|v| v.diagnostics).unwrap_or(
 503                sources
 504                    .default
 505                    .diagnostics
 506                    .ok_or_else(Self::missing_default)?,
 507            ),
 508            metrics: sources
 509                .user
 510                .as_ref()
 511                .and_then(|v| v.metrics)
 512                .unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
 513        })
 514    }
 515}
 516
 517impl Client {
 518    pub fn new(
 519        clock: Arc<dyn SystemClock>,
 520        http: Arc<HttpClientWithUrl>,
 521        cx: &mut AppContext,
 522    ) -> Arc<Self> {
 523        let use_zed_development_auth = match ReleaseChannel::try_global(cx) {
 524            Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH,
 525            Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
 526            | None => false,
 527        };
 528
 529        let credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static> =
 530            if use_zed_development_auth {
 531                Arc::new(DevelopmentCredentialsProvider {
 532                    path: paths::config_dir().join("development_auth"),
 533                })
 534            } else {
 535                Arc::new(KeychainCredentialsProvider)
 536            };
 537
 538        Arc::new(Self {
 539            id: AtomicU64::new(0),
 540            peer: Peer::new(0),
 541            telemetry: Telemetry::new(clock, http.clone(), cx),
 542            http,
 543            credentials_provider,
 544            state: Default::default(),
 545            handler_set: Default::default(),
 546
 547            #[cfg(any(test, feature = "test-support"))]
 548            authenticate: Default::default(),
 549            #[cfg(any(test, feature = "test-support"))]
 550            establish_connection: Default::default(),
 551            #[cfg(any(test, feature = "test-support"))]
 552            rpc_url: RwLock::default(),
 553        })
 554    }
 555
 556    pub fn production(cx: &mut AppContext) -> Arc<Self> {
 557        let user_agent = format!(
 558            "Zed/{} ({}; {})",
 559            AppVersion::global(cx),
 560            std::env::consts::OS,
 561            std::env::consts::ARCH
 562        );
 563        let clock = Arc::new(clock::RealSystemClock);
 564        let http = Arc::new(HttpClientWithUrl::new(
 565            &ClientSettings::get_global(cx).server_url,
 566            Some(user_agent),
 567            ProxySettings::get_global(cx).proxy.clone(),
 568        ));
 569        Self::new(clock, http.clone(), cx)
 570    }
 571
 572    pub fn id(&self) -> u64 {
 573        self.id.load(Ordering::SeqCst)
 574    }
 575
 576    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 577        self.http.clone()
 578    }
 579
 580    pub fn set_id(&self, id: u64) -> &Self {
 581        self.id.store(id, Ordering::SeqCst);
 582        self
 583    }
 584
 585    #[cfg(any(test, feature = "test-support"))]
 586    pub fn teardown(&self) {
 587        let mut state = self.state.write();
 588        state._reconnect_task.take();
 589        self.handler_set.lock().clear();
 590        self.peer.teardown();
 591    }
 592
 593    #[cfg(any(test, feature = "test-support"))]
 594    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 595    where
 596        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 597    {
 598        *self.authenticate.write() = Some(Box::new(authenticate));
 599        self
 600    }
 601
 602    #[cfg(any(test, feature = "test-support"))]
 603    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 604    where
 605        F: 'static
 606            + Send
 607            + Sync
 608            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 609    {
 610        *self.establish_connection.write() = Some(Box::new(connect));
 611        self
 612    }
 613
 614    #[cfg(any(test, feature = "test-support"))]
 615    pub fn override_rpc_url(&self, url: Url) -> &Self {
 616        *self.rpc_url.write() = Some(url);
 617        self
 618    }
 619
 620    pub fn global(cx: &AppContext) -> Arc<Self> {
 621        cx.global::<GlobalClient>().0.clone()
 622    }
 623    pub fn set_global(client: Arc<Client>, cx: &mut AppContext) {
 624        cx.set_global(GlobalClient(client))
 625    }
 626
 627    pub fn user_id(&self) -> Option<u64> {
 628        if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() {
 629            Some(*user_id)
 630        } else {
 631            None
 632        }
 633    }
 634
 635    pub fn peer_id(&self) -> Option<PeerId> {
 636        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 637            Some(*peer_id)
 638        } else {
 639            None
 640        }
 641    }
 642
 643    pub fn status(&self) -> watch::Receiver<Status> {
 644        self.state.read().status.1.clone()
 645    }
 646
 647    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 648        log::info!("set status on client {}: {:?}", self.id(), status);
 649        let mut state = self.state.write();
 650        *state.status.0.borrow_mut() = status;
 651
 652        match status {
 653            Status::Connected { .. } => {
 654                state._reconnect_task = None;
 655            }
 656            Status::ConnectionLost => {
 657                let this = self.clone();
 658                state._reconnect_task = Some(cx.spawn(move |cx| async move {
 659                    #[cfg(any(test, feature = "test-support"))]
 660                    let mut rng = StdRng::seed_from_u64(0);
 661                    #[cfg(not(any(test, feature = "test-support")))]
 662                    let mut rng = StdRng::from_entropy();
 663
 664                    let mut delay = INITIAL_RECONNECTION_DELAY;
 665                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 666                        log::error!("failed to connect {}", error);
 667                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 668                            this.set_status(
 669                                Status::ReconnectionError {
 670                                    next_reconnection: Instant::now() + delay,
 671                                },
 672                                &cx,
 673                            );
 674                            cx.background_executor().timer(delay).await;
 675                            delay = delay
 676                                .mul_f32(rng.gen_range(0.5..=2.5))
 677                                .max(INITIAL_RECONNECTION_DELAY)
 678                                .min(MAX_RECONNECTION_DELAY);
 679                        } else {
 680                            break;
 681                        }
 682                    }
 683                }));
 684            }
 685            Status::SignedOut | Status::UpgradeRequired => {
 686                self.telemetry.set_authenticated_user_info(None, false);
 687                state._reconnect_task.take();
 688            }
 689            _ => {}
 690        }
 691    }
 692
 693    pub fn subscribe_to_entity<T>(
 694        self: &Arc<Self>,
 695        remote_id: u64,
 696    ) -> Result<PendingEntitySubscription<T>>
 697    where
 698        T: 'static,
 699    {
 700        let id = (TypeId::of::<T>(), remote_id);
 701
 702        let mut state = self.handler_set.lock();
 703        if state.entities_by_type_and_remote_id.contains_key(&id) {
 704            return Err(anyhow!("already subscribed to entity"));
 705        }
 706
 707        state
 708            .entities_by_type_and_remote_id
 709            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 710
 711        Ok(PendingEntitySubscription {
 712            client: self.clone(),
 713            remote_id,
 714            consumed: false,
 715            _entity_type: PhantomData,
 716        })
 717    }
 718
 719    #[track_caller]
 720    pub fn add_message_handler<M, E, H, F>(
 721        self: &Arc<Self>,
 722        entity: WeakModel<E>,
 723        handler: H,
 724    ) -> Subscription
 725    where
 726        M: EnvelopedMessage,
 727        E: 'static,
 728        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 729        F: 'static + Future<Output = Result<()>>,
 730    {
 731        self.add_message_handler_impl(entity, move |model, message, _, cx| {
 732            handler(model, message, cx)
 733        })
 734    }
 735
 736    fn add_message_handler_impl<M, E, H, F>(
 737        self: &Arc<Self>,
 738        entity: WeakModel<E>,
 739        handler: H,
 740    ) -> Subscription
 741    where
 742        M: EnvelopedMessage,
 743        E: 'static,
 744        H: 'static
 745            + Sync
 746            + Fn(Model<E>, TypedEnvelope<M>, AnyProtoClient, AsyncAppContext) -> F
 747            + Send
 748            + Sync,
 749        F: 'static + Future<Output = Result<()>>,
 750    {
 751        let message_type_id = TypeId::of::<M>();
 752        let mut state = self.handler_set.lock();
 753        state
 754            .models_by_message_type
 755            .insert(message_type_id, entity.into());
 756
 757        let prev_handler = state.message_handlers.insert(
 758            message_type_id,
 759            Arc::new(move |subscriber, envelope, client, cx| {
 760                let subscriber = subscriber.downcast::<E>().unwrap();
 761                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 762                handler(subscriber, *envelope, client.clone(), cx).boxed_local()
 763            }),
 764        );
 765        if prev_handler.is_some() {
 766            let location = std::panic::Location::caller();
 767            panic!(
 768                "{}:{} registered handler for the same message {} twice",
 769                location.file(),
 770                location.line(),
 771                std::any::type_name::<M>()
 772            );
 773        }
 774
 775        Subscription::Message {
 776            client: Arc::downgrade(self),
 777            id: message_type_id,
 778        }
 779    }
 780
 781    pub fn add_request_handler<M, E, H, F>(
 782        self: &Arc<Self>,
 783        model: WeakModel<E>,
 784        handler: H,
 785    ) -> Subscription
 786    where
 787        M: RequestMessage,
 788        E: 'static,
 789        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 790        F: 'static + Future<Output = Result<M::Response>>,
 791    {
 792        self.add_message_handler_impl(model, move |handle, envelope, this, cx| {
 793            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 794        })
 795    }
 796
 797    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 798        receipt: Receipt<T>,
 799        response: F,
 800        client: AnyProtoClient,
 801    ) -> Result<()> {
 802        match response.await {
 803            Ok(response) => {
 804                client.send_response(receipt.message_id, response)?;
 805                Ok(())
 806            }
 807            Err(error) => {
 808                client.send_response(receipt.message_id, error.to_proto())?;
 809                Err(error)
 810            }
 811        }
 812    }
 813
 814    pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool {
 815        self.credentials_provider
 816            .read_credentials(cx)
 817            .await
 818            .is_some()
 819    }
 820
 821    pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self {
 822        self.state.write().credentials = Some(Credentials::DevServer { token });
 823        self
 824    }
 825
 826    #[async_recursion(?Send)]
 827    pub async fn authenticate_and_connect(
 828        self: &Arc<Self>,
 829        try_provider: bool,
 830        cx: &AsyncAppContext,
 831    ) -> anyhow::Result<()> {
 832        let was_disconnected = match *self.status().borrow() {
 833            Status::SignedOut => true,
 834            Status::ConnectionError
 835            | Status::ConnectionLost
 836            | Status::Authenticating { .. }
 837            | Status::Reauthenticating { .. }
 838            | Status::ReconnectionError { .. } => false,
 839            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 840                return Ok(())
 841            }
 842            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 843        };
 844        if was_disconnected {
 845            self.set_status(Status::Authenticating, cx);
 846        } else {
 847            self.set_status(Status::Reauthenticating, cx)
 848        }
 849
 850        let mut read_from_provider = false;
 851        let mut credentials = self.state.read().credentials.clone();
 852        if credentials.is_none() && try_provider {
 853            credentials = self.credentials_provider.read_credentials(cx).await;
 854            read_from_provider = credentials.is_some();
 855        }
 856
 857        if credentials.is_none() {
 858            let mut status_rx = self.status();
 859            let _ = status_rx.next().await;
 860            futures::select_biased! {
 861                authenticate = self.authenticate(cx).fuse() => {
 862                    match authenticate {
 863                        Ok(creds) => credentials = Some(creds),
 864                        Err(err) => {
 865                            self.set_status(Status::ConnectionError, cx);
 866                            return Err(err);
 867                        }
 868                    }
 869                }
 870                _ = status_rx.next().fuse() => {
 871                    return Err(anyhow!("authentication canceled"));
 872                }
 873            }
 874        }
 875        let credentials = credentials.unwrap();
 876        if let Credentials::User { user_id, .. } = &credentials {
 877            self.set_id(*user_id);
 878        }
 879
 880        if was_disconnected {
 881            self.set_status(Status::Connecting, cx);
 882        } else {
 883            self.set_status(Status::Reconnecting, cx);
 884        }
 885
 886        let mut timeout =
 887            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
 888        futures::select_biased! {
 889            connection = self.establish_connection(&credentials, cx).fuse() => {
 890                match connection {
 891                    Ok(conn) => {
 892                        self.state.write().credentials = Some(credentials.clone());
 893                        if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
 894                            if let Credentials::User{user_id, access_token} = credentials {
 895                                self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err();
 896                            }
 897                        }
 898
 899                        futures::select_biased! {
 900                            result = self.set_connection(conn, cx).fuse() => result,
 901                            _ = timeout => {
 902                                self.set_status(Status::ConnectionError, cx);
 903                                Err(anyhow!("timed out waiting on hello message from server"))
 904                            }
 905                        }
 906                    }
 907                    Err(EstablishConnectionError::Unauthorized) => {
 908                        self.state.write().credentials.take();
 909                        if read_from_provider {
 910                            self.credentials_provider.delete_credentials(cx).await.log_err();
 911                            self.set_status(Status::SignedOut, cx);
 912                            self.authenticate_and_connect(false, cx).await
 913                        } else {
 914                            self.set_status(Status::ConnectionError, cx);
 915                            Err(EstablishConnectionError::Unauthorized)?
 916                        }
 917                    }
 918                    Err(EstablishConnectionError::UpgradeRequired) => {
 919                        self.set_status(Status::UpgradeRequired, cx);
 920                        Err(EstablishConnectionError::UpgradeRequired)?
 921                    }
 922                    Err(error) => {
 923                        self.set_status(Status::ConnectionError, cx);
 924                        Err(error)?
 925                    }
 926                }
 927            }
 928            _ = &mut timeout => {
 929                self.set_status(Status::ConnectionError, cx);
 930                Err(anyhow!("timed out trying to establish connection"))
 931            }
 932        }
 933    }
 934
 935    async fn set_connection(
 936        self: &Arc<Self>,
 937        conn: Connection,
 938        cx: &AsyncAppContext,
 939    ) -> Result<()> {
 940        let executor = cx.background_executor();
 941        log::info!("add connection to peer");
 942        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
 943            let executor = executor.clone();
 944            move |duration| executor.timer(duration)
 945        });
 946        let handle_io = executor.spawn(handle_io);
 947
 948        let peer_id = async {
 949            log::info!("waiting for server hello");
 950            let message = incoming
 951                .next()
 952                .await
 953                .ok_or_else(|| anyhow!("no hello message received"))?;
 954            log::info!("got server hello");
 955            let hello_message_type_name = message.payload_type_name().to_string();
 956            let hello = message
 957                .into_any()
 958                .downcast::<TypedEnvelope<proto::Hello>>()
 959                .map_err(|_| {
 960                    anyhow!(
 961                        "invalid hello message received: {:?}",
 962                        hello_message_type_name
 963                    )
 964                })?;
 965            let peer_id = hello
 966                .payload
 967                .peer_id
 968                .ok_or_else(|| anyhow!("invalid peer id"))?;
 969            Ok(peer_id)
 970        };
 971
 972        let peer_id = match peer_id.await {
 973            Ok(peer_id) => peer_id,
 974            Err(error) => {
 975                self.peer.disconnect(connection_id);
 976                return Err(error);
 977            }
 978        };
 979
 980        log::info!(
 981            "set status to connected (connection id: {:?}, peer id: {:?})",
 982            connection_id,
 983            peer_id
 984        );
 985        self.set_status(
 986            Status::Connected {
 987                peer_id,
 988                connection_id,
 989            },
 990            cx,
 991        );
 992
 993        cx.spawn({
 994            let this = self.clone();
 995            |cx| {
 996                async move {
 997                    while let Some(message) = incoming.next().await {
 998                        this.handle_message(message, &cx);
 999                        // Don't starve the main thread when receiving lots of messages at once.
1000                        smol::future::yield_now().await;
1001                    }
1002                }
1003            }
1004        })
1005        .detach();
1006
1007        cx.spawn({
1008            let this = self.clone();
1009            move |cx| async move {
1010                match handle_io.await {
1011                    Ok(()) => {
1012                        if *this.status().borrow()
1013                            == (Status::Connected {
1014                                connection_id,
1015                                peer_id,
1016                            })
1017                        {
1018                            this.set_status(Status::SignedOut, &cx);
1019                        }
1020                    }
1021                    Err(err) => {
1022                        log::error!("connection error: {:?}", err);
1023                        this.set_status(Status::ConnectionLost, &cx);
1024                    }
1025                }
1026            }
1027        })
1028        .detach();
1029
1030        Ok(())
1031    }
1032
1033    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
1034        #[cfg(any(test, feature = "test-support"))]
1035        if let Some(callback) = self.authenticate.read().as_ref() {
1036            return callback(cx);
1037        }
1038
1039        self.authenticate_with_browser(cx)
1040    }
1041
1042    fn establish_connection(
1043        self: &Arc<Self>,
1044        credentials: &Credentials,
1045        cx: &AsyncAppContext,
1046    ) -> Task<Result<Connection, EstablishConnectionError>> {
1047        #[cfg(any(test, feature = "test-support"))]
1048        if let Some(callback) = self.establish_connection.read().as_ref() {
1049            return callback(credentials, cx);
1050        }
1051
1052        self.establish_websocket_connection(credentials, cx)
1053    }
1054
1055    fn rpc_url(
1056        &self,
1057        http: Arc<HttpClientWithUrl>,
1058        release_channel: Option<ReleaseChannel>,
1059    ) -> impl Future<Output = Result<Url>> {
1060        #[cfg(any(test, feature = "test-support"))]
1061        let url_override = self.rpc_url.read().clone();
1062
1063        async move {
1064            #[cfg(any(test, feature = "test-support"))]
1065            if let Some(url) = url_override {
1066                return Ok(url);
1067            }
1068
1069            if let Some(url) = &*ZED_RPC_URL {
1070                return Url::parse(url).context("invalid rpc url");
1071            }
1072
1073            let mut url = http.build_url("/rpc");
1074            if let Some(preview_param) =
1075                release_channel.and_then(|channel| channel.release_query_param())
1076            {
1077                url += "?";
1078                url += preview_param;
1079            }
1080
1081            let response = http.get(&url, Default::default(), false).await?;
1082            let collab_url = if response.status().is_redirection() {
1083                response
1084                    .headers()
1085                    .get("Location")
1086                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1087                    .to_str()
1088                    .map_err(EstablishConnectionError::other)?
1089                    .to_string()
1090            } else {
1091                Err(anyhow!(
1092                    "unexpected /rpc response status {}",
1093                    response.status()
1094                ))?
1095            };
1096
1097            Url::parse(&collab_url).context("invalid rpc url")
1098        }
1099    }
1100
1101    fn establish_websocket_connection(
1102        self: &Arc<Self>,
1103        credentials: &Credentials,
1104        cx: &AsyncAppContext,
1105    ) -> Task<Result<Connection, EstablishConnectionError>> {
1106        let release_channel = cx
1107            .update(|cx| ReleaseChannel::try_global(cx))
1108            .ok()
1109            .flatten();
1110        let app_version = cx
1111            .update(|cx| AppVersion::global(cx).to_string())
1112            .ok()
1113            .unwrap_or_default();
1114
1115        let http = self.http.clone();
1116        let proxy = http.proxy().cloned();
1117        let credentials = credentials.clone();
1118        let rpc_url = self.rpc_url(http, release_channel);
1119        cx.background_executor().spawn(async move {
1120            use HttpOrHttps::*;
1121
1122            #[derive(Debug)]
1123            enum HttpOrHttps {
1124                Http,
1125                Https,
1126            }
1127
1128            let mut rpc_url = rpc_url.await?;
1129            let url_scheme = match rpc_url.scheme() {
1130                "https" => Https,
1131                "http" => Http,
1132                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1133            };
1134            let rpc_host = rpc_url
1135                .host_str()
1136                .zip(rpc_url.port_or_known_default())
1137                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1138            let stream = connect_socks_proxy_stream(proxy.as_ref(), rpc_host).await?;
1139
1140            log::info!("connected to rpc endpoint {}", rpc_url);
1141
1142            rpc_url
1143                .set_scheme(match url_scheme {
1144                    Https => "wss",
1145                    Http => "ws",
1146                })
1147                .unwrap();
1148
1149            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1150            // for us from the RPC URL.
1151            //
1152            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1153            let mut request = rpc_url.into_client_request()?;
1154
1155            // We then modify the request to add our desired headers.
1156            let request_headers = request.headers_mut();
1157            request_headers.insert(
1158                "Authorization",
1159                HeaderValue::from_str(&credentials.authorization_header())?,
1160            );
1161            request_headers.insert(
1162                "x-zed-protocol-version",
1163                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1164            );
1165            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1166            request_headers.insert(
1167                "x-zed-release-channel",
1168                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1169            );
1170
1171            match url_scheme {
1172                Https => {
1173                    let (stream, _) =
1174                        async_tungstenite::async_std::client_async_tls(request, stream).await?;
1175                    Ok(Connection::new(
1176                        stream
1177                            .map_err(|error| anyhow!(error))
1178                            .sink_map_err(|error| anyhow!(error)),
1179                    ))
1180                }
1181                Http => {
1182                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1183                    Ok(Connection::new(
1184                        stream
1185                            .map_err(|error| anyhow!(error))
1186                            .sink_map_err(|error| anyhow!(error)),
1187                    ))
1188                }
1189            }
1190        })
1191    }
1192
1193    pub fn authenticate_with_browser(
1194        self: &Arc<Self>,
1195        cx: &AsyncAppContext,
1196    ) -> Task<Result<Credentials>> {
1197        let http = self.http.clone();
1198        let this = self.clone();
1199        cx.spawn(|cx| async move {
1200            let background = cx.background_executor().clone();
1201
1202            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1203            cx.update(|cx| {
1204                cx.spawn(move |cx| async move {
1205                    let url = open_url_rx.await?;
1206                    cx.update(|cx| cx.open_url(&url))
1207                })
1208                .detach_and_log_err(cx);
1209            })
1210            .log_err();
1211
1212            let credentials = background
1213                .clone()
1214                .spawn(async move {
1215                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1216                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1217                    // any other app running on the user's device.
1218                    let (public_key, private_key) =
1219                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1220                    let public_key_string = String::try_from(public_key)
1221                        .expect("failed to serialize public key for auth");
1222
1223                    if let Some((login, token)) =
1224                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1225                    {
1226                        eprintln!("authenticate as admin {login}, {token}");
1227
1228                        return this
1229                            .authenticate_as_admin(http, login.clone(), token.clone())
1230                            .await;
1231                    }
1232
1233                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1234                    let server =
1235                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1236                    let port = server.server_addr().port();
1237
1238                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1239                    // that the user is signing in from a Zed app running on the same device.
1240                    let mut url = http.build_url(&format!(
1241                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1242                        port, public_key_string
1243                    ));
1244
1245                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1246                        log::info!("impersonating user @{}", impersonate_login);
1247                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1248                    }
1249
1250                    open_url_tx.send(url).log_err();
1251
1252                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1253                    // access token from the query params.
1254                    //
1255                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1256                    // custom URL scheme instead of this local HTTP server.
1257                    let (user_id, access_token) = background
1258                        .spawn(async move {
1259                            for _ in 0..100 {
1260                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1261                                    let path = req.url();
1262                                    let mut user_id = None;
1263                                    let mut access_token = None;
1264                                    let url = Url::parse(&format!("http://example.com{}", path))
1265                                        .context("failed to parse login notification url")?;
1266                                    for (key, value) in url.query_pairs() {
1267                                        if key == "access_token" {
1268                                            access_token = Some(value.to_string());
1269                                        } else if key == "user_id" {
1270                                            user_id = Some(value.to_string());
1271                                        }
1272                                    }
1273
1274                                    let post_auth_url =
1275                                        http.build_url("/native_app_signin_succeeded");
1276                                    req.respond(
1277                                        tiny_http::Response::empty(302).with_header(
1278                                            tiny_http::Header::from_bytes(
1279                                                &b"Location"[..],
1280                                                post_auth_url.as_bytes(),
1281                                            )
1282                                            .unwrap(),
1283                                        ),
1284                                    )
1285                                    .context("failed to respond to login http request")?;
1286                                    return Ok((
1287                                        user_id
1288                                            .ok_or_else(|| anyhow!("missing user_id parameter"))?,
1289                                        access_token.ok_or_else(|| {
1290                                            anyhow!("missing access_token parameter")
1291                                        })?,
1292                                    ));
1293                                }
1294                            }
1295
1296                            Err(anyhow!("didn't receive login redirect"))
1297                        })
1298                        .await?;
1299
1300                    let access_token = private_key
1301                        .decrypt_string(&access_token)
1302                        .context("failed to decrypt access token")?;
1303
1304                    Ok(Credentials::User {
1305                        user_id: user_id.parse()?,
1306                        access_token,
1307                    })
1308                })
1309                .await?;
1310
1311            cx.update(|cx| cx.activate(true))?;
1312            Ok(credentials)
1313        })
1314    }
1315
1316    async fn authenticate_as_admin(
1317        self: &Arc<Self>,
1318        http: Arc<HttpClientWithUrl>,
1319        login: String,
1320        mut api_token: String,
1321    ) -> Result<Credentials> {
1322        #[derive(Deserialize)]
1323        struct AuthenticatedUserResponse {
1324            user: User,
1325        }
1326
1327        #[derive(Deserialize)]
1328        struct User {
1329            id: u64,
1330        }
1331
1332        let github_user = {
1333            #[derive(Deserialize)]
1334            struct GithubUser {
1335                id: i32,
1336                login: String,
1337                created_at: DateTime<Utc>,
1338            }
1339
1340            let request = {
1341                let mut request_builder =
1342                    Request::get(&format!("https://api.github.com/users/{login}"));
1343                if let Ok(github_token) = std::env::var("GITHUB_TOKEN") {
1344                    request_builder =
1345                        request_builder.header("Authorization", format!("Bearer {}", github_token));
1346                }
1347
1348                request_builder.body(AsyncBody::empty())?
1349            };
1350
1351            let mut response = http
1352                .send(request)
1353                .await
1354                .context("error fetching GitHub user")?;
1355
1356            let mut body = Vec::new();
1357            response
1358                .body_mut()
1359                .read_to_end(&mut body)
1360                .await
1361                .context("error reading GitHub user")?;
1362
1363            if !response.status().is_success() {
1364                let text = String::from_utf8_lossy(body.as_slice());
1365                bail!(
1366                    "status error {}, response: {text:?}",
1367                    response.status().as_u16()
1368                );
1369            }
1370
1371            serde_json::from_slice::<GithubUser>(body.as_slice()).map_err(|err| {
1372                log::error!("Error deserializing: {:?}", err);
1373                log::error!(
1374                    "GitHub API response text: {:?}",
1375                    String::from_utf8_lossy(body.as_slice())
1376                );
1377                anyhow!("error deserializing GitHub user")
1378            })?
1379        };
1380
1381        let query_params = [
1382            ("github_login", &github_user.login),
1383            ("github_user_id", &github_user.id.to_string()),
1384            (
1385                "github_user_created_at",
1386                &github_user.created_at.to_rfc3339(),
1387            ),
1388        ];
1389
1390        // Use the collab server's admin API to retrieve the ID
1391        // of the impersonated user.
1392        let mut url = self.rpc_url(http.clone(), None).await?;
1393        url.set_path("/user");
1394        url.set_query(Some(
1395            &query_params
1396                .iter()
1397                .map(|(key, value)| {
1398                    format!(
1399                        "{}={}",
1400                        key,
1401                        url::form_urlencoded::byte_serialize(value.as_bytes()).collect::<String>()
1402                    )
1403                })
1404                .collect::<Vec<String>>()
1405                .join("&"),
1406        ));
1407        let request: http_client::Request<AsyncBody> = Request::get(url.as_str())
1408            .header("Authorization", format!("token {api_token}"))
1409            .body("".into())?;
1410
1411        let mut response = http.send(request).await?;
1412        let mut body = String::new();
1413        response.body_mut().read_to_string(&mut body).await?;
1414        if !response.status().is_success() {
1415            Err(anyhow!(
1416                "admin user request failed {} - {}",
1417                response.status().as_u16(),
1418                body,
1419            ))?;
1420        }
1421        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1422
1423        // Use the admin API token to authenticate as the impersonated user.
1424        api_token.insert_str(0, "ADMIN_TOKEN:");
1425        Ok(Credentials::User {
1426            user_id: response.user.id,
1427            access_token: api_token,
1428        })
1429    }
1430
1431    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncAppContext) {
1432        self.state.write().credentials = None;
1433        self.disconnect(cx);
1434
1435        if self.has_credentials(cx).await {
1436            self.credentials_provider
1437                .delete_credentials(cx)
1438                .await
1439                .log_err();
1440        }
1441    }
1442
1443    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1444        self.peer.teardown();
1445        self.set_status(Status::SignedOut, cx);
1446    }
1447
1448    pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1449        self.peer.teardown();
1450        self.set_status(Status::ConnectionLost, cx);
1451    }
1452
1453    fn connection_id(&self) -> Result<ConnectionId> {
1454        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1455            Ok(connection_id)
1456        } else {
1457            Err(anyhow!("not connected"))
1458        }
1459    }
1460
1461    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1462        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1463        self.peer.send(self.connection_id()?, message)
1464    }
1465
1466    pub fn request<T: RequestMessage>(
1467        &self,
1468        request: T,
1469    ) -> impl Future<Output = Result<T::Response>> {
1470        self.request_envelope(request)
1471            .map_ok(|envelope| envelope.payload)
1472    }
1473
1474    pub fn request_stream<T: RequestMessage>(
1475        &self,
1476        request: T,
1477    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1478        let client_id = self.id.load(Ordering::SeqCst);
1479        log::debug!(
1480            "rpc request start. client_id:{}. name:{}",
1481            client_id,
1482            T::NAME
1483        );
1484        let response = self
1485            .connection_id()
1486            .map(|conn_id| self.peer.request_stream(conn_id, request));
1487        async move {
1488            let response = response?.await;
1489            log::debug!(
1490                "rpc request finish. client_id:{}. name:{}",
1491                client_id,
1492                T::NAME
1493            );
1494            response
1495        }
1496    }
1497
1498    pub fn request_envelope<T: RequestMessage>(
1499        &self,
1500        request: T,
1501    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1502        let client_id = self.id();
1503        log::debug!(
1504            "rpc request start. client_id:{}. name:{}",
1505            client_id,
1506            T::NAME
1507        );
1508        let response = self
1509            .connection_id()
1510            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1511        async move {
1512            let response = response?.await;
1513            log::debug!(
1514                "rpc request finish. client_id:{}. name:{}",
1515                client_id,
1516                T::NAME
1517            );
1518            response
1519        }
1520    }
1521
1522    pub fn request_dynamic(
1523        &self,
1524        envelope: proto::Envelope,
1525        request_type: &'static str,
1526    ) -> impl Future<Output = Result<proto::Envelope>> {
1527        let client_id = self.id();
1528        log::debug!(
1529            "rpc request start. client_id:{}. name:{}",
1530            client_id,
1531            request_type
1532        );
1533        let response = self
1534            .connection_id()
1535            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1536        async move {
1537            let response = response?.await;
1538            log::debug!(
1539                "rpc request finish. client_id:{}. name:{}",
1540                client_id,
1541                request_type
1542            );
1543            Ok(response?.0)
1544        }
1545    }
1546
1547    fn handle_message(
1548        self: &Arc<Client>,
1549        message: Box<dyn AnyTypedEnvelope>,
1550        cx: &AsyncAppContext,
1551    ) {
1552        let sender_id = message.sender_id();
1553        let request_id = message.message_id();
1554        let type_name = message.payload_type_name();
1555        let original_sender_id = message.original_sender_id();
1556
1557        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1558            &self.handler_set,
1559            message,
1560            self.clone().into(),
1561            cx.clone(),
1562        ) {
1563            let client_id = self.id();
1564            log::debug!(
1565                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1566                client_id,
1567                original_sender_id,
1568                type_name
1569            );
1570            cx.spawn(move |_| async move {
1571                match future.await {
1572                    Ok(()) => {
1573                        log::debug!(
1574                            "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1575                            client_id,
1576                            original_sender_id,
1577                            type_name
1578                        );
1579                    }
1580                    Err(error) => {
1581                        log::error!(
1582                            "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1583                            client_id,
1584                            original_sender_id,
1585                            type_name,
1586                            error
1587                        );
1588                    }
1589                }
1590            })
1591            .detach();
1592        } else {
1593            log::info!("unhandled message {}", type_name);
1594            self.peer
1595                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1596                .log_err();
1597        }
1598    }
1599
1600    pub fn telemetry(&self) -> &Arc<Telemetry> {
1601        &self.telemetry
1602    }
1603}
1604
1605impl ProtoClient for Client {
1606    fn request(
1607        &self,
1608        envelope: proto::Envelope,
1609        request_type: &'static str,
1610    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1611        self.request_dynamic(envelope, request_type).boxed()
1612    }
1613
1614    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1615        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1616        let connection_id = self.connection_id()?;
1617        self.peer.send_dynamic(connection_id, envelope)
1618    }
1619
1620    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1621        log::debug!(
1622            "rpc respond. client_id:{}, name:{}",
1623            self.id(),
1624            message_type
1625        );
1626        let connection_id = self.connection_id()?;
1627        self.peer.send_dynamic(connection_id, envelope)
1628    }
1629
1630    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1631        &self.handler_set
1632    }
1633}
1634
1635#[derive(Serialize, Deserialize)]
1636struct DevelopmentCredentials {
1637    user_id: u64,
1638    access_token: String,
1639}
1640
1641/// A credentials provider that stores credentials in a local file.
1642///
1643/// This MUST only be used in development, as this is not a secure way of storing
1644/// credentials on user machines.
1645///
1646/// Its existence is purely to work around the annoyance of having to constantly
1647/// re-allow access to the system keychain when developing Zed.
1648struct DevelopmentCredentialsProvider {
1649    path: PathBuf,
1650}
1651
1652impl CredentialsProvider for DevelopmentCredentialsProvider {
1653    fn read_credentials<'a>(
1654        &'a self,
1655        _cx: &'a AsyncAppContext,
1656    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1657        async move {
1658            if IMPERSONATE_LOGIN.is_some() {
1659                return None;
1660            }
1661
1662            let json = std::fs::read(&self.path).log_err()?;
1663
1664            let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?;
1665
1666            Some(Credentials::User {
1667                user_id: credentials.user_id,
1668                access_token: credentials.access_token,
1669            })
1670        }
1671        .boxed_local()
1672    }
1673
1674    fn write_credentials<'a>(
1675        &'a self,
1676        user_id: u64,
1677        access_token: String,
1678        _cx: &'a AsyncAppContext,
1679    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1680        async move {
1681            let json = serde_json::to_string(&DevelopmentCredentials {
1682                user_id,
1683                access_token,
1684            })?;
1685
1686            std::fs::write(&self.path, json)?;
1687
1688            Ok(())
1689        }
1690        .boxed_local()
1691    }
1692
1693    fn delete_credentials<'a>(
1694        &'a self,
1695        _cx: &'a AsyncAppContext,
1696    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1697        async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local()
1698    }
1699}
1700
1701/// A credentials provider that stores credentials in the system keychain.
1702struct KeychainCredentialsProvider;
1703
1704impl CredentialsProvider for KeychainCredentialsProvider {
1705    fn read_credentials<'a>(
1706        &'a self,
1707        cx: &'a AsyncAppContext,
1708    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1709        async move {
1710            if IMPERSONATE_LOGIN.is_some() {
1711                return None;
1712            }
1713
1714            let (user_id, access_token) = cx
1715                .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
1716                .log_err()?
1717                .await
1718                .log_err()??;
1719
1720            Some(Credentials::User {
1721                user_id: user_id.parse().ok()?,
1722                access_token: String::from_utf8(access_token).ok()?,
1723            })
1724        }
1725        .boxed_local()
1726    }
1727
1728    fn write_credentials<'a>(
1729        &'a self,
1730        user_id: u64,
1731        access_token: String,
1732        cx: &'a AsyncAppContext,
1733    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1734        async move {
1735            cx.update(move |cx| {
1736                cx.write_credentials(
1737                    &ClientSettings::get_global(cx).server_url,
1738                    &user_id.to_string(),
1739                    access_token.as_bytes(),
1740                )
1741            })?
1742            .await
1743        }
1744        .boxed_local()
1745    }
1746
1747    fn delete_credentials<'a>(
1748        &'a self,
1749        cx: &'a AsyncAppContext,
1750    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1751        async move {
1752            cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
1753                .await
1754        }
1755        .boxed_local()
1756    }
1757}
1758
1759/// prefix for the zed:// url scheme
1760pub static ZED_URL_SCHEME: &str = "zed";
1761
1762/// Parses the given link into a Zed link.
1763///
1764/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1765/// Returns [`None`] otherwise.
1766pub fn parse_zed_link<'a>(link: &'a str, cx: &AppContext) -> Option<&'a str> {
1767    let server_url = &ClientSettings::get_global(cx).server_url;
1768    if let Some(stripped) = link
1769        .strip_prefix(server_url)
1770        .and_then(|result| result.strip_prefix('/'))
1771    {
1772        return Some(stripped);
1773    }
1774    if let Some(stripped) = link
1775        .strip_prefix(ZED_URL_SCHEME)
1776        .and_then(|result| result.strip_prefix("://"))
1777    {
1778        return Some(stripped);
1779    }
1780
1781    None
1782}
1783
1784#[cfg(test)]
1785mod tests {
1786    use super::*;
1787    use crate::test::FakeServer;
1788
1789    use clock::FakeSystemClock;
1790    use gpui::{BackgroundExecutor, Context, TestAppContext};
1791    use http_client::FakeHttpClient;
1792    use parking_lot::Mutex;
1793    use proto::TypedEnvelope;
1794    use settings::SettingsStore;
1795    use std::future;
1796
1797    #[gpui::test(iterations = 10)]
1798    async fn test_reconnection(cx: &mut TestAppContext) {
1799        init_test(cx);
1800        let user_id = 5;
1801        let client = cx.update(|cx| {
1802            Client::new(
1803                Arc::new(FakeSystemClock::default()),
1804                FakeHttpClient::with_404_response(),
1805                cx,
1806            )
1807        });
1808        let server = FakeServer::for_client(user_id, &client, cx).await;
1809        let mut status = client.status();
1810        assert!(matches!(
1811            status.next().await,
1812            Some(Status::Connected { .. })
1813        ));
1814        assert_eq!(server.auth_count(), 1);
1815
1816        server.forbid_connections();
1817        server.disconnect();
1818        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1819
1820        server.allow_connections();
1821        cx.executor().advance_clock(Duration::from_secs(10));
1822        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1823        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1824
1825        server.forbid_connections();
1826        server.disconnect();
1827        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1828
1829        // Clear cached credentials after authentication fails
1830        server.roll_access_token();
1831        server.allow_connections();
1832        cx.executor().run_until_parked();
1833        cx.executor().advance_clock(Duration::from_secs(10));
1834        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1835        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1836    }
1837
1838    #[gpui::test(iterations = 10)]
1839    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1840        init_test(cx);
1841        let user_id = 5;
1842        let client = cx.update(|cx| {
1843            Client::new(
1844                Arc::new(FakeSystemClock::default()),
1845                FakeHttpClient::with_404_response(),
1846                cx,
1847            )
1848        });
1849        let mut status = client.status();
1850
1851        // Time out when client tries to connect.
1852        client.override_authenticate(move |cx| {
1853            cx.background_executor().spawn(async move {
1854                Ok(Credentials::User {
1855                    user_id,
1856                    access_token: "token".into(),
1857                })
1858            })
1859        });
1860        client.override_establish_connection(|_, cx| {
1861            cx.background_executor().spawn(async move {
1862                future::pending::<()>().await;
1863                unreachable!()
1864            })
1865        });
1866        let auth_and_connect = cx.spawn({
1867            let client = client.clone();
1868            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1869        });
1870        executor.run_until_parked();
1871        assert!(matches!(status.next().await, Some(Status::Connecting)));
1872
1873        executor.advance_clock(CONNECTION_TIMEOUT);
1874        assert!(matches!(
1875            status.next().await,
1876            Some(Status::ConnectionError { .. })
1877        ));
1878        auth_and_connect.await.unwrap_err();
1879
1880        // Allow the connection to be established.
1881        let server = FakeServer::for_client(user_id, &client, cx).await;
1882        assert!(matches!(
1883            status.next().await,
1884            Some(Status::Connected { .. })
1885        ));
1886
1887        // Disconnect client.
1888        server.forbid_connections();
1889        server.disconnect();
1890        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1891
1892        // Time out when re-establishing the connection.
1893        server.allow_connections();
1894        client.override_establish_connection(|_, cx| {
1895            cx.background_executor().spawn(async move {
1896                future::pending::<()>().await;
1897                unreachable!()
1898            })
1899        });
1900        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1901        assert!(matches!(
1902            status.next().await,
1903            Some(Status::Reconnecting { .. })
1904        ));
1905
1906        executor.advance_clock(CONNECTION_TIMEOUT);
1907        assert!(matches!(
1908            status.next().await,
1909            Some(Status::ReconnectionError { .. })
1910        ));
1911    }
1912
1913    #[gpui::test(iterations = 10)]
1914    async fn test_authenticating_more_than_once(
1915        cx: &mut TestAppContext,
1916        executor: BackgroundExecutor,
1917    ) {
1918        init_test(cx);
1919        let auth_count = Arc::new(Mutex::new(0));
1920        let dropped_auth_count = Arc::new(Mutex::new(0));
1921        let client = cx.update(|cx| {
1922            Client::new(
1923                Arc::new(FakeSystemClock::default()),
1924                FakeHttpClient::with_404_response(),
1925                cx,
1926            )
1927        });
1928        client.override_authenticate({
1929            let auth_count = auth_count.clone();
1930            let dropped_auth_count = dropped_auth_count.clone();
1931            move |cx| {
1932                let auth_count = auth_count.clone();
1933                let dropped_auth_count = dropped_auth_count.clone();
1934                cx.background_executor().spawn(async move {
1935                    *auth_count.lock() += 1;
1936                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1937                    future::pending::<()>().await;
1938                    unreachable!()
1939                })
1940            }
1941        });
1942
1943        let _authenticate = cx.spawn({
1944            let client = client.clone();
1945            move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1946        });
1947        executor.run_until_parked();
1948        assert_eq!(*auth_count.lock(), 1);
1949        assert_eq!(*dropped_auth_count.lock(), 0);
1950
1951        let _authenticate = cx.spawn({
1952            let client = client.clone();
1953            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1954        });
1955        executor.run_until_parked();
1956        assert_eq!(*auth_count.lock(), 2);
1957        assert_eq!(*dropped_auth_count.lock(), 1);
1958    }
1959
1960    #[gpui::test]
1961    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1962        init_test(cx);
1963        let user_id = 5;
1964        let client = cx.update(|cx| {
1965            Client::new(
1966                Arc::new(FakeSystemClock::default()),
1967                FakeHttpClient::with_404_response(),
1968                cx,
1969            )
1970        });
1971        let server = FakeServer::for_client(user_id, &client, cx).await;
1972
1973        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1974        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1975        AnyProtoClient::from(client.clone()).add_model_message_handler(
1976            move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, mut cx| {
1977                match model.update(&mut cx, |model, _| model.id).unwrap() {
1978                    1 => done_tx1.try_send(()).unwrap(),
1979                    2 => done_tx2.try_send(()).unwrap(),
1980                    _ => unreachable!(),
1981                }
1982                async { Ok(()) }
1983            },
1984        );
1985        let model1 = cx.new_model(|_| TestModel {
1986            id: 1,
1987            subscription: None,
1988        });
1989        let model2 = cx.new_model(|_| TestModel {
1990            id: 2,
1991            subscription: None,
1992        });
1993        let model3 = cx.new_model(|_| TestModel {
1994            id: 3,
1995            subscription: None,
1996        });
1997
1998        let _subscription1 = client
1999            .subscribe_to_entity(1)
2000            .unwrap()
2001            .set_model(&model1, &mut cx.to_async());
2002        let _subscription2 = client
2003            .subscribe_to_entity(2)
2004            .unwrap()
2005            .set_model(&model2, &mut cx.to_async());
2006        // Ensure dropping a subscription for the same entity type still allows receiving of
2007        // messages for other entity IDs of the same type.
2008        let subscription3 = client
2009            .subscribe_to_entity(3)
2010            .unwrap()
2011            .set_model(&model3, &mut cx.to_async());
2012        drop(subscription3);
2013
2014        server.send(proto::JoinProject { project_id: 1 });
2015        server.send(proto::JoinProject { project_id: 2 });
2016        done_rx1.next().await.unwrap();
2017        done_rx2.next().await.unwrap();
2018    }
2019
2020    #[gpui::test]
2021    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2022        init_test(cx);
2023        let user_id = 5;
2024        let client = cx.update(|cx| {
2025            Client::new(
2026                Arc::new(FakeSystemClock::default()),
2027                FakeHttpClient::with_404_response(),
2028                cx,
2029            )
2030        });
2031        let server = FakeServer::for_client(user_id, &client, cx).await;
2032
2033        let model = cx.new_model(|_| TestModel::default());
2034        let (done_tx1, _done_rx1) = smol::channel::unbounded();
2035        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
2036        let subscription1 = client.add_message_handler(
2037            model.downgrade(),
2038            move |_, _: TypedEnvelope<proto::Ping>, _| {
2039                done_tx1.try_send(()).unwrap();
2040                async { Ok(()) }
2041            },
2042        );
2043        drop(subscription1);
2044        let _subscription2 = client.add_message_handler(
2045            model.downgrade(),
2046            move |_, _: TypedEnvelope<proto::Ping>, _| {
2047                done_tx2.try_send(()).unwrap();
2048                async { Ok(()) }
2049            },
2050        );
2051        server.send(proto::Ping {});
2052        done_rx2.next().await.unwrap();
2053    }
2054
2055    #[gpui::test]
2056    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2057        init_test(cx);
2058        let user_id = 5;
2059        let client = cx.update(|cx| {
2060            Client::new(
2061                Arc::new(FakeSystemClock::default()),
2062                FakeHttpClient::with_404_response(),
2063                cx,
2064            )
2065        });
2066        let server = FakeServer::for_client(user_id, &client, cx).await;
2067
2068        let model = cx.new_model(|_| TestModel::default());
2069        let (done_tx, mut done_rx) = smol::channel::unbounded();
2070        let subscription = client.add_message_handler(
2071            model.clone().downgrade(),
2072            move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, mut cx| {
2073                model
2074                    .update(&mut cx, |model, _| model.subscription.take())
2075                    .unwrap();
2076                done_tx.try_send(()).unwrap();
2077                async { Ok(()) }
2078            },
2079        );
2080        model.update(cx, |model, _| {
2081            model.subscription = Some(subscription);
2082        });
2083        server.send(proto::Ping {});
2084        done_rx.next().await.unwrap();
2085    }
2086
2087    #[derive(Default)]
2088    struct TestModel {
2089        id: usize,
2090        subscription: Option<Subscription>,
2091    }
2092
2093    fn init_test(cx: &mut TestAppContext) {
2094        cx.update(|cx| {
2095            let settings_store = SettingsStore::test(cx);
2096            cx.set_global(settings_store);
2097            init_settings(cx);
2098        });
2099    }
2100}