client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod http;
   5pub mod telemetry;
   6pub mod user;
   7
   8use anyhow::{anyhow, Context, Result};
   9use async_recursion::async_recursion;
  10use async_tungstenite::tungstenite::{
  11    error::Error as WebsocketError,
  12    http::{Request, StatusCode},
  13};
  14use futures::{future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryStreamExt};
  15use gpui::{
  16    actions,
  17    serde_json::{self, Value},
  18    AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AppContext, AppVersion,
  19    AsyncAppContext, Entity, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle,
  20};
  21use http::HttpClient;
  22use lazy_static::lazy_static;
  23use parking_lot::RwLock;
  24use postage::watch;
  25use rand::prelude::*;
  26use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
  27use serde::Deserialize;
  28use settings::{Settings, TelemetrySettings};
  29use std::{
  30    any::TypeId,
  31    collections::HashMap,
  32    convert::TryFrom,
  33    fmt::Write as _,
  34    future::Future,
  35    marker::PhantomData,
  36    path::PathBuf,
  37    sync::{Arc, Weak},
  38    time::{Duration, Instant},
  39};
  40use telemetry::Telemetry;
  41use thiserror::Error;
  42use url::Url;
  43use util::channel::ReleaseChannel;
  44use util::{ResultExt, TryFutureExt};
  45
  46pub use rpc::*;
  47pub use user::*;
  48
  49lazy_static! {
  50    pub static ref ZED_SERVER_URL: String =
  51        std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
  52    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  53        .ok()
  54        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  55    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
  56        .ok()
  57        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  58    pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
  59        .ok()
  60        .and_then(|v| v.parse().ok());
  61    pub static ref ZED_APP_PATH: Option<PathBuf> =
  62        std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
  63}
  64
  65pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
  66pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
  67pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
  68
  69actions!(client, [Authenticate]);
  70
  71pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
  72    cx.add_global_action({
  73        let client = client.clone();
  74        move |_: &Authenticate, cx| {
  75            let client = client.clone();
  76            cx.spawn(
  77                |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
  78            )
  79            .detach();
  80        }
  81    });
  82}
  83
  84pub struct Client {
  85    id: usize,
  86    peer: Arc<Peer>,
  87    http: Arc<dyn HttpClient>,
  88    telemetry: Arc<Telemetry>,
  89    state: RwLock<ClientState>,
  90
  91    #[allow(clippy::type_complexity)]
  92    #[cfg(any(test, feature = "test-support"))]
  93    authenticate: RwLock<
  94        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
  95    >,
  96
  97    #[allow(clippy::type_complexity)]
  98    #[cfg(any(test, feature = "test-support"))]
  99    establish_connection: RwLock<
 100        Option<
 101            Box<
 102                dyn 'static
 103                    + Send
 104                    + Sync
 105                    + Fn(
 106                        &Credentials,
 107                        &AsyncAppContext,
 108                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 109            >,
 110        >,
 111    >,
 112}
 113
 114#[derive(Error, Debug)]
 115pub enum EstablishConnectionError {
 116    #[error("upgrade required")]
 117    UpgradeRequired,
 118    #[error("unauthorized")]
 119    Unauthorized,
 120    #[error("{0}")]
 121    Other(#[from] anyhow::Error),
 122    #[error("{0}")]
 123    Http(#[from] http::Error),
 124    #[error("{0}")]
 125    Io(#[from] std::io::Error),
 126    #[error("{0}")]
 127    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 128}
 129
 130impl From<WebsocketError> for EstablishConnectionError {
 131    fn from(error: WebsocketError) -> Self {
 132        if let WebsocketError::Http(response) = &error {
 133            match response.status() {
 134                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 135                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 136                _ => {}
 137            }
 138        }
 139        EstablishConnectionError::Other(error.into())
 140    }
 141}
 142
 143impl EstablishConnectionError {
 144    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 145        Self::Other(error.into())
 146    }
 147}
 148
 149#[derive(Copy, Clone, Debug, PartialEq)]
 150pub enum Status {
 151    SignedOut,
 152    UpgradeRequired,
 153    Authenticating,
 154    Connecting,
 155    ConnectionError,
 156    Connected {
 157        peer_id: PeerId,
 158        connection_id: ConnectionId,
 159    },
 160    ConnectionLost,
 161    Reauthenticating,
 162    Reconnecting,
 163    ReconnectionError {
 164        next_reconnection: Instant,
 165    },
 166}
 167
 168impl Status {
 169    pub fn is_connected(&self) -> bool {
 170        matches!(self, Self::Connected { .. })
 171    }
 172}
 173
 174struct ClientState {
 175    credentials: Option<Credentials>,
 176    status: (watch::Sender<Status>, watch::Receiver<Status>),
 177    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 178    _reconnect_task: Option<Task<()>>,
 179    reconnect_interval: Duration,
 180    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 181    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 182    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 183    #[allow(clippy::type_complexity)]
 184    message_handlers: HashMap<
 185        TypeId,
 186        Arc<
 187            dyn Send
 188                + Sync
 189                + Fn(
 190                    Subscriber,
 191                    Box<dyn AnyTypedEnvelope>,
 192                    &Arc<Client>,
 193                    AsyncAppContext,
 194                ) -> LocalBoxFuture<'static, Result<()>>,
 195        >,
 196    >,
 197}
 198
 199enum WeakSubscriber {
 200    Model(AnyWeakModelHandle),
 201    View(AnyWeakViewHandle),
 202    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 203}
 204
 205enum Subscriber {
 206    Model(AnyModelHandle),
 207    View(AnyViewHandle),
 208}
 209
 210#[derive(Clone, Debug)]
 211pub struct Credentials {
 212    pub user_id: u64,
 213    pub access_token: String,
 214}
 215
 216impl Default for ClientState {
 217    fn default() -> Self {
 218        Self {
 219            credentials: None,
 220            status: watch::channel_with(Status::SignedOut),
 221            entity_id_extractors: Default::default(),
 222            _reconnect_task: None,
 223            reconnect_interval: Duration::from_secs(5),
 224            models_by_message_type: Default::default(),
 225            entities_by_type_and_remote_id: Default::default(),
 226            entity_types_by_message_type: Default::default(),
 227            message_handlers: Default::default(),
 228        }
 229    }
 230}
 231
 232pub enum Subscription {
 233    Entity {
 234        client: Weak<Client>,
 235        id: (TypeId, u64),
 236    },
 237    Message {
 238        client: Weak<Client>,
 239        id: TypeId,
 240    },
 241}
 242
 243impl Drop for Subscription {
 244    fn drop(&mut self) {
 245        match self {
 246            Subscription::Entity { client, id } => {
 247                if let Some(client) = client.upgrade() {
 248                    let mut state = client.state.write();
 249                    let _ = state.entities_by_type_and_remote_id.remove(id);
 250                }
 251            }
 252            Subscription::Message { client, id } => {
 253                if let Some(client) = client.upgrade() {
 254                    let mut state = client.state.write();
 255                    let _ = state.entity_types_by_message_type.remove(id);
 256                    let _ = state.message_handlers.remove(id);
 257                }
 258            }
 259        }
 260    }
 261}
 262
 263pub struct PendingEntitySubscription<T: Entity> {
 264    client: Arc<Client>,
 265    remote_id: u64,
 266    _entity_type: PhantomData<T>,
 267    consumed: bool,
 268}
 269
 270impl<T: Entity> PendingEntitySubscription<T> {
 271    pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
 272        self.consumed = true;
 273        let mut state = self.client.state.write();
 274        let id = (TypeId::of::<T>(), self.remote_id);
 275        let Some(WeakSubscriber::Pending(messages)) =
 276            state.entities_by_type_and_remote_id.remove(&id)
 277        else {
 278            unreachable!()
 279        };
 280
 281        state
 282            .entities_by_type_and_remote_id
 283            .insert(id, WeakSubscriber::Model(model.downgrade().into()));
 284        drop(state);
 285        for message in messages {
 286            self.client.handle_message(message, cx);
 287        }
 288        Subscription::Entity {
 289            client: Arc::downgrade(&self.client),
 290            id,
 291        }
 292    }
 293}
 294
 295impl<T: Entity> Drop for PendingEntitySubscription<T> {
 296    fn drop(&mut self) {
 297        if !self.consumed {
 298            let mut state = self.client.state.write();
 299            if let Some(WeakSubscriber::Pending(messages)) = state
 300                .entities_by_type_and_remote_id
 301                .remove(&(TypeId::of::<T>(), self.remote_id))
 302            {
 303                for message in messages {
 304                    log::info!("unhandled message {}", message.payload_type_name());
 305                }
 306            }
 307        }
 308    }
 309}
 310
 311impl Client {
 312    pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
 313        Arc::new(Self {
 314            id: 0,
 315            peer: Peer::new(0),
 316            telemetry: Telemetry::new(http.clone(), cx),
 317            http,
 318            state: Default::default(),
 319
 320            #[cfg(any(test, feature = "test-support"))]
 321            authenticate: Default::default(),
 322            #[cfg(any(test, feature = "test-support"))]
 323            establish_connection: Default::default(),
 324        })
 325    }
 326
 327    pub fn id(&self) -> usize {
 328        self.id
 329    }
 330
 331    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 332        self.http.clone()
 333    }
 334
 335    #[cfg(any(test, feature = "test-support"))]
 336    pub fn set_id(&mut self, id: usize) -> &Self {
 337        self.id = id;
 338        self
 339    }
 340
 341    #[cfg(any(test, feature = "test-support"))]
 342    pub fn teardown(&self) {
 343        let mut state = self.state.write();
 344        state._reconnect_task.take();
 345        state.message_handlers.clear();
 346        state.models_by_message_type.clear();
 347        state.entities_by_type_and_remote_id.clear();
 348        state.entity_id_extractors.clear();
 349        self.peer.teardown();
 350    }
 351
 352    #[cfg(any(test, feature = "test-support"))]
 353    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 354    where
 355        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 356    {
 357        *self.authenticate.write() = Some(Box::new(authenticate));
 358        self
 359    }
 360
 361    #[cfg(any(test, feature = "test-support"))]
 362    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 363    where
 364        F: 'static
 365            + Send
 366            + Sync
 367            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 368    {
 369        *self.establish_connection.write() = Some(Box::new(connect));
 370        self
 371    }
 372
 373    pub fn user_id(&self) -> Option<u64> {
 374        self.state
 375            .read()
 376            .credentials
 377            .as_ref()
 378            .map(|credentials| credentials.user_id)
 379    }
 380
 381    pub fn peer_id(&self) -> Option<PeerId> {
 382        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 383            Some(*peer_id)
 384        } else {
 385            None
 386        }
 387    }
 388
 389    pub fn status(&self) -> watch::Receiver<Status> {
 390        self.state.read().status.1.clone()
 391    }
 392
 393    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 394        log::info!("set status on client {}: {:?}", self.id, status);
 395        let mut state = self.state.write();
 396        *state.status.0.borrow_mut() = status;
 397
 398        match status {
 399            Status::Connected { .. } => {
 400                state._reconnect_task = None;
 401            }
 402            Status::ConnectionLost => {
 403                let this = self.clone();
 404                let reconnect_interval = state.reconnect_interval;
 405                state._reconnect_task = Some(cx.spawn(|cx| async move {
 406                    #[cfg(any(test, feature = "test-support"))]
 407                    let mut rng = StdRng::seed_from_u64(0);
 408                    #[cfg(not(any(test, feature = "test-support")))]
 409                    let mut rng = StdRng::from_entropy();
 410
 411                    let mut delay = INITIAL_RECONNECTION_DELAY;
 412                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 413                        log::error!("failed to connect {}", error);
 414                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 415                            this.set_status(
 416                                Status::ReconnectionError {
 417                                    next_reconnection: Instant::now() + delay,
 418                                },
 419                                &cx,
 420                            );
 421                            cx.background().timer(delay).await;
 422                            delay = delay
 423                                .mul_f32(rng.gen_range(1.0..=2.0))
 424                                .min(reconnect_interval);
 425                        } else {
 426                            break;
 427                        }
 428                    }
 429                }));
 430            }
 431            Status::SignedOut | Status::UpgradeRequired => {
 432                let telemetry_settings = cx.read(|cx| cx.global::<Settings>().telemetry());
 433                self.telemetry
 434                    .set_authenticated_user_info(None, false, telemetry_settings);
 435                state._reconnect_task.take();
 436            }
 437            _ => {}
 438        }
 439    }
 440
 441    pub fn add_view_for_remote_entity<T: View>(
 442        self: &Arc<Self>,
 443        remote_id: u64,
 444        cx: &mut ViewContext<T>,
 445    ) -> Subscription {
 446        let id = (TypeId::of::<T>(), remote_id);
 447        self.state
 448            .write()
 449            .entities_by_type_and_remote_id
 450            .insert(id, WeakSubscriber::View(cx.weak_handle().into()));
 451        Subscription::Entity {
 452            client: Arc::downgrade(self),
 453            id,
 454        }
 455    }
 456
 457    pub fn subscribe_to_entity<T: Entity>(
 458        self: &Arc<Self>,
 459        remote_id: u64,
 460    ) -> PendingEntitySubscription<T> {
 461        let id = (TypeId::of::<T>(), remote_id);
 462        self.state
 463            .write()
 464            .entities_by_type_and_remote_id
 465            .insert(id, WeakSubscriber::Pending(Default::default()));
 466
 467        PendingEntitySubscription {
 468            client: self.clone(),
 469            remote_id,
 470            consumed: false,
 471            _entity_type: PhantomData,
 472        }
 473    }
 474
 475    pub fn add_message_handler<M, E, H, F>(
 476        self: &Arc<Self>,
 477        model: ModelHandle<E>,
 478        handler: H,
 479    ) -> Subscription
 480    where
 481        M: EnvelopedMessage,
 482        E: Entity,
 483        H: 'static
 484            + Send
 485            + Sync
 486            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 487        F: 'static + Future<Output = Result<()>>,
 488    {
 489        let message_type_id = TypeId::of::<M>();
 490
 491        let mut state = self.state.write();
 492        state
 493            .models_by_message_type
 494            .insert(message_type_id, model.downgrade().into());
 495
 496        let prev_handler = state.message_handlers.insert(
 497            message_type_id,
 498            Arc::new(move |handle, envelope, client, cx| {
 499                let handle = if let Subscriber::Model(handle) = handle {
 500                    handle
 501                } else {
 502                    unreachable!();
 503                };
 504                let model = handle.downcast::<E>().unwrap();
 505                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 506                handler(model, *envelope, client.clone(), cx).boxed_local()
 507            }),
 508        );
 509        if prev_handler.is_some() {
 510            panic!("registered handler for the same message twice");
 511        }
 512
 513        Subscription::Message {
 514            client: Arc::downgrade(self),
 515            id: message_type_id,
 516        }
 517    }
 518
 519    pub fn add_request_handler<M, E, H, F>(
 520        self: &Arc<Self>,
 521        model: ModelHandle<E>,
 522        handler: H,
 523    ) -> Subscription
 524    where
 525        M: RequestMessage,
 526        E: Entity,
 527        H: 'static
 528            + Send
 529            + Sync
 530            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 531        F: 'static + Future<Output = Result<M::Response>>,
 532    {
 533        self.add_message_handler(model, move |handle, envelope, this, cx| {
 534            Self::respond_to_request(
 535                envelope.receipt(),
 536                handler(handle, envelope, this.clone(), cx),
 537                this,
 538            )
 539        })
 540    }
 541
 542    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 543    where
 544        M: EntityMessage,
 545        E: View,
 546        H: 'static
 547            + Send
 548            + Sync
 549            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 550        F: 'static + Future<Output = Result<()>>,
 551    {
 552        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 553            if let Subscriber::View(handle) = handle {
 554                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 555            } else {
 556                unreachable!();
 557            }
 558        })
 559    }
 560
 561    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 562    where
 563        M: EntityMessage,
 564        E: Entity,
 565        H: 'static
 566            + Send
 567            + Sync
 568            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 569        F: 'static + Future<Output = Result<()>>,
 570    {
 571        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 572            if let Subscriber::Model(handle) = handle {
 573                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 574            } else {
 575                unreachable!();
 576            }
 577        })
 578    }
 579
 580    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 581    where
 582        M: EntityMessage,
 583        E: Entity,
 584        H: 'static
 585            + Send
 586            + Sync
 587            + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 588        F: 'static + Future<Output = Result<()>>,
 589    {
 590        let model_type_id = TypeId::of::<E>();
 591        let message_type_id = TypeId::of::<M>();
 592
 593        let mut state = self.state.write();
 594        state
 595            .entity_types_by_message_type
 596            .insert(message_type_id, model_type_id);
 597        state
 598            .entity_id_extractors
 599            .entry(message_type_id)
 600            .or_insert_with(|| {
 601                |envelope| {
 602                    envelope
 603                        .as_any()
 604                        .downcast_ref::<TypedEnvelope<M>>()
 605                        .unwrap()
 606                        .payload
 607                        .remote_entity_id()
 608                }
 609            });
 610        let prev_handler = state.message_handlers.insert(
 611            message_type_id,
 612            Arc::new(move |handle, envelope, client, cx| {
 613                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 614                handler(handle, *envelope, client.clone(), cx).boxed_local()
 615            }),
 616        );
 617        if prev_handler.is_some() {
 618            panic!("registered handler for the same message twice");
 619        }
 620    }
 621
 622    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 623    where
 624        M: EntityMessage + RequestMessage,
 625        E: Entity,
 626        H: 'static
 627            + Send
 628            + Sync
 629            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 630        F: 'static + Future<Output = Result<M::Response>>,
 631    {
 632        self.add_model_message_handler(move |entity, envelope, client, cx| {
 633            Self::respond_to_request::<M, _>(
 634                envelope.receipt(),
 635                handler(entity, envelope, client.clone(), cx),
 636                client,
 637            )
 638        })
 639    }
 640
 641    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 642    where
 643        M: EntityMessage + RequestMessage,
 644        E: View,
 645        H: 'static
 646            + Send
 647            + Sync
 648            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 649        F: 'static + Future<Output = Result<M::Response>>,
 650    {
 651        self.add_view_message_handler(move |entity, envelope, client, cx| {
 652            Self::respond_to_request::<M, _>(
 653                envelope.receipt(),
 654                handler(entity, envelope, client.clone(), cx),
 655                client,
 656            )
 657        })
 658    }
 659
 660    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 661        receipt: Receipt<T>,
 662        response: F,
 663        client: Arc<Self>,
 664    ) -> Result<()> {
 665        match response.await {
 666            Ok(response) => {
 667                client.respond(receipt, response)?;
 668                Ok(())
 669            }
 670            Err(error) => {
 671                client.respond_with_error(
 672                    receipt,
 673                    proto::Error {
 674                        message: format!("{:?}", error),
 675                    },
 676                )?;
 677                Err(error)
 678            }
 679        }
 680    }
 681
 682    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 683        read_credentials_from_keychain(cx).is_some()
 684    }
 685
 686    #[async_recursion(?Send)]
 687    pub async fn authenticate_and_connect(
 688        self: &Arc<Self>,
 689        try_keychain: bool,
 690        cx: &AsyncAppContext,
 691    ) -> anyhow::Result<()> {
 692        let was_disconnected = match *self.status().borrow() {
 693            Status::SignedOut => true,
 694            Status::ConnectionError
 695            | Status::ConnectionLost
 696            | Status::Authenticating { .. }
 697            | Status::Reauthenticating { .. }
 698            | Status::ReconnectionError { .. } => false,
 699            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 700                return Ok(())
 701            }
 702            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 703        };
 704
 705        if was_disconnected {
 706            self.set_status(Status::Authenticating, cx);
 707        } else {
 708            self.set_status(Status::Reauthenticating, cx)
 709        }
 710
 711        let mut read_from_keychain = false;
 712        let mut credentials = self.state.read().credentials.clone();
 713        if credentials.is_none() && try_keychain {
 714            credentials = read_credentials_from_keychain(cx);
 715            read_from_keychain = credentials.is_some();
 716            if read_from_keychain {
 717                cx.read(|cx| {
 718                    self.report_event(
 719                        "read credentials from keychain",
 720                        Default::default(),
 721                        cx.global::<Settings>().telemetry(),
 722                    );
 723                });
 724            }
 725        }
 726        if credentials.is_none() {
 727            let mut status_rx = self.status();
 728            let _ = status_rx.next().await;
 729            futures::select_biased! {
 730                authenticate = self.authenticate(cx).fuse() => {
 731                    match authenticate {
 732                        Ok(creds) => credentials = Some(creds),
 733                        Err(err) => {
 734                            self.set_status(Status::ConnectionError, cx);
 735                            return Err(err);
 736                        }
 737                    }
 738                }
 739                _ = status_rx.next().fuse() => {
 740                    return Err(anyhow!("authentication canceled"));
 741                }
 742            }
 743        }
 744        let credentials = credentials.unwrap();
 745
 746        if was_disconnected {
 747            self.set_status(Status::Connecting, cx);
 748        } else {
 749            self.set_status(Status::Reconnecting, cx);
 750        }
 751
 752        let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
 753        futures::select_biased! {
 754            connection = self.establish_connection(&credentials, cx).fuse() => {
 755                match connection {
 756                    Ok(conn) => {
 757                        self.state.write().credentials = Some(credentials.clone());
 758                        if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 759                            write_credentials_to_keychain(&credentials, cx).log_err();
 760                        }
 761
 762                        futures::select_biased! {
 763                            result = self.set_connection(conn, cx).fuse() => result,
 764                            _ = timeout => {
 765                                self.set_status(Status::ConnectionError, cx);
 766                                Err(anyhow!("timed out waiting on hello message from server"))
 767                            }
 768                        }
 769                    }
 770                    Err(EstablishConnectionError::Unauthorized) => {
 771                        self.state.write().credentials.take();
 772                        if read_from_keychain {
 773                            cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 774                            self.set_status(Status::SignedOut, cx);
 775                            self.authenticate_and_connect(false, cx).await
 776                        } else {
 777                            self.set_status(Status::ConnectionError, cx);
 778                            Err(EstablishConnectionError::Unauthorized)?
 779                        }
 780                    }
 781                    Err(EstablishConnectionError::UpgradeRequired) => {
 782                        self.set_status(Status::UpgradeRequired, cx);
 783                        Err(EstablishConnectionError::UpgradeRequired)?
 784                    }
 785                    Err(error) => {
 786                        self.set_status(Status::ConnectionError, cx);
 787                        Err(error)?
 788                    }
 789                }
 790            }
 791            _ = &mut timeout => {
 792                self.set_status(Status::ConnectionError, cx);
 793                Err(anyhow!("timed out trying to establish connection"))
 794            }
 795        }
 796    }
 797
 798    async fn set_connection(
 799        self: &Arc<Self>,
 800        conn: Connection,
 801        cx: &AsyncAppContext,
 802    ) -> Result<()> {
 803        let executor = cx.background();
 804        log::info!("add connection to peer");
 805        let (connection_id, handle_io, mut incoming) = self
 806            .peer
 807            .add_connection(conn, move |duration| executor.timer(duration));
 808        let handle_io = cx.background().spawn(handle_io);
 809
 810        let peer_id = async {
 811            log::info!("waiting for server hello");
 812            let message = incoming
 813                .next()
 814                .await
 815                .ok_or_else(|| anyhow!("no hello message received"))?;
 816            log::info!("got server hello");
 817            let hello_message_type_name = message.payload_type_name().to_string();
 818            let hello = message
 819                .into_any()
 820                .downcast::<TypedEnvelope<proto::Hello>>()
 821                .map_err(|_| {
 822                    anyhow!(
 823                        "invalid hello message received: {:?}",
 824                        hello_message_type_name
 825                    )
 826                })?;
 827            let peer_id = hello
 828                .payload
 829                .peer_id
 830                .ok_or_else(|| anyhow!("invalid peer id"))?;
 831            Ok(peer_id)
 832        };
 833
 834        let peer_id = match peer_id.await {
 835            Ok(peer_id) => peer_id,
 836            Err(error) => {
 837                self.peer.disconnect(connection_id);
 838                return Err(error);
 839            }
 840        };
 841
 842        log::info!(
 843            "set status to connected (connection id: {:?}, peer id: {:?})",
 844            connection_id,
 845            peer_id
 846        );
 847        self.set_status(
 848            Status::Connected {
 849                peer_id,
 850                connection_id,
 851            },
 852            cx,
 853        );
 854        cx.foreground()
 855            .spawn({
 856                let cx = cx.clone();
 857                let this = self.clone();
 858                async move {
 859                    while let Some(message) = incoming.next().await {
 860                        this.handle_message(message, &cx);
 861                        // Don't starve the main thread when receiving lots of messages at once.
 862                        smol::future::yield_now().await;
 863                    }
 864                }
 865            })
 866            .detach();
 867
 868        let this = self.clone();
 869        let cx = cx.clone();
 870        cx.foreground()
 871            .spawn(async move {
 872                match handle_io.await {
 873                    Ok(()) => {
 874                        if this.status().borrow().clone()
 875                            == (Status::Connected {
 876                                connection_id,
 877                                peer_id,
 878                            })
 879                        {
 880                            this.set_status(Status::SignedOut, &cx);
 881                        }
 882                    }
 883                    Err(err) => {
 884                        log::error!("connection error: {:?}", err);
 885                        this.set_status(Status::ConnectionLost, &cx);
 886                    }
 887                }
 888            })
 889            .detach();
 890
 891        Ok(())
 892    }
 893
 894    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 895        #[cfg(any(test, feature = "test-support"))]
 896        if let Some(callback) = self.authenticate.read().as_ref() {
 897            return callback(cx);
 898        }
 899
 900        self.authenticate_with_browser(cx)
 901    }
 902
 903    fn establish_connection(
 904        self: &Arc<Self>,
 905        credentials: &Credentials,
 906        cx: &AsyncAppContext,
 907    ) -> Task<Result<Connection, EstablishConnectionError>> {
 908        #[cfg(any(test, feature = "test-support"))]
 909        if let Some(callback) = self.establish_connection.read().as_ref() {
 910            return callback(credentials, cx);
 911        }
 912
 913        self.establish_websocket_connection(credentials, cx)
 914    }
 915
 916    async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
 917        let preview_param = if is_preview { "?preview=1" } else { "" };
 918        let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
 919        let response = http.get(&url, Default::default(), false).await?;
 920
 921        // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
 922        // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
 923        // which requires authorization via an HTTP header.
 924        //
 925        // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
 926        // of a collab server. In that case, a request to the /rpc endpoint will
 927        // return an 'unauthorized' response.
 928        let collab_url = if response.status().is_redirection() {
 929            response
 930                .headers()
 931                .get("Location")
 932                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 933                .to_str()
 934                .map_err(EstablishConnectionError::other)?
 935                .to_string()
 936        } else if response.status() == StatusCode::UNAUTHORIZED {
 937            url
 938        } else {
 939            Err(anyhow!(
 940                "unexpected /rpc response status {}",
 941                response.status()
 942            ))?
 943        };
 944
 945        Url::parse(&collab_url).context("invalid rpc url")
 946    }
 947
 948    fn establish_websocket_connection(
 949        self: &Arc<Self>,
 950        credentials: &Credentials,
 951        cx: &AsyncAppContext,
 952    ) -> Task<Result<Connection, EstablishConnectionError>> {
 953        let is_preview = cx.read(|cx| {
 954            if cx.has_global::<ReleaseChannel>() {
 955                *cx.global::<ReleaseChannel>() == ReleaseChannel::Preview
 956            } else {
 957                false
 958            }
 959        });
 960
 961        let request = Request::builder()
 962            .header(
 963                "Authorization",
 964                format!("{} {}", credentials.user_id, credentials.access_token),
 965            )
 966            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
 967
 968        let http = self.http.clone();
 969        cx.background().spawn(async move {
 970            let mut rpc_url = Self::get_rpc_url(http, is_preview).await?;
 971            let rpc_host = rpc_url
 972                .host_str()
 973                .zip(rpc_url.port_or_known_default())
 974                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
 975            let stream = smol::net::TcpStream::connect(rpc_host).await?;
 976
 977            log::info!("connected to rpc endpoint {}", rpc_url);
 978
 979            match rpc_url.scheme() {
 980                "https" => {
 981                    rpc_url.set_scheme("wss").unwrap();
 982                    let request = request.uri(rpc_url.as_str()).body(())?;
 983                    let (stream, _) =
 984                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
 985                    Ok(Connection::new(
 986                        stream
 987                            .map_err(|error| anyhow!(error))
 988                            .sink_map_err(|error| anyhow!(error)),
 989                    ))
 990                }
 991                "http" => {
 992                    rpc_url.set_scheme("ws").unwrap();
 993                    let request = request.uri(rpc_url.as_str()).body(())?;
 994                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
 995                    Ok(Connection::new(
 996                        stream
 997                            .map_err(|error| anyhow!(error))
 998                            .sink_map_err(|error| anyhow!(error)),
 999                    ))
1000                }
1001                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1002            }
1003        })
1004    }
1005
1006    pub fn authenticate_with_browser(
1007        self: &Arc<Self>,
1008        cx: &AsyncAppContext,
1009    ) -> Task<Result<Credentials>> {
1010        let platform = cx.platform();
1011        let executor = cx.background();
1012        let telemetry = self.telemetry.clone();
1013        let http = self.http.clone();
1014        let metrics_enabled = cx.read(|cx| cx.global::<Settings>().telemetry());
1015
1016        executor.clone().spawn(async move {
1017            // Generate a pair of asymmetric encryption keys. The public key will be used by the
1018            // zed server to encrypt the user's access token, so that it can'be intercepted by
1019            // any other app running on the user's device.
1020            let (public_key, private_key) =
1021                rpc::auth::keypair().expect("failed to generate keypair for auth");
1022            let public_key_string =
1023                String::try_from(public_key).expect("failed to serialize public key for auth");
1024
1025            if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1026                return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1027            }
1028
1029            // Start an HTTP server to receive the redirect from Zed's sign-in page.
1030            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1031            let port = server.server_addr().port();
1032
1033            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1034            // that the user is signing in from a Zed app running on the same device.
1035            let mut url = format!(
1036                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1037                *ZED_SERVER_URL, port, public_key_string
1038            );
1039
1040            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1041                log::info!("impersonating user @{}", impersonate_login);
1042                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1043            }
1044
1045            platform.open_url(&url);
1046
1047            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1048            // access token from the query params.
1049            //
1050            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1051            // custom URL scheme instead of this local HTTP server.
1052            let (user_id, access_token) = executor
1053                .spawn(async move {
1054                    for _ in 0..100 {
1055                        if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1056                            let path = req.url();
1057                            let mut user_id = None;
1058                            let mut access_token = None;
1059                            let url = Url::parse(&format!("http://example.com{}", path))
1060                                .context("failed to parse login notification url")?;
1061                            for (key, value) in url.query_pairs() {
1062                                if key == "access_token" {
1063                                    access_token = Some(value.to_string());
1064                                } else if key == "user_id" {
1065                                    user_id = Some(value.to_string());
1066                                }
1067                            }
1068
1069                            let post_auth_url =
1070                                format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1071                            req.respond(
1072                                tiny_http::Response::empty(302).with_header(
1073                                    tiny_http::Header::from_bytes(
1074                                        &b"Location"[..],
1075                                        post_auth_url.as_bytes(),
1076                                    )
1077                                    .unwrap(),
1078                                ),
1079                            )
1080                            .context("failed to respond to login http request")?;
1081                            return Ok((
1082                                user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1083                                access_token
1084                                    .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1085                            ));
1086                        }
1087                    }
1088
1089                    Err(anyhow!("didn't receive login redirect"))
1090                })
1091                .await?;
1092
1093            let access_token = private_key
1094                .decrypt_string(&access_token)
1095                .context("failed to decrypt access token")?;
1096            platform.activate(true);
1097
1098            telemetry.report_event(
1099                "authenticate with browser",
1100                Default::default(),
1101                metrics_enabled,
1102            );
1103
1104            Ok(Credentials {
1105                user_id: user_id.parse()?,
1106                access_token,
1107            })
1108        })
1109    }
1110
1111    async fn authenticate_as_admin(
1112        http: Arc<dyn HttpClient>,
1113        login: String,
1114        mut api_token: String,
1115    ) -> Result<Credentials> {
1116        #[derive(Deserialize)]
1117        struct AuthenticatedUserResponse {
1118            user: User,
1119        }
1120
1121        #[derive(Deserialize)]
1122        struct User {
1123            id: u64,
1124        }
1125
1126        // Use the collab server's admin API to retrieve the id
1127        // of the impersonated user.
1128        let mut url = Self::get_rpc_url(http.clone(), false).await?;
1129        url.set_path("/user");
1130        url.set_query(Some(&format!("github_login={login}")));
1131        let request = Request::get(url.as_str())
1132            .header("Authorization", format!("token {api_token}"))
1133            .body("".into())?;
1134
1135        let mut response = http.send(request).await?;
1136        let mut body = String::new();
1137        response.body_mut().read_to_string(&mut body).await?;
1138        if !response.status().is_success() {
1139            Err(anyhow!(
1140                "admin user request failed {} - {}",
1141                response.status().as_u16(),
1142                body,
1143            ))?;
1144        }
1145        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1146
1147        // Use the admin API token to authenticate as the impersonated user.
1148        api_token.insert_str(0, "ADMIN_TOKEN:");
1149        Ok(Credentials {
1150            user_id: response.user.id,
1151            access_token: api_token,
1152        })
1153    }
1154
1155    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
1156        let conn_id = self.connection_id()?;
1157        self.peer.disconnect(conn_id);
1158        self.set_status(Status::SignedOut, cx);
1159        Ok(())
1160    }
1161
1162    fn connection_id(&self) -> Result<ConnectionId> {
1163        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1164            Ok(connection_id)
1165        } else {
1166            Err(anyhow!("not connected"))
1167        }
1168    }
1169
1170    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1171        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1172        self.peer.send(self.connection_id()?, message)
1173    }
1174
1175    pub fn request<T: RequestMessage>(
1176        &self,
1177        request: T,
1178    ) -> impl Future<Output = Result<T::Response>> {
1179        let client_id = self.id;
1180        log::debug!(
1181            "rpc request start. client_id:{}. name:{}",
1182            client_id,
1183            T::NAME
1184        );
1185        let response = self
1186            .connection_id()
1187            .map(|conn_id| self.peer.request(conn_id, request));
1188        async move {
1189            let response = response?.await;
1190            log::debug!(
1191                "rpc request finish. client_id:{}. name:{}",
1192                client_id,
1193                T::NAME
1194            );
1195            response
1196        }
1197    }
1198
1199    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1200        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1201        self.peer.respond(receipt, response)
1202    }
1203
1204    fn respond_with_error<T: RequestMessage>(
1205        &self,
1206        receipt: Receipt<T>,
1207        error: proto::Error,
1208    ) -> Result<()> {
1209        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1210        self.peer.respond_with_error(receipt, error)
1211    }
1212
1213    fn handle_message(
1214        self: &Arc<Client>,
1215        message: Box<dyn AnyTypedEnvelope>,
1216        cx: &AsyncAppContext,
1217    ) {
1218        let mut state = self.state.write();
1219        let type_name = message.payload_type_name();
1220        let payload_type_id = message.payload_type_id();
1221        let sender_id = message.original_sender_id();
1222
1223        let mut subscriber = None;
1224
1225        if let Some(message_model) = state
1226            .models_by_message_type
1227            .get(&payload_type_id)
1228            .and_then(|model| model.upgrade(cx))
1229        {
1230            subscriber = Some(Subscriber::Model(message_model));
1231        } else if let Some((extract_entity_id, entity_type_id)) =
1232            state.entity_id_extractors.get(&payload_type_id).zip(
1233                state
1234                    .entity_types_by_message_type
1235                    .get(&payload_type_id)
1236                    .copied(),
1237            )
1238        {
1239            let entity_id = (extract_entity_id)(message.as_ref());
1240
1241            match state
1242                .entities_by_type_and_remote_id
1243                .get_mut(&(entity_type_id, entity_id))
1244            {
1245                Some(WeakSubscriber::Pending(pending)) => {
1246                    pending.push(message);
1247                    return;
1248                }
1249                Some(weak_subscriber @ _) => subscriber = weak_subscriber.upgrade(cx),
1250                _ => {}
1251            }
1252        }
1253
1254        let subscriber = if let Some(subscriber) = subscriber {
1255            subscriber
1256        } else {
1257            log::info!("unhandled message {}", type_name);
1258            self.peer.respond_with_unhandled_message(message).log_err();
1259            return;
1260        };
1261
1262        let handler = state.message_handlers.get(&payload_type_id).cloned();
1263        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1264        // It also ensures we don't hold the lock while yielding back to the executor, as
1265        // that might cause the executor thread driving this future to block indefinitely.
1266        drop(state);
1267
1268        if let Some(handler) = handler {
1269            let future = handler(subscriber, message, &self, cx.clone());
1270            let client_id = self.id;
1271            log::debug!(
1272                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1273                client_id,
1274                sender_id,
1275                type_name
1276            );
1277            cx.foreground()
1278                .spawn(async move {
1279                    match future.await {
1280                        Ok(()) => {
1281                            log::debug!(
1282                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1283                                client_id,
1284                                sender_id,
1285                                type_name
1286                            );
1287                        }
1288                        Err(error) => {
1289                            log::error!(
1290                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1291                                client_id,
1292                                sender_id,
1293                                type_name,
1294                                error
1295                            );
1296                        }
1297                    }
1298                })
1299                .detach();
1300        } else {
1301            log::info!("unhandled message {}", type_name);
1302            self.peer.respond_with_unhandled_message(message).log_err();
1303        }
1304    }
1305
1306    pub fn start_telemetry(&self) {
1307        self.telemetry.start();
1308    }
1309
1310    pub fn report_event(
1311        &self,
1312        kind: &str,
1313        properties: Value,
1314        telemetry_settings: TelemetrySettings,
1315    ) {
1316        self.telemetry
1317            .report_event(kind, properties.clone(), telemetry_settings);
1318    }
1319
1320    pub fn telemetry_log_file_path(&self) -> Option<PathBuf> {
1321        self.telemetry.log_file_path()
1322    }
1323
1324    pub fn metrics_id(&self) -> Option<Arc<str>> {
1325        self.telemetry.metrics_id()
1326    }
1327
1328    pub fn is_staff(&self) -> Option<bool> {
1329        self.telemetry.is_staff()
1330    }
1331}
1332
1333impl WeakSubscriber {
1334    fn upgrade(&self, cx: &AsyncAppContext) -> Option<Subscriber> {
1335        match self {
1336            WeakSubscriber::Model(handle) => handle.upgrade(cx).map(Subscriber::Model),
1337            WeakSubscriber::View(handle) => handle.upgrade(cx).map(Subscriber::View),
1338            WeakSubscriber::Pending(_) => None,
1339        }
1340    }
1341}
1342
1343fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1344    if IMPERSONATE_LOGIN.is_some() {
1345        return None;
1346    }
1347
1348    let (user_id, access_token) = cx
1349        .platform()
1350        .read_credentials(&ZED_SERVER_URL)
1351        .log_err()
1352        .flatten()?;
1353    Some(Credentials {
1354        user_id: user_id.parse().ok()?,
1355        access_token: String::from_utf8(access_token).ok()?,
1356    })
1357}
1358
1359fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1360    cx.platform().write_credentials(
1361        &ZED_SERVER_URL,
1362        &credentials.user_id.to_string(),
1363        credentials.access_token.as_bytes(),
1364    )
1365}
1366
1367const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1368
1369pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1370    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1371}
1372
1373pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1374    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1375    let mut parts = path.split('/');
1376    let id = parts.next()?.parse::<u64>().ok()?;
1377    let access_token = parts.next()?;
1378    if access_token.is_empty() {
1379        return None;
1380    }
1381    Some((id, access_token.to_string()))
1382}
1383
1384#[cfg(test)]
1385mod tests {
1386    use super::*;
1387    use crate::test::{FakeHttpClient, FakeServer};
1388    use gpui::{executor::Deterministic, TestAppContext};
1389    use parking_lot::Mutex;
1390    use std::future;
1391
1392    #[gpui::test(iterations = 10)]
1393    async fn test_reconnection(cx: &mut TestAppContext) {
1394        cx.foreground().forbid_parking();
1395
1396        let user_id = 5;
1397        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1398        let server = FakeServer::for_client(user_id, &client, cx).await;
1399        let mut status = client.status();
1400        assert!(matches!(
1401            status.next().await,
1402            Some(Status::Connected { .. })
1403        ));
1404        assert_eq!(server.auth_count(), 1);
1405
1406        server.forbid_connections();
1407        server.disconnect();
1408        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1409
1410        server.allow_connections();
1411        cx.foreground().advance_clock(Duration::from_secs(10));
1412        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1413        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1414
1415        server.forbid_connections();
1416        server.disconnect();
1417        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1418
1419        // Clear cached credentials after authentication fails
1420        server.roll_access_token();
1421        server.allow_connections();
1422        cx.foreground().advance_clock(Duration::from_secs(10));
1423        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1424        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1425    }
1426
1427    #[gpui::test(iterations = 10)]
1428    async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1429        deterministic.forbid_parking();
1430
1431        let user_id = 5;
1432        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1433        let mut status = client.status();
1434
1435        // Time out when client tries to connect.
1436        client.override_authenticate(move |cx| {
1437            cx.foreground().spawn(async move {
1438                Ok(Credentials {
1439                    user_id,
1440                    access_token: "token".into(),
1441                })
1442            })
1443        });
1444        client.override_establish_connection(|_, cx| {
1445            cx.foreground().spawn(async move {
1446                future::pending::<()>().await;
1447                unreachable!()
1448            })
1449        });
1450        let auth_and_connect = cx.spawn({
1451            let client = client.clone();
1452            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1453        });
1454        deterministic.run_until_parked();
1455        assert!(matches!(status.next().await, Some(Status::Connecting)));
1456
1457        deterministic.advance_clock(CONNECTION_TIMEOUT);
1458        assert!(matches!(
1459            status.next().await,
1460            Some(Status::ConnectionError { .. })
1461        ));
1462        auth_and_connect.await.unwrap_err();
1463
1464        // Allow the connection to be established.
1465        let server = FakeServer::for_client(user_id, &client, cx).await;
1466        assert!(matches!(
1467            status.next().await,
1468            Some(Status::Connected { .. })
1469        ));
1470
1471        // Disconnect client.
1472        server.forbid_connections();
1473        server.disconnect();
1474        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1475
1476        // Time out when re-establishing the connection.
1477        server.allow_connections();
1478        client.override_establish_connection(|_, cx| {
1479            cx.foreground().spawn(async move {
1480                future::pending::<()>().await;
1481                unreachable!()
1482            })
1483        });
1484        deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1485        assert!(matches!(
1486            status.next().await,
1487            Some(Status::Reconnecting { .. })
1488        ));
1489
1490        deterministic.advance_clock(CONNECTION_TIMEOUT);
1491        assert!(matches!(
1492            status.next().await,
1493            Some(Status::ReconnectionError { .. })
1494        ));
1495    }
1496
1497    #[gpui::test(iterations = 10)]
1498    async fn test_authenticating_more_than_once(
1499        cx: &mut TestAppContext,
1500        deterministic: Arc<Deterministic>,
1501    ) {
1502        cx.foreground().forbid_parking();
1503
1504        let auth_count = Arc::new(Mutex::new(0));
1505        let dropped_auth_count = Arc::new(Mutex::new(0));
1506        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1507        client.override_authenticate({
1508            let auth_count = auth_count.clone();
1509            let dropped_auth_count = dropped_auth_count.clone();
1510            move |cx| {
1511                let auth_count = auth_count.clone();
1512                let dropped_auth_count = dropped_auth_count.clone();
1513                cx.foreground().spawn(async move {
1514                    *auth_count.lock() += 1;
1515                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1516                    future::pending::<()>().await;
1517                    unreachable!()
1518                })
1519            }
1520        });
1521
1522        let _authenticate = cx.spawn(|cx| {
1523            let client = client.clone();
1524            async move { client.authenticate_and_connect(false, &cx).await }
1525        });
1526        deterministic.run_until_parked();
1527        assert_eq!(*auth_count.lock(), 1);
1528        assert_eq!(*dropped_auth_count.lock(), 0);
1529
1530        let _authenticate = cx.spawn(|cx| {
1531            let client = client.clone();
1532            async move { client.authenticate_and_connect(false, &cx).await }
1533        });
1534        deterministic.run_until_parked();
1535        assert_eq!(*auth_count.lock(), 2);
1536        assert_eq!(*dropped_auth_count.lock(), 1);
1537    }
1538
1539    #[test]
1540    fn test_encode_and_decode_worktree_url() {
1541        let url = encode_worktree_url(5, "deadbeef");
1542        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1543        assert_eq!(
1544            decode_worktree_url(&format!("\n {}\t", url)),
1545            Some((5, "deadbeef".to_string()))
1546        );
1547        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1548    }
1549
1550    #[gpui::test]
1551    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1552        cx.foreground().forbid_parking();
1553
1554        let user_id = 5;
1555        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1556        let server = FakeServer::for_client(user_id, &client, cx).await;
1557
1558        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1559        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1560        client.add_model_message_handler(
1561            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1562                match model.read_with(&cx, |model, _| model.id) {
1563                    1 => done_tx1.try_send(()).unwrap(),
1564                    2 => done_tx2.try_send(()).unwrap(),
1565                    _ => unreachable!(),
1566                }
1567                async { Ok(()) }
1568            },
1569        );
1570        let model1 = cx.add_model(|_| Model {
1571            id: 1,
1572            subscription: None,
1573        });
1574        let model2 = cx.add_model(|_| Model {
1575            id: 2,
1576            subscription: None,
1577        });
1578        let model3 = cx.add_model(|_| Model {
1579            id: 3,
1580            subscription: None,
1581        });
1582
1583        let _subscription1 = client
1584            .subscribe_to_entity(1)
1585            .set_model(&model1, &mut cx.to_async());
1586        let _subscription2 = client
1587            .subscribe_to_entity(2)
1588            .set_model(&model2, &mut cx.to_async());
1589        // Ensure dropping a subscription for the same entity type still allows receiving of
1590        // messages for other entity IDs of the same type.
1591        let subscription3 = client
1592            .subscribe_to_entity(3)
1593            .set_model(&model3, &mut cx.to_async());
1594        drop(subscription3);
1595
1596        server.send(proto::JoinProject { project_id: 1 });
1597        server.send(proto::JoinProject { project_id: 2 });
1598        done_rx1.next().await.unwrap();
1599        done_rx2.next().await.unwrap();
1600    }
1601
1602    #[gpui::test]
1603    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1604        cx.foreground().forbid_parking();
1605
1606        let user_id = 5;
1607        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1608        let server = FakeServer::for_client(user_id, &client, cx).await;
1609
1610        let model = cx.add_model(|_| Model::default());
1611        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1612        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1613        let subscription1 = client.add_message_handler(
1614            model.clone(),
1615            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1616                done_tx1.try_send(()).unwrap();
1617                async { Ok(()) }
1618            },
1619        );
1620        drop(subscription1);
1621        let _subscription2 =
1622            client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1623                done_tx2.try_send(()).unwrap();
1624                async { Ok(()) }
1625            });
1626        server.send(proto::Ping {});
1627        done_rx2.next().await.unwrap();
1628    }
1629
1630    #[gpui::test]
1631    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1632        cx.foreground().forbid_parking();
1633
1634        let user_id = 5;
1635        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1636        let server = FakeServer::for_client(user_id, &client, cx).await;
1637
1638        let model = cx.add_model(|_| Model::default());
1639        let (done_tx, mut done_rx) = smol::channel::unbounded();
1640        let subscription = client.add_message_handler(
1641            model.clone(),
1642            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1643                model.update(&mut cx, |model, _| model.subscription.take());
1644                done_tx.try_send(()).unwrap();
1645                async { Ok(()) }
1646            },
1647        );
1648        model.update(cx, |model, _| {
1649            model.subscription = Some(subscription);
1650        });
1651        server.send(proto::Ping {});
1652        done_rx.next().await.unwrap();
1653    }
1654
1655    #[derive(Default)]
1656    struct Model {
1657        id: usize,
1658        subscription: Option<Subscription>,
1659    }
1660
1661    impl Entity for Model {
1662        type Event = ();
1663    }
1664}