client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod telemetry;
   5pub mod user;
   6
   7use anyhow::{anyhow, Context, Result};
   8use async_recursion::async_recursion;
   9use async_tungstenite::tungstenite::{
  10    error::Error as WebsocketError,
  11    http::{Request, StatusCode},
  12};
  13use futures::{
  14    future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
  15    TryStreamExt,
  16};
  17use gpui::{
  18    actions, platform::AppVersion, serde_json, AnyModelHandle, AnyWeakModelHandle,
  19    AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Task, View, ViewContext,
  20    WeakViewHandle,
  21};
  22use lazy_static::lazy_static;
  23use parking_lot::RwLock;
  24use postage::watch;
  25use rand::prelude::*;
  26use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
  27use schemars::JsonSchema;
  28use serde::{Deserialize, Serialize};
  29use std::{
  30    any::TypeId,
  31    collections::HashMap,
  32    convert::TryFrom,
  33    fmt::Write as _,
  34    future::Future,
  35    marker::PhantomData,
  36    path::PathBuf,
  37    sync::{atomic::AtomicU64, Arc, Weak},
  38    time::{Duration, Instant},
  39};
  40use telemetry::Telemetry;
  41use thiserror::Error;
  42use url::Url;
  43use util::channel::ReleaseChannel;
  44use util::http::HttpClient;
  45use util::{ResultExt, TryFutureExt};
  46
  47pub use rpc::*;
  48pub use telemetry::ClickhouseEvent;
  49pub use user::*;
  50
  51lazy_static! {
  52    pub static ref ZED_SERVER_URL: String =
  53        std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
  54    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  55        .ok()
  56        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  57    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
  58        .ok()
  59        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  60    pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
  61        .ok()
  62        .and_then(|v| v.parse().ok());
  63    pub static ref ZED_APP_PATH: Option<PathBuf> =
  64        std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
  65    pub static ref ZED_ALWAYS_ACTIVE: bool =
  66        std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| e.len() > 0);
  67}
  68
  69pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
  70pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
  71pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
  72
  73actions!(client, [SignIn, SignOut, Reconnect]);
  74
  75pub fn init_settings(cx: &mut AppContext) {
  76    settings::register::<TelemetrySettings>(cx);
  77}
  78
  79pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
  80    init_settings(cx);
  81
  82    let client = Arc::downgrade(client);
  83    cx.add_global_action({
  84        let client = client.clone();
  85        move |_: &SignIn, cx| {
  86            if let Some(client) = client.upgrade() {
  87                cx.spawn(
  88                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
  89                )
  90                .detach();
  91            }
  92        }
  93    });
  94    cx.add_global_action({
  95        let client = client.clone();
  96        move |_: &SignOut, cx| {
  97            if let Some(client) = client.upgrade() {
  98                cx.spawn(|cx| async move {
  99                    client.disconnect(&cx);
 100                })
 101                .detach();
 102            }
 103        }
 104    });
 105    cx.add_global_action({
 106        let client = client.clone();
 107        move |_: &Reconnect, cx| {
 108            if let Some(client) = client.upgrade() {
 109                cx.spawn(|cx| async move {
 110                    client.reconnect(&cx);
 111                })
 112                .detach();
 113            }
 114        }
 115    });
 116}
 117
 118pub struct Client {
 119    id: AtomicU64,
 120    peer: Arc<Peer>,
 121    http: Arc<dyn HttpClient>,
 122    telemetry: Arc<Telemetry>,
 123    state: RwLock<ClientState>,
 124
 125    #[allow(clippy::type_complexity)]
 126    #[cfg(any(test, feature = "test-support"))]
 127    authenticate: RwLock<
 128        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 129    >,
 130
 131    #[allow(clippy::type_complexity)]
 132    #[cfg(any(test, feature = "test-support"))]
 133    establish_connection: RwLock<
 134        Option<
 135            Box<
 136                dyn 'static
 137                    + Send
 138                    + Sync
 139                    + Fn(
 140                        &Credentials,
 141                        &AsyncAppContext,
 142                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 143            >,
 144        >,
 145    >,
 146}
 147
 148#[derive(Error, Debug)]
 149pub enum EstablishConnectionError {
 150    #[error("upgrade required")]
 151    UpgradeRequired,
 152    #[error("unauthorized")]
 153    Unauthorized,
 154    #[error("{0}")]
 155    Other(#[from] anyhow::Error),
 156    #[error("{0}")]
 157    Http(#[from] util::http::Error),
 158    #[error("{0}")]
 159    Io(#[from] std::io::Error),
 160    #[error("{0}")]
 161    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 162}
 163
 164impl From<WebsocketError> for EstablishConnectionError {
 165    fn from(error: WebsocketError) -> Self {
 166        if let WebsocketError::Http(response) = &error {
 167            match response.status() {
 168                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 169                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 170                _ => {}
 171            }
 172        }
 173        EstablishConnectionError::Other(error.into())
 174    }
 175}
 176
 177impl EstablishConnectionError {
 178    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 179        Self::Other(error.into())
 180    }
 181}
 182
 183#[derive(Copy, Clone, Debug, PartialEq)]
 184pub enum Status {
 185    SignedOut,
 186    UpgradeRequired,
 187    Authenticating,
 188    Connecting,
 189    ConnectionError,
 190    Connected {
 191        peer_id: PeerId,
 192        connection_id: ConnectionId,
 193    },
 194    ConnectionLost,
 195    Reauthenticating,
 196    Reconnecting,
 197    ReconnectionError {
 198        next_reconnection: Instant,
 199    },
 200}
 201
 202impl Status {
 203    pub fn is_connected(&self) -> bool {
 204        matches!(self, Self::Connected { .. })
 205    }
 206
 207    pub fn is_signed_out(&self) -> bool {
 208        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 209    }
 210}
 211
 212struct ClientState {
 213    credentials: Option<Credentials>,
 214    status: (watch::Sender<Status>, watch::Receiver<Status>),
 215    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 216    _reconnect_task: Option<Task<()>>,
 217    reconnect_interval: Duration,
 218    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 219    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 220    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 221    #[allow(clippy::type_complexity)]
 222    message_handlers: HashMap<
 223        TypeId,
 224        Arc<
 225            dyn Send
 226                + Sync
 227                + Fn(
 228                    Subscriber,
 229                    Box<dyn AnyTypedEnvelope>,
 230                    &Arc<Client>,
 231                    AsyncAppContext,
 232                ) -> LocalBoxFuture<'static, Result<()>>,
 233        >,
 234    >,
 235}
 236
 237enum WeakSubscriber {
 238    Model(AnyWeakModelHandle),
 239    View(AnyWeakViewHandle),
 240    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 241}
 242
 243enum Subscriber {
 244    Model(AnyModelHandle),
 245    View(AnyWeakViewHandle),
 246}
 247
 248#[derive(Clone, Debug)]
 249pub struct Credentials {
 250    pub user_id: u64,
 251    pub access_token: String,
 252}
 253
 254impl Default for ClientState {
 255    fn default() -> Self {
 256        Self {
 257            credentials: None,
 258            status: watch::channel_with(Status::SignedOut),
 259            entity_id_extractors: Default::default(),
 260            _reconnect_task: None,
 261            reconnect_interval: Duration::from_secs(5),
 262            models_by_message_type: Default::default(),
 263            entities_by_type_and_remote_id: Default::default(),
 264            entity_types_by_message_type: Default::default(),
 265            message_handlers: Default::default(),
 266        }
 267    }
 268}
 269
 270pub enum Subscription {
 271    Entity {
 272        client: Weak<Client>,
 273        id: (TypeId, u64),
 274    },
 275    Message {
 276        client: Weak<Client>,
 277        id: TypeId,
 278    },
 279}
 280
 281impl Drop for Subscription {
 282    fn drop(&mut self) {
 283        match self {
 284            Subscription::Entity { client, id } => {
 285                if let Some(client) = client.upgrade() {
 286                    let mut state = client.state.write();
 287                    let _ = state.entities_by_type_and_remote_id.remove(id);
 288                }
 289            }
 290            Subscription::Message { client, id } => {
 291                if let Some(client) = client.upgrade() {
 292                    let mut state = client.state.write();
 293                    let _ = state.entity_types_by_message_type.remove(id);
 294                    let _ = state.message_handlers.remove(id);
 295                }
 296            }
 297        }
 298    }
 299}
 300
 301pub struct PendingEntitySubscription<T: Entity> {
 302    client: Arc<Client>,
 303    remote_id: u64,
 304    _entity_type: PhantomData<T>,
 305    consumed: bool,
 306}
 307
 308impl<T: Entity> PendingEntitySubscription<T> {
 309    pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
 310        self.consumed = true;
 311        let mut state = self.client.state.write();
 312        let id = (TypeId::of::<T>(), self.remote_id);
 313        let Some(WeakSubscriber::Pending(messages)) =
 314            state.entities_by_type_and_remote_id.remove(&id)
 315        else {
 316            unreachable!()
 317        };
 318
 319        state
 320            .entities_by_type_and_remote_id
 321            .insert(id, WeakSubscriber::Model(model.downgrade().into_any()));
 322        drop(state);
 323        for message in messages {
 324            self.client.handle_message(message, cx);
 325        }
 326        Subscription::Entity {
 327            client: Arc::downgrade(&self.client),
 328            id,
 329        }
 330    }
 331}
 332
 333impl<T: Entity> Drop for PendingEntitySubscription<T> {
 334    fn drop(&mut self) {
 335        if !self.consumed {
 336            let mut state = self.client.state.write();
 337            if let Some(WeakSubscriber::Pending(messages)) = state
 338                .entities_by_type_and_remote_id
 339                .remove(&(TypeId::of::<T>(), self.remote_id))
 340            {
 341                for message in messages {
 342                    log::info!("unhandled message {}", message.payload_type_name());
 343                }
 344            }
 345        }
 346    }
 347}
 348
 349#[derive(Copy, Clone)]
 350pub struct TelemetrySettings {
 351    pub diagnostics: bool,
 352    pub metrics: bool,
 353}
 354
 355#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 356pub struct TelemetrySettingsContent {
 357    pub diagnostics: Option<bool>,
 358    pub metrics: Option<bool>,
 359}
 360
 361impl settings::Setting for TelemetrySettings {
 362    const KEY: Option<&'static str> = Some("telemetry");
 363
 364    type FileContent = TelemetrySettingsContent;
 365
 366    fn load(
 367        default_value: &Self::FileContent,
 368        user_values: &[&Self::FileContent],
 369        _: &AppContext,
 370    ) -> Result<Self> {
 371        Ok(Self {
 372            diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
 373                default_value
 374                    .diagnostics
 375                    .ok_or_else(Self::missing_default)?,
 376            ),
 377            metrics: user_values
 378                .first()
 379                .and_then(|v| v.metrics)
 380                .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
 381        })
 382    }
 383}
 384
 385impl Client {
 386    pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
 387        Arc::new(Self {
 388            id: AtomicU64::new(0),
 389            peer: Peer::new(0),
 390            telemetry: Telemetry::new(http.clone(), cx),
 391            http,
 392            state: Default::default(),
 393
 394            #[cfg(any(test, feature = "test-support"))]
 395            authenticate: Default::default(),
 396            #[cfg(any(test, feature = "test-support"))]
 397            establish_connection: Default::default(),
 398        })
 399    }
 400
 401    pub fn id(&self) -> u64 {
 402        self.id.load(std::sync::atomic::Ordering::SeqCst)
 403    }
 404
 405    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 406        self.http.clone()
 407    }
 408
 409    pub fn set_id(&self, id: u64) -> &Self {
 410        self.id.store(id, std::sync::atomic::Ordering::SeqCst);
 411        self
 412    }
 413
 414    #[cfg(any(test, feature = "test-support"))]
 415    pub fn teardown(&self) {
 416        let mut state = self.state.write();
 417        state._reconnect_task.take();
 418        state.message_handlers.clear();
 419        state.models_by_message_type.clear();
 420        state.entities_by_type_and_remote_id.clear();
 421        state.entity_id_extractors.clear();
 422        self.peer.teardown();
 423    }
 424
 425    #[cfg(any(test, feature = "test-support"))]
 426    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 427    where
 428        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 429    {
 430        *self.authenticate.write() = Some(Box::new(authenticate));
 431        self
 432    }
 433
 434    #[cfg(any(test, feature = "test-support"))]
 435    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 436    where
 437        F: 'static
 438            + Send
 439            + Sync
 440            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 441    {
 442        *self.establish_connection.write() = Some(Box::new(connect));
 443        self
 444    }
 445
 446    pub fn user_id(&self) -> Option<u64> {
 447        self.state
 448            .read()
 449            .credentials
 450            .as_ref()
 451            .map(|credentials| credentials.user_id)
 452    }
 453
 454    pub fn peer_id(&self) -> Option<PeerId> {
 455        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 456            Some(*peer_id)
 457        } else {
 458            None
 459        }
 460    }
 461
 462    pub fn status(&self) -> watch::Receiver<Status> {
 463        self.state.read().status.1.clone()
 464    }
 465
 466    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 467        log::info!("set status on client {}: {:?}", self.id(), status);
 468        let mut state = self.state.write();
 469        *state.status.0.borrow_mut() = status;
 470
 471        match status {
 472            Status::Connected { .. } => {
 473                state._reconnect_task = None;
 474            }
 475            Status::ConnectionLost => {
 476                let this = self.clone();
 477                let reconnect_interval = state.reconnect_interval;
 478                state._reconnect_task = Some(cx.spawn(|cx| async move {
 479                    #[cfg(any(test, feature = "test-support"))]
 480                    let mut rng = StdRng::seed_from_u64(0);
 481                    #[cfg(not(any(test, feature = "test-support")))]
 482                    let mut rng = StdRng::from_entropy();
 483
 484                    let mut delay = INITIAL_RECONNECTION_DELAY;
 485                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 486                        log::error!("failed to connect {}", error);
 487                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 488                            this.set_status(
 489                                Status::ReconnectionError {
 490                                    next_reconnection: Instant::now() + delay,
 491                                },
 492                                &cx,
 493                            );
 494                            cx.background().timer(delay).await;
 495                            delay = delay
 496                                .mul_f32(rng.gen_range(1.0..=2.0))
 497                                .min(reconnect_interval);
 498                        } else {
 499                            break;
 500                        }
 501                    }
 502                }));
 503            }
 504            Status::SignedOut | Status::UpgradeRequired => {
 505                cx.read(|cx| self.telemetry.set_authenticated_user_info(None, false, cx));
 506                state._reconnect_task.take();
 507            }
 508            _ => {}
 509        }
 510    }
 511
 512    pub fn add_view_for_remote_entity<T: View>(
 513        self: &Arc<Self>,
 514        remote_id: u64,
 515        cx: &mut ViewContext<T>,
 516    ) -> Subscription {
 517        let id = (TypeId::of::<T>(), remote_id);
 518        self.state
 519            .write()
 520            .entities_by_type_and_remote_id
 521            .insert(id, WeakSubscriber::View(cx.weak_handle().into_any()));
 522        Subscription::Entity {
 523            client: Arc::downgrade(self),
 524            id,
 525        }
 526    }
 527
 528    pub fn subscribe_to_entity<T: Entity>(
 529        self: &Arc<Self>,
 530        remote_id: u64,
 531    ) -> Result<PendingEntitySubscription<T>> {
 532        let id = (TypeId::of::<T>(), remote_id);
 533
 534        let mut state = self.state.write();
 535        if state.entities_by_type_and_remote_id.contains_key(&id) {
 536            return Err(anyhow!("already subscribed to entity"));
 537        } else {
 538            state
 539                .entities_by_type_and_remote_id
 540                .insert(id, WeakSubscriber::Pending(Default::default()));
 541            Ok(PendingEntitySubscription {
 542                client: self.clone(),
 543                remote_id,
 544                consumed: false,
 545                _entity_type: PhantomData,
 546            })
 547        }
 548    }
 549
 550    #[track_caller]
 551    pub fn add_message_handler<M, E, H, F>(
 552        self: &Arc<Self>,
 553        model: ModelHandle<E>,
 554        handler: H,
 555    ) -> Subscription
 556    where
 557        M: EnvelopedMessage,
 558        E: Entity,
 559        H: 'static
 560            + Send
 561            + Sync
 562            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 563        F: 'static + Future<Output = Result<()>>,
 564    {
 565        let message_type_id = TypeId::of::<M>();
 566
 567        let mut state = self.state.write();
 568        state
 569            .models_by_message_type
 570            .insert(message_type_id, model.downgrade().into_any());
 571
 572        let prev_handler = state.message_handlers.insert(
 573            message_type_id,
 574            Arc::new(move |handle, envelope, client, cx| {
 575                let handle = if let Subscriber::Model(handle) = handle {
 576                    handle
 577                } else {
 578                    unreachable!();
 579                };
 580                let model = handle.downcast::<E>().unwrap();
 581                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 582                handler(model, *envelope, client.clone(), cx).boxed_local()
 583            }),
 584        );
 585        if prev_handler.is_some() {
 586            let location = std::panic::Location::caller();
 587            panic!(
 588                "{}:{} registered handler for the same message {} twice",
 589                location.file(),
 590                location.line(),
 591                std::any::type_name::<M>()
 592            );
 593        }
 594
 595        Subscription::Message {
 596            client: Arc::downgrade(self),
 597            id: message_type_id,
 598        }
 599    }
 600
 601    pub fn add_request_handler<M, E, H, F>(
 602        self: &Arc<Self>,
 603        model: ModelHandle<E>,
 604        handler: H,
 605    ) -> Subscription
 606    where
 607        M: RequestMessage,
 608        E: Entity,
 609        H: 'static
 610            + Send
 611            + Sync
 612            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 613        F: 'static + Future<Output = Result<M::Response>>,
 614    {
 615        self.add_message_handler(model, move |handle, envelope, this, cx| {
 616            Self::respond_to_request(
 617                envelope.receipt(),
 618                handler(handle, envelope, this.clone(), cx),
 619                this,
 620            )
 621        })
 622    }
 623
 624    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 625    where
 626        M: EntityMessage,
 627        E: View,
 628        H: 'static
 629            + Send
 630            + Sync
 631            + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 632        F: 'static + Future<Output = Result<()>>,
 633    {
 634        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 635            if let Subscriber::View(handle) = handle {
 636                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 637            } else {
 638                unreachable!();
 639            }
 640        })
 641    }
 642
 643    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 644    where
 645        M: EntityMessage,
 646        E: Entity,
 647        H: 'static
 648            + Send
 649            + Sync
 650            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 651        F: 'static + Future<Output = Result<()>>,
 652    {
 653        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 654            if let Subscriber::Model(handle) = handle {
 655                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 656            } else {
 657                unreachable!();
 658            }
 659        })
 660    }
 661
 662    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 663    where
 664        M: EntityMessage,
 665        E: Entity,
 666        H: 'static
 667            + Send
 668            + Sync
 669            + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 670        F: 'static + Future<Output = Result<()>>,
 671    {
 672        let model_type_id = TypeId::of::<E>();
 673        let message_type_id = TypeId::of::<M>();
 674
 675        let mut state = self.state.write();
 676        state
 677            .entity_types_by_message_type
 678            .insert(message_type_id, model_type_id);
 679        state
 680            .entity_id_extractors
 681            .entry(message_type_id)
 682            .or_insert_with(|| {
 683                |envelope| {
 684                    envelope
 685                        .as_any()
 686                        .downcast_ref::<TypedEnvelope<M>>()
 687                        .unwrap()
 688                        .payload
 689                        .remote_entity_id()
 690                }
 691            });
 692        let prev_handler = state.message_handlers.insert(
 693            message_type_id,
 694            Arc::new(move |handle, envelope, client, cx| {
 695                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 696                handler(handle, *envelope, client.clone(), cx).boxed_local()
 697            }),
 698        );
 699        if prev_handler.is_some() {
 700            panic!("registered handler for the same message twice");
 701        }
 702    }
 703
 704    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 705    where
 706        M: EntityMessage + RequestMessage,
 707        E: Entity,
 708        H: 'static
 709            + Send
 710            + Sync
 711            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 712        F: 'static + Future<Output = Result<M::Response>>,
 713    {
 714        self.add_model_message_handler(move |entity, envelope, client, cx| {
 715            Self::respond_to_request::<M, _>(
 716                envelope.receipt(),
 717                handler(entity, envelope, client.clone(), cx),
 718                client,
 719            )
 720        })
 721    }
 722
 723    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 724    where
 725        M: EntityMessage + RequestMessage,
 726        E: View,
 727        H: 'static
 728            + Send
 729            + Sync
 730            + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 731        F: 'static + Future<Output = Result<M::Response>>,
 732    {
 733        self.add_view_message_handler(move |entity, envelope, client, cx| {
 734            Self::respond_to_request::<M, _>(
 735                envelope.receipt(),
 736                handler(entity, envelope, client.clone(), cx),
 737                client,
 738            )
 739        })
 740    }
 741
 742    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 743        receipt: Receipt<T>,
 744        response: F,
 745        client: Arc<Self>,
 746    ) -> Result<()> {
 747        match response.await {
 748            Ok(response) => {
 749                client.respond(receipt, response)?;
 750                Ok(())
 751            }
 752            Err(error) => {
 753                client.respond_with_error(
 754                    receipt,
 755                    proto::Error {
 756                        message: format!("{:?}", error),
 757                    },
 758                )?;
 759                Err(error)
 760            }
 761        }
 762    }
 763
 764    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 765        read_credentials_from_keychain(cx).is_some()
 766    }
 767
 768    #[async_recursion(?Send)]
 769    pub async fn authenticate_and_connect(
 770        self: &Arc<Self>,
 771        try_keychain: bool,
 772        cx: &AsyncAppContext,
 773    ) -> anyhow::Result<()> {
 774        let was_disconnected = match *self.status().borrow() {
 775            Status::SignedOut => true,
 776            Status::ConnectionError
 777            | Status::ConnectionLost
 778            | Status::Authenticating { .. }
 779            | Status::Reauthenticating { .. }
 780            | Status::ReconnectionError { .. } => false,
 781            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 782                return Ok(())
 783            }
 784            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 785        };
 786
 787        if was_disconnected {
 788            self.set_status(Status::Authenticating, cx);
 789        } else {
 790            self.set_status(Status::Reauthenticating, cx)
 791        }
 792
 793        let mut read_from_keychain = false;
 794        let mut credentials = self.state.read().credentials.clone();
 795        if credentials.is_none() && try_keychain {
 796            credentials = read_credentials_from_keychain(cx);
 797            read_from_keychain = credentials.is_some();
 798        }
 799        if credentials.is_none() {
 800            let mut status_rx = self.status();
 801            let _ = status_rx.next().await;
 802            futures::select_biased! {
 803                authenticate = self.authenticate(cx).fuse() => {
 804                    match authenticate {
 805                        Ok(creds) => credentials = Some(creds),
 806                        Err(err) => {
 807                            self.set_status(Status::ConnectionError, cx);
 808                            return Err(err);
 809                        }
 810                    }
 811                }
 812                _ = status_rx.next().fuse() => {
 813                    return Err(anyhow!("authentication canceled"));
 814                }
 815            }
 816        }
 817        let credentials = credentials.unwrap();
 818        self.set_id(credentials.user_id);
 819
 820        if was_disconnected {
 821            self.set_status(Status::Connecting, cx);
 822        } else {
 823            self.set_status(Status::Reconnecting, cx);
 824        }
 825
 826        let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
 827        futures::select_biased! {
 828            connection = self.establish_connection(&credentials, cx).fuse() => {
 829                match connection {
 830                    Ok(conn) => {
 831                        self.state.write().credentials = Some(credentials.clone());
 832                        if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 833                            write_credentials_to_keychain(&credentials, cx).log_err();
 834                        }
 835
 836                        futures::select_biased! {
 837                            result = self.set_connection(conn, cx).fuse() => result,
 838                            _ = timeout => {
 839                                self.set_status(Status::ConnectionError, cx);
 840                                Err(anyhow!("timed out waiting on hello message from server"))
 841                            }
 842                        }
 843                    }
 844                    Err(EstablishConnectionError::Unauthorized) => {
 845                        self.state.write().credentials.take();
 846                        if read_from_keychain {
 847                            cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 848                            self.set_status(Status::SignedOut, cx);
 849                            self.authenticate_and_connect(false, cx).await
 850                        } else {
 851                            self.set_status(Status::ConnectionError, cx);
 852                            Err(EstablishConnectionError::Unauthorized)?
 853                        }
 854                    }
 855                    Err(EstablishConnectionError::UpgradeRequired) => {
 856                        self.set_status(Status::UpgradeRequired, cx);
 857                        Err(EstablishConnectionError::UpgradeRequired)?
 858                    }
 859                    Err(error) => {
 860                        self.set_status(Status::ConnectionError, cx);
 861                        Err(error)?
 862                    }
 863                }
 864            }
 865            _ = &mut timeout => {
 866                self.set_status(Status::ConnectionError, cx);
 867                Err(anyhow!("timed out trying to establish connection"))
 868            }
 869        }
 870    }
 871
 872    async fn set_connection(
 873        self: &Arc<Self>,
 874        conn: Connection,
 875        cx: &AsyncAppContext,
 876    ) -> Result<()> {
 877        let executor = cx.background();
 878        log::info!("add connection to peer");
 879        let (connection_id, handle_io, mut incoming) = self
 880            .peer
 881            .add_connection(conn, move |duration| executor.timer(duration));
 882        let handle_io = cx.background().spawn(handle_io);
 883
 884        let peer_id = async {
 885            log::info!("waiting for server hello");
 886            let message = incoming
 887                .next()
 888                .await
 889                .ok_or_else(|| anyhow!("no hello message received"))?;
 890            log::info!("got server hello");
 891            let hello_message_type_name = message.payload_type_name().to_string();
 892            let hello = message
 893                .into_any()
 894                .downcast::<TypedEnvelope<proto::Hello>>()
 895                .map_err(|_| {
 896                    anyhow!(
 897                        "invalid hello message received: {:?}",
 898                        hello_message_type_name
 899                    )
 900                })?;
 901            let peer_id = hello
 902                .payload
 903                .peer_id
 904                .ok_or_else(|| anyhow!("invalid peer id"))?;
 905            Ok(peer_id)
 906        };
 907
 908        let peer_id = match peer_id.await {
 909            Ok(peer_id) => peer_id,
 910            Err(error) => {
 911                self.peer.disconnect(connection_id);
 912                return Err(error);
 913            }
 914        };
 915
 916        log::info!(
 917            "set status to connected (connection id: {:?}, peer id: {:?})",
 918            connection_id,
 919            peer_id
 920        );
 921        self.set_status(
 922            Status::Connected {
 923                peer_id,
 924                connection_id,
 925            },
 926            cx,
 927        );
 928        cx.foreground()
 929            .spawn({
 930                let cx = cx.clone();
 931                let this = self.clone();
 932                async move {
 933                    while let Some(message) = incoming.next().await {
 934                        this.handle_message(message, &cx);
 935                        // Don't starve the main thread when receiving lots of messages at once.
 936                        smol::future::yield_now().await;
 937                    }
 938                }
 939            })
 940            .detach();
 941
 942        let this = self.clone();
 943        let cx = cx.clone();
 944        cx.foreground()
 945            .spawn(async move {
 946                match handle_io.await {
 947                    Ok(()) => {
 948                        if this.status().borrow().clone()
 949                            == (Status::Connected {
 950                                connection_id,
 951                                peer_id,
 952                            })
 953                        {
 954                            this.set_status(Status::SignedOut, &cx);
 955                        }
 956                    }
 957                    Err(err) => {
 958                        log::error!("connection error: {:?}", err);
 959                        this.set_status(Status::ConnectionLost, &cx);
 960                    }
 961                }
 962            })
 963            .detach();
 964
 965        Ok(())
 966    }
 967
 968    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 969        #[cfg(any(test, feature = "test-support"))]
 970        if let Some(callback) = self.authenticate.read().as_ref() {
 971            return callback(cx);
 972        }
 973
 974        self.authenticate_with_browser(cx)
 975    }
 976
 977    fn establish_connection(
 978        self: &Arc<Self>,
 979        credentials: &Credentials,
 980        cx: &AsyncAppContext,
 981    ) -> Task<Result<Connection, EstablishConnectionError>> {
 982        #[cfg(any(test, feature = "test-support"))]
 983        if let Some(callback) = self.establish_connection.read().as_ref() {
 984            return callback(credentials, cx);
 985        }
 986
 987        self.establish_websocket_connection(credentials, cx)
 988    }
 989
 990    async fn get_rpc_url(
 991        http: Arc<dyn HttpClient>,
 992        release_channel: Option<ReleaseChannel>,
 993    ) -> Result<Url> {
 994        let mut url = format!("{}/rpc", *ZED_SERVER_URL);
 995        if let Some(preview_param) =
 996            release_channel.and_then(|channel| channel.release_query_param())
 997        {
 998            url += "?";
 999            url += preview_param;
1000        }
1001        let response = http.get(&url, Default::default(), false).await?;
1002
1003        // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
1004        // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
1005        // which requires authorization via an HTTP header.
1006        //
1007        // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
1008        // of a collab server. In that case, a request to the /rpc endpoint will
1009        // return an 'unauthorized' response.
1010        let collab_url = if response.status().is_redirection() {
1011            response
1012                .headers()
1013                .get("Location")
1014                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1015                .to_str()
1016                .map_err(EstablishConnectionError::other)?
1017                .to_string()
1018        } else if response.status() == StatusCode::UNAUTHORIZED {
1019            url
1020        } else {
1021            Err(anyhow!(
1022                "unexpected /rpc response status {}",
1023                response.status()
1024            ))?
1025        };
1026
1027        Url::parse(&collab_url).context("invalid rpc url")
1028    }
1029
1030    fn establish_websocket_connection(
1031        self: &Arc<Self>,
1032        credentials: &Credentials,
1033        cx: &AsyncAppContext,
1034    ) -> Task<Result<Connection, EstablishConnectionError>> {
1035        let release_channel = cx.read(|cx| {
1036            if cx.has_global::<ReleaseChannel>() {
1037                Some(*cx.global::<ReleaseChannel>())
1038            } else {
1039                None
1040            }
1041        });
1042
1043        let request = Request::builder()
1044            .header(
1045                "Authorization",
1046                format!("{} {}", credentials.user_id, credentials.access_token),
1047            )
1048            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
1049
1050        let http = self.http.clone();
1051        cx.background().spawn(async move {
1052            let mut rpc_url = Self::get_rpc_url(http, release_channel).await?;
1053            let rpc_host = rpc_url
1054                .host_str()
1055                .zip(rpc_url.port_or_known_default())
1056                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1057            let stream = smol::net::TcpStream::connect(rpc_host).await?;
1058
1059            log::info!("connected to rpc endpoint {}", rpc_url);
1060
1061            match rpc_url.scheme() {
1062                "https" => {
1063                    rpc_url.set_scheme("wss").unwrap();
1064                    let request = request.uri(rpc_url.as_str()).body(())?;
1065                    let (stream, _) =
1066                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1067                    Ok(Connection::new(
1068                        stream
1069                            .map_err(|error| anyhow!(error))
1070                            .sink_map_err(|error| anyhow!(error)),
1071                    ))
1072                }
1073                "http" => {
1074                    rpc_url.set_scheme("ws").unwrap();
1075                    let request = request.uri(rpc_url.as_str()).body(())?;
1076                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1077                    Ok(Connection::new(
1078                        stream
1079                            .map_err(|error| anyhow!(error))
1080                            .sink_map_err(|error| anyhow!(error)),
1081                    ))
1082                }
1083                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1084            }
1085        })
1086    }
1087
1088    pub fn authenticate_with_browser(
1089        self: &Arc<Self>,
1090        cx: &AsyncAppContext,
1091    ) -> Task<Result<Credentials>> {
1092        let platform = cx.platform();
1093        let executor = cx.background();
1094        let http = self.http.clone();
1095
1096        executor.clone().spawn(async move {
1097            // Generate a pair of asymmetric encryption keys. The public key will be used by the
1098            // zed server to encrypt the user's access token, so that it can'be intercepted by
1099            // any other app running on the user's device.
1100            let (public_key, private_key) =
1101                rpc::auth::keypair().expect("failed to generate keypair for auth");
1102            let public_key_string =
1103                String::try_from(public_key).expect("failed to serialize public key for auth");
1104
1105            if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1106                return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1107            }
1108
1109            // Start an HTTP server to receive the redirect from Zed's sign-in page.
1110            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1111            let port = server.server_addr().port();
1112
1113            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1114            // that the user is signing in from a Zed app running on the same device.
1115            let mut url = format!(
1116                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1117                *ZED_SERVER_URL, port, public_key_string
1118            );
1119
1120            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1121                log::info!("impersonating user @{}", impersonate_login);
1122                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1123            }
1124
1125            platform.open_url(&url);
1126
1127            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1128            // access token from the query params.
1129            //
1130            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1131            // custom URL scheme instead of this local HTTP server.
1132            let (user_id, access_token) = executor
1133                .spawn(async move {
1134                    for _ in 0..100 {
1135                        if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1136                            let path = req.url();
1137                            let mut user_id = None;
1138                            let mut access_token = None;
1139                            let url = Url::parse(&format!("http://example.com{}", path))
1140                                .context("failed to parse login notification url")?;
1141                            for (key, value) in url.query_pairs() {
1142                                if key == "access_token" {
1143                                    access_token = Some(value.to_string());
1144                                } else if key == "user_id" {
1145                                    user_id = Some(value.to_string());
1146                                }
1147                            }
1148
1149                            let post_auth_url =
1150                                format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1151                            req.respond(
1152                                tiny_http::Response::empty(302).with_header(
1153                                    tiny_http::Header::from_bytes(
1154                                        &b"Location"[..],
1155                                        post_auth_url.as_bytes(),
1156                                    )
1157                                    .unwrap(),
1158                                ),
1159                            )
1160                            .context("failed to respond to login http request")?;
1161                            return Ok((
1162                                user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1163                                access_token
1164                                    .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1165                            ));
1166                        }
1167                    }
1168
1169                    Err(anyhow!("didn't receive login redirect"))
1170                })
1171                .await?;
1172
1173            let access_token = private_key
1174                .decrypt_string(&access_token)
1175                .context("failed to decrypt access token")?;
1176            platform.activate(true);
1177
1178            Ok(Credentials {
1179                user_id: user_id.parse()?,
1180                access_token,
1181            })
1182        })
1183    }
1184
1185    async fn authenticate_as_admin(
1186        http: Arc<dyn HttpClient>,
1187        login: String,
1188        mut api_token: String,
1189    ) -> Result<Credentials> {
1190        #[derive(Deserialize)]
1191        struct AuthenticatedUserResponse {
1192            user: User,
1193        }
1194
1195        #[derive(Deserialize)]
1196        struct User {
1197            id: u64,
1198        }
1199
1200        // Use the collab server's admin API to retrieve the id
1201        // of the impersonated user.
1202        let mut url = Self::get_rpc_url(http.clone(), None).await?;
1203        url.set_path("/user");
1204        url.set_query(Some(&format!("github_login={login}")));
1205        let request = Request::get(url.as_str())
1206            .header("Authorization", format!("token {api_token}"))
1207            .body("".into())?;
1208
1209        let mut response = http.send(request).await?;
1210        let mut body = String::new();
1211        response.body_mut().read_to_string(&mut body).await?;
1212        if !response.status().is_success() {
1213            Err(anyhow!(
1214                "admin user request failed {} - {}",
1215                response.status().as_u16(),
1216                body,
1217            ))?;
1218        }
1219        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1220
1221        // Use the admin API token to authenticate as the impersonated user.
1222        api_token.insert_str(0, "ADMIN_TOKEN:");
1223        Ok(Credentials {
1224            user_id: response.user.id,
1225            access_token: api_token,
1226        })
1227    }
1228
1229    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1230        self.peer.teardown();
1231        self.set_status(Status::SignedOut, cx);
1232    }
1233
1234    pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1235        self.peer.teardown();
1236        self.set_status(Status::ConnectionLost, cx);
1237    }
1238
1239    fn connection_id(&self) -> Result<ConnectionId> {
1240        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1241            Ok(connection_id)
1242        } else {
1243            Err(anyhow!("not connected"))
1244        }
1245    }
1246
1247    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1248        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1249        self.peer.send(self.connection_id()?, message)
1250    }
1251
1252    pub fn request<T: RequestMessage>(
1253        &self,
1254        request: T,
1255    ) -> impl Future<Output = Result<T::Response>> {
1256        self.request_envelope(request)
1257            .map_ok(|envelope| envelope.payload)
1258    }
1259
1260    pub fn request_envelope<T: RequestMessage>(
1261        &self,
1262        request: T,
1263    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1264        let client_id = self.id();
1265        log::debug!(
1266            "rpc request start. client_id:{}. name:{}",
1267            client_id,
1268            T::NAME
1269        );
1270        let response = self
1271            .connection_id()
1272            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1273        async move {
1274            let response = response?.await;
1275            log::debug!(
1276                "rpc request finish. client_id:{}. name:{}",
1277                client_id,
1278                T::NAME
1279            );
1280            response
1281        }
1282    }
1283
1284    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1285        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1286        self.peer.respond(receipt, response)
1287    }
1288
1289    fn respond_with_error<T: RequestMessage>(
1290        &self,
1291        receipt: Receipt<T>,
1292        error: proto::Error,
1293    ) -> Result<()> {
1294        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1295        self.peer.respond_with_error(receipt, error)
1296    }
1297
1298    fn handle_message(
1299        self: &Arc<Client>,
1300        message: Box<dyn AnyTypedEnvelope>,
1301        cx: &AsyncAppContext,
1302    ) {
1303        let mut state = self.state.write();
1304        let type_name = message.payload_type_name();
1305        let payload_type_id = message.payload_type_id();
1306        let sender_id = message.original_sender_id();
1307
1308        let mut subscriber = None;
1309
1310        if let Some(message_model) = state
1311            .models_by_message_type
1312            .get(&payload_type_id)
1313            .and_then(|model| model.upgrade(cx))
1314        {
1315            subscriber = Some(Subscriber::Model(message_model));
1316        } else if let Some((extract_entity_id, entity_type_id)) =
1317            state.entity_id_extractors.get(&payload_type_id).zip(
1318                state
1319                    .entity_types_by_message_type
1320                    .get(&payload_type_id)
1321                    .copied(),
1322            )
1323        {
1324            let entity_id = (extract_entity_id)(message.as_ref());
1325
1326            match state
1327                .entities_by_type_and_remote_id
1328                .get_mut(&(entity_type_id, entity_id))
1329            {
1330                Some(WeakSubscriber::Pending(pending)) => {
1331                    pending.push(message);
1332                    return;
1333                }
1334                Some(weak_subscriber @ _) => match weak_subscriber {
1335                    WeakSubscriber::Model(handle) => {
1336                        subscriber = handle.upgrade(cx).map(Subscriber::Model);
1337                    }
1338                    WeakSubscriber::View(handle) => {
1339                        subscriber = Some(Subscriber::View(handle.clone()));
1340                    }
1341                    WeakSubscriber::Pending(_) => {}
1342                },
1343                _ => {}
1344            }
1345        }
1346
1347        let subscriber = if let Some(subscriber) = subscriber {
1348            subscriber
1349        } else {
1350            log::info!("unhandled message {}", type_name);
1351            self.peer.respond_with_unhandled_message(message).log_err();
1352            return;
1353        };
1354
1355        let handler = state.message_handlers.get(&payload_type_id).cloned();
1356        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1357        // It also ensures we don't hold the lock while yielding back to the executor, as
1358        // that might cause the executor thread driving this future to block indefinitely.
1359        drop(state);
1360
1361        if let Some(handler) = handler {
1362            let future = handler(subscriber, message, &self, cx.clone());
1363            let client_id = self.id();
1364            log::debug!(
1365                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1366                client_id,
1367                sender_id,
1368                type_name
1369            );
1370            cx.foreground()
1371                .spawn(async move {
1372                    match future.await {
1373                        Ok(()) => {
1374                            log::debug!(
1375                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1376                                client_id,
1377                                sender_id,
1378                                type_name
1379                            );
1380                        }
1381                        Err(error) => {
1382                            log::error!(
1383                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1384                                client_id,
1385                                sender_id,
1386                                type_name,
1387                                error
1388                            );
1389                        }
1390                    }
1391                })
1392                .detach();
1393        } else {
1394            log::info!("unhandled message {}", type_name);
1395            self.peer.respond_with_unhandled_message(message).log_err();
1396        }
1397    }
1398
1399    pub fn telemetry(&self) -> &Arc<Telemetry> {
1400        &self.telemetry
1401    }
1402}
1403
1404fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1405    if IMPERSONATE_LOGIN.is_some() {
1406        return None;
1407    }
1408
1409    let (user_id, access_token) = cx
1410        .platform()
1411        .read_credentials(&ZED_SERVER_URL)
1412        .log_err()
1413        .flatten()?;
1414    Some(Credentials {
1415        user_id: user_id.parse().ok()?,
1416        access_token: String::from_utf8(access_token).ok()?,
1417    })
1418}
1419
1420fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1421    cx.platform().write_credentials(
1422        &ZED_SERVER_URL,
1423        &credentials.user_id.to_string(),
1424        credentials.access_token.as_bytes(),
1425    )
1426}
1427
1428const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1429
1430pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1431    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1432}
1433
1434pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1435    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1436    let mut parts = path.split('/');
1437    let id = parts.next()?.parse::<u64>().ok()?;
1438    let access_token = parts.next()?;
1439    if access_token.is_empty() {
1440        return None;
1441    }
1442    Some((id, access_token.to_string()))
1443}
1444
1445#[cfg(test)]
1446mod tests {
1447    use super::*;
1448    use crate::test::FakeServer;
1449    use gpui::{executor::Deterministic, TestAppContext};
1450    use parking_lot::Mutex;
1451    use std::future;
1452    use util::http::FakeHttpClient;
1453
1454    #[gpui::test(iterations = 10)]
1455    async fn test_reconnection(cx: &mut TestAppContext) {
1456        cx.foreground().forbid_parking();
1457
1458        let user_id = 5;
1459        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1460        let server = FakeServer::for_client(user_id, &client, cx).await;
1461        let mut status = client.status();
1462        assert!(matches!(
1463            status.next().await,
1464            Some(Status::Connected { .. })
1465        ));
1466        assert_eq!(server.auth_count(), 1);
1467
1468        server.forbid_connections();
1469        server.disconnect();
1470        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1471
1472        server.allow_connections();
1473        cx.foreground().advance_clock(Duration::from_secs(10));
1474        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1475        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1476
1477        server.forbid_connections();
1478        server.disconnect();
1479        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1480
1481        // Clear cached credentials after authentication fails
1482        server.roll_access_token();
1483        server.allow_connections();
1484        cx.foreground().advance_clock(Duration::from_secs(10));
1485        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1486        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1487    }
1488
1489    #[gpui::test(iterations = 10)]
1490    async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1491        deterministic.forbid_parking();
1492
1493        let user_id = 5;
1494        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1495        let mut status = client.status();
1496
1497        // Time out when client tries to connect.
1498        client.override_authenticate(move |cx| {
1499            cx.foreground().spawn(async move {
1500                Ok(Credentials {
1501                    user_id,
1502                    access_token: "token".into(),
1503                })
1504            })
1505        });
1506        client.override_establish_connection(|_, cx| {
1507            cx.foreground().spawn(async move {
1508                future::pending::<()>().await;
1509                unreachable!()
1510            })
1511        });
1512        let auth_and_connect = cx.spawn({
1513            let client = client.clone();
1514            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1515        });
1516        deterministic.run_until_parked();
1517        assert!(matches!(status.next().await, Some(Status::Connecting)));
1518
1519        deterministic.advance_clock(CONNECTION_TIMEOUT);
1520        assert!(matches!(
1521            status.next().await,
1522            Some(Status::ConnectionError { .. })
1523        ));
1524        auth_and_connect.await.unwrap_err();
1525
1526        // Allow the connection to be established.
1527        let server = FakeServer::for_client(user_id, &client, cx).await;
1528        assert!(matches!(
1529            status.next().await,
1530            Some(Status::Connected { .. })
1531        ));
1532
1533        // Disconnect client.
1534        server.forbid_connections();
1535        server.disconnect();
1536        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1537
1538        // Time out when re-establishing the connection.
1539        server.allow_connections();
1540        client.override_establish_connection(|_, cx| {
1541            cx.foreground().spawn(async move {
1542                future::pending::<()>().await;
1543                unreachable!()
1544            })
1545        });
1546        deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1547        assert!(matches!(
1548            status.next().await,
1549            Some(Status::Reconnecting { .. })
1550        ));
1551
1552        deterministic.advance_clock(CONNECTION_TIMEOUT);
1553        assert!(matches!(
1554            status.next().await,
1555            Some(Status::ReconnectionError { .. })
1556        ));
1557    }
1558
1559    #[gpui::test(iterations = 10)]
1560    async fn test_authenticating_more_than_once(
1561        cx: &mut TestAppContext,
1562        deterministic: Arc<Deterministic>,
1563    ) {
1564        cx.foreground().forbid_parking();
1565
1566        let auth_count = Arc::new(Mutex::new(0));
1567        let dropped_auth_count = Arc::new(Mutex::new(0));
1568        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1569        client.override_authenticate({
1570            let auth_count = auth_count.clone();
1571            let dropped_auth_count = dropped_auth_count.clone();
1572            move |cx| {
1573                let auth_count = auth_count.clone();
1574                let dropped_auth_count = dropped_auth_count.clone();
1575                cx.foreground().spawn(async move {
1576                    *auth_count.lock() += 1;
1577                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1578                    future::pending::<()>().await;
1579                    unreachable!()
1580                })
1581            }
1582        });
1583
1584        let _authenticate = cx.spawn(|cx| {
1585            let client = client.clone();
1586            async move { client.authenticate_and_connect(false, &cx).await }
1587        });
1588        deterministic.run_until_parked();
1589        assert_eq!(*auth_count.lock(), 1);
1590        assert_eq!(*dropped_auth_count.lock(), 0);
1591
1592        let _authenticate = cx.spawn(|cx| {
1593            let client = client.clone();
1594            async move { client.authenticate_and_connect(false, &cx).await }
1595        });
1596        deterministic.run_until_parked();
1597        assert_eq!(*auth_count.lock(), 2);
1598        assert_eq!(*dropped_auth_count.lock(), 1);
1599    }
1600
1601    #[test]
1602    fn test_encode_and_decode_worktree_url() {
1603        let url = encode_worktree_url(5, "deadbeef");
1604        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1605        assert_eq!(
1606            decode_worktree_url(&format!("\n {}\t", url)),
1607            Some((5, "deadbeef".to_string()))
1608        );
1609        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1610    }
1611
1612    #[gpui::test]
1613    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1614        cx.foreground().forbid_parking();
1615
1616        let user_id = 5;
1617        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1618        let server = FakeServer::for_client(user_id, &client, cx).await;
1619
1620        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1621        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1622        client.add_model_message_handler(
1623            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1624                match model.read_with(&cx, |model, _| model.id) {
1625                    1 => done_tx1.try_send(()).unwrap(),
1626                    2 => done_tx2.try_send(()).unwrap(),
1627                    _ => unreachable!(),
1628                }
1629                async { Ok(()) }
1630            },
1631        );
1632        let model1 = cx.add_model(|_| Model {
1633            id: 1,
1634            subscription: None,
1635        });
1636        let model2 = cx.add_model(|_| Model {
1637            id: 2,
1638            subscription: None,
1639        });
1640        let model3 = cx.add_model(|_| Model {
1641            id: 3,
1642            subscription: None,
1643        });
1644
1645        let _subscription1 = client
1646            .subscribe_to_entity(1)
1647            .unwrap()
1648            .set_model(&model1, &mut cx.to_async());
1649        let _subscription2 = client
1650            .subscribe_to_entity(2)
1651            .unwrap()
1652            .set_model(&model2, &mut cx.to_async());
1653        // Ensure dropping a subscription for the same entity type still allows receiving of
1654        // messages for other entity IDs of the same type.
1655        let subscription3 = client
1656            .subscribe_to_entity(3)
1657            .unwrap()
1658            .set_model(&model3, &mut cx.to_async());
1659        drop(subscription3);
1660
1661        server.send(proto::JoinProject { project_id: 1 });
1662        server.send(proto::JoinProject { project_id: 2 });
1663        done_rx1.next().await.unwrap();
1664        done_rx2.next().await.unwrap();
1665    }
1666
1667    #[gpui::test]
1668    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1669        cx.foreground().forbid_parking();
1670
1671        let user_id = 5;
1672        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1673        let server = FakeServer::for_client(user_id, &client, cx).await;
1674
1675        let model = cx.add_model(|_| Model::default());
1676        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1677        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1678        let subscription1 = client.add_message_handler(
1679            model.clone(),
1680            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1681                done_tx1.try_send(()).unwrap();
1682                async { Ok(()) }
1683            },
1684        );
1685        drop(subscription1);
1686        let _subscription2 = client.add_message_handler(
1687            model.clone(),
1688            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1689                done_tx2.try_send(()).unwrap();
1690                async { Ok(()) }
1691            },
1692        );
1693        server.send(proto::Ping {});
1694        done_rx2.next().await.unwrap();
1695    }
1696
1697    #[gpui::test]
1698    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1699        cx.foreground().forbid_parking();
1700
1701        let user_id = 5;
1702        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1703        let server = FakeServer::for_client(user_id, &client, cx).await;
1704
1705        let model = cx.add_model(|_| Model::default());
1706        let (done_tx, mut done_rx) = smol::channel::unbounded();
1707        let subscription = client.add_message_handler(
1708            model.clone(),
1709            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1710                model.update(&mut cx, |model, _| model.subscription.take());
1711                done_tx.try_send(()).unwrap();
1712                async { Ok(()) }
1713            },
1714        );
1715        model.update(cx, |model, _| {
1716            model.subscription = Some(subscription);
1717        });
1718        server.send(proto::Ping {});
1719        done_rx.next().await.unwrap();
1720    }
1721
1722    #[derive(Default)]
1723    struct Model {
1724        id: usize,
1725        subscription: Option<Subscription>,
1726    }
1727
1728    impl Entity for Model {
1729        type Event = ();
1730    }
1731}