client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod telemetry;
   5pub mod user;
   6
   7use anyhow::{anyhow, Context, Result};
   8use async_recursion::async_recursion;
   9use async_tungstenite::tungstenite::{
  10    error::Error as WebsocketError,
  11    http::{Request, StatusCode},
  12};
  13use futures::{
  14    future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
  15    TryStreamExt,
  16};
  17use gpui::{
  18    actions, platform::AppVersion, serde_json, AnyModelHandle, AnyWeakModelHandle,
  19    AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Task, View, ViewContext,
  20    WeakViewHandle,
  21};
  22use lazy_static::lazy_static;
  23use parking_lot::RwLock;
  24use postage::watch;
  25use rand::prelude::*;
  26use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
  27use schemars::JsonSchema;
  28use serde::{Deserialize, Serialize};
  29use std::{
  30    any::TypeId,
  31    collections::HashMap,
  32    convert::TryFrom,
  33    fmt::Write as _,
  34    future::Future,
  35    marker::PhantomData,
  36    path::PathBuf,
  37    sync::{Arc, Weak},
  38    time::{Duration, Instant},
  39};
  40use telemetry::Telemetry;
  41use thiserror::Error;
  42use url::Url;
  43use util::channel::ReleaseChannel;
  44use util::http::HttpClient;
  45use util::{ResultExt, TryFutureExt};
  46
  47pub use rpc::*;
  48pub use telemetry::ClickhouseEvent;
  49pub use user::*;
  50
  51lazy_static! {
  52    pub static ref ZED_SERVER_URL: String =
  53        std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
  54    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  55        .ok()
  56        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  57    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
  58        .ok()
  59        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  60    pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
  61        .ok()
  62        .and_then(|v| v.parse().ok());
  63    pub static ref ZED_APP_PATH: Option<PathBuf> =
  64        std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
  65}
  66
  67pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
  68pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
  69pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
  70
  71actions!(client, [SignIn, SignOut]);
  72
  73pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
  74    let client = Arc::downgrade(client);
  75    settings::register_setting::<TelemetrySettings>(cx);
  76
  77    cx.add_global_action({
  78        let client = client.clone();
  79        move |_: &SignIn, cx| {
  80            if let Some(client) = client.upgrade() {
  81                cx.spawn(
  82                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
  83                )
  84                .detach();
  85            }
  86        }
  87    });
  88    cx.add_global_action({
  89        let client = client.clone();
  90        move |_: &SignOut, cx| {
  91            if let Some(client) = client.upgrade() {
  92                cx.spawn(|cx| async move {
  93                    client.disconnect(&cx);
  94                })
  95                .detach();
  96            }
  97        }
  98    });
  99}
 100
 101pub struct Client {
 102    id: usize,
 103    peer: Arc<Peer>,
 104    http: Arc<dyn HttpClient>,
 105    telemetry: Arc<Telemetry>,
 106    state: RwLock<ClientState>,
 107
 108    #[allow(clippy::type_complexity)]
 109    #[cfg(any(test, feature = "test-support"))]
 110    authenticate: RwLock<
 111        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 112    >,
 113
 114    #[allow(clippy::type_complexity)]
 115    #[cfg(any(test, feature = "test-support"))]
 116    establish_connection: RwLock<
 117        Option<
 118            Box<
 119                dyn 'static
 120                    + Send
 121                    + Sync
 122                    + Fn(
 123                        &Credentials,
 124                        &AsyncAppContext,
 125                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 126            >,
 127        >,
 128    >,
 129}
 130
 131#[derive(Error, Debug)]
 132pub enum EstablishConnectionError {
 133    #[error("upgrade required")]
 134    UpgradeRequired,
 135    #[error("unauthorized")]
 136    Unauthorized,
 137    #[error("{0}")]
 138    Other(#[from] anyhow::Error),
 139    #[error("{0}")]
 140    Http(#[from] util::http::Error),
 141    #[error("{0}")]
 142    Io(#[from] std::io::Error),
 143    #[error("{0}")]
 144    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 145}
 146
 147impl From<WebsocketError> for EstablishConnectionError {
 148    fn from(error: WebsocketError) -> Self {
 149        if let WebsocketError::Http(response) = &error {
 150            match response.status() {
 151                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 152                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 153                _ => {}
 154            }
 155        }
 156        EstablishConnectionError::Other(error.into())
 157    }
 158}
 159
 160impl EstablishConnectionError {
 161    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 162        Self::Other(error.into())
 163    }
 164}
 165
 166#[derive(Copy, Clone, Debug, PartialEq)]
 167pub enum Status {
 168    SignedOut,
 169    UpgradeRequired,
 170    Authenticating,
 171    Connecting,
 172    ConnectionError,
 173    Connected {
 174        peer_id: PeerId,
 175        connection_id: ConnectionId,
 176    },
 177    ConnectionLost,
 178    Reauthenticating,
 179    Reconnecting,
 180    ReconnectionError {
 181        next_reconnection: Instant,
 182    },
 183}
 184
 185impl Status {
 186    pub fn is_connected(&self) -> bool {
 187        matches!(self, Self::Connected { .. })
 188    }
 189
 190    pub fn is_signed_out(&self) -> bool {
 191        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 192    }
 193}
 194
 195struct ClientState {
 196    credentials: Option<Credentials>,
 197    status: (watch::Sender<Status>, watch::Receiver<Status>),
 198    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 199    _reconnect_task: Option<Task<()>>,
 200    reconnect_interval: Duration,
 201    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 202    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 203    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 204    #[allow(clippy::type_complexity)]
 205    message_handlers: HashMap<
 206        TypeId,
 207        Arc<
 208            dyn Send
 209                + Sync
 210                + Fn(
 211                    Subscriber,
 212                    Box<dyn AnyTypedEnvelope>,
 213                    &Arc<Client>,
 214                    AsyncAppContext,
 215                ) -> LocalBoxFuture<'static, Result<()>>,
 216        >,
 217    >,
 218}
 219
 220enum WeakSubscriber {
 221    Model(AnyWeakModelHandle),
 222    View(AnyWeakViewHandle),
 223    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 224}
 225
 226enum Subscriber {
 227    Model(AnyModelHandle),
 228    View(AnyWeakViewHandle),
 229}
 230
 231#[derive(Clone, Debug)]
 232pub struct Credentials {
 233    pub user_id: u64,
 234    pub access_token: String,
 235}
 236
 237impl Default for ClientState {
 238    fn default() -> Self {
 239        Self {
 240            credentials: None,
 241            status: watch::channel_with(Status::SignedOut),
 242            entity_id_extractors: Default::default(),
 243            _reconnect_task: None,
 244            reconnect_interval: Duration::from_secs(5),
 245            models_by_message_type: Default::default(),
 246            entities_by_type_and_remote_id: Default::default(),
 247            entity_types_by_message_type: Default::default(),
 248            message_handlers: Default::default(),
 249        }
 250    }
 251}
 252
 253pub enum Subscription {
 254    Entity {
 255        client: Weak<Client>,
 256        id: (TypeId, u64),
 257    },
 258    Message {
 259        client: Weak<Client>,
 260        id: TypeId,
 261    },
 262}
 263
 264impl Drop for Subscription {
 265    fn drop(&mut self) {
 266        match self {
 267            Subscription::Entity { client, id } => {
 268                if let Some(client) = client.upgrade() {
 269                    let mut state = client.state.write();
 270                    let _ = state.entities_by_type_and_remote_id.remove(id);
 271                }
 272            }
 273            Subscription::Message { client, id } => {
 274                if let Some(client) = client.upgrade() {
 275                    let mut state = client.state.write();
 276                    let _ = state.entity_types_by_message_type.remove(id);
 277                    let _ = state.message_handlers.remove(id);
 278                }
 279            }
 280        }
 281    }
 282}
 283
 284pub struct PendingEntitySubscription<T: Entity> {
 285    client: Arc<Client>,
 286    remote_id: u64,
 287    _entity_type: PhantomData<T>,
 288    consumed: bool,
 289}
 290
 291impl<T: Entity> PendingEntitySubscription<T> {
 292    pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
 293        self.consumed = true;
 294        let mut state = self.client.state.write();
 295        let id = (TypeId::of::<T>(), self.remote_id);
 296        let Some(WeakSubscriber::Pending(messages)) =
 297            state.entities_by_type_and_remote_id.remove(&id)
 298        else {
 299            unreachable!()
 300        };
 301
 302        state
 303            .entities_by_type_and_remote_id
 304            .insert(id, WeakSubscriber::Model(model.downgrade().into_any()));
 305        drop(state);
 306        for message in messages {
 307            self.client.handle_message(message, cx);
 308        }
 309        Subscription::Entity {
 310            client: Arc::downgrade(&self.client),
 311            id,
 312        }
 313    }
 314}
 315
 316impl<T: Entity> Drop for PendingEntitySubscription<T> {
 317    fn drop(&mut self) {
 318        if !self.consumed {
 319            let mut state = self.client.state.write();
 320            if let Some(WeakSubscriber::Pending(messages)) = state
 321                .entities_by_type_and_remote_id
 322                .remove(&(TypeId::of::<T>(), self.remote_id))
 323            {
 324                for message in messages {
 325                    log::info!("unhandled message {}", message.payload_type_name());
 326                }
 327            }
 328        }
 329    }
 330}
 331
 332#[derive(Copy, Clone)]
 333pub struct TelemetrySettings {
 334    pub diagnostics: bool,
 335    pub metrics: bool,
 336}
 337
 338#[derive(Clone, Serialize, Deserialize, JsonSchema)]
 339pub struct TelemetrySettingsContent {
 340    pub diagnostics: Option<bool>,
 341    pub metrics: Option<bool>,
 342}
 343
 344impl settings::Setting for TelemetrySettings {
 345    const KEY: Option<&'static str> = Some("telemetry");
 346
 347    type FileContent = TelemetrySettingsContent;
 348
 349    fn load(
 350        default_value: &Self::FileContent,
 351        user_values: &[&Self::FileContent],
 352        _: &AppContext,
 353    ) -> Self {
 354        Self {
 355            diagnostics: user_values
 356                .first()
 357                .and_then(|v| v.diagnostics)
 358                .unwrap_or(default_value.diagnostics.unwrap()),
 359            metrics: user_values
 360                .first()
 361                .and_then(|v| v.metrics)
 362                .unwrap_or(default_value.metrics.unwrap()),
 363        }
 364    }
 365}
 366
 367impl Client {
 368    pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
 369        Arc::new(Self {
 370            id: 0,
 371            peer: Peer::new(0),
 372            telemetry: Telemetry::new(http.clone(), cx),
 373            http,
 374            state: Default::default(),
 375
 376            #[cfg(any(test, feature = "test-support"))]
 377            authenticate: Default::default(),
 378            #[cfg(any(test, feature = "test-support"))]
 379            establish_connection: Default::default(),
 380        })
 381    }
 382
 383    pub fn id(&self) -> usize {
 384        self.id
 385    }
 386
 387    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 388        self.http.clone()
 389    }
 390
 391    #[cfg(any(test, feature = "test-support"))]
 392    pub fn set_id(&mut self, id: usize) -> &Self {
 393        self.id = id;
 394        self
 395    }
 396
 397    #[cfg(any(test, feature = "test-support"))]
 398    pub fn teardown(&self) {
 399        let mut state = self.state.write();
 400        state._reconnect_task.take();
 401        state.message_handlers.clear();
 402        state.models_by_message_type.clear();
 403        state.entities_by_type_and_remote_id.clear();
 404        state.entity_id_extractors.clear();
 405        self.peer.teardown();
 406    }
 407
 408    #[cfg(any(test, feature = "test-support"))]
 409    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 410    where
 411        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 412    {
 413        *self.authenticate.write() = Some(Box::new(authenticate));
 414        self
 415    }
 416
 417    #[cfg(any(test, feature = "test-support"))]
 418    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 419    where
 420        F: 'static
 421            + Send
 422            + Sync
 423            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 424    {
 425        *self.establish_connection.write() = Some(Box::new(connect));
 426        self
 427    }
 428
 429    pub fn user_id(&self) -> Option<u64> {
 430        self.state
 431            .read()
 432            .credentials
 433            .as_ref()
 434            .map(|credentials| credentials.user_id)
 435    }
 436
 437    pub fn peer_id(&self) -> Option<PeerId> {
 438        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 439            Some(*peer_id)
 440        } else {
 441            None
 442        }
 443    }
 444
 445    pub fn status(&self) -> watch::Receiver<Status> {
 446        self.state.read().status.1.clone()
 447    }
 448
 449    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 450        log::info!("set status on client {}: {:?}", self.id, status);
 451        let mut state = self.state.write();
 452        *state.status.0.borrow_mut() = status;
 453
 454        match status {
 455            Status::Connected { .. } => {
 456                state._reconnect_task = None;
 457            }
 458            Status::ConnectionLost => {
 459                let this = self.clone();
 460                let reconnect_interval = state.reconnect_interval;
 461                state._reconnect_task = Some(cx.spawn(|cx| async move {
 462                    #[cfg(any(test, feature = "test-support"))]
 463                    let mut rng = StdRng::seed_from_u64(0);
 464                    #[cfg(not(any(test, feature = "test-support")))]
 465                    let mut rng = StdRng::from_entropy();
 466
 467                    let mut delay = INITIAL_RECONNECTION_DELAY;
 468                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 469                        log::error!("failed to connect {}", error);
 470                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 471                            this.set_status(
 472                                Status::ReconnectionError {
 473                                    next_reconnection: Instant::now() + delay,
 474                                },
 475                                &cx,
 476                            );
 477                            cx.background().timer(delay).await;
 478                            delay = delay
 479                                .mul_f32(rng.gen_range(1.0..=2.0))
 480                                .min(reconnect_interval);
 481                        } else {
 482                            break;
 483                        }
 484                    }
 485                }));
 486            }
 487            Status::SignedOut | Status::UpgradeRequired => {
 488                cx.read(|cx| self.telemetry.set_authenticated_user_info(None, false, cx));
 489                state._reconnect_task.take();
 490            }
 491            _ => {}
 492        }
 493    }
 494
 495    pub fn add_view_for_remote_entity<T: View>(
 496        self: &Arc<Self>,
 497        remote_id: u64,
 498        cx: &mut ViewContext<T>,
 499    ) -> Subscription {
 500        let id = (TypeId::of::<T>(), remote_id);
 501        self.state
 502            .write()
 503            .entities_by_type_and_remote_id
 504            .insert(id, WeakSubscriber::View(cx.weak_handle().into_any()));
 505        Subscription::Entity {
 506            client: Arc::downgrade(self),
 507            id,
 508        }
 509    }
 510
 511    pub fn subscribe_to_entity<T: Entity>(
 512        self: &Arc<Self>,
 513        remote_id: u64,
 514    ) -> Result<PendingEntitySubscription<T>> {
 515        let id = (TypeId::of::<T>(), remote_id);
 516
 517        let mut state = self.state.write();
 518        if state.entities_by_type_and_remote_id.contains_key(&id) {
 519            return Err(anyhow!("already subscribed to entity"));
 520        } else {
 521            state
 522                .entities_by_type_and_remote_id
 523                .insert(id, WeakSubscriber::Pending(Default::default()));
 524            Ok(PendingEntitySubscription {
 525                client: self.clone(),
 526                remote_id,
 527                consumed: false,
 528                _entity_type: PhantomData,
 529            })
 530        }
 531    }
 532
 533    pub fn add_message_handler<M, E, H, F>(
 534        self: &Arc<Self>,
 535        model: ModelHandle<E>,
 536        handler: H,
 537    ) -> Subscription
 538    where
 539        M: EnvelopedMessage,
 540        E: Entity,
 541        H: 'static
 542            + Send
 543            + Sync
 544            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 545        F: 'static + Future<Output = Result<()>>,
 546    {
 547        let message_type_id = TypeId::of::<M>();
 548
 549        let mut state = self.state.write();
 550        state
 551            .models_by_message_type
 552            .insert(message_type_id, model.downgrade().into_any());
 553
 554        let prev_handler = state.message_handlers.insert(
 555            message_type_id,
 556            Arc::new(move |handle, envelope, client, cx| {
 557                let handle = if let Subscriber::Model(handle) = handle {
 558                    handle
 559                } else {
 560                    unreachable!();
 561                };
 562                let model = handle.downcast::<E>().unwrap();
 563                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 564                handler(model, *envelope, client.clone(), cx).boxed_local()
 565            }),
 566        );
 567        if prev_handler.is_some() {
 568            panic!("registered handler for the same message twice");
 569        }
 570
 571        Subscription::Message {
 572            client: Arc::downgrade(self),
 573            id: message_type_id,
 574        }
 575    }
 576
 577    pub fn add_request_handler<M, E, H, F>(
 578        self: &Arc<Self>,
 579        model: ModelHandle<E>,
 580        handler: H,
 581    ) -> Subscription
 582    where
 583        M: RequestMessage,
 584        E: Entity,
 585        H: 'static
 586            + Send
 587            + Sync
 588            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 589        F: 'static + Future<Output = Result<M::Response>>,
 590    {
 591        self.add_message_handler(model, move |handle, envelope, this, cx| {
 592            Self::respond_to_request(
 593                envelope.receipt(),
 594                handler(handle, envelope, this.clone(), cx),
 595                this,
 596            )
 597        })
 598    }
 599
 600    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 601    where
 602        M: EntityMessage,
 603        E: View,
 604        H: 'static
 605            + Send
 606            + Sync
 607            + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 608        F: 'static + Future<Output = Result<()>>,
 609    {
 610        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 611            if let Subscriber::View(handle) = handle {
 612                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 613            } else {
 614                unreachable!();
 615            }
 616        })
 617    }
 618
 619    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 620    where
 621        M: EntityMessage,
 622        E: Entity,
 623        H: 'static
 624            + Send
 625            + Sync
 626            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 627        F: 'static + Future<Output = Result<()>>,
 628    {
 629        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 630            if let Subscriber::Model(handle) = handle {
 631                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 632            } else {
 633                unreachable!();
 634            }
 635        })
 636    }
 637
 638    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 639    where
 640        M: EntityMessage,
 641        E: Entity,
 642        H: 'static
 643            + Send
 644            + Sync
 645            + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 646        F: 'static + Future<Output = Result<()>>,
 647    {
 648        let model_type_id = TypeId::of::<E>();
 649        let message_type_id = TypeId::of::<M>();
 650
 651        let mut state = self.state.write();
 652        state
 653            .entity_types_by_message_type
 654            .insert(message_type_id, model_type_id);
 655        state
 656            .entity_id_extractors
 657            .entry(message_type_id)
 658            .or_insert_with(|| {
 659                |envelope| {
 660                    envelope
 661                        .as_any()
 662                        .downcast_ref::<TypedEnvelope<M>>()
 663                        .unwrap()
 664                        .payload
 665                        .remote_entity_id()
 666                }
 667            });
 668        let prev_handler = state.message_handlers.insert(
 669            message_type_id,
 670            Arc::new(move |handle, envelope, client, cx| {
 671                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 672                handler(handle, *envelope, client.clone(), cx).boxed_local()
 673            }),
 674        );
 675        if prev_handler.is_some() {
 676            panic!("registered handler for the same message twice");
 677        }
 678    }
 679
 680    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 681    where
 682        M: EntityMessage + RequestMessage,
 683        E: Entity,
 684        H: 'static
 685            + Send
 686            + Sync
 687            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 688        F: 'static + Future<Output = Result<M::Response>>,
 689    {
 690        self.add_model_message_handler(move |entity, envelope, client, cx| {
 691            Self::respond_to_request::<M, _>(
 692                envelope.receipt(),
 693                handler(entity, envelope, client.clone(), cx),
 694                client,
 695            )
 696        })
 697    }
 698
 699    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 700    where
 701        M: EntityMessage + RequestMessage,
 702        E: View,
 703        H: 'static
 704            + Send
 705            + Sync
 706            + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 707        F: 'static + Future<Output = Result<M::Response>>,
 708    {
 709        self.add_view_message_handler(move |entity, envelope, client, cx| {
 710            Self::respond_to_request::<M, _>(
 711                envelope.receipt(),
 712                handler(entity, envelope, client.clone(), cx),
 713                client,
 714            )
 715        })
 716    }
 717
 718    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 719        receipt: Receipt<T>,
 720        response: F,
 721        client: Arc<Self>,
 722    ) -> Result<()> {
 723        match response.await {
 724            Ok(response) => {
 725                client.respond(receipt, response)?;
 726                Ok(())
 727            }
 728            Err(error) => {
 729                client.respond_with_error(
 730                    receipt,
 731                    proto::Error {
 732                        message: format!("{:?}", error),
 733                    },
 734                )?;
 735                Err(error)
 736            }
 737        }
 738    }
 739
 740    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 741        read_credentials_from_keychain(cx).is_some()
 742    }
 743
 744    #[async_recursion(?Send)]
 745    pub async fn authenticate_and_connect(
 746        self: &Arc<Self>,
 747        try_keychain: bool,
 748        cx: &AsyncAppContext,
 749    ) -> anyhow::Result<()> {
 750        let was_disconnected = match *self.status().borrow() {
 751            Status::SignedOut => true,
 752            Status::ConnectionError
 753            | Status::ConnectionLost
 754            | Status::Authenticating { .. }
 755            | Status::Reauthenticating { .. }
 756            | Status::ReconnectionError { .. } => false,
 757            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 758                return Ok(())
 759            }
 760            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 761        };
 762
 763        if was_disconnected {
 764            self.set_status(Status::Authenticating, cx);
 765        } else {
 766            self.set_status(Status::Reauthenticating, cx)
 767        }
 768
 769        let mut read_from_keychain = false;
 770        let mut credentials = self.state.read().credentials.clone();
 771        if credentials.is_none() && try_keychain {
 772            credentials = read_credentials_from_keychain(cx);
 773            read_from_keychain = credentials.is_some();
 774            if read_from_keychain {
 775                cx.read(|cx| {
 776                    self.telemetry().report_mixpanel_event(
 777                        "read credentials from keychain",
 778                        Default::default(),
 779                        *settings::get_setting::<TelemetrySettings>(None, cx),
 780                    );
 781                });
 782            }
 783        }
 784        if credentials.is_none() {
 785            let mut status_rx = self.status();
 786            let _ = status_rx.next().await;
 787            futures::select_biased! {
 788                authenticate = self.authenticate(cx).fuse() => {
 789                    match authenticate {
 790                        Ok(creds) => credentials = Some(creds),
 791                        Err(err) => {
 792                            self.set_status(Status::ConnectionError, cx);
 793                            return Err(err);
 794                        }
 795                    }
 796                }
 797                _ = status_rx.next().fuse() => {
 798                    return Err(anyhow!("authentication canceled"));
 799                }
 800            }
 801        }
 802        let credentials = credentials.unwrap();
 803
 804        if was_disconnected {
 805            self.set_status(Status::Connecting, cx);
 806        } else {
 807            self.set_status(Status::Reconnecting, cx);
 808        }
 809
 810        let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
 811        futures::select_biased! {
 812            connection = self.establish_connection(&credentials, cx).fuse() => {
 813                match connection {
 814                    Ok(conn) => {
 815                        self.state.write().credentials = Some(credentials.clone());
 816                        if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 817                            write_credentials_to_keychain(&credentials, cx).log_err();
 818                        }
 819
 820                        futures::select_biased! {
 821                            result = self.set_connection(conn, cx).fuse() => result,
 822                            _ = timeout => {
 823                                self.set_status(Status::ConnectionError, cx);
 824                                Err(anyhow!("timed out waiting on hello message from server"))
 825                            }
 826                        }
 827                    }
 828                    Err(EstablishConnectionError::Unauthorized) => {
 829                        self.state.write().credentials.take();
 830                        if read_from_keychain {
 831                            cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 832                            self.set_status(Status::SignedOut, cx);
 833                            self.authenticate_and_connect(false, cx).await
 834                        } else {
 835                            self.set_status(Status::ConnectionError, cx);
 836                            Err(EstablishConnectionError::Unauthorized)?
 837                        }
 838                    }
 839                    Err(EstablishConnectionError::UpgradeRequired) => {
 840                        self.set_status(Status::UpgradeRequired, cx);
 841                        Err(EstablishConnectionError::UpgradeRequired)?
 842                    }
 843                    Err(error) => {
 844                        self.set_status(Status::ConnectionError, cx);
 845                        Err(error)?
 846                    }
 847                }
 848            }
 849            _ = &mut timeout => {
 850                self.set_status(Status::ConnectionError, cx);
 851                Err(anyhow!("timed out trying to establish connection"))
 852            }
 853        }
 854    }
 855
 856    async fn set_connection(
 857        self: &Arc<Self>,
 858        conn: Connection,
 859        cx: &AsyncAppContext,
 860    ) -> Result<()> {
 861        let executor = cx.background();
 862        log::info!("add connection to peer");
 863        let (connection_id, handle_io, mut incoming) = self
 864            .peer
 865            .add_connection(conn, move |duration| executor.timer(duration));
 866        let handle_io = cx.background().spawn(handle_io);
 867
 868        let peer_id = async {
 869            log::info!("waiting for server hello");
 870            let message = incoming
 871                .next()
 872                .await
 873                .ok_or_else(|| anyhow!("no hello message received"))?;
 874            log::info!("got server hello");
 875            let hello_message_type_name = message.payload_type_name().to_string();
 876            let hello = message
 877                .into_any()
 878                .downcast::<TypedEnvelope<proto::Hello>>()
 879                .map_err(|_| {
 880                    anyhow!(
 881                        "invalid hello message received: {:?}",
 882                        hello_message_type_name
 883                    )
 884                })?;
 885            let peer_id = hello
 886                .payload
 887                .peer_id
 888                .ok_or_else(|| anyhow!("invalid peer id"))?;
 889            Ok(peer_id)
 890        };
 891
 892        let peer_id = match peer_id.await {
 893            Ok(peer_id) => peer_id,
 894            Err(error) => {
 895                self.peer.disconnect(connection_id);
 896                return Err(error);
 897            }
 898        };
 899
 900        log::info!(
 901            "set status to connected (connection id: {:?}, peer id: {:?})",
 902            connection_id,
 903            peer_id
 904        );
 905        self.set_status(
 906            Status::Connected {
 907                peer_id,
 908                connection_id,
 909            },
 910            cx,
 911        );
 912        cx.foreground()
 913            .spawn({
 914                let cx = cx.clone();
 915                let this = self.clone();
 916                async move {
 917                    while let Some(message) = incoming.next().await {
 918                        this.handle_message(message, &cx);
 919                        // Don't starve the main thread when receiving lots of messages at once.
 920                        smol::future::yield_now().await;
 921                    }
 922                }
 923            })
 924            .detach();
 925
 926        let this = self.clone();
 927        let cx = cx.clone();
 928        cx.foreground()
 929            .spawn(async move {
 930                match handle_io.await {
 931                    Ok(()) => {
 932                        if this.status().borrow().clone()
 933                            == (Status::Connected {
 934                                connection_id,
 935                                peer_id,
 936                            })
 937                        {
 938                            this.set_status(Status::SignedOut, &cx);
 939                        }
 940                    }
 941                    Err(err) => {
 942                        log::error!("connection error: {:?}", err);
 943                        this.set_status(Status::ConnectionLost, &cx);
 944                    }
 945                }
 946            })
 947            .detach();
 948
 949        Ok(())
 950    }
 951
 952    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 953        #[cfg(any(test, feature = "test-support"))]
 954        if let Some(callback) = self.authenticate.read().as_ref() {
 955            return callback(cx);
 956        }
 957
 958        self.authenticate_with_browser(cx)
 959    }
 960
 961    fn establish_connection(
 962        self: &Arc<Self>,
 963        credentials: &Credentials,
 964        cx: &AsyncAppContext,
 965    ) -> Task<Result<Connection, EstablishConnectionError>> {
 966        #[cfg(any(test, feature = "test-support"))]
 967        if let Some(callback) = self.establish_connection.read().as_ref() {
 968            return callback(credentials, cx);
 969        }
 970
 971        self.establish_websocket_connection(credentials, cx)
 972    }
 973
 974    async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
 975        let preview_param = if is_preview { "?preview=1" } else { "" };
 976        let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
 977        let response = http.get(&url, Default::default(), false).await?;
 978
 979        // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
 980        // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
 981        // which requires authorization via an HTTP header.
 982        //
 983        // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
 984        // of a collab server. In that case, a request to the /rpc endpoint will
 985        // return an 'unauthorized' response.
 986        let collab_url = if response.status().is_redirection() {
 987            response
 988                .headers()
 989                .get("Location")
 990                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 991                .to_str()
 992                .map_err(EstablishConnectionError::other)?
 993                .to_string()
 994        } else if response.status() == StatusCode::UNAUTHORIZED {
 995            url
 996        } else {
 997            Err(anyhow!(
 998                "unexpected /rpc response status {}",
 999                response.status()
1000            ))?
1001        };
1002
1003        Url::parse(&collab_url).context("invalid rpc url")
1004    }
1005
1006    fn establish_websocket_connection(
1007        self: &Arc<Self>,
1008        credentials: &Credentials,
1009        cx: &AsyncAppContext,
1010    ) -> Task<Result<Connection, EstablishConnectionError>> {
1011        let is_preview = cx.read(|cx| {
1012            if cx.has_global::<ReleaseChannel>() {
1013                *cx.global::<ReleaseChannel>() == ReleaseChannel::Preview
1014            } else {
1015                false
1016            }
1017        });
1018
1019        let request = Request::builder()
1020            .header(
1021                "Authorization",
1022                format!("{} {}", credentials.user_id, credentials.access_token),
1023            )
1024            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
1025
1026        let http = self.http.clone();
1027        cx.background().spawn(async move {
1028            let mut rpc_url = Self::get_rpc_url(http, is_preview).await?;
1029            let rpc_host = rpc_url
1030                .host_str()
1031                .zip(rpc_url.port_or_known_default())
1032                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1033            let stream = smol::net::TcpStream::connect(rpc_host).await?;
1034
1035            log::info!("connected to rpc endpoint {}", rpc_url);
1036
1037            match rpc_url.scheme() {
1038                "https" => {
1039                    rpc_url.set_scheme("wss").unwrap();
1040                    let request = request.uri(rpc_url.as_str()).body(())?;
1041                    let (stream, _) =
1042                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1043                    Ok(Connection::new(
1044                        stream
1045                            .map_err(|error| anyhow!(error))
1046                            .sink_map_err(|error| anyhow!(error)),
1047                    ))
1048                }
1049                "http" => {
1050                    rpc_url.set_scheme("ws").unwrap();
1051                    let request = request.uri(rpc_url.as_str()).body(())?;
1052                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1053                    Ok(Connection::new(
1054                        stream
1055                            .map_err(|error| anyhow!(error))
1056                            .sink_map_err(|error| anyhow!(error)),
1057                    ))
1058                }
1059                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1060            }
1061        })
1062    }
1063
1064    pub fn authenticate_with_browser(
1065        self: &Arc<Self>,
1066        cx: &AsyncAppContext,
1067    ) -> Task<Result<Credentials>> {
1068        let platform = cx.platform();
1069        let executor = cx.background();
1070        let telemetry = self.telemetry.clone();
1071        let http = self.http.clone();
1072
1073        let telemetry_settings =
1074            cx.read(|cx| *settings::get_setting::<TelemetrySettings>(None, cx));
1075
1076        executor.clone().spawn(async move {
1077            // Generate a pair of asymmetric encryption keys. The public key will be used by the
1078            // zed server to encrypt the user's access token, so that it can'be intercepted by
1079            // any other app running on the user's device.
1080            let (public_key, private_key) =
1081                rpc::auth::keypair().expect("failed to generate keypair for auth");
1082            let public_key_string =
1083                String::try_from(public_key).expect("failed to serialize public key for auth");
1084
1085            if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1086                return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1087            }
1088
1089            // Start an HTTP server to receive the redirect from Zed's sign-in page.
1090            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1091            let port = server.server_addr().port();
1092
1093            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1094            // that the user is signing in from a Zed app running on the same device.
1095            let mut url = format!(
1096                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1097                *ZED_SERVER_URL, port, public_key_string
1098            );
1099
1100            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1101                log::info!("impersonating user @{}", impersonate_login);
1102                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1103            }
1104
1105            platform.open_url(&url);
1106
1107            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1108            // access token from the query params.
1109            //
1110            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1111            // custom URL scheme instead of this local HTTP server.
1112            let (user_id, access_token) = executor
1113                .spawn(async move {
1114                    for _ in 0..100 {
1115                        if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1116                            let path = req.url();
1117                            let mut user_id = None;
1118                            let mut access_token = None;
1119                            let url = Url::parse(&format!("http://example.com{}", path))
1120                                .context("failed to parse login notification url")?;
1121                            for (key, value) in url.query_pairs() {
1122                                if key == "access_token" {
1123                                    access_token = Some(value.to_string());
1124                                } else if key == "user_id" {
1125                                    user_id = Some(value.to_string());
1126                                }
1127                            }
1128
1129                            let post_auth_url =
1130                                format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1131                            req.respond(
1132                                tiny_http::Response::empty(302).with_header(
1133                                    tiny_http::Header::from_bytes(
1134                                        &b"Location"[..],
1135                                        post_auth_url.as_bytes(),
1136                                    )
1137                                    .unwrap(),
1138                                ),
1139                            )
1140                            .context("failed to respond to login http request")?;
1141                            return Ok((
1142                                user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1143                                access_token
1144                                    .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1145                            ));
1146                        }
1147                    }
1148
1149                    Err(anyhow!("didn't receive login redirect"))
1150                })
1151                .await?;
1152
1153            let access_token = private_key
1154                .decrypt_string(&access_token)
1155                .context("failed to decrypt access token")?;
1156            platform.activate(true);
1157
1158            telemetry.report_mixpanel_event(
1159                "authenticate with browser",
1160                Default::default(),
1161                telemetry_settings,
1162            );
1163
1164            Ok(Credentials {
1165                user_id: user_id.parse()?,
1166                access_token,
1167            })
1168        })
1169    }
1170
1171    async fn authenticate_as_admin(
1172        http: Arc<dyn HttpClient>,
1173        login: String,
1174        mut api_token: String,
1175    ) -> Result<Credentials> {
1176        #[derive(Deserialize)]
1177        struct AuthenticatedUserResponse {
1178            user: User,
1179        }
1180
1181        #[derive(Deserialize)]
1182        struct User {
1183            id: u64,
1184        }
1185
1186        // Use the collab server's admin API to retrieve the id
1187        // of the impersonated user.
1188        let mut url = Self::get_rpc_url(http.clone(), false).await?;
1189        url.set_path("/user");
1190        url.set_query(Some(&format!("github_login={login}")));
1191        let request = Request::get(url.as_str())
1192            .header("Authorization", format!("token {api_token}"))
1193            .body("".into())?;
1194
1195        let mut response = http.send(request).await?;
1196        let mut body = String::new();
1197        response.body_mut().read_to_string(&mut body).await?;
1198        if !response.status().is_success() {
1199            Err(anyhow!(
1200                "admin user request failed {} - {}",
1201                response.status().as_u16(),
1202                body,
1203            ))?;
1204        }
1205        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1206
1207        // Use the admin API token to authenticate as the impersonated user.
1208        api_token.insert_str(0, "ADMIN_TOKEN:");
1209        Ok(Credentials {
1210            user_id: response.user.id,
1211            access_token: api_token,
1212        })
1213    }
1214
1215    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1216        self.peer.teardown();
1217        self.set_status(Status::SignedOut, cx);
1218    }
1219
1220    fn connection_id(&self) -> Result<ConnectionId> {
1221        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1222            Ok(connection_id)
1223        } else {
1224            Err(anyhow!("not connected"))
1225        }
1226    }
1227
1228    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1229        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1230        self.peer.send(self.connection_id()?, message)
1231    }
1232
1233    pub fn request<T: RequestMessage>(
1234        &self,
1235        request: T,
1236    ) -> impl Future<Output = Result<T::Response>> {
1237        self.request_envelope(request)
1238            .map_ok(|envelope| envelope.payload)
1239    }
1240
1241    pub fn request_envelope<T: RequestMessage>(
1242        &self,
1243        request: T,
1244    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1245        let client_id = self.id;
1246        log::debug!(
1247            "rpc request start. client_id:{}. name:{}",
1248            client_id,
1249            T::NAME
1250        );
1251        let response = self
1252            .connection_id()
1253            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1254        async move {
1255            let response = response?.await;
1256            log::debug!(
1257                "rpc request finish. client_id:{}. name:{}",
1258                client_id,
1259                T::NAME
1260            );
1261            response
1262        }
1263    }
1264
1265    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1266        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1267        self.peer.respond(receipt, response)
1268    }
1269
1270    fn respond_with_error<T: RequestMessage>(
1271        &self,
1272        receipt: Receipt<T>,
1273        error: proto::Error,
1274    ) -> Result<()> {
1275        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1276        self.peer.respond_with_error(receipt, error)
1277    }
1278
1279    fn handle_message(
1280        self: &Arc<Client>,
1281        message: Box<dyn AnyTypedEnvelope>,
1282        cx: &AsyncAppContext,
1283    ) {
1284        let mut state = self.state.write();
1285        let type_name = message.payload_type_name();
1286        let payload_type_id = message.payload_type_id();
1287        let sender_id = message.original_sender_id();
1288
1289        let mut subscriber = None;
1290
1291        if let Some(message_model) = state
1292            .models_by_message_type
1293            .get(&payload_type_id)
1294            .and_then(|model| model.upgrade(cx))
1295        {
1296            subscriber = Some(Subscriber::Model(message_model));
1297        } else if let Some((extract_entity_id, entity_type_id)) =
1298            state.entity_id_extractors.get(&payload_type_id).zip(
1299                state
1300                    .entity_types_by_message_type
1301                    .get(&payload_type_id)
1302                    .copied(),
1303            )
1304        {
1305            let entity_id = (extract_entity_id)(message.as_ref());
1306
1307            match state
1308                .entities_by_type_and_remote_id
1309                .get_mut(&(entity_type_id, entity_id))
1310            {
1311                Some(WeakSubscriber::Pending(pending)) => {
1312                    pending.push(message);
1313                    return;
1314                }
1315                Some(weak_subscriber @ _) => match weak_subscriber {
1316                    WeakSubscriber::Model(handle) => {
1317                        subscriber = handle.upgrade(cx).map(Subscriber::Model);
1318                    }
1319                    WeakSubscriber::View(handle) => {
1320                        subscriber = Some(Subscriber::View(handle.clone()));
1321                    }
1322                    WeakSubscriber::Pending(_) => {}
1323                },
1324                _ => {}
1325            }
1326        }
1327
1328        let subscriber = if let Some(subscriber) = subscriber {
1329            subscriber
1330        } else {
1331            log::info!("unhandled message {}", type_name);
1332            self.peer.respond_with_unhandled_message(message).log_err();
1333            return;
1334        };
1335
1336        let handler = state.message_handlers.get(&payload_type_id).cloned();
1337        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1338        // It also ensures we don't hold the lock while yielding back to the executor, as
1339        // that might cause the executor thread driving this future to block indefinitely.
1340        drop(state);
1341
1342        if let Some(handler) = handler {
1343            let future = handler(subscriber, message, &self, cx.clone());
1344            let client_id = self.id;
1345            log::debug!(
1346                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1347                client_id,
1348                sender_id,
1349                type_name
1350            );
1351            cx.foreground()
1352                .spawn(async move {
1353                    match future.await {
1354                        Ok(()) => {
1355                            log::debug!(
1356                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1357                                client_id,
1358                                sender_id,
1359                                type_name
1360                            );
1361                        }
1362                        Err(error) => {
1363                            log::error!(
1364                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1365                                client_id,
1366                                sender_id,
1367                                type_name,
1368                                error
1369                            );
1370                        }
1371                    }
1372                })
1373                .detach();
1374        } else {
1375            log::info!("unhandled message {}", type_name);
1376            self.peer.respond_with_unhandled_message(message).log_err();
1377        }
1378    }
1379
1380    pub fn telemetry(&self) -> &Arc<Telemetry> {
1381        &self.telemetry
1382    }
1383}
1384
1385fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1386    if IMPERSONATE_LOGIN.is_some() {
1387        return None;
1388    }
1389
1390    let (user_id, access_token) = cx
1391        .platform()
1392        .read_credentials(&ZED_SERVER_URL)
1393        .log_err()
1394        .flatten()?;
1395    Some(Credentials {
1396        user_id: user_id.parse().ok()?,
1397        access_token: String::from_utf8(access_token).ok()?,
1398    })
1399}
1400
1401fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1402    cx.platform().write_credentials(
1403        &ZED_SERVER_URL,
1404        &credentials.user_id.to_string(),
1405        credentials.access_token.as_bytes(),
1406    )
1407}
1408
1409const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1410
1411pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1412    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1413}
1414
1415pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1416    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1417    let mut parts = path.split('/');
1418    let id = parts.next()?.parse::<u64>().ok()?;
1419    let access_token = parts.next()?;
1420    if access_token.is_empty() {
1421        return None;
1422    }
1423    Some((id, access_token.to_string()))
1424}
1425
1426#[cfg(test)]
1427mod tests {
1428    use super::*;
1429    use crate::test::FakeServer;
1430    use gpui::{executor::Deterministic, TestAppContext};
1431    use parking_lot::Mutex;
1432    use std::future;
1433    use util::http::FakeHttpClient;
1434
1435    #[gpui::test(iterations = 10)]
1436    async fn test_reconnection(cx: &mut TestAppContext) {
1437        cx.foreground().forbid_parking();
1438
1439        let user_id = 5;
1440        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1441        let server = FakeServer::for_client(user_id, &client, cx).await;
1442        let mut status = client.status();
1443        assert!(matches!(
1444            status.next().await,
1445            Some(Status::Connected { .. })
1446        ));
1447        assert_eq!(server.auth_count(), 1);
1448
1449        server.forbid_connections();
1450        server.disconnect();
1451        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1452
1453        server.allow_connections();
1454        cx.foreground().advance_clock(Duration::from_secs(10));
1455        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1456        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1457
1458        server.forbid_connections();
1459        server.disconnect();
1460        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1461
1462        // Clear cached credentials after authentication fails
1463        server.roll_access_token();
1464        server.allow_connections();
1465        cx.foreground().advance_clock(Duration::from_secs(10));
1466        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1467        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1468    }
1469
1470    #[gpui::test(iterations = 10)]
1471    async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1472        deterministic.forbid_parking();
1473
1474        let user_id = 5;
1475        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1476        let mut status = client.status();
1477
1478        // Time out when client tries to connect.
1479        client.override_authenticate(move |cx| {
1480            cx.foreground().spawn(async move {
1481                Ok(Credentials {
1482                    user_id,
1483                    access_token: "token".into(),
1484                })
1485            })
1486        });
1487        client.override_establish_connection(|_, cx| {
1488            cx.foreground().spawn(async move {
1489                future::pending::<()>().await;
1490                unreachable!()
1491            })
1492        });
1493        let auth_and_connect = cx.spawn({
1494            let client = client.clone();
1495            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1496        });
1497        deterministic.run_until_parked();
1498        assert!(matches!(status.next().await, Some(Status::Connecting)));
1499
1500        deterministic.advance_clock(CONNECTION_TIMEOUT);
1501        assert!(matches!(
1502            status.next().await,
1503            Some(Status::ConnectionError { .. })
1504        ));
1505        auth_and_connect.await.unwrap_err();
1506
1507        // Allow the connection to be established.
1508        let server = FakeServer::for_client(user_id, &client, cx).await;
1509        assert!(matches!(
1510            status.next().await,
1511            Some(Status::Connected { .. })
1512        ));
1513
1514        // Disconnect client.
1515        server.forbid_connections();
1516        server.disconnect();
1517        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1518
1519        // Time out when re-establishing the connection.
1520        server.allow_connections();
1521        client.override_establish_connection(|_, cx| {
1522            cx.foreground().spawn(async move {
1523                future::pending::<()>().await;
1524                unreachable!()
1525            })
1526        });
1527        deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1528        assert!(matches!(
1529            status.next().await,
1530            Some(Status::Reconnecting { .. })
1531        ));
1532
1533        deterministic.advance_clock(CONNECTION_TIMEOUT);
1534        assert!(matches!(
1535            status.next().await,
1536            Some(Status::ReconnectionError { .. })
1537        ));
1538    }
1539
1540    #[gpui::test(iterations = 10)]
1541    async fn test_authenticating_more_than_once(
1542        cx: &mut TestAppContext,
1543        deterministic: Arc<Deterministic>,
1544    ) {
1545        cx.foreground().forbid_parking();
1546
1547        let auth_count = Arc::new(Mutex::new(0));
1548        let dropped_auth_count = Arc::new(Mutex::new(0));
1549        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1550        client.override_authenticate({
1551            let auth_count = auth_count.clone();
1552            let dropped_auth_count = dropped_auth_count.clone();
1553            move |cx| {
1554                let auth_count = auth_count.clone();
1555                let dropped_auth_count = dropped_auth_count.clone();
1556                cx.foreground().spawn(async move {
1557                    *auth_count.lock() += 1;
1558                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1559                    future::pending::<()>().await;
1560                    unreachable!()
1561                })
1562            }
1563        });
1564
1565        let _authenticate = cx.spawn(|cx| {
1566            let client = client.clone();
1567            async move { client.authenticate_and_connect(false, &cx).await }
1568        });
1569        deterministic.run_until_parked();
1570        assert_eq!(*auth_count.lock(), 1);
1571        assert_eq!(*dropped_auth_count.lock(), 0);
1572
1573        let _authenticate = cx.spawn(|cx| {
1574            let client = client.clone();
1575            async move { client.authenticate_and_connect(false, &cx).await }
1576        });
1577        deterministic.run_until_parked();
1578        assert_eq!(*auth_count.lock(), 2);
1579        assert_eq!(*dropped_auth_count.lock(), 1);
1580    }
1581
1582    #[test]
1583    fn test_encode_and_decode_worktree_url() {
1584        let url = encode_worktree_url(5, "deadbeef");
1585        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1586        assert_eq!(
1587            decode_worktree_url(&format!("\n {}\t", url)),
1588            Some((5, "deadbeef".to_string()))
1589        );
1590        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1591    }
1592
1593    #[gpui::test]
1594    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1595        cx.foreground().forbid_parking();
1596
1597        let user_id = 5;
1598        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1599        let server = FakeServer::for_client(user_id, &client, cx).await;
1600
1601        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1602        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1603        client.add_model_message_handler(
1604            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1605                match model.read_with(&cx, |model, _| model.id) {
1606                    1 => done_tx1.try_send(()).unwrap(),
1607                    2 => done_tx2.try_send(()).unwrap(),
1608                    _ => unreachable!(),
1609                }
1610                async { Ok(()) }
1611            },
1612        );
1613        let model1 = cx.add_model(|_| Model {
1614            id: 1,
1615            subscription: None,
1616        });
1617        let model2 = cx.add_model(|_| Model {
1618            id: 2,
1619            subscription: None,
1620        });
1621        let model3 = cx.add_model(|_| Model {
1622            id: 3,
1623            subscription: None,
1624        });
1625
1626        let _subscription1 = client
1627            .subscribe_to_entity(1)
1628            .unwrap()
1629            .set_model(&model1, &mut cx.to_async());
1630        let _subscription2 = client
1631            .subscribe_to_entity(2)
1632            .unwrap()
1633            .set_model(&model2, &mut cx.to_async());
1634        // Ensure dropping a subscription for the same entity type still allows receiving of
1635        // messages for other entity IDs of the same type.
1636        let subscription3 = client
1637            .subscribe_to_entity(3)
1638            .unwrap()
1639            .set_model(&model3, &mut cx.to_async());
1640        drop(subscription3);
1641
1642        server.send(proto::JoinProject { project_id: 1 });
1643        server.send(proto::JoinProject { project_id: 2 });
1644        done_rx1.next().await.unwrap();
1645        done_rx2.next().await.unwrap();
1646    }
1647
1648    #[gpui::test]
1649    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1650        cx.foreground().forbid_parking();
1651
1652        let user_id = 5;
1653        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1654        let server = FakeServer::for_client(user_id, &client, cx).await;
1655
1656        let model = cx.add_model(|_| Model::default());
1657        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1658        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1659        let subscription1 = client.add_message_handler(
1660            model.clone(),
1661            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1662                done_tx1.try_send(()).unwrap();
1663                async { Ok(()) }
1664            },
1665        );
1666        drop(subscription1);
1667        let _subscription2 = client.add_message_handler(
1668            model.clone(),
1669            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1670                done_tx2.try_send(()).unwrap();
1671                async { Ok(()) }
1672            },
1673        );
1674        server.send(proto::Ping {});
1675        done_rx2.next().await.unwrap();
1676    }
1677
1678    #[gpui::test]
1679    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1680        cx.foreground().forbid_parking();
1681
1682        let user_id = 5;
1683        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1684        let server = FakeServer::for_client(user_id, &client, cx).await;
1685
1686        let model = cx.add_model(|_| Model::default());
1687        let (done_tx, mut done_rx) = smol::channel::unbounded();
1688        let subscription = client.add_message_handler(
1689            model.clone(),
1690            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1691                model.update(&mut cx, |model, _| model.subscription.take());
1692                done_tx.try_send(()).unwrap();
1693                async { Ok(()) }
1694            },
1695        );
1696        model.update(cx, |model, _| {
1697            model.subscription = Some(subscription);
1698        });
1699        server.send(proto::Ping {});
1700        done_rx.next().await.unwrap();
1701    }
1702
1703    #[derive(Default)]
1704    struct Model {
1705        id: usize,
1706        subscription: Option<Subscription>,
1707    }
1708
1709    impl Entity for Model {
1710        type Event = ();
1711    }
1712}