client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod telemetry;
   5pub mod user;
   6
   7use anyhow::{anyhow, Context as _, Result};
   8use async_recursion::async_recursion;
   9use async_tungstenite::tungstenite::{
  10    error::Error as WebsocketError,
  11    http::{Request, StatusCode},
  12};
  13use clock::SystemClock;
  14use collections::HashMap;
  15use futures::{
  16    channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt,
  17    TryFutureExt as _, TryStreamExt,
  18};
  19use gpui::{
  20    actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Global, Model, Task, WeakModel,
  21};
  22use lazy_static::lazy_static;
  23use parking_lot::RwLock;
  24use postage::watch;
  25use rand::prelude::*;
  26use release_channel::{AppVersion, ReleaseChannel};
  27use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
  28use schemars::JsonSchema;
  29use serde::{Deserialize, Serialize};
  30
  31use settings::{Settings, SettingsStore};
  32use std::{
  33    any::TypeId,
  34    convert::TryFrom,
  35    fmt::Write as _,
  36    future::Future,
  37    marker::PhantomData,
  38    path::PathBuf,
  39    sync::{atomic::AtomicU64, Arc, Weak},
  40    time::{Duration, Instant},
  41};
  42use telemetry::Telemetry;
  43use thiserror::Error;
  44use url::Url;
  45use util::http::{HttpClient, HttpClientWithUrl};
  46use util::{ResultExt, TryFutureExt};
  47
  48pub use rpc::*;
  49pub use telemetry_events::Event;
  50pub use user::*;
  51
  52lazy_static! {
  53    static ref ZED_SERVER_URL: Option<String> = std::env::var("ZED_SERVER_URL").ok();
  54    static ref ZED_RPC_URL: Option<String> = std::env::var("ZED_RPC_URL").ok();
  55    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  56        .ok()
  57        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  58    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
  59        .ok()
  60        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  61    pub static ref ZED_APP_PATH: Option<PathBuf> =
  62        std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
  63    pub static ref ZED_ALWAYS_ACTIVE: bool =
  64        std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty());
  65}
  66
  67pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
  68pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
  69
  70actions!(client, [SignIn, SignOut, Reconnect]);
  71
  72#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
  73pub struct ClientSettingsContent {
  74    server_url: Option<String>,
  75}
  76
  77#[derive(Deserialize)]
  78pub struct ClientSettings {
  79    pub server_url: String,
  80}
  81
  82impl Settings for ClientSettings {
  83    const KEY: Option<&'static str> = None;
  84
  85    type FileContent = ClientSettingsContent;
  86
  87    fn load(
  88        default_value: &Self::FileContent,
  89        user_values: &[&Self::FileContent],
  90        _: &mut AppContext,
  91    ) -> Result<Self>
  92    where
  93        Self: Sized,
  94    {
  95        let mut result = Self::load_via_json_merge(default_value, user_values)?;
  96        if let Some(server_url) = &*ZED_SERVER_URL {
  97            result.server_url = server_url.clone()
  98        }
  99        Ok(result)
 100    }
 101}
 102
 103pub fn init_settings(cx: &mut AppContext) {
 104    TelemetrySettings::register(cx);
 105    cx.update_global(|store: &mut SettingsStore, cx| {
 106        store.register_setting::<ClientSettings>(cx);
 107    });
 108}
 109
 110pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
 111    let client = Arc::downgrade(client);
 112    cx.on_action({
 113        let client = client.clone();
 114        move |_: &SignIn, cx| {
 115            if let Some(client) = client.upgrade() {
 116                cx.spawn(
 117                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
 118                )
 119                .detach();
 120            }
 121        }
 122    });
 123
 124    cx.on_action({
 125        let client = client.clone();
 126        move |_: &SignOut, cx| {
 127            if let Some(client) = client.upgrade() {
 128                cx.spawn(|cx| async move {
 129                    client.disconnect(&cx);
 130                })
 131                .detach();
 132            }
 133        }
 134    });
 135
 136    cx.on_action({
 137        let client = client.clone();
 138        move |_: &Reconnect, cx| {
 139            if let Some(client) = client.upgrade() {
 140                cx.spawn(|cx| async move {
 141                    client.reconnect(&cx);
 142                })
 143                .detach();
 144            }
 145        }
 146    });
 147}
 148
 149struct GlobalClient(Arc<Client>);
 150
 151impl Global for GlobalClient {}
 152
 153pub struct Client {
 154    id: AtomicU64,
 155    peer: Arc<Peer>,
 156    http: Arc<HttpClientWithUrl>,
 157    telemetry: Arc<Telemetry>,
 158    state: RwLock<ClientState>,
 159
 160    #[allow(clippy::type_complexity)]
 161    #[cfg(any(test, feature = "test-support"))]
 162    authenticate: RwLock<
 163        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 164    >,
 165
 166    #[allow(clippy::type_complexity)]
 167    #[cfg(any(test, feature = "test-support"))]
 168    establish_connection: RwLock<
 169        Option<
 170            Box<
 171                dyn 'static
 172                    + Send
 173                    + Sync
 174                    + Fn(
 175                        &Credentials,
 176                        &AsyncAppContext,
 177                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 178            >,
 179        >,
 180    >,
 181}
 182
 183#[derive(Error, Debug)]
 184pub enum EstablishConnectionError {
 185    #[error("upgrade required")]
 186    UpgradeRequired,
 187    #[error("unauthorized")]
 188    Unauthorized,
 189    #[error("{0}")]
 190    Other(#[from] anyhow::Error),
 191    #[error("{0}")]
 192    Http(#[from] util::http::Error),
 193    #[error("{0}")]
 194    Io(#[from] std::io::Error),
 195    #[error("{0}")]
 196    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 197}
 198
 199impl From<WebsocketError> for EstablishConnectionError {
 200    fn from(error: WebsocketError) -> Self {
 201        if let WebsocketError::Http(response) = &error {
 202            match response.status() {
 203                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 204                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 205                _ => {}
 206            }
 207        }
 208        EstablishConnectionError::Other(error.into())
 209    }
 210}
 211
 212impl EstablishConnectionError {
 213    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 214        Self::Other(error.into())
 215    }
 216}
 217
 218#[derive(Copy, Clone, Debug, PartialEq)]
 219pub enum Status {
 220    SignedOut,
 221    UpgradeRequired,
 222    Authenticating,
 223    Connecting,
 224    ConnectionError,
 225    Connected {
 226        peer_id: PeerId,
 227        connection_id: ConnectionId,
 228    },
 229    ConnectionLost,
 230    Reauthenticating,
 231    Reconnecting,
 232    ReconnectionError {
 233        next_reconnection: Instant,
 234    },
 235}
 236
 237impl Status {
 238    pub fn is_connected(&self) -> bool {
 239        matches!(self, Self::Connected { .. })
 240    }
 241
 242    pub fn is_signed_out(&self) -> bool {
 243        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 244    }
 245}
 246
 247struct ClientState {
 248    credentials: Option<Credentials>,
 249    status: (watch::Sender<Status>, watch::Receiver<Status>),
 250    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 251    _reconnect_task: Option<Task<()>>,
 252    reconnect_interval: Duration,
 253    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 254    models_by_message_type: HashMap<TypeId, AnyWeakModel>,
 255    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 256    #[allow(clippy::type_complexity)]
 257    message_handlers: HashMap<
 258        TypeId,
 259        Arc<
 260            dyn Send
 261                + Sync
 262                + Fn(
 263                    AnyModel,
 264                    Box<dyn AnyTypedEnvelope>,
 265                    &Arc<Client>,
 266                    AsyncAppContext,
 267                ) -> LocalBoxFuture<'static, Result<()>>,
 268        >,
 269    >,
 270}
 271
 272enum WeakSubscriber {
 273    Entity { handle: AnyWeakModel },
 274    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 275}
 276
 277#[derive(Clone, Debug)]
 278pub struct Credentials {
 279    pub user_id: u64,
 280    pub access_token: String,
 281}
 282
 283impl Default for ClientState {
 284    fn default() -> Self {
 285        Self {
 286            credentials: None,
 287            status: watch::channel_with(Status::SignedOut),
 288            entity_id_extractors: Default::default(),
 289            _reconnect_task: None,
 290            reconnect_interval: Duration::from_secs(5),
 291            models_by_message_type: Default::default(),
 292            entities_by_type_and_remote_id: Default::default(),
 293            entity_types_by_message_type: Default::default(),
 294            message_handlers: Default::default(),
 295        }
 296    }
 297}
 298
 299pub enum Subscription {
 300    Entity {
 301        client: Weak<Client>,
 302        id: (TypeId, u64),
 303    },
 304    Message {
 305        client: Weak<Client>,
 306        id: TypeId,
 307    },
 308}
 309
 310impl Drop for Subscription {
 311    fn drop(&mut self) {
 312        match self {
 313            Subscription::Entity { client, id } => {
 314                if let Some(client) = client.upgrade() {
 315                    let mut state = client.state.write();
 316                    let _ = state.entities_by_type_and_remote_id.remove(id);
 317                }
 318            }
 319            Subscription::Message { client, id } => {
 320                if let Some(client) = client.upgrade() {
 321                    let mut state = client.state.write();
 322                    let _ = state.entity_types_by_message_type.remove(id);
 323                    let _ = state.message_handlers.remove(id);
 324                }
 325            }
 326        }
 327    }
 328}
 329
 330pub struct PendingEntitySubscription<T: 'static> {
 331    client: Arc<Client>,
 332    remote_id: u64,
 333    _entity_type: PhantomData<T>,
 334    consumed: bool,
 335}
 336
 337impl<T: 'static> PendingEntitySubscription<T> {
 338    pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
 339        self.consumed = true;
 340        let mut state = self.client.state.write();
 341        let id = (TypeId::of::<T>(), self.remote_id);
 342        let Some(WeakSubscriber::Pending(messages)) =
 343            state.entities_by_type_and_remote_id.remove(&id)
 344        else {
 345            unreachable!()
 346        };
 347
 348        state.entities_by_type_and_remote_id.insert(
 349            id,
 350            WeakSubscriber::Entity {
 351                handle: model.downgrade().into(),
 352            },
 353        );
 354        drop(state);
 355        for message in messages {
 356            self.client.handle_message(message, cx);
 357        }
 358        Subscription::Entity {
 359            client: Arc::downgrade(&self.client),
 360            id,
 361        }
 362    }
 363}
 364
 365impl<T: 'static> Drop for PendingEntitySubscription<T> {
 366    fn drop(&mut self) {
 367        if !self.consumed {
 368            let mut state = self.client.state.write();
 369            if let Some(WeakSubscriber::Pending(messages)) = state
 370                .entities_by_type_and_remote_id
 371                .remove(&(TypeId::of::<T>(), self.remote_id))
 372            {
 373                for message in messages {
 374                    log::info!("unhandled message {}", message.payload_type_name());
 375                }
 376            }
 377        }
 378    }
 379}
 380
 381#[derive(Copy, Clone)]
 382pub struct TelemetrySettings {
 383    pub diagnostics: bool,
 384    pub metrics: bool,
 385}
 386
 387/// Control what info is collected by Zed.
 388#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 389pub struct TelemetrySettingsContent {
 390    /// Send debug info like crash reports.
 391    ///
 392    /// Default: true
 393    pub diagnostics: Option<bool>,
 394    /// Send anonymized usage data like what languages you're using Zed with.
 395    ///
 396    /// Default: true
 397    pub metrics: Option<bool>,
 398}
 399
 400impl settings::Settings for TelemetrySettings {
 401    const KEY: Option<&'static str> = Some("telemetry");
 402
 403    type FileContent = TelemetrySettingsContent;
 404
 405    fn load(
 406        default_value: &Self::FileContent,
 407        user_values: &[&Self::FileContent],
 408        _: &mut AppContext,
 409    ) -> Result<Self> {
 410        Ok(Self {
 411            diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
 412                default_value
 413                    .diagnostics
 414                    .ok_or_else(Self::missing_default)?,
 415            ),
 416            metrics: user_values
 417                .first()
 418                .and_then(|v| v.metrics)
 419                .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
 420        })
 421    }
 422}
 423
 424impl Client {
 425    pub fn new(
 426        clock: Arc<dyn SystemClock>,
 427        http: Arc<HttpClientWithUrl>,
 428        cx: &mut AppContext,
 429    ) -> Arc<Self> {
 430        Arc::new(Self {
 431            id: AtomicU64::new(0),
 432            peer: Peer::new(0),
 433            telemetry: Telemetry::new(clock, http.clone(), cx),
 434            http,
 435            state: Default::default(),
 436
 437            #[cfg(any(test, feature = "test-support"))]
 438            authenticate: Default::default(),
 439            #[cfg(any(test, feature = "test-support"))]
 440            establish_connection: Default::default(),
 441        })
 442    }
 443
 444    pub fn id(&self) -> u64 {
 445        self.id.load(std::sync::atomic::Ordering::SeqCst)
 446    }
 447
 448    pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
 449        self.http.clone()
 450    }
 451
 452    pub fn set_id(&self, id: u64) -> &Self {
 453        self.id.store(id, std::sync::atomic::Ordering::SeqCst);
 454        self
 455    }
 456
 457    #[cfg(any(test, feature = "test-support"))]
 458    pub fn teardown(&self) {
 459        let mut state = self.state.write();
 460        state._reconnect_task.take();
 461        state.message_handlers.clear();
 462        state.models_by_message_type.clear();
 463        state.entities_by_type_and_remote_id.clear();
 464        state.entity_id_extractors.clear();
 465        self.peer.teardown();
 466    }
 467
 468    #[cfg(any(test, feature = "test-support"))]
 469    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 470    where
 471        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 472    {
 473        *self.authenticate.write() = Some(Box::new(authenticate));
 474        self
 475    }
 476
 477    #[cfg(any(test, feature = "test-support"))]
 478    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 479    where
 480        F: 'static
 481            + Send
 482            + Sync
 483            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 484    {
 485        *self.establish_connection.write() = Some(Box::new(connect));
 486        self
 487    }
 488
 489    pub fn global(cx: &AppContext) -> Arc<Self> {
 490        cx.global::<GlobalClient>().0.clone()
 491    }
 492    pub fn set_global(client: Arc<Client>, cx: &mut AppContext) {
 493        cx.set_global(GlobalClient(client))
 494    }
 495
 496    pub fn user_id(&self) -> Option<u64> {
 497        self.state
 498            .read()
 499            .credentials
 500            .as_ref()
 501            .map(|credentials| credentials.user_id)
 502    }
 503
 504    pub fn peer_id(&self) -> Option<PeerId> {
 505        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 506            Some(*peer_id)
 507        } else {
 508            None
 509        }
 510    }
 511
 512    pub fn status(&self) -> watch::Receiver<Status> {
 513        self.state.read().status.1.clone()
 514    }
 515
 516    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 517        log::info!("set status on client {}: {:?}", self.id(), status);
 518        let mut state = self.state.write();
 519        *state.status.0.borrow_mut() = status;
 520
 521        match status {
 522            Status::Connected { .. } => {
 523                state._reconnect_task = None;
 524            }
 525            Status::ConnectionLost => {
 526                let this = self.clone();
 527                let reconnect_interval = state.reconnect_interval;
 528                state._reconnect_task = Some(cx.spawn(move |cx| async move {
 529                    #[cfg(any(test, feature = "test-support"))]
 530                    let mut rng = StdRng::seed_from_u64(0);
 531                    #[cfg(not(any(test, feature = "test-support")))]
 532                    let mut rng = StdRng::from_entropy();
 533
 534                    let mut delay = INITIAL_RECONNECTION_DELAY;
 535                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 536                        log::error!("failed to connect {}", error);
 537                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 538                            this.set_status(
 539                                Status::ReconnectionError {
 540                                    next_reconnection: Instant::now() + delay,
 541                                },
 542                                &cx,
 543                            );
 544                            cx.background_executor().timer(delay).await;
 545                            delay = delay
 546                                .mul_f32(rng.gen_range(1.0..=2.0))
 547                                .min(reconnect_interval);
 548                        } else {
 549                            break;
 550                        }
 551                    }
 552                }));
 553            }
 554            Status::SignedOut | Status::UpgradeRequired => {
 555                self.telemetry.set_authenticated_user_info(None, false);
 556                state._reconnect_task.take();
 557            }
 558            _ => {}
 559        }
 560    }
 561
 562    pub fn subscribe_to_entity<T>(
 563        self: &Arc<Self>,
 564        remote_id: u64,
 565    ) -> Result<PendingEntitySubscription<T>>
 566    where
 567        T: 'static,
 568    {
 569        let id = (TypeId::of::<T>(), remote_id);
 570
 571        let mut state = self.state.write();
 572        if state.entities_by_type_and_remote_id.contains_key(&id) {
 573            return Err(anyhow!("already subscribed to entity"));
 574        }
 575
 576        state
 577            .entities_by_type_and_remote_id
 578            .insert(id, WeakSubscriber::Pending(Default::default()));
 579
 580        Ok(PendingEntitySubscription {
 581            client: self.clone(),
 582            remote_id,
 583            consumed: false,
 584            _entity_type: PhantomData,
 585        })
 586    }
 587
 588    #[track_caller]
 589    pub fn add_message_handler<M, E, H, F>(
 590        self: &Arc<Self>,
 591        entity: WeakModel<E>,
 592        handler: H,
 593    ) -> Subscription
 594    where
 595        M: EnvelopedMessage,
 596        E: 'static,
 597        H: 'static
 598            + Sync
 599            + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
 600            + Send
 601            + Sync,
 602        F: 'static + Future<Output = Result<()>>,
 603    {
 604        let message_type_id = TypeId::of::<M>();
 605        let mut state = self.state.write();
 606        state
 607            .models_by_message_type
 608            .insert(message_type_id, entity.into());
 609
 610        let prev_handler = state.message_handlers.insert(
 611            message_type_id,
 612            Arc::new(move |subscriber, envelope, client, cx| {
 613                let subscriber = subscriber.downcast::<E>().unwrap();
 614                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 615                handler(subscriber, *envelope, client.clone(), cx).boxed_local()
 616            }),
 617        );
 618        if prev_handler.is_some() {
 619            let location = std::panic::Location::caller();
 620            panic!(
 621                "{}:{} registered handler for the same message {} twice",
 622                location.file(),
 623                location.line(),
 624                std::any::type_name::<M>()
 625            );
 626        }
 627
 628        Subscription::Message {
 629            client: Arc::downgrade(self),
 630            id: message_type_id,
 631        }
 632    }
 633
 634    pub fn add_request_handler<M, E, H, F>(
 635        self: &Arc<Self>,
 636        model: WeakModel<E>,
 637        handler: H,
 638    ) -> Subscription
 639    where
 640        M: RequestMessage,
 641        E: 'static,
 642        H: 'static
 643            + Sync
 644            + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
 645            + Send
 646            + Sync,
 647        F: 'static + Future<Output = Result<M::Response>>,
 648    {
 649        self.add_message_handler(model, move |handle, envelope, this, cx| {
 650            Self::respond_to_request(
 651                envelope.receipt(),
 652                handler(handle, envelope, this.clone(), cx),
 653                this,
 654            )
 655        })
 656    }
 657
 658    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 659    where
 660        M: EntityMessage,
 661        E: 'static,
 662        H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
 663        F: 'static + Future<Output = Result<()>>,
 664    {
 665        self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
 666            handler(subscriber.downcast::<E>().unwrap(), message, client, cx)
 667        })
 668    }
 669
 670    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 671    where
 672        M: EntityMessage,
 673        E: 'static,
 674        H: 'static + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
 675        F: 'static + Future<Output = Result<()>>,
 676    {
 677        let model_type_id = TypeId::of::<E>();
 678        let message_type_id = TypeId::of::<M>();
 679
 680        let mut state = self.state.write();
 681        state
 682            .entity_types_by_message_type
 683            .insert(message_type_id, model_type_id);
 684        state
 685            .entity_id_extractors
 686            .entry(message_type_id)
 687            .or_insert_with(|| {
 688                |envelope| {
 689                    envelope
 690                        .as_any()
 691                        .downcast_ref::<TypedEnvelope<M>>()
 692                        .unwrap()
 693                        .payload
 694                        .remote_entity_id()
 695                }
 696            });
 697        let prev_handler = state.message_handlers.insert(
 698            message_type_id,
 699            Arc::new(move |handle, envelope, client, cx| {
 700                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 701                handler(handle, *envelope, client.clone(), cx).boxed_local()
 702            }),
 703        );
 704        if prev_handler.is_some() {
 705            panic!("registered handler for the same message twice");
 706        }
 707    }
 708
 709    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 710    where
 711        M: EntityMessage + RequestMessage,
 712        E: 'static,
 713        H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
 714        F: 'static + Future<Output = Result<M::Response>>,
 715    {
 716        self.add_model_message_handler(move |entity, envelope, client, cx| {
 717            Self::respond_to_request::<M, _>(
 718                envelope.receipt(),
 719                handler(entity, envelope, client.clone(), cx),
 720                client,
 721            )
 722        })
 723    }
 724
 725    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 726        receipt: Receipt<T>,
 727        response: F,
 728        client: Arc<Self>,
 729    ) -> Result<()> {
 730        match response.await {
 731            Ok(response) => {
 732                client.respond(receipt, response)?;
 733                Ok(())
 734            }
 735            Err(error) => {
 736                client.respond_with_error(receipt, error.to_proto())?;
 737                Err(error)
 738            }
 739        }
 740    }
 741
 742    pub async fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 743        read_credentials_from_keychain(cx).await.is_some()
 744    }
 745
 746    #[async_recursion(?Send)]
 747    pub async fn authenticate_and_connect(
 748        self: &Arc<Self>,
 749        try_keychain: bool,
 750        cx: &AsyncAppContext,
 751    ) -> anyhow::Result<()> {
 752        let was_disconnected = match *self.status().borrow() {
 753            Status::SignedOut => true,
 754            Status::ConnectionError
 755            | Status::ConnectionLost
 756            | Status::Authenticating { .. }
 757            | Status::Reauthenticating { .. }
 758            | Status::ReconnectionError { .. } => false,
 759            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 760                return Ok(())
 761            }
 762            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 763        };
 764
 765        if was_disconnected {
 766            self.set_status(Status::Authenticating, cx);
 767        } else {
 768            self.set_status(Status::Reauthenticating, cx)
 769        }
 770
 771        let mut read_from_keychain = false;
 772        let mut credentials = self.state.read().credentials.clone();
 773        if credentials.is_none() && try_keychain {
 774            credentials = read_credentials_from_keychain(cx).await;
 775            read_from_keychain = credentials.is_some();
 776        }
 777        if credentials.is_none() {
 778            let mut status_rx = self.status();
 779            let _ = status_rx.next().await;
 780            futures::select_biased! {
 781                authenticate = self.authenticate(cx).fuse() => {
 782                    match authenticate {
 783                        Ok(creds) => credentials = Some(creds),
 784                        Err(err) => {
 785                            self.set_status(Status::ConnectionError, cx);
 786                            return Err(err);
 787                        }
 788                    }
 789                }
 790                _ = status_rx.next().fuse() => {
 791                    return Err(anyhow!("authentication canceled"));
 792                }
 793            }
 794        }
 795        let credentials = credentials.unwrap();
 796        self.set_id(credentials.user_id);
 797
 798        if was_disconnected {
 799            self.set_status(Status::Connecting, cx);
 800        } else {
 801            self.set_status(Status::Reconnecting, cx);
 802        }
 803
 804        let mut timeout =
 805            futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
 806        futures::select_biased! {
 807            connection = self.establish_connection(&credentials, cx).fuse() => {
 808                match connection {
 809                    Ok(conn) => {
 810                        self.state.write().credentials = Some(credentials.clone());
 811                        if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 812                            write_credentials_to_keychain(credentials, cx).await.log_err();
 813                        }
 814
 815                        futures::select_biased! {
 816                            result = self.set_connection(conn, cx).fuse() => result,
 817                            _ = timeout => {
 818                                self.set_status(Status::ConnectionError, cx);
 819                                Err(anyhow!("timed out waiting on hello message from server"))
 820                            }
 821                        }
 822                    }
 823                    Err(EstablishConnectionError::Unauthorized) => {
 824                        self.state.write().credentials.take();
 825                        if read_from_keychain {
 826                            delete_credentials_from_keychain(cx).await.log_err();
 827                            self.set_status(Status::SignedOut, cx);
 828                            self.authenticate_and_connect(false, cx).await
 829                        } else {
 830                            self.set_status(Status::ConnectionError, cx);
 831                            Err(EstablishConnectionError::Unauthorized)?
 832                        }
 833                    }
 834                    Err(EstablishConnectionError::UpgradeRequired) => {
 835                        self.set_status(Status::UpgradeRequired, cx);
 836                        Err(EstablishConnectionError::UpgradeRequired)?
 837                    }
 838                    Err(error) => {
 839                        self.set_status(Status::ConnectionError, cx);
 840                        Err(error)?
 841                    }
 842                }
 843            }
 844            _ = &mut timeout => {
 845                self.set_status(Status::ConnectionError, cx);
 846                Err(anyhow!("timed out trying to establish connection"))
 847            }
 848        }
 849    }
 850
 851    async fn set_connection(
 852        self: &Arc<Self>,
 853        conn: Connection,
 854        cx: &AsyncAppContext,
 855    ) -> Result<()> {
 856        let executor = cx.background_executor();
 857        log::info!("add connection to peer");
 858        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
 859            let executor = executor.clone();
 860            move |duration| executor.timer(duration)
 861        });
 862        let handle_io = executor.spawn(handle_io);
 863
 864        let peer_id = async {
 865            log::info!("waiting for server hello");
 866            let message = incoming
 867                .next()
 868                .await
 869                .ok_or_else(|| anyhow!("no hello message received"))?;
 870            log::info!("got server hello");
 871            let hello_message_type_name = message.payload_type_name().to_string();
 872            let hello = message
 873                .into_any()
 874                .downcast::<TypedEnvelope<proto::Hello>>()
 875                .map_err(|_| {
 876                    anyhow!(
 877                        "invalid hello message received: {:?}",
 878                        hello_message_type_name
 879                    )
 880                })?;
 881            let peer_id = hello
 882                .payload
 883                .peer_id
 884                .ok_or_else(|| anyhow!("invalid peer id"))?;
 885            Ok(peer_id)
 886        };
 887
 888        let peer_id = match peer_id.await {
 889            Ok(peer_id) => peer_id,
 890            Err(error) => {
 891                self.peer.disconnect(connection_id);
 892                return Err(error);
 893            }
 894        };
 895
 896        log::info!(
 897            "set status to connected (connection id: {:?}, peer id: {:?})",
 898            connection_id,
 899            peer_id
 900        );
 901        self.set_status(
 902            Status::Connected {
 903                peer_id,
 904                connection_id,
 905            },
 906            cx,
 907        );
 908
 909        cx.spawn({
 910            let this = self.clone();
 911            |cx| {
 912                async move {
 913                    while let Some(message) = incoming.next().await {
 914                        this.handle_message(message, &cx);
 915                        // Don't starve the main thread when receiving lots of messages at once.
 916                        smol::future::yield_now().await;
 917                    }
 918                }
 919            }
 920        })
 921        .detach();
 922
 923        cx.spawn({
 924            let this = self.clone();
 925            move |cx| async move {
 926                match handle_io.await {
 927                    Ok(()) => {
 928                        if *this.status().borrow()
 929                            == (Status::Connected {
 930                                connection_id,
 931                                peer_id,
 932                            })
 933                        {
 934                            this.set_status(Status::SignedOut, &cx);
 935                        }
 936                    }
 937                    Err(err) => {
 938                        log::error!("connection error: {:?}", err);
 939                        this.set_status(Status::ConnectionLost, &cx);
 940                    }
 941                }
 942            }
 943        })
 944        .detach();
 945
 946        Ok(())
 947    }
 948
 949    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 950        #[cfg(any(test, feature = "test-support"))]
 951        if let Some(callback) = self.authenticate.read().as_ref() {
 952            return callback(cx);
 953        }
 954
 955        self.authenticate_with_browser(cx)
 956    }
 957
 958    fn establish_connection(
 959        self: &Arc<Self>,
 960        credentials: &Credentials,
 961        cx: &AsyncAppContext,
 962    ) -> Task<Result<Connection, EstablishConnectionError>> {
 963        #[cfg(any(test, feature = "test-support"))]
 964        if let Some(callback) = self.establish_connection.read().as_ref() {
 965            return callback(credentials, cx);
 966        }
 967
 968        self.establish_websocket_connection(credentials, cx)
 969    }
 970
 971    async fn get_rpc_url(
 972        http: Arc<HttpClientWithUrl>,
 973        release_channel: Option<ReleaseChannel>,
 974    ) -> Result<Url> {
 975        if let Some(url) = &*ZED_RPC_URL {
 976            return Url::parse(url).context("invalid rpc url");
 977        }
 978
 979        let mut url = http.build_url("/rpc");
 980        if let Some(preview_param) =
 981            release_channel.and_then(|channel| channel.release_query_param())
 982        {
 983            url += "?";
 984            url += preview_param;
 985        }
 986        let response = http.get(&url, Default::default(), false).await?;
 987        let collab_url = if response.status().is_redirection() {
 988            response
 989                .headers()
 990                .get("Location")
 991                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 992                .to_str()
 993                .map_err(EstablishConnectionError::other)?
 994                .to_string()
 995        } else {
 996            Err(anyhow!(
 997                "unexpected /rpc response status {}",
 998                response.status()
 999            ))?
1000        };
1001
1002        Url::parse(&collab_url).context("invalid rpc url")
1003    }
1004
1005    fn establish_websocket_connection(
1006        self: &Arc<Self>,
1007        credentials: &Credentials,
1008        cx: &AsyncAppContext,
1009    ) -> Task<Result<Connection, EstablishConnectionError>> {
1010        let release_channel = cx
1011            .update(|cx| ReleaseChannel::try_global(cx))
1012            .ok()
1013            .flatten();
1014        let app_version = cx
1015            .update(|cx| AppVersion::global(cx).to_string())
1016            .ok()
1017            .unwrap_or_default();
1018
1019        let request = Request::builder()
1020            .header(
1021                "Authorization",
1022                format!("{} {}", credentials.user_id, credentials.access_token),
1023            )
1024            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION)
1025            .header("x-zed-app-version", app_version)
1026            .header(
1027                "x-zed-release-channel",
1028                release_channel.map(|r| r.dev_name()).unwrap_or("unknown"),
1029            );
1030
1031        let http = self.http.clone();
1032        cx.background_executor().spawn(async move {
1033            let mut rpc_url = Self::get_rpc_url(http, release_channel).await?;
1034            let rpc_host = rpc_url
1035                .host_str()
1036                .zip(rpc_url.port_or_known_default())
1037                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1038            let stream = smol::net::TcpStream::connect(rpc_host).await?;
1039
1040            log::info!("connected to rpc endpoint {}", rpc_url);
1041
1042            match rpc_url.scheme() {
1043                "https" => {
1044                    rpc_url.set_scheme("wss").unwrap();
1045                    let request = request.uri(rpc_url.as_str()).body(())?;
1046                    let (stream, _) =
1047                        async_tungstenite::async_std::client_async_tls(request, stream).await?;
1048                    Ok(Connection::new(
1049                        stream
1050                            .map_err(|error| anyhow!(error))
1051                            .sink_map_err(|error| anyhow!(error)),
1052                    ))
1053                }
1054                "http" => {
1055                    rpc_url.set_scheme("ws").unwrap();
1056                    let request = request.uri(rpc_url.as_str()).body(())?;
1057                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1058                    Ok(Connection::new(
1059                        stream
1060                            .map_err(|error| anyhow!(error))
1061                            .sink_map_err(|error| anyhow!(error)),
1062                    ))
1063                }
1064                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1065            }
1066        })
1067    }
1068
1069    pub fn authenticate_with_browser(
1070        self: &Arc<Self>,
1071        cx: &AsyncAppContext,
1072    ) -> Task<Result<Credentials>> {
1073        let http = self.http.clone();
1074        cx.spawn(|cx| async move {
1075            let background = cx.background_executor().clone();
1076
1077            let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1078            cx.update(|cx| {
1079                cx.spawn(move |cx| async move {
1080                    let url = open_url_rx.await?;
1081                    cx.update(|cx| cx.open_url(&url))
1082                })
1083                .detach_and_log_err(cx);
1084            })
1085            .log_err();
1086
1087            let credentials = background
1088                .clone()
1089                .spawn(async move {
1090                    // Generate a pair of asymmetric encryption keys. The public key will be used by the
1091                    // zed server to encrypt the user's access token, so that it can'be intercepted by
1092                    // any other app running on the user's device.
1093                    let (public_key, private_key) =
1094                        rpc::auth::keypair().expect("failed to generate keypair for auth");
1095                    let public_key_string = String::try_from(public_key)
1096                        .expect("failed to serialize public key for auth");
1097
1098                    if let Some((login, token)) =
1099                        IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1100                    {
1101                        return Self::authenticate_as_admin(http, login.clone(), token.clone())
1102                            .await;
1103                    }
1104
1105                    // Start an HTTP server to receive the redirect from Zed's sign-in page.
1106                    let server =
1107                        tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1108                    let port = server.server_addr().port();
1109
1110                    // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1111                    // that the user is signing in from a Zed app running on the same device.
1112                    let mut url = http.build_url(&format!(
1113                        "/native_app_signin?native_app_port={}&native_app_public_key={}",
1114                        port, public_key_string
1115                    ));
1116
1117                    if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1118                        log::info!("impersonating user @{}", impersonate_login);
1119                        write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1120                    }
1121
1122                    open_url_tx.send(url).log_err();
1123
1124                    // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1125                    // access token from the query params.
1126                    //
1127                    // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1128                    // custom URL scheme instead of this local HTTP server.
1129                    let (user_id, access_token) = background
1130                        .spawn(async move {
1131                            for _ in 0..100 {
1132                                if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1133                                    let path = req.url();
1134                                    let mut user_id = None;
1135                                    let mut access_token = None;
1136                                    let url = Url::parse(&format!("http://example.com{}", path))
1137                                        .context("failed to parse login notification url")?;
1138                                    for (key, value) in url.query_pairs() {
1139                                        if key == "access_token" {
1140                                            access_token = Some(value.to_string());
1141                                        } else if key == "user_id" {
1142                                            user_id = Some(value.to_string());
1143                                        }
1144                                    }
1145
1146                                    let post_auth_url =
1147                                        http.build_url("/native_app_signin_succeeded");
1148                                    req.respond(
1149                                        tiny_http::Response::empty(302).with_header(
1150                                            tiny_http::Header::from_bytes(
1151                                                &b"Location"[..],
1152                                                post_auth_url.as_bytes(),
1153                                            )
1154                                            .unwrap(),
1155                                        ),
1156                                    )
1157                                    .context("failed to respond to login http request")?;
1158                                    return Ok((
1159                                        user_id
1160                                            .ok_or_else(|| anyhow!("missing user_id parameter"))?,
1161                                        access_token.ok_or_else(|| {
1162                                            anyhow!("missing access_token parameter")
1163                                        })?,
1164                                    ));
1165                                }
1166                            }
1167
1168                            Err(anyhow!("didn't receive login redirect"))
1169                        })
1170                        .await?;
1171
1172                    let access_token = private_key
1173                        .decrypt_string(&access_token)
1174                        .context("failed to decrypt access token")?;
1175
1176                    Ok(Credentials {
1177                        user_id: user_id.parse()?,
1178                        access_token,
1179                    })
1180                })
1181                .await?;
1182
1183            cx.update(|cx| cx.activate(true))?;
1184            Ok(credentials)
1185        })
1186    }
1187
1188    async fn authenticate_as_admin(
1189        http: Arc<HttpClientWithUrl>,
1190        login: String,
1191        mut api_token: String,
1192    ) -> Result<Credentials> {
1193        #[derive(Deserialize)]
1194        struct AuthenticatedUserResponse {
1195            user: User,
1196        }
1197
1198        #[derive(Deserialize)]
1199        struct User {
1200            id: u64,
1201        }
1202
1203        // Use the collab server's admin API to retrieve the id
1204        // of the impersonated user.
1205        let mut url = Self::get_rpc_url(http.clone(), None).await?;
1206        url.set_path("/user");
1207        url.set_query(Some(&format!("github_login={login}")));
1208        let request = Request::get(url.as_str())
1209            .header("Authorization", format!("token {api_token}"))
1210            .body("".into())?;
1211
1212        let mut response = http.send(request).await?;
1213        let mut body = String::new();
1214        response.body_mut().read_to_string(&mut body).await?;
1215        if !response.status().is_success() {
1216            Err(anyhow!(
1217                "admin user request failed {} - {}",
1218                response.status().as_u16(),
1219                body,
1220            ))?;
1221        }
1222        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1223
1224        // Use the admin API token to authenticate as the impersonated user.
1225        api_token.insert_str(0, "ADMIN_TOKEN:");
1226        Ok(Credentials {
1227            user_id: response.user.id,
1228            access_token: api_token,
1229        })
1230    }
1231
1232    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1233        self.peer.teardown();
1234        self.set_status(Status::SignedOut, cx);
1235    }
1236
1237    pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1238        self.peer.teardown();
1239        self.set_status(Status::ConnectionLost, cx);
1240    }
1241
1242    fn connection_id(&self) -> Result<ConnectionId> {
1243        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1244            Ok(connection_id)
1245        } else {
1246            Err(anyhow!("not connected"))
1247        }
1248    }
1249
1250    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1251        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1252        self.peer.send(self.connection_id()?, message)
1253    }
1254
1255    pub fn request<T: RequestMessage>(
1256        &self,
1257        request: T,
1258    ) -> impl Future<Output = Result<T::Response>> {
1259        self.request_envelope(request)
1260            .map_ok(|envelope| envelope.payload)
1261    }
1262
1263    pub fn request_envelope<T: RequestMessage>(
1264        &self,
1265        request: T,
1266    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1267        let client_id = self.id();
1268        log::debug!(
1269            "rpc request start. client_id:{}. name:{}",
1270            client_id,
1271            T::NAME
1272        );
1273        let response = self
1274            .connection_id()
1275            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1276        async move {
1277            let response = response?.await;
1278            log::debug!(
1279                "rpc request finish. client_id:{}. name:{}",
1280                client_id,
1281                T::NAME
1282            );
1283            response
1284        }
1285    }
1286
1287    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1288        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1289        self.peer.respond(receipt, response)
1290    }
1291
1292    fn respond_with_error<T: RequestMessage>(
1293        &self,
1294        receipt: Receipt<T>,
1295        error: proto::Error,
1296    ) -> Result<()> {
1297        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1298        self.peer.respond_with_error(receipt, error)
1299    }
1300
1301    fn handle_message(
1302        self: &Arc<Client>,
1303        message: Box<dyn AnyTypedEnvelope>,
1304        cx: &AsyncAppContext,
1305    ) {
1306        let mut state = self.state.write();
1307        let type_name = message.payload_type_name();
1308        let payload_type_id = message.payload_type_id();
1309        let sender_id = message.original_sender_id();
1310
1311        let mut subscriber = None;
1312
1313        if let Some(handle) = state
1314            .models_by_message_type
1315            .get(&payload_type_id)
1316            .and_then(|handle| handle.upgrade())
1317        {
1318            subscriber = Some(handle);
1319        } else if let Some((extract_entity_id, entity_type_id)) =
1320            state.entity_id_extractors.get(&payload_type_id).zip(
1321                state
1322                    .entity_types_by_message_type
1323                    .get(&payload_type_id)
1324                    .copied(),
1325            )
1326        {
1327            let entity_id = (extract_entity_id)(message.as_ref());
1328
1329            match state
1330                .entities_by_type_and_remote_id
1331                .get_mut(&(entity_type_id, entity_id))
1332            {
1333                Some(WeakSubscriber::Pending(pending)) => {
1334                    pending.push(message);
1335                    return;
1336                }
1337                Some(weak_subscriber) => match weak_subscriber {
1338                    WeakSubscriber::Entity { handle } => {
1339                        subscriber = handle.upgrade();
1340                    }
1341
1342                    WeakSubscriber::Pending(_) => {}
1343                },
1344                _ => {}
1345            }
1346        }
1347
1348        let subscriber = if let Some(subscriber) = subscriber {
1349            subscriber
1350        } else {
1351            log::info!("unhandled message {}", type_name);
1352            self.peer.respond_with_unhandled_message(message).log_err();
1353            return;
1354        };
1355
1356        let handler = state.message_handlers.get(&payload_type_id).cloned();
1357        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1358        // It also ensures we don't hold the lock while yielding back to the executor, as
1359        // that might cause the executor thread driving this future to block indefinitely.
1360        drop(state);
1361
1362        if let Some(handler) = handler {
1363            let future = handler(subscriber, message, self, cx.clone());
1364            let client_id = self.id();
1365            log::debug!(
1366                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1367                client_id,
1368                sender_id,
1369                type_name
1370            );
1371            cx.spawn(move |_| async move {
1372                    match future.await {
1373                        Ok(()) => {
1374                            log::debug!(
1375                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1376                                client_id,
1377                                sender_id,
1378                                type_name
1379                            );
1380                        }
1381                        Err(error) => {
1382                            log::error!(
1383                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1384                                client_id,
1385                                sender_id,
1386                                type_name,
1387                                error
1388                            );
1389                        }
1390                    }
1391                })
1392                .detach();
1393        } else {
1394            log::info!("unhandled message {}", type_name);
1395            self.peer.respond_with_unhandled_message(message).log_err();
1396        }
1397    }
1398
1399    pub fn telemetry(&self) -> &Arc<Telemetry> {
1400        &self.telemetry
1401    }
1402}
1403
1404async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1405    if IMPERSONATE_LOGIN.is_some() {
1406        return None;
1407    }
1408
1409    let (user_id, access_token) = cx
1410        .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
1411        .log_err()?
1412        .await
1413        .log_err()??;
1414
1415    Some(Credentials {
1416        user_id: user_id.parse().ok()?,
1417        access_token: String::from_utf8(access_token).ok()?,
1418    })
1419}
1420
1421async fn write_credentials_to_keychain(
1422    credentials: Credentials,
1423    cx: &AsyncAppContext,
1424) -> Result<()> {
1425    cx.update(move |cx| {
1426        cx.write_credentials(
1427            &ClientSettings::get_global(cx).server_url,
1428            &credentials.user_id.to_string(),
1429            credentials.access_token.as_bytes(),
1430        )
1431    })?
1432    .await
1433}
1434
1435async fn delete_credentials_from_keychain(cx: &AsyncAppContext) -> Result<()> {
1436    cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
1437        .await
1438}
1439
1440/// prefix for the zed:// url scheme
1441pub static ZED_URL_SCHEME: &str = "zed";
1442
1443/// Parses the given link into a Zed link.
1444///
1445/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1446/// Returns [`None`] otherwise.
1447pub fn parse_zed_link<'a>(link: &'a str, cx: &AppContext) -> Option<&'a str> {
1448    let server_url = &ClientSettings::get_global(cx).server_url;
1449    if let Some(stripped) = link
1450        .strip_prefix(server_url)
1451        .and_then(|result| result.strip_prefix('/'))
1452    {
1453        return Some(stripped);
1454    }
1455    if let Some(stripped) = link
1456        .strip_prefix(ZED_URL_SCHEME)
1457        .and_then(|result| result.strip_prefix("://"))
1458    {
1459        return Some(stripped);
1460    }
1461
1462    None
1463}
1464
1465#[cfg(test)]
1466mod tests {
1467    use super::*;
1468    use crate::test::FakeServer;
1469
1470    use clock::FakeSystemClock;
1471    use gpui::{BackgroundExecutor, Context, TestAppContext};
1472    use parking_lot::Mutex;
1473    use settings::SettingsStore;
1474    use std::future;
1475    use util::http::FakeHttpClient;
1476
1477    #[gpui::test(iterations = 10)]
1478    async fn test_reconnection(cx: &mut TestAppContext) {
1479        init_test(cx);
1480        let user_id = 5;
1481        let client = cx.update(|cx| {
1482            Client::new(
1483                Arc::new(FakeSystemClock::default()),
1484                FakeHttpClient::with_404_response(),
1485                cx,
1486            )
1487        });
1488        let server = FakeServer::for_client(user_id, &client, cx).await;
1489        let mut status = client.status();
1490        assert!(matches!(
1491            status.next().await,
1492            Some(Status::Connected { .. })
1493        ));
1494        assert_eq!(server.auth_count(), 1);
1495
1496        server.forbid_connections();
1497        server.disconnect();
1498        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1499
1500        server.allow_connections();
1501        cx.executor().advance_clock(Duration::from_secs(10));
1502        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1503        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1504
1505        server.forbid_connections();
1506        server.disconnect();
1507        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1508
1509        // Clear cached credentials after authentication fails
1510        server.roll_access_token();
1511        server.allow_connections();
1512        cx.executor().run_until_parked();
1513        cx.executor().advance_clock(Duration::from_secs(10));
1514        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1515        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1516    }
1517
1518    #[gpui::test(iterations = 10)]
1519    async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1520        init_test(cx);
1521        let user_id = 5;
1522        let client = cx.update(|cx| {
1523            Client::new(
1524                Arc::new(FakeSystemClock::default()),
1525                FakeHttpClient::with_404_response(),
1526                cx,
1527            )
1528        });
1529        let mut status = client.status();
1530
1531        // Time out when client tries to connect.
1532        client.override_authenticate(move |cx| {
1533            cx.background_executor().spawn(async move {
1534                Ok(Credentials {
1535                    user_id,
1536                    access_token: "token".into(),
1537                })
1538            })
1539        });
1540        client.override_establish_connection(|_, cx| {
1541            cx.background_executor().spawn(async move {
1542                future::pending::<()>().await;
1543                unreachable!()
1544            })
1545        });
1546        let auth_and_connect = cx.spawn({
1547            let client = client.clone();
1548            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1549        });
1550        executor.run_until_parked();
1551        assert!(matches!(status.next().await, Some(Status::Connecting)));
1552
1553        executor.advance_clock(CONNECTION_TIMEOUT);
1554        assert!(matches!(
1555            status.next().await,
1556            Some(Status::ConnectionError { .. })
1557        ));
1558        auth_and_connect.await.unwrap_err();
1559
1560        // Allow the connection to be established.
1561        let server = FakeServer::for_client(user_id, &client, cx).await;
1562        assert!(matches!(
1563            status.next().await,
1564            Some(Status::Connected { .. })
1565        ));
1566
1567        // Disconnect client.
1568        server.forbid_connections();
1569        server.disconnect();
1570        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1571
1572        // Time out when re-establishing the connection.
1573        server.allow_connections();
1574        client.override_establish_connection(|_, cx| {
1575            cx.background_executor().spawn(async move {
1576                future::pending::<()>().await;
1577                unreachable!()
1578            })
1579        });
1580        executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1581        assert!(matches!(
1582            status.next().await,
1583            Some(Status::Reconnecting { .. })
1584        ));
1585
1586        executor.advance_clock(CONNECTION_TIMEOUT);
1587        assert!(matches!(
1588            status.next().await,
1589            Some(Status::ReconnectionError { .. })
1590        ));
1591    }
1592
1593    #[gpui::test(iterations = 10)]
1594    async fn test_authenticating_more_than_once(
1595        cx: &mut TestAppContext,
1596        executor: BackgroundExecutor,
1597    ) {
1598        init_test(cx);
1599        let auth_count = Arc::new(Mutex::new(0));
1600        let dropped_auth_count = Arc::new(Mutex::new(0));
1601        let client = cx.update(|cx| {
1602            Client::new(
1603                Arc::new(FakeSystemClock::default()),
1604                FakeHttpClient::with_404_response(),
1605                cx,
1606            )
1607        });
1608        client.override_authenticate({
1609            let auth_count = auth_count.clone();
1610            let dropped_auth_count = dropped_auth_count.clone();
1611            move |cx| {
1612                let auth_count = auth_count.clone();
1613                let dropped_auth_count = dropped_auth_count.clone();
1614                cx.background_executor().spawn(async move {
1615                    *auth_count.lock() += 1;
1616                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1617                    future::pending::<()>().await;
1618                    unreachable!()
1619                })
1620            }
1621        });
1622
1623        let _authenticate = cx.spawn({
1624            let client = client.clone();
1625            move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1626        });
1627        executor.run_until_parked();
1628        assert_eq!(*auth_count.lock(), 1);
1629        assert_eq!(*dropped_auth_count.lock(), 0);
1630
1631        let _authenticate = cx.spawn({
1632            let client = client.clone();
1633            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1634        });
1635        executor.run_until_parked();
1636        assert_eq!(*auth_count.lock(), 2);
1637        assert_eq!(*dropped_auth_count.lock(), 1);
1638    }
1639
1640    #[gpui::test]
1641    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1642        init_test(cx);
1643        let user_id = 5;
1644        let client = cx.update(|cx| {
1645            Client::new(
1646                Arc::new(FakeSystemClock::default()),
1647                FakeHttpClient::with_404_response(),
1648                cx,
1649            )
1650        });
1651        let server = FakeServer::for_client(user_id, &client, cx).await;
1652
1653        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1654        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1655        client.add_model_message_handler(
1656            move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
1657                match model.update(&mut cx, |model, _| model.id).unwrap() {
1658                    1 => done_tx1.try_send(()).unwrap(),
1659                    2 => done_tx2.try_send(()).unwrap(),
1660                    _ => unreachable!(),
1661                }
1662                async { Ok(()) }
1663            },
1664        );
1665        let model1 = cx.new_model(|_| TestModel {
1666            id: 1,
1667            subscription: None,
1668        });
1669        let model2 = cx.new_model(|_| TestModel {
1670            id: 2,
1671            subscription: None,
1672        });
1673        let model3 = cx.new_model(|_| TestModel {
1674            id: 3,
1675            subscription: None,
1676        });
1677
1678        let _subscription1 = client
1679            .subscribe_to_entity(1)
1680            .unwrap()
1681            .set_model(&model1, &mut cx.to_async());
1682        let _subscription2 = client
1683            .subscribe_to_entity(2)
1684            .unwrap()
1685            .set_model(&model2, &mut cx.to_async());
1686        // Ensure dropping a subscription for the same entity type still allows receiving of
1687        // messages for other entity IDs of the same type.
1688        let subscription3 = client
1689            .subscribe_to_entity(3)
1690            .unwrap()
1691            .set_model(&model3, &mut cx.to_async());
1692        drop(subscription3);
1693
1694        server.send(proto::JoinProject { project_id: 1 });
1695        server.send(proto::JoinProject { project_id: 2 });
1696        done_rx1.next().await.unwrap();
1697        done_rx2.next().await.unwrap();
1698    }
1699
1700    #[gpui::test]
1701    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1702        init_test(cx);
1703        let user_id = 5;
1704        let client = cx.update(|cx| {
1705            Client::new(
1706                Arc::new(FakeSystemClock::default()),
1707                FakeHttpClient::with_404_response(),
1708                cx,
1709            )
1710        });
1711        let server = FakeServer::for_client(user_id, &client, cx).await;
1712
1713        let model = cx.new_model(|_| TestModel::default());
1714        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1715        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1716        let subscription1 = client.add_message_handler(
1717            model.downgrade(),
1718            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1719                done_tx1.try_send(()).unwrap();
1720                async { Ok(()) }
1721            },
1722        );
1723        drop(subscription1);
1724        let _subscription2 = client.add_message_handler(
1725            model.downgrade(),
1726            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1727                done_tx2.try_send(()).unwrap();
1728                async { Ok(()) }
1729            },
1730        );
1731        server.send(proto::Ping {});
1732        done_rx2.next().await.unwrap();
1733    }
1734
1735    #[gpui::test]
1736    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1737        init_test(cx);
1738        let user_id = 5;
1739        let client = cx.update(|cx| {
1740            Client::new(
1741                Arc::new(FakeSystemClock::default()),
1742                FakeHttpClient::with_404_response(),
1743                cx,
1744            )
1745        });
1746        let server = FakeServer::for_client(user_id, &client, cx).await;
1747
1748        let model = cx.new_model(|_| TestModel::default());
1749        let (done_tx, mut done_rx) = smol::channel::unbounded();
1750        let subscription = client.add_message_handler(
1751            model.clone().downgrade(),
1752            move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1753                model
1754                    .update(&mut cx, |model, _| model.subscription.take())
1755                    .unwrap();
1756                done_tx.try_send(()).unwrap();
1757                async { Ok(()) }
1758            },
1759        );
1760        model.update(cx, |model, _| {
1761            model.subscription = Some(subscription);
1762        });
1763        server.send(proto::Ping {});
1764        done_rx.next().await.unwrap();
1765    }
1766
1767    #[derive(Default)]
1768    struct TestModel {
1769        id: usize,
1770        subscription: Option<Subscription>,
1771    }
1772
1773    fn init_test(cx: &mut TestAppContext) {
1774        cx.update(|cx| {
1775            let settings_store = SettingsStore::test(cx);
1776            cx.set_global(settings_store);
1777            init_settings(cx);
1778        });
1779    }
1780}