client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod telemetry;
   5pub mod user;
   6
   7use anyhow::{anyhow, Context, Result};
   8use async_recursion::async_recursion;
   9use async_tungstenite::tungstenite::{
  10    error::Error as WebsocketError,
  11    http::{Request, StatusCode},
  12};
  13use futures::{
  14    future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
  15    TryStreamExt,
  16};
  17use gpui::{
  18    actions,
  19    serde_json::{self, Value},
  20    AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AppContext, AppVersion,
  21    AsyncAppContext, Entity, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle,
  22};
  23use lazy_static::lazy_static;
  24use parking_lot::RwLock;
  25use postage::watch;
  26use rand::prelude::*;
  27use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
  28use serde::Deserialize;
  29use settings::{Settings, TelemetrySettings};
  30use std::{
  31    any::TypeId,
  32    collections::HashMap,
  33    convert::TryFrom,
  34    fmt::Write as _,
  35    future::Future,
  36    marker::PhantomData,
  37    path::PathBuf,
  38    sync::{Arc, Weak},
  39    time::{Duration, Instant},
  40};
  41use telemetry::Telemetry;
  42use thiserror::Error;
  43use url::Url;
  44use util::channel::ReleaseChannel;
  45use util::http::HttpClient;
  46use util::{ResultExt, TryFutureExt};
  47
  48pub use rpc::*;
  49pub use user::*;
  50
  51lazy_static! {
  52    pub static ref ZED_SERVER_URL: String =
  53        std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
  54    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  55        .ok()
  56        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  57    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
  58        .ok()
  59        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  60    pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
  61        .ok()
  62        .and_then(|v| v.parse().ok());
  63    pub static ref ZED_APP_PATH: Option<PathBuf> =
  64        std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
  65}
  66
  67pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
  68pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
  69pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
  70
  71actions!(client, [SignIn, SignOut]);
  72
  73pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
  74    cx.add_global_action({
  75        let client = client.clone();
  76        move |_: &SignIn, cx| {
  77            let client = client.clone();
  78            cx.spawn(
  79                |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
  80            )
  81            .detach();
  82        }
  83    });
  84    cx.add_global_action({
  85        let client = client.clone();
  86        move |_: &SignOut, cx| {
  87            let client = client.clone();
  88            cx.spawn(|cx| async move {
  89                client.disconnect(&cx);
  90            })
  91            .detach();
  92        }
  93    });
  94}
  95
  96pub struct Client {
  97    id: usize,
  98    peer: Arc<Peer>,
  99    http: Arc<dyn HttpClient>,
 100    telemetry: Arc<Telemetry>,
 101    state: RwLock<ClientState>,
 102
 103    #[allow(clippy::type_complexity)]
 104    #[cfg(any(test, feature = "test-support"))]
 105    authenticate: RwLock<
 106        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 107    >,
 108
 109    #[allow(clippy::type_complexity)]
 110    #[cfg(any(test, feature = "test-support"))]
 111    establish_connection: RwLock<
 112        Option<
 113            Box<
 114                dyn 'static
 115                    + Send
 116                    + Sync
 117                    + Fn(
 118                        &Credentials,
 119                        &AsyncAppContext,
 120                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 121            >,
 122        >,
 123    >,
 124}
 125
 126#[derive(Error, Debug)]
 127pub enum EstablishConnectionError {
 128    #[error("upgrade required")]
 129    UpgradeRequired,
 130    #[error("unauthorized")]
 131    Unauthorized,
 132    #[error("{0}")]
 133    Other(#[from] anyhow::Error),
 134    #[error("{0}")]
 135    Http(#[from] util::http::Error),
 136    #[error("{0}")]
 137    Io(#[from] std::io::Error),
 138    #[error("{0}")]
 139    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 140}
 141
 142impl From<WebsocketError> for EstablishConnectionError {
 143    fn from(error: WebsocketError) -> Self {
 144        if let WebsocketError::Http(response) = &error {
 145            match response.status() {
 146                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 147                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 148                _ => {}
 149            }
 150        }
 151        EstablishConnectionError::Other(error.into())
 152    }
 153}
 154
 155impl EstablishConnectionError {
 156    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 157        Self::Other(error.into())
 158    }
 159}
 160
 161#[derive(Copy, Clone, Debug, PartialEq)]
 162pub enum Status {
 163    SignedOut,
 164    UpgradeRequired,
 165    Authenticating,
 166    Connecting,
 167    ConnectionError,
 168    Connected {
 169        peer_id: PeerId,
 170        connection_id: ConnectionId,
 171    },
 172    ConnectionLost,
 173    Reauthenticating,
 174    Reconnecting,
 175    ReconnectionError {
 176        next_reconnection: Instant,
 177    },
 178}
 179
 180impl Status {
 181    pub fn is_connected(&self) -> bool {
 182        matches!(self, Self::Connected { .. })
 183    }
 184
 185    pub fn is_signed_out(&self) -> bool {
 186        matches!(self, Self::SignedOut | Self::UpgradeRequired)
 187    }
 188}
 189
 190struct ClientState {
 191    credentials: Option<Credentials>,
 192    status: (watch::Sender<Status>, watch::Receiver<Status>),
 193    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 194    _reconnect_task: Option<Task<()>>,
 195    reconnect_interval: Duration,
 196    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
 197    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 198    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 199    #[allow(clippy::type_complexity)]
 200    message_handlers: HashMap<
 201        TypeId,
 202        Arc<
 203            dyn Send
 204                + Sync
 205                + Fn(
 206                    Subscriber,
 207                    Box<dyn AnyTypedEnvelope>,
 208                    &Arc<Client>,
 209                    AsyncAppContext,
 210                ) -> LocalBoxFuture<'static, Result<()>>,
 211        >,
 212    >,
 213}
 214
 215enum WeakSubscriber {
 216    Model(AnyWeakModelHandle),
 217    View(AnyWeakViewHandle),
 218    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 219}
 220
 221enum Subscriber {
 222    Model(AnyModelHandle),
 223    View(AnyViewHandle),
 224}
 225
 226#[derive(Clone, Debug)]
 227pub struct Credentials {
 228    pub user_id: u64,
 229    pub access_token: String,
 230}
 231
 232impl Default for ClientState {
 233    fn default() -> Self {
 234        Self {
 235            credentials: None,
 236            status: watch::channel_with(Status::SignedOut),
 237            entity_id_extractors: Default::default(),
 238            _reconnect_task: None,
 239            reconnect_interval: Duration::from_secs(5),
 240            models_by_message_type: Default::default(),
 241            entities_by_type_and_remote_id: Default::default(),
 242            entity_types_by_message_type: Default::default(),
 243            message_handlers: Default::default(),
 244        }
 245    }
 246}
 247
 248pub enum Subscription {
 249    Entity {
 250        client: Weak<Client>,
 251        id: (TypeId, u64),
 252    },
 253    Message {
 254        client: Weak<Client>,
 255        id: TypeId,
 256    },
 257}
 258
 259impl Drop for Subscription {
 260    fn drop(&mut self) {
 261        match self {
 262            Subscription::Entity { client, id } => {
 263                if let Some(client) = client.upgrade() {
 264                    let mut state = client.state.write();
 265                    let _ = state.entities_by_type_and_remote_id.remove(id);
 266                }
 267            }
 268            Subscription::Message { client, id } => {
 269                if let Some(client) = client.upgrade() {
 270                    let mut state = client.state.write();
 271                    let _ = state.entity_types_by_message_type.remove(id);
 272                    let _ = state.message_handlers.remove(id);
 273                }
 274            }
 275        }
 276    }
 277}
 278
 279pub struct PendingEntitySubscription<T: Entity> {
 280    client: Arc<Client>,
 281    remote_id: u64,
 282    _entity_type: PhantomData<T>,
 283    consumed: bool,
 284}
 285
 286impl<T: Entity> PendingEntitySubscription<T> {
 287    pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
 288        self.consumed = true;
 289        let mut state = self.client.state.write();
 290        let id = (TypeId::of::<T>(), self.remote_id);
 291        let Some(WeakSubscriber::Pending(messages)) =
 292            state.entities_by_type_and_remote_id.remove(&id)
 293        else {
 294            unreachable!()
 295        };
 296
 297        state
 298            .entities_by_type_and_remote_id
 299            .insert(id, WeakSubscriber::Model(model.downgrade().into_any()));
 300        drop(state);
 301        for message in messages {
 302            self.client.handle_message(message, cx);
 303        }
 304        Subscription::Entity {
 305            client: Arc::downgrade(&self.client),
 306            id,
 307        }
 308    }
 309}
 310
 311impl<T: Entity> Drop for PendingEntitySubscription<T> {
 312    fn drop(&mut self) {
 313        if !self.consumed {
 314            let mut state = self.client.state.write();
 315            if let Some(WeakSubscriber::Pending(messages)) = state
 316                .entities_by_type_and_remote_id
 317                .remove(&(TypeId::of::<T>(), self.remote_id))
 318            {
 319                for message in messages {
 320                    log::info!("unhandled message {}", message.payload_type_name());
 321                }
 322            }
 323        }
 324    }
 325}
 326
 327impl Client {
 328    pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
 329        Arc::new(Self {
 330            id: 0,
 331            peer: Peer::new(0),
 332            telemetry: Telemetry::new(http.clone(), cx),
 333            http,
 334            state: Default::default(),
 335
 336            #[cfg(any(test, feature = "test-support"))]
 337            authenticate: Default::default(),
 338            #[cfg(any(test, feature = "test-support"))]
 339            establish_connection: Default::default(),
 340        })
 341    }
 342
 343    pub fn id(&self) -> usize {
 344        self.id
 345    }
 346
 347    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 348        self.http.clone()
 349    }
 350
 351    #[cfg(any(test, feature = "test-support"))]
 352    pub fn set_id(&mut self, id: usize) -> &Self {
 353        self.id = id;
 354        self
 355    }
 356
 357    #[cfg(any(test, feature = "test-support"))]
 358    pub fn teardown(&self) {
 359        let mut state = self.state.write();
 360        state._reconnect_task.take();
 361        state.message_handlers.clear();
 362        state.models_by_message_type.clear();
 363        state.entities_by_type_and_remote_id.clear();
 364        state.entity_id_extractors.clear();
 365        self.peer.teardown();
 366    }
 367
 368    #[cfg(any(test, feature = "test-support"))]
 369    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 370    where
 371        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 372    {
 373        *self.authenticate.write() = Some(Box::new(authenticate));
 374        self
 375    }
 376
 377    #[cfg(any(test, feature = "test-support"))]
 378    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 379    where
 380        F: 'static
 381            + Send
 382            + Sync
 383            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 384    {
 385        *self.establish_connection.write() = Some(Box::new(connect));
 386        self
 387    }
 388
 389    pub fn user_id(&self) -> Option<u64> {
 390        self.state
 391            .read()
 392            .credentials
 393            .as_ref()
 394            .map(|credentials| credentials.user_id)
 395    }
 396
 397    pub fn peer_id(&self) -> Option<PeerId> {
 398        if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
 399            Some(*peer_id)
 400        } else {
 401            None
 402        }
 403    }
 404
 405    pub fn status(&self) -> watch::Receiver<Status> {
 406        self.state.read().status.1.clone()
 407    }
 408
 409    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 410        log::info!("set status on client {}: {:?}", self.id, status);
 411        let mut state = self.state.write();
 412        *state.status.0.borrow_mut() = status;
 413
 414        match status {
 415            Status::Connected { .. } => {
 416                state._reconnect_task = None;
 417            }
 418            Status::ConnectionLost => {
 419                let this = self.clone();
 420                let reconnect_interval = state.reconnect_interval;
 421                state._reconnect_task = Some(cx.spawn(|cx| async move {
 422                    #[cfg(any(test, feature = "test-support"))]
 423                    let mut rng = StdRng::seed_from_u64(0);
 424                    #[cfg(not(any(test, feature = "test-support")))]
 425                    let mut rng = StdRng::from_entropy();
 426
 427                    let mut delay = INITIAL_RECONNECTION_DELAY;
 428                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 429                        log::error!("failed to connect {}", error);
 430                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 431                            this.set_status(
 432                                Status::ReconnectionError {
 433                                    next_reconnection: Instant::now() + delay,
 434                                },
 435                                &cx,
 436                            );
 437                            cx.background().timer(delay).await;
 438                            delay = delay
 439                                .mul_f32(rng.gen_range(1.0..=2.0))
 440                                .min(reconnect_interval);
 441                        } else {
 442                            break;
 443                        }
 444                    }
 445                }));
 446            }
 447            Status::SignedOut | Status::UpgradeRequired => {
 448                let telemetry_settings = cx.read(|cx| cx.global::<Settings>().telemetry());
 449                self.telemetry
 450                    .set_authenticated_user_info(None, false, telemetry_settings);
 451                state._reconnect_task.take();
 452            }
 453            _ => {}
 454        }
 455    }
 456
 457    pub fn add_view_for_remote_entity<T: View>(
 458        self: &Arc<Self>,
 459        remote_id: u64,
 460        cx: &mut ViewContext<T>,
 461    ) -> Subscription {
 462        let id = (TypeId::of::<T>(), remote_id);
 463        self.state
 464            .write()
 465            .entities_by_type_and_remote_id
 466            .insert(id, WeakSubscriber::View(cx.weak_handle().into_any()));
 467        Subscription::Entity {
 468            client: Arc::downgrade(self),
 469            id,
 470        }
 471    }
 472
 473    pub fn subscribe_to_entity<T: Entity>(
 474        self: &Arc<Self>,
 475        remote_id: u64,
 476    ) -> PendingEntitySubscription<T> {
 477        let id = (TypeId::of::<T>(), remote_id);
 478        self.state
 479            .write()
 480            .entities_by_type_and_remote_id
 481            .insert(id, WeakSubscriber::Pending(Default::default()));
 482
 483        PendingEntitySubscription {
 484            client: self.clone(),
 485            remote_id,
 486            consumed: false,
 487            _entity_type: PhantomData,
 488        }
 489    }
 490
 491    pub fn add_message_handler<M, E, H, F>(
 492        self: &Arc<Self>,
 493        model: ModelHandle<E>,
 494        handler: H,
 495    ) -> Subscription
 496    where
 497        M: EnvelopedMessage,
 498        E: Entity,
 499        H: 'static
 500            + Send
 501            + Sync
 502            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 503        F: 'static + Future<Output = Result<()>>,
 504    {
 505        let message_type_id = TypeId::of::<M>();
 506
 507        let mut state = self.state.write();
 508        state
 509            .models_by_message_type
 510            .insert(message_type_id, model.downgrade().into_any());
 511
 512        let prev_handler = state.message_handlers.insert(
 513            message_type_id,
 514            Arc::new(move |handle, envelope, client, cx| {
 515                let handle = if let Subscriber::Model(handle) = handle {
 516                    handle
 517                } else {
 518                    unreachable!();
 519                };
 520                let model = handle.downcast::<E>().unwrap();
 521                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 522                handler(model, *envelope, client.clone(), cx).boxed_local()
 523            }),
 524        );
 525        if prev_handler.is_some() {
 526            panic!("registered handler for the same message twice");
 527        }
 528
 529        Subscription::Message {
 530            client: Arc::downgrade(self),
 531            id: message_type_id,
 532        }
 533    }
 534
 535    pub fn add_request_handler<M, E, H, F>(
 536        self: &Arc<Self>,
 537        model: ModelHandle<E>,
 538        handler: H,
 539    ) -> Subscription
 540    where
 541        M: RequestMessage,
 542        E: Entity,
 543        H: 'static
 544            + Send
 545            + Sync
 546            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 547        F: 'static + Future<Output = Result<M::Response>>,
 548    {
 549        self.add_message_handler(model, move |handle, envelope, this, cx| {
 550            Self::respond_to_request(
 551                envelope.receipt(),
 552                handler(handle, envelope, this.clone(), cx),
 553                this,
 554            )
 555        })
 556    }
 557
 558    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 559    where
 560        M: EntityMessage,
 561        E: View,
 562        H: 'static
 563            + Send
 564            + Sync
 565            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 566        F: 'static + Future<Output = Result<()>>,
 567    {
 568        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 569            if let Subscriber::View(handle) = handle {
 570                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 571            } else {
 572                unreachable!();
 573            }
 574        })
 575    }
 576
 577    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 578    where
 579        M: EntityMessage,
 580        E: Entity,
 581        H: 'static
 582            + Send
 583            + Sync
 584            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 585        F: 'static + Future<Output = Result<()>>,
 586    {
 587        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 588            if let Subscriber::Model(handle) = handle {
 589                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 590            } else {
 591                unreachable!();
 592            }
 593        })
 594    }
 595
 596    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 597    where
 598        M: EntityMessage,
 599        E: Entity,
 600        H: 'static
 601            + Send
 602            + Sync
 603            + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 604        F: 'static + Future<Output = Result<()>>,
 605    {
 606        let model_type_id = TypeId::of::<E>();
 607        let message_type_id = TypeId::of::<M>();
 608
 609        let mut state = self.state.write();
 610        state
 611            .entity_types_by_message_type
 612            .insert(message_type_id, model_type_id);
 613        state
 614            .entity_id_extractors
 615            .entry(message_type_id)
 616            .or_insert_with(|| {
 617                |envelope| {
 618                    envelope
 619                        .as_any()
 620                        .downcast_ref::<TypedEnvelope<M>>()
 621                        .unwrap()
 622                        .payload
 623                        .remote_entity_id()
 624                }
 625            });
 626        let prev_handler = state.message_handlers.insert(
 627            message_type_id,
 628            Arc::new(move |handle, envelope, client, cx| {
 629                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 630                handler(handle, *envelope, client.clone(), cx).boxed_local()
 631            }),
 632        );
 633        if prev_handler.is_some() {
 634            panic!("registered handler for the same message twice");
 635        }
 636    }
 637
 638    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 639    where
 640        M: EntityMessage + RequestMessage,
 641        E: Entity,
 642        H: 'static
 643            + Send
 644            + Sync
 645            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 646        F: 'static + Future<Output = Result<M::Response>>,
 647    {
 648        self.add_model_message_handler(move |entity, envelope, client, cx| {
 649            Self::respond_to_request::<M, _>(
 650                envelope.receipt(),
 651                handler(entity, envelope, client.clone(), cx),
 652                client,
 653            )
 654        })
 655    }
 656
 657    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 658    where
 659        M: EntityMessage + RequestMessage,
 660        E: View,
 661        H: 'static
 662            + Send
 663            + Sync
 664            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 665        F: 'static + Future<Output = Result<M::Response>>,
 666    {
 667        self.add_view_message_handler(move |entity, envelope, client, cx| {
 668            Self::respond_to_request::<M, _>(
 669                envelope.receipt(),
 670                handler(entity, envelope, client.clone(), cx),
 671                client,
 672            )
 673        })
 674    }
 675
 676    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 677        receipt: Receipt<T>,
 678        response: F,
 679        client: Arc<Self>,
 680    ) -> Result<()> {
 681        match response.await {
 682            Ok(response) => {
 683                client.respond(receipt, response)?;
 684                Ok(())
 685            }
 686            Err(error) => {
 687                client.respond_with_error(
 688                    receipt,
 689                    proto::Error {
 690                        message: format!("{:?}", error),
 691                    },
 692                )?;
 693                Err(error)
 694            }
 695        }
 696    }
 697
 698    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 699        read_credentials_from_keychain(cx).is_some()
 700    }
 701
 702    #[async_recursion(?Send)]
 703    pub async fn authenticate_and_connect(
 704        self: &Arc<Self>,
 705        try_keychain: bool,
 706        cx: &AsyncAppContext,
 707    ) -> anyhow::Result<()> {
 708        let was_disconnected = match *self.status().borrow() {
 709            Status::SignedOut => true,
 710            Status::ConnectionError
 711            | Status::ConnectionLost
 712            | Status::Authenticating { .. }
 713            | Status::Reauthenticating { .. }
 714            | Status::ReconnectionError { .. } => false,
 715            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 716                return Ok(())
 717            }
 718            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 719        };
 720
 721        if was_disconnected {
 722            self.set_status(Status::Authenticating, cx);
 723        } else {
 724            self.set_status(Status::Reauthenticating, cx)
 725        }
 726
 727        let mut read_from_keychain = false;
 728        let mut credentials = self.state.read().credentials.clone();
 729        if credentials.is_none() && try_keychain {
 730            credentials = read_credentials_from_keychain(cx);
 731            read_from_keychain = credentials.is_some();
 732            if read_from_keychain {
 733                cx.read(|cx| {
 734                    self.report_event(
 735                        "read credentials from keychain",
 736                        Default::default(),
 737                        cx.global::<Settings>().telemetry(),
 738                    );
 739                });
 740            }
 741        }
 742        if credentials.is_none() {
 743            let mut status_rx = self.status();
 744            let _ = status_rx.next().await;
 745            futures::select_biased! {
 746                authenticate = self.authenticate(cx).fuse() => {
 747                    match authenticate {
 748                        Ok(creds) => credentials = Some(creds),
 749                        Err(err) => {
 750                            self.set_status(Status::ConnectionError, cx);
 751                            return Err(err);
 752                        }
 753                    }
 754                }
 755                _ = status_rx.next().fuse() => {
 756                    return Err(anyhow!("authentication canceled"));
 757                }
 758            }
 759        }
 760        let credentials = credentials.unwrap();
 761
 762        if was_disconnected {
 763            self.set_status(Status::Connecting, cx);
 764        } else {
 765            self.set_status(Status::Reconnecting, cx);
 766        }
 767
 768        let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
 769        futures::select_biased! {
 770            connection = self.establish_connection(&credentials, cx).fuse() => {
 771                match connection {
 772                    Ok(conn) => {
 773                        self.state.write().credentials = Some(credentials.clone());
 774                        if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 775                            write_credentials_to_keychain(&credentials, cx).log_err();
 776                        }
 777
 778                        futures::select_biased! {
 779                            result = self.set_connection(conn, cx).fuse() => result,
 780                            _ = timeout => {
 781                                self.set_status(Status::ConnectionError, cx);
 782                                Err(anyhow!("timed out waiting on hello message from server"))
 783                            }
 784                        }
 785                    }
 786                    Err(EstablishConnectionError::Unauthorized) => {
 787                        self.state.write().credentials.take();
 788                        if read_from_keychain {
 789                            cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 790                            self.set_status(Status::SignedOut, cx);
 791                            self.authenticate_and_connect(false, cx).await
 792                        } else {
 793                            self.set_status(Status::ConnectionError, cx);
 794                            Err(EstablishConnectionError::Unauthorized)?
 795                        }
 796                    }
 797                    Err(EstablishConnectionError::UpgradeRequired) => {
 798                        self.set_status(Status::UpgradeRequired, cx);
 799                        Err(EstablishConnectionError::UpgradeRequired)?
 800                    }
 801                    Err(error) => {
 802                        self.set_status(Status::ConnectionError, cx);
 803                        Err(error)?
 804                    }
 805                }
 806            }
 807            _ = &mut timeout => {
 808                self.set_status(Status::ConnectionError, cx);
 809                Err(anyhow!("timed out trying to establish connection"))
 810            }
 811        }
 812    }
 813
 814    async fn set_connection(
 815        self: &Arc<Self>,
 816        conn: Connection,
 817        cx: &AsyncAppContext,
 818    ) -> Result<()> {
 819        let executor = cx.background();
 820        log::info!("add connection to peer");
 821        let (connection_id, handle_io, mut incoming) = self
 822            .peer
 823            .add_connection(conn, move |duration| executor.timer(duration));
 824        let handle_io = cx.background().spawn(handle_io);
 825
 826        let peer_id = async {
 827            log::info!("waiting for server hello");
 828            let message = incoming
 829                .next()
 830                .await
 831                .ok_or_else(|| anyhow!("no hello message received"))?;
 832            log::info!("got server hello");
 833            let hello_message_type_name = message.payload_type_name().to_string();
 834            let hello = message
 835                .into_any()
 836                .downcast::<TypedEnvelope<proto::Hello>>()
 837                .map_err(|_| {
 838                    anyhow!(
 839                        "invalid hello message received: {:?}",
 840                        hello_message_type_name
 841                    )
 842                })?;
 843            let peer_id = hello
 844                .payload
 845                .peer_id
 846                .ok_or_else(|| anyhow!("invalid peer id"))?;
 847            Ok(peer_id)
 848        };
 849
 850        let peer_id = match peer_id.await {
 851            Ok(peer_id) => peer_id,
 852            Err(error) => {
 853                self.peer.disconnect(connection_id);
 854                return Err(error);
 855            }
 856        };
 857
 858        log::info!(
 859            "set status to connected (connection id: {:?}, peer id: {:?})",
 860            connection_id,
 861            peer_id
 862        );
 863        self.set_status(
 864            Status::Connected {
 865                peer_id,
 866                connection_id,
 867            },
 868            cx,
 869        );
 870        cx.foreground()
 871            .spawn({
 872                let cx = cx.clone();
 873                let this = self.clone();
 874                async move {
 875                    while let Some(message) = incoming.next().await {
 876                        this.handle_message(message, &cx);
 877                        // Don't starve the main thread when receiving lots of messages at once.
 878                        smol::future::yield_now().await;
 879                    }
 880                }
 881            })
 882            .detach();
 883
 884        let this = self.clone();
 885        let cx = cx.clone();
 886        cx.foreground()
 887            .spawn(async move {
 888                match handle_io.await {
 889                    Ok(()) => {
 890                        if this.status().borrow().clone()
 891                            == (Status::Connected {
 892                                connection_id,
 893                                peer_id,
 894                            })
 895                        {
 896                            this.set_status(Status::SignedOut, &cx);
 897                        }
 898                    }
 899                    Err(err) => {
 900                        log::error!("connection error: {:?}", err);
 901                        this.set_status(Status::ConnectionLost, &cx);
 902                    }
 903                }
 904            })
 905            .detach();
 906
 907        Ok(())
 908    }
 909
 910    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 911        #[cfg(any(test, feature = "test-support"))]
 912        if let Some(callback) = self.authenticate.read().as_ref() {
 913            return callback(cx);
 914        }
 915
 916        self.authenticate_with_browser(cx)
 917    }
 918
 919    fn establish_connection(
 920        self: &Arc<Self>,
 921        credentials: &Credentials,
 922        cx: &AsyncAppContext,
 923    ) -> Task<Result<Connection, EstablishConnectionError>> {
 924        #[cfg(any(test, feature = "test-support"))]
 925        if let Some(callback) = self.establish_connection.read().as_ref() {
 926            return callback(credentials, cx);
 927        }
 928
 929        self.establish_websocket_connection(credentials, cx)
 930    }
 931
 932    async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
 933        let preview_param = if is_preview { "?preview=1" } else { "" };
 934        let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
 935        let response = http.get(&url, Default::default(), false).await?;
 936
 937        // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
 938        // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
 939        // which requires authorization via an HTTP header.
 940        //
 941        // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
 942        // of a collab server. In that case, a request to the /rpc endpoint will
 943        // return an 'unauthorized' response.
 944        let collab_url = if response.status().is_redirection() {
 945            response
 946                .headers()
 947                .get("Location")
 948                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 949                .to_str()
 950                .map_err(EstablishConnectionError::other)?
 951                .to_string()
 952        } else if response.status() == StatusCode::UNAUTHORIZED {
 953            url
 954        } else {
 955            Err(anyhow!(
 956                "unexpected /rpc response status {}",
 957                response.status()
 958            ))?
 959        };
 960
 961        Url::parse(&collab_url).context("invalid rpc url")
 962    }
 963
 964    fn establish_websocket_connection(
 965        self: &Arc<Self>,
 966        credentials: &Credentials,
 967        cx: &AsyncAppContext,
 968    ) -> Task<Result<Connection, EstablishConnectionError>> {
 969        let is_preview = cx.read(|cx| {
 970            if cx.has_global::<ReleaseChannel>() {
 971                *cx.global::<ReleaseChannel>() == ReleaseChannel::Preview
 972            } else {
 973                false
 974            }
 975        });
 976
 977        let request = Request::builder()
 978            .header(
 979                "Authorization",
 980                format!("{} {}", credentials.user_id, credentials.access_token),
 981            )
 982            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
 983
 984        let http = self.http.clone();
 985        cx.background().spawn(async move {
 986            let mut rpc_url = Self::get_rpc_url(http, is_preview).await?;
 987            let rpc_host = rpc_url
 988                .host_str()
 989                .zip(rpc_url.port_or_known_default())
 990                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
 991            let stream = smol::net::TcpStream::connect(rpc_host).await?;
 992
 993            log::info!("connected to rpc endpoint {}", rpc_url);
 994
 995            match rpc_url.scheme() {
 996                "https" => {
 997                    rpc_url.set_scheme("wss").unwrap();
 998                    let request = request.uri(rpc_url.as_str()).body(())?;
 999                    let (stream, _) =
1000                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1001                    Ok(Connection::new(
1002                        stream
1003                            .map_err(|error| anyhow!(error))
1004                            .sink_map_err(|error| anyhow!(error)),
1005                    ))
1006                }
1007                "http" => {
1008                    rpc_url.set_scheme("ws").unwrap();
1009                    let request = request.uri(rpc_url.as_str()).body(())?;
1010                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1011                    Ok(Connection::new(
1012                        stream
1013                            .map_err(|error| anyhow!(error))
1014                            .sink_map_err(|error| anyhow!(error)),
1015                    ))
1016                }
1017                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1018            }
1019        })
1020    }
1021
1022    pub fn authenticate_with_browser(
1023        self: &Arc<Self>,
1024        cx: &AsyncAppContext,
1025    ) -> Task<Result<Credentials>> {
1026        let platform = cx.platform();
1027        let executor = cx.background();
1028        let telemetry = self.telemetry.clone();
1029        let http = self.http.clone();
1030        let metrics_enabled = cx.read(|cx| cx.global::<Settings>().telemetry());
1031
1032        executor.clone().spawn(async move {
1033            // Generate a pair of asymmetric encryption keys. The public key will be used by the
1034            // zed server to encrypt the user's access token, so that it can'be intercepted by
1035            // any other app running on the user's device.
1036            let (public_key, private_key) =
1037                rpc::auth::keypair().expect("failed to generate keypair for auth");
1038            let public_key_string =
1039                String::try_from(public_key).expect("failed to serialize public key for auth");
1040
1041            if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1042                return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1043            }
1044
1045            // Start an HTTP server to receive the redirect from Zed's sign-in page.
1046            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1047            let port = server.server_addr().port();
1048
1049            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1050            // that the user is signing in from a Zed app running on the same device.
1051            let mut url = format!(
1052                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1053                *ZED_SERVER_URL, port, public_key_string
1054            );
1055
1056            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1057                log::info!("impersonating user @{}", impersonate_login);
1058                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1059            }
1060
1061            platform.open_url(&url);
1062
1063            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1064            // access token from the query params.
1065            //
1066            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1067            // custom URL scheme instead of this local HTTP server.
1068            let (user_id, access_token) = executor
1069                .spawn(async move {
1070                    for _ in 0..100 {
1071                        if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1072                            let path = req.url();
1073                            let mut user_id = None;
1074                            let mut access_token = None;
1075                            let url = Url::parse(&format!("http://example.com{}", path))
1076                                .context("failed to parse login notification url")?;
1077                            for (key, value) in url.query_pairs() {
1078                                if key == "access_token" {
1079                                    access_token = Some(value.to_string());
1080                                } else if key == "user_id" {
1081                                    user_id = Some(value.to_string());
1082                                }
1083                            }
1084
1085                            let post_auth_url =
1086                                format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1087                            req.respond(
1088                                tiny_http::Response::empty(302).with_header(
1089                                    tiny_http::Header::from_bytes(
1090                                        &b"Location"[..],
1091                                        post_auth_url.as_bytes(),
1092                                    )
1093                                    .unwrap(),
1094                                ),
1095                            )
1096                            .context("failed to respond to login http request")?;
1097                            return Ok((
1098                                user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1099                                access_token
1100                                    .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1101                            ));
1102                        }
1103                    }
1104
1105                    Err(anyhow!("didn't receive login redirect"))
1106                })
1107                .await?;
1108
1109            let access_token = private_key
1110                .decrypt_string(&access_token)
1111                .context("failed to decrypt access token")?;
1112            platform.activate(true);
1113
1114            telemetry.report_event(
1115                "authenticate with browser",
1116                Default::default(),
1117                metrics_enabled,
1118            );
1119
1120            Ok(Credentials {
1121                user_id: user_id.parse()?,
1122                access_token,
1123            })
1124        })
1125    }
1126
1127    async fn authenticate_as_admin(
1128        http: Arc<dyn HttpClient>,
1129        login: String,
1130        mut api_token: String,
1131    ) -> Result<Credentials> {
1132        #[derive(Deserialize)]
1133        struct AuthenticatedUserResponse {
1134            user: User,
1135        }
1136
1137        #[derive(Deserialize)]
1138        struct User {
1139            id: u64,
1140        }
1141
1142        // Use the collab server's admin API to retrieve the id
1143        // of the impersonated user.
1144        let mut url = Self::get_rpc_url(http.clone(), false).await?;
1145        url.set_path("/user");
1146        url.set_query(Some(&format!("github_login={login}")));
1147        let request = Request::get(url.as_str())
1148            .header("Authorization", format!("token {api_token}"))
1149            .body("".into())?;
1150
1151        let mut response = http.send(request).await?;
1152        let mut body = String::new();
1153        response.body_mut().read_to_string(&mut body).await?;
1154        if !response.status().is_success() {
1155            Err(anyhow!(
1156                "admin user request failed {} - {}",
1157                response.status().as_u16(),
1158                body,
1159            ))?;
1160        }
1161        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1162
1163        // Use the admin API token to authenticate as the impersonated user.
1164        api_token.insert_str(0, "ADMIN_TOKEN:");
1165        Ok(Credentials {
1166            user_id: response.user.id,
1167            access_token: api_token,
1168        })
1169    }
1170
1171    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1172        self.peer.teardown();
1173        self.set_status(Status::SignedOut, cx);
1174    }
1175
1176    fn connection_id(&self) -> Result<ConnectionId> {
1177        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1178            Ok(connection_id)
1179        } else {
1180            Err(anyhow!("not connected"))
1181        }
1182    }
1183
1184    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1185        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1186        self.peer.send(self.connection_id()?, message)
1187    }
1188
1189    pub fn request<T: RequestMessage>(
1190        &self,
1191        request: T,
1192    ) -> impl Future<Output = Result<T::Response>> {
1193        self.request_envelope(request)
1194            .map_ok(|envelope| envelope.payload)
1195    }
1196
1197    pub fn request_envelope<T: RequestMessage>(
1198        &self,
1199        request: T,
1200    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1201        let client_id = self.id;
1202        log::debug!(
1203            "rpc request start. client_id:{}. name:{}",
1204            client_id,
1205            T::NAME
1206        );
1207        let response = self
1208            .connection_id()
1209            .map(|conn_id| self.peer.request_envelope(conn_id, request));
1210        async move {
1211            let response = response?.await;
1212            log::debug!(
1213                "rpc request finish. client_id:{}. name:{}",
1214                client_id,
1215                T::NAME
1216            );
1217            response
1218        }
1219    }
1220
1221    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1222        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1223        self.peer.respond(receipt, response)
1224    }
1225
1226    fn respond_with_error<T: RequestMessage>(
1227        &self,
1228        receipt: Receipt<T>,
1229        error: proto::Error,
1230    ) -> Result<()> {
1231        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1232        self.peer.respond_with_error(receipt, error)
1233    }
1234
1235    fn handle_message(
1236        self: &Arc<Client>,
1237        message: Box<dyn AnyTypedEnvelope>,
1238        cx: &AsyncAppContext,
1239    ) {
1240        let mut state = self.state.write();
1241        let type_name = message.payload_type_name();
1242        let payload_type_id = message.payload_type_id();
1243        let sender_id = message.original_sender_id();
1244
1245        let mut subscriber = None;
1246
1247        if let Some(message_model) = state
1248            .models_by_message_type
1249            .get(&payload_type_id)
1250            .and_then(|model| model.upgrade(cx))
1251        {
1252            subscriber = Some(Subscriber::Model(message_model));
1253        } else if let Some((extract_entity_id, entity_type_id)) =
1254            state.entity_id_extractors.get(&payload_type_id).zip(
1255                state
1256                    .entity_types_by_message_type
1257                    .get(&payload_type_id)
1258                    .copied(),
1259            )
1260        {
1261            let entity_id = (extract_entity_id)(message.as_ref());
1262
1263            match state
1264                .entities_by_type_and_remote_id
1265                .get_mut(&(entity_type_id, entity_id))
1266            {
1267                Some(WeakSubscriber::Pending(pending)) => {
1268                    pending.push(message);
1269                    return;
1270                }
1271                Some(weak_subscriber @ _) => subscriber = weak_subscriber.upgrade(cx),
1272                _ => {}
1273            }
1274        }
1275
1276        let subscriber = if let Some(subscriber) = subscriber {
1277            subscriber
1278        } else {
1279            log::info!("unhandled message {}", type_name);
1280            self.peer.respond_with_unhandled_message(message).log_err();
1281            return;
1282        };
1283
1284        let handler = state.message_handlers.get(&payload_type_id).cloned();
1285        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1286        // It also ensures we don't hold the lock while yielding back to the executor, as
1287        // that might cause the executor thread driving this future to block indefinitely.
1288        drop(state);
1289
1290        if let Some(handler) = handler {
1291            let future = handler(subscriber, message, &self, cx.clone());
1292            let client_id = self.id;
1293            log::debug!(
1294                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1295                client_id,
1296                sender_id,
1297                type_name
1298            );
1299            cx.foreground()
1300                .spawn(async move {
1301                    match future.await {
1302                        Ok(()) => {
1303                            log::debug!(
1304                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1305                                client_id,
1306                                sender_id,
1307                                type_name
1308                            );
1309                        }
1310                        Err(error) => {
1311                            log::error!(
1312                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1313                                client_id,
1314                                sender_id,
1315                                type_name,
1316                                error
1317                            );
1318                        }
1319                    }
1320                })
1321                .detach();
1322        } else {
1323            log::info!("unhandled message {}", type_name);
1324            self.peer.respond_with_unhandled_message(message).log_err();
1325        }
1326    }
1327
1328    pub fn start_telemetry(&self) {
1329        self.telemetry.start();
1330    }
1331
1332    pub fn report_event(
1333        &self,
1334        kind: &str,
1335        properties: Value,
1336        telemetry_settings: TelemetrySettings,
1337    ) {
1338        self.telemetry
1339            .report_event(kind, properties.clone(), telemetry_settings);
1340    }
1341
1342    pub fn telemetry_log_file_path(&self) -> Option<PathBuf> {
1343        self.telemetry.log_file_path()
1344    }
1345
1346    pub fn metrics_id(&self) -> Option<Arc<str>> {
1347        self.telemetry.metrics_id()
1348    }
1349
1350    pub fn is_staff(&self) -> Option<bool> {
1351        self.telemetry.is_staff()
1352    }
1353}
1354
1355impl WeakSubscriber {
1356    fn upgrade(&self, cx: &AsyncAppContext) -> Option<Subscriber> {
1357        match self {
1358            WeakSubscriber::Model(handle) => handle.upgrade(cx).map(Subscriber::Model),
1359            WeakSubscriber::View(handle) => handle.upgrade(cx).map(Subscriber::View),
1360            WeakSubscriber::Pending(_) => None,
1361        }
1362    }
1363}
1364
1365fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1366    if IMPERSONATE_LOGIN.is_some() {
1367        return None;
1368    }
1369
1370    let (user_id, access_token) = cx
1371        .platform()
1372        .read_credentials(&ZED_SERVER_URL)
1373        .log_err()
1374        .flatten()?;
1375    Some(Credentials {
1376        user_id: user_id.parse().ok()?,
1377        access_token: String::from_utf8(access_token).ok()?,
1378    })
1379}
1380
1381fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1382    cx.platform().write_credentials(
1383        &ZED_SERVER_URL,
1384        &credentials.user_id.to_string(),
1385        credentials.access_token.as_bytes(),
1386    )
1387}
1388
1389const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1390
1391pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1392    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1393}
1394
1395pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1396    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1397    let mut parts = path.split('/');
1398    let id = parts.next()?.parse::<u64>().ok()?;
1399    let access_token = parts.next()?;
1400    if access_token.is_empty() {
1401        return None;
1402    }
1403    Some((id, access_token.to_string()))
1404}
1405
1406#[cfg(test)]
1407mod tests {
1408    use super::*;
1409    use crate::test::FakeServer;
1410    use gpui::{executor::Deterministic, TestAppContext};
1411    use parking_lot::Mutex;
1412    use std::future;
1413    use util::http::FakeHttpClient;
1414
1415    #[gpui::test(iterations = 10)]
1416    async fn test_reconnection(cx: &mut TestAppContext) {
1417        cx.foreground().forbid_parking();
1418
1419        let user_id = 5;
1420        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1421        let server = FakeServer::for_client(user_id, &client, cx).await;
1422        let mut status = client.status();
1423        assert!(matches!(
1424            status.next().await,
1425            Some(Status::Connected { .. })
1426        ));
1427        assert_eq!(server.auth_count(), 1);
1428
1429        server.forbid_connections();
1430        server.disconnect();
1431        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1432
1433        server.allow_connections();
1434        cx.foreground().advance_clock(Duration::from_secs(10));
1435        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1436        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1437
1438        server.forbid_connections();
1439        server.disconnect();
1440        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1441
1442        // Clear cached credentials after authentication fails
1443        server.roll_access_token();
1444        server.allow_connections();
1445        cx.foreground().advance_clock(Duration::from_secs(10));
1446        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1447        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1448    }
1449
1450    #[gpui::test(iterations = 10)]
1451    async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1452        deterministic.forbid_parking();
1453
1454        let user_id = 5;
1455        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1456        let mut status = client.status();
1457
1458        // Time out when client tries to connect.
1459        client.override_authenticate(move |cx| {
1460            cx.foreground().spawn(async move {
1461                Ok(Credentials {
1462                    user_id,
1463                    access_token: "token".into(),
1464                })
1465            })
1466        });
1467        client.override_establish_connection(|_, cx| {
1468            cx.foreground().spawn(async move {
1469                future::pending::<()>().await;
1470                unreachable!()
1471            })
1472        });
1473        let auth_and_connect = cx.spawn({
1474            let client = client.clone();
1475            |cx| async move { client.authenticate_and_connect(false, &cx).await }
1476        });
1477        deterministic.run_until_parked();
1478        assert!(matches!(status.next().await, Some(Status::Connecting)));
1479
1480        deterministic.advance_clock(CONNECTION_TIMEOUT);
1481        assert!(matches!(
1482            status.next().await,
1483            Some(Status::ConnectionError { .. })
1484        ));
1485        auth_and_connect.await.unwrap_err();
1486
1487        // Allow the connection to be established.
1488        let server = FakeServer::for_client(user_id, &client, cx).await;
1489        assert!(matches!(
1490            status.next().await,
1491            Some(Status::Connected { .. })
1492        ));
1493
1494        // Disconnect client.
1495        server.forbid_connections();
1496        server.disconnect();
1497        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1498
1499        // Time out when re-establishing the connection.
1500        server.allow_connections();
1501        client.override_establish_connection(|_, cx| {
1502            cx.foreground().spawn(async move {
1503                future::pending::<()>().await;
1504                unreachable!()
1505            })
1506        });
1507        deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1508        assert!(matches!(
1509            status.next().await,
1510            Some(Status::Reconnecting { .. })
1511        ));
1512
1513        deterministic.advance_clock(CONNECTION_TIMEOUT);
1514        assert!(matches!(
1515            status.next().await,
1516            Some(Status::ReconnectionError { .. })
1517        ));
1518    }
1519
1520    #[gpui::test(iterations = 10)]
1521    async fn test_authenticating_more_than_once(
1522        cx: &mut TestAppContext,
1523        deterministic: Arc<Deterministic>,
1524    ) {
1525        cx.foreground().forbid_parking();
1526
1527        let auth_count = Arc::new(Mutex::new(0));
1528        let dropped_auth_count = Arc::new(Mutex::new(0));
1529        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1530        client.override_authenticate({
1531            let auth_count = auth_count.clone();
1532            let dropped_auth_count = dropped_auth_count.clone();
1533            move |cx| {
1534                let auth_count = auth_count.clone();
1535                let dropped_auth_count = dropped_auth_count.clone();
1536                cx.foreground().spawn(async move {
1537                    *auth_count.lock() += 1;
1538                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1539                    future::pending::<()>().await;
1540                    unreachable!()
1541                })
1542            }
1543        });
1544
1545        let _authenticate = cx.spawn(|cx| {
1546            let client = client.clone();
1547            async move { client.authenticate_and_connect(false, &cx).await }
1548        });
1549        deterministic.run_until_parked();
1550        assert_eq!(*auth_count.lock(), 1);
1551        assert_eq!(*dropped_auth_count.lock(), 0);
1552
1553        let _authenticate = cx.spawn(|cx| {
1554            let client = client.clone();
1555            async move { client.authenticate_and_connect(false, &cx).await }
1556        });
1557        deterministic.run_until_parked();
1558        assert_eq!(*auth_count.lock(), 2);
1559        assert_eq!(*dropped_auth_count.lock(), 1);
1560    }
1561
1562    #[test]
1563    fn test_encode_and_decode_worktree_url() {
1564        let url = encode_worktree_url(5, "deadbeef");
1565        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1566        assert_eq!(
1567            decode_worktree_url(&format!("\n {}\t", url)),
1568            Some((5, "deadbeef".to_string()))
1569        );
1570        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1571    }
1572
1573    #[gpui::test]
1574    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1575        cx.foreground().forbid_parking();
1576
1577        let user_id = 5;
1578        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1579        let server = FakeServer::for_client(user_id, &client, cx).await;
1580
1581        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1582        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1583        client.add_model_message_handler(
1584            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1585                match model.read_with(&cx, |model, _| model.id) {
1586                    1 => done_tx1.try_send(()).unwrap(),
1587                    2 => done_tx2.try_send(()).unwrap(),
1588                    _ => unreachable!(),
1589                }
1590                async { Ok(()) }
1591            },
1592        );
1593        let model1 = cx.add_model(|_| Model {
1594            id: 1,
1595            subscription: None,
1596        });
1597        let model2 = cx.add_model(|_| Model {
1598            id: 2,
1599            subscription: None,
1600        });
1601        let model3 = cx.add_model(|_| Model {
1602            id: 3,
1603            subscription: None,
1604        });
1605
1606        let _subscription1 = client
1607            .subscribe_to_entity(1)
1608            .set_model(&model1, &mut cx.to_async());
1609        let _subscription2 = client
1610            .subscribe_to_entity(2)
1611            .set_model(&model2, &mut cx.to_async());
1612        // Ensure dropping a subscription for the same entity type still allows receiving of
1613        // messages for other entity IDs of the same type.
1614        let subscription3 = client
1615            .subscribe_to_entity(3)
1616            .set_model(&model3, &mut cx.to_async());
1617        drop(subscription3);
1618
1619        server.send(proto::JoinProject { project_id: 1 });
1620        server.send(proto::JoinProject { project_id: 2 });
1621        done_rx1.next().await.unwrap();
1622        done_rx2.next().await.unwrap();
1623    }
1624
1625    #[gpui::test]
1626    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1627        cx.foreground().forbid_parking();
1628
1629        let user_id = 5;
1630        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1631        let server = FakeServer::for_client(user_id, &client, cx).await;
1632
1633        let model = cx.add_model(|_| Model::default());
1634        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1635        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1636        let subscription1 = client.add_message_handler(
1637            model.clone(),
1638            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1639                done_tx1.try_send(()).unwrap();
1640                async { Ok(()) }
1641            },
1642        );
1643        drop(subscription1);
1644        let _subscription2 =
1645            client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1646                done_tx2.try_send(()).unwrap();
1647                async { Ok(()) }
1648            });
1649        server.send(proto::Ping {});
1650        done_rx2.next().await.unwrap();
1651    }
1652
1653    #[gpui::test]
1654    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1655        cx.foreground().forbid_parking();
1656
1657        let user_id = 5;
1658        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1659        let server = FakeServer::for_client(user_id, &client, cx).await;
1660
1661        let model = cx.add_model(|_| Model::default());
1662        let (done_tx, mut done_rx) = smol::channel::unbounded();
1663        let subscription = client.add_message_handler(
1664            model.clone(),
1665            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1666                model.update(&mut cx, |model, _| model.subscription.take());
1667                done_tx.try_send(()).unwrap();
1668                async { Ok(()) }
1669            },
1670        );
1671        model.update(cx, |model, _| {
1672            model.subscription = Some(subscription);
1673        });
1674        server.send(proto::Ping {});
1675        done_rx.next().await.unwrap();
1676    }
1677
1678    #[derive(Default)]
1679    struct Model {
1680        id: usize,
1681        subscription: Option<Subscription>,
1682    }
1683
1684    impl Entity for Model {
1685        type Event = ();
1686    }
1687}