client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod channel;
   5pub mod http;
   6pub mod telemetry;
   7pub mod user;
   8
   9use anyhow::{anyhow, Context, Result};
  10use async_recursion::async_recursion;
  11use async_tungstenite::tungstenite::{
  12    error::Error as WebsocketError,
  13    http::{Request, StatusCode},
  14};
  15use db::Db;
  16use futures::{future::LocalBoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt};
  17use gpui::{
  18    actions,
  19    serde_json::{json, Value},
  20    AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AppContext,
  21    AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext,
  22    ViewHandle,
  23};
  24use http::HttpClient;
  25use lazy_static::lazy_static;
  26use parking_lot::RwLock;
  27use postage::watch;
  28use rand::prelude::*;
  29use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
  30use std::{
  31    any::TypeId,
  32    collections::HashMap,
  33    convert::TryFrom,
  34    fmt::Write as _,
  35    future::Future,
  36    sync::{Arc, Weak},
  37    time::{Duration, Instant},
  38};
  39use telemetry::Telemetry;
  40use thiserror::Error;
  41use url::Url;
  42use util::{ResultExt, TryFutureExt};
  43
  44pub use channel::*;
  45pub use rpc::*;
  46pub use user::*;
  47
  48lazy_static! {
  49    pub static ref ZED_SERVER_URL: String =
  50        std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
  51    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  52        .ok()
  53        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  54}
  55
  56pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
  57
  58actions!(client, [Authenticate, TestTelemetry]);
  59
  60pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
  61    cx.add_global_action({
  62        let client = client.clone();
  63        move |_: &Authenticate, cx| {
  64            let client = client.clone();
  65            cx.spawn(
  66                |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
  67            )
  68            .detach();
  69        }
  70    });
  71    cx.add_global_action({
  72        let client = client.clone();
  73        move |_: &TestTelemetry, _| {
  74            client.report_event(
  75                "test_telemetry",
  76                json!({
  77                    "test_property": "test_value"
  78                }),
  79            )
  80        }
  81    });
  82}
  83
  84pub struct Client {
  85    id: usize,
  86    peer: Arc<Peer>,
  87    http: Arc<dyn HttpClient>,
  88    telemetry: Arc<Telemetry>,
  89    state: RwLock<ClientState>,
  90
  91    #[allow(clippy::type_complexity)]
  92    #[cfg(any(test, feature = "test-support"))]
  93    authenticate: RwLock<
  94        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
  95    >,
  96
  97    #[allow(clippy::type_complexity)]
  98    #[cfg(any(test, feature = "test-support"))]
  99    establish_connection: RwLock<
 100        Option<
 101            Box<
 102                dyn 'static
 103                    + Send
 104                    + Sync
 105                    + Fn(
 106                        &Credentials,
 107                        &AsyncAppContext,
 108                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 109            >,
 110        >,
 111    >,
 112}
 113
 114#[derive(Error, Debug)]
 115pub enum EstablishConnectionError {
 116    #[error("upgrade required")]
 117    UpgradeRequired,
 118    #[error("unauthorized")]
 119    Unauthorized,
 120    #[error("{0}")]
 121    Other(#[from] anyhow::Error),
 122    #[error("{0}")]
 123    Http(#[from] http::Error),
 124    #[error("{0}")]
 125    Io(#[from] std::io::Error),
 126    #[error("{0}")]
 127    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 128}
 129
 130impl From<WebsocketError> for EstablishConnectionError {
 131    fn from(error: WebsocketError) -> Self {
 132        if let WebsocketError::Http(response) = &error {
 133            match response.status() {
 134                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 135                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 136                _ => {}
 137            }
 138        }
 139        EstablishConnectionError::Other(error.into())
 140    }
 141}
 142
 143impl EstablishConnectionError {
 144    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 145        Self::Other(error.into())
 146    }
 147}
 148
 149#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 150pub enum Status {
 151    SignedOut,
 152    UpgradeRequired,
 153    Authenticating,
 154    Connecting,
 155    ConnectionError,
 156    Connected { connection_id: ConnectionId },
 157    ConnectionLost,
 158    Reauthenticating,
 159    Reconnecting,
 160    ReconnectionError { next_reconnection: Instant },
 161}
 162
 163impl Status {
 164    pub fn is_connected(&self) -> bool {
 165        matches!(self, Self::Connected { .. })
 166    }
 167}
 168
 169struct ClientState {
 170    credentials: Option<Credentials>,
 171    status: (watch::Sender<Status>, watch::Receiver<Status>),
 172    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 173    _reconnect_task: Option<Task<()>>,
 174    reconnect_interval: Duration,
 175    entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
 176    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 177    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 178    #[allow(clippy::type_complexity)]
 179    message_handlers: HashMap<
 180        TypeId,
 181        Arc<
 182            dyn Send
 183                + Sync
 184                + Fn(
 185                    AnyEntityHandle,
 186                    Box<dyn AnyTypedEnvelope>,
 187                    &Arc<Client>,
 188                    AsyncAppContext,
 189                ) -> LocalBoxFuture<'static, Result<()>>,
 190        >,
 191    >,
 192}
 193
 194enum AnyWeakEntityHandle {
 195    Model(AnyWeakModelHandle),
 196    View(AnyWeakViewHandle),
 197}
 198
 199enum AnyEntityHandle {
 200    Model(AnyModelHandle),
 201    View(AnyViewHandle),
 202}
 203
 204#[derive(Clone, Debug)]
 205pub struct Credentials {
 206    pub user_id: u64,
 207    pub access_token: String,
 208}
 209
 210impl Default for ClientState {
 211    fn default() -> Self {
 212        Self {
 213            credentials: None,
 214            status: watch::channel_with(Status::SignedOut),
 215            entity_id_extractors: Default::default(),
 216            _reconnect_task: None,
 217            reconnect_interval: Duration::from_secs(5),
 218            models_by_message_type: Default::default(),
 219            entities_by_type_and_remote_id: Default::default(),
 220            entity_types_by_message_type: Default::default(),
 221            message_handlers: Default::default(),
 222        }
 223    }
 224}
 225
 226pub enum Subscription {
 227    Entity {
 228        client: Weak<Client>,
 229        id: (TypeId, u64),
 230    },
 231    Message {
 232        client: Weak<Client>,
 233        id: TypeId,
 234    },
 235}
 236
 237impl Drop for Subscription {
 238    fn drop(&mut self) {
 239        match self {
 240            Subscription::Entity { client, id } => {
 241                if let Some(client) = client.upgrade() {
 242                    let mut state = client.state.write();
 243                    let _ = state.entities_by_type_and_remote_id.remove(id);
 244                }
 245            }
 246            Subscription::Message { client, id } => {
 247                if let Some(client) = client.upgrade() {
 248                    let mut state = client.state.write();
 249                    let _ = state.entity_types_by_message_type.remove(id);
 250                    let _ = state.message_handlers.remove(id);
 251                }
 252            }
 253        }
 254    }
 255}
 256
 257impl Client {
 258    pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
 259        Arc::new(Self {
 260            id: 0,
 261            peer: Peer::new(),
 262            telemetry: Telemetry::new(http.clone(), cx),
 263            http,
 264            state: Default::default(),
 265
 266            #[cfg(any(test, feature = "test-support"))]
 267            authenticate: Default::default(),
 268            #[cfg(any(test, feature = "test-support"))]
 269            establish_connection: Default::default(),
 270        })
 271    }
 272
 273    pub fn id(&self) -> usize {
 274        self.id
 275    }
 276
 277    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 278        self.http.clone()
 279    }
 280
 281    #[cfg(any(test, feature = "test-support"))]
 282    pub fn set_id(&mut self, id: usize) -> &Self {
 283        self.id = id;
 284        self
 285    }
 286
 287    #[cfg(any(test, feature = "test-support"))]
 288    pub fn tear_down(&self) {
 289        let mut state = self.state.write();
 290        state._reconnect_task.take();
 291        state.message_handlers.clear();
 292        state.models_by_message_type.clear();
 293        state.entities_by_type_and_remote_id.clear();
 294        state.entity_id_extractors.clear();
 295        self.peer.reset();
 296    }
 297
 298    #[cfg(any(test, feature = "test-support"))]
 299    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 300    where
 301        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 302    {
 303        *self.authenticate.write() = Some(Box::new(authenticate));
 304        self
 305    }
 306
 307    #[cfg(any(test, feature = "test-support"))]
 308    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 309    where
 310        F: 'static
 311            + Send
 312            + Sync
 313            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 314    {
 315        *self.establish_connection.write() = Some(Box::new(connect));
 316        self
 317    }
 318
 319    pub fn user_id(&self) -> Option<u64> {
 320        self.state
 321            .read()
 322            .credentials
 323            .as_ref()
 324            .map(|credentials| credentials.user_id)
 325    }
 326
 327    pub fn status(&self) -> watch::Receiver<Status> {
 328        self.state.read().status.1.clone()
 329    }
 330
 331    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 332        log::info!("set status on client {}: {:?}", self.id, status);
 333        let mut state = self.state.write();
 334        *state.status.0.borrow_mut() = status;
 335
 336        match status {
 337            Status::Connected { .. } => {
 338                self.telemetry.set_user_id(self.user_id());
 339                state._reconnect_task = None;
 340            }
 341            Status::ConnectionLost => {
 342                let this = self.clone();
 343                let reconnect_interval = state.reconnect_interval;
 344                state._reconnect_task = Some(cx.spawn(|cx| async move {
 345                    let mut rng = StdRng::from_entropy();
 346                    let mut delay = Duration::from_millis(100);
 347                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 348                        log::error!("failed to connect {}", error);
 349                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 350                            this.set_status(
 351                                Status::ReconnectionError {
 352                                    next_reconnection: Instant::now() + delay,
 353                                },
 354                                &cx,
 355                            );
 356                            cx.background().timer(delay).await;
 357                            delay = delay
 358                                .mul_f32(rng.gen_range(1.0..=2.0))
 359                                .min(reconnect_interval);
 360                        } else {
 361                            break;
 362                        }
 363                    }
 364                }));
 365            }
 366            Status::SignedOut | Status::UpgradeRequired => {
 367                self.telemetry.set_user_id(self.user_id());
 368                state._reconnect_task.take();
 369            }
 370            _ => {}
 371        }
 372    }
 373
 374    pub fn add_view_for_remote_entity<T: View>(
 375        self: &Arc<Self>,
 376        remote_id: u64,
 377        cx: &mut ViewContext<T>,
 378    ) -> Subscription {
 379        let id = (TypeId::of::<T>(), remote_id);
 380        self.state
 381            .write()
 382            .entities_by_type_and_remote_id
 383            .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
 384        Subscription::Entity {
 385            client: Arc::downgrade(self),
 386            id,
 387        }
 388    }
 389
 390    pub fn add_model_for_remote_entity<T: Entity>(
 391        self: &Arc<Self>,
 392        remote_id: u64,
 393        cx: &mut ModelContext<T>,
 394    ) -> Subscription {
 395        let id = (TypeId::of::<T>(), remote_id);
 396        self.state
 397            .write()
 398            .entities_by_type_and_remote_id
 399            .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
 400        Subscription::Entity {
 401            client: Arc::downgrade(self),
 402            id,
 403        }
 404    }
 405
 406    pub fn add_message_handler<M, E, H, F>(
 407        self: &Arc<Self>,
 408        model: ModelHandle<E>,
 409        handler: H,
 410    ) -> Subscription
 411    where
 412        M: EnvelopedMessage,
 413        E: Entity,
 414        H: 'static
 415            + Send
 416            + Sync
 417            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 418        F: 'static + Future<Output = Result<()>>,
 419    {
 420        let message_type_id = TypeId::of::<M>();
 421
 422        let mut state = self.state.write();
 423        state
 424            .models_by_message_type
 425            .insert(message_type_id, model.downgrade().into());
 426
 427        let prev_handler = state.message_handlers.insert(
 428            message_type_id,
 429            Arc::new(move |handle, envelope, client, cx| {
 430                let handle = if let AnyEntityHandle::Model(handle) = handle {
 431                    handle
 432                } else {
 433                    unreachable!();
 434                };
 435                let model = handle.downcast::<E>().unwrap();
 436                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 437                handler(model, *envelope, client.clone(), cx).boxed_local()
 438            }),
 439        );
 440        if prev_handler.is_some() {
 441            panic!("registered handler for the same message twice");
 442        }
 443
 444        Subscription::Message {
 445            client: Arc::downgrade(self),
 446            id: message_type_id,
 447        }
 448    }
 449
 450    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 451    where
 452        M: EntityMessage,
 453        E: View,
 454        H: 'static
 455            + Send
 456            + Sync
 457            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 458        F: 'static + Future<Output = Result<()>>,
 459    {
 460        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 461            if let AnyEntityHandle::View(handle) = handle {
 462                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 463            } else {
 464                unreachable!();
 465            }
 466        })
 467    }
 468
 469    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 470    where
 471        M: EntityMessage,
 472        E: Entity,
 473        H: 'static
 474            + Send
 475            + Sync
 476            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 477        F: 'static + Future<Output = Result<()>>,
 478    {
 479        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 480            if let AnyEntityHandle::Model(handle) = handle {
 481                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 482            } else {
 483                unreachable!();
 484            }
 485        })
 486    }
 487
 488    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 489    where
 490        M: EntityMessage,
 491        E: Entity,
 492        H: 'static
 493            + Send
 494            + Sync
 495            + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 496        F: 'static + Future<Output = Result<()>>,
 497    {
 498        let model_type_id = TypeId::of::<E>();
 499        let message_type_id = TypeId::of::<M>();
 500
 501        let mut state = self.state.write();
 502        state
 503            .entity_types_by_message_type
 504            .insert(message_type_id, model_type_id);
 505        state
 506            .entity_id_extractors
 507            .entry(message_type_id)
 508            .or_insert_with(|| {
 509                |envelope| {
 510                    envelope
 511                        .as_any()
 512                        .downcast_ref::<TypedEnvelope<M>>()
 513                        .unwrap()
 514                        .payload
 515                        .remote_entity_id()
 516                }
 517            });
 518        let prev_handler = state.message_handlers.insert(
 519            message_type_id,
 520            Arc::new(move |handle, envelope, client, cx| {
 521                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 522                handler(handle, *envelope, client.clone(), cx).boxed_local()
 523            }),
 524        );
 525        if prev_handler.is_some() {
 526            panic!("registered handler for the same message twice");
 527        }
 528    }
 529
 530    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 531    where
 532        M: EntityMessage + RequestMessage,
 533        E: Entity,
 534        H: 'static
 535            + Send
 536            + Sync
 537            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 538        F: 'static + Future<Output = Result<M::Response>>,
 539    {
 540        self.add_model_message_handler(move |entity, envelope, client, cx| {
 541            Self::respond_to_request::<M, _>(
 542                envelope.receipt(),
 543                handler(entity, envelope, client.clone(), cx),
 544                client,
 545            )
 546        })
 547    }
 548
 549    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 550    where
 551        M: EntityMessage + RequestMessage,
 552        E: View,
 553        H: 'static
 554            + Send
 555            + Sync
 556            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 557        F: 'static + Future<Output = Result<M::Response>>,
 558    {
 559        self.add_view_message_handler(move |entity, envelope, client, cx| {
 560            Self::respond_to_request::<M, _>(
 561                envelope.receipt(),
 562                handler(entity, envelope, client.clone(), cx),
 563                client,
 564            )
 565        })
 566    }
 567
 568    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 569        receipt: Receipt<T>,
 570        response: F,
 571        client: Arc<Self>,
 572    ) -> Result<()> {
 573        match response.await {
 574            Ok(response) => {
 575                client.respond(receipt, response)?;
 576                Ok(())
 577            }
 578            Err(error) => {
 579                client.respond_with_error(
 580                    receipt,
 581                    proto::Error {
 582                        message: format!("{:?}", error),
 583                    },
 584                )?;
 585                Err(error)
 586            }
 587        }
 588    }
 589
 590    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 591        read_credentials_from_keychain(cx).is_some()
 592    }
 593
 594    #[async_recursion(?Send)]
 595    pub async fn authenticate_and_connect(
 596        self: &Arc<Self>,
 597        try_keychain: bool,
 598        cx: &AsyncAppContext,
 599    ) -> anyhow::Result<()> {
 600        let was_disconnected = match *self.status().borrow() {
 601            Status::SignedOut => true,
 602            Status::ConnectionError
 603            | Status::ConnectionLost
 604            | Status::Authenticating { .. }
 605            | Status::Reauthenticating { .. }
 606            | Status::ReconnectionError { .. } => false,
 607            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 608                return Ok(())
 609            }
 610            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 611        };
 612
 613        if was_disconnected {
 614            self.set_status(Status::Authenticating, cx);
 615        } else {
 616            self.set_status(Status::Reauthenticating, cx)
 617        }
 618
 619        let mut read_from_keychain = false;
 620        let mut credentials = self.state.read().credentials.clone();
 621        if credentials.is_none() && try_keychain {
 622            credentials = read_credentials_from_keychain(cx);
 623            read_from_keychain = credentials.is_some();
 624            if read_from_keychain {
 625                self.report_event("read credentials from keychain", Default::default());
 626            }
 627        }
 628        if credentials.is_none() {
 629            let mut status_rx = self.status();
 630            let _ = status_rx.next().await;
 631            futures::select_biased! {
 632                authenticate = self.authenticate(cx).fuse() => {
 633                    match authenticate {
 634                        Ok(creds) => credentials = Some(creds),
 635                        Err(err) => {
 636                            self.set_status(Status::ConnectionError, cx);
 637                            return Err(err);
 638                        }
 639                    }
 640                }
 641                _ = status_rx.next().fuse() => {
 642                    return Err(anyhow!("authentication canceled"));
 643                }
 644            }
 645        }
 646        let credentials = credentials.unwrap();
 647
 648        if was_disconnected {
 649            self.set_status(Status::Connecting, cx);
 650        } else {
 651            self.set_status(Status::Reconnecting, cx);
 652        }
 653
 654        match self.establish_connection(&credentials, cx).await {
 655            Ok(conn) => {
 656                self.state.write().credentials = Some(credentials.clone());
 657                if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 658                    write_credentials_to_keychain(&credentials, cx).log_err();
 659                }
 660                self.set_connection(conn, cx).await;
 661                Ok(())
 662            }
 663            Err(EstablishConnectionError::Unauthorized) => {
 664                self.state.write().credentials.take();
 665                if read_from_keychain {
 666                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 667                    self.set_status(Status::SignedOut, cx);
 668                    self.authenticate_and_connect(false, cx).await
 669                } else {
 670                    self.set_status(Status::ConnectionError, cx);
 671                    Err(EstablishConnectionError::Unauthorized)?
 672                }
 673            }
 674            Err(EstablishConnectionError::UpgradeRequired) => {
 675                self.set_status(Status::UpgradeRequired, cx);
 676                Err(EstablishConnectionError::UpgradeRequired)?
 677            }
 678            Err(error) => {
 679                self.set_status(Status::ConnectionError, cx);
 680                Err(error)?
 681            }
 682        }
 683    }
 684
 685    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
 686        let executor = cx.background();
 687        log::info!("add connection to peer");
 688        let (connection_id, handle_io, mut incoming) = self
 689            .peer
 690            .add_connection(conn, move |duration| executor.timer(duration))
 691            .await;
 692        log::info!("set status to connected {}", connection_id);
 693        self.set_status(Status::Connected { connection_id }, cx);
 694        cx.foreground()
 695            .spawn({
 696                let cx = cx.clone();
 697                let this = self.clone();
 698                async move {
 699                    let mut message_id = 0_usize;
 700                    while let Some(message) = incoming.next().await {
 701                        let mut state = this.state.write();
 702                        message_id += 1;
 703                        let type_name = message.payload_type_name();
 704                        let payload_type_id = message.payload_type_id();
 705                        let sender_id = message.original_sender_id().map(|id| id.0);
 706
 707                        let model = state
 708                            .models_by_message_type
 709                            .get(&payload_type_id)
 710                            .and_then(|model| model.upgrade(&cx))
 711                            .map(AnyEntityHandle::Model)
 712                            .or_else(|| {
 713                                let entity_type_id =
 714                                    *state.entity_types_by_message_type.get(&payload_type_id)?;
 715                                let entity_id = state
 716                                    .entity_id_extractors
 717                                    .get(&message.payload_type_id())
 718                                    .map(|extract_entity_id| {
 719                                        (extract_entity_id)(message.as_ref())
 720                                    })?;
 721
 722                                let entity = state
 723                                    .entities_by_type_and_remote_id
 724                                    .get(&(entity_type_id, entity_id))?;
 725                                if let Some(entity) = entity.upgrade(&cx) {
 726                                    Some(entity)
 727                                } else {
 728                                    state
 729                                        .entities_by_type_and_remote_id
 730                                        .remove(&(entity_type_id, entity_id));
 731                                    None
 732                                }
 733                            });
 734
 735                        let model = if let Some(model) = model {
 736                            model
 737                        } else {
 738                            log::info!("unhandled message {}", type_name);
 739                            continue;
 740                        };
 741
 742                        if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
 743                        {
 744                            drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
 745                            let future = handler(model, message, &this, cx.clone());
 746
 747                            let client_id = this.id;
 748                            log::debug!(
 749                                "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 750                                client_id,
 751                                message_id,
 752                                sender_id,
 753                                type_name
 754                            );
 755                            cx.foreground()
 756                                .spawn(async move {
 757                                    match future.await {
 758                                        Ok(()) => {
 759                                            log::debug!(
 760                                                "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 761                                                client_id,
 762                                                message_id,
 763                                                sender_id,
 764                                                type_name
 765                                            );
 766                                        }
 767                                        Err(error) => {
 768                                            log::error!(
 769                                                "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
 770                                                client_id,
 771                                                message_id,
 772                                                sender_id,
 773                                                type_name,
 774                                                error
 775                                            );
 776                                        }
 777                                    }
 778                                })
 779                                .detach();
 780                        } else {
 781                            log::info!("unhandled message {}", type_name);
 782                        }
 783
 784                        // Don't starve the main thread when receiving lots of messages at once.
 785                        smol::future::yield_now().await;
 786                    }
 787                }
 788            })
 789            .detach();
 790
 791        let handle_io = cx.background().spawn(handle_io);
 792        let this = self.clone();
 793        let cx = cx.clone();
 794        cx.foreground()
 795            .spawn(async move {
 796                match handle_io.await {
 797                    Ok(()) => {
 798                        if *this.status().borrow() == (Status::Connected { connection_id }) {
 799                            this.set_status(Status::SignedOut, &cx);
 800                        }
 801                    }
 802                    Err(err) => {
 803                        log::error!("connection error: {:?}", err);
 804                        this.set_status(Status::ConnectionLost, &cx);
 805                    }
 806                }
 807            })
 808            .detach();
 809    }
 810
 811    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 812        #[cfg(any(test, feature = "test-support"))]
 813        if let Some(callback) = self.authenticate.read().as_ref() {
 814            return callback(cx);
 815        }
 816
 817        self.authenticate_with_browser(cx)
 818    }
 819
 820    fn establish_connection(
 821        self: &Arc<Self>,
 822        credentials: &Credentials,
 823        cx: &AsyncAppContext,
 824    ) -> Task<Result<Connection, EstablishConnectionError>> {
 825        #[cfg(any(test, feature = "test-support"))]
 826        if let Some(callback) = self.establish_connection.read().as_ref() {
 827            return callback(credentials, cx);
 828        }
 829
 830        self.establish_websocket_connection(credentials, cx)
 831    }
 832
 833    fn establish_websocket_connection(
 834        self: &Arc<Self>,
 835        credentials: &Credentials,
 836        cx: &AsyncAppContext,
 837    ) -> Task<Result<Connection, EstablishConnectionError>> {
 838        let request = Request::builder()
 839            .header(
 840                "Authorization",
 841                format!("{} {}", credentials.user_id, credentials.access_token),
 842            )
 843            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
 844
 845        let http = self.http.clone();
 846        cx.background().spawn(async move {
 847            let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
 848            let rpc_response = http.get(&rpc_url, Default::default(), false).await?;
 849            if rpc_response.status().is_redirection() {
 850                rpc_url = rpc_response
 851                    .headers()
 852                    .get("Location")
 853                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 854                    .to_str()
 855                    .map_err(EstablishConnectionError::other)?
 856                    .to_string();
 857            }
 858            // Until we switch the zed.dev domain to point to the new Next.js app, there
 859            // will be no redirect required, and the app will connect directly to
 860            // wss://zed.dev/rpc.
 861            else if rpc_response.status() != StatusCode::UPGRADE_REQUIRED {
 862                Err(anyhow!(
 863                    "unexpected /rpc response status {}",
 864                    rpc_response.status()
 865                ))?
 866            }
 867
 868            let mut rpc_url = Url::parse(&rpc_url).context("invalid rpc url")?;
 869            let rpc_host = rpc_url
 870                .host_str()
 871                .zip(rpc_url.port_or_known_default())
 872                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
 873            let stream = smol::net::TcpStream::connect(rpc_host).await?;
 874
 875            log::info!("connected to rpc endpoint {}", rpc_url);
 876
 877            match rpc_url.scheme() {
 878                "https" => {
 879                    rpc_url.set_scheme("wss").unwrap();
 880                    let request = request.uri(rpc_url.as_str()).body(())?;
 881                    let (stream, _) =
 882                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
 883                    Ok(Connection::new(
 884                        stream
 885                            .map_err(|error| anyhow!(error))
 886                            .sink_map_err(|error| anyhow!(error)),
 887                    ))
 888                }
 889                "http" => {
 890                    rpc_url.set_scheme("ws").unwrap();
 891                    let request = request.uri(rpc_url.as_str()).body(())?;
 892                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
 893                    Ok(Connection::new(
 894                        stream
 895                            .map_err(|error| anyhow!(error))
 896                            .sink_map_err(|error| anyhow!(error)),
 897                    ))
 898                }
 899                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
 900            }
 901        })
 902    }
 903
 904    pub fn authenticate_with_browser(
 905        self: &Arc<Self>,
 906        cx: &AsyncAppContext,
 907    ) -> Task<Result<Credentials>> {
 908        let platform = cx.platform();
 909        let executor = cx.background();
 910        let telemetry = self.telemetry.clone();
 911        executor.clone().spawn(async move {
 912            // Generate a pair of asymmetric encryption keys. The public key will be used by the
 913            // zed server to encrypt the user's access token, so that it can'be intercepted by
 914            // any other app running on the user's device.
 915            let (public_key, private_key) =
 916                rpc::auth::keypair().expect("failed to generate keypair for auth");
 917            let public_key_string =
 918                String::try_from(public_key).expect("failed to serialize public key for auth");
 919
 920            // Start an HTTP server to receive the redirect from Zed's sign-in page.
 921            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
 922            let port = server.server_addr().port();
 923
 924            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
 925            // that the user is signing in from a Zed app running on the same device.
 926            let mut url = format!(
 927                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
 928                *ZED_SERVER_URL, port, public_key_string
 929            );
 930
 931            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
 932                log::info!("impersonating user @{}", impersonate_login);
 933                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
 934            }
 935
 936            platform.open_url(&url);
 937
 938            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
 939            // access token from the query params.
 940            //
 941            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
 942            // custom URL scheme instead of this local HTTP server.
 943            let (user_id, access_token) = executor
 944                .spawn(async move {
 945                    for _ in 0..100 {
 946                        if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
 947                            let path = req.url();
 948                            let mut user_id = None;
 949                            let mut access_token = None;
 950                            let url = Url::parse(&format!("http://example.com{}", path))
 951                                .context("failed to parse login notification url")?;
 952                            for (key, value) in url.query_pairs() {
 953                                if key == "access_token" {
 954                                    access_token = Some(value.to_string());
 955                                } else if key == "user_id" {
 956                                    user_id = Some(value.to_string());
 957                                }
 958                            }
 959
 960                            let post_auth_url =
 961                                format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
 962                            req.respond(
 963                                tiny_http::Response::empty(302).with_header(
 964                                    tiny_http::Header::from_bytes(
 965                                        &b"Location"[..],
 966                                        post_auth_url.as_bytes(),
 967                                    )
 968                                    .unwrap(),
 969                                ),
 970                            )
 971                            .context("failed to respond to login http request")?;
 972                            return Ok((
 973                                user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
 974                                access_token
 975                                    .ok_or_else(|| anyhow!("missing access_token parameter"))?,
 976                            ));
 977                        }
 978                    }
 979
 980                    Err(anyhow!("didn't receive login redirect"))
 981                })
 982                .await?;
 983
 984            let access_token = private_key
 985                .decrypt_string(&access_token)
 986                .context("failed to decrypt access token")?;
 987            platform.activate(true);
 988
 989            telemetry.report_event("authenticate with browser", Default::default());
 990
 991            Ok(Credentials {
 992                user_id: user_id.parse()?,
 993                access_token,
 994            })
 995        })
 996    }
 997
 998    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
 999        let conn_id = self.connection_id()?;
