client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4mod socks;
   5pub mod telemetry;
   6pub mod user;
   7
   8use anyhow::{anyhow, bail, Context as _, Result};
   9use async_recursion::async_recursion;
  10use async_tungstenite::tungstenite::{
  11    client::IntoClientRequest,
  12    error::Error as WebsocketError,
  13    http::{HeaderValue, Request, StatusCode},
  14};
  15use chrono::{DateTime, Utc};
  16use clock::SystemClock;
  17use collections::HashMap;
  18use futures::{
  19    channel::oneshot,
  20    future::{BoxFuture, LocalBoxFuture},
  21    AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
  22};
  23use gpui::{
  24    actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Global, Model, Task, WeakModel,
  25};
  26use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
  27use parking_lot::RwLock;
  28use postage::watch;
  29use proto::ProtoClient;
  30use rand::prelude::*;
  31use release_channel::{AppVersion, ReleaseChannel};
  32use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
  33use schemars::JsonSchema;
  34use serde::{Deserialize, Serialize};
  35use settings::{Settings, SettingsSources};
  36use socks::connect_socks_proxy_stream;
  37use std::fmt;
  38use std::pin::Pin;
  39use std::{
  40    any::TypeId,
  41    convert::TryFrom,
  42    fmt::Write as _,
  43    future::Future,
  44    marker::PhantomData,
  45    path::PathBuf,
  46    sync::{
  47        atomic::{AtomicU64, Ordering},
  48        Arc, LazyLock, Weak,
  49    },
  50    time::{Duration, Instant},
  51};
  52use telemetry::Telemetry;
  53use thiserror::Error;
  54use url::Url;
  55use util::{ResultExt, TryFutureExt};
  56
  57pub use rpc::*;
  58pub use telemetry_events::Event;
  59pub use user::*;
  60
  61#[derive(Debug, Clone, Eq, PartialEq)]
  62pub struct DevServerToken(pub String);
  63
  64impl fmt::Display for DevServerToken {
  65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  66        write!(f, "{}", self.0)
  67    }
  68}
  69
  70static ZED_SERVER_URL: LazyLock<Option<String>> =
  71    LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
  72static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
  73
  74/// An environment variable whose presence indicates that the development auth
  75/// provider should be used.
  76///
  77/// Only works in development. Setting this environment variable in other release
  78/// channels is a no-op.
  79pub static ZED_DEVELOPMENT_AUTH: LazyLock<bool> = LazyLock::new(|| {
  80    std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty())
  81});
  82pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
  83    std::env::var("ZED_IMPERSONATE")
  84        .ok()
  85        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  86});
  87
  88pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
  89    std::env::var("ZED_ADMIN_API_TOKEN")
  90        .ok()
  91        .and_then(|s| if s.is_empty() { None } else { Some(s) })
  92});
  93
  94pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
  95    LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
  96
  97pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
  98    LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty()));
  99
 100pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
 101pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
 102pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
 103
 104actions!(client, [SignIn, SignOut, Reconnect]);
 105
 106#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
 107pub struct ClientSettingsContent {
 108    server_url: Option<String>,
 109}
 110
 111#[derive(Deserialize)]
 112pub struct ClientSettings {
 113    pub server_url: String,
 114}
 115
 116impl Settings for ClientSettings {
 117    const KEY: Option<&'static str> = None;
 118
 119    type FileContent = ClientSettingsContent;
 120
 121    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 122        let mut result = sources.json_merge::<Self>()?;
 123        if let Some(server_url) = &*ZED_SERVER_URL {
 124            result.server_url.clone_from(&server_url)
 125        }
 126        Ok(result)
 127    }
 128}
 129
 130#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 131pub struct ProxySettingsContent {
 132    proxy: Option<String>,
 133}
 134
 135#[derive(Deserialize, Default)]
 136pub struct ProxySettings {
 137    pub proxy: Option<String>,
 138}
 139
 140impl Settings for ProxySettings {
 141    const KEY: Option<&'static str> = None;
 142
 143    type FileContent = ProxySettingsContent;
 144
 145    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 146        Ok(Self {
 147            proxy: sources
 148                .user
 149                .and_then(|value| value.proxy.clone())
 150                .or(sources.default.proxy.clone()),
 151        })
 152    }
 153}
 154
 155pub fn init_settings(cx: &mut AppContext) {
 156    TelemetrySettings::register(cx);
 157    ClientSettings::register(cx);
 158    ProxySettings::register(cx);
 159}
 160
 161pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
 162    let client = Arc::downgrade(client);
 163    cx.on_action({
 164        let client = client.clone();
 165        move |_: &SignIn, cx| {
 166            if let Some(client) = client.upgrade() {
 167                cx.spawn(
 168                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
 169                )
 170                .detach();
 171            }
 172        }
 173    });
 174
 175    cx.on_action({
 176        let client = client.clone();
 177        move |_: &SignOut, cx| {
 178            if let Some(client) = client.upgrade() {
 179                cx.spawn(|cx| async move {
 180                    client.sign_out(&cx).await;
 181                })
 182                .detach();
 183            }
 184        }
 185    });
 186
 187    cx.on_action({
 188        let client = client.clone();
 189        move |_: &Reconnect, cx| {
 190            if let Some(client) = client.upgrade() {
 191                cx.spawn(|cx| async move {
 192                    client.reconnect(&cx);
 193                })
 194                .detach();
 195            }
 196        }
 197    });
 198}
 199
 200struct GlobalClient(Arc<Client>);
 201
 202impl Global for GlobalClient {}
 203
 204pub struct Client {
 205    id: AtomicU64,
 206    peer: Arc<Peer>,
 207    http: Arc<HttpClientWithUrl>,
 208    telemetry: Arc<Telemetry>,
 209    credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
 210    state: RwLock<ClientState>,
 211
 212    #[allow(clippy::type_complexity)]
 213    #[cfg(any(test, feature = "test-support"))]
 214    authenticate: RwLock<
 215        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 216    >,
 217
 218    #[allow(clippy::type_complexity)]
 219    #[cfg(any(test, feature = "test-support"))]
 220    establish_connection: RwLock<
 221        Option<
 222            Box<
 223                dyn 'static
 224                    + Send
 225                    + Sync
 226                    + Fn(
 227                        &Credentials,
 228                        &AsyncAppContext,
 229                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 230            >,
 231        >,
 232    >,
 233
 234    #[cfg(any(test, feature = "test-support"))]
 235    rpc_url: RwLock<Option<Url>>,
 236}
 237
 238#[derive(Error, Debug)]
 239pub enum EstablishConnectionError {
 240    #[error("upgrade required")]
 241    UpgradeRequired,
 242    #[error("unauthorized")]
 243    Unauthorized,
 244    #[error("{0}")]
 245    Other(#[from] anyhow::Error),
 246    #[error("{0}")]
 247    Http(#[from] http_client::Error),
 248    #[error("{0}")]
 249    InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
 250    #[error("{0}")]
 251    Io(#[from] std::io::Error),
 252    #[error("{0}")]
 253    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 254}
 255
 256impl From<WebsocketError> for EstablishConnectionError {
 257    fn from(error: WebsocketError) -> Self {
 258        if let WebsocketError::Http(response) = &error {
 259            match response.status() {
 260                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 261                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 262                _ => {}
 263            }
 264        }
 265        EstablishConnectionError::Other(error.into())
 266    }
 267}
 268
 269impl EstablishConnectionError {
 270    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 271        Self::Other(error.into())
 272    }
 273}
 274
 275#[derive(Copy, Clone, Debug, PartialEq)]
 276pub enum Status {
 277    SignedOut,
 278    UpgradeRequired,
 279    Authenticating,
 280    Connecting,
 281    ConnectionError,
 282    Connected {
 283        peer_id: PeerId,
 284        connection_id: ConnectionId,
 285    },
 286    ConnectionLost,
 287    Reauthenticating,
 288    Reconnecting,
 289    ReconnectionError {
 290        next_reconnection: Instant,
 291    },
 292}
 293
 294impl Status {
 295    pub fn is_connected(&self) -> bool {
 296        matches!(self, Self::Connected { .. })
 297    }
 298
 299    pub fn is_signed_out(&self) -> bool {
 300        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 301    }
 302}
 303
 304struct ClientState {
 305    credentials: Option<Credentials>,
 306    status: (watch::Sender<Status>, watch::Receiver<Status>),
 307    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 308    _reconnect_task: Option<Task<()>>,
 309    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 310    models_by_message_type: HashMap<TypeId, AnyWeakModel>,
 311    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 312    #[allow(clippy::type_complexity)]
 313    message_handlers: HashMap<
 314        TypeId,
 315        Arc<
 316            dyn Send
 317                + Sync
 318                + Fn(
 319                    AnyModel,
 320                    Box<dyn AnyTypedEnvelope>,
 321                    &Arc<Client>,
 322                    AsyncAppContext,
 323                ) -> LocalBoxFuture<'static, Result<()>>,
 324        >,
 325    >,
 326}
 327
 328enum WeakSubscriber {
 329    Entity { handle: AnyWeakModel },
 330    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 331}
 332
 333#[derive(Clone, Debug, Eq, PartialEq)]
 334pub enum Credentials {
 335    DevServer { token: DevServerToken },
 336    User { user_id: u64, access_token: String },
 337}
 338
 339impl Credentials {
 340    pub fn authorization_header(&self) -> String {
 341        match self {
 342            Credentials::DevServer { token } => format!("dev-server-token {}", token),
 343            Credentials::User {
 344                user_id,
 345                access_token,
 346            } => format!("{} {}", user_id, access_token),
 347        }
 348    }
 349}
 350
 351/// A provider for [`Credentials`].
 352///
 353/// Used to abstract over reading and writing credentials to some form of
 354/// persistence (like the system keychain).
 355trait CredentialsProvider {
 356    /// Reads the credentials from the provider.
 357    fn read_credentials<'a>(
 358        &'a self,
 359        cx: &'a AsyncAppContext,
 360    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>>;
 361
 362    /// Writes the credentials to the provider.
 363    fn write_credentials<'a>(
 364        &'a self,
 365        user_id: u64,
 366        access_token: String,
 367        cx: &'a AsyncAppContext,
 368    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 369
 370    /// Deletes the credentials from the provider.
 371    fn delete_credentials<'a>(
 372        &'a self,
 373        cx: &'a AsyncAppContext,
 374    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
 375}
 376
 377impl Default for ClientState {
 378    fn default() -> Self {
 379        Self {
 380            credentials: None,
 381            status: watch::channel_with(Status::SignedOut),
 382            entity_id_extractors: Default::default(),
 383            _reconnect_task: None,
 384            models_by_message_type: Default::default(),
 385            entities_by_type_and_remote_id: Default::default(),
 386            entity_types_by_message_type: Default::default(),
 387            message_handlers: Default::default(),
 388        }
 389    }
 390}
 391
 392pub enum Subscription {
 393    Entity {
 394        client: Weak<Client>,
 395        id: (TypeId, u64),
 396    },
 397    Message {
 398        client: Weak<Client>,
 399        id: TypeId,
 400    },
 401}
 402
 403impl Drop for Subscription {
 404    fn drop(&mut self) {
 405        match self {
 406            Subscription::Entity { client, id } => {
 407                if let Some(client) = client.upgrade() {
 408                    let mut state = client.state.write();
 409                    let _ = state.entities_by_type_and_remote_id.remove(id);
 410                }
 411            }
 412            Subscription::Message { client, id } => {
 413                if let Some(client) = client.upgrade() {
 414                    let mut state = client.state.write();
 415                    let _ = state.entity_types_by_message_type.remove(id);
 416                    let _ = state.message_handlers.remove(id);
 417                }
 418            }
 419        }
 420    }
 421}
 422
 423pub struct PendingEntitySubscription<T: 'static> {
 424    client: Arc<Client>,
 425    remote_id: u64,
 426    _entity_type: PhantomData<T>,
 427    consumed: bool,
 428}
 429
 430impl<T: 'static> PendingEntitySubscription<T> {
 431    pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
 432        self.consumed = true;
 433        let mut state = self.client.state.write();
 434        let id = (TypeId::of::<T>(), self.remote_id);
 435        let Some(WeakSubscriber::Pending(messages)) =
 436            state.entities_by_type_and_remote_id.remove(&id)
 437        else {
 438            unreachable!()
 439        };
 440
 441        state.entities_by_type_and_remote_id.insert(
 442            id,
 443            WeakSubscriber::Entity {
 444                handle: model.downgrade().into(),
 445            },
 446        );
 447        drop(state);
 448        for message in messages {
 449            let client_id = self.client.id();
 450            let type_name = message.payload_type_name();
 451            let sender_id = message.original_sender_id();
 452            log::debug!(
 453                "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
 454                client_id,
 455                sender_id,
 456                type_name
 457            );
 458            self.client.handle_message(message, cx);
 459        }
 460        Subscription::Entity {
 461            client: Arc::downgrade(&self.client),
 462            id,
 463        }
 464    }
 465}
 466
 467impl<T: 'static> Drop for PendingEntitySubscription<T> {
 468    fn drop(&mut self) {
 469        if !self.consumed {
 470            let mut state = self.client.state.write();
 471            if let Some(WeakSubscriber::Pending(messages)) = state
 472                .entities_by_type_and_remote_id
 473                .remove(&(TypeId::of::<T>(), self.remote_id))
 474            {
 475                for message in messages {
 476                    log::info!("unhandled message {}", message.payload_type_name());
 477                }
 478            }
 479        }
 480    }
 481}
 482
 483#[derive(Copy, Clone)]
 484pub struct TelemetrySettings {
 485    pub diagnostics: bool,
 486    pub metrics: bool,
 487}
 488
 489/// Control what info is collected by Zed.
 490#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 491pub struct TelemetrySettingsContent {
 492    /// Send debug info like crash reports.
 493    ///
 494    /// Default: true
 495    pub diagnostics: Option<bool>,
 496    /// Send anonymized usage data like what languages you're using Zed with.
 497    ///
 498    /// Default: true
 499    pub metrics: Option<bool>,
 500}
 501
 502impl settings::Settings for TelemetrySettings {
 503    const KEY: Option<&'static str> = Some("telemetry");
 504
 505    type FileContent = TelemetrySettingsContent;
 506
 507    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
 508        Ok(Self {
 509            diagnostics: sources.user.as_ref().and_then(|v| v.diagnostics).unwrap_or(
 510                sources
 511                    .default
 512                    .diagnostics
 513                    .ok_or_else(Self::missing_default)?,
 514            ),
 515            metrics: sources
 516                .user
 517                .as_ref()
 518                .and_then(|v| v.metrics)
 519                .unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
 520        })
 521    }
 522}
 523
 524impl Client {
 525    pub fn new(
 526        clock: Arc<dyn SystemClock>,
 527        http: Arc<HttpClientWithUrl>,
 528        cx: &mut AppContext,
 529    ) -> Arc<Self> {
 530        let use_zed_development_auth = match ReleaseChannel::try_global(cx) {
 531            Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH,
 532            Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
 533            | None => false,
 534        };
 535
 536        let credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static> =
 537            if use_zed_development_auth {
 538                Arc::new(DevelopmentCredentialsProvider {
 539                    path: paths::config_dir().join("development_auth"),
 540                })
 541            } else {
 542                Arc::new(KeychainCredentialsProvider)
 543            };
 544
 545        Arc::new(Self {
 546            id: AtomicU64::new(0),
 547            peer: Peer::new(0),
 548            telemetry: Telemetry::new(clock, http.clone(), cx),
 549            http,
 550            credentials_provider,
 551            state: Default::default(),
 552
 553            #[cfg(any(test, feature = "test-support"))]
 554            authenticate: Default::default(),
 555            #[cfg(any(test, feature = "test-support"))]
 556            establish_connection: Default::default(),
 557            #[cfg(any(test, feature = "test-support"))]
 558            rpc_url: RwLock::default(),
 559        })
 560    }
 561
 562    pub fn production(cx: &mut AppContext) -> Arc<Self> {
 563        let user_agent = format!(
 564            "Zed/{} ({}; {})",
 565            AppVersion::global(cx),
 566            std::env::consts::OS,
 567            std::env::consts::ARCH
 568        );
 569        let clock = Arc::new(clock::RealSystemClock);
 570        let http = Arc::new(HttpClientWithUrl::new(
 571            &ClientSettings::get_global(cx).server_url,
 572            Some(user_agent),
 573            ProxySettings::get_global(cx).proxy.clone(),
 574        ));
 575        Self::new(clock, http.clone(), cx)
 576    }
 577
 578    pub fn id(&self) -> u64 {
 579        self.id.load(Ordering::SeqCst)
 580    }
 581
 582    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 583        self.http.clone()
 584    }
 585
 586    pub fn set_id(&self, id: u64) -> &Self {
 587        self.id.store(id, Ordering::SeqCst);
 588        self
 589    }
 590
 591    #[cfg(any(test, feature = "test-support"))]
 592    pub fn teardown(&self) {
 593        let mut state = self.state.write();
 594        state._reconnect_task.take();
 595        state.message_handlers.clear();
 596        state.models_by_message_type.clear();
 597        state.entities_by_type_and_remote_id.clear();
 598        state.entity_id_extractors.clear();
 599        self.peer.teardown();
 600    }
 601
 602    #[cfg(any(test, feature = "test-support"))]
 603    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 604    where
 605        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 606    {
 607        *self.authenticate.write() = Some(Box::new(authenticate));
 608        self
 609    }
 610
 611    #[cfg(any(test, feature = "test-support"))]
 612    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 613    where
 614        F: 'static
 615            + Send
 616            + Sync
 617            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 618    {
 619        *self.establish_connection.write() = Some(Box::new(connect));
 620        self
 621    }
 622
 623    #[cfg(any(test, feature = "test-support"))]
 624    pub fn override_rpc_url(&self, url: Url) -> &Self {
 625        *self.rpc_url.write() = Some(url);
 626        self
 627    }
 628
 629    pub fn global(cx: &AppContext) -> Arc<Self> {
 630        cx.global::<GlobalClient>().0.clone()
 631    }
 632    pub fn set_global(client: Arc<Client>, cx: &mut AppContext) {
 633        cx.set_global(GlobalClient(client))
 634    }
 635
 636    pub fn user_id(&self) -> Option<u64> {
 637        if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() {
 638            Some(*user_id)
 639        } else {
 640            None
 641        }
 642    }
 643
 644    pub fn peer_id(&self) -> Option<PeerId> {
 645        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 646            Some(*peer_id)
 647        } else {
 648            None
 649        }
 650    }
 651
 652    pub fn status(&self) -> watch::Receiver<Status> {
 653        self.state.read().status.1.clone()
 654    }
 655
 656    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 657        log::info!("set status on client {}: {:?}", self.id(), status);
 658        let mut state = self.state.write();
 659        *state.status.0.borrow_mut() = status;
 660
 661        match status {
 662            Status::Connected { .. } => {
 663                state._reconnect_task = None;
 664            }
 665            Status::ConnectionLost => {
 666                let this = self.clone();
 667                state._reconnect_task = Some(cx.spawn(move |cx| async move {
 668                    #[cfg(any(test, feature = "test-support"))]
 669                    let mut rng = StdRng::seed_from_u64(0);
 670                    #[cfg(not(any(test, feature = "test-support")))]
 671                    let mut rng = StdRng::from_entropy();
 672
 673                    let mut delay = INITIAL_RECONNECTION_DELAY;
 674                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 675                        log::error!("failed to connect {}", error);
 676                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 677                            this.set_status(
 678                                Status::ReconnectionError {
 679                                    next_reconnection: Instant::now() + delay,
 680                                },
 681                                &cx,
 682                            );
 683                            cx.background_executor().timer(delay).await;
 684                            delay = delay
 685                                .mul_f32(rng.gen_range(0.5..=2.5))
 686                                .max(INITIAL_RECONNECTION_DELAY)
 687                                .min(MAX_RECONNECTION_DELAY);
 688                        } else {
 689                            break;
 690                        }
 691                    }
 692                }));
 693            }
 694            Status::SignedOut | Status::UpgradeRequired => {
 695                self.telemetry.set_authenticated_user_info(None, false);
 696                state._reconnect_task.take();
 697            }
 698            _ => {}
 699        }
 700    }
 701
 702    pub fn subscribe_to_entity<T>(
 703        self: &Arc<Self>,
 704        remote_id: u64,
 705    ) -> Result<PendingEntitySubscription<T>>
 706    where
 707        T: 'static,
 708    {
 709        let id = (TypeId::of::<T>(), remote_id);
 710
 711        let mut state = self.state.write();
 712        if state.entities_by_type_and_remote_id.contains_key(&id) {
 713            return Err(anyhow!("already subscribed to entity"));
 714        }
 715
 716        state
 717            .entities_by_type_and_remote_id
 718            .insert(id, WeakSubscriber::Pending(Default::default()));
 719
 720        Ok(PendingEntitySubscription {
 721            client: self.clone(),
 722            remote_id,
 723            consumed: false,
 724            _entity_type: PhantomData,
 725        })
 726    }
 727
 728    #[track_caller]
 729    pub fn add_message_handler<M, E, H, F>(
 730        self: &Arc<Self>,
 731        entity: WeakModel<E>,
 732        handler: H,
 733    ) -> Subscription
 734    where
 735        M: EnvelopedMessage,
 736        E: 'static,
 737        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 738        F: 'static + Future<Output = Result<()>>,
 739    {
 740        self.add_message_handler_impl(entity, move |model, message, _, cx| {
 741            handler(model, message, cx)
 742        })
 743    }
 744
 745    fn add_message_handler_impl<M, E, H, F>(
 746        self: &Arc<Self>,
 747        entity: WeakModel<E>,
 748        handler: H,
 749    ) -> Subscription
 750    where
 751        M: EnvelopedMessage,
 752        E: 'static,
 753        H: 'static
 754            + Sync
 755            + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
 756            + Send
 757            + Sync,
 758        F: 'static + Future<Output = Result<()>>,
 759    {
 760        let message_type_id = TypeId::of::<M>();
 761        let mut state = self.state.write();
 762        state
 763            .models_by_message_type
 764            .insert(message_type_id, entity.into());
 765
 766        let prev_handler = state.message_handlers.insert(
 767            message_type_id,
 768            Arc::new(move |subscriber, envelope, client, cx| {
 769                let subscriber = subscriber.downcast::<E>().unwrap();
 770                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 771                handler(subscriber, *envelope, client.clone(), cx).boxed_local()
 772            }),
 773        );
 774        if prev_handler.is_some() {
 775            let location = std::panic::Location::caller();
 776            panic!(
 777                "{}:{} registered handler for the same message {} twice",
 778                location.file(),
 779                location.line(),
 780                std::any::type_name::<M>()
 781            );
 782        }
 783
 784        Subscription::Message {
 785            client: Arc::downgrade(self),
 786            id: message_type_id,
 787        }
 788    }
 789
 790    pub fn add_request_handler<M, E, H, F>(
 791        self: &Arc<Self>,
 792        model: WeakModel<E>,
 793        handler: H,
 794    ) -> Subscription
 795    where
 796        M: RequestMessage,
 797        E: 'static,
 798        H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 799        F: 'static + Future<Output = Result<M::Response>>,
 800    {
 801        self.add_message_handler_impl(model, move |handle, envelope, this, cx| {
 802            Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
 803        })
 804    }
 805
 806    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 807    where
 808        M: EntityMessage,
 809        E: 'static,
 810        H: 'static + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 811        F: 'static + Future<Output = Result<()>>,
 812    {
 813        self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, _, cx| {
 814            handler(subscriber.downcast::<E>().unwrap(), message, cx)
 815        })
 816    }
 817
 818    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 819    where
 820        M: EntityMessage,
 821        E: 'static,
 822        H: 'static + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
 823        F: 'static + Future<Output = Result<()>>,
 824    {
 825        let model_type_id = TypeId::of::<E>();
 826        let message_type_id = TypeId::of::<M>();
 827
 828        let mut state = self.state.write();
 829        state
 830            .entity_types_by_message_type
 831            .insert(message_type_id, model_type_id);
 832        state
 833            .entity_id_extractors
 834            .entry(message_type_id)
 835            .or_insert_with(|| {
 836                |envelope| {
 837                    envelope
 838                        .as_any()
 839                        .downcast_ref::<TypedEnvelope<M>>()
 840                        .unwrap()
 841                        .payload
 842                        .remote_entity_id()
 843                }
 844            });
 845        let prev_handler = state.message_handlers.insert(
 846            message_type_id,
 847            Arc::new(move |handle, envelope, client, cx| {
 848                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 849                handler(handle, *envelope, client.clone(), cx).boxed_local()
 850            }),
 851        );
 852        if prev_handler.is_some() {
 853            panic!("registered handler for the same message twice");
 854        }
 855    }
 856
 857    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 858    where
 859        M: EntityMessage + RequestMessage,
 860        E: 'static,
 861        H: 'static + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
 862        F: 'static + Future<Output = Result<M::Response>>,
 863    {
 864        self.add_entity_message_handler::<M, E, _, _>(move |entity, envelope, client, cx| {
 865            Self::respond_to_request::<M, _>(
 866                envelope.receipt(),
 867                handler(entity.downcast::<E>().unwrap(), envelope, cx),
 868                client,
 869            )
 870        })
 871    }
 872
 873    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 874        receipt: Receipt<T>,
 875        response: F,
 876        client: Arc<Self>,
 877    ) -> Result<()> {
 878        match response.await {
 879            Ok(response) => {
 880                client.respond(receipt, response)?;
 881                Ok(())
 882            }
 883            Err(error) => {
 884                client.respond_with_error(receipt, error.to_proto())?;
 885                Err(error)
 886            }
 887        }
 888    }
 889
 890    pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool {
 891        self.credentials_provider
 892            .read_credentials(cx)
 893            .await
 894            .is_some()
 895    }
 896
 897    pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self {
 898        self.state.write().credentials = Some(Credentials::DevServer { token });
 899        self
 900    }
 901
 902    #[async_recursion(?Send)]
 903    pub async fn authenticate_and_connect(
 904        self: &Arc<Self>,
 905        try_provider: bool,
 906        cx: &AsyncAppContext,
 907    ) -> anyhow::Result<()> {
 908        let was_disconnected = match *self.status().borrow() {
 909            Status::SignedOut => true,
 910            Status::ConnectionError
 911            | Status::ConnectionLost
 912            | Status::Authenticating { .. }
 913            | Status::Reauthenticating { .. }
 914            | Status::ReconnectionError { .. } => false,
 915            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 916                return Ok(())
 917            }
 918            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 919        };
 920        if was_disconnected {
 921            self.set_status(Status::Authenticating, cx);
 922        } else {
 923            self.set_status(Status::Reauthenticating, cx)
 924        }
 925
 926        let mut read_from_provider = false;
 927        let mut credentials = self.state.read().credentials.clone();
 928        if credentials.is_none() && try_provider {
 929            credentials = self.credentials_provider.read_credentials(cx).await;
 930            read_from_provider = credentials.is_some();
 931        }
 932
 933        if credentials.is_none() {
 934            let mut status_rx = self.status();
 935            let _ = status_rx.next().await;
 936            futures::select_biased! {
 937                authenticate = self.authenticate(cx).fuse() => {
 938                    match authenticate {
 939                        Ok(creds) => credentials = Some(creds),
 940                        Err(err) => {
 941                            self.set_status(Status::ConnectionError, cx);
 942                            return Err(err);
 943                        }
 944                    }
 945                }
 946                _ = status_rx.next().fuse() => {
 947                    return Err(anyhow!("authentication canceled"));
 948                }
 949            }
 950        }
 951        let credentials = credentials.unwrap();
 952        if let Credentials::User { user_id, .. } = &credentials {
 953            self.set_id(*user_id);
 954        }
 955
 956        if was_disconnected {
 957            self.set_status(Status::Connecting, cx);
 958        } else {
 959            self.set_status(Status::Reconnecting, cx);
 960        }
 961
 962        let mut timeout =
 963            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
 964        futures::select_biased! {
 965            connection = self.establish_connection(&credentials, cx).fuse() => {
 966                match connection {
 967                    Ok(conn) => {
 968                        self.state.write().credentials = Some(credentials.clone());
 969                        if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
 970                            if let Credentials::User{user_id, access_token} = credentials {
 971                                self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err();
 972                            }
 973                        }
 974
 975                        futures::select_biased! {
 976                            result = self.set_connection(conn, cx).fuse() => result,
 977                            _ = timeout => {
 978                                self.set_status(Status::ConnectionError, cx);
 979                                Err(anyhow!("timed out waiting on hello message from server"))
 980                            }
 981                        }
 982                    }
 983                    Err(EstablishConnectionError::Unauthorized) => {
 984                        self.state.write().credentials.take();
 985                        if read_from_provider {
 986                            self.credentials_provider.delete_credentials(cx).await.log_err();
 987                            self.set_status(Status::SignedOut, cx);
 988                            self.authenticate_and_connect(false, cx).await
 989                        } else {
 990                            self.set_status(Status::ConnectionError, cx);
 991                            Err(EstablishConnectionError::Unauthorized)?
 992                        }
 993                    }
 994                    Err(EstablishConnectionError::UpgradeRequired) => {
 995                        self.set_status(Status::UpgradeRequired, cx);
 996                        Err(EstablishConnectionError::UpgradeRequired)?
 997                    }
 998                    Err(error) => {
 999                        self.set_status(Status::ConnectionError, cx);
1000                        Err(error)?
1001                    }
1002                }
1003            }
1004            _ = &mut timeout => {
1005                self.set_status(Status::ConnectionError, cx);
1006                Err(anyhow!("timed out trying to establish connection"))
1007            }
1008        }
1009    }
1010
1011    async fn set_connection(
1012        self: &Arc<Self>,
1013        conn: Connection,
1014        cx: &AsyncAppContext,
1015    ) -> Result<()> {
1016        let executor = cx.background_executor();
1017        log::info!("add connection to peer");
1018        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
1019            let executor = executor.clone();
1020            move |duration| executor.timer(duration)
1021        });
1022        let handle_io = executor.spawn(handle_io);
1023
1024        let peer_id = async {
1025            log::info!("waiting for server hello");
1026            let message = incoming
1027                .next()
1028                .await
1029                .ok_or_else(|| anyhow!("no hello message received"))?;
1030            log::info!("got server hello");
1031            let hello_message_type_name = message.payload_type_name().to_string();
1032            let hello = message
1033                .into_any()
1034                .downcast::<TypedEnvelope<proto::Hello>>()
1035                .map_err(|_| {
1036                    anyhow!(
1037                        "invalid hello message received: {:?}",
1038                        hello_message_type_name
1039                    )
1040                })?;
1041            let peer_id = hello
1042                .payload
1043                .peer_id
1044                .ok_or_else(|| anyhow!("invalid peer id"))?;
1045            Ok(peer_id)
1046        };
1047
1048        let peer_id = match peer_id.await {
1049            Ok(peer_id) => peer_id,
1050            Err(error) => {
1051                self.peer.disconnect(connection_id);
1052                return Err(error);
1053            }
1054        };
1055
1056        log::info!(
1057            "set status to connected (connection id: {:?}, peer id: {:?})",
1058            connection_id,
1059            peer_id
1060        );
1061        self.set_status(
1062            Status::Connected {
1063                peer_id,
1064                connection_id,
1065            },
1066            cx,
1067        );
1068
1069        cx.spawn({
1070            let this = self.clone();
1071            |cx| {
1072                async move {
1073                    while let Some(message) = incoming.next().await {
1074                        this.handle_message(message, &cx);
1075                        // Don't starve the main thread when receiving lots of messages at once.
1076                        smol::future::yield_now().await;
1077                    }
1078                }
1079            }
1080        })
1081        .detach();
1082
1083        cx.spawn({
1084            let this = self.clone();
1085            move |cx| async move {
1086                match handle_io.await {
1087                    Ok(()) => {
1088                        if *this.status().borrow()
1089                            == (Status::Connected {
1090                                connection_id,
1091                                peer_id,
1092                            })
1093                        {
1094                            this.set_status(Status::SignedOut, &cx);
1095                        }
1096                    }
1097                    Err(err) => {
1098                        log::error!("connection error: {:?}", err);
1099                        this.set_status(Status::ConnectionLost, &cx);
1100                    }
1101                }
1102            }
1103        })
1104        .detach();
1105
1106        Ok(())
1107    }
1108
1109    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
1110        #[cfg(any(test, feature = "test-support"))]
1111        if let Some(callback) = self.authenticate.read().as_ref() {
1112            return callback(cx);
1113        }
1114
1115        self.authenticate_with_browser(cx)
1116    }
1117
1118    fn establish_connection(
1119        self: &Arc<Self>,
1120        credentials: &Credentials,
1121        cx: &AsyncAppContext,
1122    ) -> Task<Result<Connection, EstablishConnectionError>> {
1123        #[cfg(any(test, feature = "test-support"))]
1124        if let Some(callback) = self.establish_connection.read().as_ref() {
1125            return callback(credentials, cx);
1126        }
1127
1128        self.establish_websocket_connection(credentials, cx)
1129    }
1130
1131    fn rpc_url(
1132        &self,
1133        http: Arc<HttpClientWithUrl>,
1134        release_channel: Option<ReleaseChannel>,
1135    ) -> impl Future<Output = Result<Url>> {
1136        #[cfg(any(test, feature = "test-support"))]
1137        let url_override = self.rpc_url.read().clone();
1138
1139        async move {
1140            #[cfg(any(test, feature = "test-support"))]
1141            if let Some(url) = url_override {
1142                return Ok(url);
1143            }
1144
1145            if let Some(url) = &*ZED_RPC_URL {
1146                return Url::parse(url).context("invalid rpc url");
1147            }
1148
1149            let mut url = http.build_url("/rpc");
1150            if let Some(preview_param) =
1151                release_channel.and_then(|channel| channel.release_query_param())
1152            {
1153                url += "?";
1154                url += preview_param;
1155            }
1156
1157            let response = http.get(&url, Default::default(), false).await?;
1158            let collab_url = if response.status().is_redirection() {
1159                response
1160                    .headers()
1161                    .get("Location")
1162                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1163                    .to_str()
1164                    .map_err(EstablishConnectionError::other)?
1165                    .to_string()
1166            } else {
1167                Err(anyhow!(
1168                    "unexpected /rpc response status {}",
1169                    response.status()
1170                ))?
1171            };
1172
1173            Url::parse(&collab_url).context("invalid rpc url")
1174        }
1175    }
1176
1177    fn establish_websocket_connection(
1178        self: &Arc<Self>,
1179        credentials: &Credentials,
1180        cx: &AsyncAppContext,
1181    ) -> Task<Result<Connection, EstablishConnectionError>> {
1182        let release_channel = cx
1183            .update(|cx| ReleaseChannel::try_global(cx))
1184            .ok()
1185            .flatten();
1186        let app_version = cx
1187            .update(|cx| AppVersion::global(cx).to_string())
1188            .ok()
1189            .unwrap_or_default();
1190
1191        let http = self.http.clone();
1192        let proxy = http.proxy().cloned();
1193        let credentials = credentials.clone();
1194        let rpc_url = self.rpc_url(http, release_channel);
1195        cx.background_executor().spawn(async move {
1196            use HttpOrHttps::*;
1197
1198            #[derive(Debug)]
1199            enum HttpOrHttps {
1200                Http,
1201                Https,
1202            }
1203
1204            let mut rpc_url = rpc_url.await?;
1205            let url_scheme = match rpc_url.scheme() {
1206                "https" => Https,
1207                "http" => Http,
1208                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1209            };
1210            let rpc_host = rpc_url
1211                .host_str()
1212                .zip(rpc_url.port_or_known_default())
1213                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1214            let stream = connect_socks_proxy_stream(proxy.as_ref(), rpc_host).await?;
1215
1216            log::info!("connected to rpc endpoint {}", rpc_url);
1217
1218            rpc_url
1219                .set_scheme(match url_scheme {
1220                    Https => "wss",
1221                    Http => "ws",
1222                })
1223                .unwrap();
1224
1225            // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1226            // for us from the RPC URL.
1227            //
1228            // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1229            let mut request = rpc_url.into_client_request()?;
1230
1231            // We then modify the request to add our desired headers.
1232            let request_headers = request.headers_mut();
1233            request_headers.insert(
1234                "Authorization",
1235                HeaderValue::from_str(&credentials.authorization_header())?,
1236            );
1237            request_headers.insert(
1238                "x-zed-protocol-version",
1239                HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1240            );
1241            request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1242            request_headers.insert(
1243                "x-zed-release-channel",
1244                HeaderValue::from_str(&release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1245            );
1246
1247            match url_scheme {
1248                Https => {
1249                    let (stream, _) =
1250                        async_tungstenite::async_std::client_async_tls(request, stream).await?;
1251                    Ok(Connection::new(
1252                        stream
1253                            .map_err(|error| anyhow!(error))
1254                            .sink_map_err(|error| anyhow!(error)),
1255                    ))
1256                }
1257                Http => {
1258                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1259                    Ok(Connection::new(
1260                        stream
1261                            .map_err(|error| anyhow!(error))
1262                            .sink_map_err(|error| anyhow!(error)),
1263                    ))
1264                }
1265            }
1266        })
1267    }
1268
1269    pub fn authenticate_with_browser(
1270        self: &Arc<Self>,
1271        cx: &AsyncAppContext,
1272    ) -> Task<Result<Credentials>> {
1273        let http = self.http.clone();
1274        let this = self.clone();
1275        cx.spawn(|cx| async move {
1276            let background = cx.background_executor().clone();
1277
1278            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1279            cx.update(|cx| {
1280                cx.spawn(move |cx| async move {
1281                    let url = open_url_rx.await?;
1282                    cx.update(|cx| cx.open_url(&url))
1283                })
1284                .detach_and_log_err(cx);
1285            })
1286            .log_err();
1287
1288            let credentials = background
1289                .clone()
1290                .spawn(async move {
1291                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1292                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1293                    // any other app running on the user's device.
1294                    let (public_key, private_key) =
1295                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1296                    let public_key_string = String::try_from(public_key)
1297                        .expect("failed to serialize public key for auth");
1298
1299                    if let Some((login, token)) =
1300                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1301                    {
1302                        eprintln!("authenticate as admin {login}, {token}");
1303
1304                        return this
1305                            .authenticate_as_admin(http, login.clone(), token.clone())
1306                            .await;
1307                    }
1308
1309                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1310                    let server =
1311                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1312                    let port = server.server_addr().port();
1313
1314                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1315                    // that the user is signing in from a Zed app running on the same device.
1316                    let mut url = http.build_url(&format!(
1317                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1318                        port, public_key_string
1319                    ));
1320
1321                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1322                        log::info!("impersonating user @{}", impersonate_login);
1323                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1324                    }
1325
1326                    open_url_tx.send(url).log_err();
1327
1328                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1329                    // access token from the query params.
1330                    //
1331                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1332                    // custom URL scheme instead of this local HTTP server.
1333                    let (user_id, access_token) = background
1334                        .spawn(async move {
1335                            for _ in 0..100 {
1336                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1337                                    let path = req.url();
1338                                    let mut user_id = None;
1339                                    let mut access_token = None;
1340                                    let url = Url::parse(&format!("http://example.com{}", path))
1341                                        .context("failed to parse login notification url")?;
1342                                    for (key, value) in url.query_pairs() {
1343                                        if key == "access_token" {
1344                                            access_token = Some(value.to_string());
1345                                        } else if key == "user_id" {
1346                                            user_id = Some(value.to_string());
1347                                        }
1348                                    }
1349
1350                                    let post_auth_url =
1351                                        http.build_url("/native_app_signin_succeeded");
1352                                    req.respond(
1353                                        tiny_http::Response::empty(302).with_header(
1354                                            tiny_http::Header::from_bytes(
1355                                                &b"Location"[..],
1356                                                post_auth_url.as_bytes(),
1357                                            )
1358                                            .unwrap(),
1359                                        ),
1360                                    )
1361                                    .context("failed to respond to login http request")?;
1362                                    return Ok((
1363                                        user_id
1364                                            .ok_or_else(|| anyhow!("missing user_id parameter"))?,
1365                                        access_token.ok_or_else(|| {
1366                                            anyhow!("missing access_token parameter")
1367                                        })?,
1368                                    ));
1369                                }
1370                            }
1371
1372                            Err(anyhow!("didn't receive login redirect"))
1373                        })
1374                        .await?;
1375
1376                    let access_token = private_key
1377                        .decrypt_string(&access_token)
1378                        .context("failed to decrypt access token")?;
1379
1380                    Ok(Credentials::User {
1381                        user_id: user_id.parse()?,
1382                        access_token,
1383                    })
1384                })
1385                .await?;
1386
1387            cx.update(|cx| cx.activate(true))?;
1388            Ok(credentials)
1389        })
1390    }
1391
1392    async fn authenticate_as_admin(
1393        self: &Arc<Self>,
1394        http: Arc<HttpClientWithUrl>,
1395        login: String,
1396        mut api_token: String,
1397    ) -> Result<Credentials> {
1398        #[derive(Deserialize)]
1399        struct AuthenticatedUserResponse {
1400            user: User,
1401        }
1402
1403        #[derive(Deserialize)]
1404        struct User {
1405            id: u64,
1406        }
1407
1408        let github_user = {
1409            #[derive(Deserialize)]
1410            struct GithubUser {
1411                id: i32,
1412                login: String,
1413                created_at: DateTime<Utc>,
1414            }
1415
1416            let request = {
1417                let mut request_builder =
1418                    Request::get(&format!("https://api.github.com/users/{login}"));
1419                if let Ok(github_token) = std::env::var("GITHUB_TOKEN") {
1420                    request_builder =
1421                        request_builder.header("Authorization", format!("Bearer {}", github_token));
1422                }
1423
1424                request_builder.body(AsyncBody::empty())?
1425            };
1426
1427            let mut response = http
1428                .send(request)
1429                .await
1430                .context("error fetching GitHub user")?;
1431
1432            let mut body = Vec::new();
1433            response
1434                .body_mut()
1435                .read_to_end(&mut body)
1436                .await
1437                .context("error reading GitHub user")?;
1438
1439            if !response.status().is_success() {
1440                let text = String::from_utf8_lossy(body.as_slice());
1441                bail!(
1442                    "status error {}, response: {text:?}",
1443                    response.status().as_u16()
1444                );
1445            }
1446
1447            let user = serde_json::from_slice::<GithubUser>(body.as_slice()).map_err(|err| {
1448                log::error!("Error deserializing: {:?}", err);
1449                log::error!(
1450                    "GitHub API response text: {:?}",
1451                    String::from_utf8_lossy(body.as_slice())
1452                );
1453                anyhow!("error deserializing GitHub user")
1454            })?;
1455
1456            user
1457        };
1458
1459        // Use the collab server's admin API to retrieve the ID
1460        // of the impersonated user.
1461        let mut url = self.rpc_url(http.clone(), None).await?;
1462        url.set_path("/user");
1463        url.set_query(Some(&format!(
1464            "github_login={login}&github_user_id={id}&github_user_created_at={created_at}",
1465            login = github_user.login,
1466            id = github_user.id,
1467            created_at = github_user.created_at.to_rfc3339()
1468        )));
1469        let request: http_client::Request<AsyncBody> = Request::get(url.as_str())
1470            .header("Authorization", format!("token {api_token}"))
1471            .body("".into())?;
1472
1473        let mut response = http.send(request).await?;
1474        let mut body = String::new();
1475        response.body_mut().read_to_string(&mut body).await?;
1476        if !response.status().is_success() {
1477            Err(anyhow!(
1478                "admin user request failed {} - {}",
1479                response.status().as_u16(),
1480                body,
1481            ))?;
1482        }
1483        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1484
1485        // Use the admin API token to authenticate as the impersonated user.
1486        api_token.insert_str(0, "ADMIN_TOKEN:");
1487        Ok(Credentials::User {
1488            user_id: response.user.id,
1489            access_token: api_token,
1490        })
1491    }
1492
1493    pub async fn sign_out(self: &Arc<Self>, cx: &AsyncAppContext) {
1494        self.state.write().credentials = None;
1495        self.disconnect(&cx);
1496
1497        if self.has_credentials(cx).await {
1498            self.credentials_provider
1499                .delete_credentials(cx)
1500                .await
1501                .log_err();
1502        }
1503    }
1504
1505    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1506        self.peer.teardown();
1507        self.set_status(Status::SignedOut, cx);
1508    }
1509
1510    pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1511        self.peer.teardown();
1512        self.set_status(Status::ConnectionLost, cx);
1513    }
1514
1515    fn connection_id(&self) -> Result<ConnectionId> {
1516        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1517            Ok(connection_id)
1518        } else {
1519            Err(anyhow!("not connected"))
1520        }
1521    }
1522
1523    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1524        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1525        self.peer.send(self.connection_id()?, message)
1526    }
1527
1528    pub fn send_dynamic(
1529        &self,
1530        envelope: proto::Envelope,
1531        message_type: &'static str,
1532    ) -> Result<()> {
1533        log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1534        let connection_id = self.connection_id()?;
1535        self.peer.send_dynamic(connection_id, envelope)
1536    }
1537
1538    pub fn request<T: RequestMessage>(
1539        &self,
1540        request: T,
1541    ) -> impl Future<Output = Result<T::Response>> {
1542        self.request_envelope(request)
1543            .map_ok(|envelope| envelope.payload)
1544    }
1545
1546    pub fn request_stream<T: RequestMessage>(
1547        &self,
1548        request: T,
1549    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1550        let client_id = self.id.load(Ordering::SeqCst);
1551        log::debug!(
1552            "rpc request start. client_id:{}. name:{}",
1553            client_id,
1554            T::NAME
1555        );
1556        let response = self
1557            .connection_id()
1558            .map(|conn_id| self.peer.request_stream(conn_id, request));
1559        async move {
1560            let response = response?.await;
1561            log::debug!(
1562                "rpc request finish. client_id:{}. name:{}",
1563                client_id,
1564                T::NAME
1565            );
1566            response
1567        }
1568    }
1569
1570    pub fn request_envelope<T: RequestMessage>(
1571        &self,
1572        request: T,
1573    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1574        let client_id = self.id();
1575        log::debug!(
1576            "rpc request start. client_id:{}. name:{}",
1577            client_id,
1578            T::NAME
1579        );
1580        let response = self
1581            .connection_id()
1582            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1583        async move {
1584            let response = response?.await;
1585            log::debug!(
1586                "rpc request finish. client_id:{}. name:{}",
1587                client_id,
1588                T::NAME
1589            );
1590            response
1591        }
1592    }
1593
1594    pub fn request_dynamic(
1595        &self,
1596        envelope: proto::Envelope,
1597        request_type: &'static str,
1598    ) -> impl Future<Output = Result<proto::Envelope>> {
1599        let client_id = self.id();
1600        log::debug!(
1601            "rpc request start. client_id:{}. name:{}",
1602            client_id,
1603            request_type
1604        );
1605        let response = self
1606            .connection_id()
1607            .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1608        async move {
1609            let response = response?.await;
1610            log::debug!(
1611                "rpc request finish. client_id:{}. name:{}",
1612                client_id,
1613                request_type
1614            );
1615            Ok(response?.0)
1616        }
1617    }
1618
1619    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1620        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1621        self.peer.respond(receipt, response)
1622    }
1623
1624    fn respond_with_error<T: RequestMessage>(
1625        &self,
1626        receipt: Receipt<T>,
1627        error: proto::Error,
1628    ) -> Result<()> {
1629        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1630        self.peer.respond_with_error(receipt, error)
1631    }
1632
1633    fn handle_message(
1634        self: &Arc<Client>,
1635        message: Box<dyn AnyTypedEnvelope>,
1636        cx: &AsyncAppContext,
1637    ) {
1638        let mut state = self.state.write();
1639        let type_name = message.payload_type_name();
1640        let payload_type_id = message.payload_type_id();
1641        let sender_id = message.original_sender_id();
1642
1643        let mut subscriber = None;
1644
1645        if let Some(handle) = state
1646            .models_by_message_type
1647            .get(&payload_type_id)
1648            .and_then(|handle| handle.upgrade())
1649        {
1650            subscriber = Some(handle);
1651        } else if let Some((extract_entity_id, entity_type_id)) =
1652            state.entity_id_extractors.get(&payload_type_id).zip(
1653                state
1654                    .entity_types_by_message_type
1655                    .get(&payload_type_id)
1656                    .copied(),
1657            )
1658        {
1659            let entity_id = (extract_entity_id)(message.as_ref());
1660
1661            match state
1662                .entities_by_type_and_remote_id
1663                .get_mut(&(entity_type_id, entity_id))
1664            {
1665                Some(WeakSubscriber::Pending(pending)) => {
1666                    pending.push(message);
1667                    return;
1668                }
1669                Some(weak_subscriber) => match weak_subscriber {
1670                    WeakSubscriber::Entity { handle } => {
1671                        subscriber = handle.upgrade();
1672                    }
1673
1674                    WeakSubscriber::Pending(_) => {}
1675                },
1676                _ => {}
1677            }
1678        }
1679
1680        let subscriber = if let Some(subscriber) = subscriber {
1681            subscriber
1682        } else {
1683            log::info!("unhandled message {}", type_name);
1684            self.peer.respond_with_unhandled_message(message).log_err();
1685            return;
1686        };
1687
1688        let handler = state.message_handlers.get(&payload_type_id).cloned();
1689        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1690        // It also ensures we don't hold the lock while yielding back to the executor, as
1691        // that might cause the executor thread driving this future to block indefinitely.
1692        drop(state);
1693
1694        if let Some(handler) = handler {
1695            let future = handler(subscriber, message, self, cx.clone());
1696            let client_id = self.id();
1697            log::debug!(
1698                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1699                client_id,
1700                sender_id,
1701                type_name
1702            );
1703            cx.spawn(move |_| async move {
1704                    match future.await {
1705                        Ok(()) => {
1706                            log::debug!(
1707                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1708                                client_id,
1709                                sender_id,
1710                                type_name
1711                            );
1712                        }
1713                        Err(error) => {
1714                            log::error!(
1715                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1716                                client_id,
1717                                sender_id,
1718                                type_name,
1719                                error
1720                            );
1721                        }
1722                    }
1723                })
1724                .detach();
1725        } else {
1726            log::info!("unhandled message {}", type_name);
1727            self.peer.respond_with_unhandled_message(message).log_err();
1728        }
1729    }
1730
1731    pub fn telemetry(&self) -> &Arc<Telemetry> {
1732        &self.telemetry
1733    }
1734}
1735
1736impl ProtoClient for Client {
1737    fn request(
1738        &self,
1739        envelope: proto::Envelope,
1740        request_type: &'static str,
1741    ) -> BoxFuture<'static, Result<proto::Envelope>> {
1742        self.request_dynamic(envelope, request_type).boxed()
1743    }
1744
1745    fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1746        self.send_dynamic(envelope, message_type)
1747    }
1748}
1749
1750#[derive(Serialize, Deserialize)]
1751struct DevelopmentCredentials {
1752    user_id: u64,
1753    access_token: String,
1754}
1755
1756/// A credentials provider that stores credentials in a local file.
1757///
1758/// This MUST only be used in development, as this is not a secure way of storing
1759/// credentials on user machines.
1760///
1761/// Its existence is purely to work around the annoyance of having to constantly
1762/// re-allow access to the system keychain when developing Zed.
1763struct DevelopmentCredentialsProvider {
1764    path: PathBuf,
1765}
1766
1767impl CredentialsProvider for DevelopmentCredentialsProvider {
1768    fn read_credentials<'a>(
1769        &'a self,
1770        _cx: &'a AsyncAppContext,
1771    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1772        async move {
1773            if IMPERSONATE_LOGIN.is_some() {
1774                return None;
1775            }
1776
1777            let json = std::fs::read(&self.path).log_err()?;
1778
1779            let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?;
1780
1781            Some(Credentials::User {
1782                user_id: credentials.user_id,
1783                access_token: credentials.access_token,
1784            })
1785        }
1786        .boxed_local()
1787    }
1788
1789    fn write_credentials<'a>(
1790        &'a self,
1791        user_id: u64,
1792        access_token: String,
1793        _cx: &'a AsyncAppContext,
1794    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1795        async move {
1796            let json = serde_json::to_string(&DevelopmentCredentials {
1797                user_id,
1798                access_token,
1799            })?;
1800
1801            std::fs::write(&self.path, json)?;
1802
1803            Ok(())
1804        }
1805        .boxed_local()
1806    }
1807
1808    fn delete_credentials<'a>(
1809        &'a self,
1810        _cx: &'a AsyncAppContext,
1811    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1812        async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local()
1813    }
1814}
1815
1816/// A credentials provider that stores credentials in the system keychain.
1817struct KeychainCredentialsProvider;
1818
1819impl CredentialsProvider for KeychainCredentialsProvider {
1820    fn read_credentials<'a>(
1821        &'a self,
1822        cx: &'a AsyncAppContext,
1823    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1824        async move {
1825            if IMPERSONATE_LOGIN.is_some() {
1826                return None;
1827            }
1828
1829            let (user_id, access_token) = cx
1830                .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
1831                .log_err()?
1832                .await
1833                .log_err()??;
1834
1835            Some(Credentials::User {
1836                user_id: user_id.parse().ok()?,
1837                access_token: String::from_utf8(access_token).ok()?,
1838            })
1839        }
1840        .boxed_local()
1841    }
1842
1843    fn write_credentials<'a>(
1844        &'a self,
1845        user_id: u64,
1846        access_token: String,
1847        cx: &'a AsyncAppContext,
1848    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1849        async move {
1850            cx.update(move |cx| {
1851                cx.write_credentials(
1852                    &ClientSettings::get_global(cx).server_url,
1853                    &user_id.to_string(),
1854                    access_token.as_bytes(),
1855                )
1856            })?
1857            .await
1858        }
1859        .boxed_local()
1860    }
1861
1862    fn delete_credentials<'a>(
1863        &'a self,
1864        cx: &'a AsyncAppContext,
1865    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1866        async move {
1867            cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
1868                .await
1869        }
1870        .boxed_local()
1871    }
1872}
1873
1874/// prefix for the zed:// url scheme
1875pub static ZED_URL_SCHEME: &str = "zed";
1876
1877/// Parses the given link into a Zed link.
1878///
1879/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1880/// Returns [`None`] otherwise.
1881pub fn parse_zed_link<'a>(link: &'a str, cx: &AppContext) -> Option<&'a str> {
1882    let server_url = &ClientSettings::get_global(cx).server_url;
1883    if let Some(stripped) = link
1884        .strip_prefix(server_url)
1885        .and_then(|result| result.strip_prefix('/'))
1886    {
1887        return Some(stripped);
1888    }
1889    if let Some(stripped) = link
1890        .strip_prefix(ZED_URL_SCHEME)
1891        .and_then(|result| result.strip_prefix("://"))
1892    {
1893        return Some(stripped);
1894    }
1895
1896    None
1897}
1898
1899#[cfg(test)]
1900mod tests {
1901    use super::*;
1902    use crate::test::FakeServer;
1903
1904    use clock::FakeSystemClock;
1905    use gpui::{BackgroundExecutor, Context, TestAppContext};
1906    use http_client::FakeHttpClient;
1907    use parking_lot::Mutex;
1908    use proto::TypedEnvelope;
1909    use settings::SettingsStore;
1910    use std::future;
1911
1912    #[gpui::test(iterations = 10)]
1913    async fn test_reconnection(cx: &mut TestAppContext) {
1914        init_test(cx);
1915        let user_id = 5;
1916        let client = cx.update(|cx| {
1917            Client::new(
1918                Arc::new(FakeSystemClock::default()),
1919                FakeHttpClient::with_404_response(),
1920                cx,
1921            )
1922        });
1923        let server = FakeServer::for_client(user_id, &client, cx).await;
1924        let mut status = client.status();
1925        assert!(matches!(
1926            status.next().await,
1927            Some(Status::Connected { .. })
1928        ));
1929        assert_eq!(server.auth_count(), 1);
1930
1931        server.forbid_connections();
1932        server.disconnect();
1933        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1934
1935        server.allow_connections();
1936        cx.executor().advance_clock(Duration::from_secs(10));
1937        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1938        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1939
1940        server.forbid_connections();
1941        server.disconnect();
1942        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1943
1944        // Clear cached credentials after authentication fails
1945        server.roll_access_token();
1946        server.allow_connections();
1947        cx.executor().run_until_parked();
1948        cx.executor().advance_clock(Duration::from_secs(10));
1949        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1950        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1951    }
1952
1953    #[gpui::test(iterations = 10)]
1954    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1955        init_test(cx);
1956        let user_id = 5;
1957        let client = cx.update(|cx| {
1958            Client::new(
1959                Arc::new(FakeSystemClock::default()),
1960                FakeHttpClient::with_404_response(),
1961                cx,
1962            )
1963        });
1964        let mut status = client.status();
1965
1966        // Time out when client tries to connect.
1967        client.override_authenticate(move |cx| {
1968            cx.background_executor().spawn(async move {
1969                Ok(Credentials::User {
1970                    user_id,
1971                    access_token: "token".into(),
1972                })
1973            })
1974        });
1975        client.override_establish_connection(|_, cx| {
1976            cx.background_executor().spawn(async move {
1977                future::pending::<()>().await;
1978                unreachable!()
1979            })
1980        });
1981        let auth_and_connect = cx.spawn({
1982            let client = client.clone();
1983            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1984        });
1985        executor.run_until_parked();
1986        assert!(matches!(status.next().await, Some(Status::Connecting)));
1987
1988        executor.advance_clock(CONNECTION_TIMEOUT);
1989        assert!(matches!(
1990            status.next().await,
1991            Some(Status::ConnectionError { .. })
1992        ));
1993        auth_and_connect.await.unwrap_err();
1994
1995        // Allow the connection to be established.
1996        let server = FakeServer::for_client(user_id, &client, cx).await;
1997        assert!(matches!(
1998            status.next().await,
1999            Some(Status::Connected { .. })
2000        ));
2001
2002        // Disconnect client.
2003        server.forbid_connections();
2004        server.disconnect();
2005        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
2006
2007        // Time out when re-establishing the connection.
2008        server.allow_connections();
2009        client.override_establish_connection(|_, cx| {
2010            cx.background_executor().spawn(async move {
2011                future::pending::<()>().await;
2012                unreachable!()
2013            })
2014        });
2015        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
2016        assert!(matches!(
2017            status.next().await,
2018            Some(Status::Reconnecting { .. })
2019        ));
2020
2021        executor.advance_clock(CONNECTION_TIMEOUT);
2022        assert!(matches!(
2023            status.next().await,
2024            Some(Status::ReconnectionError { .. })
2025        ));
2026    }
2027
2028    #[gpui::test(iterations = 10)]
2029    async fn test_authenticating_more_than_once(
2030        cx: &mut TestAppContext,
2031        executor: BackgroundExecutor,
2032    ) {
2033        init_test(cx);
2034        let auth_count = Arc::new(Mutex::new(0));
2035        let dropped_auth_count = Arc::new(Mutex::new(0));
2036        let client = cx.update(|cx| {
2037            Client::new(
2038                Arc::new(FakeSystemClock::default()),
2039                FakeHttpClient::with_404_response(),
2040                cx,
2041            )
2042        });
2043        client.override_authenticate({
2044            let auth_count = auth_count.clone();
2045            let dropped_auth_count = dropped_auth_count.clone();
2046            move |cx| {
2047                let auth_count = auth_count.clone();
2048                let dropped_auth_count = dropped_auth_count.clone();
2049                cx.background_executor().spawn(async move {
2050                    *auth_count.lock() += 1;
2051                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
2052                    future::pending::<()>().await;
2053                    unreachable!()
2054                })
2055            }
2056        });
2057
2058        let _authenticate = cx.spawn({
2059            let client = client.clone();
2060            move |cx| async move { client.authenticate_and_connect(false, &cx).await }
2061        });
2062        executor.run_until_parked();
2063        assert_eq!(*auth_count.lock(), 1);
2064        assert_eq!(*dropped_auth_count.lock(), 0);
2065
2066        let _authenticate = cx.spawn({
2067            let client = client.clone();
2068            |cx| async move { client.authenticate_and_connect(false, &cx).await }
2069        });
2070        executor.run_until_parked();
2071        assert_eq!(*auth_count.lock(), 2);
2072        assert_eq!(*dropped_auth_count.lock(), 1);
2073    }
2074
2075    #[gpui::test]
2076    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
2077        init_test(cx);
2078        let user_id = 5;
2079        let client = cx.update(|cx| {
2080            Client::new(
2081                Arc::new(FakeSystemClock::default()),
2082                FakeHttpClient::with_404_response(),
2083                cx,
2084            )
2085        });
2086        let server = FakeServer::for_client(user_id, &client, cx).await;
2087
2088        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
2089        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
2090        client.add_model_message_handler(
2091            move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, mut cx| {
2092                match model.update(&mut cx, |model, _| model.id).unwrap() {
2093                    1 => done_tx1.try_send(()).unwrap(),
2094                    2 => done_tx2.try_send(()).unwrap(),
2095                    _ => unreachable!(),
2096                }
2097                async { Ok(()) }
2098            },
2099        );
2100        let model1 = cx.new_model(|_| TestModel {
2101            id: 1,
2102            subscription: None,
2103        });
2104        let model2 = cx.new_model(|_| TestModel {
2105            id: 2,
2106            subscription: None,
2107        });
2108        let model3 = cx.new_model(|_| TestModel {
2109            id: 3,
2110            subscription: None,
2111        });
2112
2113        let _subscription1 = client
2114            .subscribe_to_entity(1)
2115            .unwrap()
2116            .set_model(&model1, &mut cx.to_async());
2117        let _subscription2 = client
2118            .subscribe_to_entity(2)
2119            .unwrap()
2120            .set_model(&model2, &mut cx.to_async());
2121        // Ensure dropping a subscription for the same entity type still allows receiving of
2122        // messages for other entity IDs of the same type.
2123        let subscription3 = client
2124            .subscribe_to_entity(3)
2125            .unwrap()
2126            .set_model(&model3, &mut cx.to_async());
2127        drop(subscription3);
2128
2129        server.send(proto::JoinProject { project_id: 1 });
2130        server.send(proto::JoinProject { project_id: 2 });
2131        done_rx1.next().await.unwrap();
2132        done_rx2.next().await.unwrap();
2133    }
2134
2135    #[gpui::test]
2136    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2137        init_test(cx);
2138        let user_id = 5;
2139        let client = cx.update(|cx| {
2140            Client::new(
2141                Arc::new(FakeSystemClock::default()),
2142                FakeHttpClient::with_404_response(),
2143                cx,
2144            )
2145        });
2146        let server = FakeServer::for_client(user_id, &client, cx).await;
2147
2148        let model = cx.new_model(|_| TestModel::default());
2149        let (done_tx1, _done_rx1) = smol::channel::unbounded();
2150        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
2151        let subscription1 = client.add_message_handler(
2152            model.downgrade(),
2153            move |_, _: TypedEnvelope<proto::Ping>, _| {
2154                done_tx1.try_send(()).unwrap();
2155                async { Ok(()) }
2156            },
2157        );
2158        drop(subscription1);
2159        let _subscription2 = client.add_message_handler(
2160            model.downgrade(),
2161            move |_, _: TypedEnvelope<proto::Ping>, _| {
2162                done_tx2.try_send(()).unwrap();
2163                async { Ok(()) }
2164            },
2165        );
2166        server.send(proto::Ping {});
2167        done_rx2.next().await.unwrap();
2168    }
2169
2170    #[gpui::test]
2171    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2172        init_test(cx);
2173        let user_id = 5;
2174        let client = cx.update(|cx| {
2175            Client::new(
2176                Arc::new(FakeSystemClock::default()),
2177                FakeHttpClient::with_404_response(),
2178                cx,
2179            )
2180        });
2181        let server = FakeServer::for_client(user_id, &client, cx).await;
2182
2183        let model = cx.new_model(|_| TestModel::default());
2184        let (done_tx, mut done_rx) = smol::channel::unbounded();
2185        let subscription = client.add_message_handler(
2186            model.clone().downgrade(),
2187            move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, mut cx| {
2188                model
2189                    .update(&mut cx, |model, _| model.subscription.take())
2190                    .unwrap();
2191                done_tx.try_send(()).unwrap();
2192                async { Ok(()) }
2193            },
2194        );
2195        model.update(cx, |model, _| {
2196            model.subscription = Some(subscription);
2197        });
2198        server.send(proto::Ping {});
2199        done_rx.next().await.unwrap();
2200    }
2201
2202    #[derive(Default)]
2203    struct TestModel {
2204        id: usize,
2205        subscription: Option<Subscription>,
2206    }
2207
2208    fn init_test(cx: &mut TestAppContext) {
2209        cx.update(|cx| {
2210            let settings_store = SettingsStore::test(cx);
2211            cx.set_global(settings_store);
2212            init_settings(cx);
2213        });
2214    }
2215}