client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod channel_store;
   5pub mod telemetry;
   6pub mod user;
   7
   8use anyhow::{anyhow, Context, Result};
   9use async_recursion::async_recursion;
  10use async_tungstenite::tungstenite::{
  11    error::Error as WebsocketError,
  12    http::{Request, StatusCode},
  13};
  14use futures::{
  15    future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
  16    TryStreamExt,
  17};
  18use gpui::{
  19    actions, platform::AppVersion, serde_json, AnyModelHandle, AnyWeakModelHandle,
  20    AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Task, View, ViewContext,
  21    WeakViewHandle,
  22};
  23use lazy_static::lazy_static;
  24use parking_lot::RwLock;
  25use postage::watch;
  26use rand::prelude::*;
  27use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
  28use schemars::JsonSchema;
  29use serde::{Deserialize, Serialize};
  30use std::{
  31    any::TypeId,
  32    collections::HashMap,
  33    convert::TryFrom,
  34    fmt::Write as _,
  35    future::Future,
  36    marker::PhantomData,
  37    path::PathBuf,
  38    sync::{Arc, Weak},
  39    time::{Duration, Instant},
  40};
  41use telemetry::Telemetry;
  42use thiserror::Error;
  43use url::Url;
  44use util::channel::ReleaseChannel;
  45use util::http::HttpClient;
  46use util::{ResultExt, TryFutureExt};
  47
  48pub use channel_store::*;
  49pub use rpc::*;
  50pub use telemetry::ClickhouseEvent;
  51pub use user::*;
  52
  53lazy_static! {
  54    pub static ref ZED_SERVER_URL: String =
  55        std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
  56    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  57        .ok()
  58        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  59    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
  60        .ok()
  61        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  62    pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
  63        .ok()
  64        .and_then(|v| v.parse().ok());
  65    pub static ref ZED_APP_PATH: Option<PathBuf> =
  66        std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
  67}
  68
  69pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
  70pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
  71pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
  72
  73actions!(client, [SignIn, SignOut]);
  74
  75pub fn init_settings(cx: &mut AppContext) {
  76    settings::register::<TelemetrySettings>(cx);
  77}
  78
  79pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
  80    init_settings(cx);
  81
  82    let client = Arc::downgrade(client);
  83    cx.add_global_action({
  84        let client = client.clone();
  85        move |_: &SignIn, cx| {
  86            if let Some(client) = client.upgrade() {
  87                cx.spawn(
  88                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
  89                )
  90                .detach();
  91            }
  92        }
  93    });
  94    cx.add_global_action({
  95        let client = client.clone();
  96        move |_: &SignOut, cx| {
  97            if let Some(client) = client.upgrade() {
  98                cx.spawn(|cx| async move {
  99                    client.disconnect(&cx);
 100                })
 101                .detach();
 102            }
 103        }
 104    });
 105}
 106
 107pub struct Client {
 108    id: usize,
 109    peer: Arc<Peer>,
 110    http: Arc<dyn HttpClient>,
 111    telemetry: Arc<Telemetry>,
 112    state: RwLock<ClientState>,
 113
 114    #[allow(clippy::type_complexity)]
 115    #[cfg(any(test, feature = "test-support"))]
 116    authenticate: RwLock<
 117        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 118    >,
 119
 120    #[allow(clippy::type_complexity)]
 121    #[cfg(any(test, feature = "test-support"))]
 122    establish_connection: RwLock<
 123        Option<
 124            Box<
 125                dyn 'static
 126                    + Send
 127                    + Sync
 128                    + Fn(
 129                        &Credentials,
 130                        &AsyncAppContext,
 131                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 132            >,
 133        >,
 134    >,
 135}
 136
 137#[derive(Error, Debug)]
 138pub enum EstablishConnectionError {
 139    #[error("upgrade required")]
 140    UpgradeRequired,
 141    #[error("unauthorized")]
 142    Unauthorized,
 143    #[error("{0}")]
 144    Other(#[from] anyhow::Error),
 145    #[error("{0}")]
 146    Http(#[from] util::http::Error),
 147    #[error("{0}")]
 148    Io(#[from] std::io::Error),
 149    #[error("{0}")]
 150    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 151}
 152
 153impl From<WebsocketError> for EstablishConnectionError {
 154    fn from(error: WebsocketError) -> Self {
 155        if let WebsocketError::Http(response) = &error {
 156            match response.status() {
 157                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 158                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 159                _ => {}
 160            }
 161        }
 162        EstablishConnectionError::Other(error.into())
 163    }
 164}
 165
 166impl EstablishConnectionError {
 167    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 168        Self::Other(error.into())
 169    }
 170}
 171
 172#[derive(Copy, Clone, Debug, PartialEq)]
 173pub enum Status {
 174    SignedOut,
 175    UpgradeRequired,
 176    Authenticating,
 177    Connecting,
 178    ConnectionError,
 179    Connected {
 180        peer_id: PeerId,
 181        connection_id: ConnectionId,
 182    },
 183    ConnectionLost,
 184    Reauthenticating,
 185    Reconnecting,
 186    ReconnectionError {
 187        next_reconnection: Instant,
 188    },
 189}
 190
 191impl Status {
 192    pub fn is_connected(&self) -> bool {
 193        matches!(self, Self::Connected { .. })
 194    }
 195
 196    pub fn is_signed_out(&self) -> bool {
 197        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 198    }
 199}
 200
 201struct ClientState {
 202    credentials: Option<Credentials>,
 203    status: (watch::Sender<Status>, watch::Receiver<Status>),
 204    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 205    _reconnect_task: Option<Task<()>>,
 206    reconnect_interval: Duration,
 207    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 208    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 209    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 210    #[allow(clippy::type_complexity)]
 211    message_handlers: HashMap<
 212        TypeId,
 213        Arc<
 214            dyn Send
 215                + Sync
 216                + Fn(
 217                    Subscriber,
 218                    Box<dyn AnyTypedEnvelope>,
 219                    &Arc<Client>,
 220                    AsyncAppContext,
 221                ) -> LocalBoxFuture<'static, Result<()>>,
 222        >,
 223    >,
 224}
 225
 226enum WeakSubscriber {
 227    Model(AnyWeakModelHandle),
 228    View(AnyWeakViewHandle),
 229    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 230}
 231
 232enum Subscriber {
 233    Model(AnyModelHandle),
 234    View(AnyWeakViewHandle),
 235}
 236
 237#[derive(Clone, Debug)]
 238pub struct Credentials {
 239    pub user_id: u64,
 240    pub access_token: String,
 241}
 242
 243impl Default for ClientState {
 244    fn default() -> Self {
 245        Self {
 246            credentials: None,
 247            status: watch::channel_with(Status::SignedOut),
 248            entity_id_extractors: Default::default(),
 249            _reconnect_task: None,
 250            reconnect_interval: Duration::from_secs(5),
 251            models_by_message_type: Default::default(),
 252            entities_by_type_and_remote_id: Default::default(),
 253            entity_types_by_message_type: Default::default(),
 254            message_handlers: Default::default(),
 255        }
 256    }
 257}
 258
 259pub enum Subscription {
 260    Entity {
 261        client: Weak<Client>,
 262        id: (TypeId, u64),
 263    },
 264    Message {
 265        client: Weak<Client>,
 266        id: TypeId,
 267    },
 268}
 269
 270impl Drop for Subscription {
 271    fn drop(&mut self) {
 272        match self {
 273            Subscription::Entity { client, id } => {
 274                if let Some(client) = client.upgrade() {
 275                    let mut state = client.state.write();
 276                    let _ = state.entities_by_type_and_remote_id.remove(id);
 277                }
 278            }
 279            Subscription::Message { client, id } => {
 280                if let Some(client) = client.upgrade() {
 281                    let mut state = client.state.write();
 282                    let _ = state.entity_types_by_message_type.remove(id);
 283                    let _ = state.message_handlers.remove(id);
 284                }
 285            }
 286        }
 287    }
 288}
 289
 290pub struct PendingEntitySubscription<T: Entity> {
 291    client: Arc<Client>,
 292    remote_id: u64,
 293    _entity_type: PhantomData<T>,
 294    consumed: bool,
 295}
 296
 297impl<T: Entity> PendingEntitySubscription<T> {
 298    pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
 299        self.consumed = true;
 300        let mut state = self.client.state.write();
 301        let id = (TypeId::of::<T>(), self.remote_id);
 302        let Some(WeakSubscriber::Pending(messages)) =
 303            state.entities_by_type_and_remote_id.remove(&id)
 304        else {
 305            unreachable!()
 306        };
 307
 308        state
 309            .entities_by_type_and_remote_id
 310            .insert(id, WeakSubscriber::Model(model.downgrade().into_any()));
 311        drop(state);
 312        for message in messages {
 313            self.client.handle_message(message, cx);
 314        }
 315        Subscription::Entity {
 316            client: Arc::downgrade(&self.client),
 317            id,
 318        }
 319    }
 320}
 321
 322impl<T: Entity> Drop for PendingEntitySubscription<T> {
 323    fn drop(&mut self) {
 324        if !self.consumed {
 325            let mut state = self.client.state.write();
 326            if let Some(WeakSubscriber::Pending(messages)) = state
 327                .entities_by_type_and_remote_id
 328                .remove(&(TypeId::of::<T>(), self.remote_id))
 329            {
 330                for message in messages {
 331                    log::info!("unhandled message {}", message.payload_type_name());
 332                }
 333            }
 334        }
 335    }
 336}
 337
 338#[derive(Copy, Clone)]
 339pub struct TelemetrySettings {
 340    pub diagnostics: bool,
 341    pub metrics: bool,
 342}
 343
 344#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 345pub struct TelemetrySettingsContent {
 346    pub diagnostics: Option<bool>,
 347    pub metrics: Option<bool>,
 348}
 349
 350impl settings::Setting for TelemetrySettings {
 351    const KEY: Option<&'static str> = Some("telemetry");
 352
 353    type FileContent = TelemetrySettingsContent;
 354
 355    fn load(
 356        default_value: &Self::FileContent,
 357        user_values: &[&Self::FileContent],
 358        _: &AppContext,
 359    ) -> Result<Self> {
 360        Ok(Self {
 361            diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
 362                default_value
 363                    .diagnostics
 364                    .ok_or_else(Self::missing_default)?,
 365            ),
 366            metrics: user_values
 367                .first()
 368                .and_then(|v| v.metrics)
 369                .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
 370        })
 371    }
 372}
 373
 374impl Client {
 375    pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
 376        Arc::new(Self {
 377            id: 0,
 378            peer: Peer::new(0),
 379            telemetry: Telemetry::new(http.clone(), cx),
 380            http,
 381            state: Default::default(),
 382
 383            #[cfg(any(test, feature = "test-support"))]
 384            authenticate: Default::default(),
 385            #[cfg(any(test, feature = "test-support"))]
 386            establish_connection: Default::default(),
 387        })
 388    }
 389
 390    pub fn id(&self) -> usize {
 391        self.id
 392    }
 393
 394    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 395        self.http.clone()
 396    }
 397
 398    #[cfg(any(test, feature = "test-support"))]
 399    pub fn set_id(&mut self, id: usize) -> &Self {
 400        self.id = id;
 401        self
 402    }
 403
 404    #[cfg(any(test, feature = "test-support"))]
 405    pub fn teardown(&self) {
 406        let mut state = self.state.write();
 407        state._reconnect_task.take();
 408        state.message_handlers.clear();
 409        state.models_by_message_type.clear();
 410        state.entities_by_type_and_remote_id.clear();
 411        state.entity_id_extractors.clear();
 412        self.peer.teardown();
 413    }
 414
 415    #[cfg(any(test, feature = "test-support"))]
 416    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 417    where
 418        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 419    {
 420        *self.authenticate.write() = Some(Box::new(authenticate));
 421        self
 422    }
 423
 424    #[cfg(any(test, feature = "test-support"))]
 425    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 426    where
 427        F: 'static
 428            + Send
 429            + Sync
 430            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 431    {
 432        *self.establish_connection.write() = Some(Box::new(connect));
 433        self
 434    }
 435
 436    pub fn user_id(&self) -> Option<u64> {
 437        self.state
 438            .read()
 439            .credentials
 440            .as_ref()
 441            .map(|credentials| credentials.user_id)
 442    }
 443
 444    pub fn peer_id(&self) -> Option<PeerId> {
 445        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 446            Some(*peer_id)
 447        } else {
 448            None
 449        }
 450    }
 451
 452    pub fn status(&self) -> watch::Receiver<Status> {
 453        self.state.read().status.1.clone()
 454    }
 455
 456    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 457        log::info!("set status on client {}: {:?}", self.id, status);
 458        let mut state = self.state.write();
 459        *state.status.0.borrow_mut() = status;
 460
 461        match status {
 462            Status::Connected { .. } => {
 463                state._reconnect_task = None;
 464            }
 465            Status::ConnectionLost => {
 466                let this = self.clone();
 467                let reconnect_interval = state.reconnect_interval;
 468                state._reconnect_task = Some(cx.spawn(|cx| async move {
 469                    #[cfg(any(test, feature = "test-support"))]
 470                    let mut rng = StdRng::seed_from_u64(0);
 471                    #[cfg(not(any(test, feature = "test-support")))]
 472                    let mut rng = StdRng::from_entropy();
 473
 474                    let mut delay = INITIAL_RECONNECTION_DELAY;
 475                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 476                        log::error!("failed to connect {}", error);
 477                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 478                            this.set_status(
 479                                Status::ReconnectionError {
 480                                    next_reconnection: Instant::now() + delay,
 481                                },
 482                                &cx,
 483                            );
 484                            cx.background().timer(delay).await;
 485                            delay = delay
 486                                .mul_f32(rng.gen_range(1.0..=2.0))
 487                                .min(reconnect_interval);
 488                        } else {
 489                            break;
 490                        }
 491                    }
 492                }));
 493            }
 494            Status::SignedOut | Status::UpgradeRequired => {
 495                cx.read(|cx| self.telemetry.set_authenticated_user_info(None, false, cx));
 496                state._reconnect_task.take();
 497            }
 498            _ => {}
 499        }
 500    }
 501
 502    pub fn add_view_for_remote_entity<T: View>(
 503        self: &Arc<Self>,
 504        remote_id: u64,
 505        cx: &mut ViewContext<T>,
 506    ) -> Subscription {
 507        let id = (TypeId::of::<T>(), remote_id);
 508        self.state
 509            .write()
 510            .entities_by_type_and_remote_id
 511            .insert(id, WeakSubscriber::View(cx.weak_handle().into_any()));
 512        Subscription::Entity {
 513            client: Arc::downgrade(self),
 514            id,
 515        }
 516    }
 517
 518    pub fn subscribe_to_entity<T: Entity>(
 519        self: &Arc<Self>,
 520        remote_id: u64,
 521    ) -> Result<PendingEntitySubscription<T>> {
 522        let id = (TypeId::of::<T>(), remote_id);
 523
 524        let mut state = self.state.write();
 525        if state.entities_by_type_and_remote_id.contains_key(&id) {
 526            return Err(anyhow!("already subscribed to entity"));
 527        } else {
 528            state
 529                .entities_by_type_and_remote_id
 530                .insert(id, WeakSubscriber::Pending(Default::default()));
 531            Ok(PendingEntitySubscription {
 532                client: self.clone(),
 533                remote_id,
 534                consumed: false,
 535                _entity_type: PhantomData,
 536            })
 537        }
 538    }
 539
 540    pub fn add_message_handler<M, E, H, F>(
 541        self: &Arc<Self>,
 542        model: ModelHandle<E>,
 543        handler: H,
 544    ) -> Subscription
 545    where
 546        M: EnvelopedMessage,
 547        E: Entity,
 548        H: 'static
 549            + Send
 550            + Sync
 551            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 552        F: 'static + Future<Output = Result<()>>,
 553    {
 554        let message_type_id = TypeId::of::<M>();
 555
 556        let mut state = self.state.write();
 557        state
 558            .models_by_message_type
 559            .insert(message_type_id, model.downgrade().into_any());
 560
 561        let prev_handler = state.message_handlers.insert(
 562            message_type_id,
 563            Arc::new(move |handle, envelope, client, cx| {
 564                let handle = if let Subscriber::Model(handle) = handle {
 565                    handle
 566                } else {
 567                    unreachable!();
 568                };
 569                let model = handle.downcast::<E>().unwrap();
 570                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 571                handler(model, *envelope, client.clone(), cx).boxed_local()
 572            }),
 573        );
 574        if prev_handler.is_some() {
 575            panic!("registered handler for the same message twice");
 576        }
 577
 578        Subscription::Message {
 579            client: Arc::downgrade(self),
 580            id: message_type_id,
 581        }
 582    }
 583
 584    pub fn add_request_handler<M, E, H, F>(
 585        self: &Arc<Self>,
 586        model: ModelHandle<E>,
 587        handler: H,
 588    ) -> Subscription
 589    where
 590        M: RequestMessage,
 591        E: Entity,
 592        H: 'static
 593            + Send
 594            + Sync
 595            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 596        F: 'static + Future<Output = Result<M::Response>>,
 597    {
 598        self.add_message_handler(model, move |handle, envelope, this, cx| {
 599            Self::respond_to_request(
 600                envelope.receipt(),
 601                handler(handle, envelope, this.clone(), cx),
 602                this,
 603            )
 604        })
 605    }
 606
 607    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 608    where
 609        M: EntityMessage,
 610        E: View,
 611        H: 'static
 612            + Send
 613            + Sync
 614            + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 615        F: 'static + Future<Output = Result<()>>,
 616    {
 617        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 618            if let Subscriber::View(handle) = handle {
 619                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 620            } else {
 621                unreachable!();
 622            }
 623        })
 624    }
 625
 626    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 627    where
 628        M: EntityMessage,
 629        E: Entity,
 630        H: 'static
 631            + Send
 632            + Sync
 633            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 634        F: 'static + Future<Output = Result<()>>,
 635    {
 636        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 637            if let Subscriber::Model(handle) = handle {
 638                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 639            } else {
 640                unreachable!();
 641            }
 642        })
 643    }
 644
 645    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 646    where
 647        M: EntityMessage,
 648        E: Entity,
 649        H: 'static
 650            + Send
 651            + Sync
 652            + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 653        F: 'static + Future<Output = Result<()>>,
 654    {
 655        let model_type_id = TypeId::of::<E>();
 656        let message_type_id = TypeId::of::<M>();
 657
 658        let mut state = self.state.write();
 659        state
 660            .entity_types_by_message_type
 661            .insert(message_type_id, model_type_id);
 662        state
 663            .entity_id_extractors
 664            .entry(message_type_id)
 665            .or_insert_with(|| {
 666                |envelope| {
 667                    envelope
 668                        .as_any()
 669                        .downcast_ref::<TypedEnvelope<M>>()
 670                        .unwrap()
 671                        .payload
 672                        .remote_entity_id()
 673                }
 674            });
 675        let prev_handler = state.message_handlers.insert(
 676            message_type_id,
 677            Arc::new(move |handle, envelope, client, cx| {
 678                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 679                handler(handle, *envelope, client.clone(), cx).boxed_local()
 680            }),
 681        );
 682        if prev_handler.is_some() {
 683            panic!("registered handler for the same message twice");
 684        }
 685    }
 686
 687    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 688    where
 689        M: EntityMessage + RequestMessage,
 690        E: Entity,
 691        H: 'static
 692            + Send
 693            + Sync
 694            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 695        F: 'static + Future<Output = Result<M::Response>>,
 696    {
 697        self.add_model_message_handler(move |entity, envelope, client, cx| {
 698            Self::respond_to_request::<M, _>(
 699                envelope.receipt(),
 700                handler(entity, envelope, client.clone(), cx),
 701                client,
 702            )
 703        })
 704    }
 705
 706    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 707    where
 708        M: EntityMessage + RequestMessage,
 709        E: View,
 710        H: 'static
 711            + Send
 712            + Sync
 713            + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 714        F: 'static + Future<Output = Result<M::Response>>,
 715    {
 716        self.add_view_message_handler(move |entity, envelope, client, cx| {
 717            Self::respond_to_request::<M, _>(
 718                envelope.receipt(),
 719                handler(entity, envelope, client.clone(), cx),
 720                client,
 721            )
 722        })
 723    }
 724
 725    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 726        receipt: Receipt<T>,
 727        response: F,
 728        client: Arc<Self>,
 729    ) -> Result<()> {
 730        match response.await {
 731            Ok(response) => {
 732                client.respond(receipt, response)?;
 733                Ok(())
 734            }
 735            Err(error) => {
 736                client.respond_with_error(
 737                    receipt,
 738                    proto::Error {
 739                        message: format!("{:?}", error),
 740                    },
 741                )?;
 742                Err(error)
 743            }
 744        }
 745    }
 746
 747    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 748        read_credentials_from_keychain(cx).is_some()
 749    }
 750
 751    #[async_recursion(?Send)]
 752    pub async fn authenticate_and_connect(
 753        self: &Arc<Self>,
 754        try_keychain: bool,
 755        cx: &AsyncAppContext,
 756    ) -> anyhow::Result<()> {
 757        let was_disconnected = match *self.status().borrow() {
 758            Status::SignedOut => true,
 759            Status::ConnectionError
 760            | Status::ConnectionLost
 761            | Status::Authenticating { .. }
 762            | Status::Reauthenticating { .. }
 763            | Status::ReconnectionError { .. } => false,
 764            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 765                return Ok(())
 766            }
 767            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 768        };
 769
 770        if was_disconnected {
 771            self.set_status(Status::Authenticating, cx);
 772        } else {
 773            self.set_status(Status::Reauthenticating, cx)
 774        }
 775
 776        let mut read_from_keychain = false;
 777        let mut credentials = self.state.read().credentials.clone();
 778        if credentials.is_none() && try_keychain {
 779            credentials = read_credentials_from_keychain(cx);
 780            read_from_keychain = credentials.is_some();
 781        }
 782        if credentials.is_none() {
 783            let mut status_rx = self.status();
 784            let _ = status_rx.next().await;
 785            futures::select_biased! {
 786                authenticate = self.authenticate(cx).fuse() => {
 787                    match authenticate {
 788                        Ok(creds) => credentials = Some(creds),
 789                        Err(err) => {
 790                            self.set_status(Status::ConnectionError, cx);
 791                            return Err(err);
 792                        }
 793                    }
 794                }
 795                _ = status_rx.next().fuse() => {
 796                    return Err(anyhow!("authentication canceled"));
 797                }
 798            }
 799        }
 800        let credentials = credentials.unwrap();
 801
 802        if was_disconnected {
 803            self.set_status(Status::Connecting, cx);
 804        } else {
 805            self.set_status(Status::Reconnecting, cx);
 806        }
 807
 808        let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
 809        futures::select_biased! {
 810            connection = self.establish_connection(&credentials, cx).fuse() => {
 811                match connection {
 812                    Ok(conn) => {
 813                        self.state.write().credentials = Some(credentials.clone());
 814                        if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 815                            write_credentials_to_keychain(&credentials, cx).log_err();
 816                        }
 817
 818                        futures::select_biased! {
 819                            result = self.set_connection(conn, cx).fuse() => result,
 820                            _ = timeout => {
 821                                self.set_status(Status::ConnectionError, cx);
 822                                Err(anyhow!("timed out waiting on hello message from server"))
 823                            }
 824                        }
 825                    }
 826                    Err(EstablishConnectionError::Unauthorized) => {
 827                        self.state.write().credentials.take();
 828                        if read_from_keychain {
 829                            cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 830                            self.set_status(Status::SignedOut, cx);
 831                            self.authenticate_and_connect(false, cx).await
 832                        } else {
 833                            self.set_status(Status::ConnectionError, cx);
 834                            Err(EstablishConnectionError::Unauthorized)?
 835                        }
 836                    }
 837                    Err(EstablishConnectionError::UpgradeRequired) => {
 838                        self.set_status(Status::UpgradeRequired, cx);
 839                        Err(EstablishConnectionError::UpgradeRequired)?
 840                    }
 841                    Err(error) => {
 842                        self.set_status(Status::ConnectionError, cx);
 843                        Err(error)?
 844                    }
 845                }
 846            }
 847            _ = &mut timeout => {
 848                self.set_status(Status::ConnectionError, cx);
 849                Err(anyhow!("timed out trying to establish connection"))
 850            }
 851        }
 852    }
 853
 854    async fn set_connection(
 855        self: &Arc<Self>,
 856        conn: Connection,
 857        cx: &AsyncAppContext,
 858    ) -> Result<()> {
 859        let executor = cx.background();
 860        log::info!("add connection to peer");
 861        let (connection_id, handle_io, mut incoming) = self
 862            .peer
 863            .add_connection(conn, move |duration| executor.timer(duration));
 864        let handle_io = cx.background().spawn(handle_io);
 865
 866        let peer_id = async {
 867            log::info!("waiting for server hello");
 868            let message = incoming
 869                .next()
 870                .await
 871                .ok_or_else(|| anyhow!("no hello message received"))?;
 872            log::info!("got server hello");
 873            let hello_message_type_name = message.payload_type_name().to_string();
 874            let hello = message
 875                .into_any()
 876                .downcast::<TypedEnvelope<proto::Hello>>()
 877                .map_err(|_| {
 878                    anyhow!(
 879                        "invalid hello message received: {:?}",
 880                        hello_message_type_name
 881                    )
 882                })?;
 883            let peer_id = hello
 884                .payload
 885                .peer_id
 886                .ok_or_else(|| anyhow!("invalid peer id"))?;
 887            Ok(peer_id)
 888        };
 889
 890        let peer_id = match peer_id.await {
 891            Ok(peer_id) => peer_id,
 892            Err(error) => {
 893                self.peer.disconnect(connection_id);
 894                return Err(error);
 895            }
 896        };
 897
 898        log::info!(
 899            "set status to connected (connection id: {:?}, peer id: {:?})",
 900            connection_id,
 901            peer_id
 902        );
 903        self.set_status(
 904            Status::Connected {
 905                peer_id,
 906                connection_id,
 907            },
 908            cx,
 909        );
 910        cx.foreground()
 911            .spawn({
 912                let cx = cx.clone();
 913                let this = self.clone();
 914                async move {
 915                    while let Some(message) = incoming.next().await {
 916                        this.handle_message(message, &cx);
 917                        // Don't starve the main thread when receiving lots of messages at once.
 918                        smol::future::yield_now().await;
 919                    }
 920                }
 921            })
 922            .detach();
 923
 924        let this = self.clone();
 925        let cx = cx.clone();
 926        cx.foreground()
 927            .spawn(async move {
 928                match handle_io.await {
 929                    Ok(()) => {
 930                        if this.status().borrow().clone()
 931                            == (Status::Connected {
 932                                connection_id,
 933                                peer_id,
 934                            })
 935                        {
 936                            this.set_status(Status::SignedOut, &cx);
 937                        }
 938                    }
 939                    Err(err) => {
 940                        log::error!("connection error: {:?}", err);
 941                        this.set_status(Status::ConnectionLost, &cx);
 942                    }
 943                }
 944            })
 945            .detach();
 946
 947        Ok(())
 948    }
 949
 950    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 951        #[cfg(any(test, feature = "test-support"))]
 952        if let Some(callback) = self.authenticate.read().as_ref() {
 953            return callback(cx);
 954        }
 955
 956        self.authenticate_with_browser(cx)
 957    }
 958
 959    fn establish_connection(
 960        self: &Arc<Self>,
 961        credentials: &Credentials,
 962        cx: &AsyncAppContext,
 963    ) -> Task<Result<Connection, EstablishConnectionError>> {
 964        #[cfg(any(test, feature = "test-support"))]
 965        if let Some(callback) = self.establish_connection.read().as_ref() {
 966            return callback(credentials, cx);
 967        }
 968
 969        self.establish_websocket_connection(credentials, cx)
 970    }
 971
 972    async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
 973        let preview_param = if is_preview { "?preview=1" } else { "" };
 974        let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
 975        let response = http.get(&url, Default::default(), false).await?;
 976
 977        // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
 978        // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
 979        // which requires authorization via an HTTP header.
 980        //
 981        // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
 982        // of a collab server. In that case, a request to the /rpc endpoint will
 983        // return an 'unauthorized' response.
 984        let collab_url = if response.status().is_redirection() {
 985            response
 986                .headers()
 987                .get("Location")
 988                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 989                .to_str()
 990                .map_err(EstablishConnectionError::other)?
 991                .to_string()
 992        } else if response.status() == StatusCode::UNAUTHORIZED {
 993            url
 994        } else {
 995            Err(anyhow!(
 996                "unexpected /rpc response status {}",
 997                response.status()
 998            ))?
 999        };
