client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod telemetry;
   5pub mod user;
   6
   7use anyhow::{anyhow, Context as _, Result};
   8use async_recursion::async_recursion;
   9use async_tungstenite::tungstenite::{
  10    error::Error as WebsocketError,
  11    http::{Request, StatusCode},
  12};
  13use clock::SystemClock;
  14use collections::HashMap;
  15use futures::{
  16    channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
  17    TryFutureExt as _, TryStreamExt,
  18};
  19use gpui::{
  20    actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, BorrowAppContext, Global, Model,
  21    Task, WeakModel,
  22};
  23use lazy_static::lazy_static;
  24use parking_lot::RwLock;
  25use postage::watch;
  26use rand::prelude::*;
  27use release_channel::{AppVersion, ReleaseChannel};
  28use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::{Settings, SettingsSources, SettingsStore};
  32use std::fmt;
  33use std::pin::Pin;
  34use std::{
  35    any::TypeId,
  36    convert::TryFrom,
  37    fmt::Write as _,
  38    future::Future,
  39    marker::PhantomData,
  40    path::PathBuf,
  41    sync::{
  42        atomic::{AtomicU64, Ordering},
  43        Arc, Weak,
  44    },
  45    time::{Duration, Instant},
  46};
  47use telemetry::Telemetry;
  48use thiserror::Error;
  49use url::Url;
  50use util::http::{HttpClient, HttpClientWithUrl};
  51use util::{ResultExt, TryFutureExt};
  52
  53pub use rpc::*;
  54pub use telemetry_events::Event;
  55pub use user::*;
  56
  57#[derive(Debug, Clone, Eq, PartialEq)]
  58pub struct DevServerToken(pub String);
  59
  60impl fmt::Display for DevServerToken {
  61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  62        write!(f, "{}", self.0)
  63    }
  64}
  65
  66lazy_static! {
  67    static ref ZED_SERVER_URL: Option<String> = std::env::var("ZED_SERVER_URL").ok();
  68    static ref ZED_RPC_URL: Option<String> = std::env::var("ZED_RPC_URL").ok();
  69    /// An environment variable whose presence indicates that the development auth
  70    /// provider should be used.
  71    ///
  72    /// Only works in development. Setting this environment variable in other release
  73    /// channels is a no-op.
  74    pub static ref ZED_DEVELOPMENT_AUTH: bool =
  75        std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty());
  76    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  77        .ok()
  78        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  79    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
  80        .ok()
  81        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  82    pub static ref ZED_APP_PATH: Option<PathBuf> =
  83        std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
  84    pub static ref ZED_ALWAYS_ACTIVE: bool =
  85        std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty());
  86}
  87
  88pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
  89pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
  90
  91actions!(client, [SignIn, SignOut, Reconnect]);
  92
  93#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
  94pub struct ClientSettingsContent {
  95    server_url: Option<String>,
  96}
  97
  98#[derive(Deserialize)]
  99pub struct ClientSettings {
 100    pub server_url: String,
 101}
 102
 103impl Settings for ClientSettings {
 104    const KEY: Option<&'static str> = None;
 105
 106    type FileContent = ClientSettingsContent;
 107
 108    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 109        let mut result = sources.json_merge::<Self>()?;
 110        if let Some(server_url) = &*ZED_SERVER_URL {
 111            result.server_url.clone_from(&server_url)
 112        }
 113        Ok(result)
 114    }
 115}
 116
 117pub fn init_settings(cx: &mut AppContext) {
 118    TelemetrySettings::register(cx);
 119    cx.update_global(|store: &mut SettingsStore, cx| {
 120        store.register_setting::<ClientSettings>(cx);
 121    });
 122}
 123
 124pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
 125    let client = Arc::downgrade(client);
 126    cx.on_action({
 127        let client = client.clone();
 128        move |_: &SignIn, cx| {
 129            if let Some(client) = client.upgrade() {
 130                cx.spawn(
 131                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
 132                )
 133                .detach();
 134            }
 135        }
 136    });
 137
 138    cx.on_action({
 139        let client = client.clone();
 140        move |_: &SignOut, cx| {
 141            if let Some(client) = client.upgrade() {
 142                cx.spawn(|cx| async move {
 143                    client.sign_out(&cx).await;
 144                })
 145                .detach();
 146            }
 147        }
 148    });
 149
 150    cx.on_action({
 151        let client = client.clone();
 152        move |_: &Reconnect, cx| {
 153            if let Some(client) = client.upgrade() {
 154                cx.spawn(|cx| async move {
 155                    client.reconnect(&cx);
 156                })
 157                .detach();
 158            }
 159        }
 160    });
 161}
 162
 163struct GlobalClient(Arc<Client>);
 164
 165impl Global for GlobalClient {}
 166
 167pub struct Client {
 168    id: AtomicU64,
 169    peer: Arc<Peer>,
 170    http: Arc<HttpClientWithUrl>,
 171    telemetry: Arc<Telemetry>,
 172    credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
 173    state: RwLock<ClientState>,
 174
 175    #[allow(clippy::type_complexity)]
 176    #[cfg(any(test, feature = "test-support"))]
 177    authenticate: RwLock<
 178        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 179    >,
 180
 181    #[allow(clippy::type_complexity)]
 182    #[cfg(any(test, feature = "test-support"))]
 183    establish_connection: RwLock<
 184        Option<
 185            Box<
 186                dyn 'static
 187                    + Send
 188                    + Sync
 189                    + Fn(
 190                        &Credentials,
 191                        &AsyncAppContext,
 192                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 193            >,
 194        >,
 195    >,
 196}
 197
 198#[derive(Error, Debug)]
 199pub enum EstablishConnectionError {
 200    #[error("upgrade required")]
 201    UpgradeRequired,
 202    #[error("unauthorized")]
 203    Unauthorized,
 204    #[error("{0}")]
 205    Other(#[from] anyhow::Error),
 206    #[error("{0}")]
 207    Http(#[from] util::http::Error),
 208    #[error("{0}")]
 209    Io(#[from] std::io::Error),
 210    #[error("{0}")]
 211    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 212}
 213
 214impl From<WebsocketError> for EstablishConnectionError {
 215    fn from(error: WebsocketError) -> Self {
 216        if let WebsocketError::Http(response) = &error {
 217            match response.status() {
 218                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 219                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 220                _ => {}
 221            }
 222        }
 223        EstablishConnectionError::Other(error.into())
 224    }
 225}
 226
 227impl EstablishConnectionError {
 228    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 229        Self::Other(error.into())
 230    }
 231}
 232
 233#[derive(Copy, Clone, Debug, PartialEq)]
 234pub enum Status {
 235    SignedOut,
 236    UpgradeRequired,
 237    Authenticating,
 238    Connecting,
 239    ConnectionError,
 240    Connected {
 241        peer_id: PeerId,
 242        connection_id: ConnectionId,
 243    },
 244    ConnectionLost,
 245    Reauthenticating,
 246    Reconnecting,
 247    ReconnectionError {
 248        next_reconnection: Instant,
 249    },
 250}
 251
 252impl Status {
 253    pub fn is_connected(&self) -> bool {
 254        matches!(self, Self::Connected { .. })
 255    }
 256
 257    pub fn is_signed_out(&self) -> bool {
 258        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 259    }
 260}
 261
 262struct ClientState {
 263    credentials: Option<Credentials>,
 264    status: (watch::Sender<Status>, watch::Receiver<Status>),
 265    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 266    _reconnect_task: Option<Task<()>>,
 267    reconnect_interval: Duration,
 268    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 269    models_by_message_type: HashMap<TypeId, AnyWeakModel>,
 270    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 271    #[allow(clippy::type_complexity)]
 272    message_handlers: HashMap<
 273        TypeId,
 274        Arc<
 275            dyn Send
 276                + Sync
 277                + Fn(
 278                    AnyModel,
 279                    Box<dyn AnyTypedEnvelope>,
 280                    &Arc<Client>,
 281                    AsyncAppContext,
 282                ) -> LocalBoxFuture<'static, Result<()>>,
 283        >,
 284    >,
 285}
 286
 287enum WeakSubscriber {
 288    Entity { handle: AnyWeakModel },
 289    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 290}
 291
 292#[derive(Clone, Debug, Eq, PartialEq)]
 293pub enum Credentials {
 294    DevServer { token: DevServerToken },
 295    User { user_id: u64, access_token: String },
 296}
 297
 298impl Credentials {
 299    pub fn authorization_header(&self) -> String {
 300        match self {
 301            Credentials::DevServer { token } => format!("dev-server-token {}", token),
 302            Credentials::User {
 303                user_id,
 304                access_token,
 305            } => format!("{} {}", user_id, access_token),
 306        }
 307    }
 308}
 309
 310/// A provider for [`Credentials`].
 311///
 312/// Used to abstract over reading and writing credentials to some form of
 313/// persistence (like the system keychain).
 314trait CredentialsProvider {
 315    /// Reads the credentials from the provider.
 316    fn read_credentials<'a>(
 317        &'a self,
 318        cx: &'a AsyncAppContext,
 319    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>>;
 320
 321    /// Writes the credentials to the provider.
 322    fn write_credentials<'a>(
 323        &'a self,
 324        user_id: u64,
 325        access_token: String,
 326        cx: &'a AsyncAppContext,
 327    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 328
 329    /// Deletes the credentials from the provider.
 330    fn delete_credentials<'a>(
 331        &'a self,
 332        cx: &'a AsyncAppContext,
 333    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 334}
 335
 336impl Default for ClientState {
 337    fn default() -> Self {
 338        Self {
 339            credentials: None,
 340            status: watch::channel_with(Status::SignedOut),
 341            entity_id_extractors: Default::default(),
 342            _reconnect_task: None,
 343            reconnect_interval: Duration::from_secs(5),
 344            models_by_message_type: Default::default(),
 345            entities_by_type_and_remote_id: Default::default(),
 346            entity_types_by_message_type: Default::default(),
 347            message_handlers: Default::default(),
 348        }
 349    }
 350}
 351
 352pub enum Subscription {
 353    Entity {
 354        client: Weak<Client>,
 355        id: (TypeId, u64),
 356    },
 357    Message {
 358        client: Weak<Client>,
 359        id: TypeId,
 360    },
 361}
 362
 363impl Drop for Subscription {
 364    fn drop(&mut self) {
 365        match self {
 366            Subscription::Entity { client, id } => {
 367                if let Some(client) = client.upgrade() {
 368                    let mut state = client.state.write();
 369                    let _ = state.entities_by_type_and_remote_id.remove(id);
 370                }
 371            }
 372            Subscription::Message { client, id } => {
 373                if let Some(client) = client.upgrade() {
 374                    let mut state = client.state.write();
 375                    let _ = state.entity_types_by_message_type.remove(id);
 376                    let _ = state.message_handlers.remove(id);
 377                }
 378            }
 379        }
 380    }
 381}
 382
 383pub struct PendingEntitySubscription<T: 'static> {
 384    client: Arc<Client>,
 385    remote_id: u64,
 386    _entity_type: PhantomData<T>,
 387    consumed: bool,
 388}
 389
 390impl<T: 'static> PendingEntitySubscription<T> {
 391    pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
 392        self.consumed = true;
 393        let mut state = self.client.state.write();
 394        let id = (TypeId::of::<T>(), self.remote_id);
 395        let Some(WeakSubscriber::Pending(messages)) =
 396            state.entities_by_type_and_remote_id.remove(&id)
 397        else {
 398            unreachable!()
 399        };
 400
 401        state.entities_by_type_and_remote_id.insert(
 402            id,
 403            WeakSubscriber::Entity {
 404                handle: model.downgrade().into(),
 405            },
 406        );
 407        drop(state);
 408        for message in messages {
 409            self.client.handle_message(message, cx);
 410        }
 411        Subscription::Entity {
 412            client: Arc::downgrade(&self.client),
 413            id,
 414        }
 415    }
 416}
 417
 418impl<T: 'static> Drop for PendingEntitySubscription<T> {
 419    fn drop(&mut self) {
 420        if !self.consumed {
 421            let mut state = self.client.state.write();
 422            if let Some(WeakSubscriber::Pending(messages)) = state
 423                .entities_by_type_and_remote_id
 424                .remove(&(TypeId::of::<T>(), self.remote_id))
 425            {
 426                for message in messages {
 427                    log::info!("unhandled message {}", message.payload_type_name());
 428                }
 429            }
 430        }
 431    }
 432}
 433
 434#[derive(Copy, Clone)]
 435pub struct TelemetrySettings {
 436    pub diagnostics: bool,
 437    pub metrics: bool,
 438}
 439
 440/// Control what info is collected by Zed.
 441#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 442pub struct TelemetrySettingsContent {
 443    /// Send debug info like crash reports.
 444    ///
 445    /// Default: true
 446    pub diagnostics: Option<bool>,
 447    /// Send anonymized usage data like what languages you're using Zed with.
 448    ///
 449    /// Default: true
 450    pub metrics: Option<bool>,
 451}
 452
 453impl settings::Settings for TelemetrySettings {
 454    const KEY: Option<&'static str> = Some("telemetry");
 455
 456    type FileContent = TelemetrySettingsContent;
 457
 458    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 459        Ok(Self {
 460            diagnostics: sources.user.as_ref().and_then(|v| v.diagnostics).unwrap_or(
 461                sources
 462                    .default
 463                    .diagnostics
 464                    .ok_or_else(Self::missing_default)?,
 465            ),
 466            metrics: sources
 467                .user
 468                .as_ref()
 469                .and_then(|v| v.metrics)
 470                .unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
 471        })
 472    }
 473}
 474
 475impl Client {
 476    pub fn new(
 477        clock: Arc<dyn SystemClock>,
 478        http: Arc<HttpClientWithUrl>,
 479        cx: &mut AppContext,
 480    ) -> Arc<Self> {
 481        let use_zed_development_auth = match ReleaseChannel::try_global(cx) {
 482            Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH,
 483            Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
 484            | None => false,
 485        };
 486
 487        let credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static> =
 488            if use_zed_development_auth {
 489                Arc::new(DevelopmentCredentialsProvider {
 490                    path: util::paths::CONFIG_DIR.join("development_auth"),
 491                })
 492            } else {
 493                Arc::new(KeychainCredentialsProvider)
 494            };
 495
 496        Arc::new(Self {
 497            id: AtomicU64::new(0),
 498            peer: Peer::new(0),
 499            telemetry: Telemetry::new(clock, http.clone(), cx),
 500            http,
 501            credentials_provider,
 502            state: Default::default(),
 503
 504            #[cfg(any(test, feature = "test-support"))]
 505            authenticate: Default::default(),
 506            #[cfg(any(test, feature = "test-support"))]
 507            establish_connection: Default::default(),
 508        })
 509    }
 510
 511    pub fn production(cx: &mut AppContext) -> Arc<Self> {
 512        let clock = Arc::new(clock::RealSystemClock);
 513        let http = Arc::new(HttpClientWithUrl::new(
 514            &ClientSettings::get_global(cx).server_url,
 515        ));
 516        Self::new(clock, http.clone(), cx)
 517    }
 518
 519    pub fn id(&self) -> u64 {
 520        self.id.load(Ordering::SeqCst)
 521    }
 522
 523    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 524        self.http.clone()
 525    }
 526
 527    pub fn set_id(&self, id: u64) -> &Self {
 528        self.id.store(id, Ordering::SeqCst);
 529        self
 530    }
 531
 532    #[cfg(any(test, feature = "test-support"))]
 533    pub fn teardown(&self) {
 534        let mut state = self.state.write();
 535        state._reconnect_task.take();
 536        state.message_handlers.clear();
 537        state.models_by_message_type.clear();
 538        state.entities_by_type_and_remote_id.clear();
 539        state.entity_id_extractors.clear();
 540        self.peer.teardown();
 541    }
 542
 543    #[cfg(any(test, feature = "test-support"))]
 544    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 545    where
 546        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 547    {
 548        *self.authenticate.write() = Some(Box::new(authenticate));
 549        self
 550    }
 551
 552    #[cfg(any(test, feature = "test-support"))]
 553    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 554    where
 555        F: 'static
 556            + Send
 557            + Sync
 558            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 559    {
 560        *self.establish_connection.write() = Some(Box::new(connect));
 561        self
 562    }
 563
 564    pub fn global(cx: &AppContext) -> Arc<Self> {
 565        cx.global::<GlobalClient>().0.clone()
 566    }
 567    pub fn set_global(client: Arc<Client>, cx: &mut AppContext) {
 568        cx.set_global(GlobalClient(client))
 569    }
 570
 571    pub fn user_id(&self) -> Option<u64> {
 572        if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() {
 573            Some(*user_id)
 574        } else {
 575            None
 576        }
 577    }
 578
 579    pub fn peer_id(&self) -> Option<PeerId> {
 580        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 581            Some(*peer_id)
 582        } else {
 583            None
 584        }
 585    }
 586
 587    pub fn status(&self) -> watch::Receiver<Status> {
 588        self.state.read().status.1.clone()
 589    }
 590
 591    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 592        log::info!("set status on client {}: {:?}", self.id(), status);
 593        let mut state = self.state.write();
 594        *state.status.0.borrow_mut() = status;
 595
 596        match status {
 597            Status::Connected { .. } => {
 598                state._reconnect_task = None;
 599            }
 600            Status::ConnectionLost => {
 601                let this = self.clone();
 602                let reconnect_interval = state.reconnect_interval;
 603                state._reconnect_task = Some(cx.spawn(move |cx| async move {
 604                    #[cfg(any(test, feature = "test-support"))]
 605                    let mut rng = StdRng::seed_from_u64(0);
 606                    #[cfg(not(any(test, feature = "test-support")))]
 607                    let mut rng = StdRng::from_entropy();
 608
 609                    let mut delay = INITIAL_RECONNECTION_DELAY;
 610                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 611                        log::error!("failed to connect {}", error);
 612                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 613                            this.set_status(
 614                                Status::ReconnectionError {
 615                                    next_reconnection: Instant::now() + delay,
 616                                },
 617                                &cx,
 618                            );
 619                            cx.background_executor().timer(delay).await;
 620                            delay = delay
 621                                .mul_f32(rng.gen_range(1.0..=2.0))
 622                                .min(reconnect_interval);
 623                        } else {
 624                            break;
 625                        }
 626                    }
 627                }));
 628            }
 629            Status::SignedOut | Status::UpgradeRequired => {
 630                self.telemetry.set_authenticated_user_info(None, false);
 631                state._reconnect_task.take();
 632            }
 633            _ => {}
 634        }
 635    }
 636
 637    pub fn subscribe_to_entity<T>(
 638        self: &Arc<Self>,
 639        remote_id: u64,
 640    ) -> Result<PendingEntitySubscription<T>>
 641    where
 642        T: 'static,
 643    {
 644        let id = (TypeId::of::<T>(), remote_id);
 645
 646        let mut state = self.state.write();
 647        if state.entities_by_type_and_remote_id.contains_key(&id) {
 648            return Err(anyhow!("already subscribed to entity"));
 649        }
 650
 651        state
 652            .entities_by_type_and_remote_id
 653            .insert(id, WeakSubscriber::Pending(Default::default()));
 654
 655        Ok(PendingEntitySubscription {
 656            client: self.clone(),
 657            remote_id,
 658            consumed: false,
 659            _entity_type: PhantomData,
 660        })
 661    }
 662
 663    #[track_caller]
 664    pub fn add_message_handler<M, E, H, F>(
 665        self: &Arc<Self>,
 666        entity: WeakModel<E>,
 667        handler: H,
 668    ) -> Subscription
 669    where
 670        M: EnvelopedMessage,
 671        E: 'static,
 672        H: 'static
 673            + Sync
 674            + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
 675            + Send
 676            + Sync,
 677        F: 'static + Future<Output = Result<()>>,
 678    {
 679        let message_type_id = TypeId::of::<M>();
 680        let mut state = self.state.write();
 681        state
 682            .models_by_message_type
 683            .insert(message_type_id, entity.into());
 684
 685        let prev_handler = state.message_handlers.insert(
 686            message_type_id,
 687            Arc::new(move |subscriber, envelope, client, cx| {
 688                let subscriber = subscriber.downcast::<E>().unwrap();
 689                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 690                handler(subscriber, *envelope, client.clone(), cx).boxed_local()
 691            }),
 692        );
 693        if prev_handler.is_some() {
 694            let location = std::panic::Location::caller();
 695            panic!(
 696                "{}:{} registered handler for the same message {} twice",
 697                location.file(),
 698                location.line(),
 699                std::any::type_name::<M>()
 700            );
 701        }
 702
 703        Subscription::Message {
 704            client: Arc::downgrade(self),
 705            id: message_type_id,
 706        }
 707    }
 708
 709    pub fn add_request_handler<M, E, H, F>(
 710        self: &Arc<Self>,
 711        model: WeakModel<E>,
 712        handler: H,
 713    ) -> Subscription
 714    where
 715        M: RequestMessage,
 716        E: 'static,
 717        H: 'static
 718            + Sync
 719            + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
 720            + Send
 721            + Sync,
 722        F: 'static + Future<Output = Result<M::Response>>,
 723    {
 724        self.add_message_handler(model, move |handle, envelope, this, cx| {
 725            Self::respond_to_request(
 726                envelope.receipt(),
 727                handler(handle, envelope, this.clone(), cx),
 728                this,
 729            )
 730        })
 731    }
 732
 733    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 734    where
 735        M: EntityMessage,
 736        E: 'static,
 737        H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
 738        F: 'static + Future<Output = Result<()>>,
 739    {
 740        self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
 741            handler(subscriber.downcast::<E>().unwrap(), message, client, cx)
 742        })
 743    }
 744
 745    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 746    where
 747        M: EntityMessage,
 748        E: 'static,
 749        H: 'static + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
 750        F: 'static + Future<Output = Result<()>>,
 751    {
 752        let model_type_id = TypeId::of::<E>();
 753        let message_type_id = TypeId::of::<M>();
 754
 755        let mut state = self.state.write();
 756        state
 757            .entity_types_by_message_type
 758            .insert(message_type_id, model_type_id);
 759        state
 760            .entity_id_extractors
 761            .entry(message_type_id)
 762            .or_insert_with(|| {
 763                |envelope| {
 764                    envelope
 765                        .as_any()
 766                        .downcast_ref::<TypedEnvelope<M>>()
 767                        .unwrap()
 768                        .payload
 769                        .remote_entity_id()
 770                }
 771            });
 772        let prev_handler = state.message_handlers.insert(
 773            message_type_id,
 774            Arc::new(move |handle, envelope, client, cx| {
 775                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 776                handler(handle, *envelope, client.clone(), cx).boxed_local()
 777            }),
 778        );
 779        if prev_handler.is_some() {
 780            panic!("registered handler for the same message twice");
 781        }
 782    }
 783
 784    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 785    where
 786        M: EntityMessage + RequestMessage,
 787        E: 'static,
 788        H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
 789        F: 'static + Future<Output = Result<M::Response>>,
 790    {
 791        self.add_model_message_handler(move |entity, envelope, client, cx| {
 792            Self::respond_to_request::<M, _>(
 793                envelope.receipt(),
 794                handler(entity, envelope, client.clone(), cx),
 795                client,
 796            )
 797        })
 798    }
 799
 800    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 801        receipt: Receipt<T>,
 802        response: F,
 803        client: Arc<Self>,
 804    ) -> Result<()> {
 805        match response.await {
 806            Ok(response) => {
 807                client.respond(receipt, response)?;
 808                Ok(())
 809            }
 810            Err(error) => {
 811                client.respond_with_error(receipt, error.to_proto())?;
 812                Err(error)
 813            }
 814        }
 815    }
 816
 817    pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool {
 818        self.credentials_provider
 819            .read_credentials(cx)
 820            .await
 821            .is_some()
 822    }
 823
 824    pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self {
 825        self.state.write().credentials = Some(Credentials::DevServer { token });
 826        self
 827    }
 828
 829    #[async_recursion(?Send)]
 830    pub async fn authenticate_and_connect(
 831        self: &Arc<Self>,
 832        try_provider: bool,
 833        cx: &AsyncAppContext,
 834    ) -> anyhow::Result<()> {
 835        let was_disconnected = match *self.status().borrow() {
 836            Status::SignedOut => true,
 837            Status::ConnectionError
 838            | Status::ConnectionLost
 839            | Status::Authenticating { .. }
 840            | Status::Reauthenticating { .. }
 841            | Status::ReconnectionError { .. } => false,
 842            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 843                return Ok(())
 844            }
 845            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 846        };
 847        if was_disconnected {
 848            self.set_status(Status::Authenticating, cx);
 849        } else {
 850            self.set_status(Status::Reauthenticating, cx)
 851        }
 852
 853        let mut read_from_provider = false;
 854        let mut credentials = self.state.read().credentials.clone();
 855        if credentials.is_none() && try_provider {
 856            credentials = self.credentials_provider.read_credentials(cx).await;
 857            read_from_provider = credentials.is_some();
 858        }
 859
 860        if credentials.is_none() {
 861            let mut status_rx = self.status();
 862            let _ = status_rx.next().await;
 863            futures::select_biased! {
 864                authenticate = self.authenticate(cx).fuse() => {
 865                    match authenticate {
 866                        Ok(creds) => credentials = Some(creds),
 867                        Err(err) => {
 868                            self.set_status(Status::ConnectionError, cx);
 869                            return Err(err);
 870                        }
 871                    }
 872                }
 873                _ = status_rx.next().fuse() => {
 874                    return Err(anyhow!("authentication canceled"));
 875                }
 876            }
 877        }
 878        let credentials = credentials.unwrap();
 879        if let Credentials::User { user_id, .. } = &credentials {
 880            self.set_id(*user_id);
 881        }
 882
 883        if was_disconnected {
 884            self.set_status(Status::Connecting, cx);
 885        } else {
 886            self.set_status(Status::Reconnecting, cx);
 887        }
 888
 889        let mut timeout =
 890            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
 891        futures::select_biased! {
 892            connection = self.establish_connection(&credentials, cx).fuse() => {
 893                match connection {
 894                    Ok(conn) => {
 895                        self.state.write().credentials = Some(credentials.clone());
 896                        if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
 897                            if let Credentials::User{user_id, access_token} = credentials {
 898                                self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err();
 899                            }
 900                        }
 901
 902                        futures::select_biased! {
 903                            result = self.set_connection(conn, cx).fuse() => result,
 904                            _ = timeout => {
 905                                self.set_status(Status::ConnectionError, cx);
 906                                Err(anyhow!("timed out waiting on hello message from server"))
 907                            }
 908                        }
 909                    }
 910                    Err(EstablishConnectionError::Unauthorized) => {
 911                        self.state.write().credentials.take();
 912                        if read_from_provider {
 913                            self.credentials_provider.delete_credentials(cx).await.log_err();
 914                            self.set_status(Status::SignedOut, cx);
 915                            self.authenticate_and_connect(false, cx).await
 916                        } else {
 917                            self.set_status(Status::ConnectionError, cx);
 918                            Err(EstablishConnectionError::Unauthorized)?
 919                        }
 920                    }
 921                    Err(EstablishConnectionError::UpgradeRequired) => {
 922                        self.set_status(Status::UpgradeRequired, cx);
 923                        Err(EstablishConnectionError::UpgradeRequired)?
 924                    }
 925                    Err(error) => {
 926                        self.set_status(Status::ConnectionError, cx);
 927                        Err(error)?
 928                    }
 929                }
 930            }
 931            _ = &mut timeout => {
 932                self.set_status(Status::ConnectionError, cx);
 933                Err(anyhow!("timed out trying to establish connection"))
 934            }
 935        }
 936    }
 937
 938    async fn set_connection(
 939        self: &Arc<Self>,
 940        conn: Connection,
 941        cx: &AsyncAppContext,
 942    ) -> Result<()> {
 943        let executor = cx.background_executor();
 944        log::info!("add connection to peer");
 945        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
 946            let executor = executor.clone();
 947            move |duration| executor.timer(duration)
 948        });
 949        let handle_io = executor.spawn(handle_io);
 950
 951        let peer_id = async {
 952            log::info!("waiting for server hello");
 953            let message = incoming
 954                .next()
 955                .await
 956                .ok_or_else(|| anyhow!("no hello message received"))?;
 957            log::info!("got server hello");
 958            let hello_message_type_name = message.payload_type_name().to_string();
 959            let hello = message
 960                .into_any()
 961                .downcast::<TypedEnvelope<proto::Hello>>()
 962                .map_err(|_| {
 963                    anyhow!(
 964                        "invalid hello message received: {:?}",
 965                        hello_message_type_name
 966                    )
 967                })?;
 968            let peer_id = hello
 969                .payload
 970                .peer_id
 971                .ok_or_else(|| anyhow!("invalid peer id"))?;
 972            Ok(peer_id)
 973        };
 974
 975        let peer_id = match peer_id.await {
 976            Ok(peer_id) => peer_id,
 977            Err(error) => {
 978                self.peer.disconnect(connection_id);
 979                return Err(error);
 980            }
 981        };
 982
 983        log::info!(
 984            "set status to connected (connection id: {:?}, peer id: {:?})",
 985            connection_id,
 986            peer_id
 987        );
 988        self.set_status(
 989            Status::Connected {
 990                peer_id,
 991                connection_id,
 992            },
 993            cx,
 994        );
 995
 996        cx.spawn({
 997            let this = self.clone();
 998            |cx| {
 999                async move {
1000                    while let Some(message) = incoming.next().await {
1001                        this.handle_message(message, &cx);
1002                        // Don't starve the main thread when receiving lots of messages at once.
1003                        smol::future::yield_now().await;
1004                    }
1005                }
1006            }
1007        })
1008        .detach();
1009
1010        cx.spawn({
1011            let this = self.clone();
1012            move |cx| async move {
1013                match handle_io.await {
1014                    Ok(()) => {
1015                        if *this.status().borrow()
1016                            == (Status::Connected {
1017                                connection_id,
1018                                peer_id,
1019                            })
1020                        {
1021                            this.set_status(Status::SignedOut, &cx);
1022                        }
1023                    }
1024                    Err(err) => {
1025                        log::error!("connection error: {:?}", err);
1026                        this.set_status(Status::ConnectionLost, &cx);
1027                    }
1028                }
1029            }
1030        })
1031        .detach();
1032
1033        Ok(())
1034    }
1035
1036    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
1037        #[cfg(any(test, feature = "test-support"))]
1038        if let Some(callback) = self.authenticate.read().as_ref() {
1039            return callback(cx);
1040        }
1041
1042        self.authenticate_with_browser(cx)
1043    }
1044
1045    fn establish_connection(
1046        self: &Arc<Self>,
1047        credentials: &Credentials,
1048        cx: &AsyncAppContext,
1049    ) -> Task<Result<Connection, EstablishConnectionError>> {
1050        #[cfg(any(test, feature = "test-support"))]
1051        if let Some(callback) = self.establish_connection.read().as_ref() {
1052            return callback(credentials, cx);
1053        }
1054
1055        self.establish_websocket_connection(credentials, cx)
1056    }
1057
1058    async fn get_rpc_url(
1059        http: Arc<HttpClientWithUrl>,
1060        release_channel: Option<ReleaseChannel>,
1061    ) -> Result<Url> {
1062        if let Some(url) = &*ZED_RPC_URL {
1063            return Url::parse(url).context("invalid rpc url");
1064        }
1065
1066        let mut url = http.build_url("/rpc");
1067        if let Some(preview_param) =
1068            release_channel.and_then(|channel| channel.release_query_param())
1069        {
1070            url += "?";
1071            url += preview_param;
1072        }
1073        let response = http.get(&url, Default::default(), false).await?;
1074        let collab_url = if response.status().is_redirection() {
1075            response
1076                .headers()
1077                .get("Location")
1078                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1079                .to_str()
1080                .map_err(EstablishConnectionError::other)?
1081                .to_string()
1082        } else {
1083            Err(anyhow!(
1084                "unexpected /rpc response status {}",
1085                response.status()
1086            ))?
1087        };
1088
1089        Url::parse(&collab_url).context("invalid rpc url")
1090    }
1091
1092    fn establish_websocket_connection(
1093        self: &Arc<Self>,
1094        credentials: &Credentials,
1095        cx: &AsyncAppContext,
1096    ) -> Task<Result<Connection, EstablishConnectionError>> {
1097        let release_channel = cx
1098            .update(|cx| ReleaseChannel::try_global(cx))
1099            .ok()
1100            .flatten();
1101        let app_version = cx
1102            .update(|cx| AppVersion::global(cx).to_string())
1103            .ok()
1104            .unwrap_or_default();
1105
1106        let request = Request::builder()
1107            .header("Authorization", credentials.authorization_header())
1108            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION)
1109            .header("x-zed-app-version", app_version)
1110            .header(
1111                "x-zed-release-channel",
1112                release_channel.map(|r| r.dev_name()).unwrap_or("unknown"),
1113            );
1114
1115        let http = self.http.clone();
1116        cx.background_executor().spawn(async move {
1117            let mut rpc_url = Self::get_rpc_url(http, release_channel).await?;
1118            let rpc_host = rpc_url
1119                .host_str()
1120                .zip(rpc_url.port_or_known_default())
1121                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1122            let stream = smol::net::TcpStream::connect(rpc_host).await?;
1123
1124            log::info!("connected to rpc endpoint {}", rpc_url);
1125
1126            match rpc_url.scheme() {
1127                "https" => {
1128                    rpc_url.set_scheme("wss").unwrap();
1129                    let request = request.uri(rpc_url.as_str()).body(())?;
1130                    let (stream, _) =
1131                        async_tungstenite::async_std::client_async_tls(request, stream).await?;
1132                    Ok(Connection::new(
1133                        stream
1134                            .map_err(|error| anyhow!(error))
1135                            .sink_map_err(|error| anyhow!(error)),
1136                    ))
1137                }
1138                "http" => {
1139                    rpc_url.set_scheme("ws").unwrap();
1140                    let request = request.uri(rpc_url.as_str()).body(())?;
1141                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1142                    Ok(Connection::new(
1143                        stream
1144                            .map_err(|error| anyhow!(error))
1145                            .sink_map_err(|error| anyhow!(error)),
1146                    ))
1147                }
1148                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1149            }
1150        })
1151    }
1152
1153    pub fn authenticate_with_browser(
1154        self: &Arc<Self>,
1155        cx: &AsyncAppContext,
1156    ) -> Task<Result<Credentials>> {
1157        let http = self.http.clone();
1158        cx.spawn(|cx| async move {
1159            let background = cx.background_executor().clone();
1160
1161            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1162            cx.update(|cx| {
1163                cx.spawn(move |cx| async move {
1164                    let url = open_url_rx.await?;
1165                    cx.update(|cx| cx.open_url(&url))
1166                })
1167                .detach_and_log_err(cx);
1168            })
1169            .log_err();
1170
1171            let credentials = background
1172                .clone()
1173                .spawn(async move {
1174                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1175                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1176                    // any other app running on the user's device.
1177                    let (public_key, private_key) =
1178                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1179                    let public_key_string = String::try_from(public_key)
1180                        .expect("failed to serialize public key for auth");
1181
1182                    if let Some((login, token)) =
1183                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1184                    {
1185                        eprintln!("authenticate as admin {login}, {token}");
1186
1187                        return Self::authenticate_as_admin(http, login.clone(), token.clone())
1188                            .await;
1189                    }
1190
1191                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1192                    let server =
1193                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1194                    let port = server.server_addr().port();
1195
1196                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1197                    // that the user is signing in from a Zed app running on the same device.
1198                    let mut url = http.build_url(&format!(
1199                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1200                        port, public_key_string
1201                    ));
1202
1203                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1204                        log::info!("impersonating user @{}", impersonate_login);
1205                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1206                    }
1207
1208                    open_url_tx.send(url).log_err();
1209
1210                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1211                    // access token from the query params.
1212                    //
1213                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1214                    // custom URL scheme instead of this local HTTP server.
1215                    let (user_id, access_token) = background
1216                        .spawn(async move {
1217                            for _ in 0..100 {
1218                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1219                                    let path = req.url();
1220                                    let mut user_id = None;
1221                                    let mut access_token = None;
1222                                    let url = Url::parse(&format!("http://example.com{}", path))
1223                                        .context("failed to parse login notification url")?;
1224                                    for (key, value) in url.query_pairs() {
1225                                        if key == "access_token" {
1226                                            access_token = Some(value.to_string());
1227                                        } else if key == "user_id" {
1228                                            user_id = Some(value.to_string());
1229                                        }
1230                                    }
1231
1232                                    let post_auth_url =
1233                                        http.build_url("/native_app_signin_succeeded");
1234                                    req.respond(
1235                                        tiny_http::Response::empty(302).with_header(
1236                                            tiny_http::Header::from_bytes(
1237                                                &b"Location"[..],
1238                                                post_auth_url.as_bytes(),
1239                                            )
1240                                            .unwrap(),
1241                                        ),
1242                                    )
1243                                    .context("failed to respond to login http request")?;
1244                                    return Ok((
1245                                        user_id
1246                                            .ok_or_else(|| anyhow!("missing user_id parameter"))?,
1247                                        access_token.ok_or_else(|| {
1248                                            anyhow!("missing access_token parameter")
1249                                        })?,
1250                                    ));
1251                                }
1252                            }
1253
1254                            Err(anyhow!("didn't receive login redirect"))
1255                        })
1256                        .await?;
1257
1258                    let access_token = private_key
1259                        .decrypt_string(&access_token)
1260                        .context("failed to decrypt access token")?;
1261
1262                    Ok(Credentials::User {
1263                        user_id: user_id.parse()?,
1264                        access_token,
1265                    })
1266                })
1267                .await?;
1268
1269            cx.update(|cx| cx.activate(true))?;
1270            Ok(credentials)
1271        })
1272    }
1273
1274    async fn authenticate_as_admin(
1275        http: Arc<HttpClientWithUrl>,
1276        login: String,
1277        mut api_token: String,
1278    ) -> Result<Credentials> {
1279        #[derive(Deserialize)]
1280        struct AuthenticatedUserResponse {
1281            user: User,
1282        }
1283
1284        #[derive(Deserialize)]
1285        struct User {
1286            id: u64,
1287        }
1288
1289        // Use the collab server's admin API to retrieve the id
1290        // of the impersonated user.
1291        let mut url = Self::get_rpc_url(http.clone(), None).await?;
1292        url.set_path("/user");
1293        url.set_query(Some(&format!("github_login={login}")));
1294        let request = Request::get(url.as_str())
1295            .header("Authorization", format!("token {api_token}"))
1296            .body("".into())?;
1297
1298        let mut response = http.send(request).await?;
1299        let mut body = String::new();
1300        response.body_mut().read_to_string(&mut body).await?;
1301        if !response.status().is_success() {
1302            Err(anyhow!(
1303                "admin user request failed {} - {}",
1304                response.status().as_u16(),
1305                body,
1306            ))?;
1307        }
1308        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1309
1310        // Use the admin API token to authenticate as the impersonated user.
1311        api_token.insert_str(0, "ADMIN_TOKEN:");
1312        Ok(Credentials::User {
1313            user_id: response.user.id,
1314            access_token: api_token,
1315        })
1316    }
1317
1318    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncAppContext) {
1319        self.state.write().credentials = None;
1320        self.disconnect(&cx);
1321
1322        if self.has_credentials(cx).await {
1323            self.credentials_provider
1324                .delete_credentials(cx)
1325                .await
1326                .log_err();
1327        }
1328    }
1329
1330    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1331        self.peer.teardown();
1332        self.set_status(Status::SignedOut, cx);
1333    }
1334
1335    pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1336        self.peer.teardown();
1337        self.set_status(Status::ConnectionLost, cx);
1338    }
1339
1340    fn connection_id(&self) -> Result<ConnectionId> {
1341        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1342            Ok(connection_id)
1343        } else {
1344            Err(anyhow!("not connected"))
1345        }
1346    }
1347
1348    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1349        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1350        self.peer.send(self.connection_id()?, message)
1351    }
1352
1353    pub fn request<T: RequestMessage>(
1354        &self,
1355        request: T,
1356    ) -> impl Future<Output = Result<T::Response>> {
1357        self.request_envelope(request)
1358            .map_ok(|envelope| envelope.payload)
1359    }
1360
1361    pub fn request_stream<T: RequestMessage>(
1362        &self,
1363        request: T,
1364    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1365        let client_id = self.id.load(Ordering::SeqCst);
1366        log::debug!(
1367            "rpc request start. client_id:{}. name:{}",
1368            client_id,
1369            T::NAME
1370        );
1371        let response = self
1372            .connection_id()
1373            .map(|conn_id| self.peer.request_stream(conn_id, request));
1374        async move {
1375            let response = response?.await;
1376            log::debug!(
1377                "rpc request finish. client_id:{}. name:{}",
1378                client_id,
1379                T::NAME
1380            );
1381            response
1382        }
1383    }
1384
1385    pub fn request_envelope<T: RequestMessage>(
1386        &self,
1387        request: T,
1388    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1389        let client_id = self.id();
1390        log::debug!(
1391            "rpc request start. client_id:{}. name:{}",
1392            client_id,
1393            T::NAME
1394        );
1395        let response = self
1396            .connection_id()
1397            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1398        async move {
1399            let response = response?.await;
1400            log::debug!(
1401                "rpc request finish. client_id:{}. name:{}",
1402                client_id,
1403                T::NAME
1404            );
1405            response
1406        }
1407    }
1408
1409    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1410        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1411        self.peer.respond(receipt, response)
1412    }
1413
1414    fn respond_with_error<T: RequestMessage>(
1415        &self,
1416        receipt: Receipt<T>,
1417        error: proto::Error,
1418    ) -> Result<()> {
1419        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1420        self.peer.respond_with_error(receipt, error)
1421    }
1422
1423    fn handle_message(
1424        self: &Arc<Client>,
1425        message: Box<dyn AnyTypedEnvelope>,
1426        cx: &AsyncAppContext,
1427    ) {
1428        let mut state = self.state.write();
1429        let type_name = message.payload_type_name();
1430        let payload_type_id = message.payload_type_id();
1431        let sender_id = message.original_sender_id();
1432
1433        let mut subscriber = None;
1434
1435        if let Some(handle) = state
1436            .models_by_message_type
1437            .get(&payload_type_id)
1438            .and_then(|handle| handle.upgrade())
1439        {
1440            subscriber = Some(handle);
1441        } else if let Some((extract_entity_id, entity_type_id)) =
1442            state.entity_id_extractors.get(&payload_type_id).zip(
1443                state
1444                    .entity_types_by_message_type
1445                    .get(&payload_type_id)
1446                    .copied(),
1447            )
1448        {
1449            let entity_id = (extract_entity_id)(message.as_ref());
1450
1451            match state
1452                .entities_by_type_and_remote_id
1453                .get_mut(&(entity_type_id, entity_id))
1454            {
1455                Some(WeakSubscriber::Pending(pending)) => {
1456                    pending.push(message);
1457                    return;
1458                }
1459                Some(weak_subscriber) => match weak_subscriber {
1460                    WeakSubscriber::Entity { handle } => {
1461                        subscriber = handle.upgrade();
1462                    }
1463
1464                    WeakSubscriber::Pending(_) => {}
1465                },
1466                _ => {}
1467            }
1468        }
1469
1470        let subscriber = if let Some(subscriber) = subscriber {
1471            subscriber
1472        } else {
1473            log::info!("unhandled message {}", type_name);
1474            self.peer.respond_with_unhandled_message(message).log_err();
1475            return;
1476        };
1477
1478        let handler = state.message_handlers.get(&payload_type_id).cloned();
1479        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1480        // It also ensures we don't hold the lock while yielding back to the executor, as
1481        // that might cause the executor thread driving this future to block indefinitely.
1482        drop(state);
1483
1484        if let Some(handler) = handler {
1485            let future = handler(subscriber, message, self, cx.clone());
1486            let client_id = self.id();
1487            log::debug!(
1488                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1489                client_id,
1490                sender_id,
1491                type_name
1492            );
1493            cx.spawn(move |_| async move {
1494                    match future.await {
1495                        Ok(()) => {
1496                            log::debug!(
1497                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1498                                client_id,
1499                                sender_id,
1500                                type_name
1501                            );
1502                        }
1503                        Err(error) => {
1504                            log::error!(
1505                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1506                                client_id,
1507                                sender_id,
1508                                type_name,
1509                                error
1510                            );
1511                        }
1512                    }
1513                })
1514                .detach();
1515        } else {
1516            log::info!("unhandled message {}", type_name);
1517            self.peer.respond_with_unhandled_message(message).log_err();
1518        }
1519    }
1520
1521    pub fn telemetry(&self) -> &Arc<Telemetry> {
1522        &self.telemetry
1523    }
1524}
1525
1526#[derive(Serialize, Deserialize)]
1527struct DevelopmentCredentials {
1528    user_id: u64,
1529    access_token: String,
1530}
1531
1532/// A credentials provider that stores credentials in a local file.
1533///
1534/// This MUST only be used in development, as this is not a secure way of storing
1535/// credentials on user machines.
1536///
1537/// Its existence is purely to work around the annoyance of having to constantly
1538/// re-allow access to the system keychain when developing Zed.
1539struct DevelopmentCredentialsProvider {
1540    path: PathBuf,
1541}
1542
1543impl CredentialsProvider for DevelopmentCredentialsProvider {
1544    fn read_credentials<'a>(
1545        &'a self,
1546        _cx: &'a AsyncAppContext,
1547    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1548        async move {
1549            if IMPERSONATE_LOGIN.is_some() {
1550                return None;
1551            }
1552
1553            let json = std::fs::read(&self.path).log_err()?;
1554
1555            let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?;
1556
1557            Some(Credentials::User {
1558                user_id: credentials.user_id,
1559                access_token: credentials.access_token,
1560            })
1561        }
1562        .boxed_local()
1563    }
1564
1565    fn write_credentials<'a>(
1566        &'a self,
1567        user_id: u64,
1568        access_token: String,
1569        _cx: &'a AsyncAppContext,
1570    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1571        async move {
1572            let json = serde_json::to_string(&DevelopmentCredentials {
1573                user_id,
1574                access_token,
1575            })?;
1576
1577            std::fs::write(&self.path, json)?;
1578
1579            Ok(())
1580        }
1581        .boxed_local()
1582    }
1583
1584    fn delete_credentials<'a>(
1585        &'a self,
1586        _cx: &'a AsyncAppContext,
1587    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1588        async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local()
1589    }
1590}
1591
1592/// A credentials provider that stores credentials in the system keychain.
1593struct KeychainCredentialsProvider;
1594
1595impl CredentialsProvider for KeychainCredentialsProvider {
1596    fn read_credentials<'a>(
1597        &'a self,
1598        cx: &'a AsyncAppContext,
1599    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1600        async move {
1601            if IMPERSONATE_LOGIN.is_some() {
1602                return None;
1603            }
1604
1605            let (user_id, access_token) = cx
1606                .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
1607                .log_err()?
1608                .await
1609                .log_err()??;
1610
1611            Some(Credentials::User {
1612                user_id: user_id.parse().ok()?,
1613                access_token: String::from_utf8(access_token).ok()?,
1614            })
1615        }
1616        .boxed_local()
1617    }
1618
1619    fn write_credentials<'a>(
1620        &'a self,
1621        user_id: u64,
1622        access_token: String,
1623        cx: &'a AsyncAppContext,
1624    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1625        async move {
1626            cx.update(move |cx| {
1627                cx.write_credentials(
1628                    &ClientSettings::get_global(cx).server_url,
1629                    &user_id.to_string(),
1630                    access_token.as_bytes(),
1631                )
1632            })?
1633            .await
1634        }
1635        .boxed_local()
1636    }
1637
1638    fn delete_credentials<'a>(
1639        &'a self,
1640        cx: &'a AsyncAppContext,
1641    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1642        async move {
1643            cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
1644                .await
1645        }
1646        .boxed_local()
1647    }
1648}
1649
1650/// prefix for the zed:// url scheme
1651pub static ZED_URL_SCHEME: &str = "zed";
1652
1653/// Parses the given link into a Zed link.
1654///
1655/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1656/// Returns [`None`] otherwise.
1657pub fn parse_zed_link<'a>(link: &'a str, cx: &AppContext) -> Option<&'a str> {
1658    let server_url = &ClientSettings::get_global(cx).server_url;
1659    if let Some(stripped) = link
1660        .strip_prefix(server_url)
1661        .and_then(|result| result.strip_prefix('/'))
1662    {
1663        return Some(stripped);
1664    }
1665    if let Some(stripped) = link
1666        .strip_prefix(ZED_URL_SCHEME)
1667        .and_then(|result| result.strip_prefix("://"))
1668    {
1669        return Some(stripped);
1670    }
1671
1672    None
1673}
1674
1675#[cfg(test)]
1676mod tests {
1677    use super::*;
1678    use crate::test::FakeServer;
1679
1680    use clock::FakeSystemClock;
1681    use gpui::{BackgroundExecutor, Context, TestAppContext};
1682    use parking_lot::Mutex;
1683    use settings::SettingsStore;
1684    use std::future;
1685    use util::http::FakeHttpClient;
1686
1687    #[gpui::test(iterations = 10)]
1688    async fn test_reconnection(cx: &mut TestAppContext) {
1689        init_test(cx);
1690        let user_id = 5;
1691        let client = cx.update(|cx| {
1692            Client::new(
1693                Arc::new(FakeSystemClock::default()),
1694                FakeHttpClient::with_404_response(),
1695                cx,
1696            )
1697        });
1698        let server = FakeServer::for_client(user_id, &client, cx).await;
1699        let mut status = client.status();
1700        assert!(matches!(
1701            status.next().await,
1702            Some(Status::Connected { .. })
1703        ));
1704        assert_eq!(server.auth_count(), 1);
1705
1706        server.forbid_connections();
1707        server.disconnect();
1708        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1709
1710        server.allow_connections();
1711        cx.executor().advance_clock(Duration::from_secs(10));
1712        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1713        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1714
1715        server.forbid_connections();
1716        server.disconnect();
1717        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1718
1719        // Clear cached credentials after authentication fails
1720        server.roll_access_token();
1721        server.allow_connections();
1722        cx.executor().run_until_parked();
1723        cx.executor().advance_clock(Duration::from_secs(10));
1724        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1725        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1726    }
1727
1728    #[gpui::test(iterations = 10)]
1729    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1730        init_test(cx);
1731        let user_id = 5;
1732        let client = cx.update(|cx| {
1733            Client::new(
1734                Arc::new(FakeSystemClock::default()),
1735                FakeHttpClient::with_404_response(),
1736                cx,
1737            )
1738        });
1739        let mut status = client.status();
1740
1741        // Time out when client tries to connect.
1742        client.override_authenticate(move |cx| {
1743            cx.background_executor().spawn(async move {
1744                Ok(Credentials::User {
1745                    user_id,
1746                    access_token: "token".into(),
1747                })
1748            })
1749        });
1750        client.override_establish_connection(|_, cx| {
1751            cx.background_executor().spawn(async move {
1752                future::pending::<()>().await;
1753                unreachable!()
1754            })
1755        });
1756        let auth_and_connect = cx.spawn({
1757            let client = client.clone();
1758            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1759        });
1760        executor.run_until_parked();
1761        assert!(matches!(status.next().await, Some(Status::Connecting)));
1762
1763        executor.advance_clock(CONNECTION_TIMEOUT);
1764        assert!(matches!(
1765            status.next().await,
1766            Some(Status::ConnectionError { .. })
1767        ));
1768        auth_and_connect.await.unwrap_err();
1769
1770        // Allow the connection to be established.
1771        let server = FakeServer::for_client(user_id, &client, cx).await;
1772        assert!(matches!(
1773            status.next().await,
1774            Some(Status::Connected { .. })
1775        ));
1776
1777        // Disconnect client.
1778        server.forbid_connections();
1779        server.disconnect();
1780        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1781
1782        // Time out when re-establishing the connection.
1783        server.allow_connections();
1784        client.override_establish_connection(|_, cx| {
1785            cx.background_executor().spawn(async move {
1786                future::pending::<()>().await;
1787                unreachable!()
1788            })
1789        });
1790        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1791        assert!(matches!(
1792            status.next().await,
1793            Some(Status::Reconnecting { .. })
1794        ));
1795
1796        executor.advance_clock(CONNECTION_TIMEOUT);
1797        assert!(matches!(
1798            status.next().await,
1799            Some(Status::ReconnectionError { .. })
1800        ));
1801    }
1802
1803    #[gpui::test(iterations = 10)]
1804    async fn test_authenticating_more_than_once(
1805        cx: &mut TestAppContext,
1806        executor: BackgroundExecutor,
1807    ) {
1808        init_test(cx);
1809        let auth_count = Arc::new(Mutex::new(0));
1810        let dropped_auth_count = Arc::new(Mutex::new(0));
1811        let client = cx.update(|cx| {
1812            Client::new(
1813                Arc::new(FakeSystemClock::default()),
1814                FakeHttpClient::with_404_response(),
1815                cx,
1816            )
1817        });
1818        client.override_authenticate({
1819            let auth_count = auth_count.clone();
1820            let dropped_auth_count = dropped_auth_count.clone();
1821            move |cx| {
1822                let auth_count = auth_count.clone();
1823                let dropped_auth_count = dropped_auth_count.clone();
1824                cx.background_executor().spawn(async move {
1825                    *auth_count.lock() += 1;
1826                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1827                    future::pending::<()>().await;
1828                    unreachable!()
1829                })
1830            }
1831        });
1832
1833        let _authenticate = cx.spawn({
1834            let client = client.clone();
1835            move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1836        });
1837        executor.run_until_parked();
1838        assert_eq!(*auth_count.lock(), 1);
1839        assert_eq!(*dropped_auth_count.lock(), 0);
1840
1841        let _authenticate = cx.spawn({
1842            let client = client.clone();
1843            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1844        });
1845        executor.run_until_parked();
1846        assert_eq!(*auth_count.lock(), 2);
1847        assert_eq!(*dropped_auth_count.lock(), 1);
1848    }
1849
1850    #[gpui::test]
1851    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1852        init_test(cx);
1853        let user_id = 5;
1854        let client = cx.update(|cx| {
1855            Client::new(
1856                Arc::new(FakeSystemClock::default()),
1857                FakeHttpClient::with_404_response(),
1858                cx,
1859            )
1860        });
1861        let server = FakeServer::for_client(user_id, &client, cx).await;
1862
1863        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1864        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1865        client.add_model_message_handler(
1866            move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
1867                match model.update(&mut cx, |model, _| model.id).unwrap() {
1868                    1 => done_tx1.try_send(()).unwrap(),
1869                    2 => done_tx2.try_send(()).unwrap(),
1870                    _ => unreachable!(),
1871                }
1872                async { Ok(()) }
1873            },
1874        );
1875        let model1 = cx.new_model(|_| TestModel {
1876            id: 1,
1877            subscription: None,
1878        });
1879        let model2 = cx.new_model(|_| TestModel {
1880            id: 2,
1881            subscription: None,
1882        });
1883        let model3 = cx.new_model(|_| TestModel {
1884            id: 3,
1885            subscription: None,
1886        });
1887
1888        let _subscription1 = client
1889            .subscribe_to_entity(1)
1890            .unwrap()
1891            .set_model(&model1, &mut cx.to_async());
1892        let _subscription2 = client
1893            .subscribe_to_entity(2)
1894            .unwrap()
1895            .set_model(&model2, &mut cx.to_async());
1896        // Ensure dropping a subscription for the same entity type still allows receiving of
1897        // messages for other entity IDs of the same type.
1898        let subscription3 = client
1899            .subscribe_to_entity(3)
1900            .unwrap()
1901            .set_model(&model3, &mut cx.to_async());
1902        drop(subscription3);
1903
1904        server.send(proto::JoinProject { project_id: 1 });
1905        server.send(proto::JoinProject { project_id: 2 });
1906        done_rx1.next().await.unwrap();
1907        done_rx2.next().await.unwrap();
1908    }
1909
1910    #[gpui::test]
1911    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1912        init_test(cx);
1913        let user_id = 5;
1914        let client = cx.update(|cx| {
1915            Client::new(
1916                Arc::new(FakeSystemClock::default()),
1917                FakeHttpClient::with_404_response(),
1918                cx,
1919            )
1920        });
1921        let server = FakeServer::for_client(user_id, &client, cx).await;
1922
1923        let model = cx.new_model(|_| TestModel::default());
1924        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1925        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1926        let subscription1 = client.add_message_handler(
1927            model.downgrade(),
1928            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1929                done_tx1.try_send(()).unwrap();
1930                async { Ok(()) }
1931            },
1932        );
1933        drop(subscription1);
1934        let _subscription2 = client.add_message_handler(
1935            model.downgrade(),
1936            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1937                done_tx2.try_send(()).unwrap();
1938                async { Ok(()) }
1939            },
1940        );
1941        server.send(proto::Ping {});
1942        done_rx2.next().await.unwrap();
1943    }
1944
1945    #[gpui::test]
1946    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1947        init_test(cx);
1948        let user_id = 5;
1949        let client = cx.update(|cx| {
1950            Client::new(
1951                Arc::new(FakeSystemClock::default()),
1952                FakeHttpClient::with_404_response(),
1953                cx,
1954            )
1955        });
1956        let server = FakeServer::for_client(user_id, &client, cx).await;
1957
1958        let model = cx.new_model(|_| TestModel::default());
1959        let (done_tx, mut done_rx) = smol::channel::unbounded();
1960        let subscription = client.add_message_handler(
1961            model.clone().downgrade(),
1962            move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1963                model
1964                    .update(&mut cx, |model, _| model.subscription.take())
1965                    .unwrap();
1966                done_tx.try_send(()).unwrap();
1967                async { Ok(()) }
1968            },
1969        );
1970        model.update(cx, |model, _| {
1971            model.subscription = Some(subscription);
1972        });
1973        server.send(proto::Ping {});
1974        done_rx.next().await.unwrap();
1975    }
1976
1977    #[derive(Default)]
1978    struct TestModel {
1979        id: usize,
1980        subscription: Option<Subscription>,
1981    }
1982
1983    fn init_test(cx: &mut TestAppContext) {
1984        cx.update(|cx| {
1985            let settings_store = SettingsStore::test(cx);
1986            cx.set_global(settings_store);
1987            init_settings(cx);
1988        });
1989    }
1990}