client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod telemetry;
   5pub mod user;
   6
   7use anyhow::{anyhow, Context, Result};
   8use async_recursion::async_recursion;
   9use async_tungstenite::tungstenite::{
  10    error::Error as WebsocketError,
  11    http::{Request, StatusCode},
  12};
  13use futures::{
  14    future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
  15    TryStreamExt,
  16};
  17use gpui::{
  18    actions, platform::AppVersion, serde_json, AnyModelHandle, AnyWeakModelHandle,
  19    AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Task, View, ViewContext,
  20    WeakViewHandle,
  21};
  22use lazy_static::lazy_static;
  23use parking_lot::RwLock;
  24use postage::watch;
  25use rand::prelude::*;
  26use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
  27use schemars::JsonSchema;
  28use serde::{Deserialize, Serialize};
  29use std::{
  30    any::TypeId,
  31    collections::HashMap,
  32    convert::TryFrom,
  33    fmt::Write as _,
  34    future::Future,
  35    marker::PhantomData,
  36    path::PathBuf,
  37    sync::{atomic::AtomicU64, Arc, Weak},
  38    time::{Duration, Instant},
  39};
  40use telemetry::Telemetry;
  41use thiserror::Error;
  42use url::Url;
  43use util::channel::ReleaseChannel;
  44use util::http::HttpClient;
  45use util::{ResultExt, TryFutureExt};
  46
  47pub use rpc::*;
  48pub use telemetry::ClickhouseEvent;
  49pub use user::*;
  50
  51lazy_static! {
  52    pub static ref ZED_SERVER_URL: String =
  53        std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
  54    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  55        .ok()
  56        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  57    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
  58        .ok()
  59        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  60    pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
  61        .ok()
  62        .and_then(|v| v.parse().ok());
  63    pub static ref ZED_APP_PATH: Option<PathBuf> =
  64        std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
  65    pub static ref ZED_ALWAYS_ACTIVE: bool =
  66        std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| e.len() > 0);
  67}
  68
  69pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
  70pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
  71pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
  72
  73actions!(client, [SignIn, SignOut, Reconnect]);
  74
  75pub fn init_settings(cx: &mut AppContext) {
  76    settings::register::<TelemetrySettings>(cx);
  77}
  78
  79pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
  80    init_settings(cx);
  81
  82    let client = Arc::downgrade(client);
  83    cx.add_global_action({
  84        let client = client.clone();
  85        move |_: &SignIn, cx| {
  86            if let Some(client) = client.upgrade() {
  87                cx.spawn(
  88                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
  89                )
  90                .detach();
  91            }
  92        }
  93    });
  94    cx.add_global_action({
  95        let client = client.clone();
  96        move |_: &SignOut, cx| {
  97            if let Some(client) = client.upgrade() {
  98                cx.spawn(|cx| async move {
  99                    client.disconnect(&cx);
 100                })
 101                .detach();
 102            }
 103        }
 104    });
 105    cx.add_global_action({
 106        let client = client.clone();
 107        move |_: &Reconnect, cx| {
 108            if let Some(client) = client.upgrade() {
 109                cx.spawn(|cx| async move {
 110                    client.reconnect(&cx);
 111                })
 112                .detach();
 113            }
 114        }
 115    });
 116}
 117
 118pub struct Client {
 119    id: AtomicU64,
 120    peer: Arc<Peer>,
 121    http: Arc<dyn HttpClient>,
 122    telemetry: Arc<Telemetry>,
 123    state: RwLock<ClientState>,
 124
 125    #[allow(clippy::type_complexity)]
 126    #[cfg(any(test, feature = "test-support"))]
 127    authenticate: RwLock<
 128        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 129    >,
 130
 131    #[allow(clippy::type_complexity)]
 132    #[cfg(any(test, feature = "test-support"))]
 133    establish_connection: RwLock<
 134        Option<
 135            Box<
 136                dyn 'static
 137                    + Send
 138                    + Sync
 139                    + Fn(
 140                        &Credentials,
 141                        &AsyncAppContext,
 142                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 143            >,
 144        >,
 145    >,
 146}
 147
 148#[derive(Error, Debug)]
 149pub enum EstablishConnectionError {
 150    #[error("upgrade required")]
 151    UpgradeRequired,
 152    #[error("unauthorized")]
 153    Unauthorized,
 154    #[error("{0}")]
 155    Other(#[from] anyhow::Error),
 156    #[error("{0}")]
 157    Http(#[from] util::http::Error),
 158    #[error("{0}")]
 159    Io(#[from] std::io::Error),
 160    #[error("{0}")]
 161    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 162}
 163
 164impl From<WebsocketError> for EstablishConnectionError {
 165    fn from(error: WebsocketError) -> Self {
 166        if let WebsocketError::Http(response) = &error {
 167            match response.status() {
 168                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 169                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 170                _ => {}
 171            }
 172        }
 173        EstablishConnectionError::Other(error.into())
 174    }
 175}
 176
 177impl EstablishConnectionError {
 178    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 179        Self::Other(error.into())
 180    }
 181}
 182
 183#[derive(Copy, Clone, Debug, PartialEq)]
 184pub enum Status {
 185    SignedOut,
 186    UpgradeRequired,
 187    Authenticating,
 188    Connecting,
 189    ConnectionError,
 190    Connected {
 191        peer_id: PeerId,
 192        connection_id: ConnectionId,
 193    },
 194    ConnectionLost,
 195    Reauthenticating,
 196    Reconnecting,
 197    ReconnectionError {
 198        next_reconnection: Instant,
 199    },
 200}
 201
 202impl Status {
 203    pub fn is_connected(&self) -> bool {
 204        matches!(self, Self::Connected { .. })
 205    }
 206
 207    pub fn is_signed_out(&self) -> bool {
 208        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 209    }
 210}
 211
 212struct ClientState {
 213    credentials: Option<Credentials>,
 214    status: (watch::Sender<Status>, watch::Receiver<Status>),
 215    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 216    _reconnect_task: Option<Task<()>>,
 217    reconnect_interval: Duration,
 218    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 219    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 220    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 221    #[allow(clippy::type_complexity)]
 222    message_handlers: HashMap<
 223        TypeId,
 224        Arc<
 225            dyn Send
 226                + Sync
 227                + Fn(
 228                    Subscriber,
 229                    Box<dyn AnyTypedEnvelope>,
 230                    &Arc<Client>,
 231                    AsyncAppContext,
 232                ) -> LocalBoxFuture<'static, Result<()>>,
 233        >,
 234    >,
 235}
 236
 237enum WeakSubscriber {
 238    Model(AnyWeakModelHandle),
 239    View(AnyWeakViewHandle),
 240    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 241}
 242
 243enum Subscriber {
 244    Model(AnyModelHandle),
 245    View(AnyWeakViewHandle),
 246}
 247
 248#[derive(Clone, Debug)]
 249pub struct Credentials {
 250    pub user_id: u64,
 251    pub access_token: String,
 252}
 253
 254impl Default for ClientState {
 255    fn default() -> Self {
 256        Self {
 257            credentials: None,
 258            status: watch::channel_with(Status::SignedOut),
 259            entity_id_extractors: Default::default(),
 260            _reconnect_task: None,
 261            reconnect_interval: Duration::from_secs(5),
 262            models_by_message_type: Default::default(),
 263            entities_by_type_and_remote_id: Default::default(),
 264            entity_types_by_message_type: Default::default(),
 265            message_handlers: Default::default(),
 266        }
 267    }
 268}
 269
 270pub enum Subscription {
 271    Entity {
 272        client: Weak<Client>,
 273        id: (TypeId, u64),
 274    },
 275    Message {
 276        client: Weak<Client>,
 277        id: TypeId,
 278    },
 279}
 280
 281impl Drop for Subscription {
 282    fn drop(&mut self) {
 283        match self {
 284            Subscription::Entity { client, id } => {
 285                if let Some(client) = client.upgrade() {
 286                    let mut state = client.state.write();
 287                    let _ = state.entities_by_type_and_remote_id.remove(id);
 288                }
 289            }
 290            Subscription::Message { client, id } => {
 291                if let Some(client) = client.upgrade() {
 292                    let mut state = client.state.write();
 293                    let _ = state.entity_types_by_message_type.remove(id);
 294                    let _ = state.message_handlers.remove(id);
 295                }
 296            }
 297        }
 298    }
 299}
 300
 301pub struct PendingEntitySubscription<T: Entity> {
 302    client: Arc<Client>,
 303    remote_id: u64,
 304    _entity_type: PhantomData<T>,
 305    consumed: bool,
 306}
 307
 308impl<T: Entity> PendingEntitySubscription<T> {
 309    pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
 310        self.consumed = true;
 311        let mut state = self.client.state.write();
 312        let id = (TypeId::of::<T>(), self.remote_id);
 313        let Some(WeakSubscriber::Pending(messages)) =
 314            state.entities_by_type_and_remote_id.remove(&id)
 315        else {
 316            unreachable!()
 317        };
 318
 319        state
 320            .entities_by_type_and_remote_id
 321            .insert(id, WeakSubscriber::Model(model.downgrade().into_any()));
 322        drop(state);
 323        for message in messages {
 324            self.client.handle_message(message, cx);
 325        }
 326        Subscription::Entity {
 327            client: Arc::downgrade(&self.client),
 328            id,
 329        }
 330    }
 331}
 332
 333impl<T: Entity> Drop for PendingEntitySubscription<T> {
 334    fn drop(&mut self) {
 335        if !self.consumed {
 336            let mut state = self.client.state.write();
 337            if let Some(WeakSubscriber::Pending(messages)) = state
 338                .entities_by_type_and_remote_id
 339                .remove(&(TypeId::of::<T>(), self.remote_id))
 340            {
 341                for message in messages {
 342                    log::info!("unhandled message {}", message.payload_type_name());
 343                }
 344            }
 345        }
 346    }
 347}
 348
 349#[derive(Copy, Clone)]
 350pub struct TelemetrySettings {
 351    pub diagnostics: bool,
 352    pub metrics: bool,
 353}
 354
 355#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 356pub struct TelemetrySettingsContent {
 357    pub diagnostics: Option<bool>,
 358    pub metrics: Option<bool>,
 359}
 360
 361impl settings::Setting for TelemetrySettings {
 362    const KEY: Option<&'static str> = Some("telemetry");
 363
 364    type FileContent = TelemetrySettingsContent;
 365
 366    fn load(
 367        default_value: &Self::FileContent,
 368        user_values: &[&Self::FileContent],
 369        _: &AppContext,
 370    ) -> Result<Self> {
 371        Ok(Self {
 372            diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
 373                default_value
 374                    .diagnostics
 375                    .ok_or_else(Self::missing_default)?,
 376            ),
 377            metrics: user_values
 378                .first()
 379                .and_then(|v| v.metrics)
 380                .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
 381        })
 382    }
 383}
 384
 385impl Client {
 386    pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
 387        Arc::new(Self {
 388            id: AtomicU64::new(0),
 389            peer: Peer::new(0),
 390            telemetry: Telemetry::new(http.clone(), cx),
 391            http,
 392            state: Default::default(),
 393
 394            #[cfg(any(test, feature = "test-support"))]
 395            authenticate: Default::default(),
 396            #[cfg(any(test, feature = "test-support"))]
 397            establish_connection: Default::default(),
 398        })
 399    }
 400
 401    pub fn id(&self) -> u64 {
 402        self.id.load(std::sync::atomic::Ordering::SeqCst)
 403    }
 404
 405    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 406        self.http.clone()
 407    }
 408
 409    pub fn set_id(&self, id: u64) -> &Self {
 410        self.id.store(id, std::sync::atomic::Ordering::SeqCst);
 411        self
 412    }
 413
 414    #[cfg(any(test, feature = "test-support"))]
 415    pub fn teardown(&self) {
 416        let mut state = self.state.write();
 417        state._reconnect_task.take();
 418        state.message_handlers.clear();
 419        state.models_by_message_type.clear();
 420        state.entities_by_type_and_remote_id.clear();
 421        state.entity_id_extractors.clear();
 422        self.peer.teardown();
 423    }
 424
 425    #[cfg(any(test, feature = "test-support"))]
 426    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 427    where
 428        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 429    {
 430        *self.authenticate.write() = Some(Box::new(authenticate));
 431        self
 432    }
 433
 434    #[cfg(any(test, feature = "test-support"))]
 435    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 436    where
 437        F: 'static
 438            + Send
 439            + Sync
 440            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 441    {
 442        *self.establish_connection.write() = Some(Box::new(connect));
 443        self
 444    }
 445
 446    pub fn user_id(&self) -> Option<u64> {
 447        self.state
 448            .read()
 449            .credentials
 450            .as_ref()
 451            .map(|credentials| credentials.user_id)
 452    }
 453
 454    pub fn peer_id(&self) -> Option<PeerId> {
 455        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 456            Some(*peer_id)
 457        } else {
 458            None
 459        }
 460    }
 461
 462    pub fn status(&self) -> watch::Receiver<Status> {
 463        self.state.read().status.1.clone()
 464    }
 465
 466    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 467        log::info!("set status on client {}: {:?}", self.id(), status);
 468        let mut state = self.state.write();
 469        *state.status.0.borrow_mut() = status;
 470
 471        match status {
 472            Status::Connected { .. } => {
 473                state._reconnect_task = None;
 474            }
 475            Status::ConnectionLost => {
 476                let this = self.clone();
 477                let reconnect_interval = state.reconnect_interval;
 478                state._reconnect_task = Some(cx.spawn(|cx| async move {
 479                    #[cfg(any(test, feature = "test-support"))]
 480                    let mut rng = StdRng::seed_from_u64(0);
 481                    #[cfg(not(any(test, feature = "test-support")))]
 482                    let mut rng = StdRng::from_entropy();
 483
 484                    let mut delay = INITIAL_RECONNECTION_DELAY;
 485                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 486                        log::error!("failed to connect {}", error);
 487                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 488                            this.set_status(
 489                                Status::ReconnectionError {
 490                                    next_reconnection: Instant::now() + delay,
 491                                },
 492                                &cx,
 493                            );
 494                            cx.background().timer(delay).await;
 495                            delay = delay
 496                                .mul_f32(rng.gen_range(1.0..=2.0))
 497                                .min(reconnect_interval);
 498                        } else {
 499                            break;
 500                        }
 501                    }
 502                }));
 503            }
 504            Status::SignedOut | Status::UpgradeRequired => {
 505                cx.read(|cx| self.telemetry.set_authenticated_user_info(None, false, cx));
 506                state._reconnect_task.take();
 507            }
 508            _ => {}
 509        }
 510    }
 511
 512    pub fn add_view_for_remote_entity<T: View>(
 513        self: &Arc<Self>,
 514        remote_id: u64,
 515        cx: &mut ViewContext<T>,
 516    ) -> Subscription {
 517        let id = (TypeId::of::<T>(), remote_id);
 518        self.state
 519            .write()
 520            .entities_by_type_and_remote_id
 521            .insert(id, WeakSubscriber::View(cx.weak_handle().into_any()));
 522        Subscription::Entity {
 523            client: Arc::downgrade(self),
 524            id,
 525        }
 526    }
 527
 528    pub fn subscribe_to_entity<T: Entity>(
 529        self: &Arc<Self>,
 530        remote_id: u64,
 531    ) -> Result<PendingEntitySubscription<T>> {
 532        let id = (TypeId::of::<T>(), remote_id);
 533
 534        let mut state = self.state.write();
 535        if state.entities_by_type_and_remote_id.contains_key(&id) {
 536            return Err(anyhow!("already subscribed to entity"));
 537        } else {
 538            state
 539                .entities_by_type_and_remote_id
 540                .insert(id, WeakSubscriber::Pending(Default::default()));
 541            Ok(PendingEntitySubscription {
 542                client: self.clone(),
 543                remote_id,
 544                consumed: false,
 545                _entity_type: PhantomData,
 546            })
 547        }
 548    }
 549
 550    #[track_caller]
 551    pub fn add_message_handler<M, E, H, F>(
 552        self: &Arc<Self>,
 553        model: ModelHandle<E>,
 554        handler: H,
 555    ) -> Subscription
 556    where
 557        M: EnvelopedMessage,
 558        E: Entity,
 559        H: 'static
 560            + Send
 561            + Sync
 562            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 563        F: 'static + Future<Output = Result<()>>,
 564    {
 565        let message_type_id = TypeId::of::<M>();
 566
 567        let mut state = self.state.write();
 568        state
 569            .models_by_message_type
 570            .insert(message_type_id, model.downgrade().into_any());
 571
 572        let prev_handler = state.message_handlers.insert(
 573            message_type_id,
 574            Arc::new(move |handle, envelope, client, cx| {
 575                let handle = if let Subscriber::Model(handle) = handle {
 576                    handle
 577                } else {
 578                    unreachable!();
 579                };
 580                let model = handle.downcast::<E>().unwrap();
 581                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 582                handler(model, *envelope, client.clone(), cx).boxed_local()
 583            }),
 584        );
 585        if prev_handler.is_some() {
 586            let location = std::panic::Location::caller();
 587            panic!(
 588                "{}:{} registered handler for the same message {} twice",
 589                location.file(),
 590                location.line(),
 591                std::any::type_name::<M>()
 592            );
 593        }
 594
 595        Subscription::Message {
 596            client: Arc::downgrade(self),
 597            id: message_type_id,
 598        }
 599    }
 600
 601    pub fn add_request_handler<M, E, H, F>(
 602        self: &Arc<Self>,
 603        model: ModelHandle<E>,
 604        handler: H,
 605    ) -> Subscription
 606    where
 607        M: RequestMessage,
 608        E: Entity,
 609        H: 'static
 610            + Send
 611            + Sync
 612            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 613        F: 'static + Future<Output = Result<M::Response>>,
 614    {
 615        self.add_message_handler(model, move |handle, envelope, this, cx| {
 616            Self::respond_to_request(
 617                envelope.receipt(),
 618                handler(handle, envelope, this.clone(), cx),
 619                this,
 620            )
 621        })
 622    }
 623
 624    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 625    where
 626        M: EntityMessage,
 627        E: View,
 628        H: 'static
 629            + Send
 630            + Sync
 631            + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 632        F: 'static + Future<Output = Result<()>>,
 633    {
 634        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 635            if let Subscriber::View(handle) = handle {
 636                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 637            } else {
 638                unreachable!();
 639            }
 640        })
 641    }
 642
 643    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 644    where
 645        M: EntityMessage,
 646        E: Entity,
 647        H: 'static
 648            + Send
 649            + Sync
 650            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 651        F: 'static + Future<Output = Result<()>>,
 652    {
 653        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 654            if let Subscriber::Model(handle) = handle {
 655                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 656            } else {
 657                unreachable!();
 658            }
 659        })
 660    }
 661
 662    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 663    where
 664        M: EntityMessage,
 665        E: Entity,
 666        H: 'static
 667            + Send
 668            + Sync
 669            + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 670        F: 'static + Future<Output = Result<()>>,
 671    {
 672        let model_type_id = TypeId::of::<E>();
 673        let message_type_id = TypeId::of::<M>();
 674
 675        let mut state = self.state.write();
 676        state
 677            .entity_types_by_message_type
 678            .insert(message_type_id, model_type_id);
 679        state
 680            .entity_id_extractors
 681            .entry(message_type_id)
 682            .or_insert_with(|| {
 683                |envelope| {
 684                    envelope
 685                        .as_any()
 686                        .downcast_ref::<TypedEnvelope<M>>()
 687                        .unwrap()
 688                        .payload
 689                        .remote_entity_id()
 690                }
 691            });
 692        let prev_handler = state.message_handlers.insert(
 693            message_type_id,
 694            Arc::new(move |handle, envelope, client, cx| {
 695                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 696                handler(handle, *envelope, client.clone(), cx).boxed_local()
 697            }),
 698        );
 699        if prev_handler.is_some() {
 700            panic!("registered handler for the same message twice");
 701        }
 702    }
 703
 704    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 705    where
 706        M: EntityMessage + RequestMessage,
 707        E: Entity,
 708        H: 'static
 709            + Send
 710            + Sync
 711            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 712        F: 'static + Future<Output = Result<M::Response>>,
 713    {
 714        self.add_model_message_handler(move |entity, envelope, client, cx| {
 715            Self::respond_to_request::<M, _>(
 716                envelope.receipt(),
 717                handler(entity, envelope, client.clone(), cx),
 718                client,
 719            )
 720        })
 721    }
 722
 723    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 724    where
 725        M: EntityMessage + RequestMessage,
 726        E: View,
 727        H: 'static
 728            + Send
 729            + Sync
 730            + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 731        F: 'static + Future<Output = Result<M::Response>>,
 732    {
 733        self.add_view_message_handler(move |entity, envelope, client, cx| {
 734            Self::respond_to_request::<M, _>(
 735                envelope.receipt(),
 736                handler(entity, envelope, client.clone(), cx),
 737                client,
 738            )
 739        })
 740    }
 741
 742    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 743        receipt: Receipt<T>,
 744        response: F,
 745        client: Arc<Self>,
 746    ) -> Result<()> {
 747        match response.await {
 748            Ok(response) => {
 749                client.respond(receipt, response)?;
 750                Ok(())
 751            }
 752            Err(error) => {
 753                client.respond_with_error(
 754                    receipt,
 755                    proto::Error {
 756                        message: format!("{:?}", error),
 757                    },
 758                )?;
 759                Err(error)
 760            }
 761        }
 762    }
 763
 764    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 765        read_credentials_from_keychain(cx).is_some()
 766    }
 767
 768    #[async_recursion(?Send)]
 769    pub async fn authenticate_and_connect(
 770        self: &Arc<Self>,
 771        try_keychain: bool,
 772        cx: &AsyncAppContext,
 773    ) -> anyhow::Result<()> {
 774        let was_disconnected = match *self.status().borrow() {
 775            Status::SignedOut => true,
 776            Status::ConnectionError
 777            | Status::ConnectionLost
 778            | Status::Authenticating { .. }
 779            | Status::Reauthenticating { .. }
 780            | Status::ReconnectionError { .. } => false,
 781            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 782                return Ok(())
 783            }
 784            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 785        };
 786
 787        if was_disconnected {
 788            self.set_status(Status::Authenticating, cx);
 789        } else {
 790            self.set_status(Status::Reauthenticating, cx)
 791        }
 792
 793        let mut read_from_keychain = false;
 794        let mut credentials = self.state.read().credentials.clone();
 795        if credentials.is_none() && try_keychain {
 796            credentials = read_credentials_from_keychain(cx);
 797            read_from_keychain = credentials.is_some();
 798        }
 799        if credentials.is_none() {
 800            let mut status_rx = self.status();
 801            let _ = status_rx.next().await;
 802            futures::select_biased! {
 803                authenticate = self.authenticate(cx).fuse() => {
 804                    match authenticate {
 805                        Ok(creds) => credentials = Some(creds),
 806                        Err(err) => {
 807                            self.set_status(Status::ConnectionError, cx);
 808                            return Err(err);
 809                        }
 810                    }
 811                }
 812                _ = status_rx.next().fuse() => {
 813                    return Err(anyhow!("authentication canceled"));
 814                }
 815            }
 816        }
 817        let credentials = credentials.unwrap();
 818        self.set_id(credentials.user_id);
 819
 820        if was_disconnected {
 821            self.set_status(Status::Connecting, cx);
 822        } else {
 823            self.set_status(Status::Reconnecting, cx);
 824        }
 825
 826        let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
 827        futures::select_biased! {
 828            connection = self.establish_connection(&credentials, cx).fuse() => {
 829                match connection {
 830                    Ok(conn) => {
 831                        self.state.write().credentials = Some(credentials.clone());
 832                        if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 833                            write_credentials_to_keychain(&credentials, cx).log_err();
 834                        }
 835
 836                        futures::select_biased! {
 837                            result = self.set_connection(conn, cx).fuse() => result,
 838                            _ = timeout => {
 839                                self.set_status(Status::ConnectionError, cx);
 840                                Err(anyhow!("timed out waiting on hello message from server"))
 841                            }
 842                        }
 843                    }
 844                    Err(EstablishConnectionError::Unauthorized) => {
 845                        self.state.write().credentials.take();
 846                        if read_from_keychain {
 847                            cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 848                            self.set_status(Status::SignedOut, cx);
 849                            self.authenticate_and_connect(false, cx).await
 850                        } else {
 851                            self.set_status(Status::ConnectionError, cx);
 852                            Err(EstablishConnectionError::Unauthorized)?
 853                        }
 854                    }
 855                    Err(EstablishConnectionError::UpgradeRequired) => {
 856                        self.set_status(Status::UpgradeRequired, cx);
 857                        Err(EstablishConnectionError::UpgradeRequired)?
 858                    }
 859                    Err(error) => {
 860                        self.set_status(Status::ConnectionError, cx);
 861                        Err(error)?
 862                    }
 863                }
 864            }
 865            _ = &mut timeout => {
 866                self.set_status(Status::ConnectionError, cx);
 867                Err(anyhow!("timed out trying to establish connection"))
 868            }
 869        }
 870    }
 871
 872    async fn set_connection(
 873        self: &Arc<Self>,
 874        conn: Connection,
 875        cx: &AsyncAppContext,
 876    ) -> Result<()> {
 877        let executor = cx.background();
 878        log::info!("add connection to peer");
 879        let (connection_id, handle_io, mut incoming) = self
 880            .peer
 881            .add_connection(conn, move |duration| executor.timer(duration));
 882        let handle_io = cx.background().spawn(handle_io);
 883
 884        let peer_id = async {
 885            log::info!("waiting for server hello");
 886            let message = incoming
 887                .next()
 888                .await
 889                .ok_or_else(|| anyhow!("no hello message received"))?;
 890            log::info!("got server hello");
 891            let hello_message_type_name = message.payload_type_name().to_string();
 892            let hello = message
 893                .into_any()
 894                .downcast::<TypedEnvelope<proto::Hello>>()
 895                .map_err(|_| {
 896                    anyhow!(
 897                        "invalid hello message received: {:?}",
 898                        hello_message_type_name
 899                    )
 900                })?;
 901            let peer_id = hello
 902                .payload
 903                .peer_id
 904                .ok_or_else(|| anyhow!("invalid peer id"))?;
 905            Ok(peer_id)
 906        };
 907
 908        let peer_id = match peer_id.await {
 909            Ok(peer_id) => peer_id,
 910            Err(error) => {
 911                self.peer.disconnect(connection_id);
 912                return Err(error);
 913            }
 914        };
 915
 916        log::info!(
 917            "set status to connected (connection id: {:?}, peer id: {:?})",
 918            connection_id,
 919            peer_id
 920        );
 921        self.set_status(
 922            Status::Connected {
 923                peer_id,
 924                connection_id,
 925            },
 926            cx,
 927        );
 928        cx.foreground()
 929            .spawn({
 930                let cx = cx.clone();
 931                let this = self.clone();
 932                async move {
 933                    while let Some(message) = incoming.next().await {
 934                        this.handle_message(message, &cx);
 935                        // Don't starve the main thread when receiving lots of messages at once.
 936                        smol::future::yield_now().await;
 937                    }
 938                }
 939            })
 940            .detach();
 941
 942        let this = self.clone();
 943        let cx = cx.clone();
 944        cx.foreground()
 945            .spawn(async move {
 946                match handle_io.await {
 947                    Ok(()) => {
 948                        if this.status().borrow().clone()
 949                            == (Status::Connected {
 950                                connection_id,
 951                                peer_id,
 952                            })
 953                        {
 954                            this.set_status(Status::SignedOut, &cx);
 955                        }
 956                    }
 957                    Err(err) => {
 958                        log::error!("connection error: {:?}", err);
 959                        this.set_status(Status::ConnectionLost, &cx);
 960                    }
 961                }
 962            })
 963            .detach();
 964
 965        Ok(())
 966    }
 967
 968    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 969        #[cfg(any(test, feature = "test-support"))]
 970        if let Some(callback) = self.authenticate.read().as_ref() {
 971            return callback(cx);
 972        }
 973
 974        self.authenticate_with_browser(cx)
 975    }
 976
 977    fn establish_connection(
 978        self: &Arc<Self>,
 979        credentials: &Credentials,
 980        cx: &AsyncAppContext,
 981    ) -> Task<Result<Connection, EstablishConnectionError>> {
 982        #[cfg(any(test, feature = "test-support"))]
 983        if let Some(callback) = self.establish_connection.read().as_ref() {
 984            return callback(credentials, cx);
 985        }
 986
 987        self.establish_websocket_connection(credentials, cx)
 988    }
 989
 990    async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
 991        let preview_param = if is_preview { "?preview=1" } else { "" };
 992        let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
 993        let response = http.get(&url, Default::default(), false).await?;
 994
 995        // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
 996        // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
 997        // which requires authorization via an HTTP header.
 998        //
 999        // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
