client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod telemetry;
   5pub mod user;
   6
   7use anyhow::{anyhow, Context as _, Result};
   8use async_recursion::async_recursion;
   9use async_tungstenite::tungstenite::{
  10    error::Error as WebsocketError,
  11    http::{Request, StatusCode},
  12};
  13use clock::SystemClock;
  14use collections::HashMap;
  15use futures::{
  16    channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
  17    TryFutureExt as _, TryStreamExt,
  18};
  19use gpui::{
  20    actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Global, Model, Task, WeakModel,
  21};
  22use http::{HttpClient, HttpClientWithUrl};
  23use lazy_static::lazy_static;
  24use parking_lot::RwLock;
  25use postage::watch;
  26use rand::prelude::*;
  27use release_channel::{AppVersion, ReleaseChannel};
  28use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::{Settings, SettingsSources};
  32use std::fmt;
  33use std::pin::Pin;
  34use std::{
  35    any::TypeId,
  36    convert::TryFrom,
  37    fmt::Write as _,
  38    future::Future,
  39    marker::PhantomData,
  40    path::PathBuf,
  41    sync::{
  42        atomic::{AtomicU64, Ordering},
  43        Arc, Weak,
  44    },
  45    time::{Duration, Instant},
  46};
  47use telemetry::Telemetry;
  48use thiserror::Error;
  49use url::Url;
  50use util::{ResultExt, TryFutureExt};
  51
  52pub use rpc::*;
  53pub use telemetry_events::Event;
  54pub use user::*;
  55
  56#[derive(Debug, Clone, Eq, PartialEq)]
  57pub struct DevServerToken(pub String);
  58
  59impl fmt::Display for DevServerToken {
  60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  61        write!(f, "{}", self.0)
  62    }
  63}
  64
  65lazy_static! {
  66    static ref ZED_SERVER_URL: Option<String> = std::env::var("ZED_SERVER_URL").ok();
  67    static ref ZED_RPC_URL: Option<String> = std::env::var("ZED_RPC_URL").ok();
  68    /// An environment variable whose presence indicates that the development auth
  69    /// provider should be used.
  70    ///
  71    /// Only works in development. Setting this environment variable in other release
  72    /// channels is a no-op.
  73    pub static ref ZED_DEVELOPMENT_AUTH: bool =
  74        std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty());
  75    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  76        .ok()
  77        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  78    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
  79        .ok()
  80        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  81    pub static ref ZED_APP_PATH: Option<PathBuf> =
  82        std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
  83    pub static ref ZED_ALWAYS_ACTIVE: bool =
  84        std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty());
  85}
  86
  87pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  88pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
  89pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  90
  91actions!(client, [SignIn, SignOut, Reconnect]);
  92
  93#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
  94pub struct ClientSettingsContent {
  95    server_url: Option<String>,
  96}
  97
  98#[derive(Deserialize)]
  99pub struct ClientSettings {
 100    pub server_url: String,
 101}
 102
 103impl Settings for ClientSettings {
 104    const KEY: Option<&'static str> = None;
 105
 106    type FileContent = ClientSettingsContent;
 107
 108    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 109        let mut result = sources.json_merge::<Self>()?;
 110        if let Some(server_url) = &*ZED_SERVER_URL {
 111            result.server_url.clone_from(&server_url)
 112        }
 113        Ok(result)
 114    }
 115}
 116
 117#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 118pub struct ProxySettingsContent {
 119    proxy: Option<String>,
 120}
 121
 122#[derive(Deserialize, Default)]
 123pub struct ProxySettings {
 124    pub proxy: Option<String>,
 125}
 126
 127impl Settings for ProxySettings {
 128    const KEY: Option<&'static str> = None;
 129
 130    type FileContent = ProxySettingsContent;
 131
 132    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 133        Ok(Self {
 134            proxy: sources
 135                .user
 136                .and_then(|value| value.proxy.clone())
 137                .or(sources.default.proxy.clone()),
 138        })
 139    }
 140}
 141
 142pub fn init_settings(cx: &mut AppContext) {
 143    TelemetrySettings::register(cx);
 144    ClientSettings::register(cx);
 145    ProxySettings::register(cx);
 146}
 147
 148pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
 149    let client = Arc::downgrade(client);
 150    cx.on_action({
 151        let client = client.clone();
 152        move |_: &SignIn, cx| {
 153            if let Some(client) = client.upgrade() {
 154                cx.spawn(
 155                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
 156                )
 157                .detach();
 158            }
 159        }
 160    });
 161
 162    cx.on_action({
 163        let client = client.clone();
 164        move |_: &SignOut, cx| {
 165            if let Some(client) = client.upgrade() {
 166                cx.spawn(|cx| async move {
 167                    client.sign_out(&cx).await;
 168                })
 169                .detach();
 170            }
 171        }
 172    });
 173
 174    cx.on_action({
 175        let client = client.clone();
 176        move |_: &Reconnect, cx| {
 177            if let Some(client) = client.upgrade() {
 178                cx.spawn(|cx| async move {
 179                    client.reconnect(&cx);
 180                })
 181                .detach();
 182            }
 183        }
 184    });
 185}
 186
 187struct GlobalClient(Arc<Client>);
 188
 189impl Global for GlobalClient {}
 190
 191pub struct Client {
 192    id: AtomicU64,
 193    peer: Arc<Peer>,
 194    http: Arc<HttpClientWithUrl>,
 195    telemetry: Arc<Telemetry>,
 196    credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
 197    state: RwLock<ClientState>,
 198
 199    #[allow(clippy::type_complexity)]
 200    #[cfg(any(test, feature = "test-support"))]
 201    authenticate: RwLock<
 202        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 203    >,
 204
 205    #[allow(clippy::type_complexity)]
 206    #[cfg(any(test, feature = "test-support"))]
 207    establish_connection: RwLock<
 208        Option<
 209            Box<
 210                dyn 'static
 211                    + Send
 212                    + Sync
 213                    + Fn(
 214                        &Credentials,
 215                        &AsyncAppContext,
 216                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 217            >,
 218        >,
 219    >,
 220
 221    #[cfg(any(test, feature = "test-support"))]
 222    rpc_url: RwLock<Option<Url>>,
 223}
 224
 225#[derive(Error, Debug)]
 226pub enum EstablishConnectionError {
 227    #[error("upgrade required")]
 228    UpgradeRequired,
 229    #[error("unauthorized")]
 230    Unauthorized,
 231    #[error("{0}")]
 232    Other(#[from] anyhow::Error),
 233    #[error("{0}")]
 234    Http(#[from] http::Error),
 235    #[error("{0}")]
 236    Io(#[from] std::io::Error),
 237    #[error("{0}")]
 238    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 239}
 240
 241impl From<WebsocketError> for EstablishConnectionError {
 242    fn from(error: WebsocketError) -> Self {
 243        if let WebsocketError::Http(response) = &error {
 244            match response.status() {
 245                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 246                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 247                _ => {}
 248            }
 249        }
 250        EstablishConnectionError::Other(error.into())
 251    }
 252}
 253
 254impl EstablishConnectionError {
 255    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 256        Self::Other(error.into())
 257    }
 258}
 259
 260#[derive(Copy, Clone, Debug, PartialEq)]
 261pub enum Status {
 262    SignedOut,
 263    UpgradeRequired,
 264    Authenticating,
 265    Connecting,
 266    ConnectionError,
 267    Connected {
 268        peer_id: PeerId,
 269        connection_id: ConnectionId,
 270    },
 271    ConnectionLost,
 272    Reauthenticating,
 273    Reconnecting,
 274    ReconnectionError {
 275        next_reconnection: Instant,
 276    },
 277}
 278
 279impl Status {
 280    pub fn is_connected(&self) -> bool {
 281        matches!(self, Self::Connected { .. })
 282    }
 283
 284    pub fn is_signed_out(&self) -> bool {
 285        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 286    }
 287}
 288
 289struct ClientState {
 290    credentials: Option<Credentials>,
 291    status: (watch::Sender<Status>, watch::Receiver<Status>),
 292    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 293    _reconnect_task: Option<Task<()>>,
 294    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 295    models_by_message_type: HashMap<TypeId, AnyWeakModel>,
 296    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 297    #[allow(clippy::type_complexity)]
 298    message_handlers: HashMap<
 299        TypeId,
 300        Arc<
 301            dyn Send
 302                + Sync
 303                + Fn(
 304                    AnyModel,
 305                    Box<dyn AnyTypedEnvelope>,
 306                    &Arc<Client>,
 307                    AsyncAppContext,
 308                ) -> LocalBoxFuture<'static, Result<()>>,
 309        >,
 310    >,
 311}
 312
 313enum WeakSubscriber {
 314    Entity { handle: AnyWeakModel },
 315    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 316}
 317
 318#[derive(Clone, Debug, Eq, PartialEq)]
 319pub enum Credentials {
 320    DevServer { token: DevServerToken },
 321    User { user_id: u64, access_token: String },
 322}
 323
 324impl Credentials {
 325    pub fn authorization_header(&self) -> String {
 326        match self {
 327            Credentials::DevServer { token } => format!("dev-server-token {}", token),
 328            Credentials::User {
 329                user_id,
 330                access_token,
 331            } => format!("{} {}", user_id, access_token),
 332        }
 333    }
 334}
 335
 336/// A provider for [`Credentials`].
 337///
 338/// Used to abstract over reading and writing credentials to some form of
 339/// persistence (like the system keychain).
 340trait CredentialsProvider {
 341    /// Reads the credentials from the provider.
 342    fn read_credentials<'a>(
 343        &'a self,
 344        cx: &'a AsyncAppContext,
 345    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>>;
 346
 347    /// Writes the credentials to the provider.
 348    fn write_credentials<'a>(
 349        &'a self,
 350        user_id: u64,
 351        access_token: String,
 352        cx: &'a AsyncAppContext,
 353    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 354
 355    /// Deletes the credentials from the provider.
 356    fn delete_credentials<'a>(
 357        &'a self,
 358        cx: &'a AsyncAppContext,
 359    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 360}
 361
 362impl Default for ClientState {
 363    fn default() -> Self {
 364        Self {
 365            credentials: None,
 366            status: watch::channel_with(Status::SignedOut),
 367            entity_id_extractors: Default::default(),
 368            _reconnect_task: None,
 369            models_by_message_type: Default::default(),
 370            entities_by_type_and_remote_id: Default::default(),
 371            entity_types_by_message_type: Default::default(),
 372            message_handlers: Default::default(),
 373        }
 374    }
 375}
 376
 377pub enum Subscription {
 378    Entity {
 379        client: Weak<Client>,
 380        id: (TypeId, u64),
 381    },
 382    Message {
 383        client: Weak<Client>,
 384        id: TypeId,
 385    },
 386}
 387
 388impl Drop for Subscription {
 389    fn drop(&mut self) {
 390        match self {
 391            Subscription::Entity { client, id } => {
 392                if let Some(client) = client.upgrade() {
 393                    let mut state = client.state.write();
 394                    let _ = state.entities_by_type_and_remote_id.remove(id);
 395                }
 396            }
 397            Subscription::Message { client, id } => {
 398                if let Some(client) = client.upgrade() {
 399                    let mut state = client.state.write();
 400                    let _ = state.entity_types_by_message_type.remove(id);
 401                    let _ = state.message_handlers.remove(id);
 402                }
 403            }
 404        }
 405    }
 406}
 407
 408pub struct PendingEntitySubscription<T: 'static> {
 409    client: Arc<Client>,
 410    remote_id: u64,
 411    _entity_type: PhantomData<T>,
 412    consumed: bool,
 413}
 414
 415impl<T: 'static> PendingEntitySubscription<T> {
 416    pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
 417        self.consumed = true;
 418        let mut state = self.client.state.write();
 419        let id = (TypeId::of::<T>(), self.remote_id);
 420        let Some(WeakSubscriber::Pending(messages)) =
 421            state.entities_by_type_and_remote_id.remove(&id)
 422        else {
 423            unreachable!()
 424        };
 425
 426        state.entities_by_type_and_remote_id.insert(
 427            id,
 428            WeakSubscriber::Entity {
 429                handle: model.downgrade().into(),
 430            },
 431        );
 432        drop(state);
 433        for message in messages {
 434            self.client.handle_message(message, cx);
 435        }
 436        Subscription::Entity {
 437            client: Arc::downgrade(&self.client),
 438            id,
 439        }
 440    }
 441}
 442
 443impl<T: 'static> Drop for PendingEntitySubscription<T> {
 444    fn drop(&mut self) {
 445        if !self.consumed {
 446            let mut state = self.client.state.write();
 447            if let Some(WeakSubscriber::Pending(messages)) = state
 448                .entities_by_type_and_remote_id
 449                .remove(&(TypeId::of::<T>(), self.remote_id))
 450            {
 451                for message in messages {
 452                    log::info!("unhandled message {}", message.payload_type_name());
 453                }
 454            }
 455        }
 456    }
 457}
 458
 459#[derive(Copy, Clone)]
 460pub struct TelemetrySettings {
 461    pub diagnostics: bool,
 462    pub metrics: bool,
 463}
 464
 465/// Control what info is collected by Zed.
 466#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 467pub struct TelemetrySettingsContent {
 468    /// Send debug info like crash reports.
 469    ///
 470    /// Default: true
 471    pub diagnostics: Option<bool>,
 472    /// Send anonymized usage data like what languages you're using Zed with.
 473    ///
 474    /// Default: true
 475    pub metrics: Option<bool>,
 476}
 477
 478impl settings::Settings for TelemetrySettings {
 479    const KEY: Option<&'static str> = Some("telemetry");
 480
 481    type FileContent = TelemetrySettingsContent;
 482
 483    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 484        Ok(Self {
 485            diagnostics: sources.user.as_ref().and_then(|v| v.diagnostics).unwrap_or(
 486                sources
 487                    .default
 488                    .diagnostics
 489                    .ok_or_else(Self::missing_default)?,
 490            ),
 491            metrics: sources
 492                .user
 493                .as_ref()
 494                .and_then(|v| v.metrics)
 495                .unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
 496        })
 497    }
 498}
 499
 500impl Client {
 501    pub fn new(
 502        clock: Arc<dyn SystemClock>,
 503        http: Arc<HttpClientWithUrl>,
 504        cx: &mut AppContext,
 505    ) -> Arc<Self> {
 506        let use_zed_development_auth = match ReleaseChannel::try_global(cx) {
 507            Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH,
 508            Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
 509            | None => false,
 510        };
 511
 512        let credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static> =
 513            if use_zed_development_auth {
 514                Arc::new(DevelopmentCredentialsProvider {
 515                    path: paths::config_dir().join("development_auth"),
 516                })
 517            } else {
 518                Arc::new(KeychainCredentialsProvider)
 519            };
 520
 521        Arc::new(Self {
 522            id: AtomicU64::new(0),
 523            peer: Peer::new(0),
 524            telemetry: Telemetry::new(clock, http.clone(), cx),
 525            http,
 526            credentials_provider,
 527            state: Default::default(),
 528
 529            #[cfg(any(test, feature = "test-support"))]
 530            authenticate: Default::default(),
 531            #[cfg(any(test, feature = "test-support"))]
 532            establish_connection: Default::default(),
 533            #[cfg(any(test, feature = "test-support"))]
 534            rpc_url: RwLock::default(),
 535        })
 536    }
 537
 538    pub fn production(cx: &mut AppContext) -> Arc<Self> {
 539        let clock = Arc::new(clock::RealSystemClock);
 540        let http = Arc::new(HttpClientWithUrl::new(
 541            &ClientSettings::get_global(cx).server_url,
 542            ProxySettings::get_global(cx).proxy.clone(),
 543        ));
 544        Self::new(clock, http.clone(), cx)
 545    }
 546
 547    pub fn id(&self) -> u64 {
 548        self.id.load(Ordering::SeqCst)
 549    }
 550
 551    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 552        self.http.clone()
 553    }
 554
 555    pub fn set_id(&self, id: u64) -> &Self {
 556        self.id.store(id, Ordering::SeqCst);
 557        self
 558    }
 559
 560    #[cfg(any(test, feature = "test-support"))]
 561    pub fn teardown(&self) {
 562        let mut state = self.state.write();
 563        state._reconnect_task.take();
 564        state.message_handlers.clear();
 565        state.models_by_message_type.clear();
 566        state.entities_by_type_and_remote_id.clear();
 567        state.entity_id_extractors.clear();
 568        self.peer.teardown();
 569    }
 570
 571    #[cfg(any(test, feature = "test-support"))]
 572    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 573    where
 574        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 575    {
 576        *self.authenticate.write() = Some(Box::new(authenticate));
 577        self
 578    }
 579
 580    #[cfg(any(test, feature = "test-support"))]
 581    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 582    where
 583        F: 'static
 584            + Send
 585            + Sync
 586            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 587    {
 588        *self.establish_connection.write() = Some(Box::new(connect));
 589        self
 590    }
 591
 592    #[cfg(any(test, feature = "test-support"))]
 593    pub fn override_rpc_url(&self, url: Url) -> &Self {
 594        *self.rpc_url.write() = Some(url);
 595        self
 596    }
 597
 598    pub fn global(cx: &AppContext) -> Arc<Self> {
 599        cx.global::<GlobalClient>().0.clone()
 600    }
 601    pub fn set_global(client: Arc<Client>, cx: &mut AppContext) {
 602        cx.set_global(GlobalClient(client))
 603    }
 604
 605    pub fn user_id(&self) -> Option<u64> {
 606        if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() {
 607            Some(*user_id)
 608        } else {
 609            None
 610        }
 611    }
 612
 613    pub fn peer_id(&self) -> Option<PeerId> {
 614        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 615            Some(*peer_id)
 616        } else {
 617            None
 618        }
 619    }
 620
 621    pub fn status(&self) -> watch::Receiver<Status> {
 622        self.state.read().status.1.clone()
 623    }
 624
 625    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 626        log::info!("set status on client {}: {:?}", self.id(), status);
 627        let mut state = self.state.write();
 628        *state.status.0.borrow_mut() = status;
 629
 630        match status {
 631            Status::Connected { .. } => {
 632                state._reconnect_task = None;
 633            }
 634            Status::ConnectionLost => {
 635                let this = self.clone();
 636                state._reconnect_task = Some(cx.spawn(move |cx| async move {
 637                    #[cfg(any(test, feature = "test-support"))]
 638                    let mut rng = StdRng::seed_from_u64(0);
 639                    #[cfg(not(any(test, feature = "test-support")))]
 640                    let mut rng = StdRng::from_entropy();
 641
 642                    let mut delay = INITIAL_RECONNECTION_DELAY;
 643                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 644                        log::error!("failed to connect {}", error);
 645                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 646                            this.set_status(
 647                                Status::ReconnectionError {
 648                                    next_reconnection: Instant::now() + delay,
 649                                },
 650                                &cx,
 651                            );
 652                            cx.background_executor().timer(delay).await;
 653                            delay = delay
 654                                .mul_f32(rng.gen_range(0.5..=2.5))
 655                                .max(INITIAL_RECONNECTION_DELAY)
 656                                .min(MAX_RECONNECTION_DELAY);
 657                        } else {
 658                            break;
 659                        }
 660                    }
 661                }));
 662            }
 663            Status::SignedOut | Status::UpgradeRequired => {
 664                self.telemetry.set_authenticated_user_info(None, false);
 665                state._reconnect_task.take();
 666            }
 667            _ => {}
 668        }
 669    }
 670
 671    pub fn subscribe_to_entity<T>(
 672        self: &Arc<Self>,
 673        remote_id: u64,
 674    ) -> Result<PendingEntitySubscription<T>>
 675    where
 676        T: 'static,
 677    {
 678        let id = (TypeId::of::<T>(), remote_id);
 679
 680        let mut state = self.state.write();
 681        if state.entities_by_type_and_remote_id.contains_key(&id) {
 682            return Err(anyhow!("already subscribed to entity"));
 683        }
 684
 685        state
 686            .entities_by_type_and_remote_id
 687            .insert(id, WeakSubscriber::Pending(Default::default()));
 688
 689        Ok(PendingEntitySubscription {
 690            client: self.clone(),
 691            remote_id,
 692            consumed: false,
 693            _entity_type: PhantomData,
 694        })
 695    }
 696
 697    #[track_caller]
 698    pub fn add_message_handler<M, E, H, F>(
 699        self: &Arc<Self>,
 700        entity: WeakModel<E>,
 701        handler: H,
 702    ) -> Subscription
 703    where
 704        M: EnvelopedMessage,
 705        E: 'static,
 706        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 707        F: 'static + Future<Output = Result<()>>,
 708    {
 709        self.add_message_handler_impl(entity, move |model, message, _, cx| {
 710            handler(model, message, cx)
 711        })
 712    }
 713
 714    fn add_message_handler_impl<M, E, H, F>(
 715        self: &Arc<Self>,
 716        entity: WeakModel<E>,
 717        handler: H,
 718    ) -> Subscription
 719    where
 720        M: EnvelopedMessage,
 721        E: 'static,
 722        H: 'static
 723            + Sync
 724            + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
 725            + Send
 726            + Sync,
 727        F: 'static + Future<Output = Result<()>>,
 728    {
 729        let message_type_id = TypeId::of::<M>();
 730        let mut state = self.state.write();
 731        state
 732            .models_by_message_type
 733            .insert(message_type_id, entity.into());
 734
 735        let prev_handler = state.message_handlers.insert(
 736            message_type_id,
 737            Arc::new(move |subscriber, envelope, client, cx| {
 738                let subscriber = subscriber.downcast::<E>().unwrap();
 739                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 740                handler(subscriber, *envelope, client.clone(), cx).boxed_local()
 741            }),
 742        );
 743        if prev_handler.is_some() {
 744            let location = std::panic::Location::caller();
 745            panic!(
 746                "{}:{} registered handler for the same message {} twice",
 747                location.file(),
 748                location.line(),
 749                std::any::type_name::<M>()
 750            );
 751        }
 752
 753        Subscription::Message {
 754            client: Arc::downgrade(self),
 755            id: message_type_id,
 756        }
 757    }
 758
 759    pub fn add_request_handler<M, E, H, F>(
 760        self: &Arc<Self>,
 761        model: WeakModel<E>,
 762        handler: H,
 763    ) -> Subscription
 764    where
 765        M: RequestMessage,
 766        E: 'static,
 767        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 768        F: 'static + Future<Output = Result<M::Response>>,
 769    {
 770        self.add_message_handler_impl(model, move |handle, envelope, this, cx| {
 771            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 772        })
 773    }
 774
 775    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 776    where
 777        M: EntityMessage,
 778        E: 'static,
 779        H: 'static + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 780        F: 'static + Future<Output = Result<()>>,
 781    {
 782        self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, _, cx| {
 783            handler(subscriber.downcast::<E>().unwrap(), message, cx)
 784        })
 785    }
 786
 787    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 788    where
 789        M: EntityMessage,
 790        E: 'static,
 791        H: 'static + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
 792        F: 'static + Future<Output = Result<()>>,
 793    {
 794        let model_type_id = TypeId::of::<E>();
 795        let message_type_id = TypeId::of::<M>();
 796
 797        let mut state = self.state.write();
 798        state
 799            .entity_types_by_message_type
 800            .insert(message_type_id, model_type_id);
 801        state
 802            .entity_id_extractors
 803            .entry(message_type_id)
 804            .or_insert_with(|| {
 805                |envelope| {
 806                    envelope
 807                        .as_any()
 808                        .downcast_ref::<TypedEnvelope<M>>()
 809                        .unwrap()
 810                        .payload
 811                        .remote_entity_id()
 812                }
 813            });
 814        let prev_handler = state.message_handlers.insert(
 815            message_type_id,
 816            Arc::new(move |handle, envelope, client, cx| {
 817                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 818                handler(handle, *envelope, client.clone(), cx).boxed_local()
 819            }),
 820        );
 821        if prev_handler.is_some() {
 822            panic!("registered handler for the same message twice");
 823        }
 824    }
 825
 826    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 827    where
 828        M: EntityMessage + RequestMessage,
 829        E: 'static,
 830        H: 'static + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 831        F: 'static + Future<Output = Result<M::Response>>,
 832    {
 833        self.add_entity_message_handler::<M, E, _, _>(move |entity, envelope, client, cx| {
 834            Self::respond_to_request::<M, _>(
 835                envelope.receipt(),
 836                handler(entity.downcast::<E>().unwrap(), envelope, cx),
 837                client,
 838            )
 839        })
 840    }
 841
 842    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 843        receipt: Receipt<T>,
 844        response: F,
 845        client: Arc<Self>,
 846    ) -> Result<()> {
 847        match response.await {
 848            Ok(response) => {
 849                client.respond(receipt, response)?;
 850                Ok(())
 851            }
 852            Err(error) => {
 853                client.respond_with_error(receipt, error.to_proto())?;
 854                Err(error)
 855            }
 856        }
 857    }
 858
 859    pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool {
 860        self.credentials_provider
 861            .read_credentials(cx)
 862            .await
 863            .is_some()
 864    }
 865
 866    pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self {
 867        self.state.write().credentials = Some(Credentials::DevServer { token });
 868        self
 869    }
 870
 871    #[async_recursion(?Send)]
 872    pub async fn authenticate_and_connect(
 873        self: &Arc<Self>,
 874        try_provider: bool,
 875        cx: &AsyncAppContext,
 876    ) -> anyhow::Result<()> {
 877        let was_disconnected = match *self.status().borrow() {
 878            Status::SignedOut => true,
 879            Status::ConnectionError
 880            | Status::ConnectionLost
 881            | Status::Authenticating { .. }
 882            | Status::Reauthenticating { .. }
 883            | Status::ReconnectionError { .. } => false,
 884            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 885                return Ok(())
 886            }
 887            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 888        };
 889        if was_disconnected {
 890            self.set_status(Status::Authenticating, cx);
 891        } else {
 892            self.set_status(Status::Reauthenticating, cx)
 893        }
 894
 895        let mut read_from_provider = false;
 896        let mut credentials = self.state.read().credentials.clone();
 897        if credentials.is_none() && try_provider {
 898            credentials = self.credentials_provider.read_credentials(cx).await;
 899            read_from_provider = credentials.is_some();
 900        }
 901
 902        if credentials.is_none() {
 903            let mut status_rx = self.status();
 904            let _ = status_rx.next().await;
 905            futures::select_biased! {
 906                authenticate = self.authenticate(cx).fuse() => {
 907                    match authenticate {
 908                        Ok(creds) => credentials = Some(creds),
 909                        Err(err) => {
 910                            self.set_status(Status::ConnectionError, cx);
 911                            return Err(err);
 912                        }
 913                    }
 914                }
 915                _ = status_rx.next().fuse() => {
 916                    return Err(anyhow!("authentication canceled"));
 917                }
 918            }
 919        }
 920        let credentials = credentials.unwrap();
 921        if let Credentials::User { user_id, .. } = &credentials {
 922            self.set_id(*user_id);
 923        }
 924
 925        if was_disconnected {
 926            self.set_status(Status::Connecting, cx);
 927        } else {
 928            self.set_status(Status::Reconnecting, cx);
 929        }
 930
 931        let mut timeout =
 932            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
 933        futures::select_biased! {
 934            connection = self.establish_connection(&credentials, cx).fuse() => {
 935                match connection {
 936                    Ok(conn) => {
 937                        self.state.write().credentials = Some(credentials.clone());
 938                        if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
 939                            if let Credentials::User{user_id, access_token} = credentials {
 940                                self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err();
 941                            }
 942                        }
 943
 944                        futures::select_biased! {
 945                            result = self.set_connection(conn, cx).fuse() => result,
 946                            _ = timeout => {
 947                                self.set_status(Status::ConnectionError, cx);
 948                                Err(anyhow!("timed out waiting on hello message from server"))
 949                            }
 950                        }
 951                    }
 952                    Err(EstablishConnectionError::Unauthorized) => {
 953                        self.state.write().credentials.take();
 954                        if read_from_provider {
 955                            self.credentials_provider.delete_credentials(cx).await.log_err();
 956                            self.set_status(Status::SignedOut, cx);
 957                            self.authenticate_and_connect(false, cx).await
 958                        } else {
 959                            self.set_status(Status::ConnectionError, cx);
 960                            Err(EstablishConnectionError::Unauthorized)?
 961                        }
 962                    }
 963                    Err(EstablishConnectionError::UpgradeRequired) => {
 964                        self.set_status(Status::UpgradeRequired, cx);
 965                        Err(EstablishConnectionError::UpgradeRequired)?
 966                    }
 967                    Err(error) => {
 968                        self.set_status(Status::ConnectionError, cx);
 969                        Err(error)?
 970                    }
 971                }
 972            }
 973            _ = &mut timeout => {
 974                self.set_status(Status::ConnectionError, cx);
 975                Err(anyhow!("timed out trying to establish connection"))
 976            }
 977        }
 978    }
 979
 980    async fn set_connection(
 981        self: &Arc<Self>,
 982        conn: Connection,
 983        cx: &AsyncAppContext,
 984    ) -> Result<()> {
 985        let executor = cx.background_executor();
 986        log::info!("add connection to peer");
 987        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
 988            let executor = executor.clone();
 989            move |duration| executor.timer(duration)
 990        });
 991        let handle_io = executor.spawn(handle_io);
 992
 993        let peer_id = async {
 994            log::info!("waiting for server hello");
 995            let message = incoming
 996                .next()
 997                .await
 998                .ok_or_else(|| anyhow!("no hello message received"))?;
 999            log::info!("got server hello");
1000            let hello_message_type_name = message.payload_type_name().to_string();
1001            let hello = message
1002                .into_any()
1003                .downcast::<TypedEnvelope<proto::Hello>>()
1004                .map_err(|_| {
1005                    anyhow!(
1006                        "invalid hello message received: {:?}",
1007                        hello_message_type_name
1008                    )
1009                })?;
1010            let peer_id = hello
1011                .payload
1012                .peer_id
1013                .ok_or_else(|| anyhow!("invalid peer id"))?;
1014            Ok(peer_id)
1015        };
1016
1017        let peer_id = match peer_id.await {
1018            Ok(peer_id) => peer_id,
1019            Err(error) => {
1020                self.peer.disconnect(connection_id);
1021                return Err(error);
1022            }
1023        };
1024
1025        log::info!(
1026            "set status to connected (connection id: {:?}, peer id: {:?})",
1027            connection_id,
1028            peer_id
1029        );
1030        self.set_status(
1031            Status::Connected {
1032                peer_id,
1033                connection_id,
1034            },
1035            cx,
1036        );
1037
1038        cx.spawn({
1039            let this = self.clone();
1040            |cx| {
1041                async move {
1042                    while let Some(message) = incoming.next().await {
1043                        this.handle_message(message, &cx);
1044                        // Don't starve the main thread when receiving lots of messages at once.
1045                        smol::future::yield_now().await;
1046                    }
1047                }
1048            }
1049        })
1050        .detach();
1051
1052        cx.spawn({
1053            let this = self.clone();
1054            move |cx| async move {
1055                match handle_io.await {
1056                    Ok(()) => {
1057                        if *this.status().borrow()
1058                            == (Status::Connected {
1059                                connection_id,
1060                                peer_id,
1061                            })
1062                        {
1063                            this.set_status(Status::SignedOut, &cx);
1064                        }
1065                    }
1066                    Err(err) => {
1067                        log::error!("connection error: {:?}", err);
1068                        this.set_status(Status::ConnectionLost, &cx);
1069                    }
1070                }
1071            }
1072        })
1073        .detach();
1074
1075        Ok(())
1076    }
1077
1078    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
1079        #[cfg(any(test, feature = "test-support"))]
1080        if let Some(callback) = self.authenticate.read().as_ref() {
1081            return callback(cx);
1082        }
1083
1084        self.authenticate_with_browser(cx)
1085    }
1086
1087    fn establish_connection(
1088        self: &Arc<Self>,
1089        credentials: &Credentials,
1090        cx: &AsyncAppContext,
1091    ) -> Task<Result<Connection, EstablishConnectionError>> {
1092        #[cfg(any(test, feature = "test-support"))]
1093        if let Some(callback) = self.establish_connection.read().as_ref() {
1094            return callback(credentials, cx);
1095        }
1096
1097        self.establish_websocket_connection(credentials, cx)
1098    }
1099
1100    fn rpc_url(
1101        &self,
1102        http: Arc<HttpClientWithUrl>,
1103        release_channel: Option<ReleaseChannel>,
1104    ) -> impl Future<Output = Result<Url>> {
1105        #[cfg(any(test, feature = "test-support"))]
1106        let url_override = self.rpc_url.read().clone();
1107
1108        async move {
1109            #[cfg(any(test, feature = "test-support"))]
1110            if let Some(url) = url_override {
1111                return Ok(url);
1112            }
1113
1114            if let Some(url) = &*ZED_RPC_URL {
1115                return Url::parse(url).context("invalid rpc url");
1116            }
1117
1118            let mut url = http.build_url("/rpc");
1119            if let Some(preview_param) =
1120                release_channel.and_then(|channel| channel.release_query_param())
1121            {
1122                url += "?";
1123                url += preview_param;
1124            }
1125
1126            let response = http.get(&url, Default::default(), false).await?;
1127            let collab_url = if response.status().is_redirection() {
1128                response
1129                    .headers()
1130                    .get("Location")
1131                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1132                    .to_str()
1133                    .map_err(EstablishConnectionError::other)?
1134                    .to_string()
1135            } else {
1136                Err(anyhow!(
1137                    "unexpected /rpc response status {}",
1138                    response.status()
1139                ))?
1140            };
1141
1142            Url::parse(&collab_url).context("invalid rpc url")
1143        }
1144    }
1145
1146    fn establish_websocket_connection(
1147        self: &Arc<Self>,
1148        credentials: &Credentials,
1149        cx: &AsyncAppContext,
1150    ) -> Task<Result<Connection, EstablishConnectionError>> {
1151        let release_channel = cx
1152            .update(|cx| ReleaseChannel::try_global(cx))
1153            .ok()
1154            .flatten();
1155        let app_version = cx
1156            .update(|cx| AppVersion::global(cx).to_string())
1157            .ok()
1158            .unwrap_or_default();
1159
1160        let request = Request::builder()
1161            .header("Authorization", credentials.authorization_header())
1162            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION)
1163            .header("x-zed-app-version", app_version)
1164            .header(
1165                "x-zed-release-channel",
1166                release_channel.map(|r| r.dev_name()).unwrap_or("unknown"),
1167            );
1168
1169        let http = self.http.clone();
1170        let rpc_url = self.rpc_url(http, release_channel);
1171        cx.background_executor().spawn(async move {
1172            let mut rpc_url = rpc_url.await?;
1173            let rpc_host = rpc_url
1174                .host_str()
1175                .zip(rpc_url.port_or_known_default())
1176                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1177            let stream = smol::net::TcpStream::connect(rpc_host).await?;
1178
1179            log::info!("connected to rpc endpoint {}", rpc_url);
1180
1181            match rpc_url.scheme() {
1182                "https" => {
1183                    rpc_url.set_scheme("wss").unwrap();
1184                    let request = request.uri(rpc_url.as_str()).body(())?;
1185                    let (stream, _) =
1186                        async_tungstenite::async_std::client_async_tls(request, stream).await?;
1187                    Ok(Connection::new(
1188                        stream
1189                            .map_err(|error| anyhow!(error))
1190                            .sink_map_err(|error| anyhow!(error)),
1191                    ))
1192                }
1193                "http" => {
1194                    rpc_url.set_scheme("ws").unwrap();
1195                    let request = request.uri(rpc_url.as_str()).body(())?;
1196                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1197                    Ok(Connection::new(
1198                        stream
1199                            .map_err(|error| anyhow!(error))
1200                            .sink_map_err(|error| anyhow!(error)),
1201                    ))
1202                }
1203                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1204            }
1205        })
1206    }
1207
1208    pub fn authenticate_with_browser(
1209        self: &Arc<Self>,
1210        cx: &AsyncAppContext,
1211    ) -> Task<Result<Credentials>> {
1212        let http = self.http.clone();
1213        let this = self.clone();
1214        cx.spawn(|cx| async move {
1215            let background = cx.background_executor().clone();
1216
1217            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1218            cx.update(|cx| {
1219                cx.spawn(move |cx| async move {
1220                    let url = open_url_rx.await?;
1221                    cx.update(|cx| cx.open_url(&url))
1222                })
1223                .detach_and_log_err(cx);
1224            })
1225            .log_err();
1226
1227            let credentials = background
1228                .clone()
1229                .spawn(async move {
1230                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1231                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1232                    // any other app running on the user's device.
1233                    let (public_key, private_key) =
1234                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1235                    let public_key_string = String::try_from(public_key)
1236                        .expect("failed to serialize public key for auth");
1237
1238                    if let Some((login, token)) =
1239                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1240                    {
1241                        eprintln!("authenticate as admin {login}, {token}");
1242
1243                        return this
1244                            .authenticate_as_admin(http, login.clone(), token.clone())
1245                            .await;
1246                    }
1247
1248                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1249                    let server =
1250                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1251                    let port = server.server_addr().port();
1252
1253                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1254                    // that the user is signing in from a Zed app running on the same device.
1255                    let mut url = http.build_url(&format!(
1256                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1257                        port, public_key_string
1258                    ));
1259
1260                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1261                        log::info!("impersonating user @{}", impersonate_login);
1262                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1263                    }
1264
1265                    open_url_tx.send(url).log_err();
1266
1267                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1268                    // access token from the query params.
1269                    //
1270                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1271                    // custom URL scheme instead of this local HTTP server.
1272                    let (user_id, access_token) = background
1273                        .spawn(async move {
1274                            for _ in 0..100 {
1275                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1276                                    let path = req.url();
1277                                    let mut user_id = None;
1278                                    let mut access_token = None;
1279                                    let url = Url::parse(&format!("http://example.com{}", path))
1280                                        .context("failed to parse login notification url")?;
1281                                    for (key, value) in url.query_pairs() {
1282                                        if key == "access_token" {
1283                                            access_token = Some(value.to_string());
1284                                        } else if key == "user_id" {
1285                                            user_id = Some(value.to_string());
1286                                        }
1287                                    }
1288
1289                                    let post_auth_url =
1290                                        http.build_url("/native_app_signin_succeeded");
1291                                    req.respond(
1292                                        tiny_http::Response::empty(302).with_header(
1293                                            tiny_http::Header::from_bytes(
1294                                                &b"Location"[..],
1295                                                post_auth_url.as_bytes(),
1296                                            )
1297                                            .unwrap(),
1298                                        ),
1299                                    )
1300                                    .context("failed to respond to login http request")?;
1301                                    return Ok((
1302                                        user_id
1303                                            .ok_or_else(|| anyhow!("missing user_id parameter"))?,
1304                                        access_token.ok_or_else(|| {
1305                                            anyhow!("missing access_token parameter")
1306                                        })?,
1307                                    ));
1308                                }
1309                            }
1310
1311                            Err(anyhow!("didn't receive login redirect"))
1312                        })
1313                        .await?;
1314
1315                    let access_token = private_key
1316                        .decrypt_string(&access_token)
1317                        .context("failed to decrypt access token")?;
1318
1319                    Ok(Credentials::User {
1320                        user_id: user_id.parse()?,
1321                        access_token,
1322                    })
1323                })
1324                .await?;
1325
1326            cx.update(|cx| cx.activate(true))?;
1327            Ok(credentials)
1328        })
1329    }
1330
1331    async fn authenticate_as_admin(
1332        self: &Arc<Self>,
1333        http: Arc<HttpClientWithUrl>,
1334        login: String,
1335        mut api_token: String,
1336    ) -> Result<Credentials> {
1337        #[derive(Deserialize)]
1338        struct AuthenticatedUserResponse {
1339            user: User,
1340        }
1341
1342        #[derive(Deserialize)]
1343        struct User {
1344            id: u64,
1345        }
1346
1347        // Use the collab server's admin API to retrieve the id
1348        // of the impersonated user.
1349        let mut url = self.rpc_url(http.clone(), None).await?;
1350        url.set_path("/user");
1351        url.set_query(Some(&format!("github_login={login}")));
1352        let request = Request::get(url.as_str())
1353            .header("Authorization", format!("token {api_token}"))
1354            .body("".into())?;
1355
1356        let mut response = http.send(request).await?;
1357        let mut body = String::new();
1358        response.body_mut().read_to_string(&mut body).await?;
1359        if !response.status().is_success() {
1360            Err(anyhow!(
1361                "admin user request failed {} - {}",
1362                response.status().as_u16(),
1363                body,
1364            ))?;
1365        }
1366        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1367
1368        // Use the admin API token to authenticate as the impersonated user.
1369        api_token.insert_str(0, "ADMIN_TOKEN:");
1370        Ok(Credentials::User {
1371            user_id: response.user.id,
1372            access_token: api_token,
1373        })
1374    }
1375
1376    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncAppContext) {
1377        self.state.write().credentials = None;
1378        self.disconnect(&cx);
1379
1380        if self.has_credentials(cx).await {
1381            self.credentials_provider
1382                .delete_credentials(cx)
1383                .await
1384                .log_err();
1385        }
1386    }
1387
1388    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1389        self.peer.teardown();
1390        self.set_status(Status::SignedOut, cx);
1391    }
1392
1393    pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1394        self.peer.teardown();
1395        self.set_status(Status::ConnectionLost, cx);
1396    }
1397
1398    fn connection_id(&self) -> Result<ConnectionId> {
1399        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1400            Ok(connection_id)
1401        } else {
1402            Err(anyhow!("not connected"))
1403        }
1404    }
1405
1406    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1407        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1408        self.peer.send(self.connection_id()?, message)
1409    }
1410
1411    pub fn request<T: RequestMessage>(
1412        &self,
1413        request: T,
1414    ) -> impl Future<Output = Result<T::Response>> {
1415        self.request_envelope(request)
1416            .map_ok(|envelope| envelope.payload)
1417    }
1418
1419    pub fn request_stream<T: RequestMessage>(
1420        &self,
1421        request: T,
1422    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1423        let client_id = self.id.load(Ordering::SeqCst);
1424        log::debug!(
1425            "rpc request start. client_id:{}. name:{}",
1426            client_id,
1427            T::NAME
1428        );
1429        let response = self
1430            .connection_id()
1431            .map(|conn_id| self.peer.request_stream(conn_id, request));
1432        async move {
1433            let response = response?.await;
1434            log::debug!(
1435                "rpc request finish. client_id:{}. name:{}",
1436                client_id,
1437                T::NAME
1438            );
1439            response
1440        }
1441    }
1442
1443    pub fn request_envelope<T: RequestMessage>(
1444        &self,
1445        request: T,
1446    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1447        let client_id = self.id();
1448        log::debug!(
1449            "rpc request start. client_id:{}. name:{}",
1450            client_id,
1451            T::NAME
1452        );
1453        let response = self
1454            .connection_id()
1455            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1456        async move {
1457            let response = response?.await;
1458            log::debug!(
1459                "rpc request finish. client_id:{}. name:{}",
1460                client_id,
1461                T::NAME
1462            );
1463            response
1464        }
1465    }
1466
1467    pub fn request_dynamic(
1468        &self,
1469        envelope: proto::Envelope,
1470        request_type: &'static str,
1471    ) -> impl Future<Output = Result<proto::Envelope>> {
1472        let client_id = self.id();
1473        log::debug!(
1474            "rpc request start. client_id:{}. name:{}",
1475            client_id,
1476            request_type
1477        );
1478        let response = self
1479            .connection_id()
1480            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1481        async move {
1482            let response = response?.await;
1483            log::debug!(
1484                "rpc request finish. client_id:{}. name:{}",
1485                client_id,
1486                request_type
1487            );
1488            Ok(response?.0)
1489        }
1490    }
1491
1492    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1493        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1494        self.peer.respond(receipt, response)
1495    }
1496
1497    fn respond_with_error<T: RequestMessage>(
1498        &self,
1499        receipt: Receipt<T>,
1500        error: proto::Error,
1501    ) -> Result<()> {
1502        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1503        self.peer.respond_with_error(receipt, error)
1504    }
1505
1506    fn handle_message(
1507        self: &Arc<Client>,
1508        message: Box<dyn AnyTypedEnvelope>,
1509        cx: &AsyncAppContext,
1510    ) {
1511        let mut state = self.state.write();
1512        let type_name = message.payload_type_name();
1513        let payload_type_id = message.payload_type_id();
1514        let sender_id = message.original_sender_id();
1515
1516        let mut subscriber = None;
1517
1518        if let Some(handle) = state
1519            .models_by_message_type
1520            .get(&payload_type_id)
1521            .and_then(|handle| handle.upgrade())
1522        {
1523            subscriber = Some(handle);
1524        } else if let Some((extract_entity_id, entity_type_id)) =
1525            state.entity_id_extractors.get(&payload_type_id).zip(
1526                state
1527                    .entity_types_by_message_type
1528                    .get(&payload_type_id)
1529                    .copied(),
1530            )
1531        {
1532            let entity_id = (extract_entity_id)(message.as_ref());
1533
1534            match state
1535                .entities_by_type_and_remote_id
1536                .get_mut(&(entity_type_id, entity_id))
1537            {
1538                Some(WeakSubscriber::Pending(pending)) => {
1539                    pending.push(message);
1540                    return;
1541                }
1542                Some(weak_subscriber) => match weak_subscriber {
1543                    WeakSubscriber::Entity { handle } => {
1544                        subscriber = handle.upgrade();
1545                    }
1546
1547                    WeakSubscriber::Pending(_) => {}
1548                },
1549                _ => {}
1550            }
1551        }
1552
1553        let subscriber = if let Some(subscriber) = subscriber {
1554            subscriber
1555        } else {
1556            log::info!("unhandled message {}", type_name);
1557            self.peer.respond_with_unhandled_message(message).log_err();
1558            return;
1559        };
1560
1561        let handler = state.message_handlers.get(&payload_type_id).cloned();
1562        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1563        // It also ensures we don't hold the lock while yielding back to the executor, as
1564        // that might cause the executor thread driving this future to block indefinitely.
1565        drop(state);
1566
1567        if let Some(handler) = handler {
1568            let future = handler(subscriber, message, self, cx.clone());
1569            let client_id = self.id();
1570            log::debug!(
1571                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1572                client_id,
1573                sender_id,
1574                type_name
1575            );
1576            cx.spawn(move |_| async move {
1577                    match future.await {
1578                        Ok(()) => {
1579                            log::debug!(
1580                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1581                                client_id,
1582                                sender_id,
1583                                type_name
1584                            );
1585                        }
1586                        Err(error) => {
1587                            log::error!(
1588                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1589                                client_id,
1590                                sender_id,
1591                                type_name,
1592                                error
1593                            );
1594                        }
1595                    }
1596                })
1597                .detach();
1598        } else {
1599            log::info!("unhandled message {}", type_name);
1600            self.peer.respond_with_unhandled_message(message).log_err();
1601        }
1602    }
1603
1604    pub fn telemetry(&self) -> &Arc<Telemetry> {
1605        &self.telemetry
1606    }
1607}
1608
1609#[derive(Serialize, Deserialize)]
1610struct DevelopmentCredentials {
1611    user_id: u64,
1612    access_token: String,
1613}
1614
1615/// A credentials provider that stores credentials in a local file.
1616///
1617/// This MUST only be used in development, as this is not a secure way of storing
1618/// credentials on user machines.
1619///
1620/// Its existence is purely to work around the annoyance of having to constantly
1621/// re-allow access to the system keychain when developing Zed.
1622struct DevelopmentCredentialsProvider {
1623    path: PathBuf,
1624}
1625
1626impl CredentialsProvider for DevelopmentCredentialsProvider {
1627    fn read_credentials<'a>(
1628        &'a self,
1629        _cx: &'a AsyncAppContext,
1630    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1631        async move {
1632            if IMPERSONATE_LOGIN.is_some() {
1633                return None;
1634            }
1635
1636            let json = std::fs::read(&self.path).log_err()?;
1637
1638            let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?;
1639
1640            Some(Credentials::User {
1641                user_id: credentials.user_id,
1642                access_token: credentials.access_token,
1643            })
1644        }
1645        .boxed_local()
1646    }
1647
1648    fn write_credentials<'a>(
1649        &'a self,
1650        user_id: u64,
1651        access_token: String,
1652        _cx: &'a AsyncAppContext,
1653    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1654        async move {
1655            let json = serde_json::to_string(&DevelopmentCredentials {
1656                user_id,
1657                access_token,
1658            })?;
1659
1660            std::fs::write(&self.path, json)?;
1661
1662            Ok(())
1663        }
1664        .boxed_local()
1665    }
1666
1667    fn delete_credentials<'a>(
1668        &'a self,
1669        _cx: &'a AsyncAppContext,
1670    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1671        async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local()
1672    }
1673}
1674
1675/// A credentials provider that stores credentials in the system keychain.
1676struct KeychainCredentialsProvider;
1677
1678impl CredentialsProvider for KeychainCredentialsProvider {
1679    fn read_credentials<'a>(
1680        &'a self,
1681        cx: &'a AsyncAppContext,
1682    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1683        async move {
1684            if IMPERSONATE_LOGIN.is_some() {
1685                return None;
1686            }
1687
1688            let (user_id, access_token) = cx
1689                .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
1690                .log_err()?
1691                .await
1692                .log_err()??;
1693
1694            Some(Credentials::User {
1695                user_id: user_id.parse().ok()?,
1696                access_token: String::from_utf8(access_token).ok()?,
1697            })
1698        }
1699        .boxed_local()
1700    }
1701
1702    fn write_credentials<'a>(
1703        &'a self,
1704        user_id: u64,
1705        access_token: String,
1706        cx: &'a AsyncAppContext,
1707    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1708        async move {
1709            cx.update(move |cx| {
1710                cx.write_credentials(
1711                    &ClientSettings::get_global(cx).server_url,
1712                    &user_id.to_string(),
1713                    access_token.as_bytes(),
1714                )
1715            })?
1716            .await
1717        }
1718        .boxed_local()
1719    }
1720
1721    fn delete_credentials<'a>(
1722        &'a self,
1723        cx: &'a AsyncAppContext,
1724    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1725        async move {
1726            cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
1727                .await
1728        }
1729        .boxed_local()
1730    }
1731}
1732
1733/// prefix for the zed:// url scheme
1734pub static ZED_URL_SCHEME: &str = "zed";
1735
1736/// Parses the given link into a Zed link.
1737///
1738/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1739/// Returns [`None`] otherwise.
1740pub fn parse_zed_link<'a>(link: &'a str, cx: &AppContext) -> Option<&'a str> {
1741    let server_url = &ClientSettings::get_global(cx).server_url;
1742    if let Some(stripped) = link
1743        .strip_prefix(server_url)
1744        .and_then(|result| result.strip_prefix('/'))
1745    {
1746        return Some(stripped);
1747    }
1748    if let Some(stripped) = link
1749        .strip_prefix(ZED_URL_SCHEME)
1750        .and_then(|result| result.strip_prefix("://"))
1751    {
1752        return Some(stripped);
1753    }
1754
1755    None
1756}
1757
1758#[cfg(test)]
1759mod tests {
1760    use super::*;
1761    use crate::test::FakeServer;
1762
1763    use clock::FakeSystemClock;
1764    use gpui::{BackgroundExecutor, Context, TestAppContext};
1765    use http::FakeHttpClient;
1766    use parking_lot::Mutex;
1767    use proto::TypedEnvelope;
1768    use settings::SettingsStore;
1769    use std::future;
1770
1771    #[gpui::test(iterations = 10)]
1772    async fn test_reconnection(cx: &mut TestAppContext) {
1773        init_test(cx);
1774        let user_id = 5;
1775        let client = cx.update(|cx| {
1776            Client::new(
1777                Arc::new(FakeSystemClock::default()),
1778                FakeHttpClient::with_404_response(),
1779                cx,
1780            )
1781        });
1782        let server = FakeServer::for_client(user_id, &client, cx).await;
1783        let mut status = client.status();
1784        assert!(matches!(
1785            status.next().await,
1786            Some(Status::Connected { .. })
1787        ));
1788        assert_eq!(server.auth_count(), 1);
1789
1790        server.forbid_connections();
1791        server.disconnect();
1792        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1793
1794        server.allow_connections();
1795        cx.executor().advance_clock(Duration::from_secs(10));
1796        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1797        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1798
1799        server.forbid_connections();
1800        server.disconnect();
1801        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1802
1803        // Clear cached credentials after authentication fails
1804        server.roll_access_token();
1805        server.allow_connections();
1806        cx.executor().run_until_parked();
1807        cx.executor().advance_clock(Duration::from_secs(10));
1808        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1809        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1810    }
1811
1812    #[gpui::test(iterations = 10)]
1813    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1814        init_test(cx);
1815        let user_id = 5;
1816        let client = cx.update(|cx| {
1817            Client::new(
1818                Arc::new(FakeSystemClock::default()),
1819                FakeHttpClient::with_404_response(),
1820                cx,
1821            )
1822        });
1823        let mut status = client.status();
1824
1825        // Time out when client tries to connect.
1826        client.override_authenticate(move |cx| {
1827            cx.background_executor().spawn(async move {
1828                Ok(Credentials::User {
1829                    user_id,
1830                    access_token: "token".into(),
1831                })
1832            })
1833        });
1834        client.override_establish_connection(|_, cx| {
1835            cx.background_executor().spawn(async move {
1836                future::pending::<()>().await;
1837                unreachable!()
1838            })
1839        });
1840        let auth_and_connect = cx.spawn({
1841            let client = client.clone();
1842            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1843        });
1844        executor.run_until_parked();
1845        assert!(matches!(status.next().await, Some(Status::Connecting)));
1846
1847        executor.advance_clock(CONNECTION_TIMEOUT);
1848        assert!(matches!(
1849            status.next().await,
1850            Some(Status::ConnectionError { .. })
1851        ));
1852        auth_and_connect.await.unwrap_err();
1853
1854        // Allow the connection to be established.
1855        let server = FakeServer::for_client(user_id, &client, cx).await;
1856        assert!(matches!(
1857            status.next().await,
1858            Some(Status::Connected { .. })
1859        ));
1860
1861        // Disconnect client.
1862        server.forbid_connections();
1863        server.disconnect();
1864        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1865
1866        // Time out when re-establishing the connection.
1867        server.allow_connections();
1868        client.override_establish_connection(|_, cx| {
1869            cx.background_executor().spawn(async move {
1870                future::pending::<()>().await;
1871                unreachable!()
1872            })
1873        });
1874        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1875        assert!(matches!(
1876            status.next().await,
1877            Some(Status::Reconnecting { .. })
1878        ));
1879
1880        executor.advance_clock(CONNECTION_TIMEOUT);
1881        assert!(matches!(
1882            status.next().await,
1883            Some(Status::ReconnectionError { .. })
1884        ));
1885    }
1886
1887    #[gpui::test(iterations = 10)]
1888    async fn test_authenticating_more_than_once(
1889        cx: &mut TestAppContext,
1890        executor: BackgroundExecutor,
1891    ) {
1892        init_test(cx);
1893        let auth_count = Arc::new(Mutex::new(0));
1894        let dropped_auth_count = Arc::new(Mutex::new(0));
1895        let client = cx.update(|cx| {
1896            Client::new(
1897                Arc::new(FakeSystemClock::default()),
1898                FakeHttpClient::with_404_response(),
1899                cx,
1900            )
1901        });
1902        client.override_authenticate({
1903            let auth_count = auth_count.clone();
1904            let dropped_auth_count = dropped_auth_count.clone();
1905            move |cx| {
1906                let auth_count = auth_count.clone();
1907                let dropped_auth_count = dropped_auth_count.clone();
1908                cx.background_executor().spawn(async move {
1909                    *auth_count.lock() += 1;
1910                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1911                    future::pending::<()>().await;
1912                    unreachable!()
1913                })
1914            }
1915        });
1916
1917        let _authenticate = cx.spawn({
1918            let client = client.clone();
1919            move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1920        });
1921        executor.run_until_parked();
1922        assert_eq!(*auth_count.lock(), 1);
1923        assert_eq!(*dropped_auth_count.lock(), 0);
1924
1925        let _authenticate = cx.spawn({
1926            let client = client.clone();
1927            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1928        });
1929        executor.run_until_parked();
1930        assert_eq!(*auth_count.lock(), 2);
1931        assert_eq!(*dropped_auth_count.lock(), 1);
1932    }
1933
1934    #[gpui::test]
1935    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1936        init_test(cx);
1937        let user_id = 5;
1938        let client = cx.update(|cx| {
1939            Client::new(
1940                Arc::new(FakeSystemClock::default()),
1941                FakeHttpClient::with_404_response(),
1942                cx,
1943            )
1944        });
1945        let server = FakeServer::for_client(user_id, &client, cx).await;
1946
1947        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1948        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1949        client.add_model_message_handler(
1950            move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, mut cx| {
1951                match model.update(&mut cx, |model, _| model.id).unwrap() {
1952                    1 => done_tx1.try_send(()).unwrap(),
1953                    2 => done_tx2.try_send(()).unwrap(),
1954                    _ => unreachable!(),
1955                }
1956                async { Ok(()) }
1957            },
1958        );
1959        let model1 = cx.new_model(|_| TestModel {
1960            id: 1,
1961            subscription: None,
1962        });
1963        let model2 = cx.new_model(|_| TestModel {
1964            id: 2,
1965            subscription: None,
1966        });
1967        let model3 = cx.new_model(|_| TestModel {
1968            id: 3,
1969            subscription: None,
1970        });
1971
1972        let _subscription1 = client
1973            .subscribe_to_entity(1)
1974            .unwrap()
1975            .set_model(&model1, &mut cx.to_async());
1976        let _subscription2 = client
1977            .subscribe_to_entity(2)
1978            .unwrap()
1979            .set_model(&model2, &mut cx.to_async());
1980        // Ensure dropping a subscription for the same entity type still allows receiving of
1981        // messages for other entity IDs of the same type.
1982        let subscription3 = client
1983            .subscribe_to_entity(3)
1984            .unwrap()
1985            .set_model(&model3, &mut cx.to_async());
1986        drop(subscription3);
1987
1988        server.send(proto::JoinProject { project_id: 1 });
1989        server.send(proto::JoinProject { project_id: 2 });
1990        done_rx1.next().await.unwrap();
1991        done_rx2.next().await.unwrap();
1992    }
1993
1994    #[gpui::test]
1995    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1996        init_test(cx);
1997        let user_id = 5;
1998        let client = cx.update(|cx| {
1999            Client::new(
2000                Arc::new(FakeSystemClock::default()),
2001                FakeHttpClient::with_404_response(),
2002                cx,
2003            )
2004        });
2005        let server = FakeServer::for_client(user_id, &client, cx).await;
2006
2007        let model = cx.new_model(|_| TestModel::default());
2008        let (done_tx1, _done_rx1) = smol::channel::unbounded();
2009        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
2010        let subscription1 = client.add_message_handler(
2011            model.downgrade(),
2012            move |_, _: TypedEnvelope<proto::Ping>, _| {
2013                done_tx1.try_send(()).unwrap();
2014                async { Ok(()) }
2015            },
2016        );
2017        drop(subscription1);
2018        let _subscription2 = client.add_message_handler(
2019            model.downgrade(),
2020            move |_, _: TypedEnvelope<proto::Ping>, _| {
2021                done_tx2.try_send(()).unwrap();
2022                async { Ok(()) }
2023            },
2024        );
2025        server.send(proto::Ping {});
2026        done_rx2.next().await.unwrap();
2027    }
2028
2029    #[gpui::test]
2030    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2031        init_test(cx);
2032        let user_id = 5;
2033        let client = cx.update(|cx| {
2034            Client::new(
2035                Arc::new(FakeSystemClock::default()),
2036                FakeHttpClient::with_404_response(),
2037                cx,
2038            )
2039        });
2040        let server = FakeServer::for_client(user_id, &client, cx).await;
2041
2042        let model = cx.new_model(|_| TestModel::default());
2043        let (done_tx, mut done_rx) = smol::channel::unbounded();
2044        let subscription = client.add_message_handler(
2045            model.clone().downgrade(),
2046            move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, mut cx| {
2047                model
2048                    .update(&mut cx, |model, _| model.subscription.take())
2049                    .unwrap();
2050                done_tx.try_send(()).unwrap();
2051                async { Ok(()) }
2052            },
2053        );
2054        model.update(cx, |model, _| {
2055            model.subscription = Some(subscription);
2056        });
2057        server.send(proto::Ping {});
2058        done_rx.next().await.unwrap();
2059    }
2060
2061    #[derive(Default)]
2062    struct TestModel {
2063        id: usize,
2064        subscription: Option<Subscription>,
2065    }
2066
2067    fn init_test(cx: &mut TestAppContext) {
2068        cx.update(|cx| {
2069            let settings_store = SettingsStore::test(cx);
2070            cx.set_global(settings_store);
2071            init_settings(cx);
2072        });
2073    }
2074}