1000
1001        Url::parse(&collab_url).context("invalid rpc url")
1002    }
1003
1004    fn establish_websocket_connection(
1005        self: &Arc<Self>,
1006        credentials: &Credentials,
1007        cx: &AsyncAppContext,
1008    ) -> Task<Result<Connection, EstablishConnectionError>> {
1009        let is_preview = cx.read(|cx| {
1010            if cx.has_global::<ReleaseChannel>() {
1011                *cx.global::<ReleaseChannel>() == ReleaseChannel::Preview
1012            } else {
1013                false
1014            }
1015        });
1016
1017        let request = Request::builder()
1018            .header(
1019                "Authorization",
1020                format!("{} {}", credentials.user_id, credentials.access_token),
1021            )
1022            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
1023
1024        let http = self.http.clone();
1025        cx.background().spawn(async move {
1026            let mut rpc_url = Self::get_rpc_url(http, is_preview).await?;
1027            let rpc_host = rpc_url
1028                .host_str()
1029                .zip(rpc_url.port_or_known_default())
1030                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1031            let stream = smol::net::TcpStream::connect(rpc_host).await?;
1032
1033            log::info!("connected to rpc endpoint {}", rpc_url);
1034
1035            match rpc_url.scheme() {
1036                "https" => {
1037                    rpc_url.set_scheme("wss").unwrap();
1038                    let request = request.uri(rpc_url.as_str()).body(())?;
1039                    let (stream, _) =
1040                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1041                    Ok(Connection::new(
1042                        stream
1043                            .map_err(|error| anyhow!(error))
1044                            .sink_map_err(|error| anyhow!(error)),
1045                    ))
1046                }
1047                "http" => {
1048                    rpc_url.set_scheme("ws").unwrap();
1049                    let request = request.uri(rpc_url.as_str()).body(())?;
1050                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1051                    Ok(Connection::new(
1052                        stream
1053                            .map_err(|error| anyhow!(error))
1054                            .sink_map_err(|error| anyhow!(error)),
1055                    ))
1056                }
1057                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1058            }
1059        })
1060    }
1061
1062    pub fn authenticate_with_browser(
1063        self: &Arc<Self>,
1064        cx: &AsyncAppContext,
1065    ) -> Task<Result<Credentials>> {
1066        let platform = cx.platform();
1067        let executor = cx.background();
1068        let http = self.http.clone();
1069
1070        executor.clone().spawn(async move {
1071            // Generate a pair of asymmetric encryption keys. The public key will be used by the
1072            // zed server to encrypt the user's access token, so that it can'be intercepted by
1073            // any other app running on the user's device.
1074            let (public_key, private_key) =
1075                rpc::auth::keypair().expect("failed to generate keypair for auth");
1076            let public_key_string =
1077                String::try_from(public_key).expect("failed to serialize public key for auth");
1078
1079            if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1080                return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1081            }
1082
1083            // Start an HTTP server to receive the redirect from Zed's sign-in page.
1084            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1085            let port = server.server_addr().port();
1086
1087            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1088            // that the user is signing in from a Zed app running on the same device.
1089            let mut url = format!(
1090                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1091                *ZED_SERVER_URL, port, public_key_string
1092            );
1093
1094            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1095                log::info!("impersonating user @{}", impersonate_login);
1096                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1097            }
1098
1099            platform.open_url(&url);
1100
1101            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1102            // access token from the query params.
1103            //
1104            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1105            // custom URL scheme instead of this local HTTP server.
1106            let (user_id, access_token) = executor
1107                .spawn(async move {
1108                    for _ in 0..100 {
1109                        if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1110                            let path = req.url();
1111                            let mut user_id = None;
1112                            let mut access_token = None;
1113                            let url = Url::parse(&format!("http://example.com{}", path))
1114                                .context("failed to parse login notification url")?;
1115                            for (key, value) in url.query_pairs() {
1116                                if key == "access_token" {
1117                                    access_token = Some(value.to_string());
1118                                } else if key == "user_id" {
1119                                    user_id = Some(value.to_string());
1120                                }
1121                            }
1122
1123                            let post_auth_url =
1124                                format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1125                            req.respond(
1126                                tiny_http::Response::empty(302).with_header(
1127                                    tiny_http::Header::from_bytes(
1128                                        &b"Location"[..],
1129                                        post_auth_url.as_bytes(),
1130                                    )
1131                                    .unwrap(),
1132                                ),
1133                            )
1134                            .context("failed to respond to login http request")?;
1135                            return Ok((
1136                                user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1137                                access_token
1138                                    .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1139                            ));
1140                        }
1141                    }
1142
1143                    Err(anyhow!("didn't receive login redirect"))
1144                })
1145                .await?;
1146
1147            let access_token = private_key
1148                .decrypt_string(&access_token)
1149                .context("failed to decrypt access token")?;
1150            platform.activate(true);
1151
1152            Ok(Credentials {
1153                user_id: user_id.parse()?,
1154                access_token,
1155            })
1156        })
1157    }
1158
1159    async fn authenticate_as_admin(
1160        http: Arc<dyn HttpClient>,
1161        login: String,
1162        mut api_token: String,
1163    ) -> Result<Credentials> {
1164        #[derive(Deserialize)]
1165        struct AuthenticatedUserResponse {
1166            user: User,
1167        }
1168
1169        #[derive(Deserialize)]
1170        struct User {
1171            id: u64,
1172        }
1173
1174        // Use the collab server's admin API to retrieve the id
1175        // of the impersonated user.
1176        let mut url = Self::get_rpc_url(http.clone(), false).await?;
1177        url.set_path("/user");
1178        url.set_query(Some(&format!("github_login={login}")));
1179        let request = Request::get(url.as_str())
1180            .header("Authorization", format!("token {api_token}"))
1181            .body("".into())?;
1182
1183        let mut response = http.send(request).await?;
1184        let mut body = String::new();
1185        response.body_mut().read_to_string(&mut body).await?;
1186        if !response.status().is_success() {
1187            Err(anyhow!(
1188                "admin user request failed {} - {}",
1189                response.status().as_u16(),
1190                body,
1191            ))?;
1192        }
1193        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1194
1195        // Use the admin API token to authenticate as the impersonated user.
1196        api_token.insert_str(0, "ADMIN_TOKEN:");
1197        Ok(Credentials {
1198            user_id: response.user.id,
1199            access_token: api_token,
1200        })
1201    }
1202
1203    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1204        self.peer.teardown();
1205        self.set_status(Status::SignedOut, cx);
1206    }
1207
1208    fn connection_id(&self) -> Result<ConnectionId> {
1209        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1210            Ok(connection_id)
1211        } else {
1212            Err(anyhow!("not connected"))
1213        }
1214    }
1215
1216    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1217        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1218        self.peer.send(self.connection_id()?, message)
1219    }
1220
1221    pub fn request<T: RequestMessage>(
1222        &self,
1223        request: T,
1224    ) -> impl Future<Output = Result<T::Response>> {
1225        self.request_envelope(request)
1226            .map_ok(|envelope| envelope.payload)
1227    }
1228
1229    pub fn request_envelope<T: RequestMessage>(
1230        &self,
1231        request: T,
1232    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1233        let client_id = self.id;
1234        log::debug!(
1235            "rpc request start. client_id:{}. name:{}",
1236            client_id,
1237            T::NAME
1238        );
1239        let response = self
1240            .connection_id()
1241            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1242        async move {
1243            let response = response?.await;
1244            log::debug!(
1245                "rpc request finish. client_id:{}. name:{}",
1246                client_id,
1247                T::NAME
1248            );
1249            response
1250        }
1251    }
1252
1253    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1254        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1255        self.peer.respond(receipt, response)
1256    }
1257
1258    fn respond_with_error<T: RequestMessage>(
1259        &self,
1260        receipt: Receipt<T>,
1261        error: proto::Error,
1262    ) -> Result<()> {
1263        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1264        self.peer.respond_with_error(receipt, error)
1265    }
1266
1267    fn handle_message(
1268        self: &Arc<Client>,
1269        message: Box<dyn AnyTypedEnvelope>,
1270        cx: &AsyncAppContext,
1271    ) {
1272        let mut state = self.state.write();
1273        let type_name = message.payload_type_name();
1274        let payload_type_id = message.payload_type_id();
1275        let sender_id = message.original_sender_id();
1276
1277        let mut subscriber = None;
1278
1279        if let Some(message_model) = state
1280            .models_by_message_type
1281            .get(&payload_type_id)
1282            .and_then(|model| model.upgrade(cx))
1283        {
1284            subscriber = Some(Subscriber::Model(message_model));
1285        } else if let Some((extract_entity_id, entity_type_id)) =
1286            state.entity_id_extractors.get(&payload_type_id).zip(
1287                state
1288                    .entity_types_by_message_type
1289                    .get(&payload_type_id)
1290                    .copied(),
1291            )
1292        {
1293            let entity_id = (extract_entity_id)(message.as_ref());
1294
1295            match state
1296                .entities_by_type_and_remote_id
1297                .get_mut(&(entity_type_id, entity_id))
1298            {
1299                Some(WeakSubscriber::Pending(pending)) => {
1300                    pending.push(message);
1301                    return;
1302                }
1303                Some(weak_subscriber @ _) => match weak_subscriber {
1304                    WeakSubscriber::Model(handle) => {
1305                        subscriber = handle.upgrade(cx).map(Subscriber::Model);
1306                    }
1307                    WeakSubscriber::View(handle) => {
1308                        subscriber = Some(Subscriber::View(handle.clone()));
1309                    }
1310                    WeakSubscriber::Pending(_) => {}
1311                },
1312                _ => {}
1313            }
1314        }
1315
1316        let subscriber = if let Some(subscriber) = subscriber {
1317            subscriber
1318        } else {
1319            log::info!("unhandled message {}", type_name);
1320            self.peer.respond_with_unhandled_message(message).log_err();
1321            return;
1322        };
1323
1324        let handler = state.message_handlers.get(&payload_type_id).cloned();
1325        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1326        // It also ensures we don't hold the lock while yielding back to the executor, as
1327        // that might cause the executor thread driving this future to block indefinitely.
1328        drop(state);
1329
1330        if let Some(handler) = handler {
1331            let future = handler(subscriber, message, &self, cx.clone());
1332            let client_id = self.id;
1333            log::debug!(
1334                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1335                client_id,
1336                sender_id,
1337                type_name
1338            );
1339            cx.foreground()
1340                .spawn(async move {
1341                    match future.await {
1342                        Ok(()) => {
1343                            log::debug!(
1344                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1345                                client_id,
1346                                sender_id,
1347                                type_name
1348                            );
1349                        }
1350                        Err(error) => {
1351                            log::error!(
1352                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1353                                client_id,
1354                                sender_id,
1355                                type_name,
1356                                error
1357                            );
1358                        }
1359                    }
1360                })
1361                .detach();
1362        } else {
1363            log::info!("unhandled message {}", type_name);
1364            self.peer.respond_with_unhandled_message(message).log_err();
1365        }
1366    }
1367
1368    pub fn telemetry(&self) -> &Arc<Telemetry> {
1369        &self.telemetry
1370    }
1371}
1372
1373fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1374    if IMPERSONATE_LOGIN.is_some() {
1375        return None;
1376    }
1377
1378    let (user_id, access_token) = cx
1379        .platform()
1380        .read_credentials(&ZED_SERVER_URL)
1381        .log_err()
1382        .flatten()?;
1383    Some(Credentials {
1384        user_id: user_id.parse().ok()?,
1385        access_token: String::from_utf8(access_token).ok()?,
1386    })
1387}
1388
1389fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1390    cx.platform().write_credentials(
1391        &ZED_SERVER_URL,
1392        &credentials.user_id.to_string(),
1393        credentials.access_token.as_bytes(),
1394    )
1395}
1396
1397const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1398
1399pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1400    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1401}
1402
1403pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1404    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1405    let mut parts = path.split('/');
1406    let id = parts.next()?.parse::<u64>().ok()?;
1407    let access_token = parts.next()?;
1408    if access_token.is_empty() {
1409        return None;
1410    }
1411    Some((id, access_token.to_string()))
1412}
1413
1414#[cfg(test)]
1415mod tests {
1416    use super::*;
1417    use crate::test::FakeServer;
1418    use gpui::{executor::Deterministic, TestAppContext};
1419    use parking_lot::Mutex;
1420    use std::future;
1421    use util::http::FakeHttpClient;
1422
1423    #[gpui::test(iterations = 10)]
1424    async fn test_reconnection(cx: &mut TestAppContext) {
1425        cx.foreground().forbid_parking();
1426
1427        let user_id = 5;
1428        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1429        let server = FakeServer::for_client(user_id, &client, cx).await;
1430        let mut status = client.status();
1431        assert!(matches!(
1432            status.next().await,
1433            Some(Status::Connected { .. })
1434        ));
1435        assert_eq!(server.auth_count(), 1);
1436
1437        server.forbid_connections();
1438        server.disconnect();
1439        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1440
1441        server.allow_connections();
1442        cx.foreground().advance_clock(Duration::from_secs(10));
1443        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1444        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1445
1446        server.forbid_connections();
1447        server.disconnect();
1448        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1449
1450        // Clear cached credentials after authentication fails
1451        server.roll_access_token();
1452        server.allow_connections();
1453        cx.foreground().advance_clock(Duration::from_secs(10));
1454        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1455        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1456    }
1457
1458    #[gpui::test(iterations = 10)]
1459    async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1460        deterministic.forbid_parking();
1461
1462        let user_id = 5;
1463        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1464        let mut status = client.status();
1465
1466        // Time out when client tries to connect.
1467        client.override_authenticate(move |cx| {
1468            cx.foreground().spawn(async move {
1469                Ok(Credentials {
1470                    user_id,
1471                    access_token: "token".into(),
1472                })
1473            })
1474        });
1475        client.override_establish_connection(|_, cx| {
1476            cx.foreground().spawn(async move {
1477                future::pending::<()>().await;
1478                unreachable!()
1479            })
1480        });
1481        let auth_and_connect = cx.spawn({
1482            let client = client.clone();
1483            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1484        });
1485        deterministic.run_until_parked();
1486        assert!(matches!(status.next().await, Some(Status::Connecting)));
1487
1488        deterministic.advance_clock(CONNECTION_TIMEOUT);
1489        assert!(matches!(
1490            status.next().await,
1491            Some(Status::ConnectionError { .. })
1492        ));
1493        auth_and_connect.await.unwrap_err();
1494
1495        // Allow the connection to be established.
1496        let server = FakeServer::for_client(user_id, &client, cx).await;
1497        assert!(matches!(
1498            status.next().await,
1499            Some(Status::Connected { .. })
1500        ));
1501
1502        // Disconnect client.
1503        server.forbid_connections();
1504        server.disconnect();
1505        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1506
1507        // Time out when re-establishing the connection.
1508        server.allow_connections();
1509        client.override_establish_connection(|_, cx| {
1510            cx.foreground().spawn(async move {
1511                future::pending::<()>().await;
1512                unreachable!()
1513            })
1514        });
1515        deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1516        assert!(matches!(
1517            status.next().await,
1518            Some(Status::Reconnecting { .. })
1519        ));
1520
1521        deterministic.advance_clock(CONNECTION_TIMEOUT);
1522        assert!(matches!(
1523            status.next().await,
1524            Some(Status::ReconnectionError { .. })
1525        ));
1526    }
1527
1528    #[gpui::test(iterations = 10)]
1529    async fn test_authenticating_more_than_once(
1530        cx: &mut TestAppContext,
1531        deterministic: Arc<Deterministic>,
1532    ) {
1533        cx.foreground().forbid_parking();
1534
1535        let auth_count = Arc::new(Mutex::new(0));
1536        let dropped_auth_count = Arc::new(Mutex::new(0));
1537        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1538        client.override_authenticate({
1539            let auth_count = auth_count.clone();
1540            let dropped_auth_count = dropped_auth_count.clone();
1541            move |cx| {
1542                let auth_count = auth_count.clone();
1543                let dropped_auth_count = dropped_auth_count.clone();
1544                cx.foreground().spawn(async move {
1545                    *auth_count.lock() += 1;
1546                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1547                    future::pending::<()>().await;
1548                    unreachable!()
1549                })
1550            }
1551        });
1552
1553        let _authenticate = cx.spawn(|cx| {
1554            let client = client.clone();
1555            async move { client.authenticate_and_connect(false, &cx).await }
1556        });
1557        deterministic.run_until_parked();
1558        assert_eq!(*auth_count.lock(), 1);
1559        assert_eq!(*dropped_auth_count.lock(), 0);
1560
1561        let _authenticate = cx.spawn(|cx| {
1562            let client = client.clone();
1563            async move { client.authenticate_and_connect(false, &cx).await }
1564        });
1565        deterministic.run_until_parked();
1566        assert_eq!(*auth_count.lock(), 2);
1567        assert_eq!(*dropped_auth_count.lock(), 1);
1568    }
1569
1570    #[test]
1571    fn test_encode_and_decode_worktree_url() {
1572        let url = encode_worktree_url(5, "deadbeef");
1573        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1574        assert_eq!(
1575            decode_worktree_url(&format!("\n {}\t", url)),
1576            Some((5, "deadbeef".to_string()))
1577        );
1578        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1579    }
1580
1581    #[gpui::test]
1582    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1583        cx.foreground().forbid_parking();
1584
1585        let user_id = 5;
1586        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1587        let server = FakeServer::for_client(user_id, &client, cx).await;
1588
1589        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1590        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1591        client.add_model_message_handler(
1592            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1593                match model.read_with(&cx, |model, _| model.id) {
1594                    1 => done_tx1.try_send(()).unwrap(),
1595                    2 => done_tx2.try_send(()).unwrap(),
1596                    _ => unreachable!(),
1597                }
1598                async { Ok(()) }
1599            },
1600        );
1601        let model1 = cx.add_model(|_| Model {
1602            id: 1,
1603            subscription: None,
1604        });
1605        let model2 = cx.add_model(|_| Model {
1606            id: 2,
1607            subscription: None,
1608        });
1609        let model3 = cx.add_model(|_| Model {
1610            id: 3,
1611            subscription: None,
1612        });
1613
1614        let _subscription1 = client
1615            .subscribe_to_entity(1)
1616            .unwrap()
1617            .set_model(&model1, &mut cx.to_async());
1618        let _subscription2 = client
1619            .subscribe_to_entity(2)
1620            .unwrap()
1621            .set_model(&model2, &mut cx.to_async());
1622        // Ensure dropping a subscription for the same entity type still allows receiving of
1623        // messages for other entity IDs of the same type.
1624        let subscription3 = client
1625            .subscribe_to_entity(3)
1626            .unwrap()
1627            .set_model(&model3, &mut cx.to_async());
1628        drop(subscription3);
1629
1630        server.send(proto::JoinProject { project_id: 1 });
1631        server.send(proto::JoinProject { project_id: 2 });
1632        done_rx1.next().await.unwrap();
1633        done_rx2.next().await.unwrap();
1634    }
1635
1636    #[gpui::test]
1637    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1638        cx.foreground().forbid_parking();
1639
1640        let user_id = 5;
1641        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1642        let server = FakeServer::for_client(user_id, &client, cx).await;
1643
1644        let model = cx.add_model(|_| Model::default());
1645        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1646        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1647        let subscription1 = client.add_message_handler(
1648            model.clone(),
1649            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1650                done_tx1.try_send(()).unwrap();
1651                async { Ok(()) }
1652            },
1653        );
1654        drop(subscription1);
1655        let _subscription2 = client.add_message_handler(
1656            model.clone(),
1657            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1658                done_tx2.try_send(()).unwrap();
1659                async { Ok(()) }
1660            },
1661        );
1662        server.send(proto::Ping {});
1663        done_rx2.next().await.unwrap();
1664    }
1665
1666    #[gpui::test]
1667    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1668        cx.foreground().forbid_parking();
1669
1670        let user_id = 5;
1671        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1672        let server = FakeServer::for_client(user_id, &client, cx).await;
1673
1674        let model = cx.add_model(|_| Model::default());
1675        let (done_tx, mut done_rx) = smol::channel::unbounded();
1676        let subscription = client.add_message_handler(
1677            model.clone(),
1678            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1679                model.update(&mut cx, |model, _| model.subscription.take());
1680                done_tx.try_send(()).unwrap();
1681                async { Ok(()) }
1682            },
1683        );
1684        model.update(cx, |model, _| {
1685            model.subscription = Some(subscription);
1686        });
1687        server.send(proto::Ping {});
1688        done_rx.next().await.unwrap();
1689    }
1690
1691    #[derive(Default)]
1692    struct Model {
1693        id: usize,
1694        subscription: Option<Subscription>,
1695    }
1696
1697    impl Entity for Model {
1698        type Event = ();
1699    }
1700}