client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod http;
   5pub mod telemetry;
   6pub mod user;
   7
   8use anyhow::{anyhow, Context, Result};
   9use async_recursion::async_recursion;
  10use async_tungstenite::tungstenite::{
  11    error::Error as WebsocketError,
  12    http::{Request, StatusCode},
  13};
  14use db::Db;
  15use futures::{future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryStreamExt};
  16use gpui::{
  17    actions,
  18    serde_json::{self, Value},
  19    AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AppContext,
  20    AsyncAppContext, Entity, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle,
  21};
  22use http::HttpClient;
  23use lazy_static::lazy_static;
  24use parking_lot::RwLock;
  25use postage::watch;
  26use rand::prelude::*;
  27use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
  28use serde::Deserialize;
  29use settings::ReleaseChannel;
  30use std::{
  31    any::TypeId,
  32    collections::HashMap,
  33    convert::TryFrom,
  34    fmt::Write as _,
  35    future::Future,
  36    marker::PhantomData,
  37    path::PathBuf,
  38    sync::{Arc, Weak},
  39    time::{Duration, Instant},
  40};
  41use telemetry::Telemetry;
  42use thiserror::Error;
  43use url::Url;
  44use util::{ResultExt, TryFutureExt};
  45
  46pub use rpc::*;
  47pub use user::*;
  48
  49lazy_static! {
  50    pub static ref ZED_SERVER_URL: String =
  51        std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
  52    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  53        .ok()
  54        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  55    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
  56        .ok()
  57        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  58}
  59
  60pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
  61pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
  62pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
  63
  64actions!(client, [Authenticate]);
  65
  66pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
  67    cx.add_global_action({
  68        let client = client.clone();
  69        move |_: &Authenticate, cx| {
  70            let client = client.clone();
  71            cx.spawn(
  72                |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
  73            )
  74            .detach();
  75        }
  76    });
  77}
  78
  79pub struct Client {
  80    id: usize,
  81    peer: Arc<Peer>,
  82    http: Arc<dyn HttpClient>,
  83    telemetry: Arc<Telemetry>,
  84    state: RwLock<ClientState>,
  85
  86    #[allow(clippy::type_complexity)]
  87    #[cfg(any(test, feature = "test-support"))]
  88    authenticate: RwLock<
  89        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
  90    >,
  91
  92    #[allow(clippy::type_complexity)]
  93    #[cfg(any(test, feature = "test-support"))]
  94    establish_connection: RwLock<
  95        Option<
  96            Box<
  97                dyn 'static
  98                    + Send
  99                    + Sync
 100                    + Fn(
 101                        &Credentials,
 102                        &AsyncAppContext,
 103                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 104            >,
 105        >,
 106    >,
 107}
 108
 109#[derive(Error, Debug)]
 110pub enum EstablishConnectionError {
 111    #[error("upgrade required")]
 112    UpgradeRequired,
 113    #[error("unauthorized")]
 114    Unauthorized,
 115    #[error("{0}")]
 116    Other(#[from] anyhow::Error),
 117    #[error("{0}")]
 118    Http(#[from] http::Error),
 119    #[error("{0}")]
 120    Io(#[from] std::io::Error),
 121    #[error("{0}")]
 122    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 123}
 124
 125impl From<WebsocketError> for EstablishConnectionError {
 126    fn from(error: WebsocketError) -> Self {
 127        if let WebsocketError::Http(response) = &error {
 128            match response.status() {
 129                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 130                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 131                _ => {}
 132            }
 133        }
 134        EstablishConnectionError::Other(error.into())
 135    }
 136}
 137
 138impl EstablishConnectionError {
 139    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 140        Self::Other(error.into())
 141    }
 142}
 143
 144#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 145pub enum Status {
 146    SignedOut,
 147    UpgradeRequired,
 148    Authenticating,
 149    Connecting,
 150    ConnectionError,
 151    Connected {
 152        peer_id: PeerId,
 153        connection_id: ConnectionId,
 154    },
 155    ConnectionLost,
 156    Reauthenticating,
 157    Reconnecting,
 158    ReconnectionError {
 159        next_reconnection: Instant,
 160    },
 161}
 162
 163impl Status {
 164    pub fn is_connected(&self) -> bool {
 165        matches!(self, Self::Connected { .. })
 166    }
 167}
 168
 169struct ClientState {
 170    credentials: Option<Credentials>,
 171    status: (watch::Sender<Status>, watch::Receiver<Status>),
 172    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 173    _reconnect_task: Option<Task<()>>,
 174    reconnect_interval: Duration,
 175    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 176    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 177    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 178    #[allow(clippy::type_complexity)]
 179    message_handlers: HashMap<
 180        TypeId,
 181        Arc<
 182            dyn Send
 183                + Sync
 184                + Fn(
 185                    Subscriber,
 186                    Box<dyn AnyTypedEnvelope>,
 187                    &Arc<Client>,
 188                    AsyncAppContext,
 189                ) -> LocalBoxFuture<'static, Result<()>>,
 190        >,
 191    >,
 192}
 193
 194enum WeakSubscriber {
 195    Model(AnyWeakModelHandle),
 196    View(AnyWeakViewHandle),
 197    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 198}
 199
 200enum Subscriber {
 201    Model(AnyModelHandle),
 202    View(AnyViewHandle),
 203}
 204
 205#[derive(Clone, Debug)]
 206pub struct Credentials {
 207    pub user_id: u64,
 208    pub access_token: String,
 209}
 210
 211impl Default for ClientState {
 212    fn default() -> Self {
 213        Self {
 214            credentials: None,
 215            status: watch::channel_with(Status::SignedOut),
 216            entity_id_extractors: Default::default(),
 217            _reconnect_task: None,
 218            reconnect_interval: Duration::from_secs(5),
 219            models_by_message_type: Default::default(),
 220            entities_by_type_and_remote_id: Default::default(),
 221            entity_types_by_message_type: Default::default(),
 222            message_handlers: Default::default(),
 223        }
 224    }
 225}
 226
 227pub enum Subscription {
 228    Entity {
 229        client: Weak<Client>,
 230        id: (TypeId, u64),
 231    },
 232    Message {
 233        client: Weak<Client>,
 234        id: TypeId,
 235    },
 236}
 237
 238impl Drop for Subscription {
 239    fn drop(&mut self) {
 240        match self {
 241            Subscription::Entity { client, id } => {
 242                if let Some(client) = client.upgrade() {
 243                    let mut state = client.state.write();
 244                    let _ = state.entities_by_type_and_remote_id.remove(id);
 245                }
 246            }
 247            Subscription::Message { client, id } => {
 248                if let Some(client) = client.upgrade() {
 249                    let mut state = client.state.write();
 250                    let _ = state.entity_types_by_message_type.remove(id);
 251                    let _ = state.message_handlers.remove(id);
 252                }
 253            }
 254        }
 255    }
 256}
 257
 258pub struct PendingEntitySubscription<T: Entity> {
 259    client: Arc<Client>,
 260    remote_id: u64,
 261    _entity_type: PhantomData<T>,
 262    consumed: bool,
 263}
 264
 265impl<T: Entity> PendingEntitySubscription<T> {
 266    pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
 267        self.consumed = true;
 268        let mut state = self.client.state.write();
 269        let id = (TypeId::of::<T>(), self.remote_id);
 270        let Some(WeakSubscriber::Pending(messages)) =
 271            state.entities_by_type_and_remote_id.remove(&id)
 272        else {
 273            unreachable!()
 274        };
 275
 276        state
 277            .entities_by_type_and_remote_id
 278            .insert(id, WeakSubscriber::Model(model.downgrade().into()));
 279        drop(state);
 280        for message in messages {
 281            self.client.handle_message(message, cx);
 282        }
 283        Subscription::Entity {
 284            client: Arc::downgrade(&self.client),
 285            id,
 286        }
 287    }
 288}
 289
 290impl<T: Entity> Drop for PendingEntitySubscription<T> {
 291    fn drop(&mut self) {
 292        if !self.consumed {
 293            let mut state = self.client.state.write();
 294            if let Some(WeakSubscriber::Pending(messages)) = state
 295                .entities_by_type_and_remote_id
 296                .remove(&(TypeId::of::<T>(), self.remote_id))
 297            {
 298                for message in messages {
 299                    log::info!("unhandled message {}", message.payload_type_name());
 300                }
 301            }
 302        }
 303    }
 304}
 305
 306impl Client {
 307    pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
 308        Arc::new(Self {
 309            id: 0,
 310            peer: Peer::new(),
 311            telemetry: Telemetry::new(http.clone(), cx),
 312            http,
 313            state: Default::default(),
 314
 315            #[cfg(any(test, feature = "test-support"))]
 316            authenticate: Default::default(),
 317            #[cfg(any(test, feature = "test-support"))]
 318            establish_connection: Default::default(),
 319        })
 320    }
 321
 322    pub fn id(&self) -> usize {
 323        self.id
 324    }
 325
 326    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 327        self.http.clone()
 328    }
 329
 330    #[cfg(any(test, feature = "test-support"))]
 331    pub fn set_id(&mut self, id: usize) -> &Self {
 332        self.id = id;
 333        self
 334    }
 335
 336    #[cfg(any(test, feature = "test-support"))]
 337    pub fn tear_down(&self) {
 338        let mut state = self.state.write();
 339        state._reconnect_task.take();
 340        state.message_handlers.clear();
 341        state.models_by_message_type.clear();
 342        state.entities_by_type_and_remote_id.clear();
 343        state.entity_id_extractors.clear();
 344        self.peer.reset();
 345    }
 346
 347    #[cfg(any(test, feature = "test-support"))]
 348    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 349    where
 350        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 351    {
 352        *self.authenticate.write() = Some(Box::new(authenticate));
 353        self
 354    }
 355
 356    #[cfg(any(test, feature = "test-support"))]
 357    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 358    where
 359        F: 'static
 360            + Send
 361            + Sync
 362            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 363    {
 364        *self.establish_connection.write() = Some(Box::new(connect));
 365        self
 366    }
 367
 368    pub fn user_id(&self) -> Option<u64> {
 369        self.state
 370            .read()
 371            .credentials
 372            .as_ref()
 373            .map(|credentials| credentials.user_id)
 374    }
 375
 376    pub fn peer_id(&self) -> Option<PeerId> {
 377        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 378            Some(*peer_id)
 379        } else {
 380            None
 381        }
 382    }
 383
 384    pub fn status(&self) -> watch::Receiver<Status> {
 385        self.state.read().status.1.clone()
 386    }
 387
 388    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 389        log::info!("set status on client {}: {:?}", self.id, status);
 390        let mut state = self.state.write();
 391        *state.status.0.borrow_mut() = status;
 392
 393        match status {
 394            Status::Connected { .. } => {
 395                state._reconnect_task = None;
 396            }
 397            Status::ConnectionLost => {
 398                let this = self.clone();
 399                let reconnect_interval = state.reconnect_interval;
 400                state._reconnect_task = Some(cx.spawn(|cx| async move {
 401                    #[cfg(any(test, feature = "test-support"))]
 402                    let mut rng = StdRng::seed_from_u64(0);
 403                    #[cfg(not(any(test, feature = "test-support")))]
 404                    let mut rng = StdRng::from_entropy();
 405
 406                    let mut delay = INITIAL_RECONNECTION_DELAY;
 407                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 408                        log::error!("failed to connect {}", error);
 409                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 410                            this.set_status(
 411                                Status::ReconnectionError {
 412                                    next_reconnection: Instant::now() + delay,
 413                                },
 414                                &cx,
 415                            );
 416                            cx.background().timer(delay).await;
 417                            delay = delay
 418                                .mul_f32(rng.gen_range(1.0..=2.0))
 419                                .min(reconnect_interval);
 420                        } else {
 421                            break;
 422                        }
 423                    }
 424                }));
 425            }
 426            Status::SignedOut | Status::UpgradeRequired => {
 427                self.telemetry.set_authenticated_user_info(None, false);
 428                state._reconnect_task.take();
 429            }
 430            _ => {}
 431        }
 432    }
 433
 434    pub fn add_view_for_remote_entity<T: View>(
 435        self: &Arc<Self>,
 436        remote_id: u64,
 437        cx: &mut ViewContext<T>,
 438    ) -> Subscription {
 439        let id = (TypeId::of::<T>(), remote_id);
 440        self.state
 441            .write()
 442            .entities_by_type_and_remote_id
 443            .insert(id, WeakSubscriber::View(cx.weak_handle().into()));
 444        Subscription::Entity {
 445            client: Arc::downgrade(self),
 446            id,
 447        }
 448    }
 449
 450    pub fn subscribe_to_entity<T: Entity>(
 451        self: &Arc<Self>,
 452        remote_id: u64,
 453    ) -> PendingEntitySubscription<T> {
 454        let id = (TypeId::of::<T>(), remote_id);
 455        self.state
 456            .write()
 457            .entities_by_type_and_remote_id
 458            .insert(id, WeakSubscriber::Pending(Default::default()));
 459
 460        PendingEntitySubscription {
 461            client: self.clone(),
 462            remote_id,
 463            consumed: false,
 464            _entity_type: PhantomData,
 465        }
 466    }
 467
 468    pub fn add_message_handler<M, E, H, F>(
 469        self: &Arc<Self>,
 470        model: ModelHandle<E>,
 471        handler: H,
 472    ) -> Subscription
 473    where
 474        M: EnvelopedMessage,
 475        E: Entity,
 476        H: 'static
 477            + Send
 478            + Sync
 479            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 480        F: 'static + Future<Output = Result<()>>,
 481    {
 482        let message_type_id = TypeId::of::<M>();
 483
 484        let mut state = self.state.write();
 485        state
 486            .models_by_message_type
 487            .insert(message_type_id, model.downgrade().into());
 488
 489        let prev_handler = state.message_handlers.insert(
 490            message_type_id,
 491            Arc::new(move |handle, envelope, client, cx| {
 492                let handle = if let Subscriber::Model(handle) = handle {
 493                    handle
 494                } else {
 495                    unreachable!();
 496                };
 497                let model = handle.downcast::<E>().unwrap();
 498                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 499                handler(model, *envelope, client.clone(), cx).boxed_local()
 500            }),
 501        );
 502        if prev_handler.is_some() {
 503            panic!("registered handler for the same message twice");
 504        }
 505
 506        Subscription::Message {
 507            client: Arc::downgrade(self),
 508            id: message_type_id,
 509        }
 510    }
 511
 512    pub fn add_request_handler<M, E, H, F>(
 513        self: &Arc<Self>,
 514        model: ModelHandle<E>,
 515        handler: H,
 516    ) -> Subscription
 517    where
 518        M: RequestMessage,
 519        E: Entity,
 520        H: 'static
 521            + Send
 522            + Sync
 523            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 524        F: 'static + Future<Output = Result<M::Response>>,
 525    {
 526        self.add_message_handler(model, move |handle, envelope, this, cx| {
 527            Self::respond_to_request(
 528                envelope.receipt(),
 529                handler(handle, envelope, this.clone(), cx),
 530                this,
 531            )
 532        })
 533    }
 534
 535    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 536    where
 537        M: EntityMessage,
 538        E: View,
 539        H: 'static
 540            + Send
 541            + Sync
 542            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 543        F: 'static + Future<Output = Result<()>>,
 544    {
 545        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 546            if let Subscriber::View(handle) = handle {
 547                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 548            } else {
 549                unreachable!();
 550            }
 551        })
 552    }
 553
 554    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 555    where
 556        M: EntityMessage,
 557        E: Entity,
 558        H: 'static
 559            + Send
 560            + Sync
 561            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 562        F: 'static + Future<Output = Result<()>>,
 563    {
 564        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 565            if let Subscriber::Model(handle) = handle {
 566                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 567            } else {
 568                unreachable!();
 569            }
 570        })
 571    }
 572
 573    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 574    where
 575        M: EntityMessage,
 576        E: Entity,
 577        H: 'static
 578            + Send
 579            + Sync
 580            + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 581        F: 'static + Future<Output = Result<()>>,
 582    {
 583        let model_type_id = TypeId::of::<E>();
 584        let message_type_id = TypeId::of::<M>();
 585
 586        let mut state = self.state.write();
 587        state
 588            .entity_types_by_message_type
 589            .insert(message_type_id, model_type_id);
 590        state
 591            .entity_id_extractors
 592            .entry(message_type_id)
 593            .or_insert_with(|| {
 594                |envelope| {
 595                    envelope
 596                        .as_any()
 597                        .downcast_ref::<TypedEnvelope<M>>()
 598                        .unwrap()
 599                        .payload
 600                        .remote_entity_id()
 601                }
 602            });
 603        let prev_handler = state.message_handlers.insert(
 604            message_type_id,
 605            Arc::new(move |handle, envelope, client, cx| {
 606                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 607                handler(handle, *envelope, client.clone(), cx).boxed_local()
 608            }),
 609        );
 610        if prev_handler.is_some() {
 611            panic!("registered handler for the same message twice");
 612        }
 613    }
 614
 615    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 616    where
 617        M: EntityMessage + RequestMessage,
 618        E: Entity,
 619        H: 'static
 620            + Send
 621            + Sync
 622            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 623        F: 'static + Future<Output = Result<M::Response>>,
 624    {
 625        self.add_model_message_handler(move |entity, envelope, client, cx| {
 626            Self::respond_to_request::<M, _>(
 627                envelope.receipt(),
 628                handler(entity, envelope, client.clone(), cx),
 629                client,
 630            )
 631        })
 632    }
 633
 634    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 635    where
 636        M: EntityMessage + RequestMessage,
 637        E: View,
 638        H: 'static
 639            + Send
 640            + Sync
 641            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 642        F: 'static + Future<Output = Result<M::Response>>,
 643    {
 644        self.add_view_message_handler(move |entity, envelope, client, cx| {
 645            Self::respond_to_request::<M, _>(
 646                envelope.receipt(),
 647                handler(entity, envelope, client.clone(), cx),
 648                client,
 649            )
 650        })
 651    }
 652
 653    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 654        receipt: Receipt<T>,
 655        response: F,
 656        client: Arc<Self>,
 657    ) -> Result<()> {
 658        match response.await {
 659            Ok(response) => {
 660                client.respond(receipt, response)?;
 661                Ok(())
 662            }
 663            Err(error) => {
 664                client.respond_with_error(
 665                    receipt,
 666                    proto::Error {
 667                        message: format!("{:?}", error),
 668                    },
 669                )?;
 670                Err(error)
 671            }
 672        }
 673    }
 674
 675    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 676        read_credentials_from_keychain(cx).is_some()
 677    }
 678
 679    #[async_recursion(?Send)]
 680    pub async fn authenticate_and_connect(
 681        self: &Arc<Self>,
 682        try_keychain: bool,
 683        cx: &AsyncAppContext,
 684    ) -> anyhow::Result<()> {
 685        let was_disconnected = match *self.status().borrow() {
 686            Status::SignedOut => true,
 687            Status::ConnectionError
 688            | Status::ConnectionLost
 689            | Status::Authenticating { .. }
 690            | Status::Reauthenticating { .. }
 691            | Status::ReconnectionError { .. } => false,
 692            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 693                return Ok(())
 694            }
 695            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 696        };
 697
 698        if was_disconnected {
 699            self.set_status(Status::Authenticating, cx);
 700        } else {
 701            self.set_status(Status::Reauthenticating, cx)
 702        }
 703
 704        let mut read_from_keychain = false;
 705        let mut credentials = self.state.read().credentials.clone();
 706        if credentials.is_none() && try_keychain {
 707            credentials = read_credentials_from_keychain(cx);
 708            read_from_keychain = credentials.is_some();
 709            if read_from_keychain {
 710                self.report_event("read credentials from keychain", Default::default());
 711            }
 712        }
 713        if credentials.is_none() {
 714            let mut status_rx = self.status();
 715            let _ = status_rx.next().await;
 716            futures::select_biased! {
 717                authenticate = self.authenticate(cx).fuse() => {
 718                    match authenticate {
 719                        Ok(creds) => credentials = Some(creds),
 720                        Err(err) => {
 721                            self.set_status(Status::ConnectionError, cx);
 722                            return Err(err);
 723                        }
 724                    }
 725                }
 726                _ = status_rx.next().fuse() => {
 727                    return Err(anyhow!("authentication canceled"));
 728                }
 729            }
 730        }
 731        let credentials = credentials.unwrap();
 732
 733        if was_disconnected {
 734            self.set_status(Status::Connecting, cx);
 735        } else {
 736            self.set_status(Status::Reconnecting, cx);
 737        }
 738
 739        let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
 740        futures::select_biased! {
 741            connection = self.establish_connection(&credentials, cx).fuse() => {
 742                match connection {
 743                    Ok(conn) => {
 744                        self.state.write().credentials = Some(credentials.clone());
 745                        if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 746                            write_credentials_to_keychain(&credentials, cx).log_err();
 747                        }
 748
 749                        futures::select_biased! {
 750                            result = self.set_connection(conn, cx).fuse() => result,
 751                            _ = timeout => {
 752                                self.set_status(Status::ConnectionError, cx);
 753                                Err(anyhow!("timed out waiting on hello message from server"))
 754                            }
 755                        }
 756                    }
 757                    Err(EstablishConnectionError::Unauthorized) => {
 758                        self.state.write().credentials.take();
 759                        if read_from_keychain {
 760                            cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 761                            self.set_status(Status::SignedOut, cx);
 762                            self.authenticate_and_connect(false, cx).await
 763                        } else {
 764                            self.set_status(Status::ConnectionError, cx);
 765                            Err(EstablishConnectionError::Unauthorized)?
 766                        }
 767                    }
 768                    Err(EstablishConnectionError::UpgradeRequired) => {
 769                        self.set_status(Status::UpgradeRequired, cx);
 770                        Err(EstablishConnectionError::UpgradeRequired)?
 771                    }
 772                    Err(error) => {
 773                        self.set_status(Status::ConnectionError, cx);
 774                        Err(error)?
 775                    }
 776                }
 777            }
 778            _ = &mut timeout => {
 779                self.set_status(Status::ConnectionError, cx);
 780                Err(anyhow!("timed out trying to establish connection"))
 781            }
 782        }
 783    }
 784
 785    async fn set_connection(
 786        self: &Arc<Self>,
 787        conn: Connection,
 788        cx: &AsyncAppContext,
 789    ) -> Result<()> {
 790        let executor = cx.background();
 791        log::info!("add connection to peer");
 792        let (connection_id, handle_io, mut incoming) = self
 793            .peer
 794            .add_connection(conn, move |duration| executor.timer(duration));
 795        let handle_io = cx.background().spawn(handle_io);
 796
 797        let peer_id = async {
 798            log::info!("waiting for server hello");
 799            let message = incoming
 800                .next()
 801                .await
 802                .ok_or_else(|| anyhow!("no hello message received"))?;
 803            log::info!("got server hello");
 804            let hello_message_type_name = message.payload_type_name().to_string();
 805            let hello = message
 806                .into_any()
 807                .downcast::<TypedEnvelope<proto::Hello>>()
 808                .map_err(|_| {
 809                    anyhow!(
 810                        "invalid hello message received: {:?}",
 811                        hello_message_type_name
 812                    )
 813                })?;
 814            Ok(PeerId(hello.payload.peer_id))
 815        };
 816
 817        let peer_id = match peer_id.await {
 818            Ok(peer_id) => peer_id,
 819            Err(error) => {
 820                self.peer.disconnect(connection_id);
 821                return Err(error);
 822            }
 823        };
 824
 825        log::info!(
 826            "set status to connected (connection id: {}, peer id: {})",
 827            connection_id,
 828            peer_id
 829        );
 830        self.set_status(
 831            Status::Connected {
 832                peer_id,
 833                connection_id,
 834            },
 835            cx,
 836        );
 837        cx.foreground()
 838            .spawn({
 839                let cx = cx.clone();
 840                let this = self.clone();
 841                async move {
 842                    while let Some(message) = incoming.next().await {
 843                        this.handle_message(message, &cx);
 844                        // Don't starve the main thread when receiving lots of messages at once.
 845                        smol::future::yield_now().await;
 846                    }
 847                }
 848            })
 849            .detach();
 850
 851        let this = self.clone();
 852        let cx = cx.clone();
 853        cx.foreground()
 854            .spawn(async move {
 855                match handle_io.await {
 856                    Ok(()) => {
 857                        if *this.status().borrow()
 858                            == (Status::Connected {
 859                                connection_id,
 860                                peer_id,
 861                            })
 862                        {
 863                            this.set_status(Status::SignedOut, &cx);
 864                        }
 865                    }
 866                    Err(err) => {
 867                        log::error!("connection error: {:?}", err);
 868                        this.set_status(Status::ConnectionLost, &cx);
 869                    }
 870                }
 871            })
 872            .detach();
 873
 874        Ok(())
 875    }
 876
 877    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 878        #[cfg(any(test, feature = "test-support"))]
 879        if let Some(callback) = self.authenticate.read().as_ref() {
 880            return callback(cx);
 881        }
 882
 883        self.authenticate_with_browser(cx)
 884    }
 885
 886    fn establish_connection(
 887        self: &Arc<Self>,
 888        credentials: &Credentials,
 889        cx: &AsyncAppContext,
 890    ) -> Task<Result<Connection, EstablishConnectionError>> {
 891        #[cfg(any(test, feature = "test-support"))]
 892        if let Some(callback) = self.establish_connection.read().as_ref() {
 893            return callback(credentials, cx);
 894        }
 895
 896        self.establish_websocket_connection(credentials, cx)
 897    }
 898
 899    async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
 900        let preview_param = if is_preview { "?preview=1" } else { "" };
 901        let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
 902        let response = http.get(&url, Default::default(), false).await?;
 903
 904        // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
 905        // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
 906        // which requires authorization via an HTTP header.
 907        //
 908        // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
 909        // of a collab server. In that case, a request to the /rpc endpoint will
 910        // return an 'unauthorized' response.
 911        let collab_url = if response.status().is_redirection() {
 912            response
 913                .headers()
 914                .get("Location")
 915                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 916                .to_str()
 917                .map_err(EstablishConnectionError::other)?
 918                .to_string()
 919        } else if response.status() == StatusCode::UNAUTHORIZED {
 920            url
 921        } else {
 922            Err(anyhow!(
 923                "unexpected /rpc response status {}",
 924                response.status()
 925            ))?
 926        };
 927
 928        Url::parse(&collab_url).context("invalid rpc url")
 929    }
 930
 931    fn establish_websocket_connection(
 932        self: &Arc<Self>,
 933        credentials: &Credentials,
 934        cx: &AsyncAppContext,
 935    ) -> Task<Result<Connection, EstablishConnectionError>> {
 936        let is_preview = cx.read(|cx| {
 937            if cx.has_global::<ReleaseChannel>() {
 938                *cx.global::<ReleaseChannel>() == ReleaseChannel::Preview
 939            } else {
 940                false
 941            }
 942        });
 943
 944        let request = Request::builder()
 945            .header(
 946                "Authorization",
 947                format!("{} {}", credentials.user_id, credentials.access_token),
 948            )
 949            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
 950
 951        let http = self.http.clone();
 952        cx.background().spawn(async move {
 953            let mut rpc_url = Self::get_rpc_url(http, is_preview).await?;
 954            let rpc_host = rpc_url
 955                .host_str()
 956                .zip(rpc_url.port_or_known_default())
 957                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
 958            let stream = smol::net::TcpStream::connect(rpc_host).await?;
 959
 960            log::info!("connected to rpc endpoint {}", rpc_url);
 961
 962            match rpc_url.scheme() {
 963                "https" => {
 964                    rpc_url.set_scheme("wss").unwrap();
 965                    let request = request.uri(rpc_url.as_str()).body(())?;
 966                    let (stream, _) =
 967                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
 968                    Ok(Connection::new(
 969                        stream
 970                            .map_err(|error| anyhow!(error))
 971                            .sink_map_err(|error| anyhow!(error)),
 972                    ))
 973                }
 974                "http" => {
 975                    rpc_url.set_scheme("ws").unwrap();
 976                    let request = request.uri(rpc_url.as_str()).body(())?;
 977                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
 978                    Ok(Connection::new(
 979                        stream
 980                            .map_err(|error| anyhow!(error))
 981                            .sink_map_err(|error| anyhow!(error)),
 982                    ))
 983                }
 984                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
 985            }
 986        })
 987    }
 988
 989    pub fn authenticate_with_browser(
 990        self: &Arc<Self>,
 991        cx: &AsyncAppContext,
 992    ) -> Task<Result<Credentials>> {
 993        let platform = cx.platform();
 994        let executor = cx.background();
 995        let telemetry = self.telemetry.clone();
 996        let http = self.http.clone();
 997        executor.clone().spawn(async move {
 998            // Generate a pair of asymmetric encryption keys. The public key will be used by the
 999            // zed server to encrypt the user's access token, so that it can'be intercepted by
1000            // any other app running on the user's device.
1001            let (public_key, private_key) =
1002                rpc::auth::keypair().expect("failed to generate keypair for auth");
1003            let public_key_string =
1004                String::try_from(public_key).expect("failed to serialize public key for auth");
1005
1006            if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1007                return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1008            }
1009
1010            // Start an HTTP server to receive the redirect from Zed's sign-in page.
1011            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1012            let port = server.server_addr().port();
1013
1014            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1015            // that the user is signing in from a Zed app running on the same device.
1016            let mut url = format!(
1017                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1018                *ZED_SERVER_URL, port, public_key_string
1019            );
1020
1021            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1022                log::info!("impersonating user @{}", impersonate_login);
1023                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1024            }
1025
1026            platform.open_url(&url);
1027
1028            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1029            // access token from the query params.
1030            //
1031            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1032            // custom URL scheme instead of this local HTTP server.
1033            let (user_id, access_token) = executor
1034                .spawn(async move {
1035                    for _ in 0..100 {
1036                        if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1037                            let path = req.url();
1038                            let mut user_id = None;
1039                            let mut access_token = None;
1040                            let url = Url::parse(&format!("http://example.com{}", path))
1041                                .context("failed to parse login notification url")?;
1042                            for (key, value) in url.query_pairs() {
1043                                if key == "access_token" {
1044                                    access_token = Some(value.to_string());
1045                                } else if key == "user_id" {
1046                                    user_id = Some(value.to_string());
1047                                }
1048                            }
1049
1050                            let post_auth_url =
1051                                format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1052                            req.respond(
1053                                tiny_http::Response::empty(302).with_header(
1054                                    tiny_http::Header::from_bytes(
1055                                        &b"Location"[..],
1056                                        post_auth_url.as_bytes(),
1057                                    )
1058                                    .unwrap(),
1059                                ),
1060                            )
1061                            .context("failed to respond to login http request")?;
1062                            return Ok((
1063                                user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1064                                access_token
1065                                    .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1066                            ));
1067                        }
1068                    }
1069
1070                    Err(anyhow!("didn't receive login redirect"))
1071                })
1072                .await?;
1073
1074            let access_token = private_key
1075                .decrypt_string(&access_token)
1076                .context("failed to decrypt access token")?;
1077            platform.activate(true);
1078
1079            telemetry.report_event("authenticate with browser", Default::default());
1080
1081            Ok(Credentials {
1082                user_id: user_id.parse()?,
1083                access_token,
1084            })
1085        })
1086    }
1087
1088    async fn authenticate_as_admin(
1089        http: Arc<dyn HttpClient>,
1090        login: String,
1091        mut api_token: String,
1092    ) -> Result<Credentials> {
1093        #[derive(Deserialize)]
1094        struct AuthenticatedUserResponse {
1095            user: User,
1096        }
1097
1098        #[derive(Deserialize)]
1099        struct User {
1100            id: u64,
1101        }
1102
1103        // Use the collab server's admin API to retrieve the id
1104        // of the impersonated user.
1105        let mut url = Self::get_rpc_url(http.clone(), false).await?;
1106        url.set_path("/user");
1107        url.set_query(Some(&format!("github_login={login}")));
1108        let request = Request::get(url.as_str())
1109            .header("Authorization", format!("token {api_token}"))
1110            .body("".into())?;
1111
1112        let mut response = http.send(request).await?;
1113        let mut body = String::new();
1114        response.body_mut().read_to_string(&mut body).await?;
1115        if !response.status().is_success() {
1116            Err(anyhow!(
1117                "admin user request failed {} - {}",
1118                response.status().as_u16(),
1119                body,
1120            ))?;
1121        }
1122        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1123
1124        // Use the admin API token to authenticate as the impersonated user.
1125        api_token.insert_str(0, "ADMIN_TOKEN:");
1126        Ok(Credentials {
1127            user_id: response.user.id,
1128            access_token: api_token,
1129        })
1130    }
1131
1132    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
1133        let conn_id = self.connection_id()?;
1134        self.peer.disconnect(conn_id);
1135        self.set_status(Status::SignedOut, cx);
1136        Ok(())
1137    }
1138
1139    fn connection_id(&self) -> Result<ConnectionId> {
1140        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1141            Ok(connection_id)
1142        } else {
1143            Err(anyhow!("not connected"))
1144        }
1145    }
1146
1147    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1148        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1149        self.peer.send(self.connection_id()?, message)
1150    }
1151
1152    pub fn request<T: RequestMessage>(
1153        &self,
1154        request: T,
1155    ) -> impl Future<Output = Result<T::Response>> {
1156        let client_id = self.id;
1157        log::debug!(
1158            "rpc request start. client_id:{}. name:{}",
1159            client_id,
1160            T::NAME
1161        );
1162        let response = self
1163            .connection_id()
1164            .map(|conn_id| self.peer.request(conn_id, request));
1165        async move {
1166            let response = response?.await;
1167            log::debug!(
1168                "rpc request finish. client_id:{}. name:{}",
1169                client_id,
1170                T::NAME
1171            );
1172            response
1173        }
1174    }
1175
1176    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1177        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1178        self.peer.respond(receipt, response)
1179    }
1180
1181    fn respond_with_error<T: RequestMessage>(
1182        &self,
1183        receipt: Receipt<T>,
1184        error: proto::Error,
1185    ) -> Result<()> {
1186        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1187        self.peer.respond_with_error(receipt, error)
1188    }
1189
1190    fn handle_message(
1191        self: &Arc<Client>,
1192        message: Box<dyn AnyTypedEnvelope>,
1193        cx: &AsyncAppContext,
1194    ) {
1195        let mut state = self.state.write();
1196        let type_name = message.payload_type_name();
1197        let payload_type_id = message.payload_type_id();
1198        let sender_id = message.original_sender_id().map(|id| id.0);
1199
1200        let mut subscriber = None;
1201
1202        if let Some(message_model) = state
1203            .models_by_message_type
1204            .get(&payload_type_id)
1205            .and_then(|model| model.upgrade(cx))
1206        {
1207            subscriber = Some(Subscriber::Model(message_model));
1208        } else if let Some((extract_entity_id, entity_type_id)) =
1209            state.entity_id_extractors.get(&payload_type_id).zip(
1210                state
1211                    .entity_types_by_message_type
1212                    .get(&payload_type_id)
1213                    .copied(),
1214            )
1215        {
1216            let entity_id = (extract_entity_id)(message.as_ref());
1217
1218            match state
1219                .entities_by_type_and_remote_id
1220                .get_mut(&(entity_type_id, entity_id))
1221            {
1222                Some(WeakSubscriber::Pending(pending)) => {
1223                    pending.push(message);
1224                    return;
1225                }
1226                Some(weak_subscriber @ _) => subscriber = weak_subscriber.upgrade(cx),
1227                _ => {}
1228            }
1229        }
1230
1231        let subscriber = if let Some(subscriber) = subscriber {
1232            subscriber
1233        } else {
1234            log::info!("unhandled message {}", type_name);
1235            return;
1236        };
1237
1238        let handler = state.message_handlers.get(&payload_type_id).cloned();
1239        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1240        // It also ensures we don't hold the lock while yielding back to the executor, as
1241        // that might cause the executor thread driving this future to block indefinitely.
1242        drop(state);
1243
1244        if let Some(handler) = handler {
1245            let future = handler(subscriber, message, &self, cx.clone());
1246            let client_id = self.id;
1247            log::debug!(
1248                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1249                client_id,
1250                sender_id,
1251                type_name
1252            );
1253            cx.foreground()
1254                .spawn(async move {
1255                    match future.await {
1256                        Ok(()) => {
1257                            log::debug!(
1258                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1259                                client_id,
1260                                sender_id,
1261                                type_name
1262                            );
1263                        }
1264                        Err(error) => {
1265                            log::error!(
1266                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1267                                client_id,
1268                                sender_id,
1269                                type_name,
1270                                error
1271                            );
1272                        }
1273                    }
1274                })
1275                .detach();
1276        } else {
1277            log::info!("unhandled message {}", type_name);
1278        }
1279    }
1280
1281    pub fn start_telemetry(&self, db: Db) {
1282        self.telemetry.start(db.clone());
1283    }
1284
1285    pub fn report_event(&self, kind: &str, properties: Value) {
1286        self.telemetry.report_event(kind, properties.clone());
1287    }
1288
1289    pub fn telemetry_log_file_path(&self) -> Option<PathBuf> {
1290        self.telemetry.log_file_path()
1291    }
1292}
1293
1294impl WeakSubscriber {
1295    fn upgrade(&self, cx: &AsyncAppContext) -> Option<Subscriber> {
1296        match self {
1297            WeakSubscriber::Model(handle) => handle.upgrade(cx).map(Subscriber::Model),
1298            WeakSubscriber::View(handle) => handle.upgrade(cx).map(Subscriber::View),
1299            WeakSubscriber::Pending(_) => None,
1300        }
1301    }
1302}
1303
1304fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1305    if IMPERSONATE_LOGIN.is_some() {
1306        return None;
1307    }
1308
1309    let (user_id, access_token) = cx
1310        .platform()
1311        .read_credentials(&ZED_SERVER_URL)
1312        .log_err()
1313        .flatten()?;
1314    Some(Credentials {
1315        user_id: user_id.parse().ok()?,
1316        access_token: String::from_utf8(access_token).ok()?,
1317    })
1318}
1319
1320fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1321    cx.platform().write_credentials(
1322        &ZED_SERVER_URL,
1323        &credentials.user_id.to_string(),
1324        credentials.access_token.as_bytes(),
1325    )
1326}
1327
1328const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1329
1330pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1331    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1332}
1333
1334pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1335    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1336    let mut parts = path.split('/');
1337    let id = parts.next()?.parse::<u64>().ok()?;
1338    let access_token = parts.next()?;
1339    if access_token.is_empty() {
1340        return None;
1341    }
1342    Some((id, access_token.to_string()))
1343}
1344
1345#[cfg(test)]
1346mod tests {
1347    use super::*;
1348    use crate::test::{FakeHttpClient, FakeServer};
1349    use gpui::{executor::Deterministic, TestAppContext};
1350    use parking_lot::Mutex;
1351    use std::future;
1352
1353    #[gpui::test(iterations = 10)]
1354    async fn test_reconnection(cx: &mut TestAppContext) {
1355        cx.foreground().forbid_parking();
1356
1357        let user_id = 5;
1358        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1359        let server = FakeServer::for_client(user_id, &client, cx).await;
1360        let mut status = client.status();
1361        assert!(matches!(
1362            status.next().await,
1363            Some(Status::Connected { .. })
1364        ));
1365        assert_eq!(server.auth_count(), 1);
1366
1367        server.forbid_connections();
1368        server.disconnect();
1369        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1370
1371        server.allow_connections();
1372        cx.foreground().advance_clock(Duration::from_secs(10));
1373        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1374        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1375
1376        server.forbid_connections();
1377        server.disconnect();
1378        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1379
1380        // Clear cached credentials after authentication fails
1381        server.roll_access_token();
1382        server.allow_connections();
1383        cx.foreground().advance_clock(Duration::from_secs(10));
1384        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1385        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1386    }
1387
1388    #[gpui::test(iterations = 10)]
1389    async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1390        deterministic.forbid_parking();
1391
1392        let user_id = 5;
1393        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1394        let mut status = client.status();
1395
1396        // Time out when client tries to connect.
1397        client.override_authenticate(move |cx| {
1398            cx.foreground().spawn(async move {
1399                Ok(Credentials {
1400                    user_id,
1401                    access_token: "token".into(),
1402                })
1403            })
1404        });
1405        client.override_establish_connection(|_, cx| {
1406            cx.foreground().spawn(async move {
1407                future::pending::<()>().await;
1408                unreachable!()
1409            })
1410        });
1411        let auth_and_connect = cx.spawn({
1412            let client = client.clone();
1413            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1414        });
1415        deterministic.run_until_parked();
1416        assert!(matches!(status.next().await, Some(Status::Connecting)));
1417
1418        deterministic.advance_clock(CONNECTION_TIMEOUT);
1419        assert!(matches!(
1420            status.next().await,
1421            Some(Status::ConnectionError { .. })
1422        ));
1423        auth_and_connect.await.unwrap_err();
1424
1425        // Allow the connection to be established.
1426        let server = FakeServer::for_client(user_id, &client, cx).await;
1427        assert!(matches!(
1428            status.next().await,
1429            Some(Status::Connected { .. })
1430        ));
1431
1432        // Disconnect client.
1433        server.forbid_connections();
1434        server.disconnect();
1435        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1436
1437        // Time out when re-establishing the connection.
1438        server.allow_connections();
1439        client.override_establish_connection(|_, cx| {
1440            cx.foreground().spawn(async move {
1441                future::pending::<()>().await;
1442                unreachable!()
1443            })
1444        });
1445        deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1446        assert!(matches!(
1447            status.next().await,
1448            Some(Status::Reconnecting { .. })
1449        ));
1450
1451        deterministic.advance_clock(CONNECTION_TIMEOUT);
1452        assert!(matches!(
1453            status.next().await,
1454            Some(Status::ReconnectionError { .. })
1455        ));
1456    }
1457
1458    #[gpui::test(iterations = 10)]
1459    async fn test_authenticating_more_than_once(
1460        cx: &mut TestAppContext,
1461        deterministic: Arc<Deterministic>,
1462    ) {
1463        cx.foreground().forbid_parking();
1464
1465        let auth_count = Arc::new(Mutex::new(0));
1466        let dropped_auth_count = Arc::new(Mutex::new(0));
1467        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1468        client.override_authenticate({
1469            let auth_count = auth_count.clone();
1470            let dropped_auth_count = dropped_auth_count.clone();
1471            move |cx| {
1472                let auth_count = auth_count.clone();
1473                let dropped_auth_count = dropped_auth_count.clone();
1474                cx.foreground().spawn(async move {
1475                    *auth_count.lock() += 1;
1476                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1477                    future::pending::<()>().await;
1478                    unreachable!()
1479                })
1480            }
1481        });
1482
1483        let _authenticate = cx.spawn(|cx| {
1484            let client = client.clone();
1485            async move { client.authenticate_and_connect(false, &cx).await }
1486        });
1487        deterministic.run_until_parked();
1488        assert_eq!(*auth_count.lock(), 1);
1489        assert_eq!(*dropped_auth_count.lock(), 0);
1490
1491        let _authenticate = cx.spawn(|cx| {
1492            let client = client.clone();
1493            async move { client.authenticate_and_connect(false, &cx).await }
1494        });
1495        deterministic.run_until_parked();
1496        assert_eq!(*auth_count.lock(), 2);
1497        assert_eq!(*dropped_auth_count.lock(), 1);
1498    }
1499
1500    #[test]
1501    fn test_encode_and_decode_worktree_url() {
1502        let url = encode_worktree_url(5, "deadbeef");
1503        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1504        assert_eq!(
1505            decode_worktree_url(&format!("\n {}\t", url)),
1506            Some((5, "deadbeef".to_string()))
1507        );
1508        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1509    }
1510
1511    #[gpui::test]
1512    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1513        cx.foreground().forbid_parking();
1514
1515        let user_id = 5;
1516        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1517        let server = FakeServer::for_client(user_id, &client, cx).await;
1518
1519        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1520        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1521        client.add_model_message_handler(
1522            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1523                match model.read_with(&cx, |model, _| model.id) {
1524                    1 => done_tx1.try_send(()).unwrap(),
1525                    2 => done_tx2.try_send(()).unwrap(),
1526                    _ => unreachable!(),
1527                }
1528                async { Ok(()) }
1529            },
1530        );
1531        let model1 = cx.add_model(|_| Model {
1532            id: 1,
1533            subscription: None,
1534        });
1535        let model2 = cx.add_model(|_| Model {
1536            id: 2,
1537            subscription: None,
1538        });
1539        let model3 = cx.add_model(|_| Model {
1540            id: 3,
1541            subscription: None,
1542        });
1543
1544        let _subscription1 = client
1545            .subscribe_to_entity(1)
1546            .set_model(&model1, &mut cx.to_async());
1547        let _subscription2 = client
1548            .subscribe_to_entity(2)
1549            .set_model(&model2, &mut cx.to_async());
1550        // Ensure dropping a subscription for the same entity type still allows receiving of
1551        // messages for other entity IDs of the same type.
1552        let subscription3 = client
1553            .subscribe_to_entity(3)
1554            .set_model(&model3, &mut cx.to_async());
1555        drop(subscription3);
1556
1557        server.send(proto::JoinProject { project_id: 1 });
1558        server.send(proto::JoinProject { project_id: 2 });
1559        done_rx1.next().await.unwrap();
1560        done_rx2.next().await.unwrap();
1561    }
1562
1563    #[gpui::test]
1564    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1565        cx.foreground().forbid_parking();
1566
1567        let user_id = 5;
1568        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1569        let server = FakeServer::for_client(user_id, &client, cx).await;
1570
1571        let model = cx.add_model(|_| Model::default());
1572        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1573        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1574        let subscription1 = client.add_message_handler(
1575            model.clone(),
1576            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1577                done_tx1.try_send(()).unwrap();
1578                async { Ok(()) }
1579            },
1580        );
1581        drop(subscription1);
1582        let _subscription2 =
1583            client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1584                done_tx2.try_send(()).unwrap();
1585                async { Ok(()) }
1586            });
1587        server.send(proto::Ping {});
1588        done_rx2.next().await.unwrap();
1589    }
1590
1591    #[gpui::test]
1592    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1593        cx.foreground().forbid_parking();
1594
1595        let user_id = 5;
1596        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1597        let server = FakeServer::for_client(user_id, &client, cx).await;
1598
1599        let model = cx.add_model(|_| Model::default());
1600        let (done_tx, mut done_rx) = smol::channel::unbounded();
1601        let subscription = client.add_message_handler(
1602            model.clone(),
1603            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1604                model.update(&mut cx, |model, _| model.subscription.take());
1605                done_tx.try_send(()).unwrap();
1606                async { Ok(()) }
1607            },
1608        );
1609        model.update(cx, |model, _| {
1610            model.subscription = Some(subscription);
1611        });
1612        server.send(proto::Ping {});
1613        done_rx.next().await.unwrap();
1614    }
1615
1616    #[derive(Default)]
1617    struct Model {
1618        id: usize,
1619        subscription: Option<Subscription>,
1620    }
1621
1622    impl Entity for Model {
1623        type Event = ();
1624    }
1625}