client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4#[cfg(test)]
   5mod channel_store_tests;
   6
   7pub mod channel_store;
   8pub mod telemetry;
   9pub mod user;
  10
  11use anyhow::{anyhow, Context, Result};
  12use async_recursion::async_recursion;
  13use async_tungstenite::tungstenite::{
  14    error::Error as WebsocketError,
  15    http::{Request, StatusCode},
  16};
  17use futures::{
  18    future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
  19    TryStreamExt,
  20};
  21use gpui::{
  22    actions, platform::AppVersion, serde_json, AnyModelHandle, AnyWeakModelHandle,
  23    AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Task, View, ViewContext,
  24    WeakViewHandle,
  25};
  26use lazy_static::lazy_static;
  27use parking_lot::RwLock;
  28use postage::watch;
  29use rand::prelude::*;
  30use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
  31use schemars::JsonSchema;
  32use serde::{Deserialize, Serialize};
  33use std::{
  34    any::TypeId,
  35    collections::HashMap,
  36    convert::TryFrom,
  37    fmt::Write as _,
  38    future::Future,
  39    marker::PhantomData,
  40    path::PathBuf,
  41    sync::{Arc, Weak},
  42    time::{Duration, Instant},
  43};
  44use telemetry::Telemetry;
  45use thiserror::Error;
  46use url::Url;
  47use util::channel::ReleaseChannel;
  48use util::http::HttpClient;
  49use util::{ResultExt, TryFutureExt};
  50
  51pub use channel_store::*;
  52pub use rpc::*;
  53pub use telemetry::ClickhouseEvent;
  54pub use user::*;
  55
  56lazy_static! {
  57    pub static ref ZED_SERVER_URL: String =
  58        std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
  59    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  60        .ok()
  61        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  62    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
  63        .ok()
  64        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  65    pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
  66        .ok()
  67        .and_then(|v| v.parse().ok());
  68    pub static ref ZED_APP_PATH: Option<PathBuf> =
  69        std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
  70}
  71
  72pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
  73pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
  74pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
  75
  76actions!(client, [SignIn, SignOut]);
  77
  78pub fn init_settings(cx: &mut AppContext) {
  79    settings::register::<TelemetrySettings>(cx);
  80}
  81
  82pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
  83    init_settings(cx);
  84
  85    let client = Arc::downgrade(client);
  86    cx.add_global_action({
  87        let client = client.clone();
  88        move |_: &SignIn, cx| {
  89            if let Some(client) = client.upgrade() {
  90                cx.spawn(
  91                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
  92                )
  93                .detach();
  94            }
  95        }
  96    });
  97    cx.add_global_action({
  98        let client = client.clone();
  99        move |_: &SignOut, cx| {
 100            if let Some(client) = client.upgrade() {
 101                cx.spawn(|cx| async move {
 102                    client.disconnect(&cx);
 103                })
 104                .detach();
 105            }
 106        }
 107    });
 108}
 109
 110pub struct Client {
 111    id: usize,
 112    peer: Arc<Peer>,
 113    http: Arc<dyn HttpClient>,
 114    telemetry: Arc<Telemetry>,
 115    state: RwLock<ClientState>,
 116
 117    #[allow(clippy::type_complexity)]
 118    #[cfg(any(test, feature = "test-support"))]
 119    authenticate: RwLock<
 120        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 121    >,
 122
 123    #[allow(clippy::type_complexity)]
 124    #[cfg(any(test, feature = "test-support"))]
 125    establish_connection: RwLock<
 126        Option<
 127            Box<
 128                dyn 'static
 129                    + Send
 130                    + Sync
 131                    + Fn(
 132                        &Credentials,
 133                        &AsyncAppContext,
 134                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 135            >,
 136        >,
 137    >,
 138}
 139
 140#[derive(Error, Debug)]
 141pub enum EstablishConnectionError {
 142    #[error("upgrade required")]
 143    UpgradeRequired,
 144    #[error("unauthorized")]
 145    Unauthorized,
 146    #[error("{0}")]
 147    Other(#[from] anyhow::Error),
 148    #[error("{0}")]
 149    Http(#[from] util::http::Error),
 150    #[error("{0}")]
 151    Io(#[from] std::io::Error),
 152    #[error("{0}")]
 153    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 154}
 155
 156impl From<WebsocketError> for EstablishConnectionError {
 157    fn from(error: WebsocketError) -> Self {
 158        if let WebsocketError::Http(response) = &error {
 159            match response.status() {
 160                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 161                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 162                _ => {}
 163            }
 164        }
 165        EstablishConnectionError::Other(error.into())
 166    }
 167}
 168
 169impl EstablishConnectionError {
 170    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 171        Self::Other(error.into())
 172    }
 173}
 174
 175#[derive(Copy, Clone, Debug, PartialEq)]
 176pub enum Status {
 177    SignedOut,
 178    UpgradeRequired,
 179    Authenticating,
 180    Connecting,
 181    ConnectionError,
 182    Connected {
 183        peer_id: PeerId,
 184        connection_id: ConnectionId,
 185    },
 186    ConnectionLost,
 187    Reauthenticating,
 188    Reconnecting,
 189    ReconnectionError {
 190        next_reconnection: Instant,
 191    },
 192}
 193
 194impl Status {
 195    pub fn is_connected(&self) -> bool {
 196        matches!(self, Self::Connected { .. })
 197    }
 198
 199    pub fn is_signed_out(&self) -> bool {
 200        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 201    }
 202}
 203
 204struct ClientState {
 205    credentials: Option<Credentials>,
 206    status: (watch::Sender<Status>, watch::Receiver<Status>),
 207    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 208    _reconnect_task: Option<Task<()>>,
 209    reconnect_interval: Duration,
 210    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 211    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 212    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 213    #[allow(clippy::type_complexity)]
 214    message_handlers: HashMap<
 215        TypeId,
 216        Arc<
 217            dyn Send
 218                + Sync
 219                + Fn(
 220                    Subscriber,
 221                    Box<dyn AnyTypedEnvelope>,
 222                    &Arc<Client>,
 223                    AsyncAppContext,
 224                ) -> LocalBoxFuture<'static, Result<()>>,
 225        >,
 226    >,
 227}
 228
 229enum WeakSubscriber {
 230    Model(AnyWeakModelHandle),
 231    View(AnyWeakViewHandle),
 232    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 233}
 234
 235enum Subscriber {
 236    Model(AnyModelHandle),
 237    View(AnyWeakViewHandle),
 238}
 239
 240#[derive(Clone, Debug)]
 241pub struct Credentials {
 242    pub user_id: u64,
 243    pub access_token: String,
 244}
 245
 246impl Default for ClientState {
 247    fn default() -> Self {
 248        Self {
 249            credentials: None,
 250            status: watch::channel_with(Status::SignedOut),
 251            entity_id_extractors: Default::default(),
 252            _reconnect_task: None,
 253            reconnect_interval: Duration::from_secs(5),
 254            models_by_message_type: Default::default(),
 255            entities_by_type_and_remote_id: Default::default(),
 256            entity_types_by_message_type: Default::default(),
 257            message_handlers: Default::default(),
 258        }
 259    }
 260}
 261
 262pub enum Subscription {
 263    Entity {
 264        client: Weak<Client>,
 265        id: (TypeId, u64),
 266    },
 267    Message {
 268        client: Weak<Client>,
 269        id: TypeId,
 270    },
 271}
 272
 273impl Drop for Subscription {
 274    fn drop(&mut self) {
 275        match self {
 276            Subscription::Entity { client, id } => {
 277                if let Some(client) = client.upgrade() {
 278                    let mut state = client.state.write();
 279                    let _ = state.entities_by_type_and_remote_id.remove(id);
 280                }
 281            }
 282            Subscription::Message { client, id } => {
 283                if let Some(client) = client.upgrade() {
 284                    let mut state = client.state.write();
 285                    let _ = state.entity_types_by_message_type.remove(id);
 286                    let _ = state.message_handlers.remove(id);
 287                }
 288            }
 289        }
 290    }
 291}
 292
 293pub struct PendingEntitySubscription<T: Entity> {
 294    client: Arc<Client>,
 295    remote_id: u64,
 296    _entity_type: PhantomData<T>,
 297    consumed: bool,
 298}
 299
 300impl<T: Entity> PendingEntitySubscription<T> {
 301    pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
 302        self.consumed = true;
 303        let mut state = self.client.state.write();
 304        let id = (TypeId::of::<T>(), self.remote_id);
 305        let Some(WeakSubscriber::Pending(messages)) =
 306            state.entities_by_type_and_remote_id.remove(&id)
 307        else {
 308            unreachable!()
 309        };
 310
 311        state
 312            .entities_by_type_and_remote_id
 313            .insert(id, WeakSubscriber::Model(model.downgrade().into_any()));
 314        drop(state);
 315        for message in messages {
 316            self.client.handle_message(message, cx);
 317        }
 318        Subscription::Entity {
 319            client: Arc::downgrade(&self.client),
 320            id,
 321        }
 322    }
 323}
 324
 325impl<T: Entity> Drop for PendingEntitySubscription<T> {
 326    fn drop(&mut self) {
 327        if !self.consumed {
 328            let mut state = self.client.state.write();
 329            if let Some(WeakSubscriber::Pending(messages)) = state
 330                .entities_by_type_and_remote_id
 331                .remove(&(TypeId::of::<T>(), self.remote_id))
 332            {
 333                for message in messages {
 334                    log::info!("unhandled message {}", message.payload_type_name());
 335                }
 336            }
 337        }
 338    }
 339}
 340
 341#[derive(Copy, Clone)]
 342pub struct TelemetrySettings {
 343    pub diagnostics: bool,
 344    pub metrics: bool,
 345}
 346
 347#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 348pub struct TelemetrySettingsContent {
 349    pub diagnostics: Option<bool>,
 350    pub metrics: Option<bool>,
 351}
 352
 353impl settings::Setting for TelemetrySettings {
 354    const KEY: Option<&'static str> = Some("telemetry");
 355
 356    type FileContent = TelemetrySettingsContent;
 357
 358    fn load(
 359        default_value: &Self::FileContent,
 360        user_values: &[&Self::FileContent],
 361        _: &AppContext,
 362    ) -> Result<Self> {
 363        Ok(Self {
 364            diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
 365                default_value
 366                    .diagnostics
 367                    .ok_or_else(Self::missing_default)?,
 368            ),
 369            metrics: user_values
 370                .first()
 371                .and_then(|v| v.metrics)
 372                .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
 373        })
 374    }
 375}
 376
 377impl Client {
 378    pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
 379        Arc::new(Self {
 380            id: 0,
 381            peer: Peer::new(0),
 382            telemetry: Telemetry::new(http.clone(), cx),
 383            http,
 384            state: Default::default(),
 385
 386            #[cfg(any(test, feature = "test-support"))]
 387            authenticate: Default::default(),
 388            #[cfg(any(test, feature = "test-support"))]
 389            establish_connection: Default::default(),
 390        })
 391    }
 392
 393    pub fn id(&self) -> usize {
 394        self.id
 395    }
 396
 397    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 398        self.http.clone()
 399    }
 400
 401    #[cfg(any(test, feature = "test-support"))]
 402    pub fn set_id(&mut self, id: usize) -> &Self {
 403        self.id = id;
 404        self
 405    }
 406
 407    #[cfg(any(test, feature = "test-support"))]
 408    pub fn teardown(&self) {
 409        let mut state = self.state.write();
 410        state._reconnect_task.take();
 411        state.message_handlers.clear();
 412        state.models_by_message_type.clear();
 413        state.entities_by_type_and_remote_id.clear();
 414        state.entity_id_extractors.clear();
 415        self.peer.teardown();
 416    }
 417
 418    #[cfg(any(test, feature = "test-support"))]
 419    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 420    where
 421        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 422    {
 423        *self.authenticate.write() = Some(Box::new(authenticate));
 424        self
 425    }
 426
 427    #[cfg(any(test, feature = "test-support"))]
 428    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 429    where
 430        F: 'static
 431            + Send
 432            + Sync
 433            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 434    {
 435        *self.establish_connection.write() = Some(Box::new(connect));
 436        self
 437    }
 438
 439    pub fn user_id(&self) -> Option<u64> {
 440        self.state
 441            .read()
 442            .credentials
 443            .as_ref()
 444            .map(|credentials| credentials.user_id)
 445    }
 446
 447    pub fn peer_id(&self) -> Option<PeerId> {
 448        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 449            Some(*peer_id)
 450        } else {
 451            None
 452        }
 453    }
 454
 455    pub fn status(&self) -> watch::Receiver<Status> {
 456        self.state.read().status.1.clone()
 457    }
 458
 459    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 460        log::info!("set status on client {}: {:?}", self.id, status);
 461        let mut state = self.state.write();
 462        *state.status.0.borrow_mut() = status;
 463
 464        match status {
 465            Status::Connected { .. } => {
 466                state._reconnect_task = None;
 467            }
 468            Status::ConnectionLost => {
 469                let this = self.clone();
 470                let reconnect_interval = state.reconnect_interval;
 471                state._reconnect_task = Some(cx.spawn(|cx| async move {
 472                    #[cfg(any(test, feature = "test-support"))]
 473                    let mut rng = StdRng::seed_from_u64(0);
 474                    #[cfg(not(any(test, feature = "test-support")))]
 475                    let mut rng = StdRng::from_entropy();
 476
 477                    let mut delay = INITIAL_RECONNECTION_DELAY;
 478                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 479                        log::error!("failed to connect {}", error);
 480                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 481                            this.set_status(
 482                                Status::ReconnectionError {
 483                                    next_reconnection: Instant::now() + delay,
 484                                },
 485                                &cx,
 486                            );
 487                            cx.background().timer(delay).await;
 488                            delay = delay
 489                                .mul_f32(rng.gen_range(1.0..=2.0))
 490                                .min(reconnect_interval);
 491                        } else {
 492                            break;
 493                        }
 494                    }
 495                }));
 496            }
 497            Status::SignedOut | Status::UpgradeRequired => {
 498                cx.read(|cx| self.telemetry.set_authenticated_user_info(None, false, cx));
 499                state._reconnect_task.take();
 500            }
 501            _ => {}
 502        }
 503    }
 504
 505    pub fn add_view_for_remote_entity<T: View>(
 506        self: &Arc<Self>,
 507        remote_id: u64,
 508        cx: &mut ViewContext<T>,
 509    ) -> Subscription {
 510        let id = (TypeId::of::<T>(), remote_id);
 511        self.state
 512            .write()
 513            .entities_by_type_and_remote_id
 514            .insert(id, WeakSubscriber::View(cx.weak_handle().into_any()));
 515        Subscription::Entity {
 516            client: Arc::downgrade(self),
 517            id,
 518        }
 519    }
 520
 521    pub fn subscribe_to_entity<T: Entity>(
 522        self: &Arc<Self>,
 523        remote_id: u64,
 524    ) -> Result<PendingEntitySubscription<T>> {
 525        let id = (TypeId::of::<T>(), remote_id);
 526
 527        let mut state = self.state.write();
 528        if state.entities_by_type_and_remote_id.contains_key(&id) {
 529            return Err(anyhow!("already subscribed to entity"));
 530        } else {
 531            state
 532                .entities_by_type_and_remote_id
 533                .insert(id, WeakSubscriber::Pending(Default::default()));
 534            Ok(PendingEntitySubscription {
 535                client: self.clone(),
 536                remote_id,
 537                consumed: false,
 538                _entity_type: PhantomData,
 539            })
 540        }
 541    }
 542
 543    #[track_caller]
 544    pub fn add_message_handler<M, E, H, F>(
 545        self: &Arc<Self>,
 546        model: ModelHandle<E>,
 547        handler: H,
 548    ) -> Subscription
 549    where
 550        M: EnvelopedMessage,
 551        E: Entity,
 552        H: 'static
 553            + Send
 554            + Sync
 555            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 556        F: 'static + Future<Output = Result<()>>,
 557    {
 558        let message_type_id = TypeId::of::<M>();
 559
 560        let mut state = self.state.write();
 561        state
 562            .models_by_message_type
 563            .insert(message_type_id, model.downgrade().into_any());
 564
 565        let prev_handler = state.message_handlers.insert(
 566            message_type_id,
 567            Arc::new(move |handle, envelope, client, cx| {
 568                let handle = if let Subscriber::Model(handle) = handle {
 569                    handle
 570                } else {
 571                    unreachable!();
 572                };
 573                let model = handle.downcast::<E>().unwrap();
 574                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 575                handler(model, *envelope, client.clone(), cx).boxed_local()
 576            }),
 577        );
 578        if prev_handler.is_some() {
 579            let location = std::panic::Location::caller();
 580            panic!(
 581                "{}:{} registered handler for the same message {} twice",
 582                location.file(),
 583                location.line(),
 584                std::any::type_name::<M>()
 585            );
 586        }
 587
 588        Subscription::Message {
 589            client: Arc::downgrade(self),
 590            id: message_type_id,
 591        }
 592    }
 593
 594    pub fn add_request_handler<M, E, H, F>(
 595        self: &Arc<Self>,
 596        model: ModelHandle<E>,
 597        handler: H,
 598    ) -> Subscription
 599    where
 600        M: RequestMessage,
 601        E: Entity,
 602        H: 'static
 603            + Send
 604            + Sync
 605            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 606        F: 'static + Future<Output = Result<M::Response>>,
 607    {
 608        self.add_message_handler(model, move |handle, envelope, this, cx| {
 609            Self::respond_to_request(
 610                envelope.receipt(),
 611                handler(handle, envelope, this.clone(), cx),
 612                this,
 613            )
 614        })
 615    }
 616
 617    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 618    where
 619        M: EntityMessage,
 620        E: View,
 621        H: 'static
 622            + Send
 623            + Sync
 624            + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 625        F: 'static + Future<Output = Result<()>>,
 626    {
 627        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 628            if let Subscriber::View(handle) = handle {
 629                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 630            } else {
 631                unreachable!();
 632            }
 633        })
 634    }
 635
 636    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 637    where
 638        M: EntityMessage,
 639        E: Entity,
 640        H: 'static
 641            + Send
 642            + Sync
 643            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 644        F: 'static + Future<Output = Result<()>>,
 645    {
 646        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 647            if let Subscriber::Model(handle) = handle {
 648                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 649            } else {
 650                unreachable!();
 651            }
 652        })
 653    }
 654
 655    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 656    where
 657        M: EntityMessage,
 658        E: Entity,
 659        H: 'static
 660            + Send
 661            + Sync
 662            + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 663        F: 'static + Future<Output = Result<()>>,
 664    {
 665        let model_type_id = TypeId::of::<E>();
 666        let message_type_id = TypeId::of::<M>();
 667
 668        let mut state = self.state.write();
 669        state
 670            .entity_types_by_message_type
 671            .insert(message_type_id, model_type_id);
 672        state
 673            .entity_id_extractors
 674            .entry(message_type_id)
 675            .or_insert_with(|| {
 676                |envelope| {
 677                    envelope
 678                        .as_any()
 679                        .downcast_ref::<TypedEnvelope<M>>()
 680                        .unwrap()
 681                        .payload
 682                        .remote_entity_id()
 683                }
 684            });
 685        let prev_handler = state.message_handlers.insert(
 686            message_type_id,
 687            Arc::new(move |handle, envelope, client, cx| {
 688                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 689                handler(handle, *envelope, client.clone(), cx).boxed_local()
 690            }),
 691        );
 692        if prev_handler.is_some() {
 693            panic!("registered handler for the same message twice");
 694        }
 695    }
 696
 697    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 698    where
 699        M: EntityMessage + RequestMessage,
 700        E: Entity,
 701        H: 'static
 702            + Send
 703            + Sync
 704            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 705        F: 'static + Future<Output = Result<M::Response>>,
 706    {
 707        self.add_model_message_handler(move |entity, envelope, client, cx| {
 708            Self::respond_to_request::<M, _>(
 709                envelope.receipt(),
 710                handler(entity, envelope, client.clone(), cx),
 711                client,
 712            )
 713        })
 714    }
 715
 716    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 717    where
 718        M: EntityMessage + RequestMessage,
 719        E: View,
 720        H: 'static
 721            + Send
 722            + Sync
 723            + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 724        F: 'static + Future<Output = Result<M::Response>>,
 725    {
 726        self.add_view_message_handler(move |entity, envelope, client, cx| {
 727            Self::respond_to_request::<M, _>(
 728                envelope.receipt(),
 729                handler(entity, envelope, client.clone(), cx),
 730                client,
 731            )
 732        })
 733    }
 734
 735    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 736        receipt: Receipt<T>,
 737        response: F,
 738        client: Arc<Self>,
 739    ) -> Result<()> {
 740        match response.await {
 741            Ok(response) => {
 742                client.respond(receipt, response)?;
 743                Ok(())
 744            }
 745            Err(error) => {
 746                client.respond_with_error(
 747                    receipt,
 748                    proto::Error {
 749                        message: format!("{:?}", error),
 750                    },
 751                )?;
 752                Err(error)
 753            }
 754        }
 755    }
 756
 757    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 758        read_credentials_from_keychain(cx).is_some()
 759    }
 760
 761    #[async_recursion(?Send)]
 762    pub async fn authenticate_and_connect(
 763        self: &Arc<Self>,
 764        try_keychain: bool,
 765        cx: &AsyncAppContext,
 766    ) -> anyhow::Result<()> {
 767        let was_disconnected = match *self.status().borrow() {
 768            Status::SignedOut => true,
 769            Status::ConnectionError
 770            | Status::ConnectionLost
 771            | Status::Authenticating { .. }
 772            | Status::Reauthenticating { .. }
 773            | Status::ReconnectionError { .. } => false,
 774            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 775                return Ok(())
 776            }
 777            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 778        };
 779
 780        if was_disconnected {
 781            self.set_status(Status::Authenticating, cx);
 782        } else {
 783            self.set_status(Status::Reauthenticating, cx)
 784        }
 785
 786        let mut read_from_keychain = false;
 787        let mut credentials = self.state.read().credentials.clone();
 788        if credentials.is_none() && try_keychain {
 789            credentials = read_credentials_from_keychain(cx);
 790            read_from_keychain = credentials.is_some();
 791        }
 792        if credentials.is_none() {
 793            let mut status_rx = self.status();
 794            let _ = status_rx.next().await;
 795            futures::select_biased! {
 796                authenticate = self.authenticate(cx).fuse() => {
 797                    match authenticate {
 798                        Ok(creds) => credentials = Some(creds),
 799                        Err(err) => {
 800                            self.set_status(Status::ConnectionError, cx);
 801                            return Err(err);
 802                        }
 803                    }
 804                }
 805                _ = status_rx.next().fuse() => {
 806                    return Err(anyhow!("authentication canceled"));
 807                }
 808            }
 809        }
 810        let credentials = credentials.unwrap();
 811
 812        if was_disconnected {
 813            self.set_status(Status::Connecting, cx);
 814        } else {
 815            self.set_status(Status::Reconnecting, cx);
 816        }
 817
 818        let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
 819        futures::select_biased! {
 820            connection = self.establish_connection(&credentials, cx).fuse() => {
 821                match connection {
 822                    Ok(conn) => {
 823                        self.state.write().credentials = Some(credentials.clone());
 824                        if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 825                            write_credentials_to_keychain(&credentials, cx).log_err();
 826                        }
 827
 828                        futures::select_biased! {
 829                            result = self.set_connection(conn, cx).fuse() => result,
 830                            _ = timeout => {
 831                                self.set_status(Status::ConnectionError, cx);
 832                                Err(anyhow!("timed out waiting on hello message from server"))
 833                            }
 834                        }
 835                    }
 836                    Err(EstablishConnectionError::Unauthorized) => {
 837                        self.state.write().credentials.take();
 838                        if read_from_keychain {
 839                            cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 840                            self.set_status(Status::SignedOut, cx);
 841                            self.authenticate_and_connect(false, cx).await
 842                        } else {
 843                            self.set_status(Status::ConnectionError, cx);
 844                            Err(EstablishConnectionError::Unauthorized)?
 845                        }
 846                    }
 847                    Err(EstablishConnectionError::UpgradeRequired) => {
 848                        self.set_status(Status::UpgradeRequired, cx);
 849                        Err(EstablishConnectionError::UpgradeRequired)?
 850                    }
 851                    Err(error) => {
 852                        self.set_status(Status::ConnectionError, cx);
 853                        Err(error)?
 854                    }
 855                }
 856            }
 857            _ = &mut timeout => {
 858                self.set_status(Status::ConnectionError, cx);
 859                Err(anyhow!("timed out trying to establish connection"))
 860            }
 861        }
 862    }
 863
 864    async fn set_connection(
 865        self: &Arc<Self>,
 866        conn: Connection,
 867        cx: &AsyncAppContext,
 868    ) -> Result<()> {
 869        let executor = cx.background();
 870        log::info!("add connection to peer");
 871        let (connection_id, handle_io, mut incoming) = self
 872            .peer
 873            .add_connection(conn, move |duration| executor.timer(duration));
 874        let handle_io = cx.background().spawn(handle_io);
 875
 876        let peer_id = async {
 877            log::info!("waiting for server hello");
 878            let message = incoming
 879                .next()
 880                .await
 881                .ok_or_else(|| anyhow!("no hello message received"))?;
 882            log::info!("got server hello");
 883            let hello_message_type_name = message.payload_type_name().to_string();
 884            let hello = message
 885                .into_any()
 886                .downcast::<TypedEnvelope<proto::Hello>>()
 887                .map_err(|_| {
 888                    anyhow!(
 889                        "invalid hello message received: {:?}",
 890                        hello_message_type_name
 891                    )
 892                })?;
 893            let peer_id = hello
 894                .payload
 895                .peer_id
 896                .ok_or_else(|| anyhow!("invalid peer id"))?;
 897            Ok(peer_id)
 898        };
 899
 900        let peer_id = match peer_id.await {
 901            Ok(peer_id) => peer_id,
 902            Err(error) => {
 903                self.peer.disconnect(connection_id);
 904                return Err(error);
 905            }
 906        };
 907
 908        log::info!(
 909            "set status to connected (connection id: {:?}, peer id: {:?})",
 910            connection_id,
 911            peer_id
 912        );
 913        self.set_status(
 914            Status::Connected {
 915                peer_id,
 916                connection_id,
 917            },
 918            cx,
 919        );
 920        cx.foreground()
 921            .spawn({
 922                let cx = cx.clone();
 923                let this = self.clone();
 924                async move {
 925                    while let Some(message) = incoming.next().await {
 926                        this.handle_message(message, &cx);
 927                        // Don't starve the main thread when receiving lots of messages at once.
 928                        smol::future::yield_now().await;
 929                    }
 930                }
 931            })
 932            .detach();
 933
 934        let this = self.clone();
 935        let cx = cx.clone();
 936        cx.foreground()
 937            .spawn(async move {
 938                match handle_io.await {
 939                    Ok(()) => {
 940                        if this.status().borrow().clone()
 941                            == (Status::Connected {
 942                                connection_id,
 943                                peer_id,
 944                            })
 945                        {
 946                            this.set_status(Status::SignedOut, &cx);
 947                        }
 948                    }
 949                    Err(err) => {
 950                        log::error!("connection error: {:?}", err);
 951                        this.set_status(Status::ConnectionLost, &cx);
 952                    }
 953                }
 954            })
 955            .detach();
 956
 957        Ok(())
 958    }
 959
 960    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 961        #[cfg(any(test, feature = "test-support"))]
 962        if let Some(callback) = self.authenticate.read().as_ref() {
 963            return callback(cx);
 964        }
 965
 966        self.authenticate_with_browser(cx)
 967    }
 968
 969    fn establish_connection(
 970        self: &Arc<Self>,
 971        credentials: &Credentials,
 972        cx: &AsyncAppContext,
 973    ) -> Task<Result<Connection, EstablishConnectionError>> {
 974        #[cfg(any(test, feature = "test-support"))]
 975        if let Some(callback) = self.establish_connection.read().as_ref() {
 976            return callback(credentials, cx);
 977        }
 978
 979        self.establish_websocket_connection(credentials, cx)
 980    }
 981
 982    async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
 983        let preview_param = if is_preview { "?preview=1" } else { "" };
 984        let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
 985        let response = http.get(&url, Default::default(), false).await?;
 986
 987        // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
 988        // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
 989        // which requires authorization via an HTTP header.
 990        //
 991        // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
 992        // of a collab server. In that case, a request to the /rpc endpoint will
 993        // return an 'unauthorized' response.
 994        let collab_url = if response.status().is_redirection() {
 995            response
 996                .headers()
 997                .get("Location")
 998                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 999                .to_str()
1000                .map_err(EstablishConnectionError::other)?
1001                .to_string()
1002        } else if response.status() == StatusCode::UNAUTHORIZED {
1003            url
1004        } else {
1005            Err(anyhow!(
1006                "unexpected /rpc response status {}",
1007                response.status()
1008            ))?
1009        };
1010
1011        Url::parse(&collab_url).context("invalid rpc url")
1012    }
1013
1014    fn establish_websocket_connection(
1015        self: &Arc<Self>,
1016        credentials: &Credentials,
1017        cx: &AsyncAppContext,
1018    ) -> Task<Result<Connection, EstablishConnectionError>> {
1019        let is_preview = cx.read(|cx| {
1020            if cx.has_global::<ReleaseChannel>() {
1021                *cx.global::<ReleaseChannel>() == ReleaseChannel::Preview
1022            } else {
1023                false
1024            }
1025        });
1026
1027        let request = Request::builder()
1028            .header(
1029                "Authorization",
1030                format!("{} {}", credentials.user_id, credentials.access_token),
1031            )
1032            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
1033
1034        let http = self.http.clone();
1035        cx.background().spawn(async move {
1036            let mut rpc_url = Self::get_rpc_url(http, is_preview).await?;
1037            let rpc_host = rpc_url
1038                .host_str()
1039                .zip(rpc_url.port_or_known_default())
1040                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1041            let stream = smol::net::TcpStream::connect(rpc_host).await?;
1042
1043            log::info!("connected to rpc endpoint {}", rpc_url);
1044
1045            match rpc_url.scheme() {
1046                "https" => {
1047                    rpc_url.set_scheme("wss").unwrap();
1048                    let request = request.uri(rpc_url.as_str()).body(())?;
1049                    let (stream, _) =
1050                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1051                    Ok(Connection::new(
1052                        stream
1053                            .map_err(|error| anyhow!(error))
1054                            .sink_map_err(|error| anyhow!(error)),
1055                    ))
1056                }
1057                "http" => {
1058                    rpc_url.set_scheme("ws").unwrap();
1059                    let request = request.uri(rpc_url.as_str()).body(())?;
1060                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1061                    Ok(Connection::new(
1062                        stream
1063                            .map_err(|error| anyhow!(error))
1064                            .sink_map_err(|error| anyhow!(error)),
1065                    ))
1066                }
1067                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1068            }
1069        })
1070    }
1071
1072    pub fn authenticate_with_browser(
1073        self: &Arc<Self>,
1074        cx: &AsyncAppContext,
1075    ) -> Task<Result<Credentials>> {
1076        let platform = cx.platform();
1077        let executor = cx.background();
1078        let http = self.http.clone();
1079
1080        executor.clone().spawn(async move {
1081            // Generate a pair of asymmetric encryption keys. The public key will be used by the
1082            // zed server to encrypt the user's access token, so that it can'be intercepted by
1083            // any other app running on the user's device.
1084            let (public_key, private_key) =
1085                rpc::auth::keypair().expect("failed to generate keypair for auth");
1086            let public_key_string =
1087                String::try_from(public_key).expect("failed to serialize public key for auth");
1088
1089            if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1090                return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1091            }
1092
1093            // Start an HTTP server to receive the redirect from Zed's sign-in page.
1094            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1095            let port = server.server_addr().port();
1096
1097            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1098            // that the user is signing in from a Zed app running on the same device.
1099            let mut url = format!(
1100                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1101                *ZED_SERVER_URL, port, public_key_string
1102            );
1103
1104            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1105                log::info!("impersonating user @{}", impersonate_login);
1106                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1107            }
1108
1109            platform.open_url(&url);
1110
1111            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1112            // access token from the query params.
1113            //
1114            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1115            // custom URL scheme instead of this local HTTP server.
1116            let (user_id, access_token) = executor
1117                .spawn(async move {
1118                    for _ in 0..100 {
1119                        if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1120                            let path = req.url();
1121                            let mut user_id = None;
1122                            let mut access_token = None;
1123                            let url = Url::parse(&format!("http://example.com{}", path))
1124                                .context("failed to parse login notification url")?;
1125                            for (key, value) in url.query_pairs() {
1126                                if key == "access_token" {
1127                                    access_token = Some(value.to_string());
1128                                } else if key == "user_id" {
1129                                    user_id = Some(value.to_string());
1130                                }
1131                            }
1132
1133                            let post_auth_url =
1134                                format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1135                            req.respond(
1136                                tiny_http::Response::empty(302).with_header(
1137                                    tiny_http::Header::from_bytes(
1138                                        &b"Location"[..],
1139                                        post_auth_url.as_bytes(),
1140                                    )
1141                                    .unwrap(),
1142                                ),
1143                            )
1144                            .context("failed to respond to login http request")?;
1145                            return Ok((
1146                                user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1147                                access_token
1148                                    .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1149                            ));
1150                        }
1151                    }
1152
1153                    Err(anyhow!("didn't receive login redirect"))
1154                })
1155                .await?;
1156
1157            let access_token = private_key
1158                .decrypt_string(&access_token)
1159                .context("failed to decrypt access token")?;
1160            platform.activate(true);
1161
1162            Ok(Credentials {
1163                user_id: user_id.parse()?,
1164                access_token,
1165            })
1166        })
1167    }
1168
1169    async fn authenticate_as_admin(
1170        http: Arc<dyn HttpClient>,
1171        login: String,
1172        mut api_token: String,
1173    ) -> Result<Credentials> {
1174        #[derive(Deserialize)]
1175        struct AuthenticatedUserResponse {
1176            user: User,
1177        }
1178
1179        #[derive(Deserialize)]
1180        struct User {
1181            id: u64,
1182        }
1183
1184        // Use the collab server's admin API to retrieve the id
1185        // of the impersonated user.
1186        let mut url = Self::get_rpc_url(http.clone(), false).await?;
1187        url.set_path("/user");
1188        url.set_query(Some(&format!("github_login={login}")));
1189        let request = Request::get(url.as_str())
1190            .header("Authorization", format!("token {api_token}"))
1191            .body("".into())?;
1192
1193        let mut response = http.send(request).await?;
1194        let mut body = String::new();
1195        response.body_mut().read_to_string(&mut body).await?;
1196        if !response.status().is_success() {
1197            Err(anyhow!(
1198                "admin user request failed {} - {}",
1199                response.status().as_u16(),
1200                body,
1201            ))?;
1202        }
1203        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1204
1205        // Use the admin API token to authenticate as the impersonated user.
1206        api_token.insert_str(0, "ADMIN_TOKEN:");
1207        Ok(Credentials {
1208            user_id: response.user.id,
1209            access_token: api_token,
1210        })
1211    }
1212
1213    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1214        self.peer.teardown();
1215        self.set_status(Status::SignedOut, cx);
1216    }
1217
1218    fn connection_id(&self) -> Result<ConnectionId> {
1219        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1220            Ok(connection_id)
1221        } else {
1222            Err(anyhow!("not connected"))
1223        }
1224    }
1225
1226    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1227        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1228        self.peer.send(self.connection_id()?, message)
1229    }
1230
1231    pub fn request<T: RequestMessage>(
1232        &self,
1233        request: T,
1234    ) -> impl Future<Output = Result<T::Response>> {
1235        self.request_envelope(request)
1236            .map_ok(|envelope| envelope.payload)
1237    }
1238
1239    pub fn request_envelope<T: RequestMessage>(
1240        &self,
1241        request: T,
1242    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1243        let client_id = self.id;
1244        log::debug!(
1245            "rpc request start. client_id:{}. name:{}",
1246            client_id,
1247            T::NAME
1248        );
1249        let response = self
1250            .connection_id()
1251            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1252        async move {
1253            let response = response?.await;
1254            log::debug!(
1255                "rpc request finish. client_id:{}. name:{}",
1256                client_id,
1257                T::NAME
1258            );
1259            response
1260        }
1261    }
1262
1263    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1264        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1265        self.peer.respond(receipt, response)
1266    }
1267
1268    fn respond_with_error<T: RequestMessage>(
1269        &self,
1270        receipt: Receipt<T>,
1271        error: proto::Error,
1272    ) -> Result<()> {
1273        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1274        self.peer.respond_with_error(receipt, error)
1275    }
1276
1277    fn handle_message(
1278        self: &Arc<Client>,
1279        message: Box<dyn AnyTypedEnvelope>,
1280        cx: &AsyncAppContext,
1281    ) {
1282        let mut state = self.state.write();
1283        let type_name = message.payload_type_name();
1284        let payload_type_id = message.payload_type_id();
1285        let sender_id = message.original_sender_id();
1286
1287        let mut subscriber = None;
1288
1289        if let Some(message_model) = state
1290            .models_by_message_type
1291            .get(&payload_type_id)
1292            .and_then(|model| model.upgrade(cx))
1293        {
1294            subscriber = Some(Subscriber::Model(message_model));
1295        } else if let Some((extract_entity_id, entity_type_id)) =
1296            state.entity_id_extractors.get(&payload_type_id).zip(
1297                state
1298                    .entity_types_by_message_type
1299                    .get(&payload_type_id)
1300                    .copied(),
1301            )
1302        {
1303            let entity_id = (extract_entity_id)(message.as_ref());
1304
1305            match state
1306                .entities_by_type_and_remote_id
1307                .get_mut(&(entity_type_id, entity_id))
1308            {
1309                Some(WeakSubscriber::Pending(pending)) => {
1310                    pending.push(message);
1311                    return;
1312                }
1313                Some(weak_subscriber @ _) => match weak_subscriber {
1314                    WeakSubscriber::Model(handle) => {
1315                        subscriber = handle.upgrade(cx).map(Subscriber::Model);
1316                    }
1317                    WeakSubscriber::View(handle) => {
1318                        subscriber = Some(Subscriber::View(handle.clone()));
1319                    }
1320                    WeakSubscriber::Pending(_) => {}
1321                },
1322                _ => {}
1323            }
1324        }
1325
1326        let subscriber = if let Some(subscriber) = subscriber {
1327            subscriber
1328        } else {
1329            log::info!("unhandled message {}", type_name);
1330            self.peer.respond_with_unhandled_message(message).log_err();
1331            return;
1332        };
1333
1334        let handler = state.message_handlers.get(&payload_type_id).cloned();
1335        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1336        // It also ensures we don't hold the lock while yielding back to the executor, as
1337        // that might cause the executor thread driving this future to block indefinitely.
1338        drop(state);
1339
1340        if let Some(handler) = handler {
1341            let future = handler(subscriber, message, &self, cx.clone());
1342            let client_id = self.id;
1343            log::debug!(
1344                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1345                client_id,
1346                sender_id,
1347                type_name
1348            );
1349            cx.foreground()
1350                .spawn(async move {
1351                    match future.await {
1352                        Ok(()) => {
1353                            log::debug!(
1354                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1355                                client_id,
1356                                sender_id,
1357                                type_name
1358                            );
1359                        }
1360                        Err(error) => {
1361                            log::error!(
1362                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1363                                client_id,
1364                                sender_id,
1365                                type_name,
1366                                error
1367                            );
1368                        }
1369                    }
1370                })
1371                .detach();
1372        } else {
1373            log::info!("unhandled message {}", type_name);
1374            self.peer.respond_with_unhandled_message(message).log_err();
1375        }
1376    }
1377
1378    pub fn telemetry(&self) -> &Arc<Telemetry> {
1379        &self.telemetry
1380    }
1381}
1382
1383fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1384    if IMPERSONATE_LOGIN.is_some() {
1385        return None;
1386    }
1387
1388    let (user_id, access_token) = cx
1389        .platform()
1390        .read_credentials(&ZED_SERVER_URL)
1391        .log_err()
1392        .flatten()?;
1393    Some(Credentials {
1394        user_id: user_id.parse().ok()?,
1395        access_token: String::from_utf8(access_token).ok()?,
1396    })
1397}
1398
1399fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1400    cx.platform().write_credentials(
1401        &ZED_SERVER_URL,
1402        &credentials.user_id.to_string(),
1403        credentials.access_token.as_bytes(),
1404    )
1405}
1406
1407const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1408
1409pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1410    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1411}
1412
1413pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1414    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1415    let mut parts = path.split('/');
1416    let id = parts.next()?.parse::<u64>().ok()?;
1417    let access_token = parts.next()?;
1418    if access_token.is_empty() {
1419        return None;
1420    }
1421    Some((id, access_token.to_string()))
1422}
1423
1424#[cfg(test)]
1425mod tests {
1426    use super::*;
1427    use crate::test::FakeServer;
1428    use gpui::{executor::Deterministic, TestAppContext};
1429    use parking_lot::Mutex;
1430    use std::future;
1431    use util::http::FakeHttpClient;
1432
1433    #[gpui::test(iterations = 10)]
1434    async fn test_reconnection(cx: &mut TestAppContext) {
1435        cx.foreground().forbid_parking();
1436
1437        let user_id = 5;
1438        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1439        let server = FakeServer::for_client(user_id, &client, cx).await;
1440        let mut status = client.status();
1441        assert!(matches!(
1442            status.next().await,
1443            Some(Status::Connected { .. })
1444        ));
1445        assert_eq!(server.auth_count(), 1);
1446
1447        server.forbid_connections();
1448        server.disconnect();
1449        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1450
1451        server.allow_connections();
1452        cx.foreground().advance_clock(Duration::from_secs(10));
1453        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1454        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1455
1456        server.forbid_connections();
1457        server.disconnect();
1458        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1459
1460        // Clear cached credentials after authentication fails
1461        server.roll_access_token();
1462        server.allow_connections();
1463        cx.foreground().advance_clock(Duration::from_secs(10));
1464        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1465        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1466    }
1467
1468    #[gpui::test(iterations = 10)]
1469    async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1470        deterministic.forbid_parking();
1471
1472        let user_id = 5;
1473        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1474        let mut status = client.status();
1475
1476        // Time out when client tries to connect.
1477        client.override_authenticate(move |cx| {
1478            cx.foreground().spawn(async move {
1479                Ok(Credentials {
1480                    user_id,
1481                    access_token: "token".into(),
1482                })
1483            })
1484        });
1485        client.override_establish_connection(|_, cx| {
1486            cx.foreground().spawn(async move {
1487                future::pending::<()>().await;
1488                unreachable!()
1489            })
1490        });
1491        let auth_and_connect = cx.spawn({
1492            let client = client.clone();
1493            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1494        });
1495        deterministic.run_until_parked();
1496        assert!(matches!(status.next().await, Some(Status::Connecting)));
1497
1498        deterministic.advance_clock(CONNECTION_TIMEOUT);
1499        assert!(matches!(
1500            status.next().await,
1501            Some(Status::ConnectionError { .. })
1502        ));
1503        auth_and_connect.await.unwrap_err();
1504
1505        // Allow the connection to be established.
1506        let server = FakeServer::for_client(user_id, &client, cx).await;
1507        assert!(matches!(
1508            status.next().await,
1509            Some(Status::Connected { .. })
1510        ));
1511
1512        // Disconnect client.
1513        server.forbid_connections();
1514        server.disconnect();
1515        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1516
1517        // Time out when re-establishing the connection.
1518        server.allow_connections();
1519        client.override_establish_connection(|_, cx| {
1520            cx.foreground().spawn(async move {
1521                future::pending::<()>().await;
1522                unreachable!()
1523            })
1524        });
1525        deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1526        assert!(matches!(
1527            status.next().await,
1528            Some(Status::Reconnecting { .. })
1529        ));
1530
1531        deterministic.advance_clock(CONNECTION_TIMEOUT);
1532        assert!(matches!(
1533            status.next().await,
1534            Some(Status::ReconnectionError { .. })
1535        ));
1536    }
1537
1538    #[gpui::test(iterations = 10)]
1539    async fn test_authenticating_more_than_once(
1540        cx: &mut TestAppContext,
1541        deterministic: Arc<Deterministic>,
1542    ) {
1543        cx.foreground().forbid_parking();
1544
1545        let auth_count = Arc::new(Mutex::new(0));
1546        let dropped_auth_count = Arc::new(Mutex::new(0));
1547        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1548        client.override_authenticate({
1549            let auth_count = auth_count.clone();
1550            let dropped_auth_count = dropped_auth_count.clone();
1551            move |cx| {
1552                let auth_count = auth_count.clone();
1553                let dropped_auth_count = dropped_auth_count.clone();
1554                cx.foreground().spawn(async move {
1555                    *auth_count.lock() += 1;
1556                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1557                    future::pending::<()>().await;
1558                    unreachable!()
1559                })
1560            }
1561        });
1562
1563        let _authenticate = cx.spawn(|cx| {
1564            let client = client.clone();
1565            async move { client.authenticate_and_connect(false, &cx).await }
1566        });
1567        deterministic.run_until_parked();
1568        assert_eq!(*auth_count.lock(), 1);
1569        assert_eq!(*dropped_auth_count.lock(), 0);
1570
1571        let _authenticate = cx.spawn(|cx| {
1572            let client = client.clone();
1573            async move { client.authenticate_and_connect(false, &cx).await }
1574        });
1575        deterministic.run_until_parked();
1576        assert_eq!(*auth_count.lock(), 2);
1577        assert_eq!(*dropped_auth_count.lock(), 1);
1578    }
1579
1580    #[test]
1581    fn test_encode_and_decode_worktree_url() {
1582        let url = encode_worktree_url(5, "deadbeef");
1583        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1584        assert_eq!(
1585            decode_worktree_url(&format!("\n {}\t", url)),
1586            Some((5, "deadbeef".to_string()))
1587        );
1588        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1589    }
1590
1591    #[gpui::test]
1592    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1593        cx.foreground().forbid_parking();
1594
1595        let user_id = 5;
1596        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1597        let server = FakeServer::for_client(user_id, &client, cx).await;
1598
1599        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1600        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1601        client.add_model_message_handler(
1602            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1603                match model.read_with(&cx, |model, _| model.id) {
1604                    1 => done_tx1.try_send(()).unwrap(),
1605                    2 => done_tx2.try_send(()).unwrap(),
1606                    _ => unreachable!(),
1607                }
1608                async { Ok(()) }
1609            },
1610        );
1611        let model1 = cx.add_model(|_| Model {
1612            id: 1,
1613            subscription: None,
1614        });
1615        let model2 = cx.add_model(|_| Model {
1616            id: 2,
1617            subscription: None,
1618        });
1619        let model3 = cx.add_model(|_| Model {
1620            id: 3,
1621            subscription: None,
1622        });
1623
1624        let _subscription1 = client
1625            .subscribe_to_entity(1)
1626            .unwrap()
1627            .set_model(&model1, &mut cx.to_async());
1628        let _subscription2 = client
1629            .subscribe_to_entity(2)
1630            .unwrap()
1631            .set_model(&model2, &mut cx.to_async());
1632        // Ensure dropping a subscription for the same entity type still allows receiving of
1633        // messages for other entity IDs of the same type.
1634        let subscription3 = client
1635            .subscribe_to_entity(3)
1636            .unwrap()
1637            .set_model(&model3, &mut cx.to_async());
1638        drop(subscription3);
1639
1640        server.send(proto::JoinProject { project_id: 1 });
1641        server.send(proto::JoinProject { project_id: 2 });
1642        done_rx1.next().await.unwrap();
1643        done_rx2.next().await.unwrap();
1644    }
1645
1646    #[gpui::test]
1647    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1648        cx.foreground().forbid_parking();
1649
1650        let user_id = 5;
1651        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1652        let server = FakeServer::for_client(user_id, &client, cx).await;
1653
1654        let model = cx.add_model(|_| Model::default());
1655        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1656        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1657        let subscription1 = client.add_message_handler(
1658            model.clone(),
1659            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1660                done_tx1.try_send(()).unwrap();
1661                async { Ok(()) }
1662            },
1663        );
1664        drop(subscription1);
1665        let _subscription2 = client.add_message_handler(
1666            model.clone(),
1667            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1668                done_tx2.try_send(()).unwrap();
1669                async { Ok(()) }
1670            },
1671        );
1672        server.send(proto::Ping {});
1673        done_rx2.next().await.unwrap();
1674    }
1675
1676    #[gpui::test]
1677    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1678        cx.foreground().forbid_parking();
1679
1680        let user_id = 5;
1681        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1682        let server = FakeServer::for_client(user_id, &client, cx).await;
1683
1684        let model = cx.add_model(|_| Model::default());
1685        let (done_tx, mut done_rx) = smol::channel::unbounded();
1686        let subscription = client.add_message_handler(
1687            model.clone(),
1688            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1689                model.update(&mut cx, |model, _| model.subscription.take());
1690                done_tx.try_send(()).unwrap();
1691                async { Ok(()) }
1692            },
1693        );
1694        model.update(cx, |model, _| {
1695            model.subscription = Some(subscription);
1696        });
1697        server.send(proto::Ping {});
1698        done_rx.next().await.unwrap();
1699    }
1700
1701    #[derive(Default)]
1702    struct Model {
1703        id: usize,
1704        subscription: Option<Subscription>,
1705    }
1706
1707    impl Entity for Model {
1708        type Event = ();
1709    }
1710}