client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod telemetry;
   5pub mod user;
   6
   7use anyhow::{anyhow, Context as _, Result};
   8use async_recursion::async_recursion;
   9use async_tungstenite::tungstenite::{
  10    error::Error as WebsocketError,
  11    http::{Request, StatusCode},
  12};
  13use clock::SystemClock;
  14use collections::HashMap;
  15use futures::{
  16    channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
  17    TryFutureExt as _, TryStreamExt,
  18};
  19use gpui::{
  20    actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Global, Model, Task, WeakModel,
  21};
  22use http::{HttpClient, HttpClientWithUrl};
  23use lazy_static::lazy_static;
  24use parking_lot::RwLock;
  25use postage::watch;
  26use rand::prelude::*;
  27use release_channel::{AppVersion, ReleaseChannel};
  28use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::{Settings, SettingsSources};
  32use std::fmt;
  33use std::pin::Pin;
  34use std::{
  35    any::TypeId,
  36    convert::TryFrom,
  37    fmt::Write as _,
  38    future::Future,
  39    marker::PhantomData,
  40    path::PathBuf,
  41    sync::{
  42        atomic::{AtomicU64, Ordering},
  43        Arc, Weak,
  44    },
  45    time::{Duration, Instant},
  46};
  47use telemetry::Telemetry;
  48use thiserror::Error;
  49use url::Url;
  50use util::{ResultExt, TryFutureExt};
  51
  52pub use rpc::*;
  53pub use telemetry_events::Event;
  54pub use user::*;
  55
  56#[derive(Debug, Clone, Eq, PartialEq)]
  57pub struct DevServerToken(pub String);
  58
  59impl fmt::Display for DevServerToken {
  60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  61        write!(f, "{}", self.0)
  62    }
  63}
  64
  65lazy_static! {
  66    static ref ZED_SERVER_URL: Option<String> = std::env::var("ZED_SERVER_URL").ok();
  67    static ref ZED_RPC_URL: Option<String> = std::env::var("ZED_RPC_URL").ok();
  68    /// An environment variable whose presence indicates that the development auth
  69    /// provider should be used.
  70    ///
  71    /// Only works in development. Setting this environment variable in other release
  72    /// channels is a no-op.
  73    pub static ref ZED_DEVELOPMENT_AUTH: bool =
  74        std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty());
  75    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  76        .ok()
  77        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  78    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
  79        .ok()
  80        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  81    pub static ref ZED_APP_PATH: Option<PathBuf> =
  82        std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
  83    pub static ref ZED_ALWAYS_ACTIVE: bool =
  84        std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty());
  85}
  86
  87pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
  88pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  89
  90actions!(client, [SignIn, SignOut, Reconnect]);
  91
  92#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
  93pub struct ClientSettingsContent {
  94    server_url: Option<String>,
  95}
  96
  97#[derive(Deserialize)]
  98pub struct ClientSettings {
  99    pub server_url: String,
 100}
 101
 102impl Settings for ClientSettings {
 103    const KEY: Option<&'static str> = None;
 104
 105    type FileContent = ClientSettingsContent;
 106
 107    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 108        let mut result = sources.json_merge::<Self>()?;
 109        if let Some(server_url) = &*ZED_SERVER_URL {
 110            result.server_url.clone_from(&server_url)
 111        }
 112        Ok(result)
 113    }
 114}
 115
 116#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 117pub struct ProxySettingsContent {
 118    proxy: Option<String>,
 119}
 120
 121#[derive(Deserialize, Default)]
 122pub struct ProxySettings {
 123    pub proxy: Option<String>,
 124}
 125
 126impl Settings for ProxySettings {
 127    const KEY: Option<&'static str> = None;
 128
 129    type FileContent = ProxySettingsContent;
 130
 131    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 132        Ok(Self {
 133            proxy: sources
 134                .user
 135                .and_then(|value| value.proxy.clone())
 136                .or(sources.default.proxy.clone()),
 137        })
 138    }
 139}
 140
 141pub fn init_settings(cx: &mut AppContext) {
 142    TelemetrySettings::register(cx);
 143    ClientSettings::register(cx);
 144    ProxySettings::register(cx);
 145}
 146
 147pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
 148    let client = Arc::downgrade(client);
 149    cx.on_action({
 150        let client = client.clone();
 151        move |_: &SignIn, cx| {
 152            if let Some(client) = client.upgrade() {
 153                cx.spawn(
 154                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
 155                )
 156                .detach();
 157            }
 158        }
 159    });
 160
 161    cx.on_action({
 162        let client = client.clone();
 163        move |_: &SignOut, cx| {
 164            if let Some(client) = client.upgrade() {
 165                cx.spawn(|cx| async move {
 166                    client.sign_out(&cx).await;
 167                })
 168                .detach();
 169            }
 170        }
 171    });
 172
 173    cx.on_action({
 174        let client = client.clone();
 175        move |_: &Reconnect, cx| {
 176            if let Some(client) = client.upgrade() {
 177                cx.spawn(|cx| async move {
 178                    client.reconnect(&cx);
 179                })
 180                .detach();
 181            }
 182        }
 183    });
 184}
 185
 186struct GlobalClient(Arc<Client>);
 187
 188impl Global for GlobalClient {}
 189
 190pub struct Client {
 191    id: AtomicU64,
 192    peer: Arc<Peer>,
 193    http: Arc<HttpClientWithUrl>,
 194    telemetry: Arc<Telemetry>,
 195    credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
 196    state: RwLock<ClientState>,
 197
 198    #[allow(clippy::type_complexity)]
 199    #[cfg(any(test, feature = "test-support"))]
 200    authenticate: RwLock<
 201        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 202    >,
 203
 204    #[allow(clippy::type_complexity)]
 205    #[cfg(any(test, feature = "test-support"))]
 206    establish_connection: RwLock<
 207        Option<
 208            Box<
 209                dyn 'static
 210                    + Send
 211                    + Sync
 212                    + Fn(
 213                        &Credentials,
 214                        &AsyncAppContext,
 215                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 216            >,
 217        >,
 218    >,
 219}
 220
 221#[derive(Error, Debug)]
 222pub enum EstablishConnectionError {
 223    #[error("upgrade required")]
 224    UpgradeRequired,
 225    #[error("unauthorized")]
 226    Unauthorized,
 227    #[error("{0}")]
 228    Other(#[from] anyhow::Error),
 229    #[error("{0}")]
 230    Http(#[from] http::Error),
 231    #[error("{0}")]
 232    Io(#[from] std::io::Error),
 233    #[error("{0}")]
 234    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 235}
 236
 237impl From<WebsocketError> for EstablishConnectionError {
 238    fn from(error: WebsocketError) -> Self {
 239        if let WebsocketError::Http(response) = &error {
 240            match response.status() {
 241                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 242                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 243                _ => {}
 244            }
 245        }
 246        EstablishConnectionError::Other(error.into())
 247    }
 248}
 249
 250impl EstablishConnectionError {
 251    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 252        Self::Other(error.into())
 253    }
 254}
 255
 256#[derive(Copy, Clone, Debug, PartialEq)]
 257pub enum Status {
 258    SignedOut,
 259    UpgradeRequired,
 260    Authenticating,
 261    Connecting,
 262    ConnectionError,
 263    Connected {
 264        peer_id: PeerId,
 265        connection_id: ConnectionId,
 266    },
 267    ConnectionLost,
 268    Reauthenticating,
 269    Reconnecting,
 270    ReconnectionError {
 271        next_reconnection: Instant,
 272    },
 273}
 274
 275impl Status {
 276    pub fn is_connected(&self) -> bool {
 277        matches!(self, Self::Connected { .. })
 278    }
 279
 280    pub fn is_signed_out(&self) -> bool {
 281        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 282    }
 283}
 284
 285struct ClientState {
 286    credentials: Option<Credentials>,
 287    status: (watch::Sender<Status>, watch::Receiver<Status>),
 288    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 289    _reconnect_task: Option<Task<()>>,
 290    reconnect_interval: Duration,
 291    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 292    models_by_message_type: HashMap<TypeId, AnyWeakModel>,
 293    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 294    #[allow(clippy::type_complexity)]
 295    message_handlers: HashMap<
 296        TypeId,
 297        Arc<
 298            dyn Send
 299                + Sync
 300                + Fn(
 301                    AnyModel,
 302                    Box<dyn AnyTypedEnvelope>,
 303                    &Arc<Client>,
 304                    AsyncAppContext,
 305                ) -> LocalBoxFuture<'static, Result<()>>,
 306        >,
 307    >,
 308}
 309
 310enum WeakSubscriber {
 311    Entity { handle: AnyWeakModel },
 312    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 313}
 314
 315#[derive(Clone, Debug, Eq, PartialEq)]
 316pub enum Credentials {
 317    DevServer { token: DevServerToken },
 318    User { user_id: u64, access_token: String },
 319}
 320
 321impl Credentials {
 322    pub fn authorization_header(&self) -> String {
 323        match self {
 324            Credentials::DevServer { token } => format!("dev-server-token {}", token),
 325            Credentials::User {
 326                user_id,
 327                access_token,
 328            } => format!("{} {}", user_id, access_token),
 329        }
 330    }
 331}
 332
 333/// A provider for [`Credentials`].
 334///
 335/// Used to abstract over reading and writing credentials to some form of
 336/// persistence (like the system keychain).
 337trait CredentialsProvider {
 338    /// Reads the credentials from the provider.
 339    fn read_credentials<'a>(
 340        &'a self,
 341        cx: &'a AsyncAppContext,
 342    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>>;
 343
 344    /// Writes the credentials to the provider.
 345    fn write_credentials<'a>(
 346        &'a self,
 347        user_id: u64,
 348        access_token: String,
 349        cx: &'a AsyncAppContext,
 350    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 351
 352    /// Deletes the credentials from the provider.
 353    fn delete_credentials<'a>(
 354        &'a self,
 355        cx: &'a AsyncAppContext,
 356    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 357}
 358
 359impl Default for ClientState {
 360    fn default() -> Self {
 361        Self {
 362            credentials: None,
 363            status: watch::channel_with(Status::SignedOut),
 364            entity_id_extractors: Default::default(),
 365            _reconnect_task: None,
 366            reconnect_interval: Duration::from_secs(5),
 367            models_by_message_type: Default::default(),
 368            entities_by_type_and_remote_id: Default::default(),
 369            entity_types_by_message_type: Default::default(),
 370            message_handlers: Default::default(),
 371        }
 372    }
 373}
 374
 375pub enum Subscription {
 376    Entity {
 377        client: Weak<Client>,
 378        id: (TypeId, u64),
 379    },
 380    Message {
 381        client: Weak<Client>,
 382        id: TypeId,
 383    },
 384}
 385
 386impl Drop for Subscription {
 387    fn drop(&mut self) {
 388        match self {
 389            Subscription::Entity { client, id } => {
 390                if let Some(client) = client.upgrade() {
 391                    let mut state = client.state.write();
 392                    let _ = state.entities_by_type_and_remote_id.remove(id);
 393                }
 394            }
 395            Subscription::Message { client, id } => {
 396                if let Some(client) = client.upgrade() {
 397                    let mut state = client.state.write();
 398                    let _ = state.entity_types_by_message_type.remove(id);
 399                    let _ = state.message_handlers.remove(id);
 400                }
 401            }
 402        }
 403    }
 404}
 405
 406pub struct PendingEntitySubscription<T: 'static> {
 407    client: Arc<Client>,
 408    remote_id: u64,
 409    _entity_type: PhantomData<T>,
 410    consumed: bool,
 411}
 412
 413impl<T: 'static> PendingEntitySubscription<T> {
 414    pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
 415        self.consumed = true;
 416        let mut state = self.client.state.write();
 417        let id = (TypeId::of::<T>(), self.remote_id);
 418        let Some(WeakSubscriber::Pending(messages)) =
 419            state.entities_by_type_and_remote_id.remove(&id)
 420        else {
 421            unreachable!()
 422        };
 423
 424        state.entities_by_type_and_remote_id.insert(
 425            id,
 426            WeakSubscriber::Entity {
 427                handle: model.downgrade().into(),
 428            },
 429        );
 430        drop(state);
 431        for message in messages {
 432            self.client.handle_message(message, cx);
 433        }
 434        Subscription::Entity {
 435            client: Arc::downgrade(&self.client),
 436            id,
 437        }
 438    }
 439}
 440
 441impl<T: 'static> Drop for PendingEntitySubscription<T> {
 442    fn drop(&mut self) {
 443        if !self.consumed {
 444            let mut state = self.client.state.write();
 445            if let Some(WeakSubscriber::Pending(messages)) = state
 446                .entities_by_type_and_remote_id
 447                .remove(&(TypeId::of::<T>(), self.remote_id))
 448            {
 449                for message in messages {
 450                    log::info!("unhandled message {}", message.payload_type_name());
 451                }
 452            }
 453        }
 454    }
 455}
 456
 457#[derive(Copy, Clone)]
 458pub struct TelemetrySettings {
 459    pub diagnostics: bool,
 460    pub metrics: bool,
 461}
 462
 463/// Control what info is collected by Zed.
 464#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 465pub struct TelemetrySettingsContent {
 466    /// Send debug info like crash reports.
 467    ///
 468    /// Default: true
 469    pub diagnostics: Option<bool>,
 470    /// Send anonymized usage data like what languages you're using Zed with.
 471    ///
 472    /// Default: true
 473    pub metrics: Option<bool>,
 474}
 475
 476impl settings::Settings for TelemetrySettings {
 477    const KEY: Option<&'static str> = Some("telemetry");
 478
 479    type FileContent = TelemetrySettingsContent;
 480
 481    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 482        Ok(Self {
 483            diagnostics: sources.user.as_ref().and_then(|v| v.diagnostics).unwrap_or(
 484                sources
 485                    .default
 486                    .diagnostics
 487                    .ok_or_else(Self::missing_default)?,
 488            ),
 489            metrics: sources
 490                .user
 491                .as_ref()
 492                .and_then(|v| v.metrics)
 493                .unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
 494        })
 495    }
 496}
 497
 498impl Client {
 499    pub fn new(
 500        clock: Arc<dyn SystemClock>,
 501        http: Arc<HttpClientWithUrl>,
 502        cx: &mut AppContext,
 503    ) -> Arc<Self> {
 504        let use_zed_development_auth = match ReleaseChannel::try_global(cx) {
 505            Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH,
 506            Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
 507            | None => false,
 508        };
 509
 510        let credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static> =
 511            if use_zed_development_auth {
 512                Arc::new(DevelopmentCredentialsProvider {
 513                    path: util::paths::CONFIG_DIR.join("development_auth"),
 514                })
 515            } else {
 516                Arc::new(KeychainCredentialsProvider)
 517            };
 518
 519        Arc::new(Self {
 520            id: AtomicU64::new(0),
 521            peer: Peer::new(0),
 522            telemetry: Telemetry::new(clock, http.clone(), cx),
 523            http,
 524            credentials_provider,
 525            state: Default::default(),
 526
 527            #[cfg(any(test, feature = "test-support"))]
 528            authenticate: Default::default(),
 529            #[cfg(any(test, feature = "test-support"))]
 530            establish_connection: Default::default(),
 531        })
 532    }
 533
 534    pub fn production(cx: &mut AppContext) -> Arc<Self> {
 535        let clock = Arc::new(clock::RealSystemClock);
 536        let http = Arc::new(HttpClientWithUrl::new(
 537            &ClientSettings::get_global(cx).server_url,
 538            ProxySettings::get_global(cx).proxy.clone(),
 539        ));
 540        Self::new(clock, http.clone(), cx)
 541    }
 542
 543    pub fn id(&self) -> u64 {
 544        self.id.load(Ordering::SeqCst)
 545    }
 546
 547    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 548        self.http.clone()
 549    }
 550
 551    pub fn set_id(&self, id: u64) -> &Self {
 552        self.id.store(id, Ordering::SeqCst);
 553        self
 554    }
 555
 556    #[cfg(any(test, feature = "test-support"))]
 557    pub fn teardown(&self) {
 558        let mut state = self.state.write();
 559        state._reconnect_task.take();
 560        state.message_handlers.clear();
 561        state.models_by_message_type.clear();
 562        state.entities_by_type_and_remote_id.clear();
 563        state.entity_id_extractors.clear();
 564        self.peer.teardown();
 565    }
 566
 567    #[cfg(any(test, feature = "test-support"))]
 568    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 569    where
 570        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 571    {
 572        *self.authenticate.write() = Some(Box::new(authenticate));
 573        self
 574    }
 575
 576    #[cfg(any(test, feature = "test-support"))]
 577    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 578    where
 579        F: 'static
 580            + Send
 581            + Sync
 582            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 583    {
 584        *self.establish_connection.write() = Some(Box::new(connect));
 585        self
 586    }
 587
 588    pub fn global(cx: &AppContext) -> Arc<Self> {
 589        cx.global::<GlobalClient>().0.clone()
 590    }
 591    pub fn set_global(client: Arc<Client>, cx: &mut AppContext) {
 592        cx.set_global(GlobalClient(client))
 593    }
 594
 595    pub fn user_id(&self) -> Option<u64> {
 596        if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() {
 597            Some(*user_id)
 598        } else {
 599            None
 600        }
 601    }
 602
 603    pub fn peer_id(&self) -> Option<PeerId> {
 604        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 605            Some(*peer_id)
 606        } else {
 607            None
 608        }
 609    }
 610
 611    pub fn status(&self) -> watch::Receiver<Status> {
 612        self.state.read().status.1.clone()
 613    }
 614
 615    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 616        log::info!("set status on client {}: {:?}", self.id(), status);
 617        let mut state = self.state.write();
 618        *state.status.0.borrow_mut() = status;
 619
 620        match status {
 621            Status::Connected { .. } => {
 622                state._reconnect_task = None;
 623            }
 624            Status::ConnectionLost => {
 625                let this = self.clone();
 626                let reconnect_interval = state.reconnect_interval;
 627                state._reconnect_task = Some(cx.spawn(move |cx| async move {
 628                    #[cfg(any(test, feature = "test-support"))]
 629                    let mut rng = StdRng::seed_from_u64(0);
 630                    #[cfg(not(any(test, feature = "test-support")))]
 631                    let mut rng = StdRng::from_entropy();
 632
 633                    let mut delay = INITIAL_RECONNECTION_DELAY;
 634                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 635                        log::error!("failed to connect {}", error);
 636                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 637                            this.set_status(
 638                                Status::ReconnectionError {
 639                                    next_reconnection: Instant::now() + delay,
 640                                },
 641                                &cx,
 642                            );
 643                            cx.background_executor().timer(delay).await;
 644                            delay = delay
 645                                .mul_f32(rng.gen_range(1.0..=2.0))
 646                                .min(reconnect_interval);
 647                        } else {
 648                            break;
 649                        }
 650                    }
 651                }));
 652            }
 653            Status::SignedOut | Status::UpgradeRequired => {
 654                self.telemetry.set_authenticated_user_info(None, false);
 655                state._reconnect_task.take();
 656            }
 657            _ => {}
 658        }
 659    }
 660
 661    pub fn subscribe_to_entity<T>(
 662        self: &Arc<Self>,
 663        remote_id: u64,
 664    ) -> Result<PendingEntitySubscription<T>>
 665    where
 666        T: 'static,
 667    {
 668        let id = (TypeId::of::<T>(), remote_id);
 669
 670        let mut state = self.state.write();
 671        if state.entities_by_type_and_remote_id.contains_key(&id) {
 672            return Err(anyhow!("already subscribed to entity"));
 673        }
 674
 675        state
 676            .entities_by_type_and_remote_id
 677            .insert(id, WeakSubscriber::Pending(Default::default()));
 678
 679        Ok(PendingEntitySubscription {
 680            client: self.clone(),
 681            remote_id,
 682            consumed: false,
 683            _entity_type: PhantomData,
 684        })
 685    }
 686
 687    #[track_caller]
 688    pub fn add_message_handler<M, E, H, F>(
 689        self: &Arc<Self>,
 690        entity: WeakModel<E>,
 691        handler: H,
 692    ) -> Subscription
 693    where
 694        M: EnvelopedMessage,
 695        E: 'static,
 696        H: 'static
 697            + Sync
 698            + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
 699            + Send
 700            + Sync,
 701        F: 'static + Future<Output = Result<()>>,
 702    {
 703        let message_type_id = TypeId::of::<M>();
 704        let mut state = self.state.write();
 705        state
 706            .models_by_message_type
 707            .insert(message_type_id, entity.into());
 708
 709        let prev_handler = state.message_handlers.insert(
 710            message_type_id,
 711            Arc::new(move |subscriber, envelope, client, cx| {
 712                let subscriber = subscriber.downcast::<E>().unwrap();
 713                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 714                handler(subscriber, *envelope, client.clone(), cx).boxed_local()
 715            }),
 716        );
 717        if prev_handler.is_some() {
 718            let location = std::panic::Location::caller();
 719            panic!(
 720                "{}:{} registered handler for the same message {} twice",
 721                location.file(),
 722                location.line(),
 723                std::any::type_name::<M>()
 724            );
 725        }
 726
 727        Subscription::Message {
 728            client: Arc::downgrade(self),
 729            id: message_type_id,
 730        }
 731    }
 732
 733    pub fn add_request_handler<M, E, H, F>(
 734        self: &Arc<Self>,
 735        model: WeakModel<E>,
 736        handler: H,
 737    ) -> Subscription
 738    where
 739        M: RequestMessage,
 740        E: 'static,
 741        H: 'static
 742            + Sync
 743            + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
 744            + Send
 745            + Sync,
 746        F: 'static + Future<Output = Result<M::Response>>,
 747    {
 748        self.add_message_handler(model, move |handle, envelope, this, cx| {
 749            Self::respond_to_request(
 750                envelope.receipt(),
 751                handler(handle, envelope, this.clone(), cx),
 752                this,
 753            )
 754        })
 755    }
 756
 757    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 758    where
 759        M: EntityMessage,
 760        E: 'static,
 761        H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
 762        F: 'static + Future<Output = Result<()>>,
 763    {
 764        self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
 765            handler(subscriber.downcast::<E>().unwrap(), message, client, cx)
 766        })
 767    }
 768
 769    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 770    where
 771        M: EntityMessage,
 772        E: 'static,
 773        H: 'static + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
 774        F: 'static + Future<Output = Result<()>>,
 775    {
 776        let model_type_id = TypeId::of::<E>();
 777        let message_type_id = TypeId::of::<M>();
 778
 779        let mut state = self.state.write();
 780        state
 781            .entity_types_by_message_type
 782            .insert(message_type_id, model_type_id);
 783        state
 784            .entity_id_extractors
 785            .entry(message_type_id)
 786            .or_insert_with(|| {
 787                |envelope| {
 788                    envelope
 789                        .as_any()
 790                        .downcast_ref::<TypedEnvelope<M>>()
 791                        .unwrap()
 792                        .payload
 793                        .remote_entity_id()
 794                }
 795            });
 796        let prev_handler = state.message_handlers.insert(
 797            message_type_id,
 798            Arc::new(move |handle, envelope, client, cx| {
 799                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 800                handler(handle, *envelope, client.clone(), cx).boxed_local()
 801            }),
 802        );
 803        if prev_handler.is_some() {
 804            panic!("registered handler for the same message twice");
 805        }
 806    }
 807
 808    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 809    where
 810        M: EntityMessage + RequestMessage,
 811        E: 'static,
 812        H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
 813        F: 'static + Future<Output = Result<M::Response>>,
 814    {
 815        self.add_model_message_handler(move |entity, envelope, client, cx| {
 816            Self::respond_to_request::<M, _>(
 817                envelope.receipt(),
 818                handler(entity, envelope, client.clone(), cx),
 819                client,
 820            )
 821        })
 822    }
 823
 824    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 825        receipt: Receipt<T>,
 826        response: F,
 827        client: Arc<Self>,
 828    ) -> Result<()> {
 829        match response.await {
 830            Ok(response) => {
 831                client.respond(receipt, response)?;
 832                Ok(())
 833            }
 834            Err(error) => {
 835                client.respond_with_error(receipt, error.to_proto())?;
 836                Err(error)
 837            }
 838        }
 839    }
 840
 841    pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool {
 842        self.credentials_provider
 843            .read_credentials(cx)
 844            .await
 845            .is_some()
 846    }
 847
 848    pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self {
 849        self.state.write().credentials = Some(Credentials::DevServer { token });
 850        self
 851    }
 852
 853    #[async_recursion(?Send)]
 854    pub async fn authenticate_and_connect(
 855        self: &Arc<Self>,
 856        try_provider: bool,
 857        cx: &AsyncAppContext,
 858    ) -> anyhow::Result<()> {
 859        let was_disconnected = match *self.status().borrow() {
 860            Status::SignedOut => true,
 861            Status::ConnectionError
 862            | Status::ConnectionLost
 863            | Status::Authenticating { .. }
 864            | Status::Reauthenticating { .. }
 865            | Status::ReconnectionError { .. } => false,
 866            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 867                return Ok(())
 868            }
 869            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 870        };
 871        if was_disconnected {
 872            self.set_status(Status::Authenticating, cx);
 873        } else {
 874            self.set_status(Status::Reauthenticating, cx)
 875        }
 876
 877        let mut read_from_provider = false;
 878        let mut credentials = self.state.read().credentials.clone();
 879        if credentials.is_none() && try_provider {
 880            credentials = self.credentials_provider.read_credentials(cx).await;
 881            read_from_provider = credentials.is_some();
 882        }
 883
 884        if credentials.is_none() {
 885            let mut status_rx = self.status();
 886            let _ = status_rx.next().await;
 887            futures::select_biased! {
 888                authenticate = self.authenticate(cx).fuse() => {
 889                    match authenticate {
 890                        Ok(creds) => credentials = Some(creds),
 891                        Err(err) => {
 892                            self.set_status(Status::ConnectionError, cx);
 893                            return Err(err);
 894                        }
 895                    }
 896                }
 897                _ = status_rx.next().fuse() => {
 898                    return Err(anyhow!("authentication canceled"));
 899                }
 900            }
 901        }
 902        let credentials = credentials.unwrap();
 903        if let Credentials::User { user_id, .. } = &credentials {
 904            self.set_id(*user_id);
 905        }
 906
 907        if was_disconnected {
 908            self.set_status(Status::Connecting, cx);
 909        } else {
 910            self.set_status(Status::Reconnecting, cx);
 911        }
 912
 913        let mut timeout =
 914            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
 915        futures::select_biased! {
 916            connection = self.establish_connection(&credentials, cx).fuse() => {
 917                match connection {
 918                    Ok(conn) => {
 919                        self.state.write().credentials = Some(credentials.clone());
 920                        if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
 921                            if let Credentials::User{user_id, access_token} = credentials {
 922                                self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err();
 923                            }
 924                        }
 925
 926                        futures::select_biased! {
 927                            result = self.set_connection(conn, cx).fuse() => result,
 928                            _ = timeout => {
 929                                self.set_status(Status::ConnectionError, cx);
 930                                Err(anyhow!("timed out waiting on hello message from server"))
 931                            }
 932                        }
 933                    }
 934                    Err(EstablishConnectionError::Unauthorized) => {
 935                        self.state.write().credentials.take();
 936                        if read_from_provider {
 937                            self.credentials_provider.delete_credentials(cx).await.log_err();
 938                            self.set_status(Status::SignedOut, cx);
 939                            self.authenticate_and_connect(false, cx).await
 940                        } else {
 941                            self.set_status(Status::ConnectionError, cx);
 942                            Err(EstablishConnectionError::Unauthorized)?
 943                        }
 944                    }
 945                    Err(EstablishConnectionError::UpgradeRequired) => {
 946                        self.set_status(Status::UpgradeRequired, cx);
 947                        Err(EstablishConnectionError::UpgradeRequired)?
 948                    }
 949                    Err(error) => {
 950                        self.set_status(Status::ConnectionError, cx);
 951                        Err(error)?
 952                    }
 953                }
 954            }
 955            _ = &mut timeout => {
 956                self.set_status(Status::ConnectionError, cx);
 957                Err(anyhow!("timed out trying to establish connection"))
 958            }
 959        }
 960    }
 961
 962    async fn set_connection(
 963        self: &Arc<Self>,
 964        conn: Connection,
 965        cx: &AsyncAppContext,
 966    ) -> Result<()> {
 967        let executor = cx.background_executor();
 968        log::info!("add connection to peer");
 969        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
 970            let executor = executor.clone();
 971            move |duration| executor.timer(duration)
 972        });
 973        let handle_io = executor.spawn(handle_io);
 974
 975        let peer_id = async {
 976            log::info!("waiting for server hello");
 977            let message = incoming
 978                .next()
 979                .await
 980                .ok_or_else(|| anyhow!("no hello message received"))?;
 981            log::info!("got server hello");
 982            let hello_message_type_name = message.payload_type_name().to_string();
 983            let hello = message
 984                .into_any()
 985                .downcast::<TypedEnvelope<proto::Hello>>()
 986                .map_err(|_| {
 987                    anyhow!(
 988                        "invalid hello message received: {:?}",
 989                        hello_message_type_name
 990                    )
 991                })?;
 992            let peer_id = hello
 993                .payload
 994                .peer_id
 995                .ok_or_else(|| anyhow!("invalid peer id"))?;
 996            Ok(peer_id)
 997        };
 998
 999        let peer_id = match peer_id.await {
1000            Ok(peer_id) => peer_id,
1001            Err(error) => {
1002                self.peer.disconnect(connection_id);
1003                return Err(error);
1004            }
1005        };
1006
1007        log::info!(
1008            "set status to connected (connection id: {:?}, peer id: {:?})",
1009            connection_id,
1010            peer_id
1011        );
1012        self.set_status(
1013            Status::Connected {
1014                peer_id,
1015                connection_id,
1016            },
1017            cx,
1018        );
1019
1020        cx.spawn({
1021            let this = self.clone();
1022            |cx| {
1023                async move {
1024                    while let Some(message) = incoming.next().await {
1025                        this.handle_message(message, &cx);
1026                        // Don't starve the main thread when receiving lots of messages at once.
1027                        smol::future::yield_now().await;
1028                    }
1029                }
1030            }
1031        })
1032        .detach();
1033
1034        cx.spawn({
1035            let this = self.clone();
1036            move |cx| async move {
1037                match handle_io.await {
1038                    Ok(()) => {
1039                        if *this.status().borrow()
1040                            == (Status::Connected {
1041                                connection_id,
1042                                peer_id,
1043                            })
1044                        {
1045                            this.set_status(Status::SignedOut, &cx);
1046                        }
1047                    }
1048                    Err(err) => {
1049                        log::error!("connection error: {:?}", err);
1050                        this.set_status(Status::ConnectionLost, &cx);
1051                    }
1052                }
1053            }
1054        })
1055        .detach();
1056
1057        Ok(())
1058    }
1059
1060    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
1061        #[cfg(any(test, feature = "test-support"))]
1062        if let Some(callback) = self.authenticate.read().as_ref() {
1063            return callback(cx);
1064        }
1065
1066        self.authenticate_with_browser(cx)
1067    }
1068
1069    fn establish_connection(
1070        self: &Arc<Self>,
1071        credentials: &Credentials,
1072        cx: &AsyncAppContext,
1073    ) -> Task<Result<Connection, EstablishConnectionError>> {
1074        #[cfg(any(test, feature = "test-support"))]
1075        if let Some(callback) = self.establish_connection.read().as_ref() {
1076            return callback(credentials, cx);
1077        }
1078
1079        self.establish_websocket_connection(credentials, cx)
1080    }
1081
1082    async fn get_rpc_url(
1083        http: Arc<HttpClientWithUrl>,
1084        release_channel: Option<ReleaseChannel>,
1085    ) -> Result<Url> {
1086        if let Some(url) = &*ZED_RPC_URL {
1087            return Url::parse(url).context("invalid rpc url");
1088        }
1089
1090        let mut url = http.build_url("/rpc");
1091        if let Some(preview_param) =
1092            release_channel.and_then(|channel| channel.release_query_param())
1093        {
1094            url += "?";
1095            url += preview_param;
1096        }
1097        let response = http.get(&url, Default::default(), false).await?;
1098        let collab_url = if response.status().is_redirection() {
1099            response
1100                .headers()
1101                .get("Location")
1102                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1103                .to_str()
1104                .map_err(EstablishConnectionError::other)?
1105                .to_string()
1106        } else {
1107            Err(anyhow!(
1108                "unexpected /rpc response status {}",
1109                response.status()
1110            ))?
1111        };
1112
1113        Url::parse(&collab_url).context("invalid rpc url")
1114    }
1115
1116    fn establish_websocket_connection(
1117        self: &Arc<Self>,
1118        credentials: &Credentials,
1119        cx: &AsyncAppContext,
1120    ) -> Task<Result<Connection, EstablishConnectionError>> {
1121        let release_channel = cx
1122            .update(|cx| ReleaseChannel::try_global(cx))
1123            .ok()
1124            .flatten();
1125        let app_version = cx
1126            .update(|cx| AppVersion::global(cx).to_string())
1127            .ok()
1128            .unwrap_or_default();
1129
1130        let request = Request::builder()
1131            .header("Authorization", credentials.authorization_header())
1132            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION)
1133            .header("x-zed-app-version", app_version)
1134            .header(
1135                "x-zed-release-channel",
1136                release_channel.map(|r| r.dev_name()).unwrap_or("unknown"),
1137            );
1138
1139        let http = self.http.clone();
1140        cx.background_executor().spawn(async move {
1141            let mut rpc_url = Self::get_rpc_url(http, release_channel).await?;
1142            let rpc_host = rpc_url
1143                .host_str()
1144                .zip(rpc_url.port_or_known_default())
1145                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1146            let stream = smol::net::TcpStream::connect(rpc_host).await?;
1147
1148            log::info!("connected to rpc endpoint {}", rpc_url);
1149
1150            match rpc_url.scheme() {
1151                "https" => {
1152                    rpc_url.set_scheme("wss").unwrap();
1153                    let request = request.uri(rpc_url.as_str()).body(())?;
1154                    let (stream, _) =
1155                        async_tungstenite::async_std::client_async_tls(request, stream).await?;
1156                    Ok(Connection::new(
1157                        stream
1158                            .map_err(|error| anyhow!(error))
1159                            .sink_map_err(|error| anyhow!(error)),
1160                    ))
1161                }
1162                "http" => {
1163                    rpc_url.set_scheme("ws").unwrap();
1164                    let request = request.uri(rpc_url.as_str()).body(())?;
1165                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1166                    Ok(Connection::new(
1167                        stream
1168                            .map_err(|error| anyhow!(error))
1169                            .sink_map_err(|error| anyhow!(error)),
1170                    ))
1171                }
1172                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1173            }
1174        })
1175    }
1176
1177    pub fn authenticate_with_browser(
1178        self: &Arc<Self>,
1179        cx: &AsyncAppContext,
1180    ) -> Task<Result<Credentials>> {
1181        let http = self.http.clone();
1182        cx.spawn(|cx| async move {
1183            let background = cx.background_executor().clone();
1184
1185            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1186            cx.update(|cx| {
1187                cx.spawn(move |cx| async move {
1188                    let url = open_url_rx.await?;
1189                    cx.update(|cx| cx.open_url(&url))
1190                })
1191                .detach_and_log_err(cx);
1192            })
1193            .log_err();
1194
1195            let credentials = background
1196                .clone()
1197                .spawn(async move {
1198                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1199                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1200                    // any other app running on the user's device.
1201                    let (public_key, private_key) =
1202                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1203                    let public_key_string = String::try_from(public_key)
1204                        .expect("failed to serialize public key for auth");
1205
1206                    if let Some((login, token)) =
1207                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1208                    {
1209                        eprintln!("authenticate as admin {login}, {token}");
1210
1211                        return Self::authenticate_as_admin(http, login.clone(), token.clone())
1212                            .await;
1213                    }
1214
1215                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1216                    let server =
1217                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1218                    let port = server.server_addr().port();
1219
1220                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1221                    // that the user is signing in from a Zed app running on the same device.
1222                    let mut url = http.build_url(&format!(
1223                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1224                        port, public_key_string
1225                    ));
1226
1227                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1228                        log::info!("impersonating user @{}", impersonate_login);
1229                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1230                    }
1231
1232                    open_url_tx.send(url).log_err();
1233
1234                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1235                    // access token from the query params.
1236                    //
1237                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1238                    // custom URL scheme instead of this local HTTP server.
1239                    let (user_id, access_token) = background
1240                        .spawn(async move {
1241                            for _ in 0..100 {
1242                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1243                                    let path = req.url();
1244                                    let mut user_id = None;
1245                                    let mut access_token = None;
1246                                    let url = Url::parse(&format!("http://example.com{}", path))
1247                                        .context("failed to parse login notification url")?;
1248                                    for (key, value) in url.query_pairs() {
1249                                        if key == "access_token" {
1250                                            access_token = Some(value.to_string());
1251                                        } else if key == "user_id" {
1252                                            user_id = Some(value.to_string());
1253                                        }
1254                                    }
1255
1256                                    let post_auth_url =
1257                                        http.build_url("/native_app_signin_succeeded");
1258                                    req.respond(
1259                                        tiny_http::Response::empty(302).with_header(
1260                                            tiny_http::Header::from_bytes(
1261                                                &b"Location"[..],
1262                                                post_auth_url.as_bytes(),
1263                                            )
1264                                            .unwrap(),
1265                                        ),
1266                                    )
1267                                    .context("failed to respond to login http request")?;
1268                                    return Ok((
1269                                        user_id
1270                                            .ok_or_else(|| anyhow!("missing user_id parameter"))?,
1271                                        access_token.ok_or_else(|| {
1272                                            anyhow!("missing access_token parameter")
1273                                        })?,
1274                                    ));
1275                                }
1276                            }
1277
1278                            Err(anyhow!("didn't receive login redirect"))
1279                        })
1280                        .await?;
1281
1282                    let access_token = private_key
1283                        .decrypt_string(&access_token)
1284                        .context("failed to decrypt access token")?;
1285
1286                    Ok(Credentials::User {
1287                        user_id: user_id.parse()?,
1288                        access_token,
1289                    })
1290                })
1291                .await?;
1292
1293            cx.update(|cx| cx.activate(true))?;
1294            Ok(credentials)
1295        })
1296    }
1297
1298    async fn authenticate_as_admin(
1299        http: Arc<HttpClientWithUrl>,
1300        login: String,
1301        mut api_token: String,
1302    ) -> Result<Credentials> {
1303        #[derive(Deserialize)]
1304        struct AuthenticatedUserResponse {
1305            user: User,
1306        }
1307
1308        #[derive(Deserialize)]
1309        struct User {
1310            id: u64,
1311        }
1312
1313        // Use the collab server's admin API to retrieve the id
1314        // of the impersonated user.
1315        let mut url = Self::get_rpc_url(http.clone(), None).await?;
1316        url.set_path("/user");
1317        url.set_query(Some(&format!("github_login={login}")));
1318        let request = Request::get(url.as_str())
1319            .header("Authorization", format!("token {api_token}"))
1320            .body("".into())?;
1321
1322        let mut response = http.send(request).await?;
1323        let mut body = String::new();
1324        response.body_mut().read_to_string(&mut body).await?;
1325        if !response.status().is_success() {
1326            Err(anyhow!(
1327                "admin user request failed {} - {}",
1328                response.status().as_u16(),
1329                body,
1330            ))?;
1331        }
1332        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1333
1334        // Use the admin API token to authenticate as the impersonated user.
1335        api_token.insert_str(0, "ADMIN_TOKEN:");
1336        Ok(Credentials::User {
1337            user_id: response.user.id,
1338            access_token: api_token,
1339        })
1340    }
1341
1342    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncAppContext) {
1343        self.state.write().credentials = None;
1344        self.disconnect(&cx);
1345
1346        if self.has_credentials(cx).await {
1347            self.credentials_provider
1348                .delete_credentials(cx)
1349                .await
1350                .log_err();
1351        }
1352    }
1353
1354    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1355        self.peer.teardown();
1356        self.set_status(Status::SignedOut, cx);
1357    }
1358
1359    pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1360        self.peer.teardown();
1361        self.set_status(Status::ConnectionLost, cx);
1362    }
1363
1364    fn connection_id(&self) -> Result<ConnectionId> {
1365        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1366            Ok(connection_id)
1367        } else {
1368            Err(anyhow!("not connected"))
1369        }
1370    }
1371
1372    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1373        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1374        self.peer.send(self.connection_id()?, message)
1375    }
1376
1377    pub fn request<T: RequestMessage>(
1378        &self,
1379        request: T,
1380    ) -> impl Future<Output = Result<T::Response>> {
1381        self.request_envelope(request)
1382            .map_ok(|envelope| envelope.payload)
1383    }
1384
1385    pub fn request_stream<T: RequestMessage>(
1386        &self,
1387        request: T,
1388    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1389        let client_id = self.id.load(Ordering::SeqCst);
1390        log::debug!(
1391            "rpc request start. client_id:{}. name:{}",
1392            client_id,
1393            T::NAME
1394        );
1395        let response = self
1396            .connection_id()
1397            .map(|conn_id| self.peer.request_stream(conn_id, request));
1398        async move {
1399            let response = response?.await;
1400            log::debug!(
1401                "rpc request finish. client_id:{}. name:{}",
1402                client_id,
1403                T::NAME
1404            );
1405            response
1406        }
1407    }
1408
1409    pub fn request_envelope<T: RequestMessage>(
1410        &self,
1411        request: T,
1412    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1413        let client_id = self.id();
1414        log::debug!(
1415            "rpc request start. client_id:{}. name:{}",
1416            client_id,
1417            T::NAME
1418        );
1419        let response = self
1420            .connection_id()
1421            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1422        async move {
1423            let response = response?.await;
1424            log::debug!(
1425                "rpc request finish. client_id:{}. name:{}",
1426                client_id,
1427                T::NAME
1428            );
1429            response
1430        }
1431    }
1432
1433    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1434        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1435        self.peer.respond(receipt, response)
1436    }
1437
1438    fn respond_with_error<T: RequestMessage>(
1439        &self,
1440        receipt: Receipt<T>,
1441        error: proto::Error,
1442    ) -> Result<()> {
1443        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1444        self.peer.respond_with_error(receipt, error)
1445    }
1446
1447    fn handle_message(
1448        self: &Arc<Client>,
1449        message: Box<dyn AnyTypedEnvelope>,
1450        cx: &AsyncAppContext,
1451    ) {
1452        let mut state = self.state.write();
1453        let type_name = message.payload_type_name();
1454        let payload_type_id = message.payload_type_id();
1455        let sender_id = message.original_sender_id();
1456
1457        let mut subscriber = None;
1458
1459        if let Some(handle) = state
1460            .models_by_message_type
1461            .get(&payload_type_id)
1462            .and_then(|handle| handle.upgrade())
1463        {
1464            subscriber = Some(handle);
1465        } else if let Some((extract_entity_id, entity_type_id)) =
1466            state.entity_id_extractors.get(&payload_type_id).zip(
1467                state
1468                    .entity_types_by_message_type
1469                    .get(&payload_type_id)
1470                    .copied(),
1471            )
1472        {
1473            let entity_id = (extract_entity_id)(message.as_ref());
1474
1475            match state
1476                .entities_by_type_and_remote_id
1477                .get_mut(&(entity_type_id, entity_id))
1478            {
1479                Some(WeakSubscriber::Pending(pending)) => {
1480                    pending.push(message);
1481                    return;
1482                }
1483                Some(weak_subscriber) => match weak_subscriber {
1484                    WeakSubscriber::Entity { handle } => {
1485                        subscriber = handle.upgrade();
1486                    }
1487
1488                    WeakSubscriber::Pending(_) => {}
1489                },
1490                _ => {}
1491            }
1492        }
1493
1494        let subscriber = if let Some(subscriber) = subscriber {
1495            subscriber
1496        } else {
1497            log::info!("unhandled message {}", type_name);
1498            self.peer.respond_with_unhandled_message(message).log_err();
1499            return;
1500        };
1501
1502        let handler = state.message_handlers.get(&payload_type_id).cloned();
1503        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1504        // It also ensures we don't hold the lock while yielding back to the executor, as
1505        // that might cause the executor thread driving this future to block indefinitely.
1506        drop(state);
1507
1508        if let Some(handler) = handler {
1509            let future = handler(subscriber, message, self, cx.clone());
1510            let client_id = self.id();
1511            log::debug!(
1512                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1513                client_id,
1514                sender_id,
1515                type_name
1516            );
1517            cx.spawn(move |_| async move {
1518                    match future.await {
1519                        Ok(()) => {
1520                            log::debug!(
1521                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1522                                client_id,
1523                                sender_id,
1524                                type_name
1525                            );
1526                        }
1527                        Err(error) => {
1528                            log::error!(
1529                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1530                                client_id,
1531                                sender_id,
1532                                type_name,
1533                                error
1534                            );
1535                        }
1536                    }
1537                })
1538                .detach();
1539        } else {
1540            log::info!("unhandled message {}", type_name);
1541            self.peer.respond_with_unhandled_message(message).log_err();
1542        }
1543    }
1544
1545    pub fn telemetry(&self) -> &Arc<Telemetry> {
1546        &self.telemetry
1547    }
1548}
1549
1550#[derive(Serialize, Deserialize)]
1551struct DevelopmentCredentials {
1552    user_id: u64,
1553    access_token: String,
1554}
1555
1556/// A credentials provider that stores credentials in a local file.
1557///
1558/// This MUST only be used in development, as this is not a secure way of storing
1559/// credentials on user machines.
1560///
1561/// Its existence is purely to work around the annoyance of having to constantly
1562/// re-allow access to the system keychain when developing Zed.
1563struct DevelopmentCredentialsProvider {
1564    path: PathBuf,
1565}
1566
1567impl CredentialsProvider for DevelopmentCredentialsProvider {
1568    fn read_credentials<'a>(
1569        &'a self,
1570        _cx: &'a AsyncAppContext,
1571    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1572        async move {
1573            if IMPERSONATE_LOGIN.is_some() {
1574                return None;
1575            }
1576
1577            let json = std::fs::read(&self.path).log_err()?;
1578
1579            let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?;
1580
1581            Some(Credentials::User {
1582                user_id: credentials.user_id,
1583                access_token: credentials.access_token,
1584            })
1585        }
1586        .boxed_local()
1587    }
1588
1589    fn write_credentials<'a>(
1590        &'a self,
1591        user_id: u64,
1592        access_token: String,
1593        _cx: &'a AsyncAppContext,
1594    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1595        async move {
1596            let json = serde_json::to_string(&DevelopmentCredentials {
1597                user_id,
1598                access_token,
1599            })?;
1600
1601            std::fs::write(&self.path, json)?;
1602
1603            Ok(())
1604        }
1605        .boxed_local()
1606    }
1607
1608    fn delete_credentials<'a>(
1609        &'a self,
1610        _cx: &'a AsyncAppContext,
1611    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1612        async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local()
1613    }
1614}
1615
1616/// A credentials provider that stores credentials in the system keychain.
1617struct KeychainCredentialsProvider;
1618
1619impl CredentialsProvider for KeychainCredentialsProvider {
1620    fn read_credentials<'a>(
1621        &'a self,
1622        cx: &'a AsyncAppContext,
1623    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1624        async move {
1625            if IMPERSONATE_LOGIN.is_some() {
1626                return None;
1627            }
1628
1629            let (user_id, access_token) = cx
1630                .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
1631                .log_err()?
1632                .await
1633                .log_err()??;
1634
1635            Some(Credentials::User {
1636                user_id: user_id.parse().ok()?,
1637                access_token: String::from_utf8(access_token).ok()?,
1638            })
1639        }
1640        .boxed_local()
1641    }
1642
1643    fn write_credentials<'a>(
1644        &'a self,
1645        user_id: u64,
1646        access_token: String,
1647        cx: &'a AsyncAppContext,
1648    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1649        async move {
1650            cx.update(move |cx| {
1651                cx.write_credentials(
1652                    &ClientSettings::get_global(cx).server_url,
1653                    &user_id.to_string(),
1654                    access_token.as_bytes(),
1655                )
1656            })?
1657            .await
1658        }
1659        .boxed_local()
1660    }
1661
1662    fn delete_credentials<'a>(
1663        &'a self,
1664        cx: &'a AsyncAppContext,
1665    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1666        async move {
1667            cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
1668                .await
1669        }
1670        .boxed_local()
1671    }
1672}
1673
1674/// prefix for the zed:// url scheme
1675pub static ZED_URL_SCHEME: &str = "zed";
1676
1677/// Parses the given link into a Zed link.
1678///
1679/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1680/// Returns [`None`] otherwise.
1681pub fn parse_zed_link<'a>(link: &'a str, cx: &AppContext) -> Option<&'a str> {
1682    let server_url = &ClientSettings::get_global(cx).server_url;
1683    if let Some(stripped) = link
1684        .strip_prefix(server_url)
1685        .and_then(|result| result.strip_prefix('/'))
1686    {
1687        return Some(stripped);
1688    }
1689    if let Some(stripped) = link
1690        .strip_prefix(ZED_URL_SCHEME)
1691        .and_then(|result| result.strip_prefix("://"))
1692    {
1693        return Some(stripped);
1694    }
1695
1696    None
1697}
1698
1699#[cfg(test)]
1700mod tests {
1701    use super::*;
1702    use crate::test::FakeServer;
1703
1704    use clock::FakeSystemClock;
1705    use gpui::{BackgroundExecutor, Context, TestAppContext};
1706    use http::FakeHttpClient;
1707    use parking_lot::Mutex;
1708    use settings::SettingsStore;
1709    use std::future;
1710
1711    #[gpui::test(iterations = 10)]
1712    async fn test_reconnection(cx: &mut TestAppContext) {
1713        init_test(cx);
1714        let user_id = 5;
1715        let client = cx.update(|cx| {
1716            Client::new(
1717                Arc::new(FakeSystemClock::default()),
1718                FakeHttpClient::with_404_response(),
1719                cx,
1720            )
1721        });
1722        let server = FakeServer::for_client(user_id, &client, cx).await;
1723        let mut status = client.status();
1724        assert!(matches!(
1725            status.next().await,
1726            Some(Status::Connected { .. })
1727        ));
1728        assert_eq!(server.auth_count(), 1);
1729
1730        server.forbid_connections();
1731        server.disconnect();
1732        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1733
1734        server.allow_connections();
1735        cx.executor().advance_clock(Duration::from_secs(10));
1736        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1737        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1738
1739        server.forbid_connections();
1740        server.disconnect();
1741        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1742
1743        // Clear cached credentials after authentication fails
1744        server.roll_access_token();
1745        server.allow_connections();
1746        cx.executor().run_until_parked();
1747        cx.executor().advance_clock(Duration::from_secs(10));
1748        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1749        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1750    }
1751
1752    #[gpui::test(iterations = 10)]
1753    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1754        init_test(cx);
1755        let user_id = 5;
1756        let client = cx.update(|cx| {
1757            Client::new(
1758                Arc::new(FakeSystemClock::default()),
1759                FakeHttpClient::with_404_response(),
1760                cx,
1761            )
1762        });
1763        let mut status = client.status();
1764
1765        // Time out when client tries to connect.
1766        client.override_authenticate(move |cx| {
1767            cx.background_executor().spawn(async move {
1768                Ok(Credentials::User {
1769                    user_id,
1770                    access_token: "token".into(),
1771                })
1772            })
1773        });
1774        client.override_establish_connection(|_, cx| {
1775            cx.background_executor().spawn(async move {
1776                future::pending::<()>().await;
1777                unreachable!()
1778            })
1779        });
1780        let auth_and_connect = cx.spawn({
1781            let client = client.clone();
1782            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1783        });
1784        executor.run_until_parked();
1785        assert!(matches!(status.next().await, Some(Status::Connecting)));
1786
1787        executor.advance_clock(CONNECTION_TIMEOUT);
1788        assert!(matches!(
1789            status.next().await,
1790            Some(Status::ConnectionError { .. })
1791        ));
1792        auth_and_connect.await.unwrap_err();
1793
1794        // Allow the connection to be established.
1795        let server = FakeServer::for_client(user_id, &client, cx).await;
1796        assert!(matches!(
1797            status.next().await,
1798            Some(Status::Connected { .. })
1799        ));
1800
1801        // Disconnect client.
1802        server.forbid_connections();
1803        server.disconnect();
1804        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1805
1806        // Time out when re-establishing the connection.
1807        server.allow_connections();
1808        client.override_establish_connection(|_, cx| {
1809            cx.background_executor().spawn(async move {
1810                future::pending::<()>().await;
1811                unreachable!()
1812            })
1813        });
1814        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1815        assert!(matches!(
1816            status.next().await,
1817            Some(Status::Reconnecting { .. })
1818        ));
1819
1820        executor.advance_clock(CONNECTION_TIMEOUT);
1821        assert!(matches!(
1822            status.next().await,
1823            Some(Status::ReconnectionError { .. })
1824        ));
1825    }
1826
1827    #[gpui::test(iterations = 10)]
1828    async fn test_authenticating_more_than_once(
1829        cx: &mut TestAppContext,
1830        executor: BackgroundExecutor,
1831    ) {
1832        init_test(cx);
1833        let auth_count = Arc::new(Mutex::new(0));
1834        let dropped_auth_count = Arc::new(Mutex::new(0));
1835        let client = cx.update(|cx| {
1836            Client::new(
1837                Arc::new(FakeSystemClock::default()),
1838                FakeHttpClient::with_404_response(),
1839                cx,
1840            )
1841        });
1842        client.override_authenticate({
1843            let auth_count = auth_count.clone();
1844            let dropped_auth_count = dropped_auth_count.clone();
1845            move |cx| {
1846                let auth_count = auth_count.clone();
1847                let dropped_auth_count = dropped_auth_count.clone();
1848                cx.background_executor().spawn(async move {
1849                    *auth_count.lock() += 1;
1850                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1851                    future::pending::<()>().await;
1852                    unreachable!()
1853                })
1854            }
1855        });
1856
1857        let _authenticate = cx.spawn({
1858            let client = client.clone();
1859            move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1860        });
1861        executor.run_until_parked();
1862        assert_eq!(*auth_count.lock(), 1);
1863        assert_eq!(*dropped_auth_count.lock(), 0);
1864
1865        let _authenticate = cx.spawn({
1866            let client = client.clone();
1867            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1868        });
1869        executor.run_until_parked();
1870        assert_eq!(*auth_count.lock(), 2);
1871        assert_eq!(*dropped_auth_count.lock(), 1);
1872    }
1873
1874    #[gpui::test]
1875    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1876        init_test(cx);
1877        let user_id = 5;
1878        let client = cx.update(|cx| {
1879            Client::new(
1880                Arc::new(FakeSystemClock::default()),
1881                FakeHttpClient::with_404_response(),
1882                cx,
1883            )
1884        });
1885        let server = FakeServer::for_client(user_id, &client, cx).await;
1886
1887        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1888        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1889        client.add_model_message_handler(
1890            move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
1891                match model.update(&mut cx, |model, _| model.id).unwrap() {
1892                    1 => done_tx1.try_send(()).unwrap(),
1893                    2 => done_tx2.try_send(()).unwrap(),
1894                    _ => unreachable!(),
1895                }
1896                async { Ok(()) }
1897            },
1898        );
1899        let model1 = cx.new_model(|_| TestModel {
1900            id: 1,
1901            subscription: None,
1902        });
1903        let model2 = cx.new_model(|_| TestModel {
1904            id: 2,
1905            subscription: None,
1906        });
1907        let model3 = cx.new_model(|_| TestModel {
1908            id: 3,
1909            subscription: None,
1910        });
1911
1912        let _subscription1 = client
1913            .subscribe_to_entity(1)
1914            .unwrap()
1915            .set_model(&model1, &mut cx.to_async());
1916        let _subscription2 = client
1917            .subscribe_to_entity(2)
1918            .unwrap()
1919            .set_model(&model2, &mut cx.to_async());
1920        // Ensure dropping a subscription for the same entity type still allows receiving of
1921        // messages for other entity IDs of the same type.
1922        let subscription3 = client
1923            .subscribe_to_entity(3)
1924            .unwrap()
1925            .set_model(&model3, &mut cx.to_async());
1926        drop(subscription3);
1927
1928        server.send(proto::JoinProject { project_id: 1 });
1929        server.send(proto::JoinProject { project_id: 2 });
1930        done_rx1.next().await.unwrap();
1931        done_rx2.next().await.unwrap();
1932    }
1933
1934    #[gpui::test]
1935    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1936        init_test(cx);
1937        let user_id = 5;
1938        let client = cx.update(|cx| {
1939            Client::new(
1940                Arc::new(FakeSystemClock::default()),
1941                FakeHttpClient::with_404_response(),
1942                cx,
1943            )
1944        });
1945        let server = FakeServer::for_client(user_id, &client, cx).await;
1946
1947        let model = cx.new_model(|_| TestModel::default());
1948        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1949        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1950        let subscription1 = client.add_message_handler(
1951            model.downgrade(),
1952            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1953                done_tx1.try_send(()).unwrap();
1954                async { Ok(()) }
1955            },
1956        );
1957        drop(subscription1);
1958        let _subscription2 = client.add_message_handler(
1959            model.downgrade(),
1960            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1961                done_tx2.try_send(()).unwrap();
1962                async { Ok(()) }
1963            },
1964        );
1965        server.send(proto::Ping {});
1966        done_rx2.next().await.unwrap();
1967    }
1968
1969    #[gpui::test]
1970    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1971        init_test(cx);
1972        let user_id = 5;
1973        let client = cx.update(|cx| {
1974            Client::new(
1975                Arc::new(FakeSystemClock::default()),
1976                FakeHttpClient::with_404_response(),
1977                cx,
1978            )
1979        });
1980        let server = FakeServer::for_client(user_id, &client, cx).await;
1981
1982        let model = cx.new_model(|_| TestModel::default());
1983        let (done_tx, mut done_rx) = smol::channel::unbounded();
1984        let subscription = client.add_message_handler(
1985            model.clone().downgrade(),
1986            move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1987                model
1988                    .update(&mut cx, |model, _| model.subscription.take())
1989                    .unwrap();
1990                done_tx.try_send(()).unwrap();
1991                async { Ok(()) }
1992            },
1993        );
1994        model.update(cx, |model, _| {
1995            model.subscription = Some(subscription);
1996        });
1997        server.send(proto::Ping {});
1998        done_rx.next().await.unwrap();
1999    }
2000
2001    #[derive(Default)]
2002    struct TestModel {
2003        id: usize,
2004        subscription: Option<Subscription>,
2005    }
2006
2007    fn init_test(cx: &mut TestAppContext) {
2008        cx.update(|cx| {
2009            let settings_store = SettingsStore::test(cx);
2010            cx.set_global(settings_store);
2011            init_settings(cx);
2012        });
2013    }
2014}