client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod socks;
   5pub mod telemetry;
   6pub mod user;
   7
   8use anyhow::{anyhow, bail, Context as _, Result};
   9use async_recursion::async_recursion;
  10use async_tungstenite::tungstenite::{
  11    client::IntoClientRequest,
  12    error::Error as WebsocketError,
  13    http::{HeaderValue, Request, StatusCode},
  14};
  15use chrono::{DateTime, Utc};
  16use clock::SystemClock;
  17use futures::{
  18    channel::oneshot, future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
  19    TryFutureExt as _, TryStreamExt,
  20};
  21use gpui::{actions, AppContext, AsyncAppContext, Global, Model, Task, WeakModel};
  22use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
  23use parking_lot::RwLock;
  24use postage::watch;
  25use proto::{AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet};
  26use rand::prelude::*;
  27use release_channel::{AppVersion, ReleaseChannel};
  28use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::{Settings, SettingsSources};
  32use socks::connect_socks_proxy_stream;
  33use std::fmt;
  34use std::pin::Pin;
  35use std::{
  36    any::TypeId,
  37    convert::TryFrom,
  38    fmt::Write as _,
  39    future::Future,
  40    marker::PhantomData,
  41    path::PathBuf,
  42    sync::{
  43        atomic::{AtomicU64, Ordering},
  44        Arc, LazyLock, Weak,
  45    },
  46    time::{Duration, Instant},
  47};
  48use telemetry::Telemetry;
  49use thiserror::Error;
  50use url::Url;
  51use util::{ResultExt, TryFutureExt};
  52
  53pub use rpc::*;
  54pub use telemetry_events::Event;
  55pub use user::*;
  56
  57#[derive(Debug, Clone, Eq, PartialEq)]
  58pub struct DevServerToken(pub String);
  59
  60impl fmt::Display for DevServerToken {
  61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  62        write!(f, "{}", self.0)
  63    }
  64}
  65
  66static ZED_SERVER_URL: LazyLock<Option<String>> =
  67    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  68static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  69
  70/// An environment variable whose presence indicates that the development auth
  71/// provider should be used.
  72///
  73/// Only works in development. Setting this environment variable in other release
  74/// channels is a no-op.
  75pub static ZED_DEVELOPMENT_AUTH: LazyLock<bool> = LazyLock::new(|| {
  76    std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty())
  77});
  78pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  79    std::env::var("ZED_IMPERSONATE")
  80        .ok()
  81        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  82});
  83
  84pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  85    std::env::var("ZED_ADMIN_API_TOKEN")
  86        .ok()
  87        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  88});
  89
  90pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  91    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  92
  93pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  94    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty()));
  95
  96pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
  97pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
  98pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
  99
 100actions!(client, [SignIn, SignOut, Reconnect]);
 101
 102#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
 103pub struct ClientSettingsContent {
 104    server_url: Option<String>,
 105}
 106
 107#[derive(Deserialize)]
 108pub struct ClientSettings {
 109    pub server_url: String,
 110}
 111
 112impl Settings for ClientSettings {
 113    const KEY: Option<&'static str> = None;
 114
 115    type FileContent = ClientSettingsContent;
 116
 117    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 118        let mut result = sources.json_merge::<Self>()?;
 119        if let Some(server_url) = &*ZED_SERVER_URL {
 120            result.server_url.clone_from(server_url)
 121        }
 122        Ok(result)
 123    }
 124}
 125
 126#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 127pub struct ProxySettingsContent {
 128    proxy: Option<String>,
 129}
 130
 131#[derive(Deserialize, Default)]
 132pub struct ProxySettings {
 133    pub proxy: Option<String>,
 134}
 135
 136impl Settings for ProxySettings {
 137    const KEY: Option<&'static str> = None;
 138
 139    type FileContent = ProxySettingsContent;
 140
 141    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 142        Ok(Self {
 143            proxy: sources
 144                .user
 145                .and_then(|value| value.proxy.clone())
 146                .or(sources.default.proxy.clone()),
 147        })
 148    }
 149}
 150
 151pub fn init_settings(cx: &mut AppContext) {
 152    TelemetrySettings::register(cx);
 153    ClientSettings::register(cx);
 154    ProxySettings::register(cx);
 155}
 156
 157pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
 158    let client = Arc::downgrade(client);
 159    cx.on_action({
 160        let client = client.clone();
 161        move |_: &SignIn, cx| {
 162            if let Some(client) = client.upgrade() {
 163                cx.spawn(
 164                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
 165                )
 166                .detach();
 167            }
 168        }
 169    });
 170
 171    cx.on_action({
 172        let client = client.clone();
 173        move |_: &SignOut, cx| {
 174            if let Some(client) = client.upgrade() {
 175                cx.spawn(|cx| async move {
 176                    client.sign_out(&cx).await;
 177                })
 178                .detach();
 179            }
 180        }
 181    });
 182
 183    cx.on_action({
 184        let client = client.clone();
 185        move |_: &Reconnect, cx| {
 186            if let Some(client) = client.upgrade() {
 187                cx.spawn(|cx| async move {
 188                    client.reconnect(&cx);
 189                })
 190                .detach();
 191            }
 192        }
 193    });
 194}
 195
 196struct GlobalClient(Arc<Client>);
 197
 198impl Global for GlobalClient {}
 199
 200pub struct Client {
 201    id: AtomicU64,
 202    peer: Arc<Peer>,
 203    http: Arc<HttpClientWithUrl>,
 204    telemetry: Arc<Telemetry>,
 205    credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
 206    state: RwLock<ClientState>,
 207    handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
 208
 209    #[allow(clippy::type_complexity)]
 210    #[cfg(any(test, feature = "test-support"))]
 211    authenticate: RwLock<
 212        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 213    >,
 214
 215    #[allow(clippy::type_complexity)]
 216    #[cfg(any(test, feature = "test-support"))]
 217    establish_connection: RwLock<
 218        Option<
 219            Box<
 220                dyn 'static
 221                    + Send
 222                    + Sync
 223                    + Fn(
 224                        &Credentials,
 225                        &AsyncAppContext,
 226                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 227            >,
 228        >,
 229    >,
 230
 231    #[cfg(any(test, feature = "test-support"))]
 232    rpc_url: RwLock<Option<Url>>,
 233}
 234
 235#[derive(Error, Debug)]
 236pub enum EstablishConnectionError {
 237    #[error("upgrade required")]
 238    UpgradeRequired,
 239    #[error("unauthorized")]
 240    Unauthorized,
 241    #[error("{0}")]
 242    Other(#[from] anyhow::Error),
 243    #[error("{0}")]
 244    Http(#[from] http_client::Error),
 245    #[error("{0}")]
 246    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 247    #[error("{0}")]
 248    Io(#[from] std::io::Error),
 249    #[error("{0}")]
 250    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 251}
 252
 253impl From<WebsocketError> for EstablishConnectionError {
 254    fn from(error: WebsocketError) -> Self {
 255        if let WebsocketError::Http(response) = &error {
 256            match response.status() {
 257                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 258                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 259                _ => {}
 260            }
 261        }
 262        EstablishConnectionError::Other(error.into())
 263    }
 264}
 265
 266impl EstablishConnectionError {
 267    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 268        Self::Other(error.into())
 269    }
 270}
 271
 272#[derive(Copy, Clone, Debug, PartialEq)]
 273pub enum Status {
 274    SignedOut,
 275    UpgradeRequired,
 276    Authenticating,
 277    Connecting,
 278    ConnectionError,
 279    Connected {
 280        peer_id: PeerId,
 281        connection_id: ConnectionId,
 282    },
 283    ConnectionLost,
 284    Reauthenticating,
 285    Reconnecting,
 286    ReconnectionError {
 287        next_reconnection: Instant,
 288    },
 289}
 290
 291impl Status {
 292    pub fn is_connected(&self) -> bool {
 293        matches!(self, Self::Connected { .. })
 294    }
 295
 296    pub fn is_signed_out(&self) -> bool {
 297        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 298    }
 299}
 300
 301struct ClientState {
 302    credentials: Option<Credentials>,
 303    status: (watch::Sender<Status>, watch::Receiver<Status>),
 304    _reconnect_task: Option<Task<()>>,
 305}
 306
 307#[derive(Clone, Debug, Eq, PartialEq)]
 308pub enum Credentials {
 309    DevServer { token: DevServerToken },
 310    User { user_id: u64, access_token: String },
 311}
 312
 313impl Credentials {
 314    pub fn authorization_header(&self) -> String {
 315        match self {
 316            Credentials::DevServer { token } => format!("dev-server-token {}", token),
 317            Credentials::User {
 318                user_id,
 319                access_token,
 320            } => format!("{} {}", user_id, access_token),
 321        }
 322    }
 323}
 324
 325/// A provider for [`Credentials`].
 326///
 327/// Used to abstract over reading and writing credentials to some form of
 328/// persistence (like the system keychain).
 329trait CredentialsProvider {
 330    /// Reads the credentials from the provider.
 331    fn read_credentials<'a>(
 332        &'a self,
 333        cx: &'a AsyncAppContext,
 334    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>>;
 335
 336    /// Writes the credentials to the provider.
 337    fn write_credentials<'a>(
 338        &'a self,
 339        user_id: u64,
 340        access_token: String,
 341        cx: &'a AsyncAppContext,
 342    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 343
 344    /// Deletes the credentials from the provider.
 345    fn delete_credentials<'a>(
 346        &'a self,
 347        cx: &'a AsyncAppContext,
 348    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 349}
 350
 351impl Default for ClientState {
 352    fn default() -> Self {
 353        Self {
 354            credentials: None,
 355            status: watch::channel_with(Status::SignedOut),
 356            _reconnect_task: None,
 357        }
 358    }
 359}
 360
 361pub enum Subscription {
 362    Entity {
 363        client: Weak<Client>,
 364        id: (TypeId, u64),
 365    },
 366    Message {
 367        client: Weak<Client>,
 368        id: TypeId,
 369    },
 370}
 371
 372impl Drop for Subscription {
 373    fn drop(&mut self) {
 374        match self {
 375            Subscription::Entity { client, id } => {
 376                if let Some(client) = client.upgrade() {
 377                    let mut state = client.handler_set.lock();
 378                    let _ = state.entities_by_type_and_remote_id.remove(id);
 379                }
 380            }
 381            Subscription::Message { client, id } => {
 382                if let Some(client) = client.upgrade() {
 383                    let mut state = client.handler_set.lock();
 384                    let _ = state.entity_types_by_message_type.remove(id);
 385                    let _ = state.message_handlers.remove(id);
 386                }
 387            }
 388        }
 389    }
 390}
 391
 392pub struct PendingEntitySubscription<T: 'static> {
 393    client: Arc<Client>,
 394    remote_id: u64,
 395    _entity_type: PhantomData<T>,
 396    consumed: bool,
 397}
 398
 399impl<T: 'static> PendingEntitySubscription<T> {
 400    pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
 401        self.consumed = true;
 402        let mut handlers = self.client.handler_set.lock();
 403        let id = (TypeId::of::<T>(), self.remote_id);
 404        let Some(EntityMessageSubscriber::Pending(messages)) =
 405            handlers.entities_by_type_and_remote_id.remove(&id)
 406        else {
 407            unreachable!()
 408        };
 409
 410        handlers.entities_by_type_and_remote_id.insert(
 411            id,
 412            EntityMessageSubscriber::Entity {
 413                handle: model.downgrade().into(),
 414            },
 415        );
 416        drop(handlers);
 417        for message in messages {
 418            let client_id = self.client.id();
 419            let type_name = message.payload_type_name();
 420            let sender_id = message.original_sender_id();
 421            log::debug!(
 422                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 423                client_id,
 424                sender_id,
 425                type_name
 426            );
 427            self.client.handle_message(message, cx);
 428        }
 429        Subscription::Entity {
 430            client: Arc::downgrade(&self.client),
 431            id,
 432        }
 433    }
 434}
 435
 436impl<T: 'static> Drop for PendingEntitySubscription<T> {
 437    fn drop(&mut self) {
 438        if !self.consumed {
 439            let mut state = self.client.handler_set.lock();
 440            if let Some(EntityMessageSubscriber::Pending(messages)) = state
 441                .entities_by_type_and_remote_id
 442                .remove(&(TypeId::of::<T>(), self.remote_id))
 443            {
 444                for message in messages {
 445                    log::info!("unhandled message {}", message.payload_type_name());
 446                }
 447            }
 448        }
 449    }
 450}
 451
 452#[derive(Copy, Clone)]
 453pub struct TelemetrySettings {
 454    pub diagnostics: bool,
 455    pub metrics: bool,
 456}
 457
 458/// Control what info is collected by Zed.
 459#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 460pub struct TelemetrySettingsContent {
 461    /// Send debug info like crash reports.
 462    ///
 463    /// Default: true
 464    pub diagnostics: Option<bool>,
 465    /// Send anonymized usage data like what languages you're using Zed with.
 466    ///
 467    /// Default: true
 468    pub metrics: Option<bool>,
 469}
 470
 471impl settings::Settings for TelemetrySettings {
 472    const KEY: Option<&'static str> = Some("telemetry");
 473
 474    type FileContent = TelemetrySettingsContent;
 475
 476    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 477        Ok(Self {
 478            diagnostics: sources.user.as_ref().and_then(|v| v.diagnostics).unwrap_or(
 479                sources
 480                    .default
 481                    .diagnostics
 482                    .ok_or_else(Self::missing_default)?,
 483            ),
 484            metrics: sources
 485                .user
 486                .as_ref()
 487                .and_then(|v| v.metrics)
 488                .unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
 489        })
 490    }
 491}
 492
 493impl Client {
 494    pub fn new(
 495        clock: Arc<dyn SystemClock>,
 496        http: Arc<HttpClientWithUrl>,
 497        cx: &mut AppContext,
 498    ) -> Arc<Self> {
 499        let use_zed_development_auth = match ReleaseChannel::try_global(cx) {
 500            Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH,
 501            Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
 502            | None => false,
 503        };
 504
 505        let credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static> =
 506            if use_zed_development_auth {
 507                Arc::new(DevelopmentCredentialsProvider {
 508                    path: paths::config_dir().join("development_auth"),
 509                })
 510            } else {
 511                Arc::new(KeychainCredentialsProvider)
 512            };
 513
 514        Arc::new(Self {
 515            id: AtomicU64::new(0),
 516            peer: Peer::new(0),
 517            telemetry: Telemetry::new(clock, http.clone(), cx),
 518            http,
 519            credentials_provider,
 520            state: Default::default(),
 521            handler_set: Default::default(),
 522
 523            #[cfg(any(test, feature = "test-support"))]
 524            authenticate: Default::default(),
 525            #[cfg(any(test, feature = "test-support"))]
 526            establish_connection: Default::default(),
 527            #[cfg(any(test, feature = "test-support"))]
 528            rpc_url: RwLock::default(),
 529        })
 530    }
 531
 532    pub fn production(cx: &mut AppContext) -> Arc<Self> {
 533        let user_agent = format!(
 534            "Zed/{} ({}; {})",
 535            AppVersion::global(cx),
 536            std::env::consts::OS,
 537            std::env::consts::ARCH
 538        );
 539        let clock = Arc::new(clock::RealSystemClock);
 540        let http = Arc::new(HttpClientWithUrl::new(
 541            &ClientSettings::get_global(cx).server_url,
 542            Some(user_agent),
 543            ProxySettings::get_global(cx).proxy.clone(),
 544        ));
 545        Self::new(clock, http.clone(), cx)
 546    }
 547
 548    pub fn id(&self) -> u64 {
 549        self.id.load(Ordering::SeqCst)
 550    }
 551
 552    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 553        self.http.clone()
 554    }
 555
 556    pub fn set_id(&self, id: u64) -> &Self {
 557        self.id.store(id, Ordering::SeqCst);
 558        self
 559    }
 560
 561    #[cfg(any(test, feature = "test-support"))]
 562    pub fn teardown(&self) {
 563        let mut state = self.state.write();
 564        state._reconnect_task.take();
 565        self.handler_set.lock().clear();
 566        self.peer.teardown();
 567    }
 568
 569    #[cfg(any(test, feature = "test-support"))]
 570    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 571    where
 572        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 573    {
 574        *self.authenticate.write() = Some(Box::new(authenticate));
 575        self
 576    }
 577
 578    #[cfg(any(test, feature = "test-support"))]
 579    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 580    where
 581        F: 'static
 582            + Send
 583            + Sync
 584            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 585    {
 586        *self.establish_connection.write() = Some(Box::new(connect));
 587        self
 588    }
 589
 590    #[cfg(any(test, feature = "test-support"))]
 591    pub fn override_rpc_url(&self, url: Url) -> &Self {
 592        *self.rpc_url.write() = Some(url);
 593        self
 594    }
 595
 596    pub fn global(cx: &AppContext) -> Arc<Self> {
 597        cx.global::<GlobalClient>().0.clone()
 598    }
 599    pub fn set_global(client: Arc<Client>, cx: &mut AppContext) {
 600        cx.set_global(GlobalClient(client))
 601    }
 602
 603    pub fn user_id(&self) -> Option<u64> {
 604        if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() {
 605            Some(*user_id)
 606        } else {
 607            None
 608        }
 609    }
 610
 611    pub fn peer_id(&self) -> Option<PeerId> {
 612        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 613            Some(*peer_id)
 614        } else {
 615            None
 616        }
 617    }
 618
 619    pub fn status(&self) -> watch::Receiver<Status> {
 620        self.state.read().status.1.clone()
 621    }
 622
 623    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 624        log::info!("set status on client {}: {:?}", self.id(), status);
 625        let mut state = self.state.write();
 626        *state.status.0.borrow_mut() = status;
 627
 628        match status {
 629            Status::Connected { .. } => {
 630                state._reconnect_task = None;
 631            }
 632            Status::ConnectionLost => {
 633                let this = self.clone();
 634                state._reconnect_task = Some(cx.spawn(move |cx| async move {
 635                    #[cfg(any(test, feature = "test-support"))]
 636                    let mut rng = StdRng::seed_from_u64(0);
 637                    #[cfg(not(any(test, feature = "test-support")))]
 638                    let mut rng = StdRng::from_entropy();
 639
 640                    let mut delay = INITIAL_RECONNECTION_DELAY;
 641                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 642                        log::error!("failed to connect {}", error);
 643                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 644                            this.set_status(
 645                                Status::ReconnectionError {
 646                                    next_reconnection: Instant::now() + delay,
 647                                },
 648                                &cx,
 649                            );
 650                            cx.background_executor().timer(delay).await;
 651                            delay = delay
 652                                .mul_f32(rng.gen_range(0.5..=2.5))
 653                                .max(INITIAL_RECONNECTION_DELAY)
 654                                .min(MAX_RECONNECTION_DELAY);
 655                        } else {
 656                            break;
 657                        }
 658                    }
 659                }));
 660            }
 661            Status::SignedOut | Status::UpgradeRequired => {
 662                self.telemetry.set_authenticated_user_info(None, false);
 663                state._reconnect_task.take();
 664            }
 665            _ => {}
 666        }
 667    }
 668
 669    pub fn subscribe_to_entity<T>(
 670        self: &Arc<Self>,
 671        remote_id: u64,
 672    ) -> Result<PendingEntitySubscription<T>>
 673    where
 674        T: 'static,
 675    {
 676        let id = (TypeId::of::<T>(), remote_id);
 677
 678        let mut state = self.handler_set.lock();
 679        if state.entities_by_type_and_remote_id.contains_key(&id) {
 680            return Err(anyhow!("already subscribed to entity"));
 681        }
 682
 683        state
 684            .entities_by_type_and_remote_id
 685            .insert(id, EntityMessageSubscriber::Pending(Default::default()));
 686
 687        Ok(PendingEntitySubscription {
 688            client: self.clone(),
 689            remote_id,
 690            consumed: false,
 691            _entity_type: PhantomData,
 692        })
 693    }
 694
 695    #[track_caller]
 696    pub fn add_message_handler<M, E, H, F>(
 697        self: &Arc<Self>,
 698        entity: WeakModel<E>,
 699        handler: H,
 700    ) -> Subscription
 701    where
 702        M: EnvelopedMessage,
 703        E: 'static,
 704        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 705        F: 'static + Future<Output = Result<()>>,
 706    {
 707        self.add_message_handler_impl(entity, move |model, message, _, cx| {
 708            handler(model, message, cx)
 709        })
 710    }
 711
 712    fn add_message_handler_impl<M, E, H, F>(
 713        self: &Arc<Self>,
 714        entity: WeakModel<E>,
 715        handler: H,
 716    ) -> Subscription
 717    where
 718        M: EnvelopedMessage,
 719        E: 'static,
 720        H: 'static
 721            + Sync
 722            + Fn(Model<E>, TypedEnvelope<M>, AnyProtoClient, AsyncAppContext) -> F
 723            + Send
 724            + Sync,
 725        F: 'static + Future<Output = Result<()>>,
 726    {
 727        let message_type_id = TypeId::of::<M>();
 728        let mut state = self.handler_set.lock();
 729        state
 730            .models_by_message_type
 731            .insert(message_type_id, entity.into());
 732
 733        let prev_handler = state.message_handlers.insert(
 734            message_type_id,
 735            Arc::new(move |subscriber, envelope, client, cx| {
 736                let subscriber = subscriber.downcast::<E>().unwrap();
 737                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 738                handler(subscriber, *envelope, client.clone(), cx).boxed_local()
 739            }),
 740        );
 741        if prev_handler.is_some() {
 742            let location = std::panic::Location::caller();
 743            panic!(
 744                "{}:{} registered handler for the same message {} twice",
 745                location.file(),
 746                location.line(),
 747                std::any::type_name::<M>()
 748            );
 749        }
 750
 751        Subscription::Message {
 752            client: Arc::downgrade(self),
 753            id: message_type_id,
 754        }
 755    }
 756
 757    pub fn add_request_handler<M, E, H, F>(
 758        self: &Arc<Self>,
 759        model: WeakModel<E>,
 760        handler: H,
 761    ) -> Subscription
 762    where
 763        M: RequestMessage,
 764        E: 'static,
 765        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 766        F: 'static + Future<Output = Result<M::Response>>,
 767    {
 768        self.add_message_handler_impl(model, move |handle, envelope, this, cx| {
 769            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 770        })
 771    }
 772
 773    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 774        receipt: Receipt<T>,
 775        response: F,
 776        client: AnyProtoClient,
 777    ) -> Result<()> {
 778        match response.await {
 779            Ok(response) => {
 780                client.send_response(receipt.message_id, response)?;
 781                Ok(())
 782            }
 783            Err(error) => {
 784                client.send_response(receipt.message_id, error.to_proto())?;
 785                Err(error)
 786            }
 787        }
 788    }
 789
 790    pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool {
 791        self.credentials_provider
 792            .read_credentials(cx)
 793            .await
 794            .is_some()
 795    }
 796
 797    pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self {
 798        self.state.write().credentials = Some(Credentials::DevServer { token });
 799        self
 800    }
 801
 802    #[async_recursion(?Send)]
 803    pub async fn authenticate_and_connect(
 804        self: &Arc<Self>,
 805        try_provider: bool,
 806        cx: &AsyncAppContext,
 807    ) -> anyhow::Result<()> {
 808        let was_disconnected = match *self.status().borrow() {
 809            Status::SignedOut => true,
 810            Status::ConnectionError
 811            | Status::ConnectionLost
 812            | Status::Authenticating { .. }
 813            | Status::Reauthenticating { .. }
 814            | Status::ReconnectionError { .. } => false,
 815            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 816                return Ok(())
 817            }
 818            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 819        };
 820        if was_disconnected {
 821            self.set_status(Status::Authenticating, cx);
 822        } else {
 823            self.set_status(Status::Reauthenticating, cx)
 824        }
 825
 826        let mut read_from_provider = false;
 827        let mut credentials = self.state.read().credentials.clone();
 828        if credentials.is_none() && try_provider {
 829            credentials = self.credentials_provider.read_credentials(cx).await;
 830            read_from_provider = credentials.is_some();
 831        }
 832
 833        if credentials.is_none() {
 834            let mut status_rx = self.status();
 835            let _ = status_rx.next().await;
 836            futures::select_biased! {
 837                authenticate = self.authenticate(cx).fuse() => {
 838                    match authenticate {
 839                        Ok(creds) => credentials = Some(creds),
 840                        Err(err) => {
 841                            self.set_status(Status::ConnectionError, cx);
 842                            return Err(err);
 843                        }
 844                    }
 845                }
 846                _ = status_rx.next().fuse() => {
 847                    return Err(anyhow!("authentication canceled"));
 848                }
 849            }
 850        }
 851        let credentials = credentials.unwrap();
 852        if let Credentials::User { user_id, .. } = &credentials {
 853            self.set_id(*user_id);
 854        }
 855
 856        if was_disconnected {
 857            self.set_status(Status::Connecting, cx);
 858        } else {
 859            self.set_status(Status::Reconnecting, cx);
 860        }
 861
 862        let mut timeout =
 863            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
 864        futures::select_biased! {
 865            connection = self.establish_connection(&credentials, cx).fuse() => {
 866                match connection {
 867                    Ok(conn) => {
 868                        self.state.write().credentials = Some(credentials.clone());
 869                        if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
 870                            if let Credentials::User{user_id, access_token} = credentials {
 871                                self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err();
 872                            }
 873                        }
 874
 875                        futures::select_biased! {
 876                            result = self.set_connection(conn, cx).fuse() => result,
 877                            _ = timeout => {
 878                                self.set_status(Status::ConnectionError, cx);
 879                                Err(anyhow!("timed out waiting on hello message from server"))
 880                            }
 881                        }
 882                    }
 883                    Err(EstablishConnectionError::Unauthorized) => {
 884                        self.state.write().credentials.take();
 885                        if read_from_provider {
 886                            self.credentials_provider.delete_credentials(cx).await.log_err();
 887                            self.set_status(Status::SignedOut, cx);
 888                            self.authenticate_and_connect(false, cx).await
 889                        } else {
 890                            self.set_status(Status::ConnectionError, cx);
 891                            Err(EstablishConnectionError::Unauthorized)?
 892                        }
 893                    }
 894                    Err(EstablishConnectionError::UpgradeRequired) => {
 895                        self.set_status(Status::UpgradeRequired, cx);
 896                        Err(EstablishConnectionError::UpgradeRequired)?
 897                    }
 898                    Err(error) => {
 899                        self.set_status(Status::ConnectionError, cx);
 900                        Err(error)?
 901                    }
 902                }
 903            }
 904            _ = &mut timeout => {
 905                self.set_status(Status::ConnectionError, cx);
 906                Err(anyhow!("timed out trying to establish connection"))
 907            }
 908        }
 909    }
 910
 911    async fn set_connection(
 912        self: &Arc<Self>,
 913        conn: Connection,
 914        cx: &AsyncAppContext,
 915    ) -> Result<()> {
 916        let executor = cx.background_executor();
 917        log::info!("add connection to peer");
 918        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
 919            let executor = executor.clone();
 920            move |duration| executor.timer(duration)
 921        });
 922        let handle_io = executor.spawn(handle_io);
 923
 924        let peer_id = async {
 925            log::info!("waiting for server hello");
 926            let message = incoming
 927                .next()
 928                .await
 929                .ok_or_else(|| anyhow!("no hello message received"))?;
 930            log::info!("got server hello");
 931            let hello_message_type_name = message.payload_type_name().to_string();
 932            let hello = message
 933                .into_any()
 934                .downcast::<TypedEnvelope<proto::Hello>>()
 935                .map_err(|_| {
 936                    anyhow!(
 937                        "invalid hello message received: {:?}",
 938                        hello_message_type_name
 939                    )
 940                })?;
 941            let peer_id = hello
 942                .payload
 943                .peer_id
 944                .ok_or_else(|| anyhow!("invalid peer id"))?;
 945            Ok(peer_id)
 946        };
 947
 948        let peer_id = match peer_id.await {
 949            Ok(peer_id) => peer_id,
 950            Err(error) => {
 951                self.peer.disconnect(connection_id);
 952                return Err(error);
 953            }
 954        };
 955
 956        log::info!(
 957            "set status to connected (connection id: {:?}, peer id: {:?})",
 958            connection_id,
 959            peer_id
 960        );
 961        self.set_status(
 962            Status::Connected {
 963                peer_id,
 964                connection_id,
 965            },
 966            cx,
 967        );
 968
 969        cx.spawn({
 970            let this = self.clone();
 971            |cx| {
 972                async move {
 973                    while let Some(message) = incoming.next().await {
 974                        this.handle_message(message, &cx);
 975                        // Don't starve the main thread when receiving lots of messages at once.
 976                        smol::future::yield_now().await;
 977                    }
 978                }
 979            }
 980        })
 981        .detach();
 982
 983        cx.spawn({
 984            let this = self.clone();
 985            move |cx| async move {
 986                match handle_io.await {
 987                    Ok(()) => {
 988                        if *this.status().borrow()
 989                            == (Status::Connected {
 990                                connection_id,
 991                                peer_id,
 992                            })
 993                        {
 994                            this.set_status(Status::SignedOut, &cx);
 995                        }
 996                    }
 997                    Err(err) => {
 998                        log::error!("connection error: {:?}", err);
 999                        this.set_status(Status::ConnectionLost, &cx);
1000                    }
1001                }
1002            }
1003        })
1004        .detach();
1005
1006        Ok(())
1007    }
1008
1009    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
1010        #[cfg(any(test, feature = "test-support"))]
1011        if let Some(callback) = self.authenticate.read().as_ref() {
1012            return callback(cx);
1013        }
1014
1015        self.authenticate_with_browser(cx)
1016    }
1017
1018    fn establish_connection(
1019        self: &Arc<Self>,
1020        credentials: &Credentials,
1021        cx: &AsyncAppContext,
1022    ) -> Task<Result<Connection, EstablishConnectionError>> {
1023        #[cfg(any(test, feature = "test-support"))]
1024        if let Some(callback) = self.establish_connection.read().as_ref() {
1025            return callback(credentials, cx);
1026        }
1027
1028        self.establish_websocket_connection(credentials, cx)
1029    }
1030
1031    fn rpc_url(
1032        &self,
1033        http: Arc<HttpClientWithUrl>,
1034        release_channel: Option<ReleaseChannel>,
1035    ) -> impl Future<Output = Result<Url>> {
1036        #[cfg(any(test, feature = "test-support"))]
1037        let url_override = self.rpc_url.read().clone();
1038
1039        async move {
1040            #[cfg(any(test, feature = "test-support"))]
1041            if let Some(url) = url_override {
1042                return Ok(url);
1043            }
1044
1045            if let Some(url) = &*ZED_RPC_URL {
1046                return Url::parse(url).context("invalid rpc url");
1047            }
1048
1049            let mut url = http.build_url("/rpc");
1050            if let Some(preview_param) =
1051                release_channel.and_then(|channel| channel.release_query_param())
1052            {
1053                url += "?";
1054                url += preview_param;
1055            }
1056
1057            let response = http.get(&url, Default::default(), false).await?;
1058            let collab_url = if response.status().is_redirection() {
1059                response
1060                    .headers()
1061                    .get("Location")
1062                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1063                    .to_str()
1064                    .map_err(EstablishConnectionError::other)?
1065                    .to_string()
1066            } else {
1067                Err(anyhow!(
1068                    "unexpected /rpc response status {}",
1069                    response.status()
1070                ))?
1071            };
1072
1073            Url::parse(&collab_url).context("invalid rpc url")
1074        }
1075    }
1076
1077    fn establish_websocket_connection(
1078        self: &Arc<Self>,
1079        credentials: &Credentials,
1080        cx: &AsyncAppContext,
1081    ) -> Task<Result<Connection, EstablishConnectionError>> {
1082        let release_channel = cx
1083            .update(|cx| ReleaseChannel::try_global(cx))
1084            .ok()
1085            .flatten();
1086        let app_version = cx
1087            .update(|cx| AppVersion::global(cx).to_string())
1088            .ok()
1089            .unwrap_or_default();
1090
1091        let http = self.http.clone();
1092        let proxy = http.proxy().cloned();
1093        let credentials = credentials.clone();
1094        let rpc_url = self.rpc_url(http, release_channel);
1095        cx.background_executor().spawn(async move {
1096            use HttpOrHttps::*;
1097
1098            #[derive(Debug)]
1099            enum HttpOrHttps {
1100                Http,
1101                Https,
1102            }
1103
1104            let mut rpc_url = rpc_url.await?;
1105            let url_scheme = match rpc_url.scheme() {
1106                "https" => Https,
1107                "http" => Http,
1108                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1109            };
1110            let rpc_host = rpc_url
1111                .host_str()
1112                .zip(rpc_url.port_or_known_default())
1113                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1114            let stream = connect_socks_proxy_stream(proxy.as_ref(), rpc_host).await?;
1115
1116            log::info!("connected to rpc endpoint {}", rpc_url);
1117
1118            rpc_url
1119                .set_scheme(match url_scheme {
1120                    Https => "wss",
1121                    Http => "ws",
1122                })
1123                .unwrap();
1124
1125            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1126            // for us from the RPC URL.
1127            //
1128            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1129            let mut request = rpc_url.into_client_request()?;
1130
1131            // We then modify the request to add our desired headers.
1132            let request_headers = request.headers_mut();
1133            request_headers.insert(
1134                "Authorization",
1135                HeaderValue::from_str(&credentials.authorization_header())?,
1136            );
1137            request_headers.insert(
1138                "x-zed-protocol-version",
1139                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1140            );
1141            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1142            request_headers.insert(
1143                "x-zed-release-channel",
1144                HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1145            );
1146
1147            match url_scheme {
1148                Https => {
1149                    let (stream, _) =
1150                        async_tungstenite::async_std::client_async_tls(request, stream).await?;
1151                    Ok(Connection::new(
1152                        stream
1153                            .map_err(|error| anyhow!(error))
1154                            .sink_map_err(|error| anyhow!(error)),
1155                    ))
1156                }
1157                Http => {
1158                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1159                    Ok(Connection::new(
1160                        stream
1161                            .map_err(|error| anyhow!(error))
1162                            .sink_map_err(|error| anyhow!(error)),
1163                    ))
1164                }
1165            }
1166        })
1167    }
1168
1169    pub fn authenticate_with_browser(
1170        self: &Arc<Self>,
1171        cx: &AsyncAppContext,
1172    ) -> Task<Result<Credentials>> {
1173        let http = self.http.clone();
1174        let this = self.clone();
1175        cx.spawn(|cx| async move {
1176            let background = cx.background_executor().clone();
1177
1178            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1179            cx.update(|cx| {
1180                cx.spawn(move |cx| async move {
1181                    let url = open_url_rx.await?;
1182                    cx.update(|cx| cx.open_url(&url))
1183                })
1184                .detach_and_log_err(cx);
1185            })
1186            .log_err();
1187
1188            let credentials = background
1189                .clone()
1190                .spawn(async move {
1191                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1192                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1193                    // any other app running on the user's device.
1194                    let (public_key, private_key) =
1195                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1196                    let public_key_string = String::try_from(public_key)
1197                        .expect("failed to serialize public key for auth");
1198
1199                    if let Some((login, token)) =
1200                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1201                    {
1202                        eprintln!("authenticate as admin {login}, {token}");
1203
1204                        return this
1205                            .authenticate_as_admin(http, login.clone(), token.clone())
1206                            .await;
1207                    }
1208
1209                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1210                    let server =
1211                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1212                    let port = server.server_addr().port();
1213
1214                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1215                    // that the user is signing in from a Zed app running on the same device.
1216                    let mut url = http.build_url(&format!(
1217                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1218                        port, public_key_string
1219                    ));
1220
1221                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1222                        log::info!("impersonating user @{}", impersonate_login);
1223                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1224                    }
1225
1226                    open_url_tx.send(url).log_err();
1227
1228                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1229                    // access token from the query params.
1230                    //
1231                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1232                    // custom URL scheme instead of this local HTTP server.
1233                    let (user_id, access_token) = background
1234                        .spawn(async move {
1235                            for _ in 0..100 {
1236                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1237                                    let path = req.url();
1238                                    let mut user_id = None;
1239                                    let mut access_token = None;
1240                                    let url = Url::parse(&format!("http://example.com{}", path))
1241                                        .context("failed to parse login notification url")?;
1242                                    for (key, value) in url.query_pairs() {
1243                                        if key == "access_token" {
1244                                            access_token = Some(value.to_string());
1245                                        } else if key == "user_id" {
1246                                            user_id = Some(value.to_string());
1247                                        }
1248                                    }
1249
1250                                    let post_auth_url =
1251                                        http.build_url("/native_app_signin_succeeded");
1252                                    req.respond(
1253                                        tiny_http::Response::empty(302).with_header(
1254                                            tiny_http::Header::from_bytes(
1255                                                &b"Location"[..],
1256                                                post_auth_url.as_bytes(),
1257                                            )
1258                                            .unwrap(),
1259                                        ),
1260                                    )
1261                                    .context("failed to respond to login http request")?;
1262                                    return Ok((
1263                                        user_id
1264                                            .ok_or_else(|| anyhow!("missing user_id parameter"))?,
1265                                        access_token.ok_or_else(|| {
1266                                            anyhow!("missing access_token parameter")
1267                                        })?,
1268                                    ));
1269                                }
1270                            }
1271
1272                            Err(anyhow!("didn't receive login redirect"))
1273                        })
1274                        .await?;
1275
1276                    let access_token = private_key
1277                        .decrypt_string(&access_token)
1278                        .context("failed to decrypt access token")?;
1279
1280                    Ok(Credentials::User {
1281                        user_id: user_id.parse()?,
1282                        access_token,
1283                    })
1284                })
1285                .await?;
1286
1287            cx.update(|cx| cx.activate(true))?;
1288            Ok(credentials)
1289        })
1290    }
1291
1292    async fn authenticate_as_admin(
1293        self: &Arc<Self>,
1294        http: Arc<HttpClientWithUrl>,
1295        login: String,
1296        mut api_token: String,
1297    ) -> Result<Credentials> {
1298        #[derive(Deserialize)]
1299        struct AuthenticatedUserResponse {
1300            user: User,
1301        }
1302
1303        #[derive(Deserialize)]
1304        struct User {
1305            id: u64,
1306        }
1307
1308        let github_user = {
1309            #[derive(Deserialize)]
1310            struct GithubUser {
1311                id: i32,
1312                login: String,
1313                created_at: DateTime<Utc>,
1314            }
1315
1316            let request = {
1317                let mut request_builder =
1318                    Request::get(&format!("https://api.github.com/users/{login}"));
1319                if let Ok(github_token) = std::env::var("GITHUB_TOKEN") {
1320                    request_builder =
1321                        request_builder.header("Authorization", format!("Bearer {}", github_token));
1322                }
1323
1324                request_builder.body(AsyncBody::empty())?
1325            };
1326
1327            let mut response = http
1328                .send(request)
1329                .await
1330                .context("error fetching GitHub user")?;
1331
1332            let mut body = Vec::new();
1333            response
1334                .body_mut()
1335                .read_to_end(&mut body)
1336                .await
1337                .context("error reading GitHub user")?;
1338
1339            if !response.status().is_success() {
1340                let text = String::from_utf8_lossy(body.as_slice());
1341                bail!(
1342                    "status error {}, response: {text:?}",
1343                    response.status().as_u16()
1344                );
1345            }
1346
1347            serde_json::from_slice::<GithubUser>(body.as_slice()).map_err(|err| {
1348                log::error!("Error deserializing: {:?}", err);
1349                log::error!(
1350                    "GitHub API response text: {:?}",
1351                    String::from_utf8_lossy(body.as_slice())
1352                );
1353                anyhow!("error deserializing GitHub user")
1354            })?
1355        };
1356
1357        let query_params = [
1358            ("github_login", &github_user.login),
1359            ("github_user_id", &github_user.id.to_string()),
1360            (
1361                "github_user_created_at",
1362                &github_user.created_at.to_rfc3339(),
1363            ),
1364        ];
1365
1366        // Use the collab server's admin API to retrieve the ID
1367        // of the impersonated user.
1368        let mut url = self.rpc_url(http.clone(), None).await?;
1369        url.set_path("/user");
1370        url.set_query(Some(
1371            &query_params
1372                .iter()
1373                .map(|(key, value)| {
1374                    format!(
1375                        "{}={}",
1376                        key,
1377                        url::form_urlencoded::byte_serialize(value.as_bytes()).collect::<String>()
1378                    )
1379                })
1380                .collect::<Vec<String>>()
1381                .join("&"),
1382        ));
1383        let request: http_client::Request<AsyncBody> = Request::get(url.as_str())
1384            .header("Authorization", format!("token {api_token}"))
1385            .body("".into())?;
1386
1387        let mut response = http.send(request).await?;
1388        let mut body = String::new();
1389        response.body_mut().read_to_string(&mut body).await?;
1390        if !response.status().is_success() {
1391            Err(anyhow!(
1392                "admin user request failed {} - {}",
1393                response.status().as_u16(),
1394                body,
1395            ))?;
1396        }
1397        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1398
1399        // Use the admin API token to authenticate as the impersonated user.
1400        api_token.insert_str(0, "ADMIN_TOKEN:");
1401        Ok(Credentials::User {
1402            user_id: response.user.id,
1403            access_token: api_token,
1404        })
1405    }
1406
1407    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncAppContext) {
1408        self.state.write().credentials = None;
1409        self.disconnect(cx);
1410
1411        if self.has_credentials(cx).await {
1412            self.credentials_provider
1413                .delete_credentials(cx)
1414                .await
1415                .log_err();
1416        }
1417    }
1418
1419    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1420        self.peer.teardown();
1421        self.set_status(Status::SignedOut, cx);
1422    }
1423
1424    pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1425        self.peer.teardown();
1426        self.set_status(Status::ConnectionLost, cx);
1427    }
1428
1429    fn connection_id(&self) -> Result<ConnectionId> {
1430        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1431            Ok(connection_id)
1432        } else {
1433            Err(anyhow!("not connected"))
1434        }
1435    }
1436
1437    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1438        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1439        self.peer.send(self.connection_id()?, message)
1440    }
1441
1442    pub fn request<T: RequestMessage>(
1443        &self,
1444        request: T,
1445    ) -> impl Future<Output = Result<T::Response>> {
1446        self.request_envelope(request)
1447            .map_ok(|envelope| envelope.payload)
1448    }
1449
1450    pub fn request_stream<T: RequestMessage>(
1451        &self,
1452        request: T,
1453    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1454        let client_id = self.id.load(Ordering::SeqCst);
1455        log::debug!(
1456            "rpc request start. client_id:{}. name:{}",
1457            client_id,
1458            T::NAME
1459        );
1460        let response = self
1461            .connection_id()
1462            .map(|conn_id| self.peer.request_stream(conn_id, request));
1463        async move {
1464            let response = response?.await;
1465            log::debug!(
1466                "rpc request finish. client_id:{}. name:{}",
1467                client_id,
1468                T::NAME
1469            );
1470            response
1471        }
1472    }
1473
1474    pub fn request_envelope<T: RequestMessage>(
1475        &self,
1476        request: T,
1477    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1478        let client_id = self.id();
1479        log::debug!(
1480            "rpc request start. client_id:{}. name:{}",
1481            client_id,
1482            T::NAME
1483        );
1484        let response = self
1485            .connection_id()
1486            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1487        async move {
1488            let response = response?.await;
1489            log::debug!(
1490                "rpc request finish. client_id:{}. name:{}",
1491                client_id,
1492                T::NAME
1493            );
1494            response
1495        }
1496    }
1497
1498    pub fn request_dynamic(
1499        &self,
1500        envelope: proto::Envelope,
1501        request_type: &'static str,
1502    ) -> impl Future<Output = Result<proto::Envelope>> {
1503        let client_id = self.id();
1504        log::debug!(
1505            "rpc request start. client_id:{}. name:{}",
1506            client_id,
1507            request_type
1508        );
1509        let response = self
1510            .connection_id()
1511            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1512        async move {
1513            let response = response?.await;
1514            log::debug!(
1515                "rpc request finish. client_id:{}. name:{}",
1516                client_id,
1517                request_type
1518            );
1519            Ok(response?.0)
1520        }
1521    }
1522
1523    fn handle_message(
1524        self: &Arc<Client>,
1525        message: Box<dyn AnyTypedEnvelope>,
1526        cx: &AsyncAppContext,
1527    ) {
1528        let sender_id = message.sender_id();
1529        let request_id = message.message_id();
1530        let type_name = message.payload_type_name();
1531        let original_sender_id = message.original_sender_id();
1532
1533        if let Some(future) = ProtoMessageHandlerSet::handle_message(
1534            &self.handler_set,
1535            message,
1536            self.clone().into(),
1537            cx.clone(),
1538        ) {
1539            let client_id = self.id();
1540            log::debug!(
1541                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1542                client_id,
1543                original_sender_id,
1544                type_name
1545            );
1546            cx.spawn(move |_| async move {
1547                match future.await {
1548                    Ok(()) => {
1549                        log::debug!(
1550                            "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1551                            client_id,
1552                            original_sender_id,
1553                            type_name
1554                        );
1555                    }
1556                    Err(error) => {
1557                        log::error!(
1558                            "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1559                            client_id,
1560                            original_sender_id,
1561                            type_name,
1562                            error
1563                        );
1564                    }
1565                }
1566            })
1567            .detach();
1568        } else {
1569            log::info!("unhandled message {}", type_name);
1570            self.peer
1571                .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1572                .log_err();
1573        }
1574    }
1575
1576    pub fn telemetry(&self) -> &Arc<Telemetry> {
1577        &self.telemetry
1578    }
1579}
1580
1581impl ProtoClient for Client {
1582    fn request(
1583        &self,
1584        envelope: proto::Envelope,
1585        request_type: &'static str,
1586    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1587        self.request_dynamic(envelope, request_type).boxed()
1588    }
1589
1590    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1591        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1592        let connection_id = self.connection_id()?;
1593        self.peer.send_dynamic(connection_id, envelope)
1594    }
1595
1596    fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1597        log::debug!(
1598            "rpc respond. client_id:{}, name:{}",
1599            self.id(),
1600            message_type
1601        );
1602        let connection_id = self.connection_id()?;
1603        self.peer.send_dynamic(connection_id, envelope)
1604    }
1605
1606    fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1607        &self.handler_set
1608    }
1609}
1610
1611#[derive(Serialize, Deserialize)]
1612struct DevelopmentCredentials {
1613    user_id: u64,
1614    access_token: String,
1615}
1616
1617/// A credentials provider that stores credentials in a local file.
1618///
1619/// This MUST only be used in development, as this is not a secure way of storing
1620/// credentials on user machines.
1621///
1622/// Its existence is purely to work around the annoyance of having to constantly
1623/// re-allow access to the system keychain when developing Zed.
1624struct DevelopmentCredentialsProvider {
1625    path: PathBuf,
1626}
1627
1628impl CredentialsProvider for DevelopmentCredentialsProvider {
1629    fn read_credentials<'a>(
1630        &'a self,
1631        _cx: &'a AsyncAppContext,
1632    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1633        async move {
1634            if IMPERSONATE_LOGIN.is_some() {
1635                return None;
1636            }
1637
1638            let json = std::fs::read(&self.path).log_err()?;
1639
1640            let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?;
1641
1642            Some(Credentials::User {
1643                user_id: credentials.user_id,
1644                access_token: credentials.access_token,
1645            })
1646        }
1647        .boxed_local()
1648    }
1649
1650    fn write_credentials<'a>(
1651        &'a self,
1652        user_id: u64,
1653        access_token: String,
1654        _cx: &'a AsyncAppContext,
1655    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1656        async move {
1657            let json = serde_json::to_string(&DevelopmentCredentials {
1658                user_id,
1659                access_token,
1660            })?;
1661
1662            std::fs::write(&self.path, json)?;
1663
1664            Ok(())
1665        }
1666        .boxed_local()
1667    }
1668
1669    fn delete_credentials<'a>(
1670        &'a self,
1671        _cx: &'a AsyncAppContext,
1672    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1673        async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local()
1674    }
1675}
1676
1677/// A credentials provider that stores credentials in the system keychain.
1678struct KeychainCredentialsProvider;
1679
1680impl CredentialsProvider for KeychainCredentialsProvider {
1681    fn read_credentials<'a>(
1682        &'a self,
1683        cx: &'a AsyncAppContext,
1684    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1685        async move {
1686            if IMPERSONATE_LOGIN.is_some() {
1687                return None;
1688            }
1689
1690            let (user_id, access_token) = cx
1691                .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
1692                .log_err()?
1693                .await
1694                .log_err()??;
1695
1696            Some(Credentials::User {
1697                user_id: user_id.parse().ok()?,
1698                access_token: String::from_utf8(access_token).ok()?,
1699            })
1700        }
1701        .boxed_local()
1702    }
1703
1704    fn write_credentials<'a>(
1705        &'a self,
1706        user_id: u64,
1707        access_token: String,
1708        cx: &'a AsyncAppContext,
1709    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1710        async move {
1711            cx.update(move |cx| {
1712                cx.write_credentials(
1713                    &ClientSettings::get_global(cx).server_url,
1714                    &user_id.to_string(),
1715                    access_token.as_bytes(),
1716                )
1717            })?
1718            .await
1719        }
1720        .boxed_local()
1721    }
1722
1723    fn delete_credentials<'a>(
1724        &'a self,
1725        cx: &'a AsyncAppContext,
1726    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1727        async move {
1728            cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
1729                .await
1730        }
1731        .boxed_local()
1732    }
1733}
1734
1735/// prefix for the zed:// url scheme
1736pub static ZED_URL_SCHEME: &str = "zed";
1737
1738/// Parses the given link into a Zed link.
1739///
1740/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1741/// Returns [`None`] otherwise.
1742pub fn parse_zed_link<'a>(link: &'a str, cx: &AppContext) -> Option<&'a str> {
1743    let server_url = &ClientSettings::get_global(cx).server_url;
1744    if let Some(stripped) = link
1745        .strip_prefix(server_url)
1746        .and_then(|result| result.strip_prefix('/'))
1747    {
1748        return Some(stripped);
1749    }
1750    if let Some(stripped) = link
1751        .strip_prefix(ZED_URL_SCHEME)
1752        .and_then(|result| result.strip_prefix("://"))
1753    {
1754        return Some(stripped);
1755    }
1756
1757    None
1758}
1759
1760#[cfg(test)]
1761mod tests {
1762    use super::*;
1763    use crate::test::FakeServer;
1764
1765    use clock::FakeSystemClock;
1766    use gpui::{BackgroundExecutor, Context, TestAppContext};
1767    use http_client::FakeHttpClient;
1768    use parking_lot::Mutex;
1769    use proto::TypedEnvelope;
1770    use settings::SettingsStore;
1771    use std::future;
1772
1773    #[gpui::test(iterations = 10)]
1774    async fn test_reconnection(cx: &mut TestAppContext) {
1775        init_test(cx);
1776        let user_id = 5;
1777        let client = cx.update(|cx| {
1778            Client::new(
1779                Arc::new(FakeSystemClock::default()),
1780                FakeHttpClient::with_404_response(),
1781                cx,
1782            )
1783        });
1784        let server = FakeServer::for_client(user_id, &client, cx).await;
1785        let mut status = client.status();
1786        assert!(matches!(
1787            status.next().await,
1788            Some(Status::Connected { .. })
1789        ));
1790        assert_eq!(server.auth_count(), 1);
1791
1792        server.forbid_connections();
1793        server.disconnect();
1794        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1795
1796        server.allow_connections();
1797        cx.executor().advance_clock(Duration::from_secs(10));
1798        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1799        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1800
1801        server.forbid_connections();
1802        server.disconnect();
1803        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1804
1805        // Clear cached credentials after authentication fails
1806        server.roll_access_token();
1807        server.allow_connections();
1808        cx.executor().run_until_parked();
1809        cx.executor().advance_clock(Duration::from_secs(10));
1810        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1811        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1812    }
1813
1814    #[gpui::test(iterations = 10)]
1815    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1816        init_test(cx);
1817        let user_id = 5;
1818        let client = cx.update(|cx| {
1819            Client::new(
1820                Arc::new(FakeSystemClock::default()),
1821                FakeHttpClient::with_404_response(),
1822                cx,
1823            )
1824        });
1825        let mut status = client.status();
1826
1827        // Time out when client tries to connect.
1828        client.override_authenticate(move |cx| {
1829            cx.background_executor().spawn(async move {
1830                Ok(Credentials::User {
1831                    user_id,
1832                    access_token: "token".into(),
1833                })
1834            })
1835        });
1836        client.override_establish_connection(|_, cx| {
1837            cx.background_executor().spawn(async move {
1838                future::pending::<()>().await;
1839                unreachable!()
1840            })
1841        });
1842        let auth_and_connect = cx.spawn({
1843            let client = client.clone();
1844            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1845        });
1846        executor.run_until_parked();
1847        assert!(matches!(status.next().await, Some(Status::Connecting)));
1848
1849        executor.advance_clock(CONNECTION_TIMEOUT);
1850        assert!(matches!(
1851            status.next().await,
1852            Some(Status::ConnectionError { .. })
1853        ));
1854        auth_and_connect.await.unwrap_err();
1855
1856        // Allow the connection to be established.
1857        let server = FakeServer::for_client(user_id, &client, cx).await;
1858        assert!(matches!(
1859            status.next().await,
1860            Some(Status::Connected { .. })
1861        ));
1862
1863        // Disconnect client.
1864        server.forbid_connections();
1865        server.disconnect();
1866        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1867
1868        // Time out when re-establishing the connection.
1869        server.allow_connections();
1870        client.override_establish_connection(|_, cx| {
1871            cx.background_executor().spawn(async move {
1872                future::pending::<()>().await;
1873                unreachable!()
1874            })
1875        });
1876        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1877        assert!(matches!(
1878            status.next().await,
1879            Some(Status::Reconnecting { .. })
1880        ));
1881
1882        executor.advance_clock(CONNECTION_TIMEOUT);
1883        assert!(matches!(
1884            status.next().await,
1885            Some(Status::ReconnectionError { .. })
1886        ));
1887    }
1888
1889    #[gpui::test(iterations = 10)]
1890    async fn test_authenticating_more_than_once(
1891        cx: &mut TestAppContext,
1892        executor: BackgroundExecutor,
1893    ) {
1894        init_test(cx);
1895        let auth_count = Arc::new(Mutex::new(0));
1896        let dropped_auth_count = Arc::new(Mutex::new(0));
1897        let client = cx.update(|cx| {
1898            Client::new(
1899                Arc::new(FakeSystemClock::default()),
1900                FakeHttpClient::with_404_response(),
1901                cx,
1902            )
1903        });
1904        client.override_authenticate({
1905            let auth_count = auth_count.clone();
1906            let dropped_auth_count = dropped_auth_count.clone();
1907            move |cx| {
1908                let auth_count = auth_count.clone();
1909                let dropped_auth_count = dropped_auth_count.clone();
1910                cx.background_executor().spawn(async move {
1911                    *auth_count.lock() += 1;
1912                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1913                    future::pending::<()>().await;
1914                    unreachable!()
1915                })
1916            }
1917        });
1918
1919        let _authenticate = cx.spawn({
1920            let client = client.clone();
1921            move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1922        });
1923        executor.run_until_parked();
1924        assert_eq!(*auth_count.lock(), 1);
1925        assert_eq!(*dropped_auth_count.lock(), 0);
1926
1927        let _authenticate = cx.spawn({
1928            let client = client.clone();
1929            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1930        });
1931        executor.run_until_parked();
1932        assert_eq!(*auth_count.lock(), 2);
1933        assert_eq!(*dropped_auth_count.lock(), 1);
1934    }
1935
1936    #[gpui::test]
1937    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1938        init_test(cx);
1939        let user_id = 5;
1940        let client = cx.update(|cx| {
1941            Client::new(
1942                Arc::new(FakeSystemClock::default()),
1943                FakeHttpClient::with_404_response(),
1944                cx,
1945            )
1946        });
1947        let server = FakeServer::for_client(user_id, &client, cx).await;
1948
1949        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1950        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1951        AnyProtoClient::from(client.clone()).add_model_message_handler(
1952            move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, mut cx| {
1953                match model.update(&mut cx, |model, _| model.id).unwrap() {
1954                    1 => done_tx1.try_send(()).unwrap(),
1955                    2 => done_tx2.try_send(()).unwrap(),
1956                    _ => unreachable!(),
1957                }
1958                async { Ok(()) }
1959            },
1960        );
1961        let model1 = cx.new_model(|_| TestModel {
1962            id: 1,
1963            subscription: None,
1964        });
1965        let model2 = cx.new_model(|_| TestModel {
1966            id: 2,
1967            subscription: None,
1968        });
1969        let model3 = cx.new_model(|_| TestModel {
1970            id: 3,
1971            subscription: None,
1972        });
1973
1974        let _subscription1 = client
1975            .subscribe_to_entity(1)
1976            .unwrap()
1977            .set_model(&model1, &mut cx.to_async());
1978        let _subscription2 = client
1979            .subscribe_to_entity(2)
1980            .unwrap()
1981            .set_model(&model2, &mut cx.to_async());
1982        // Ensure dropping a subscription for the same entity type still allows receiving of
1983        // messages for other entity IDs of the same type.
1984        let subscription3 = client
1985            .subscribe_to_entity(3)
1986            .unwrap()
1987            .set_model(&model3, &mut cx.to_async());
1988        drop(subscription3);
1989
1990        server.send(proto::JoinProject { project_id: 1 });
1991        server.send(proto::JoinProject { project_id: 2 });
1992        done_rx1.next().await.unwrap();
1993        done_rx2.next().await.unwrap();
1994    }
1995
1996    #[gpui::test]
1997    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1998        init_test(cx);
1999        let user_id = 5;
2000        let client = cx.update(|cx| {
2001            Client::new(
2002                Arc::new(FakeSystemClock::default()),
2003                FakeHttpClient::with_404_response(),
2004                cx,
2005            )
2006        });
2007        let server = FakeServer::for_client(user_id, &client, cx).await;
2008
2009        let model = cx.new_model(|_| TestModel::default());
2010        let (done_tx1, _done_rx1) = smol::channel::unbounded();
2011        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
2012        let subscription1 = client.add_message_handler(
2013            model.downgrade(),
2014            move |_, _: TypedEnvelope<proto::Ping>, _| {
2015                done_tx1.try_send(()).unwrap();
2016                async { Ok(()) }
2017            },
2018        );
2019        drop(subscription1);
2020        let _subscription2 = client.add_message_handler(
2021            model.downgrade(),
2022            move |_, _: TypedEnvelope<proto::Ping>, _| {
2023                done_tx2.try_send(()).unwrap();
2024                async { Ok(()) }
2025            },
2026        );
2027        server.send(proto::Ping {});
2028        done_rx2.next().await.unwrap();
2029    }
2030
2031    #[gpui::test]
2032    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2033        init_test(cx);
2034        let user_id = 5;
2035        let client = cx.update(|cx| {
2036            Client::new(
2037                Arc::new(FakeSystemClock::default()),
2038                FakeHttpClient::with_404_response(),
2039                cx,
2040            )
2041        });
2042        let server = FakeServer::for_client(user_id, &client, cx).await;
2043
2044        let model = cx.new_model(|_| TestModel::default());
2045        let (done_tx, mut done_rx) = smol::channel::unbounded();
2046        let subscription = client.add_message_handler(
2047            model.clone().downgrade(),
2048            move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, mut cx| {
2049                model
2050                    .update(&mut cx, |model, _| model.subscription.take())
2051                    .unwrap();
2052                done_tx.try_send(()).unwrap();
2053                async { Ok(()) }
2054            },
2055        );
2056        model.update(cx, |model, _| {
2057            model.subscription = Some(subscription);
2058        });
2059        server.send(proto::Ping {});
2060        done_rx.next().await.unwrap();
2061    }
2062
2063    #[derive(Default)]
2064    struct TestModel {
2065        id: usize,
2066        subscription: Option<Subscription>,
2067    }
2068
2069    fn init_test(cx: &mut TestAppContext) {
2070        cx.update(|cx| {
2071            let settings_store = SettingsStore::test(cx);
2072            cx.set_global(settings_store);
2073            init_settings(cx);
2074        });
2075    }
2076}