client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4#[cfg(test)]
   5mod channel_store_tests;
   6
   7pub mod channel_store;
   8pub mod telemetry;
   9pub mod user;
  10
  11use anyhow::{anyhow, Context, Result};
  12use async_recursion::async_recursion;
  13use async_tungstenite::tungstenite::{
  14    error::Error as WebsocketError,
  15    http::{Request, StatusCode},
  16};
  17use futures::{
  18    future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
  19    TryStreamExt,
  20};
  21use gpui::{
  22    actions, platform::AppVersion, serde_json, AnyModelHandle, AnyWeakModelHandle,
  23    AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Task, View, ViewContext,
  24    WeakViewHandle,
  25};
  26use lazy_static::lazy_static;
  27use parking_lot::RwLock;
  28use postage::watch;
  29use rand::prelude::*;
  30use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
  31use schemars::JsonSchema;
  32use serde::{Deserialize, Serialize};
  33use std::{
  34    any::TypeId,
  35    collections::HashMap,
  36    convert::TryFrom,
  37    fmt::Write as _,
  38    future::Future,
  39    marker::PhantomData,
  40    path::PathBuf,
  41    sync::{Arc, Weak},
  42    time::{Duration, Instant},
  43};
  44use telemetry::Telemetry;
  45use thiserror::Error;
  46use url::Url;
  47use util::channel::ReleaseChannel;
  48use util::http::HttpClient;
  49use util::{ResultExt, TryFutureExt};
  50
  51pub use channel_store::*;
  52pub use rpc::*;
  53pub use telemetry::ClickhouseEvent;
  54pub use user::*;
  55
  56lazy_static! {
  57    pub static ref ZED_SERVER_URL: String =
  58        std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
  59    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  60        .ok()
  61        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  62    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
  63        .ok()
  64        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  65    pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
  66        .ok()
  67        .and_then(|v| v.parse().ok());
  68    pub static ref ZED_APP_PATH: Option<PathBuf> =
  69        std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
  70}
  71
  72pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
  73pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
  74pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
  75
  76actions!(client, [SignIn, SignOut]);
  77
  78pub fn init_settings(cx: &mut AppContext) {
  79    settings::register::<TelemetrySettings>(cx);
  80}
  81
  82pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
  83    init_settings(cx);
  84
  85    let client = Arc::downgrade(client);
  86    cx.add_global_action({
  87        let client = client.clone();
  88        move |_: &SignIn, cx| {
  89            if let Some(client) = client.upgrade() {
  90                cx.spawn(
  91                    |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
  92                )
  93                .detach();
  94            }
  95        }
  96    });
  97    cx.add_global_action({
  98        let client = client.clone();
  99        move |_: &SignOut, cx| {
 100            if let Some(client) = client.upgrade() {
 101                cx.spawn(|cx| async move {
 102                    client.disconnect(&cx);
 103                })
 104                .detach();
 105            }
 106        }
 107    });
 108}
 109
 110pub struct Client {
 111    id: usize,
 112    peer: Arc<Peer>,
 113    http: Arc<dyn HttpClient>,
 114    telemetry: Arc<Telemetry>,
 115    state: RwLock<ClientState>,
 116
 117    #[allow(clippy::type_complexity)]
 118    #[cfg(any(test, feature = "test-support"))]
 119    authenticate: RwLock<
 120        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 121    >,
 122
 123    #[allow(clippy::type_complexity)]
 124    #[cfg(any(test, feature = "test-support"))]
 125    establish_connection: RwLock<
 126        Option<
 127            Box<
 128                dyn 'static
 129                    + Send
 130                    + Sync
 131                    + Fn(
 132                        &Credentials,
 133                        &AsyncAppContext,
 134                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 135            >,
 136        >,
 137    >,
 138}
 139
 140#[derive(Error, Debug)]
 141pub enum EstablishConnectionError {
 142    #[error("upgrade required")]
 143    UpgradeRequired,
 144    #[error("unauthorized")]
 145    Unauthorized,
 146    #[error("{0}")]
 147    Other(#[from] anyhow::Error),
 148    #[error("{0}")]
 149    Http(#[from] util::http::Error),
 150    #[error("{0}")]
 151    Io(#[from] std::io::Error),
 152    #[error("{0}")]
 153    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 154}
 155
 156impl From<WebsocketError> for EstablishConnectionError {
 157    fn from(error: WebsocketError) -> Self {
 158        if let WebsocketError::Http(response) = &error {
 159            match response.status() {
 160                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 161                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 162                _ => {}
 163            }
 164        }
 165        EstablishConnectionError::Other(error.into())
 166    }
 167}
 168
 169impl EstablishConnectionError {
 170    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 171        Self::Other(error.into())
 172    }
 173}
 174
 175#[derive(Copy, Clone, Debug, PartialEq)]
 176pub enum Status {
 177    SignedOut,
 178    UpgradeRequired,
 179    Authenticating,
 180    Connecting,
 181    ConnectionError,
 182    Connected {
 183        peer_id: PeerId,
 184        connection_id: ConnectionId,
 185    },
 186    ConnectionLost,
 187    Reauthenticating,
 188    Reconnecting,
 189    ReconnectionError {
 190        next_reconnection: Instant,
 191    },
 192}
 193
 194impl Status {
 195    pub fn is_connected(&self) -> bool {
 196        matches!(self, Self::Connected { .. })
 197    }
 198
 199    pub fn is_signed_out(&self) -> bool {
 200        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 201    }
 202}
 203
 204struct ClientState {
 205    credentials: Option<Credentials>,
 206    status: (watch::Sender<Status>, watch::Receiver<Status>),
 207    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 208    _reconnect_task: Option<Task<()>>,
 209    reconnect_interval: Duration,
 210    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 211    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 212    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 213    #[allow(clippy::type_complexity)]
 214    message_handlers: HashMap<
 215        TypeId,
 216        Arc<
 217            dyn Send
 218                + Sync
 219                + Fn(
 220                    Subscriber,
 221                    Box<dyn AnyTypedEnvelope>,
 222                    &Arc<Client>,
 223                    AsyncAppContext,
 224                ) -> LocalBoxFuture<'static, Result<()>>,
 225        >,
 226    >,
 227}
 228
 229enum WeakSubscriber {
 230    Model(AnyWeakModelHandle),
 231    View(AnyWeakViewHandle),
 232    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 233}
 234
 235enum Subscriber {
 236    Model(AnyModelHandle),
 237    View(AnyWeakViewHandle),
 238}
 239
 240#[derive(Clone, Debug)]
 241pub struct Credentials {
 242    pub user_id: u64,
 243    pub access_token: String,
 244}
 245
 246impl Default for ClientState {
 247    fn default() -> Self {
 248        Self {
 249            credentials: None,
 250            status: watch::channel_with(Status::SignedOut),
 251            entity_id_extractors: Default::default(),
 252            _reconnect_task: None,
 253            reconnect_interval: Duration::from_secs(5),
 254            models_by_message_type: Default::default(),
 255            entities_by_type_and_remote_id: Default::default(),
 256            entity_types_by_message_type: Default::default(),
 257            message_handlers: Default::default(),
 258        }
 259    }
 260}
 261
 262pub enum Subscription {
 263    Entity {
 264        client: Weak<Client>,
 265        id: (TypeId, u64),
 266    },
 267    Message {
 268        client: Weak<Client>,
 269        id: TypeId,
 270    },
 271}
 272
 273impl Drop for Subscription {
 274    fn drop(&mut self) {
 275        match self {
 276            Subscription::Entity { client, id } => {
 277                if let Some(client) = client.upgrade() {
 278                    let mut state = client.state.write();
 279                    let _ = state.entities_by_type_and_remote_id.remove(id);
 280                }
 281            }
 282            Subscription::Message { client, id } => {
 283                if let Some(client) = client.upgrade() {
 284                    let mut state = client.state.write();
 285                    let _ = state.entity_types_by_message_type.remove(id);
 286                    let _ = state.message_handlers.remove(id);
 287                }
 288            }
 289        }
 290    }
 291}
 292
 293pub struct PendingEntitySubscription<T: Entity> {
 294    client: Arc<Client>,
 295    remote_id: u64,
 296    _entity_type: PhantomData<T>,
 297    consumed: bool,
 298}
 299
 300impl<T: Entity> PendingEntitySubscription<T> {
 301    pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
 302        self.consumed = true;
 303        let mut state = self.client.state.write();
 304        let id = (TypeId::of::<T>(), self.remote_id);
 305        let Some(WeakSubscriber::Pending(messages)) =
 306            state.entities_by_type_and_remote_id.remove(&id)
 307        else {
 308            unreachable!()
 309        };
 310
 311        state
 312            .entities_by_type_and_remote_id
 313            .insert(id, WeakSubscriber::Model(model.downgrade().into_any()));
 314        drop(state);
 315        for message in messages {
 316            self.client.handle_message(message, cx);
 317        }
 318        Subscription::Entity {
 319            client: Arc::downgrade(&self.client),
 320            id,
 321        }
 322    }
 323}
 324
 325impl<T: Entity> Drop for PendingEntitySubscription<T> {
 326    fn drop(&mut self) {
 327        if !self.consumed {
 328            let mut state = self.client.state.write();
 329            if let Some(WeakSubscriber::Pending(messages)) = state
 330                .entities_by_type_and_remote_id
 331                .remove(&(TypeId::of::<T>(), self.remote_id))
 332            {
 333                for message in messages {
 334                    log::info!("unhandled message {}", message.payload_type_name());
 335                }
 336            }
 337        }
 338    }
 339}
 340
 341#[derive(Copy, Clone)]
 342pub struct TelemetrySettings {
 343    pub diagnostics: bool,
 344    pub metrics: bool,
 345}
 346
 347#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
 348pub struct TelemetrySettingsContent {
 349    pub diagnostics: Option<bool>,
 350    pub metrics: Option<bool>,
 351}
 352
 353impl settings::Setting for TelemetrySettings {
 354    const KEY: Option<&'static str> = Some("telemetry");
 355
 356    type FileContent = TelemetrySettingsContent;
 357
 358    fn load(
 359        default_value: &Self::FileContent,
 360        user_values: &[&Self::FileContent],
 361        _: &AppContext,
 362    ) -> Result<Self> {
 363        Ok(Self {
 364            diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
 365                default_value
 366                    .diagnostics
 367                    .ok_or_else(Self::missing_default)?,
 368            ),
 369            metrics: user_values
 370                .first()
 371                .and_then(|v| v.metrics)
 372                .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
 373        })
 374    }
 375}
 376
 377impl Client {
 378    pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
 379        Arc::new(Self {
 380            id: 0,
 381            peer: Peer::new(0),
 382            telemetry: Telemetry::new(http.clone(), cx),
 383            http,
 384            state: Default::default(),
 385
 386            #[cfg(any(test, feature = "test-support"))]
 387            authenticate: Default::default(),
 388            #[cfg(any(test, feature = "test-support"))]
 389            establish_connection: Default::default(),
 390        })
 391    }
 392
 393    pub fn id(&self) -> usize {
 394        self.id
 395    }
 396
 397    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 398        self.http.clone()
 399    }
 400
 401    #[cfg(any(test, feature = "test-support"))]
 402    pub fn set_id(&mut self, id: usize) -> &Self {
 403        self.id = id;
 404        self
 405    }
 406
 407    #[cfg(any(test, feature = "test-support"))]
 408    pub fn teardown(&self) {
 409        let mut state = self.state.write();
 410        state._reconnect_task.take();
 411        state.message_handlers.clear();
 412        state.models_by_message_type.clear();
 413        state.entities_by_type_and_remote_id.clear();
 414        state.entity_id_extractors.clear();
 415        self.peer.teardown();
 416    }
 417
 418    #[cfg(any(test, feature = "test-support"))]
 419    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 420    where
 421        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 422    {
 423        *self.authenticate.write() = Some(Box::new(authenticate));
 424        self
 425    }
 426
 427    #[cfg(any(test, feature = "test-support"))]
 428    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 429    where
 430        F: 'static
 431            + Send
 432            + Sync
 433            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 434    {
 435        *self.establish_connection.write() = Some(Box::new(connect));
 436        self
 437    }
 438
 439    pub fn user_id(&self) -> Option<u64> {
 440        self.state
 441            .read()
 442            .credentials
 443            .as_ref()
 444            .map(|credentials| credentials.user_id)
 445    }
 446
 447    pub fn peer_id(&self) -> Option<PeerId> {
 448        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 449            Some(*peer_id)
 450        } else {
 451            None
 452        }
 453    }
 454
 455    pub fn status(&self) -> watch::Receiver<Status> {
 456        self.state.read().status.1.clone()
 457    }
 458
 459    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 460        log::info!("set status on client {}: {:?}", self.id, status);
 461        let mut state = self.state.write();
 462        *state.status.0.borrow_mut() = status;
 463
 464        match status {
 465            Status::Connected { .. } => {
 466                state._reconnect_task = None;
 467            }
 468            Status::ConnectionLost => {
 469                let this = self.clone();
 470                let reconnect_interval = state.reconnect_interval;
 471                state._reconnect_task = Some(cx.spawn(|cx| async move {
 472                    #[cfg(any(test, feature = "test-support"))]
 473                    let mut rng = StdRng::seed_from_u64(0);
 474                    #[cfg(not(any(test, feature = "test-support")))]
 475                    let mut rng = StdRng::from_entropy();
 476
 477                    let mut delay = INITIAL_RECONNECTION_DELAY;
 478                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 479                        log::error!("failed to connect {}", error);
 480                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 481                            this.set_status(
 482                                Status::ReconnectionError {
 483                                    next_reconnection: Instant::now() + delay,
 484                                },
 485                                &cx,
 486                            );
 487                            cx.background().timer(delay).await;
 488                            delay = delay
 489                                .mul_f32(rng.gen_range(1.0..=2.0))
 490                                .min(reconnect_interval);
 491                        } else {
 492                            break;
 493                        }
 494                    }
 495                }));
 496            }
 497            Status::SignedOut | Status::UpgradeRequired => {
 498                cx.read(|cx| self.telemetry.set_authenticated_user_info(None, false, cx));
 499                state._reconnect_task.take();
 500            }
 501            _ => {}
 502        }
 503    }
 504
 505    pub fn add_view_for_remote_entity<T: View>(
 506        self: &Arc<Self>,
 507        remote_id: u64,
 508        cx: &mut ViewContext<T>,
 509    ) -> Subscription {
 510        let id = (TypeId::of::<T>(), remote_id);
 511        self.state
 512            .write()
 513            .entities_by_type_and_remote_id
 514            .insert(id, WeakSubscriber::View(cx.weak_handle().into_any()));
 515        Subscription::Entity {
 516            client: Arc::downgrade(self),
 517            id,
 518        }
 519    }
 520
 521    pub fn subscribe_to_entity<T: Entity>(
 522        self: &Arc<Self>,
 523        remote_id: u64,
 524    ) -> Result<PendingEntitySubscription<T>> {
 525        let id = (TypeId::of::<T>(), remote_id);
 526
 527        let mut state = self.state.write();
 528        if state.entities_by_type_and_remote_id.contains_key(&id) {
 529            return Err(anyhow!("already subscribed to entity"));
 530        } else {
 531            state
 532                .entities_by_type_and_remote_id
 533                .insert(id, WeakSubscriber::Pending(Default::default()));
 534            Ok(PendingEntitySubscription {
 535                client: self.clone(),
 536                remote_id,
 537                consumed: false,
 538                _entity_type: PhantomData,
 539            })
 540        }
 541    }
 542
 543    pub fn add_message_handler<M, E, H, F>(
 544        self: &Arc<Self>,
 545        model: ModelHandle<E>,
 546        handler: H,
 547    ) -> Subscription
 548    where
 549        M: EnvelopedMessage,
 550        E: Entity,
 551        H: 'static
 552            + Send
 553            + Sync
 554            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 555        F: 'static + Future<Output = Result<()>>,
 556    {
 557        let message_type_id = TypeId::of::<M>();
 558
 559        let mut state = self.state.write();
 560        state
 561            .models_by_message_type
 562            .insert(message_type_id, model.downgrade().into_any());
 563
 564        let prev_handler = state.message_handlers.insert(
 565            message_type_id,
 566            Arc::new(move |handle, envelope, client, cx| {
 567                let handle = if let Subscriber::Model(handle) = handle {
 568                    handle
 569                } else {
 570                    unreachable!();
 571                };
 572                let model = handle.downcast::<E>().unwrap();
 573                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 574                handler(model, *envelope, client.clone(), cx).boxed_local()
 575            }),
 576        );
 577        if prev_handler.is_some() {
 578            panic!(
 579                "registered handler for the same message {} twice",
 580                std::any::type_name::<M>()
 581            );
 582        }
 583
 584        Subscription::Message {
 585            client: Arc::downgrade(self),
 586            id: message_type_id,
 587        }
 588    }
 589
 590    pub fn add_request_handler<M, E, H, F>(
 591        self: &Arc<Self>,
 592        model: ModelHandle<E>,
 593        handler: H,
 594    ) -> Subscription
 595    where
 596        M: RequestMessage,
 597        E: Entity,
 598        H: 'static
 599            + Send
 600            + Sync
 601            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 602        F: 'static + Future<Output = Result<M::Response>>,
 603    {
 604        self.add_message_handler(model, move |handle, envelope, this, cx| {
 605            Self::respond_to_request(
 606                envelope.receipt(),
 607                handler(handle, envelope, this.clone(), cx),
 608                this,
 609            )
 610        })
 611    }
 612
 613    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 614    where
 615        M: EntityMessage,
 616        E: View,
 617        H: 'static
 618            + Send
 619            + Sync
 620            + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 621        F: 'static + Future<Output = Result<()>>,
 622    {
 623        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 624            if let Subscriber::View(handle) = handle {
 625                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 626            } else {
 627                unreachable!();
 628            }
 629        })
 630    }
 631
 632    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 633    where
 634        M: EntityMessage,
 635        E: Entity,
 636        H: 'static
 637            + Send
 638            + Sync
 639            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 640        F: 'static + Future<Output = Result<()>>,
 641    {
 642        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 643            if let Subscriber::Model(handle) = handle {
 644                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 645            } else {
 646                unreachable!();
 647            }
 648        })
 649    }
 650
 651    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 652    where
 653        M: EntityMessage,
 654        E: Entity,
 655        H: 'static
 656            + Send
 657            + Sync
 658            + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 659        F: 'static + Future<Output = Result<()>>,
 660    {
 661        let model_type_id = TypeId::of::<E>();
 662        let message_type_id = TypeId::of::<M>();
 663
 664        let mut state = self.state.write();
 665        state
 666            .entity_types_by_message_type
 667            .insert(message_type_id, model_type_id);
 668        state
 669            .entity_id_extractors
 670            .entry(message_type_id)
 671            .or_insert_with(|| {
 672                |envelope| {
 673                    envelope
 674                        .as_any()
 675                        .downcast_ref::<TypedEnvelope<M>>()
 676                        .unwrap()
 677                        .payload
 678                        .remote_entity_id()
 679                }
 680            });
 681        let prev_handler = state.message_handlers.insert(
 682            message_type_id,
 683            Arc::new(move |handle, envelope, client, cx| {
 684                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 685                handler(handle, *envelope, client.clone(), cx).boxed_local()
 686            }),
 687        );
 688        if prev_handler.is_some() {
 689            panic!("registered handler for the same message twice");
 690        }
 691    }
 692
 693    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 694    where
 695        M: EntityMessage + RequestMessage,
 696        E: Entity,
 697        H: 'static
 698            + Send
 699            + Sync
 700            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 701        F: 'static + Future<Output = Result<M::Response>>,
 702    {
 703        self.add_model_message_handler(move |entity, envelope, client, cx| {
 704            Self::respond_to_request::<M, _>(
 705                envelope.receipt(),
 706                handler(entity, envelope, client.clone(), cx),
 707                client,
 708            )
 709        })
 710    }
 711
 712    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 713    where
 714        M: EntityMessage + RequestMessage,
 715        E: View,
 716        H: 'static
 717            + Send
 718            + Sync
 719            + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 720        F: 'static + Future<Output = Result<M::Response>>,
 721    {
 722        self.add_view_message_handler(move |entity, envelope, client, cx| {
 723            Self::respond_to_request::<M, _>(
 724                envelope.receipt(),
 725                handler(entity, envelope, client.clone(), cx),
 726                client,
 727            )
 728        })
 729    }
 730
 731    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 732        receipt: Receipt<T>,
 733        response: F,
 734        client: Arc<Self>,
 735    ) -> Result<()> {
 736        match response.await {
 737            Ok(response) => {
 738                client.respond(receipt, response)?;
 739                Ok(())
 740            }
 741            Err(error) => {
 742                client.respond_with_error(
 743                    receipt,
 744                    proto::Error {
 745                        message: format!("{:?}", error),
 746                    },
 747                )?;
 748                Err(error)
 749            }
 750        }
 751    }
 752
 753    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 754        read_credentials_from_keychain(cx).is_some()
 755    }
 756
 757    #[async_recursion(?Send)]
 758    pub async fn authenticate_and_connect(
 759        self: &Arc<Self>,
 760        try_keychain: bool,
 761        cx: &AsyncAppContext,
 762    ) -> anyhow::Result<()> {
 763        let was_disconnected = match *self.status().borrow() {
 764            Status::SignedOut => true,
 765            Status::ConnectionError
 766            | Status::ConnectionLost
 767            | Status::Authenticating { .. }
 768            | Status::Reauthenticating { .. }
 769            | Status::ReconnectionError { .. } => false,
 770            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 771                return Ok(())
 772            }
 773            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 774        };
 775
 776        if was_disconnected {
 777            self.set_status(Status::Authenticating, cx);
 778        } else {
 779            self.set_status(Status::Reauthenticating, cx)
 780        }
 781
 782        let mut read_from_keychain = false;
 783        let mut credentials = self.state.read().credentials.clone();
 784        if credentials.is_none() && try_keychain {
 785            credentials = read_credentials_from_keychain(cx);
 786            read_from_keychain = credentials.is_some();
 787        }
 788        if credentials.is_none() {
 789            let mut status_rx = self.status();
 790            let _ = status_rx.next().await;
 791            futures::select_biased! {
 792                authenticate = self.authenticate(cx).fuse() => {
 793                    match authenticate {
 794                        Ok(creds) => credentials = Some(creds),
 795                        Err(err) => {
 796                            self.set_status(Status::ConnectionError, cx);
 797                            return Err(err);
 798                        }
 799                    }
 800                }
 801                _ = status_rx.next().fuse() => {
 802                    return Err(anyhow!("authentication canceled"));
 803                }
 804            }
 805        }
 806        let credentials = credentials.unwrap();
 807
 808        if was_disconnected {
 809            self.set_status(Status::Connecting, cx);
 810        } else {
 811            self.set_status(Status::Reconnecting, cx);
 812        }
 813
 814        let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
 815        futures::select_biased! {
 816            connection = self.establish_connection(&credentials, cx).fuse() => {
 817                match connection {
 818                    Ok(conn) => {
 819                        self.state.write().credentials = Some(credentials.clone());
 820                        if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 821                            write_credentials_to_keychain(&credentials, cx).log_err();
 822                        }
 823
 824                        futures::select_biased! {
 825                            result = self.set_connection(conn, cx).fuse() => result,
 826                            _ = timeout => {
 827                                self.set_status(Status::ConnectionError, cx);
 828                                Err(anyhow!("timed out waiting on hello message from server"))
 829                            }
 830                        }
 831                    }
 832                    Err(EstablishConnectionError::Unauthorized) => {
 833                        self.state.write().credentials.take();
 834                        if read_from_keychain {
 835                            cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 836                            self.set_status(Status::SignedOut, cx);
 837                            self.authenticate_and_connect(false, cx).await
 838                        } else {
 839                            self.set_status(Status::ConnectionError, cx);
 840                            Err(EstablishConnectionError::Unauthorized)?
 841                        }
 842                    }
 843                    Err(EstablishConnectionError::UpgradeRequired) => {
 844                        self.set_status(Status::UpgradeRequired, cx);
 845                        Err(EstablishConnectionError::UpgradeRequired)?
 846                    }
 847                    Err(error) => {
 848                        self.set_status(Status::ConnectionError, cx);
 849                        Err(error)?
 850                    }
 851                }
 852            }
 853            _ = &mut timeout => {
 854                self.set_status(Status::ConnectionError, cx);
 855                Err(anyhow!("timed out trying to establish connection"))
 856            }
 857        }
 858    }
 859
 860    async fn set_connection(
 861        self: &Arc<Self>,
 862        conn: Connection,
 863        cx: &AsyncAppContext,
 864    ) -> Result<()> {
 865        let executor = cx.background();
 866        log::info!("add connection to peer");
 867        let (connection_id, handle_io, mut incoming) = self
 868            .peer
 869            .add_connection(conn, move |duration| executor.timer(duration));
 870        let handle_io = cx.background().spawn(handle_io);
 871
 872        let peer_id = async {
 873            log::info!("waiting for server hello");
 874            let message = incoming
 875                .next()
 876                .await
 877                .ok_or_else(|| anyhow!("no hello message received"))?;
 878            log::info!("got server hello");
 879            let hello_message_type_name = message.payload_type_name().to_string();
 880            let hello = message
 881                .into_any()
 882                .downcast::<TypedEnvelope<proto::Hello>>()
 883                .map_err(|_| {
 884                    anyhow!(
 885                        "invalid hello message received: {:?}",
 886                        hello_message_type_name
 887                    )
 888                })?;
 889            let peer_id = hello
 890                .payload
 891                .peer_id
 892                .ok_or_else(|| anyhow!("invalid peer id"))?;
 893            Ok(peer_id)
 894        };
 895
 896        let peer_id = match peer_id.await {
 897            Ok(peer_id) => peer_id,
 898            Err(error) => {
 899                self.peer.disconnect(connection_id);
 900                return Err(error);
 901            }
 902        };
 903
 904        log::info!(
 905            "set status to connected (connection id: {:?}, peer id: {:?})",
 906            connection_id,
 907            peer_id
 908        );
 909        self.set_status(
 910            Status::Connected {
 911                peer_id,
 912                connection_id,
 913            },
 914            cx,
 915        );
 916        cx.foreground()
 917            .spawn({
 918                let cx = cx.clone();
 919                let this = self.clone();
 920                async move {
 921                    while let Some(message) = incoming.next().await {
 922                        this.handle_message(message, &cx);
 923                        // Don't starve the main thread when receiving lots of messages at once.
 924                        smol::future::yield_now().await;
 925                    }
 926                }
 927            })
 928            .detach();
 929
 930        let this = self.clone();
 931        let cx = cx.clone();
 932        cx.foreground()
 933            .spawn(async move {
 934                match handle_io.await {
 935                    Ok(()) => {
 936                        if this.status().borrow().clone()
 937                            == (Status::Connected {
 938                                connection_id,
 939                                peer_id,
 940                            })
 941                        {
 942                            this.set_status(Status::SignedOut, &cx);
 943                        }
 944                    }
 945                    Err(err) => {
 946                        log::error!("connection error: {:?}", err);
 947                        this.set_status(Status::ConnectionLost, &cx);
 948                    }
 949                }
 950            })
 951            .detach();
 952
 953        Ok(())
 954    }
 955
 956    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 957        #[cfg(any(test, feature = "test-support"))]
 958        if let Some(callback) = self.authenticate.read().as_ref() {
 959            return callback(cx);
 960        }
 961
 962        self.authenticate_with_browser(cx)
 963    }
 964
 965    fn establish_connection(
 966        self: &Arc<Self>,
 967        credentials: &Credentials,
 968        cx: &AsyncAppContext,
 969    ) -> Task<Result<Connection, EstablishConnectionError>> {
 970        #[cfg(any(test, feature = "test-support"))]
 971        if let Some(callback) = self.establish_connection.read().as_ref() {
 972            return callback(credentials, cx);
 973        }
 974
 975        self.establish_websocket_connection(credentials, cx)
 976    }
 977
 978    async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
 979        let preview_param = if is_preview { "?preview=1" } else { "" };
 980        let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
 981        let response = http.get(&url, Default::default(), false).await?;
 982
 983        // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
 984        // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
 985        // which requires authorization via an HTTP header.
 986        //
 987        // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
 988        // of a collab server. In that case, a request to the /rpc endpoint will
 989        // return an 'unauthorized' response.
 990        let collab_url = if response.status().is_redirection() {
 991            response
 992                .headers()
 993                .get("Location")
 994                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 995                .to_str()
 996                .map_err(EstablishConnectionError::other)?
 997                .to_string()
 998        } else if response.status() == StatusCode::UNAUTHORIZED {
 999            url
1000        } else {
1001            Err(anyhow!(
1002                "unexpected /rpc response status {}",
1003                response.status()
1004            ))?
1005        };
1006
1007        Url::parse(&collab_url).context("invalid rpc url")
1008    }
1009
1010    fn establish_websocket_connection(
1011        self: &Arc<Self>,
1012        credentials: &Credentials,
1013        cx: &AsyncAppContext,
1014    ) -> Task<Result<Connection, EstablishConnectionError>> {
1015        let is_preview = cx.read(|cx| {
1016            if cx.has_global::<ReleaseChannel>() {
1017                *cx.global::<ReleaseChannel>() == ReleaseChannel::Preview
1018            } else {
1019                false
1020            }
1021        });
1022
1023        let request = Request::builder()
1024            .header(
1025                "Authorization",
1026                format!("{} {}", credentials.user_id, credentials.access_token),
1027            )
1028            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
1029
1030        let http = self.http.clone();
1031        cx.background().spawn(async move {
1032            let mut rpc_url = Self::get_rpc_url(http, is_preview).await?;
1033            let rpc_host = rpc_url
1034                .host_str()
1035                .zip(rpc_url.port_or_known_default())
1036                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1037            let stream = smol::net::TcpStream::connect(rpc_host).await?;
1038
1039            log::info!("connected to rpc endpoint {}", rpc_url);
1040
1041            match rpc_url.scheme() {
1042                "https" => {
1043                    rpc_url.set_scheme("wss").unwrap();
1044                    let request = request.uri(rpc_url.as_str()).body(())?;
1045                    let (stream, _) =
1046                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1047                    Ok(Connection::new(
1048                        stream
1049                            .map_err(|error| anyhow!(error))
1050                            .sink_map_err(|error| anyhow!(error)),
1051                    ))
1052                }
1053                "http" => {
1054                    rpc_url.set_scheme("ws").unwrap();
1055                    let request = request.uri(rpc_url.as_str()).body(())?;
1056                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1057                    Ok(Connection::new(
1058                        stream
1059                            .map_err(|error| anyhow!(error))
1060                            .sink_map_err(|error| anyhow!(error)),
1061                    ))
1062                }
1063                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1064            }
1065        })
1066    }
1067
1068    pub fn authenticate_with_browser(
1069        self: &Arc<Self>,
1070        cx: &AsyncAppContext,
1071    ) -> Task<Result<Credentials>> {
1072        let platform = cx.platform();
1073        let executor = cx.background();
1074        let http = self.http.clone();
1075
1076        executor.clone().spawn(async move {
1077            // Generate a pair of asymmetric encryption keys. The public key will be used by the
1078            // zed server to encrypt the user's access token, so that it can'be intercepted by
1079            // any other app running on the user's device.
1080            let (public_key, private_key) =
1081                rpc::auth::keypair().expect("failed to generate keypair for auth");
1082            let public_key_string =
1083                String::try_from(public_key).expect("failed to serialize public key for auth");
1084
1085            if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1086                return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1087            }
1088
1089            // Start an HTTP server to receive the redirect from Zed's sign-in page.
1090            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1091            let port = server.server_addr().port();
1092
1093            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1094            // that the user is signing in from a Zed app running on the same device.
1095            let mut url = format!(
1096                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1097                *ZED_SERVER_URL, port, public_key_string
1098            );
1099
1100            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1101                log::info!("impersonating user @{}", impersonate_login);
1102                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1103            }
1104
1105            platform.open_url(&url);
1106
1107            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1108            // access token from the query params.
1109            //
1110            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1111            // custom URL scheme instead of this local HTTP server.
1112            let (user_id, access_token) = executor
1113                .spawn(async move {
1114                    for _ in 0..100 {
1115                        if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1116                            let path = req.url();
1117                            let mut user_id = None;
1118                            let mut access_token = None;
1119                            let url = Url::parse(&format!("http://example.com{}", path))
1120                                .context("failed to parse login notification url")?;
1121                            for (key, value) in url.query_pairs() {
1122                                if key == "access_token" {
1123                                    access_token = Some(value.to_string());
1124                                } else if key == "user_id" {
1125                                    user_id = Some(value.to_string());
1126                                }
1127                            }
1128
1129                            let post_auth_url =
1130                                format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1131                            req.respond(
1132                                tiny_http::Response::empty(302).with_header(
1133                                    tiny_http::Header::from_bytes(
1134                                        &b"Location"[..],
1135                                        post_auth_url.as_bytes(),
1136                                    )
1137                                    .unwrap(),
1138                                ),
1139                            )
1140                            .context("failed to respond to login http request")?;
1141                            return Ok((
1142                                user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1143                                access_token
1144                                    .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1145                            ));
1146                        }
1147                    }
1148
1149                    Err(anyhow!("didn't receive login redirect"))
1150                })
1151                .await?;
1152
1153            let access_token = private_key
1154                .decrypt_string(&access_token)
1155                .context("failed to decrypt access token")?;
1156            platform.activate(true);
1157
1158            Ok(Credentials {
1159                user_id: user_id.parse()?,
1160                access_token,
1161            })
1162        })
1163    }
1164
1165    async fn authenticate_as_admin(
1166        http: Arc<dyn HttpClient>,
1167        login: String,
1168        mut api_token: String,
1169    ) -> Result<Credentials> {
1170        #[derive(Deserialize)]
1171        struct AuthenticatedUserResponse {
1172            user: User,
1173        }
1174
1175        #[derive(Deserialize)]
1176        struct User {
1177            id: u64,
1178        }
1179
1180        // Use the collab server's admin API to retrieve the id
1181        // of the impersonated user.
1182        let mut url = Self::get_rpc_url(http.clone(), false).await?;
1183        url.set_path("/user");
1184        url.set_query(Some(&format!("github_login={login}")));
1185        let request = Request::get(url.as_str())
1186            .header("Authorization", format!("token {api_token}"))
1187            .body("".into())?;
1188
1189        let mut response = http.send(request).await?;
1190        let mut body = String::new();
1191        response.body_mut().read_to_string(&mut body).await?;
1192        if !response.status().is_success() {
1193            Err(anyhow!(
1194                "admin user request failed {} - {}",
1195                response.status().as_u16(),
1196                body,
1197            ))?;
1198        }
1199        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1200
1201        // Use the admin API token to authenticate as the impersonated user.
1202        api_token.insert_str(0, "ADMIN_TOKEN:");
1203        Ok(Credentials {
1204            user_id: response.user.id,
1205            access_token: api_token,
1206        })
1207    }
1208
1209    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1210        self.peer.teardown();
1211        self.set_status(Status::SignedOut, cx);
1212    }
1213
1214    fn connection_id(&self) -> Result<ConnectionId> {
1215        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1216            Ok(connection_id)
1217        } else {
1218            Err(anyhow!("not connected"))
1219        }
1220    }
1221
1222    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1223        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1224        self.peer.send(self.connection_id()?, message)
1225    }
1226
1227    pub fn request<T: RequestMessage>(
1228        &self,
1229        request: T,
1230    ) -> impl Future<Output = Result<T::Response>> {
1231        self.request_envelope(request)
1232            .map_ok(|envelope| envelope.payload)
1233    }
1234
1235    pub fn request_envelope<T: RequestMessage>(
1236        &self,
1237        request: T,
1238    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1239        let client_id = self.id;
1240        log::debug!(
1241            "rpc request start. client_id:{}. name:{}",
1242            client_id,
1243            T::NAME
1244        );
1245        let response = self
1246            .connection_id()
1247            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1248        async move {
1249            let response = response?.await;
1250            log::debug!(
1251                "rpc request finish. client_id:{}. name:{}",
1252                client_id,
1253                T::NAME
1254            );
1255            response
1256        }
1257    }
1258
1259    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1260        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1261        self.peer.respond(receipt, response)
1262    }
1263
1264    fn respond_with_error<T: RequestMessage>(
1265        &self,
1266        receipt: Receipt<T>,
1267        error: proto::Error,
1268    ) -> Result<()> {
1269        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1270        self.peer.respond_with_error(receipt, error)
1271    }
1272
1273    fn handle_message(
1274        self: &Arc<Client>,
1275        message: Box<dyn AnyTypedEnvelope>,
1276        cx: &AsyncAppContext,
1277    ) {
1278        let mut state = self.state.write();
1279        let type_name = message.payload_type_name();
1280        let payload_type_id = message.payload_type_id();
1281        let sender_id = message.original_sender_id();
1282
1283        let mut subscriber = None;
1284
1285        if let Some(message_model) = state
1286            .models_by_message_type
1287            .get(&payload_type_id)
1288            .and_then(|model| model.upgrade(cx))
1289        {
1290            subscriber = Some(Subscriber::Model(message_model));
1291        } else if let Some((extract_entity_id, entity_type_id)) =
1292            state.entity_id_extractors.get(&payload_type_id).zip(
1293                state
1294                    .entity_types_by_message_type
1295                    .get(&payload_type_id)
1296                    .copied(),
1297            )
1298        {
1299            let entity_id = (extract_entity_id)(message.as_ref());
1300
1301            match state
1302                .entities_by_type_and_remote_id
1303                .get_mut(&(entity_type_id, entity_id))
1304            {
1305                Some(WeakSubscriber::Pending(pending)) => {
1306                    pending.push(message);
1307                    return;
1308                }
1309                Some(weak_subscriber @ _) => match weak_subscriber {
1310                    WeakSubscriber::Model(handle) => {
1311                        subscriber = handle.upgrade(cx).map(Subscriber::Model);
1312                    }
1313                    WeakSubscriber::View(handle) => {
1314                        subscriber = Some(Subscriber::View(handle.clone()));
1315                    }
1316                    WeakSubscriber::Pending(_) => {}
1317                },
1318                _ => {}
1319            }
1320        }
1321
1322        let subscriber = if let Some(subscriber) = subscriber {
1323            subscriber
1324        } else {
1325            log::info!("unhandled message {}", type_name);
1326            self.peer.respond_with_unhandled_message(message).log_err();
1327            return;
1328        };
1329
1330        let handler = state.message_handlers.get(&payload_type_id).cloned();
1331        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1332        // It also ensures we don't hold the lock while yielding back to the executor, as
1333        // that might cause the executor thread driving this future to block indefinitely.
1334        drop(state);
1335
1336        if let Some(handler) = handler {
1337            let future = handler(subscriber, message, &self, cx.clone());
1338            let client_id = self.id;
1339            log::debug!(
1340                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1341                client_id,
1342                sender_id,
1343                type_name
1344            );
1345            cx.foreground()
1346                .spawn(async move {
1347                    match future.await {
1348                        Ok(()) => {
1349                            log::debug!(
1350                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1351                                client_id,
1352                                sender_id,
1353                                type_name
1354                            );
1355                        }
1356                        Err(error) => {
1357                            log::error!(
1358                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1359                                client_id,
1360                                sender_id,
1361                                type_name,
1362                                error
1363                            );
1364                        }
1365                    }
1366                })
1367                .detach();
1368        } else {
1369            log::info!("unhandled message {}", type_name);
1370            self.peer.respond_with_unhandled_message(message).log_err();
1371        }
1372    }
1373
1374    pub fn telemetry(&self) -> &Arc<Telemetry> {
1375        &self.telemetry
1376    }
1377}
1378
1379fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1380    if IMPERSONATE_LOGIN.is_some() {
1381        return None;
1382    }
1383
1384    let (user_id, access_token) = cx
1385        .platform()
1386        .read_credentials(&ZED_SERVER_URL)
1387        .log_err()
1388        .flatten()?;
1389    Some(Credentials {
1390        user_id: user_id.parse().ok()?,
1391        access_token: String::from_utf8(access_token).ok()?,
1392    })
1393}
1394
1395fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1396    cx.platform().write_credentials(
1397        &ZED_SERVER_URL,
1398        &credentials.user_id.to_string(),
1399        credentials.access_token.as_bytes(),
1400    )
1401}
1402
1403const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1404
1405pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1406    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1407}
1408
1409pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1410    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1411    let mut parts = path.split('/');
1412    let id = parts.next()?.parse::<u64>().ok()?;
1413    let access_token = parts.next()?;
1414    if access_token.is_empty() {
1415        return None;
1416    }
1417    Some((id, access_token.to_string()))
1418}
1419
1420#[cfg(test)]
1421mod tests {
1422    use super::*;
1423    use crate::test::FakeServer;
1424    use gpui::{executor::Deterministic, TestAppContext};
1425    use parking_lot::Mutex;
1426    use std::future;
1427    use util::http::FakeHttpClient;
1428
1429    #[gpui::test(iterations = 10)]
1430    async fn test_reconnection(cx: &mut TestAppContext) {
1431        cx.foreground().forbid_parking();
1432
1433        let user_id = 5;
1434        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1435        let server = FakeServer::for_client(user_id, &client, cx).await;
1436        let mut status = client.status();
1437        assert!(matches!(
1438            status.next().await,
1439            Some(Status::Connected { .. })
1440        ));
1441        assert_eq!(server.auth_count(), 1);
1442
1443        server.forbid_connections();
1444        server.disconnect();
1445        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1446
1447        server.allow_connections();
1448        cx.foreground().advance_clock(Duration::from_secs(10));
1449        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1450        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1451
1452        server.forbid_connections();
1453        server.disconnect();
1454        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1455
1456        // Clear cached credentials after authentication fails
1457        server.roll_access_token();
1458        server.allow_connections();
1459        cx.foreground().advance_clock(Duration::from_secs(10));
1460        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1461        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1462    }
1463
1464    #[gpui::test(iterations = 10)]
1465    async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1466        deterministic.forbid_parking();
1467
1468        let user_id = 5;
1469        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1470        let mut status = client.status();
1471
1472        // Time out when client tries to connect.
1473        client.override_authenticate(move |cx| {
1474            cx.foreground().spawn(async move {
1475                Ok(Credentials {
1476                    user_id,
1477                    access_token: "token".into(),
1478                })
1479            })
1480        });
1481        client.override_establish_connection(|_, cx| {
1482            cx.foreground().spawn(async move {
1483                future::pending::<()>().await;
1484                unreachable!()
1485            })
1486        });
1487        let auth_and_connect = cx.spawn({
1488            let client = client.clone();
1489            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1490        });
1491        deterministic.run_until_parked();
1492        assert!(matches!(status.next().await, Some(Status::Connecting)));
1493
1494        deterministic.advance_clock(CONNECTION_TIMEOUT);
1495        assert!(matches!(
1496            status.next().await,
1497            Some(Status::ConnectionError { .. })
1498        ));
1499        auth_and_connect.await.unwrap_err();
1500
1501        // Allow the connection to be established.
1502        let server = FakeServer::for_client(user_id, &client, cx).await;
1503        assert!(matches!(
1504            status.next().await,
1505            Some(Status::Connected { .. })
1506        ));
1507
1508        // Disconnect client.
1509        server.forbid_connections();
1510        server.disconnect();
1511        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1512
1513        // Time out when re-establishing the connection.
1514        server.allow_connections();
1515        client.override_establish_connection(|_, cx| {
1516            cx.foreground().spawn(async move {
1517                future::pending::<()>().await;
1518                unreachable!()
1519            })
1520        });
1521        deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1522        assert!(matches!(
1523            status.next().await,
1524            Some(Status::Reconnecting { .. })
1525        ));
1526
1527        deterministic.advance_clock(CONNECTION_TIMEOUT);
1528        assert!(matches!(
1529            status.next().await,
1530            Some(Status::ReconnectionError { .. })
1531        ));
1532    }
1533
1534    #[gpui::test(iterations = 10)]
1535    async fn test_authenticating_more_than_once(
1536        cx: &mut TestAppContext,
1537        deterministic: Arc<Deterministic>,
1538    ) {
1539        cx.foreground().forbid_parking();
1540
1541        let auth_count = Arc::new(Mutex::new(0));
1542        let dropped_auth_count = Arc::new(Mutex::new(0));
1543        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1544        client.override_authenticate({
1545            let auth_count = auth_count.clone();
1546            let dropped_auth_count = dropped_auth_count.clone();
1547            move |cx| {
1548                let auth_count = auth_count.clone();
1549                let dropped_auth_count = dropped_auth_count.clone();
1550                cx.foreground().spawn(async move {
1551                    *auth_count.lock() += 1;
1552                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1553                    future::pending::<()>().await;
1554                    unreachable!()
1555                })
1556            }
1557        });
1558
1559        let _authenticate = cx.spawn(|cx| {
1560            let client = client.clone();
1561            async move { client.authenticate_and_connect(false, &cx).await }
1562        });
1563        deterministic.run_until_parked();
1564        assert_eq!(*auth_count.lock(), 1);
1565        assert_eq!(*dropped_auth_count.lock(), 0);
1566
1567        let _authenticate = cx.spawn(|cx| {
1568            let client = client.clone();
1569            async move { client.authenticate_and_connect(false, &cx).await }
1570        });
1571        deterministic.run_until_parked();
1572        assert_eq!(*auth_count.lock(), 2);
1573        assert_eq!(*dropped_auth_count.lock(), 1);
1574    }
1575
1576    #[test]
1577    fn test_encode_and_decode_worktree_url() {
1578        let url = encode_worktree_url(5, "deadbeef");
1579        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1580        assert_eq!(
1581            decode_worktree_url(&format!("\n {}\t", url)),
1582            Some((5, "deadbeef".to_string()))
1583        );
1584        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1585    }
1586
1587    #[gpui::test]
1588    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1589        cx.foreground().forbid_parking();
1590
1591        let user_id = 5;
1592        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1593        let server = FakeServer::for_client(user_id, &client, cx).await;
1594
1595        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1596        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1597        client.add_model_message_handler(
1598            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1599                match model.read_with(&cx, |model, _| model.id) {
1600                    1 => done_tx1.try_send(()).unwrap(),
1601                    2 => done_tx2.try_send(()).unwrap(),
1602                    _ => unreachable!(),
1603                }
1604                async { Ok(()) }
1605            },
1606        );
1607        let model1 = cx.add_model(|_| Model {
1608            id: 1,
1609            subscription: None,
1610        });
1611        let model2 = cx.add_model(|_| Model {
1612            id: 2,
1613            subscription: None,
1614        });
1615        let model3 = cx.add_model(|_| Model {
1616            id: 3,
1617            subscription: None,
1618        });
1619
1620        let _subscription1 = client
1621            .subscribe_to_entity(1)
1622            .unwrap()
1623            .set_model(&model1, &mut cx.to_async());
1624        let _subscription2 = client
1625            .subscribe_to_entity(2)
1626            .unwrap()
1627            .set_model(&model2, &mut cx.to_async());
1628        // Ensure dropping a subscription for the same entity type still allows receiving of
1629        // messages for other entity IDs of the same type.
1630        let subscription3 = client
1631            .subscribe_to_entity(3)
1632            .unwrap()
1633            .set_model(&model3, &mut cx.to_async());
1634        drop(subscription3);
1635
1636        server.send(proto::JoinProject { project_id: 1 });
1637        server.send(proto::JoinProject { project_id: 2 });
1638        done_rx1.next().await.unwrap();
1639        done_rx2.next().await.unwrap();
1640    }
1641
1642    #[gpui::test]
1643    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1644        cx.foreground().forbid_parking();
1645
1646        let user_id = 5;
1647        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1648        let server = FakeServer::for_client(user_id, &client, cx).await;
1649
1650        let model = cx.add_model(|_| Model::default());
1651        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1652        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1653        let subscription1 = client.add_message_handler(
1654            model.clone(),
1655            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1656                done_tx1.try_send(()).unwrap();
1657                async { Ok(()) }
1658            },
1659        );
1660        drop(subscription1);
1661        let _subscription2 = client.add_message_handler(
1662            model.clone(),
1663            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1664                done_tx2.try_send(()).unwrap();
1665                async { Ok(()) }
1666            },
1667        );
1668        server.send(proto::Ping {});
1669        done_rx2.next().await.unwrap();
1670    }
1671
1672    #[gpui::test]
1673    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1674        cx.foreground().forbid_parking();
1675
1676        let user_id = 5;
1677        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1678        let server = FakeServer::for_client(user_id, &client, cx).await;
1679
1680        let model = cx.add_model(|_| Model::default());
1681        let (done_tx, mut done_rx) = smol::channel::unbounded();
1682        let subscription = client.add_message_handler(
1683            model.clone(),
1684            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1685                model.update(&mut cx, |model, _| model.subscription.take());
1686                done_tx.try_send(()).unwrap();
1687                async { Ok(()) }
1688            },
1689        );
1690        model.update(cx, |model, _| {
1691            model.subscription = Some(subscription);
1692        });
1693        server.send(proto::Ping {});
1694        done_rx.next().await.unwrap();
1695    }
1696
1697    #[derive(Default)]
1698    struct Model {
1699        id: usize,
1700        subscription: Option<Subscription>,
1701    }
1702
1703    impl Entity for Model {
1704        type Event = ();
1705    }
1706}