1000        self.peer.disconnect(conn_id);
1001        self.set_status(Status::SignedOut, cx);
1002        Ok(())
1003    }
1004
1005    fn connection_id(&self) -> Result<ConnectionId> {
1006        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1007            Ok(connection_id)
1008        } else {
1009            Err(anyhow!("not connected"))
1010        }
1011    }
1012
1013    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1014        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1015        self.peer.send(self.connection_id()?, message)
1016    }
1017
1018    pub fn request<T: RequestMessage>(
1019        &self,
1020        request: T,
1021    ) -> impl Future<Output = Result<T::Response>> {
1022        let client_id = self.id;
1023        log::debug!(
1024            "rpc request start. client_id:{}. name:{}",
1025            client_id,
1026            T::NAME
1027        );
1028        let response = self
1029            .connection_id()
1030            .map(|conn_id| self.peer.request(conn_id, request));
1031        async move {
1032            let response = response?.await;
1033            log::debug!(
1034                "rpc request finish. client_id:{}. name:{}",
1035                client_id,
1036                T::NAME
1037            );
1038            response
1039        }
1040    }
1041
1042    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1043        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1044        self.peer.respond(receipt, response)
1045    }
1046
1047    fn respond_with_error<T: RequestMessage>(
1048        &self,
1049        receipt: Receipt<T>,
1050        error: proto::Error,
1051    ) -> Result<()> {
1052        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1053        self.peer.respond_with_error(receipt, error)
1054    }
1055
1056    pub fn start_telemetry(&self, db: Arc<Db>) {
1057        self.telemetry.start(db);
1058    }
1059
1060    pub fn report_event(&self, kind: &str, properties: Value) {
1061        self.telemetry.report_event(kind, properties)
1062    }
1063}
1064
1065impl AnyWeakEntityHandle {
1066    fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
1067        match self {
1068            AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
1069            AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
1070        }
1071    }
1072}
1073
1074fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1075    if IMPERSONATE_LOGIN.is_some() {
1076        return None;
1077    }
1078
1079    let (user_id, access_token) = cx
1080        .platform()
1081        .read_credentials(&ZED_SERVER_URL)
1082        .log_err()
1083        .flatten()?;
1084    Some(Credentials {
1085        user_id: user_id.parse().ok()?,
1086        access_token: String::from_utf8(access_token).ok()?,
1087    })
1088}
1089
1090fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1091    cx.platform().write_credentials(
1092        &ZED_SERVER_URL,
1093        &credentials.user_id.to_string(),
1094        credentials.access_token.as_bytes(),
1095    )
1096}
1097
1098const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1099
1100pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1101    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1102}
1103
1104pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1105    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1106    let mut parts = path.split('/');
1107    let id = parts.next()?.parse::<u64>().ok()?;
1108    let access_token = parts.next()?;
1109    if access_token.is_empty() {
1110        return None;
1111    }
1112    Some((id, access_token.to_string()))
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117    use super::*;
1118    use crate::test::{FakeHttpClient, FakeServer};
1119    use gpui::{executor::Deterministic, TestAppContext};
1120    use parking_lot::Mutex;
1121    use std::future;
1122
1123    #[gpui::test(iterations = 10)]
1124    async fn test_reconnection(cx: &mut TestAppContext) {
1125        cx.foreground().forbid_parking();
1126
1127        let user_id = 5;
1128        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1129        let server = FakeServer::for_client(user_id, &client, cx).await;
1130        let mut status = client.status();
1131        assert!(matches!(
1132            status.next().await,
1133            Some(Status::Connected { .. })
1134        ));
1135        assert_eq!(server.auth_count(), 1);
1136
1137        server.forbid_connections();
1138        server.disconnect();
1139        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1140
1141        server.allow_connections();
1142        cx.foreground().advance_clock(Duration::from_secs(10));
1143        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1144        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1145
1146        server.forbid_connections();
1147        server.disconnect();
1148        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1149
1150        // Clear cached credentials after authentication fails
1151        server.roll_access_token();
1152        server.allow_connections();
1153        cx.foreground().advance_clock(Duration::from_secs(10));
1154        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1155        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1156    }
1157
1158    #[gpui::test(iterations = 10)]
1159    async fn test_authenticating_more_than_once(
1160        cx: &mut TestAppContext,
1161        deterministic: Arc<Deterministic>,
1162    ) {
1163        cx.foreground().forbid_parking();
1164
1165        let auth_count = Arc::new(Mutex::new(0));
1166        let dropped_auth_count = Arc::new(Mutex::new(0));
1167        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1168        client.override_authenticate({
1169            let auth_count = auth_count.clone();
1170            let dropped_auth_count = dropped_auth_count.clone();
1171            move |cx| {
1172                let auth_count = auth_count.clone();
1173                let dropped_auth_count = dropped_auth_count.clone();
1174                cx.foreground().spawn(async move {
1175                    *auth_count.lock() += 1;
1176                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1177                    future::pending::<()>().await;
1178                    unreachable!()
1179                })
1180            }
1181        });
1182
1183        let _authenticate = cx.spawn(|cx| {
1184            let client = client.clone();
1185            async move { client.authenticate_and_connect(false, &cx).await }
1186        });
1187        deterministic.run_until_parked();
1188        assert_eq!(*auth_count.lock(), 1);
1189        assert_eq!(*dropped_auth_count.lock(), 0);
1190
1191        let _authenticate = cx.spawn(|cx| {
1192            let client = client.clone();
1193            async move { client.authenticate_and_connect(false, &cx).await }
1194        });
1195        deterministic.run_until_parked();
1196        assert_eq!(*auth_count.lock(), 2);
1197        assert_eq!(*dropped_auth_count.lock(), 1);
1198    }
1199
1200    #[test]
1201    fn test_encode_and_decode_worktree_url() {
1202        let url = encode_worktree_url(5, "deadbeef");
1203        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1204        assert_eq!(
1205            decode_worktree_url(&format!("\n {}\t", url)),
1206            Some((5, "deadbeef".to_string()))
1207        );
1208        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1209    }
1210
1211    #[gpui::test]
1212    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1213        cx.foreground().forbid_parking();
1214
1215        let user_id = 5;
1216        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1217        let server = FakeServer::for_client(user_id, &client, cx).await;
1218
1219        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1220        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1221        client.add_model_message_handler(
1222            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1223                match model.read_with(&cx, |model, _| model.id) {
1224                    1 => done_tx1.try_send(()).unwrap(),
1225                    2 => done_tx2.try_send(()).unwrap(),
1226                    _ => unreachable!(),
1227                }
1228                async { Ok(()) }
1229            },
1230        );
1231        let model1 = cx.add_model(|_| Model {
1232            id: 1,
1233            subscription: None,
1234        });
1235        let model2 = cx.add_model(|_| Model {
1236            id: 2,
1237            subscription: None,
1238        });
1239        let model3 = cx.add_model(|_| Model {
1240            id: 3,
1241            subscription: None,
1242        });
1243
1244        let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1245        let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1246        // Ensure dropping a subscription for the same entity type still allows receiving of
1247        // messages for other entity IDs of the same type.
1248        let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1249        drop(subscription3);
1250
1251        server.send(proto::JoinProject { project_id: 1 });
1252        server.send(proto::JoinProject { project_id: 2 });
1253        done_rx1.next().await.unwrap();
1254        done_rx2.next().await.unwrap();
1255    }
1256
1257    #[gpui::test]
1258    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1259        cx.foreground().forbid_parking();
1260
1261        let user_id = 5;
1262        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1263        let server = FakeServer::for_client(user_id, &client, cx).await;
1264
1265        let model = cx.add_model(|_| Model::default());
1266        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1267        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1268        let subscription1 = client.add_message_handler(
1269            model.clone(),
1270            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1271                done_tx1.try_send(()).unwrap();
1272                async { Ok(()) }
1273            },
1274        );
1275        drop(subscription1);
1276        let _subscription2 =
1277            client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1278                done_tx2.try_send(()).unwrap();
1279                async { Ok(()) }
1280            });
1281        server.send(proto::Ping {});
1282        done_rx2.next().await.unwrap();
1283    }
1284
1285    #[gpui::test]
1286    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1287        cx.foreground().forbid_parking();
1288
1289        let user_id = 5;
1290        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1291        let server = FakeServer::for_client(user_id, &client, cx).await;
1292
1293        let model = cx.add_model(|_| Model::default());
1294        let (done_tx, mut done_rx) = smol::channel::unbounded();
1295        let subscription = client.add_message_handler(
1296            model.clone(),
1297            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1298                model.update(&mut cx, |model, _| model.subscription.take());
1299                done_tx.try_send(()).unwrap();
1300                async { Ok(()) }
1301            },
1302        );
1303        model.update(cx, |model, _| {
1304            model.subscription = Some(subscription);
1305        });
1306        server.send(proto::Ping {});
1307        done_rx.next().await.unwrap();
1308    }
1309
1310    #[derive(Default)]
1311    struct Model {
1312        id: usize,
1313        subscription: Option<Subscription>,
1314    }
1315
1316    impl Entity for Model {
1317        type Event = ();
1318    }
1319}