1000        // of a collab server. In that case, a request to the /rpc endpoint will
1001        // return an 'unauthorized' response.
1002        let collab_url = if response.status().is_redirection() {
1003            response
1004                .headers()
1005                .get("Location")
1006                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1007                .to_str()
1008                .map_err(EstablishConnectionError::other)?
1009                .to_string()
1010        } else if response.status() == StatusCode::UNAUTHORIZED {
1011            url
1012        } else {
1013            Err(anyhow!(
1014                "unexpected /rpc response status {}",
1015                response.status()
1016            ))?
1017        };
1018
1019        Url::parse(&collab_url).context("invalid rpc url")
1020    }
1021
1022    fn establish_websocket_connection(
1023        self: &Arc<Self>,
1024        credentials: &Credentials,
1025        cx: &AsyncAppContext,
1026    ) -> Task<Result<Connection, EstablishConnectionError>> {
1027        let use_preview_server = cx.read(|cx| {
1028            if cx.has_global::<ReleaseChannel>() {
1029                *cx.global::<ReleaseChannel>() != ReleaseChannel::Stable
1030            } else {
1031                false
1032            }
1033        });
1034
1035        let request = Request::builder()
1036            .header(
1037                "Authorization",
1038                format!("{} {}", credentials.user_id, credentials.access_token),
1039            )
1040            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
1041
1042        let http = self.http.clone();
1043        cx.background().spawn(async move {
1044            let mut rpc_url = Self::get_rpc_url(http, use_preview_server).await?;
1045            let rpc_host = rpc_url
1046                .host_str()
1047                .zip(rpc_url.port_or_known_default())
1048                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1049            let stream = smol::net::TcpStream::connect(rpc_host).await?;
1050
1051            log::info!("connected to rpc endpoint {}", rpc_url);
1052
1053            match rpc_url.scheme() {
1054                "https" => {
1055                    rpc_url.set_scheme("wss").unwrap();
1056                    let request = request.uri(rpc_url.as_str()).body(())?;
1057                    let (stream, _) =
1058                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1059                    Ok(Connection::new(
1060                        stream
1061                            .map_err(|error| anyhow!(error))
1062                            .sink_map_err(|error| anyhow!(error)),
1063                    ))
1064                }
1065                "http" => {
1066                    rpc_url.set_scheme("ws").unwrap();
1067                    let request = request.uri(rpc_url.as_str()).body(())?;
1068                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1069                    Ok(Connection::new(
1070                        stream
1071                            .map_err(|error| anyhow!(error))
1072                            .sink_map_err(|error| anyhow!(error)),
1073                    ))
1074                }
1075                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1076            }
1077        })
1078    }
1079
1080    pub fn authenticate_with_browser(
1081        self: &Arc<Self>,
1082        cx: &AsyncAppContext,
1083    ) -> Task<Result<Credentials>> {
1084        let platform = cx.platform();
1085        let executor = cx.background();
1086        let http = self.http.clone();
1087
1088        executor.clone().spawn(async move {
1089            // Generate a pair of asymmetric encryption keys. The public key will be used by the
1090            // zed server to encrypt the user's access token, so that it can'be intercepted by
1091            // any other app running on the user's device.
1092            let (public_key, private_key) =
1093                rpc::auth::keypair().expect("failed to generate keypair for auth");
1094            let public_key_string =
1095                String::try_from(public_key).expect("failed to serialize public key for auth");
1096
1097            if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1098                return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1099            }
1100
1101            // Start an HTTP server to receive the redirect from Zed's sign-in page.
1102            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1103            let port = server.server_addr().port();
1104
1105            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1106            // that the user is signing in from a Zed app running on the same device.
1107            let mut url = format!(
1108                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1109                *ZED_SERVER_URL, port, public_key_string
1110            );
1111
1112            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1113                log::info!("impersonating user @{}", impersonate_login);
1114                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1115            }
1116
1117            platform.open_url(&url);
1118
1119            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1120            // access token from the query params.
1121            //
1122            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1123            // custom URL scheme instead of this local HTTP server.
1124            let (user_id, access_token) = executor
1125                .spawn(async move {
1126                    for _ in 0..100 {
1127                        if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1128                            let path = req.url();
1129                            let mut user_id = None;
1130                            let mut access_token = None;
1131                            let url = Url::parse(&format!("http://example.com{}", path))
1132                                .context("failed to parse login notification url")?;
1133                            for (key, value) in url.query_pairs() {
1134                                if key == "access_token" {
1135                                    access_token = Some(value.to_string());
1136                                } else if key == "user_id" {
1137                                    user_id = Some(value.to_string());
1138                                }
1139                            }
1140
1141                            let post_auth_url =
1142                                format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1143                            req.respond(
1144                                tiny_http::Response::empty(302).with_header(
1145                                    tiny_http::Header::from_bytes(
1146                                        &b"Location"[..],
1147                                        post_auth_url.as_bytes(),
1148                                    )
1149                                    .unwrap(),
1150                                ),
1151                            )
1152                            .context("failed to respond to login http request")?;
1153                            return Ok((
1154                                user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1155                                access_token
1156                                    .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1157                            ));
1158                        }
1159                    }
1160
1161                    Err(anyhow!("didn't receive login redirect"))
1162                })
1163                .await?;
1164
1165            let access_token = private_key
1166                .decrypt_string(&access_token)
1167                .context("failed to decrypt access token")?;
1168            platform.activate(true);
1169
1170            Ok(Credentials {
1171                user_id: user_id.parse()?,
1172                access_token,
1173            })
1174        })
1175    }
1176
1177    async fn authenticate_as_admin(
1178        http: Arc<dyn HttpClient>,
1179        login: String,
1180        mut api_token: String,
1181    ) -> Result<Credentials> {
1182        #[derive(Deserialize)]
1183        struct AuthenticatedUserResponse {
1184            user: User,
1185        }
1186
1187        #[derive(Deserialize)]
1188        struct User {
1189            id: u64,
1190        }
1191
1192        // Use the collab server's admin API to retrieve the id
1193        // of the impersonated user.
1194        let mut url = Self::get_rpc_url(http.clone(), false).await?;
1195        url.set_path("/user");
1196        url.set_query(Some(&format!("github_login={login}")));
1197        let request = Request::get(url.as_str())
1198            .header("Authorization", format!("token {api_token}"))
1199            .body("".into())?;
1200
1201        let mut response = http.send(request).await?;
1202        let mut body = String::new();
1203        response.body_mut().read_to_string(&mut body).await?;
1204        if !response.status().is_success() {
1205            Err(anyhow!(
1206                "admin user request failed {} - {}",
1207                response.status().as_u16(),
1208                body,
1209            ))?;
1210        }
1211        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1212
1213        // Use the admin API token to authenticate as the impersonated user.
1214        api_token.insert_str(0, "ADMIN_TOKEN:");
1215        Ok(Credentials {
1216            user_id: response.user.id,
1217            access_token: api_token,
1218        })
1219    }
1220
1221    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1222        self.peer.teardown();
1223        self.set_status(Status::SignedOut, cx);
1224    }
1225
1226    pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1227        self.peer.teardown();
1228        self.set_status(Status::ConnectionLost, cx);
1229    }
1230
1231    fn connection_id(&self) -> Result<ConnectionId> {
1232        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1233            Ok(connection_id)
1234        } else {
1235            Err(anyhow!("not connected"))
1236        }
1237    }
1238
1239    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1240        log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1241        self.peer.send(self.connection_id()?, message)
1242    }
1243
1244    pub fn request<T: RequestMessage>(
1245        &self,
1246        request: T,
1247    ) -> impl Future<Output = Result<T::Response>> {
1248        self.request_envelope(request)
1249            .map_ok(|envelope| envelope.payload)
1250    }
1251
1252    pub fn request_envelope<T: RequestMessage>(
1253        &self,
1254        request: T,
1255    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1256        let client_id = self.id();
1257        log::debug!(
1258            "rpc request start. client_id:{}. name:{}",
1259            client_id,
1260            T::NAME
1261        );
1262        let response = self
1263            .connection_id()
1264            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1265        async move {
1266            let response = response?.await;
1267            log::debug!(
1268                "rpc request finish. client_id:{}. name:{}",
1269                client_id,
1270                T::NAME
1271            );
1272            response
1273        }
1274    }
1275
1276    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1277        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1278        self.peer.respond(receipt, response)
1279    }
1280
1281    fn respond_with_error<T: RequestMessage>(
1282        &self,
1283        receipt: Receipt<T>,
1284        error: proto::Error,
1285    ) -> Result<()> {
1286        log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1287        self.peer.respond_with_error(receipt, error)
1288    }
1289
1290    fn handle_message(
1291        self: &Arc<Client>,
1292        message: Box<dyn AnyTypedEnvelope>,
1293        cx: &AsyncAppContext,
1294    ) {
1295        let mut state = self.state.write();
1296        let type_name = message.payload_type_name();
1297        let payload_type_id = message.payload_type_id();
1298        let sender_id = message.original_sender_id();
1299
1300        let mut subscriber = None;
1301
1302        if let Some(message_model) = state
1303            .models_by_message_type
1304            .get(&payload_type_id)
1305            .and_then(|model| model.upgrade(cx))
1306        {
1307            subscriber = Some(Subscriber::Model(message_model));
1308        } else if let Some((extract_entity_id, entity_type_id)) =
1309            state.entity_id_extractors.get(&payload_type_id).zip(
1310                state
1311                    .entity_types_by_message_type
1312                    .get(&payload_type_id)
1313                    .copied(),
1314            )
1315        {
1316            let entity_id = (extract_entity_id)(message.as_ref());
1317
1318            match state
1319                .entities_by_type_and_remote_id
1320                .get_mut(&(entity_type_id, entity_id))
1321            {
1322                Some(WeakSubscriber::Pending(pending)) => {
1323                    pending.push(message);
1324                    return;
1325                }
1326                Some(weak_subscriber @ _) => match weak_subscriber {
1327                    WeakSubscriber::Model(handle) => {
1328                        subscriber = handle.upgrade(cx).map(Subscriber::Model);
1329                    }
1330                    WeakSubscriber::View(handle) => {
1331                        subscriber = Some(Subscriber::View(handle.clone()));
1332                    }
1333                    WeakSubscriber::Pending(_) => {}
1334                },
1335                _ => {}
1336            }
1337        }
1338
1339        let subscriber = if let Some(subscriber) = subscriber {
1340            subscriber
1341        } else {
1342            log::info!("unhandled message {}", type_name);
1343            self.peer.respond_with_unhandled_message(message).log_err();
1344            return;
1345        };
1346
1347        let handler = state.message_handlers.get(&payload_type_id).cloned();
1348        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1349        // It also ensures we don't hold the lock while yielding back to the executor, as
1350        // that might cause the executor thread driving this future to block indefinitely.
1351        drop(state);
1352
1353        if let Some(handler) = handler {
1354            let future = handler(subscriber, message, &self, cx.clone());
1355            let client_id = self.id();
1356            log::debug!(
1357                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1358                client_id,
1359                sender_id,
1360                type_name
1361            );
1362            cx.foreground()
1363                .spawn(async move {
1364                    match future.await {
1365                        Ok(()) => {
1366                            log::debug!(
1367                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1368                                client_id,
1369                                sender_id,
1370                                type_name
1371                            );
1372                        }
1373                        Err(error) => {
1374                            log::error!(
1375                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1376                                client_id,
1377                                sender_id,
1378                                type_name,
1379                                error
1380                            );
1381                        }
1382                    }
1383                })
1384                .detach();
1385        } else {
1386            log::info!("unhandled message {}", type_name);
1387            self.peer.respond_with_unhandled_message(message).log_err();
1388        }
1389    }
1390
1391    pub fn telemetry(&self) -> &Arc<Telemetry> {
1392        &self.telemetry
1393    }
1394}
1395
1396fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1397    if IMPERSONATE_LOGIN.is_some() {
1398        return None;
1399    }
1400
1401    let (user_id, access_token) = cx
1402        .platform()
1403        .read_credentials(&ZED_SERVER_URL)
1404        .log_err()
1405        .flatten()?;
1406    Some(Credentials {
1407        user_id: user_id.parse().ok()?,
1408        access_token: String::from_utf8(access_token).ok()?,
1409    })
1410}
1411
1412fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1413    cx.platform().write_credentials(
1414        &ZED_SERVER_URL,
1415        &credentials.user_id.to_string(),
1416        credentials.access_token.as_bytes(),
1417    )
1418}
1419
1420const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1421
1422pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1423    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1424}
1425
1426pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1427    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1428    let mut parts = path.split('/');
1429    let id = parts.next()?.parse::<u64>().ok()?;
1430    let access_token = parts.next()?;
1431    if access_token.is_empty() {
1432        return None;
1433    }
1434    Some((id, access_token.to_string()))
1435}
1436
1437#[cfg(test)]
1438mod tests {
1439    use super::*;
1440    use crate::test::FakeServer;
1441    use gpui::{executor::Deterministic, TestAppContext};
1442    use parking_lot::Mutex;
1443    use std::future;
1444    use util::http::FakeHttpClient;
1445
1446    #[gpui::test(iterations = 10)]
1447    async fn test_reconnection(cx: &mut TestAppContext) {
1448        cx.foreground().forbid_parking();
1449
1450        let user_id = 5;
1451        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1452        let server = FakeServer::for_client(user_id, &client, cx).await;
1453        let mut status = client.status();
1454        assert!(matches!(
1455            status.next().await,
1456            Some(Status::Connected { .. })
1457        ));
1458        assert_eq!(server.auth_count(), 1);
1459
1460        server.forbid_connections();
1461        server.disconnect();
1462        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1463
1464        server.allow_connections();
1465        cx.foreground().advance_clock(Duration::from_secs(10));
1466        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1467        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1468
1469        server.forbid_connections();
1470        server.disconnect();
1471        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1472
1473        // Clear cached credentials after authentication fails
1474        server.roll_access_token();
1475        server.allow_connections();
1476        cx.foreground().advance_clock(Duration::from_secs(10));
1477        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1478        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1479    }
1480
1481    #[gpui::test(iterations = 10)]
1482    async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1483        deterministic.forbid_parking();
1484
1485        let user_id = 5;
1486        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1487        let mut status = client.status();
1488
1489        // Time out when client tries to connect.
1490        client.override_authenticate(move |cx| {
1491            cx.foreground().spawn(async move {
1492                Ok(Credentials {
1493                    user_id,
1494                    access_token: "token".into(),
1495                })
1496            })
1497        });
1498        client.override_establish_connection(|_, cx| {
1499            cx.foreground().spawn(async move {
1500                future::pending::<()>().await;
1501                unreachable!()
1502            })
1503        });
1504        let auth_and_connect = cx.spawn({
1505            let client = client.clone();
1506            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1507        });
1508        deterministic.run_until_parked();
1509        assert!(matches!(status.next().await, Some(Status::Connecting)));
1510
1511        deterministic.advance_clock(CONNECTION_TIMEOUT);
1512        assert!(matches!(
1513            status.next().await,
1514            Some(Status::ConnectionError { .. })
1515        ));
1516        auth_and_connect.await.unwrap_err();
1517
1518        // Allow the connection to be established.
1519        let server = FakeServer::for_client(user_id, &client, cx).await;
1520        assert!(matches!(
1521            status.next().await,
1522            Some(Status::Connected { .. })
1523        ));
1524
1525        // Disconnect client.
1526        server.forbid_connections();
1527        server.disconnect();
1528        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1529
1530        // Time out when re-establishing the connection.
1531        server.allow_connections();
1532        client.override_establish_connection(|_, cx| {
1533            cx.foreground().spawn(async move {
1534                future::pending::<()>().await;
1535                unreachable!()
1536            })
1537        });
1538        deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1539        assert!(matches!(
1540            status.next().await,
1541            Some(Status::Reconnecting { .. })
1542        ));
1543
1544        deterministic.advance_clock(CONNECTION_TIMEOUT);
1545        assert!(matches!(
1546            status.next().await,
1547            Some(Status::ReconnectionError { .. })
1548        ));
1549    }
1550
1551    #[gpui::test(iterations = 10)]
1552    async fn test_authenticating_more_than_once(
1553        cx: &mut TestAppContext,
1554        deterministic: Arc<Deterministic>,
1555    ) {
1556        cx.foreground().forbid_parking();
1557
1558        let auth_count = Arc::new(Mutex::new(0));
1559        let dropped_auth_count = Arc::new(Mutex::new(0));
1560        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1561        client.override_authenticate({
1562            let auth_count = auth_count.clone();
1563            let dropped_auth_count = dropped_auth_count.clone();
1564            move |cx| {
1565                let auth_count = auth_count.clone();
1566                let dropped_auth_count = dropped_auth_count.clone();
1567                cx.foreground().spawn(async move {
1568                    *auth_count.lock() += 1;
1569                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1570                    future::pending::<()>().await;
1571                    unreachable!()
1572                })
1573            }
1574        });
1575
1576        let _authenticate = cx.spawn(|cx| {
1577            let client = client.clone();
1578            async move { client.authenticate_and_connect(false, &cx).await }
1579        });
1580        deterministic.run_until_parked();
1581        assert_eq!(*auth_count.lock(), 1);
1582        assert_eq!(*dropped_auth_count.lock(), 0);
1583
1584        let _authenticate = cx.spawn(|cx| {
1585            let client = client.clone();
1586            async move { client.authenticate_and_connect(false, &cx).await }
1587        });
1588        deterministic.run_until_parked();
1589        assert_eq!(*auth_count.lock(), 2);
1590        assert_eq!(*dropped_auth_count.lock(), 1);
1591    }
1592
1593    #[test]
1594    fn test_encode_and_decode_worktree_url() {
1595        let url = encode_worktree_url(5, "deadbeef");
1596        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1597        assert_eq!(
1598            decode_worktree_url(&format!("\n {}\t", url)),
1599            Some((5, "deadbeef".to_string()))
1600        );
1601        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1602    }
1603
1604    #[gpui::test]
1605    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1606        cx.foreground().forbid_parking();
1607
1608        let user_id = 5;
1609        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1610        let server = FakeServer::for_client(user_id, &client, cx).await;
1611
1612        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1613        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1614        client.add_model_message_handler(
1615            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1616                match model.read_with(&cx, |model, _| model.id) {
1617                    1 => done_tx1.try_send(()).unwrap(),
1618                    2 => done_tx2.try_send(()).unwrap(),
1619                    _ => unreachable!(),
1620                }
1621                async { Ok(()) }
1622            },
1623        );
1624        let model1 = cx.add_model(|_| Model {
1625            id: 1,
1626            subscription: None,
1627        });
1628        let model2 = cx.add_model(|_| Model {
1629            id: 2,
1630            subscription: None,
1631        });
1632        let model3 = cx.add_model(|_| Model {
1633            id: 3,
1634            subscription: None,
1635        });
1636
1637        let _subscription1 = client
1638            .subscribe_to_entity(1)
1639            .unwrap()
1640            .set_model(&model1, &mut cx.to_async());
1641        let _subscription2 = client
1642            .subscribe_to_entity(2)
1643            .unwrap()
1644            .set_model(&model2, &mut cx.to_async());
1645        // Ensure dropping a subscription for the same entity type still allows receiving of
1646        // messages for other entity IDs of the same type.
1647        let subscription3 = client
1648            .subscribe_to_entity(3)
1649            .unwrap()
1650            .set_model(&model3, &mut cx.to_async());
1651        drop(subscription3);
1652
1653        server.send(proto::JoinProject { project_id: 1 });
1654        server.send(proto::JoinProject { project_id: 2 });
1655        done_rx1.next().await.unwrap();
1656        done_rx2.next().await.unwrap();
1657    }
1658
1659    #[gpui::test]
1660    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1661        cx.foreground().forbid_parking();
1662
1663        let user_id = 5;
1664        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1665        let server = FakeServer::for_client(user_id, &client, cx).await;
1666
1667        let model = cx.add_model(|_| Model::default());
1668        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1669        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1670        let subscription1 = client.add_message_handler(
1671            model.clone(),
1672            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1673                done_tx1.try_send(()).unwrap();
1674                async { Ok(()) }
1675            },
1676        );
1677        drop(subscription1);
1678        let _subscription2 = client.add_message_handler(
1679            model.clone(),
1680            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1681                done_tx2.try_send(()).unwrap();
1682                async { Ok(()) }
1683            },
1684        );
1685        server.send(proto::Ping {});
1686        done_rx2.next().await.unwrap();
1687    }
1688
1689    #[gpui::test]
1690    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1691        cx.foreground().forbid_parking();
1692
1693        let user_id = 5;
1694        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1695        let server = FakeServer::for_client(user_id, &client, cx).await;
1696
1697        let model = cx.add_model(|_| Model::default());
1698        let (done_tx, mut done_rx) = smol::channel::unbounded();
1699        let subscription = client.add_message_handler(
1700            model.clone(),
1701            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1702                model.update(&mut cx, |model, _| model.subscription.take());
1703                done_tx.try_send(()).unwrap();
1704                async { Ok(()) }
1705            },
1706        );
1707        model.update(cx, |model, _| {
1708            model.subscription = Some(subscription);
1709        });
1710        server.send(proto::Ping {});
1711        done_rx.next().await.unwrap();
1712    }
1713
1714    #[derive(Default)]
1715    struct Model {
1716        id: usize,
1717        subscription: Option<Subscription>,
1718    }
1719
1720    impl Entity for Model {
1721        type Event = ();
1722    }
1723}