client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod socks;
   5pub mod telemetry;
   6pub mod user;
   7pub mod zed_urls;
   8
   9use anyhow::{anyhow, bail, Context as _, Result};
  10use async_recursion::async_recursion;
  11use async_tungstenite::tungstenite::{
  12    client::IntoClientRequest,
  13    error::Error as WebsocketError,
  14    http::{HeaderValue, Request, StatusCode},
  15};
  16use chrono::{DateTime, Utc};
  17use clock::SystemClock;
  18use futures::{
  19    channel::oneshot, future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
  20    TryFutureExt as _, TryStreamExt,
  21};
  22use gpui::{actions, AppContext, AsyncAppContext, Global, Model, Task, WeakModel};
  23use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
  24use parking_lot::RwLock;
  25use postage::watch;
  26use rand::prelude::*;
  27use release_channel::{AppVersion, ReleaseChannel};
  28use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::{Settings, SettingsSources};
  32use socks::connect_socks_proxy_stream;
  33use std::fmt;
  34use std::pin::Pin;
  35use std::{
  36    any::TypeId,
  37    convert::TryFrom,
  38    fmt::Write as _,
  39    future::Future,
  40    marker::PhantomData,
  41    path::PathBuf,
  42    sync::{
  43        atomic::{AtomicU64, Ordering},
  44        Arc, LazyLock, Weak,
  45    },
  46    time::{Duration, Instant},
  47};
  48use telemetry::Telemetry;
  49use thiserror::Error;
  50use url::Url;
  51use util::{ResultExt, TryFutureExt};
  52
  53pub use rpc::*;
  54pub use telemetry_events::Event;
  55pub use user::*;
  56
  57#[derive(Debug, Clone, Eq, PartialEq)]
  58pub struct DevServerToken(pub String);
  59
  60impl fmt::Display for DevServerToken {
  61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  62        write!(f, "{}", self.0)
  63    }
  64}
  65
  66static ZED_SERVER_URL: LazyLock<Option<String>> =
  67    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  68static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  69
  70/// An environment variable whose presence indicates that the development auth
  71/// provider should be used.
  72///
  73/// Only works in development. Setting this environment variable in other release
  74/// channels is a no-op.
  75pub static ZED_DEVELOPMENT_AUTH: LazyLock<bool> = LazyLock::new(|| {
  76    std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty())
  77});
  78pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  79    std::env::var("ZED_IMPERSONATE")
  80        .ok()
  81        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  82});
  83
  84pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  85    std::env::var("ZED_ADMIN_API_TOKEN")
  86        .ok()
  87        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  88});
  89
  90pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  91    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  92
  93pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  94    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty()));
  95
  96pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  97pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
  98pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  99
 100actions!(client, [SignIn, SignOut, Reconnect]);
 101
 102#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
 103pub struct ClientSettingsContent {
 104    server_url: Option<String>,
 105}
 106
 107#[derive(Deserialize)]
 108pub struct ClientSettings {
 109    pub server_url: String,
 110}
 111
 112impl Settings for ClientSettings {
 113    const KEY: Option<&'static str> = None;
 114
 115    type FileContent = ClientSettingsContent;
 116
 117    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 118        let mut result = sources.json_merge::<Self>()?;
 119        if let Some(server_url) = &*ZED_SERVER_URL {
 120            result.server_url.clone_from(server_url)
 121        }
 122        Ok(result)
 123    }
 124}
 125
 126#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 127pub struct ProxySettingsContent {
 128    proxy: Option<String>,
 129}
 130
 131#[derive(Deserialize, Default)]
 132pub struct ProxySettings {
 133    pub proxy: Option<String>,
 134}
 135
 136impl Settings for ProxySettings {
 137    const KEY: Option<&'static str> = None;
 138
 139    type FileContent = ProxySettingsContent;
 140
 141    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 142        Ok(Self {
 143            proxy: sources
 144                .user
 145                .or(sources.server)
 146                .and_then(|value| value.proxy.clone())
 147                .or(sources.default.proxy.clone()),
 148        })
 149    }
 150}
 151
 152pub fn init_settings(cx: &mut AppContext) {
 153    TelemetrySettings::register(cx);
 154    ClientSettings::register(cx);
 155    ProxySettings::register(cx);
 156}
 157
 158pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
 159    let client = Arc::downgrade(client);
 160    cx.on_action({
 161        let client = client.clone();
 162        move |_: &SignIn, cx| {
 163            if let Some(client) = client.upgrade() {
 164                cx.spawn(
 165                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
 166                )
 167                .detach();
 168            }
 169        }
 170    });
 171
 172    cx.on_action({
 173        let client = client.clone();
 174        move |_: &SignOut, cx| {
 175            if let Some(client) = client.upgrade() {
 176                cx.spawn(|cx| async move {
 177                    client.sign_out(&cx).await;
 178                })
 179                .detach();
 180            }
 181        }
 182    });
 183
 184    cx.on_action({
 185        let client = client.clone();
 186        move |_: &Reconnect, cx| {
 187            if let Some(client) = client.upgrade() {
 188                cx.spawn(|cx| async move {
 189                    client.reconnect(&cx);
 190                })
 191                .detach();
 192            }
 193        }
 194    });
 195}
 196
 197struct GlobalClient(Arc<Client>);
 198
 199impl Global for GlobalClient {}
 200
 201pub struct Client {
 202    id: AtomicU64,
 203    peer: Arc<Peer>,
 204    http: Arc<HttpClientWithUrl>,
 205    telemetry: Arc<Telemetry>,
 206    credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
 207    state: RwLock<ClientState>,
 208    handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
 209
 210    #[allow(clippy::type_complexity)]
 211    #[cfg(any(test, feature = "test-support"))]
 212    authenticate: RwLock<
 213        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 214    >,
 215
 216    #[allow(clippy::type_complexity)]
 217    #[cfg(any(test, feature = "test-support"))]
 218    establish_connection: RwLock<
 219        Option<
 220            Box<
 221                dyn 'static
 222                    + Send
 223                    + Sync
 224                    + Fn(
 225                        &Credentials,
 226                        &AsyncAppContext,
 227                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 228            >,
 229        >,
 230    >,
 231
 232    #[cfg(any(test, feature = "test-support"))]
 233    rpc_url: RwLock<Option<Url>>,
 234}
 235
 236#[derive(Error, Debug)]
 237pub enum EstablishConnectionError {
 238    #[error("upgrade required")]
 239    UpgradeRequired,
 240    #[error("unauthorized")]
 241    Unauthorized,
 242    #[error("{0}")]
 243    Other(#[from] anyhow::Error),
 244    #[error("{0}")]
 245    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 246    #[error("{0}")]
 247    Io(#[from] std::io::Error),
 248    #[error("{0}")]
 249    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 250}
 251
 252impl From<WebsocketError> for EstablishConnectionError {
 253    fn from(error: WebsocketError) -> Self {
 254        if let WebsocketError::Http(response) = &error {
 255            match response.status() {
 256                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 257                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 258                _ => {}
 259            }
 260        }
 261        EstablishConnectionError::Other(error.into())
 262    }
 263}
 264
 265impl EstablishConnectionError {
 266    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 267        Self::Other(error.into())
 268    }
 269}
 270
 271#[derive(Copy, Clone, Debug, PartialEq)]
 272pub enum Status {
 273    SignedOut,
 274    UpgradeRequired,
 275    Authenticating,
 276    Connecting,
 277    ConnectionError,
 278    Connected {
 279        peer_id: PeerId,
 280        connection_id: ConnectionId,
 281    },
 282    ConnectionLost,
 283    Reauthenticating,
 284    Reconnecting,
 285    ReconnectionError {
 286        next_reconnection: Instant,
 287    },
 288}
 289
 290impl Status {
 291    pub fn is_connected(&self) -> bool {
 292        matches!(self, Self::Connected { .. })
 293    }
 294
 295    pub fn is_signed_out(&self) -> bool {
 296        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 297    }
 298}
 299
 300struct ClientState {
 301    credentials: Option<Credentials>,
 302    status: (watch::Sender<Status>, watch::Receiver<Status>),
 303    _reconnect_task: Option<Task<()>>,
 304}
 305
 306#[derive(Clone, Debug, Eq, PartialEq)]
 307pub enum Credentials {
 308    DevServer { token: DevServerToken },
 309    User { user_id: u64, access_token: String },
 310}
 311
 312impl Credentials {
 313    pub fn authorization_header(&self) -> String {
 314        match self {
 315            Credentials::DevServer { token } => format!("dev-server-token {}", token),
 316            Credentials::User {
 317                user_id,
 318                access_token,
 319            } => format!("{} {}", user_id, access_token),
 320        }
 321    }
 322}
 323
 324/// A provider for [`Credentials`].
 325///
 326/// Used to abstract over reading and writing credentials to some form of
 327/// persistence (like the system keychain).
 328trait CredentialsProvider {
 329    /// Reads the credentials from the provider.
 330    fn read_credentials<'a>(
 331        &'a self,
 332        cx: &'a AsyncAppContext,
 333    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>>;
 334
 335    /// Writes the credentials to the provider.
 336    fn write_credentials<'a>(
 337        &'a self,
 338        user_id: u64,
 339        access_token: String,
 340        cx: &'a AsyncAppContext,
 341    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 342
 343    /// Deletes the credentials from the provider.
 344    fn delete_credentials<'a>(
 345        &'a self,
 346        cx: &'a AsyncAppContext,
 347    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 348}
 349
 350impl Default for ClientState {
 351    fn default() -> Self {
 352        Self {
 353            credentials: None,
 354            status: watch::channel_with(Status::SignedOut),
 355            _reconnect_task: None,
 356        }
 357    }
 358}
 359
 360pub enum Subscription {
 361    Entity {
 362        client: Weak<Client>,
 363        id: (TypeId, u64),
 364    },
 365    Message {
 366        client: Weak<Client>,
 367        id: TypeId,
 368    },
 369}
 370
 371impl Drop for Subscription {
 372    fn drop(&mut self) {
 373        match self {
 374            Subscription::Entity { client, id } => {
 375                if let Some(client) = client.upgrade() {
 376                    let mut state = client.handler_set.lock();
 377                    let _ = state.entities_by_type_and_remote_id.remove(id);
 378                }
 379            }
 380            Subscription::Message { client, id } => {
 381                if let Some(client) = client.upgrade() {
 382                    let mut state = client.handler_set.lock();
 383                    let _ = state.entity_types_by_message_type.remove(id);
 384                    let _ = state.message_handlers.remove(id);
 385                }
 386            }
 387        }
 388    }
 389}
 390
 391pub struct PendingEntitySubscription<T: 'static> {
 392    client: Arc<Client>,
 393    remote_id: u64,
 394    _entity_type: PhantomData<T>,
 395    consumed: bool,
 396}
 397
 398impl<T: 'static> PendingEntitySubscription<T> {
 399    pub fn set_model(mut self, model: &Model<T>, cx: &AsyncAppContext) -> Subscription {
 400        self.consumed = true;
 401        let mut handlers = self.client.handler_set.lock();
 402        let id = (TypeId::of::<T>(), self.remote_id);
 403        let Some(EntityMessageSubscriber::Pending(messages)) =
 404            handlers.entities_by_type_and_remote_id.remove(&id)
 405        else {
 406            unreachable!()
 407        };
 408
 409        handlers.entities_by_type_and_remote_id.insert(
 410            id,
 411            EntityMessageSubscriber::Entity {
 412                handle: model.downgrade().into(),
 413            },
 414        );
 415        drop(handlers);
 416        for message in messages {
 417            let client_id = self.client.id();
 418            let type_name = message.payload_type_name();
 419            let sender_id = message.original_sender_id();
 420            log::debug!(
 421                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 422                client_id,
 423                sender_id,
 424                type_name
 425            );
 426            self.client.handle_message(message, cx);
 427        }
 428        Subscription::Entity {
 429            client: Arc::downgrade(&self.client),
 430            id,
 431        }
 432    }
 433}
 434
 435impl<T: 'static> Drop for PendingEntitySubscription<T> {
 436    fn drop(&mut self) {
 437        if !self.consumed {
 438            let mut state = self.client.handler_set.lock();
 439            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 440                .entities_by_type_and_remote_id
 441                .remove(&(TypeId::of::<T>(), self.remote_id))
 442            {
 443                for message in messages {
 444                    log::info!("unhandled message {}", message.payload_type_name());
 445                }
 446            }
 447        }
 448    }
 449}
 450
 451#[derive(Copy, Clone)]
 452pub struct TelemetrySettings {
 453    pub diagnostics: bool,
 454    pub metrics: bool,
 455}
 456
 457/// Control what info is collected by Zed.
 458#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 459pub struct TelemetrySettingsContent {
 460    /// Send debug info like crash reports.
 461    ///
 462    /// Default: true
 463    pub diagnostics: Option<bool>,
 464    /// Send anonymized usage data like what languages you're using Zed with.
 465    ///
 466    /// Default: true
 467    pub metrics: Option<bool>,
 468}
 469
 470impl settings::Settings for TelemetrySettings {
 471    const KEY: Option<&'static str> = Some("telemetry");
 472
 473    type FileContent = TelemetrySettingsContent;
 474
 475    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 476        Ok(Self {
 477            diagnostics: sources
 478                .user
 479                .as_ref()
 480                .or(sources.server.as_ref())
 481                .and_then(|v| v.diagnostics)
 482                .unwrap_or(
 483                    sources
 484                        .default
 485                        .diagnostics
 486                        .ok_or_else(Self::missing_default)?,
 487                ),
 488            metrics: sources
 489                .user
 490                .as_ref()
 491                .or(sources.server.as_ref())
 492                .and_then(|v| v.metrics)
 493                .unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
 494        })
 495    }
 496}
 497
 498impl Client {
 499    pub fn new(
 500        clock: Arc<dyn SystemClock>,
 501        http: Arc<HttpClientWithUrl>,
 502        cx: &mut AppContext,
 503    ) -> Arc<Self> {
 504        let use_zed_development_auth = match ReleaseChannel::try_global(cx) {
 505            Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH,
 506            Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
 507            | None => false,
 508        };
 509
 510        let credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static> =
 511            if use_zed_development_auth {
 512                Arc::new(DevelopmentCredentialsProvider {
 513                    path: paths::config_dir().join("development_auth"),
 514                })
 515            } else {
 516                Arc::new(KeychainCredentialsProvider)
 517            };
 518
 519        Arc::new(Self {
 520            id: AtomicU64::new(0),
 521            peer: Peer::new(0),
 522            telemetry: Telemetry::new(clock, http.clone(), cx),
 523            http,
 524            credentials_provider,
 525            state: Default::default(),
 526            handler_set: Default::default(),
 527
 528            #[cfg(any(test, feature = "test-support"))]
 529            authenticate: Default::default(),
 530            #[cfg(any(test, feature = "test-support"))]
 531            establish_connection: Default::default(),
 532            #[cfg(any(test, feature = "test-support"))]
 533            rpc_url: RwLock::default(),
 534        })
 535    }
 536
 537    pub fn production(cx: &mut AppContext) -> Arc<Self> {
 538        let clock = Arc::new(clock::RealSystemClock);
 539        let http = Arc::new(HttpClientWithUrl::new_uri(
 540            cx.http_client(),
 541            &ClientSettings::get_global(cx).server_url,
 542            cx.http_client().proxy().cloned(),
 543        ));
 544        Self::new(clock, http, cx)
 545    }
 546
 547    pub fn id(&self) -> u64 {
 548        self.id.load(Ordering::SeqCst)
 549    }
 550
 551    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 552        self.http.clone()
 553    }
 554
 555    pub fn set_id(&self, id: u64) -> &Self {
 556        self.id.store(id, Ordering::SeqCst);
 557        self
 558    }
 559
 560    #[cfg(any(test, feature = "test-support"))]
 561    pub fn teardown(&self) {
 562        let mut state = self.state.write();
 563        state._reconnect_task.take();
 564        self.handler_set.lock().clear();
 565        self.peer.teardown();
 566    }
 567
 568    #[cfg(any(test, feature = "test-support"))]
 569    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 570    where
 571        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 572    {
 573        *self.authenticate.write() = Some(Box::new(authenticate));
 574        self
 575    }
 576
 577    #[cfg(any(test, feature = "test-support"))]
 578    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 579    where
 580        F: 'static
 581            + Send
 582            + Sync
 583            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 584    {
 585        *self.establish_connection.write() = Some(Box::new(connect));
 586        self
 587    }
 588
 589    #[cfg(any(test, feature = "test-support"))]
 590    pub fn override_rpc_url(&self, url: Url) -> &Self {
 591        *self.rpc_url.write() = Some(url);
 592        self
 593    }
 594
 595    pub fn global(cx: &AppContext) -> Arc<Self> {
 596        cx.global::<GlobalClient>().0.clone()
 597    }
 598    pub fn set_global(client: Arc<Client>, cx: &mut AppContext) {
 599        cx.set_global(GlobalClient(client))
 600    }
 601
 602    pub fn user_id(&self) -> Option<u64> {
 603        if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() {
 604            Some(*user_id)
 605        } else {
 606            None
 607        }
 608    }
 609
 610    pub fn peer_id(&self) -> Option<PeerId> {
 611        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 612            Some(*peer_id)
 613        } else {
 614            None
 615        }
 616    }
 617
 618    pub fn status(&self) -> watch::Receiver<Status> {
 619        self.state.read().status.1.clone()
 620    }
 621
 622    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 623        log::info!("set status on client {}: {:?}", self.id(), status);
 624        let mut state = self.state.write();
 625        *state.status.0.borrow_mut() = status;
 626
 627        match status {
 628            Status::Connected { .. } => {
 629                state._reconnect_task = None;
 630            }
 631            Status::ConnectionLost => {
 632                let this = self.clone();
 633                state._reconnect_task = Some(cx.spawn(move |cx| async move {
 634                    #[cfg(any(test, feature = "test-support"))]
 635                    let mut rng = StdRng::seed_from_u64(0);
 636                    #[cfg(not(any(test, feature = "test-support")))]
 637                    let mut rng = StdRng::from_entropy();
 638
 639                    let mut delay = INITIAL_RECONNECTION_DELAY;
 640                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 641                        log::error!("failed to connect {}", error);
 642                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 643                            this.set_status(
 644                                Status::ReconnectionError {
 645                                    next_reconnection: Instant::now() + delay,
 646                                },
 647                                &cx,
 648                            );
 649                            cx.background_executor().timer(delay).await;
 650                            delay = delay
 651                                .mul_f32(rng.gen_range(0.5..=2.5))
 652                                .max(INITIAL_RECONNECTION_DELAY)
 653                                .min(MAX_RECONNECTION_DELAY);
 654                        } else {
 655                            break;
 656                        }
 657                    }
 658                }));
 659            }
 660            Status::SignedOut | Status::UpgradeRequired => {
 661                self.telemetry.set_authenticated_user_info(None, false);
 662                state._reconnect_task.take();
 663            }
 664            _ => {}
 665        }
 666    }
 667
 668    pub fn subscribe_to_entity<T>(
 669        self: &Arc<Self>,
 670        remote_id: u64,
 671    ) -> Result<PendingEntitySubscription<T>>
 672    where
 673        T: 'static,
 674    {
 675        let id = (TypeId::of::<T>(), remote_id);
 676
 677        let mut state = self.handler_set.lock();
 678        if state.entities_by_type_and_remote_id.contains_key(&id) {
 679            return Err(anyhow!("already subscribed to entity"));
 680        }
 681
 682        state
 683            .entities_by_type_and_remote_id
 684            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 685
 686        Ok(PendingEntitySubscription {
 687            client: self.clone(),
 688            remote_id,
 689            consumed: false,
 690            _entity_type: PhantomData,
 691        })
 692    }
 693
 694    #[track_caller]
 695    pub fn add_message_handler<M, E, H, F>(
 696        self: &Arc<Self>,
 697        entity: WeakModel<E>,
 698        handler: H,
 699    ) -> Subscription
 700    where
 701        M: EnvelopedMessage,
 702        E: 'static,
 703        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 704        F: 'static + Future<Output = Result<()>>,
 705    {
 706        self.add_message_handler_impl(entity, move |model, message, _, cx| {
 707            handler(model, message, cx)
 708        })
 709    }
 710
 711    fn add_message_handler_impl<M, E, H, F>(
 712        self: &Arc<Self>,
 713        entity: WeakModel<E>,
 714        handler: H,
 715    ) -> Subscription
 716    where
 717        M: EnvelopedMessage,
 718        E: 'static,
 719        H: 'static
 720            + Sync
 721            + Fn(Model<E>, TypedEnvelope<M>, AnyProtoClient, AsyncAppContext) -> F
 722            + Send
 723            + Sync,
 724        F: 'static + Future<Output = Result<()>>,
 725    {
 726        let message_type_id = TypeId::of::<M>();
 727        let mut state = self.handler_set.lock();
 728        state
 729            .models_by_message_type
 730            .insert(message_type_id, entity.into());
 731
 732        let prev_handler = state.message_handlers.insert(
 733            message_type_id,
 734            Arc::new(move |subscriber, envelope, client, cx| {
 735                let subscriber = subscriber.downcast::<E>().unwrap();
 736                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 737                handler(subscriber, *envelope, client.clone(), cx).boxed_local()
 738            }),
 739        );
 740        if prev_handler.is_some() {
 741            let location = std::panic::Location::caller();
 742            panic!(
 743                "{}:{} registered handler for the same message {} twice",
 744                location.file(),
 745                location.line(),
 746                std::any::type_name::<M>()
 747            );
 748        }
 749
 750        Subscription::Message {
 751            client: Arc::downgrade(self),
 752            id: message_type_id,
 753        }
 754    }
 755
 756    pub fn add_request_handler<M, E, H, F>(
 757        self: &Arc<Self>,
 758        model: WeakModel<E>,
 759        handler: H,
 760    ) -> Subscription
 761    where
 762        M: RequestMessage,
 763        E: 'static,
 764        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 765        F: 'static + Future<Output = Result<M::Response>>,
 766    {
 767        self.add_message_handler_impl(model, move |handle, envelope, this, cx| {
 768            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 769        })
 770    }
 771
 772    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 773        receipt: Receipt<T>,
 774        response: F,
 775        client: AnyProtoClient,
 776    ) -> Result<()> {
 777        match response.await {
 778            Ok(response) => {
 779                client.send_response(receipt.message_id, response)?;
 780                Ok(())
 781            }
 782            Err(error) => {
 783                client.send_response(receipt.message_id, error.to_proto())?;
 784                Err(error)
 785            }
 786        }
 787    }
 788
 789    pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool {
 790        self.credentials_provider
 791            .read_credentials(cx)
 792            .await
 793            .is_some()
 794    }
 795
 796    pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self {
 797        self.state.write().credentials = Some(Credentials::DevServer { token });
 798        self
 799    }
 800
 801    #[async_recursion(?Send)]
 802    pub async fn authenticate_and_connect(
 803        self: &Arc<Self>,
 804        try_provider: bool,
 805        cx: &AsyncAppContext,
 806    ) -> anyhow::Result<()> {
 807        let was_disconnected = match *self.status().borrow() {
 808            Status::SignedOut => true,
 809            Status::ConnectionError
 810            | Status::ConnectionLost
 811            | Status::Authenticating { .. }
 812            | Status::Reauthenticating { .. }
 813            | Status::ReconnectionError { .. } => false,
 814            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 815                return Ok(())
 816            }
 817            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 818        };
 819        if was_disconnected {
 820            self.set_status(Status::Authenticating, cx);
 821        } else {
 822            self.set_status(Status::Reauthenticating, cx)
 823        }
 824
 825        let mut read_from_provider = false;
 826        let mut credentials = self.state.read().credentials.clone();
 827        if credentials.is_none() && try_provider {
 828            credentials = self.credentials_provider.read_credentials(cx).await;
 829            read_from_provider = credentials.is_some();
 830        }
 831
 832        if credentials.is_none() {
 833            let mut status_rx = self.status();
 834            let _ = status_rx.next().await;
 835            futures::select_biased! {
 836                authenticate = self.authenticate(cx).fuse() => {
 837                    match authenticate {
 838                        Ok(creds) => credentials = Some(creds),
 839                        Err(err) => {
 840                            self.set_status(Status::ConnectionError, cx);
 841                            return Err(err);
 842                        }
 843                    }
 844                }
 845                _ = status_rx.next().fuse() => {
 846                    return Err(anyhow!("authentication canceled"));
 847                }
 848            }
 849        }
 850        let credentials = credentials.unwrap();
 851        if let Credentials::User { user_id, .. } = &credentials {
 852            self.set_id(*user_id);
 853        }
 854
 855        if was_disconnected {
 856            self.set_status(Status::Connecting, cx);
 857        } else {
 858            self.set_status(Status::Reconnecting, cx);
 859        }
 860
 861        let mut timeout =
 862            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
 863        futures::select_biased! {
 864            connection = self.establish_connection(&credentials, cx).fuse() => {
 865                match connection {
 866                    Ok(conn) => {
 867                        self.state.write().credentials = Some(credentials.clone());
 868                        if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
 869                            if let Credentials::User{user_id, access_token} = credentials {
 870                                self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err();
 871                            }
 872                        }
 873
 874                        futures::select_biased! {
 875                            result = self.set_connection(conn, cx).fuse() => result,
 876                            _ = timeout => {
 877                                self.set_status(Status::ConnectionError, cx);
 878                                Err(anyhow!("timed out waiting on hello message from server"))
 879                            }
 880                        }
 881                    }
 882                    Err(EstablishConnectionError::Unauthorized) => {
 883                        self.state.write().credentials.take();
 884                        if read_from_provider {
 885                            self.credentials_provider.delete_credentials(cx).await.log_err();
 886                            self.set_status(Status::SignedOut, cx);
 887                            self.authenticate_and_connect(false, cx).await
 888                        } else {
 889                            self.set_status(Status::ConnectionError, cx);
 890                            Err(EstablishConnectionError::Unauthorized)?
 891                        }
 892                    }
 893                    Err(EstablishConnectionError::UpgradeRequired) => {
 894                        self.set_status(Status::UpgradeRequired, cx);
 895                        Err(EstablishConnectionError::UpgradeRequired)?
 896                    }
 897                    Err(error) => {
 898                        self.set_status(Status::ConnectionError, cx);
 899                        Err(error)?
 900                    }
 901                }
 902            }
 903            _ = &mut timeout => {
 904                self.set_status(Status::ConnectionError, cx);
 905                Err(anyhow!("timed out trying to establish connection"))
 906            }
 907        }
 908    }
 909
 910    async fn set_connection(
 911        self: &Arc<Self>,
 912        conn: Connection,
 913        cx: &AsyncAppContext,
 914    ) -> Result<()> {
 915        let executor = cx.background_executor();
 916        log::info!("add connection to peer");
 917        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
 918            let executor = executor.clone();
 919            move |duration| executor.timer(duration)
 920        });
 921        let handle_io = executor.spawn(handle_io);
 922
 923        let peer_id = async {
 924            log::info!("waiting for server hello");
 925            let message = incoming
 926                .next()
 927                .await
 928                .ok_or_else(|| anyhow!("no hello message received"))?;
 929            log::info!("got server hello");
 930            let hello_message_type_name = message.payload_type_name().to_string();
 931            let hello = message
 932                .into_any()
 933                .downcast::<TypedEnvelope<proto::Hello>>()
 934                .map_err(|_| {
 935                    anyhow!(
 936                        "invalid hello message received: {:?}",
 937                        hello_message_type_name
 938                    )
 939                })?;
 940            let peer_id = hello
 941                .payload
 942                .peer_id
 943                .ok_or_else(|| anyhow!("invalid peer id"))?;
 944            Ok(peer_id)
 945        };
 946
 947        let peer_id = match peer_id.await {
 948            Ok(peer_id) => peer_id,
 949            Err(error) => {
 950                self.peer.disconnect(connection_id);
 951                return Err(error);
 952            }
 953        };
 954
 955        log::info!(
 956            "set status to connected (connection id: {:?}, peer id: {:?})",
 957            connection_id,
 958            peer_id
 959        );
 960        self.set_status(
 961            Status::Connected {
 962                peer_id,
 963                connection_id,
 964            },
 965            cx,
 966        );
 967
 968        cx.spawn({
 969            let this = self.clone();
 970            |cx| {
 971                async move {
 972                    while let Some(message) = incoming.next().await {
 973                        this.handle_message(message, &cx);
 974                        // Don't starve the main thread when receiving lots of messages at once.
 975                        smol::future::yield_now().await;
 976                    }
 977                }
 978            }
 979        })
 980        .detach();
 981
 982        cx.spawn({
 983            let this = self.clone();
 984            move |cx| async move {
 985                match handle_io.await {
 986                    Ok(()) => {
 987                        if *this.status().borrow()
 988                            == (Status::Connected {
 989                                connection_id,
 990                                peer_id,
 991                            })
 992                        {
 993                            this.set_status(Status::SignedOut, &cx);
 994                        }
 995                    }
 996                    Err(err) => {
 997                        log::error!("connection error: {:?}", err);
 998                        this.set_status(Status::ConnectionLost, &cx);
 999                    }
1000                }
1001            }
1002        })
1003        .detach();
1004
1005        Ok(())
1006    }
1007
1008    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
1009        #[cfg(any(test, feature = "test-support"))]
1010        if let Some(callback) = self.authenticate.read().as_ref() {
1011            return callback(cx);
1012        }
1013
1014        self.authenticate_with_browser(cx)
1015    }
1016
1017    fn establish_connection(
1018        self: &Arc<Self>,
1019        credentials: &Credentials,
1020        cx: &AsyncAppContext,
1021    ) -> Task<Result<Connection, EstablishConnectionError>> {
1022        #[cfg(any(test, feature = "test-support"))]
1023        if let Some(callback) = self.establish_connection.read().as_ref() {
1024            return callback(credentials, cx);
1025        }
1026
1027        self.establish_websocket_connection(credentials, cx)
1028    }
1029
1030    fn rpc_url(
1031        &self,
1032        http: Arc<HttpClientWithUrl>,
1033        release_channel: Option<ReleaseChannel>,
1034    ) -> impl Future<Output = Result<url::Url>> {
1035        #[cfg(any(test, feature = "test-support"))]
1036        let url_override = self.rpc_url.read().clone();
1037
1038        async move {
1039            #[cfg(any(test, feature = "test-support"))]
1040            if let Some(url) = url_override {
1041                return Ok(url);
1042            }
1043
1044            if let Some(url) = &*ZED_RPC_URL {
1045                return Url::parse(url).context("invalid rpc url");
1046            }
1047
1048            let mut url = http.build_url("/rpc");
1049            if let Some(preview_param) =
1050                release_channel.and_then(|channel| channel.release_query_param())
1051            {
1052                url += "?";
1053                url += preview_param;
1054            }
1055
1056            let response = http.get(&url, Default::default(), false).await?;
1057            let collab_url = if response.status().is_redirection() {
1058                response
1059                    .headers()
1060                    .get("Location")
1061                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1062                    .to_str()
1063                    .map_err(EstablishConnectionError::other)?
1064                    .to_string()
1065            } else {
1066                Err(anyhow!(
1067                    "unexpected /rpc response status {}",
1068                    response.status()
1069                ))?
1070            };
1071
1072            Url::parse(&collab_url).context("invalid rpc url")
1073        }
1074    }
1075
1076    fn establish_websocket_connection(
1077        self: &Arc<Self>,
1078        credentials: &Credentials,
1079        cx: &AsyncAppContext,
1080    ) -> Task<Result<Connection, EstablishConnectionError>> {
1081        let release_channel = cx
1082            .update(|cx| ReleaseChannel::try_global(cx))
1083            .ok()
1084            .flatten();
1085        let app_version = cx
1086            .update(|cx| AppVersion::global(cx).to_string())
1087            .ok()
1088            .unwrap_or_default();
1089
1090        let http = self.http.clone();
1091        let proxy = http.proxy().cloned();
1092        let credentials = credentials.clone();
1093        let rpc_url = self.rpc_url(http, release_channel);
1094        cx.background_executor().spawn(async move {
1095            use HttpOrHttps::*;
1096
1097            #[derive(Debug)]
1098            enum HttpOrHttps {
1099                Http,
1100                Https,
1101            }
1102
1103            let mut rpc_url = rpc_url.await?;
1104            let url_scheme = match rpc_url.scheme() {
1105                "https" => Https,
1106                "http" => Http,
1107                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1108            };
1109            let rpc_host = rpc_url
1110                .host_str()
1111                .zip(rpc_url.port_or_known_default())
1112                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1113            let stream = connect_socks_proxy_stream(proxy.as_ref(), rpc_host).await?;
1114
1115            log::info!("connected to rpc endpoint {}", rpc_url);
1116
1117            rpc_url
1118                .set_scheme(match url_scheme {
1119                    Https => "wss",
1120                    Http => "ws",
1121                })
1122                .unwrap();
1123
1124            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1125            // for us from the RPC URL.
1126            //
1127            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1128            let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1129
1130            // We then modify the request to add our desired headers.
1131            let request_headers = request.headers_mut();
1132            request_headers.insert(
1133                "Authorization",
1134                HeaderValue::from_str(&credentials.authorization_header())?,
1135            );
1136            request_headers.insert(
1137                "x-zed-protocol-version",
1138                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1139            );
1140            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1141            request_headers.insert(
1142                "x-zed-release-channel",
1143                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1144            );
1145
1146            match url_scheme {
1147                Https => {
1148                    let client_config = {
1149                        let mut root_store = rustls::RootCertStore::empty();
1150
1151                        let root_certs = rustls_native_certs::load_native_certs();
1152                        for error in root_certs.errors {
1153                            log::warn!("error loading native certs: {:?}", error);
1154                        }
1155                        root_store.add_parsable_certificates(
1156                            &root_certs
1157                                .certs
1158                                .into_iter()
1159                                .map(|cert| cert.as_ref().to_owned())
1160                                .collect::<Vec<_>>(),
1161                        );
1162                        rustls::ClientConfig::builder()
1163                            .with_safe_defaults()
1164                            .with_root_certificates(root_store)
1165                            .with_no_client_auth()
1166                    };
1167
1168                    let (stream, _) =
1169                        async_tungstenite::async_tls::client_async_tls_with_connector(
1170                            request,
1171                            stream,
1172                            Some(client_config.into()),
1173                        )
1174                        .await?;
1175                    Ok(Connection::new(
1176                        stream
1177                            .map_err(|error| anyhow!(error))
1178                            .sink_map_err(|error| anyhow!(error)),
1179                    ))
1180                }
1181                Http => {
1182                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1183                    Ok(Connection::new(
1184                        stream
1185                            .map_err(|error| anyhow!(error))
1186                            .sink_map_err(|error| anyhow!(error)),
1187                    ))
1188                }
1189            }
1190        })
1191    }
1192
1193    pub fn authenticate_with_browser(
1194        self: &Arc<Self>,
1195        cx: &AsyncAppContext,
1196    ) -> Task<Result<Credentials>> {
1197        let http = self.http.clone();
1198        let this = self.clone();
1199        cx.spawn(|cx| async move {
1200            let background = cx.background_executor().clone();
1201
1202            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1203            cx.update(|cx| {
1204                cx.spawn(move |cx| async move {
1205                    let url = open_url_rx.await?;
1206                    cx.update(|cx| cx.open_url(&url))
1207                })
1208                .detach_and_log_err(cx);
1209            })
1210            .log_err();
1211
1212            let credentials = background
1213                .clone()
1214                .spawn(async move {
1215                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1216                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1217                    // any other app running on the user's device.
1218                    let (public_key, private_key) =
1219                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1220                    let public_key_string = String::try_from(public_key)
1221                        .expect("failed to serialize public key for auth");
1222
1223                    if let Some((login, token)) =
1224                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1225                    {
1226                        eprintln!("authenticate as admin {login}, {token}");
1227
1228                        return this
1229                            .authenticate_as_admin(http, login.clone(), token.clone())
1230                            .await;
1231                    }
1232
1233                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1234                    let server =
1235                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1236                    let port = server.server_addr().port();
1237
1238                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1239                    // that the user is signing in from a Zed app running on the same device.
1240                    let mut url = http.build_url(&format!(
1241                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1242                        port, public_key_string
1243                    ));
1244
1245                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1246                        log::info!("impersonating user @{}", impersonate_login);
1247                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1248                    }
1249
1250                    open_url_tx.send(url).log_err();
1251
1252                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1253                    // access token from the query params.
1254                    //
1255                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1256                    // custom URL scheme instead of this local HTTP server.
1257                    let (user_id, access_token) = background
1258                        .spawn(async move {
1259                            for _ in 0..100 {
1260                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1261                                    let path = req.url();
1262                                    let mut user_id = None;
1263                                    let mut access_token = None;
1264                                    let url = Url::parse(&format!("http://example.com{}", path))
1265                                        .context("failed to parse login notification url")?;
1266                                    for (key, value) in url.query_pairs() {
1267                                        if key == "access_token" {
1268                                            access_token = Some(value.to_string());
1269                                        } else if key == "user_id" {
1270                                            user_id = Some(value.to_string());
1271                                        }
1272                                    }
1273
1274                                    let post_auth_url =
1275                                        http.build_url("/native_app_signin_succeeded");
1276                                    req.respond(
1277                                        tiny_http::Response::empty(302).with_header(
1278                                            tiny_http::Header::from_bytes(
1279                                                &b"Location"[..],
1280                                                post_auth_url.as_bytes(),
1281                                            )
1282                                            .unwrap(),
1283                                        ),
1284                                    )
1285                                    .context("failed to respond to login http request")?;
1286                                    return Ok((
1287                                        user_id
1288                                            .ok_or_else(|| anyhow!("missing user_id parameter"))?,
1289                                        access_token.ok_or_else(|| {
1290                                            anyhow!("missing access_token parameter")
1291                                        })?,
1292                                    ));
1293                                }
1294                            }
1295
1296                            Err(anyhow!("didn't receive login redirect"))
1297                        })
1298                        .await?;
1299
1300                    let access_token = private_key
1301                        .decrypt_string(&access_token)
1302                        .context("failed to decrypt access token")?;
1303
1304                    Ok(Credentials::User {
1305                        user_id: user_id.parse()?,
1306                        access_token,
1307                    })
1308                })
1309                .await?;
1310
1311            cx.update(|cx| cx.activate(true))?;
1312            Ok(credentials)
1313        })
1314    }
1315
1316    async fn authenticate_as_admin(
1317        self: &Arc<Self>,
1318        http: Arc<HttpClientWithUrl>,
1319        login: String,
1320        mut api_token: String,
1321    ) -> Result<Credentials> {
1322        #[derive(Deserialize)]
1323        struct AuthenticatedUserResponse {
1324            user: User,
1325        }
1326
1327        #[derive(Deserialize)]
1328        struct User {
1329            id: u64,
1330        }
1331
1332        let github_user = {
1333            #[derive(Deserialize)]
1334            struct GithubUser {
1335                id: i32,
1336                login: String,
1337                created_at: DateTime<Utc>,
1338            }
1339
1340            let request = {
1341                let mut request_builder =
1342                    Request::get(&format!("https://api.github.com/users/{login}"));
1343                if let Ok(github_token) = std::env::var("GITHUB_TOKEN") {
1344                    request_builder =
1345                        request_builder.header("Authorization", format!("Bearer {}", github_token));
1346                }
1347
1348                request_builder.body(AsyncBody::empty())?
1349            };
1350
1351            let mut response = http
1352                .send(request)
1353                .await
1354                .context("error fetching GitHub user")?;
1355
1356            let mut body = Vec::new();
1357            response
1358                .body_mut()
1359                .read_to_end(&mut body)
1360                .await
1361                .context("error reading GitHub user")?;
1362
1363            if !response.status().is_success() {
1364                let text = String::from_utf8_lossy(body.as_slice());
1365                bail!(
1366                    "status error {}, response: {text:?}",
1367                    response.status().as_u16()
1368                );
1369            }
1370
1371            serde_json::from_slice::<GithubUser>(body.as_slice()).map_err(|err| {
1372                log::error!("Error deserializing: {:?}", err);
1373                log::error!(
1374                    "GitHub API response text: {:?}",
1375                    String::from_utf8_lossy(body.as_slice())
1376                );
1377                anyhow!("error deserializing GitHub user")
1378            })?
1379        };
1380
1381        let query_params = [
1382            ("github_login", &github_user.login),
1383            ("github_user_id", &github_user.id.to_string()),
1384            (
1385                "github_user_created_at",
1386                &github_user.created_at.to_rfc3339(),
1387            ),
1388        ];
1389
1390        // Use the collab server's admin API to retrieve the ID
1391        // of the impersonated user.
1392        let mut url = self.rpc_url(http.clone(), None).await?;
1393        url.set_path("/user");
1394        url.set_query(Some(
1395            &query_params
1396                .iter()
1397                .map(|(key, value)| {
1398                    format!(
1399                        "{}={}",
1400                        key,
1401                        url::form_urlencoded::byte_serialize(value.as_bytes()).collect::<String>()
1402                    )
1403                })
1404                .collect::<Vec<String>>()
1405                .join("&"),
1406        ));
1407        let request: http_client::Request<AsyncBody> = Request::get(url.as_str())
1408            .header("Authorization", format!("token {api_token}"))
1409            .body("".into())?;
1410
1411        let mut response = http.send(request).await?;
1412        let mut body = String::new();
1413        response.body_mut().read_to_string(&mut body).await?;
1414        if !response.status().is_success() {
1415            Err(anyhow!(
1416                "admin user request failed {} - {}",
1417                response.status().as_u16(),
1418                body,
1419            ))?;
1420        }
1421        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1422
1423        // Use the admin API token to authenticate as the impersonated user.
1424        api_token.insert_str(0, "ADMIN_TOKEN:");
1425        Ok(Credentials::User {
1426            user_id: response.user.id,
1427            access_token: api_token,
1428        })
1429    }
1430
1431    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncAppContext) {
1432        self.state.write().credentials = None;
1433        self.disconnect(cx);
1434
1435        if self.has_credentials(cx).await {
1436            self.credentials_provider
1437                .delete_credentials(cx)
1438                .await
1439                .log_err();
1440        }
1441    }
1442
1443    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1444        self.peer.teardown();
1445        self.set_status(Status::SignedOut, cx);
1446    }
1447
1448    pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1449        self.peer.teardown();
1450        self.set_status(Status::ConnectionLost, cx);
1451    }
1452
1453    fn connection_id(&self) -> Result<ConnectionId> {
1454        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1455            Ok(connection_id)
1456        } else {
1457            Err(anyhow!("not connected"))
1458        }
1459    }
1460
1461    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1462        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1463        self.peer.send(self.connection_id()?, message)
1464    }
1465
1466    pub fn request<T: RequestMessage>(
1467        &self,
1468        request: T,
1469    ) -> impl Future<Output = Result<T::Response>> {
1470        self.request_envelope(request)
1471            .map_ok(|envelope| envelope.payload)
1472    }
1473
1474    pub fn request_stream<T: RequestMessage>(
1475        &self,
1476        request: T,
1477    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1478        let client_id = self.id.load(Ordering::SeqCst);
1479        log::debug!(
1480            "rpc request start. client_id:{}. name:{}",
1481            client_id,
1482            T::NAME
1483        );
1484        let response = self
1485            .connection_id()
1486            .map(|conn_id| self.peer.request_stream(conn_id, request));
1487        async move {
1488            let response = response?.await;
1489            log::debug!(
1490                "rpc request finish. client_id:{}. name:{}",
1491                client_id,
1492                T::NAME
1493            );
1494            response
1495        }
1496    }
1497
1498    pub fn request_envelope<T: RequestMessage>(
1499        &self,
1500        request: T,
1501    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1502        let client_id = self.id();
1503        log::debug!(
1504            "rpc request start. client_id:{}. name:{}",
1505            client_id,
1506            T::NAME
1507        );
1508        let response = self
1509            .connection_id()
1510            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1511        async move {
1512            let response = response?.await;
1513            log::debug!(
1514                "rpc request finish. client_id:{}. name:{}",
1515                client_id,
1516                T::NAME
1517            );
1518            response
1519        }
1520    }
1521
1522    pub fn request_dynamic(
1523        &self,
1524        envelope: proto::Envelope,
1525        request_type: &'static str,
1526    ) -> impl Future<Output = Result<proto::Envelope>> {
1527        let client_id = self.id();
1528        log::debug!(
1529            "rpc request start. client_id:{}. name:{}",
1530            client_id,
1531            request_type
1532        );
1533        let response = self
1534            .connection_id()
1535            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1536        async move {
1537            let response = response?.await;
1538            log::debug!(
1539                "rpc request finish. client_id:{}. name:{}",
1540                client_id,
1541                request_type
1542            );
1543            Ok(response?.0)
1544        }
1545    }
1546
1547    fn handle_message(
1548        self: &Arc<Client>,
1549        message: Box<dyn AnyTypedEnvelope>,
1550        cx: &AsyncAppContext,
1551    ) {
1552        let sender_id = message.sender_id();
1553        let request_id = message.message_id();
1554        let type_name = message.payload_type_name();
1555        let original_sender_id = message.original_sender_id();
1556
1557        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1558            &self.handler_set,
1559            message,
1560            self.clone().into(),
1561            cx.clone(),
1562        ) {
1563            let client_id = self.id();
1564            log::debug!(
1565                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1566                client_id,
1567                original_sender_id,
1568                type_name
1569            );
1570            cx.spawn(move |_| async move {
1571                match future.await {
1572                    Ok(()) => {
1573                        log::debug!(
1574                            "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1575                            client_id,
1576                            original_sender_id,
1577                            type_name
1578                        );
1579                    }
1580                    Err(error) => {
1581                        log::error!(
1582                            "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1583                            client_id,
1584                            original_sender_id,
1585                            type_name,
1586                            error
1587                        );
1588                    }
1589                }
1590            })
1591            .detach();
1592        } else {
1593            log::info!("unhandled message {}", type_name);
1594            self.peer
1595                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1596                .log_err();
1597        }
1598    }
1599
1600    pub fn telemetry(&self) -> &Arc<Telemetry> {
1601        &self.telemetry
1602    }
1603}
1604
1605impl ProtoClient for Client {
1606    fn request(
1607        &self,
1608        envelope: proto::Envelope,
1609        request_type: &'static str,
1610    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1611        self.request_dynamic(envelope, request_type).boxed()
1612    }
1613
1614    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1615        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1616        let connection_id = self.connection_id()?;
1617        self.peer.send_dynamic(connection_id, envelope)
1618    }
1619
1620    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1621        log::debug!(
1622            "rpc respond. client_id:{}, name:{}",
1623            self.id(),
1624            message_type
1625        );
1626        let connection_id = self.connection_id()?;
1627        self.peer.send_dynamic(connection_id, envelope)
1628    }
1629
1630    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1631        &self.handler_set
1632    }
1633
1634    fn is_via_collab(&self) -> bool {
1635        true
1636    }
1637}
1638
1639#[derive(Serialize, Deserialize)]
1640struct DevelopmentCredentials {
1641    user_id: u64,
1642    access_token: String,
1643}
1644
1645/// A credentials provider that stores credentials in a local file.
1646///
1647/// This MUST only be used in development, as this is not a secure way of storing
1648/// credentials on user machines.
1649///
1650/// Its existence is purely to work around the annoyance of having to constantly
1651/// re-allow access to the system keychain when developing Zed.
1652struct DevelopmentCredentialsProvider {
1653    path: PathBuf,
1654}
1655
1656impl CredentialsProvider for DevelopmentCredentialsProvider {
1657    fn read_credentials<'a>(
1658        &'a self,
1659        _cx: &'a AsyncAppContext,
1660    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1661        async move {
1662            if IMPERSONATE_LOGIN.is_some() {
1663                return None;
1664            }
1665
1666            let json = std::fs::read(&self.path).log_err()?;
1667
1668            let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?;
1669
1670            Some(Credentials::User {
1671                user_id: credentials.user_id,
1672                access_token: credentials.access_token,
1673            })
1674        }
1675        .boxed_local()
1676    }
1677
1678    fn write_credentials<'a>(
1679        &'a self,
1680        user_id: u64,
1681        access_token: String,
1682        _cx: &'a AsyncAppContext,
1683    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1684        async move {
1685            let json = serde_json::to_string(&DevelopmentCredentials {
1686                user_id,
1687                access_token,
1688            })?;
1689
1690            std::fs::write(&self.path, json)?;
1691
1692            Ok(())
1693        }
1694        .boxed_local()
1695    }
1696
1697    fn delete_credentials<'a>(
1698        &'a self,
1699        _cx: &'a AsyncAppContext,
1700    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1701        async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local()
1702    }
1703}
1704
1705/// A credentials provider that stores credentials in the system keychain.
1706struct KeychainCredentialsProvider;
1707
1708impl CredentialsProvider for KeychainCredentialsProvider {
1709    fn read_credentials<'a>(
1710        &'a self,
1711        cx: &'a AsyncAppContext,
1712    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1713        async move {
1714            if IMPERSONATE_LOGIN.is_some() {
1715                return None;
1716            }
1717
1718            let (user_id, access_token) = cx
1719                .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
1720                .log_err()?
1721                .await
1722                .log_err()??;
1723
1724            Some(Credentials::User {
1725                user_id: user_id.parse().ok()?,
1726                access_token: String::from_utf8(access_token).ok()?,
1727            })
1728        }
1729        .boxed_local()
1730    }
1731
1732    fn write_credentials<'a>(
1733        &'a self,
1734        user_id: u64,
1735        access_token: String,
1736        cx: &'a AsyncAppContext,
1737    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1738        async move {
1739            cx.update(move |cx| {
1740                cx.write_credentials(
1741                    &ClientSettings::get_global(cx).server_url,
1742                    &user_id.to_string(),
1743                    access_token.as_bytes(),
1744                )
1745            })?
1746            .await
1747        }
1748        .boxed_local()
1749    }
1750
1751    fn delete_credentials<'a>(
1752        &'a self,
1753        cx: &'a AsyncAppContext,
1754    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1755        async move {
1756            cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
1757                .await
1758        }
1759        .boxed_local()
1760    }
1761}
1762
1763/// prefix for the zed:// url scheme
1764pub const ZED_URL_SCHEME: &str = "zed";
1765
1766/// Parses the given link into a Zed link.
1767///
1768/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1769/// Returns [`None`] otherwise.
1770pub fn parse_zed_link<'a>(link: &'a str, cx: &AppContext) -> Option<&'a str> {
1771    let server_url = &ClientSettings::get_global(cx).server_url;
1772    if let Some(stripped) = link
1773        .strip_prefix(server_url)
1774        .and_then(|result| result.strip_prefix('/'))
1775    {
1776        return Some(stripped);
1777    }
1778    if let Some(stripped) = link
1779        .strip_prefix(ZED_URL_SCHEME)
1780        .and_then(|result| result.strip_prefix("://"))
1781    {
1782        return Some(stripped);
1783    }
1784
1785    None
1786}
1787
1788#[cfg(test)]
1789mod tests {
1790    use super::*;
1791    use crate::test::FakeServer;
1792
1793    use clock::FakeSystemClock;
1794    use gpui::{BackgroundExecutor, Context, TestAppContext};
1795    use http_client::FakeHttpClient;
1796    use parking_lot::Mutex;
1797    use proto::TypedEnvelope;
1798    use settings::SettingsStore;
1799    use std::future;
1800
1801    #[gpui::test(iterations = 10)]
1802    async fn test_reconnection(cx: &mut TestAppContext) {
1803        init_test(cx);
1804        let user_id = 5;
1805        let client = cx.update(|cx| {
1806            Client::new(
1807                Arc::new(FakeSystemClock::default()),
1808                FakeHttpClient::with_404_response(),
1809                cx,
1810            )
1811        });
1812        let server = FakeServer::for_client(user_id, &client, cx).await;
1813        let mut status = client.status();
1814        assert!(matches!(
1815            status.next().await,
1816            Some(Status::Connected { .. })
1817        ));
1818        assert_eq!(server.auth_count(), 1);
1819
1820        server.forbid_connections();
1821        server.disconnect();
1822        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1823
1824        server.allow_connections();
1825        cx.executor().advance_clock(Duration::from_secs(10));
1826        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1827        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1828
1829        server.forbid_connections();
1830        server.disconnect();
1831        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1832
1833        // Clear cached credentials after authentication fails
1834        server.roll_access_token();
1835        server.allow_connections();
1836        cx.executor().run_until_parked();
1837        cx.executor().advance_clock(Duration::from_secs(10));
1838        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1839        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1840    }
1841
1842    #[gpui::test(iterations = 10)]
1843    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1844        init_test(cx);
1845        let user_id = 5;
1846        let client = cx.update(|cx| {
1847            Client::new(
1848                Arc::new(FakeSystemClock::default()),
1849                FakeHttpClient::with_404_response(),
1850                cx,
1851            )
1852        });
1853        let mut status = client.status();
1854
1855        // Time out when client tries to connect.
1856        client.override_authenticate(move |cx| {
1857            cx.background_executor().spawn(async move {
1858                Ok(Credentials::User {
1859                    user_id,
1860                    access_token: "token".into(),
1861                })
1862            })
1863        });
1864        client.override_establish_connection(|_, cx| {
1865            cx.background_executor().spawn(async move {
1866                future::pending::<()>().await;
1867                unreachable!()
1868            })
1869        });
1870        let auth_and_connect = cx.spawn({
1871            let client = client.clone();
1872            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1873        });
1874        executor.run_until_parked();
1875        assert!(matches!(status.next().await, Some(Status::Connecting)));
1876
1877        executor.advance_clock(CONNECTION_TIMEOUT);
1878        assert!(matches!(
1879            status.next().await,
1880            Some(Status::ConnectionError { .. })
1881        ));
1882        auth_and_connect.await.unwrap_err();
1883
1884        // Allow the connection to be established.
1885        let server = FakeServer::for_client(user_id, &client, cx).await;
1886        assert!(matches!(
1887            status.next().await,
1888            Some(Status::Connected { .. })
1889        ));
1890
1891        // Disconnect client.
1892        server.forbid_connections();
1893        server.disconnect();
1894        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1895
1896        // Time out when re-establishing the connection.
1897        server.allow_connections();
1898        client.override_establish_connection(|_, cx| {
1899            cx.background_executor().spawn(async move {
1900                future::pending::<()>().await;
1901                unreachable!()
1902            })
1903        });
1904        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1905        assert!(matches!(
1906            status.next().await,
1907            Some(Status::Reconnecting { .. })
1908        ));
1909
1910        executor.advance_clock(CONNECTION_TIMEOUT);
1911        assert!(matches!(
1912            status.next().await,
1913            Some(Status::ReconnectionError { .. })
1914        ));
1915    }
1916
1917    #[gpui::test(iterations = 10)]
1918    async fn test_authenticating_more_than_once(
1919        cx: &mut TestAppContext,
1920        executor: BackgroundExecutor,
1921    ) {
1922        init_test(cx);
1923        let auth_count = Arc::new(Mutex::new(0));
1924        let dropped_auth_count = Arc::new(Mutex::new(0));
1925        let client = cx.update(|cx| {
1926            Client::new(
1927                Arc::new(FakeSystemClock::default()),
1928                FakeHttpClient::with_404_response(),
1929                cx,
1930            )
1931        });
1932        client.override_authenticate({
1933            let auth_count = auth_count.clone();
1934            let dropped_auth_count = dropped_auth_count.clone();
1935            move |cx| {
1936                let auth_count = auth_count.clone();
1937                let dropped_auth_count = dropped_auth_count.clone();
1938                cx.background_executor().spawn(async move {
1939                    *auth_count.lock() += 1;
1940                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1941                    future::pending::<()>().await;
1942                    unreachable!()
1943                })
1944            }
1945        });
1946
1947        let _authenticate = cx.spawn({
1948            let client = client.clone();
1949            move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1950        });
1951        executor.run_until_parked();
1952        assert_eq!(*auth_count.lock(), 1);
1953        assert_eq!(*dropped_auth_count.lock(), 0);
1954
1955        let _authenticate = cx.spawn({
1956            let client = client.clone();
1957            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1958        });
1959        executor.run_until_parked();
1960        assert_eq!(*auth_count.lock(), 2);
1961        assert_eq!(*dropped_auth_count.lock(), 1);
1962    }
1963
1964    #[gpui::test]
1965    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1966        init_test(cx);
1967        let user_id = 5;
1968        let client = cx.update(|cx| {
1969            Client::new(
1970                Arc::new(FakeSystemClock::default()),
1971                FakeHttpClient::with_404_response(),
1972                cx,
1973            )
1974        });
1975        let server = FakeServer::for_client(user_id, &client, cx).await;
1976
1977        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1978        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1979        AnyProtoClient::from(client.clone()).add_model_message_handler(
1980            move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, mut cx| {
1981                match model.update(&mut cx, |model, _| model.id).unwrap() {
1982                    1 => done_tx1.try_send(()).unwrap(),
1983                    2 => done_tx2.try_send(()).unwrap(),
1984                    _ => unreachable!(),
1985                }
1986                async { Ok(()) }
1987            },
1988        );
1989        let model1 = cx.new_model(|_| TestModel {
1990            id: 1,
1991            subscription: None,
1992        });
1993        let model2 = cx.new_model(|_| TestModel {
1994            id: 2,
1995            subscription: None,
1996        });
1997        let model3 = cx.new_model(|_| TestModel {
1998            id: 3,
1999            subscription: None,
2000        });
2001
2002        let _subscription1 = client
2003            .subscribe_to_entity(1)
2004            .unwrap()
2005            .set_model(&model1, &mut cx.to_async());
2006        let _subscription2 = client
2007            .subscribe_to_entity(2)
2008            .unwrap()
2009            .set_model(&model2, &mut cx.to_async());
2010        // Ensure dropping a subscription for the same entity type still allows receiving of
2011        // messages for other entity IDs of the same type.
2012        let subscription3 = client
2013            .subscribe_to_entity(3)
2014            .unwrap()
2015            .set_model(&model3, &mut cx.to_async());
2016        drop(subscription3);
2017
2018        server.send(proto::JoinProject { project_id: 1 });
2019        server.send(proto::JoinProject { project_id: 2 });
2020        done_rx1.next().await.unwrap();
2021        done_rx2.next().await.unwrap();
2022    }
2023
2024    #[gpui::test]
2025    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2026        init_test(cx);
2027        let user_id = 5;
2028        let client = cx.update(|cx| {
2029            Client::new(
2030                Arc::new(FakeSystemClock::default()),
2031                FakeHttpClient::with_404_response(),
2032                cx,
2033            )
2034        });
2035        let server = FakeServer::for_client(user_id, &client, cx).await;
2036
2037        let model = cx.new_model(|_| TestModel::default());
2038        let (done_tx1, _done_rx1) = smol::channel::unbounded();
2039        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
2040        let subscription1 = client.add_message_handler(
2041            model.downgrade(),
2042            move |_, _: TypedEnvelope<proto::Ping>, _| {
2043                done_tx1.try_send(()).unwrap();
2044                async { Ok(()) }
2045            },
2046        );
2047        drop(subscription1);
2048        let _subscription2 = client.add_message_handler(
2049            model.downgrade(),
2050            move |_, _: TypedEnvelope<proto::Ping>, _| {
2051                done_tx2.try_send(()).unwrap();
2052                async { Ok(()) }
2053            },
2054        );
2055        server.send(proto::Ping {});
2056        done_rx2.next().await.unwrap();
2057    }
2058
2059    #[gpui::test]
2060    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2061        init_test(cx);
2062        let user_id = 5;
2063        let client = cx.update(|cx| {
2064            Client::new(
2065                Arc::new(FakeSystemClock::default()),
2066                FakeHttpClient::with_404_response(),
2067                cx,
2068            )
2069        });
2070        let server = FakeServer::for_client(user_id, &client, cx).await;
2071
2072        let model = cx.new_model(|_| TestModel::default());
2073        let (done_tx, mut done_rx) = smol::channel::unbounded();
2074        let subscription = client.add_message_handler(
2075            model.clone().downgrade(),
2076            move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, mut cx| {
2077                model
2078                    .update(&mut cx, |model, _| model.subscription.take())
2079                    .unwrap();
2080                done_tx.try_send(()).unwrap();
2081                async { Ok(()) }
2082            },
2083        );
2084        model.update(cx, |model, _| {
2085            model.subscription = Some(subscription);
2086        });
2087        server.send(proto::Ping {});
2088        done_rx.next().await.unwrap();
2089    }
2090
2091    #[derive(Default)]
2092    struct TestModel {
2093        id: usize,
2094        subscription: Option<Subscription>,
2095    }
2096
2097    fn init_test(cx: &mut TestAppContext) {
2098        cx.update(|cx| {
2099            let settings_store = SettingsStore::test(cx);
2100            cx.set_global(settings_store);
2101            init_settings(cx);
2102        });
2103    }
2104}