client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod channel;
   5pub mod http;
   6pub mod telemetry;
   7pub mod user;
   8
   9use anyhow::{anyhow, Context, Result};
  10use async_recursion::async_recursion;
  11use async_tungstenite::tungstenite::{
  12    error::Error as WebsocketError,
  13    http::{Request, StatusCode},
  14};
  15use db::Db;
  16use futures::{future::LocalBoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt};
  17use gpui::{
  18    actions,
  19    serde_json::{json, Value},
  20    AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AppContext,
  21    AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext,
  22    ViewHandle,
  23};
  24use http::HttpClient;
  25use lazy_static::lazy_static;
  26use parking_lot::RwLock;
  27use postage::watch;
  28use rand::prelude::*;
  29use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
  30use std::{
  31    any::TypeId,
  32    collections::HashMap,
  33    convert::TryFrom,
  34    fmt::Write as _,
  35    future::Future,
  36    path::PathBuf,
  37    sync::{Arc, Weak},
  38    time::{Duration, Instant},
  39};
  40use telemetry::Telemetry;
  41use thiserror::Error;
  42use url::Url;
  43use util::{ResultExt, TryFutureExt};
  44
  45pub use channel::*;
  46pub use rpc::*;
  47pub use user::*;
  48
  49lazy_static! {
  50    pub static ref ZED_SERVER_URL: String =
  51        std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
  52    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  53        .ok()
  54        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  55}
  56
  57pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
  58
  59actions!(client, [Authenticate, TestTelemetry]);
  60
  61pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
  62    cx.add_global_action({
  63        let client = client.clone();
  64        move |_: &Authenticate, cx| {
  65            let client = client.clone();
  66            cx.spawn(
  67                |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
  68            )
  69            .detach();
  70        }
  71    });
  72    cx.add_global_action({
  73        let client = client.clone();
  74        move |_: &TestTelemetry, _| {
  75            client.report_event(
  76                "test_telemetry",
  77                json!({
  78                    "test_property": "test_value"
  79                }),
  80            )
  81        }
  82    });
  83}
  84
  85pub struct Client {
  86    id: usize,
  87    peer: Arc<Peer>,
  88    http: Arc<dyn HttpClient>,
  89    telemetry: Arc<Telemetry>,
  90    state: RwLock<ClientState>,
  91
  92    #[allow(clippy::type_complexity)]
  93    #[cfg(any(test, feature = "test-support"))]
  94    authenticate: RwLock<
  95        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
  96    >,
  97
  98    #[allow(clippy::type_complexity)]
  99    #[cfg(any(test, feature = "test-support"))]
 100    establish_connection: RwLock<
 101        Option<
 102            Box<
 103                dyn 'static
 104                    + Send
 105                    + Sync
 106                    + Fn(
 107                        &Credentials,
 108                        &AsyncAppContext,
 109                    ) -> Task<Result<Connection, EstablishConnectionError>>,
 110            >,
 111        >,
 112    >,
 113}
 114
 115#[derive(Error, Debug)]
 116pub enum EstablishConnectionError {
 117    #[error("upgrade required")]
 118    UpgradeRequired,
 119    #[error("unauthorized")]
 120    Unauthorized,
 121    #[error("{0}")]
 122    Other(#[from] anyhow::Error),
 123    #[error("{0}")]
 124    Http(#[from] http::Error),
 125    #[error("{0}")]
 126    Io(#[from] std::io::Error),
 127    #[error("{0}")]
 128    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 129}
 130
 131impl From<WebsocketError> for EstablishConnectionError {
 132    fn from(error: WebsocketError) -> Self {
 133        if let WebsocketError::Http(response) = &error {
 134            match response.status() {
 135                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 136                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 137                _ => {}
 138            }
 139        }
 140        EstablishConnectionError::Other(error.into())
 141    }
 142}
 143
 144impl EstablishConnectionError {
 145    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 146        Self::Other(error.into())
 147    }
 148}
 149
 150#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 151pub enum Status {
 152    SignedOut,
 153    UpgradeRequired,
 154    Authenticating,
 155    Connecting,
 156    ConnectionError,
 157    Connected { connection_id: ConnectionId },
 158    ConnectionLost,
 159    Reauthenticating,
 160    Reconnecting,
 161    ReconnectionError { next_reconnection: Instant },
 162}
 163
 164impl Status {
 165    pub fn is_connected(&self) -> bool {
 166        matches!(self, Self::Connected { .. })
 167    }
 168}
 169
 170struct ClientState {
 171    credentials: Option<Credentials>,
 172    status: (watch::Sender<Status>, watch::Receiver<Status>),
 173    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 174    _reconnect_task: Option<Task<()>>,
 175    reconnect_interval: Duration,
 176    entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
 177    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 178    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 179    #[allow(clippy::type_complexity)]
 180    message_handlers: HashMap<
 181        TypeId,
 182        Arc<
 183            dyn Send
 184                + Sync
 185                + Fn(
 186                    AnyEntityHandle,
 187                    Box<dyn AnyTypedEnvelope>,
 188                    &Arc<Client>,
 189                    AsyncAppContext,
 190                ) -> LocalBoxFuture<'static, Result<()>>,
 191        >,
 192    >,
 193}
 194
 195enum AnyWeakEntityHandle {
 196    Model(AnyWeakModelHandle),
 197    View(AnyWeakViewHandle),
 198}
 199
 200enum AnyEntityHandle {
 201    Model(AnyModelHandle),
 202    View(AnyViewHandle),
 203}
 204
 205#[derive(Clone, Debug)]
 206pub struct Credentials {
 207    pub user_id: u64,
 208    pub access_token: String,
 209}
 210
 211impl Default for ClientState {
 212    fn default() -> Self {
 213        Self {
 214            credentials: None,
 215            status: watch::channel_with(Status::SignedOut),
 216            entity_id_extractors: Default::default(),
 217            _reconnect_task: None,
 218            reconnect_interval: Duration::from_secs(5),
 219            models_by_message_type: Default::default(),
 220            entities_by_type_and_remote_id: Default::default(),
 221            entity_types_by_message_type: Default::default(),
 222            message_handlers: Default::default(),
 223        }
 224    }
 225}
 226
 227pub enum Subscription {
 228    Entity {
 229        client: Weak<Client>,
 230        id: (TypeId, u64),
 231    },
 232    Message {
 233        client: Weak<Client>,
 234        id: TypeId,
 235    },
 236}
 237
 238impl Drop for Subscription {
 239    fn drop(&mut self) {
 240        match self {
 241            Subscription::Entity { client, id } => {
 242                if let Some(client) = client.upgrade() {
 243                    let mut state = client.state.write();
 244                    let _ = state.entities_by_type_and_remote_id.remove(id);
 245                }
 246            }
 247            Subscription::Message { client, id } => {
 248                if let Some(client) = client.upgrade() {
 249                    let mut state = client.state.write();
 250                    let _ = state.entity_types_by_message_type.remove(id);
 251                    let _ = state.message_handlers.remove(id);
 252                }
 253            }
 254        }
 255    }
 256}
 257
 258impl Client {
 259    pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
 260        Arc::new(Self {
 261            id: 0,
 262            peer: Peer::new(),
 263            telemetry: Telemetry::new(http.clone(), cx),
 264            http,
 265            state: Default::default(),
 266
 267            #[cfg(any(test, feature = "test-support"))]
 268            authenticate: Default::default(),
 269            #[cfg(any(test, feature = "test-support"))]
 270            establish_connection: Default::default(),
 271        })
 272    }
 273
 274    pub fn id(&self) -> usize {
 275        self.id
 276    }
 277
 278    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 279        self.http.clone()
 280    }
 281
 282    #[cfg(any(test, feature = "test-support"))]
 283    pub fn set_id(&mut self, id: usize) -> &Self {
 284        self.id = id;
 285        self
 286    }
 287
 288    #[cfg(any(test, feature = "test-support"))]
 289    pub fn tear_down(&self) {
 290        let mut state = self.state.write();
 291        state._reconnect_task.take();
 292        state.message_handlers.clear();
 293        state.models_by_message_type.clear();
 294        state.entities_by_type_and_remote_id.clear();
 295        state.entity_id_extractors.clear();
 296        self.peer.reset();
 297    }
 298
 299    #[cfg(any(test, feature = "test-support"))]
 300    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 301    where
 302        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 303    {
 304        *self.authenticate.write() = Some(Box::new(authenticate));
 305        self
 306    }
 307
 308    #[cfg(any(test, feature = "test-support"))]
 309    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 310    where
 311        F: 'static
 312            + Send
 313            + Sync
 314            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 315    {
 316        *self.establish_connection.write() = Some(Box::new(connect));
 317        self
 318    }
 319
 320    pub fn user_id(&self) -> Option<u64> {
 321        self.state
 322            .read()
 323            .credentials
 324            .as_ref()
 325            .map(|credentials| credentials.user_id)
 326    }
 327
 328    pub fn status(&self) -> watch::Receiver<Status> {
 329        self.state.read().status.1.clone()
 330    }
 331
 332    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 333        log::info!("set status on client {}: {:?}", self.id, status);
 334        let mut state = self.state.write();
 335        *state.status.0.borrow_mut() = status;
 336        let user_id = state.credentials.as_ref().map(|c| c.user_id);
 337
 338        match status {
 339            Status::Connected { .. } => {
 340                self.telemetry.set_user_id(user_id);
 341                state._reconnect_task = None;
 342            }
 343            Status::ConnectionLost => {
 344                let this = self.clone();
 345                let reconnect_interval = state.reconnect_interval;
 346                state._reconnect_task = Some(cx.spawn(|cx| async move {
 347                    let mut rng = StdRng::from_entropy();
 348                    let mut delay = Duration::from_millis(100);
 349                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 350                        log::error!("failed to connect {}", error);
 351                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 352                            this.set_status(
 353                                Status::ReconnectionError {
 354                                    next_reconnection: Instant::now() + delay,
 355                                },
 356                                &cx,
 357                            );
 358                            cx.background().timer(delay).await;
 359                            delay = delay
 360                                .mul_f32(rng.gen_range(1.0..=2.0))
 361                                .min(reconnect_interval);
 362                        } else {
 363                            break;
 364                        }
 365                    }
 366                }));
 367            }
 368            Status::SignedOut | Status::UpgradeRequired => {
 369                self.telemetry.set_user_id(user_id);
 370                state._reconnect_task.take();
 371            }
 372            _ => {}
 373        }
 374    }
 375
 376    pub fn add_view_for_remote_entity<T: View>(
 377        self: &Arc<Self>,
 378        remote_id: u64,
 379        cx: &mut ViewContext<T>,
 380    ) -> Subscription {
 381        let id = (TypeId::of::<T>(), remote_id);
 382        self.state
 383            .write()
 384            .entities_by_type_and_remote_id
 385            .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
 386        Subscription::Entity {
 387            client: Arc::downgrade(self),
 388            id,
 389        }
 390    }
 391
 392    pub fn add_model_for_remote_entity<T: Entity>(
 393        self: &Arc<Self>,
 394        remote_id: u64,
 395        cx: &mut ModelContext<T>,
 396    ) -> Subscription {
 397        let id = (TypeId::of::<T>(), remote_id);
 398        self.state
 399            .write()
 400            .entities_by_type_and_remote_id
 401            .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
 402        Subscription::Entity {
 403            client: Arc::downgrade(self),
 404            id,
 405        }
 406    }
 407
 408    pub fn add_message_handler<M, E, H, F>(
 409        self: &Arc<Self>,
 410        model: ModelHandle<E>,
 411        handler: H,
 412    ) -> Subscription
 413    where
 414        M: EnvelopedMessage,
 415        E: Entity,
 416        H: 'static
 417            + Send
 418            + Sync
 419            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 420        F: 'static + Future<Output = Result<()>>,
 421    {
 422        let message_type_id = TypeId::of::<M>();
 423
 424        let mut state = self.state.write();
 425        state
 426            .models_by_message_type
 427            .insert(message_type_id, model.downgrade().into());
 428
 429        let prev_handler = state.message_handlers.insert(
 430            message_type_id,
 431            Arc::new(move |handle, envelope, client, cx| {
 432                let handle = if let AnyEntityHandle::Model(handle) = handle {
 433                    handle
 434                } else {
 435                    unreachable!();
 436                };
 437                let model = handle.downcast::<E>().unwrap();
 438                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 439                handler(model, *envelope, client.clone(), cx).boxed_local()
 440            }),
 441        );
 442        if prev_handler.is_some() {
 443            panic!("registered handler for the same message twice");
 444        }
 445
 446        Subscription::Message {
 447            client: Arc::downgrade(self),
 448            id: message_type_id,
 449        }
 450    }
 451
 452    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 453    where
 454        M: EntityMessage,
 455        E: View,
 456        H: 'static
 457            + Send
 458            + Sync
 459            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 460        F: 'static + Future<Output = Result<()>>,
 461    {
 462        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 463            if let AnyEntityHandle::View(handle) = handle {
 464                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 465            } else {
 466                unreachable!();
 467            }
 468        })
 469    }
 470
 471    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 472    where
 473        M: EntityMessage,
 474        E: Entity,
 475        H: 'static
 476            + Send
 477            + Sync
 478            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 479        F: 'static + Future<Output = Result<()>>,
 480    {
 481        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 482            if let AnyEntityHandle::Model(handle) = handle {
 483                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 484            } else {
 485                unreachable!();
 486            }
 487        })
 488    }
 489
 490    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 491    where
 492        M: EntityMessage,
 493        E: Entity,
 494        H: 'static
 495            + Send
 496            + Sync
 497            + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 498        F: 'static + Future<Output = Result<()>>,
 499    {
 500        let model_type_id = TypeId::of::<E>();
 501        let message_type_id = TypeId::of::<M>();
 502
 503        let mut state = self.state.write();
 504        state
 505            .entity_types_by_message_type
 506            .insert(message_type_id, model_type_id);
 507        state
 508            .entity_id_extractors
 509            .entry(message_type_id)
 510            .or_insert_with(|| {
 511                |envelope| {
 512                    envelope
 513                        .as_any()
 514                        .downcast_ref::<TypedEnvelope<M>>()
 515                        .unwrap()
 516                        .payload
 517                        .remote_entity_id()
 518                }
 519            });
 520        let prev_handler = state.message_handlers.insert(
 521            message_type_id,
 522            Arc::new(move |handle, envelope, client, cx| {
 523                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 524                handler(handle, *envelope, client.clone(), cx).boxed_local()
 525            }),
 526        );
 527        if prev_handler.is_some() {
 528            panic!("registered handler for the same message twice");
 529        }
 530    }
 531
 532    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 533    where
 534        M: EntityMessage + RequestMessage,
 535        E: Entity,
 536        H: 'static
 537            + Send
 538            + Sync
 539            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 540        F: 'static + Future<Output = Result<M::Response>>,
 541    {
 542        self.add_model_message_handler(move |entity, envelope, client, cx| {
 543            Self::respond_to_request::<M, _>(
 544                envelope.receipt(),
 545                handler(entity, envelope, client.clone(), cx),
 546                client,
 547            )
 548        })
 549    }
 550
 551    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 552    where
 553        M: EntityMessage + RequestMessage,
 554        E: View,
 555        H: 'static
 556            + Send
 557            + Sync
 558            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 559        F: 'static + Future<Output = Result<M::Response>>,
 560    {
 561        self.add_view_message_handler(move |entity, envelope, client, cx| {
 562            Self::respond_to_request::<M, _>(
 563                envelope.receipt(),
 564                handler(entity, envelope, client.clone(), cx),
 565                client,
 566            )
 567        })
 568    }
 569
 570    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 571        receipt: Receipt<T>,
 572        response: F,
 573        client: Arc<Self>,
 574    ) -> Result<()> {
 575        match response.await {
 576            Ok(response) => {
 577                client.respond(receipt, response)?;
 578                Ok(())
 579            }
 580            Err(error) => {
 581                client.respond_with_error(
 582                    receipt,
 583                    proto::Error {
 584                        message: format!("{:?}", error),
 585                    },
 586                )?;
 587                Err(error)
 588            }
 589        }
 590    }
 591
 592    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 593        read_credentials_from_keychain(cx).is_some()
 594    }
 595
 596    #[async_recursion(?Send)]
 597    pub async fn authenticate_and_connect(
 598        self: &Arc<Self>,
 599        try_keychain: bool,
 600        cx: &AsyncAppContext,
 601    ) -> anyhow::Result<()> {
 602        let was_disconnected = match *self.status().borrow() {
 603            Status::SignedOut => true,
 604            Status::ConnectionError
 605            | Status::ConnectionLost
 606            | Status::Authenticating { .. }
 607            | Status::Reauthenticating { .. }
 608            | Status::ReconnectionError { .. } => false,
 609            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 610                return Ok(())
 611            }
 612            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 613        };
 614
 615        if was_disconnected {
 616            self.set_status(Status::Authenticating, cx);
 617        } else {
 618            self.set_status(Status::Reauthenticating, cx)
 619        }
 620
 621        let mut read_from_keychain = false;
 622        let mut credentials = self.state.read().credentials.clone();
 623        if credentials.is_none() && try_keychain {
 624            credentials = read_credentials_from_keychain(cx);
 625            read_from_keychain = credentials.is_some();
 626            if read_from_keychain {
 627                self.report_event("read credentials from keychain", Default::default());
 628            }
 629        }
 630        if credentials.is_none() {
 631            let mut status_rx = self.status();
 632            let _ = status_rx.next().await;
 633            futures::select_biased! {
 634                authenticate = self.authenticate(cx).fuse() => {
 635                    match authenticate {
 636                        Ok(creds) => credentials = Some(creds),
 637                        Err(err) => {
 638                            self.set_status(Status::ConnectionError, cx);
 639                            return Err(err);
 640                        }
 641                    }
 642                }
 643                _ = status_rx.next().fuse() => {
 644                    return Err(anyhow!("authentication canceled"));
 645                }
 646            }
 647        }
 648        let credentials = credentials.unwrap();
 649
 650        if was_disconnected {
 651            self.set_status(Status::Connecting, cx);
 652        } else {
 653            self.set_status(Status::Reconnecting, cx);
 654        }
 655
 656        match self.establish_connection(&credentials, cx).await {
 657            Ok(conn) => {
 658                self.state.write().credentials = Some(credentials.clone());
 659                if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 660                    write_credentials_to_keychain(&credentials, cx).log_err();
 661                }
 662                self.set_connection(conn, cx).await;
 663                Ok(())
 664            }
 665            Err(EstablishConnectionError::Unauthorized) => {
 666                self.state.write().credentials.take();
 667                if read_from_keychain {
 668                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 669                    self.set_status(Status::SignedOut, cx);
 670                    self.authenticate_and_connect(false, cx).await
 671                } else {
 672                    self.set_status(Status::ConnectionError, cx);
 673                    Err(EstablishConnectionError::Unauthorized)?
 674                }
 675            }
 676            Err(EstablishConnectionError::UpgradeRequired) => {
 677                self.set_status(Status::UpgradeRequired, cx);
 678                Err(EstablishConnectionError::UpgradeRequired)?
 679            }
 680            Err(error) => {
 681                self.set_status(Status::ConnectionError, cx);
 682                Err(error)?
 683            }
 684        }
 685    }
 686
 687    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
 688        let executor = cx.background();
 689        log::info!("add connection to peer");
 690        let (connection_id, handle_io, mut incoming) = self
 691            .peer
 692            .add_connection(conn, move |duration| executor.timer(duration))
 693            .await;
 694        log::info!("set status to connected {}", connection_id);
 695        self.set_status(Status::Connected { connection_id }, cx);
 696        cx.foreground()
 697            .spawn({
 698                let cx = cx.clone();
 699                let this = self.clone();
 700                async move {
 701                    let mut message_id = 0_usize;
 702                    while let Some(message) = incoming.next().await {
 703                        let mut state = this.state.write();
 704                        message_id += 1;
 705                        let type_name = message.payload_type_name();
 706                        let payload_type_id = message.payload_type_id();
 707                        let sender_id = message.original_sender_id().map(|id| id.0);
 708
 709                        let model = state
 710                            .models_by_message_type
 711                            .get(&payload_type_id)
 712                            .and_then(|model| model.upgrade(&cx))
 713                            .map(AnyEntityHandle::Model)
 714                            .or_else(|| {
 715                                let entity_type_id =
 716                                    *state.entity_types_by_message_type.get(&payload_type_id)?;
 717                                let entity_id = state
 718                                    .entity_id_extractors
 719                                    .get(&message.payload_type_id())
 720                                    .map(|extract_entity_id| {
 721                                        (extract_entity_id)(message.as_ref())
 722                                    })?;
 723
 724                                let entity = state
 725                                    .entities_by_type_and_remote_id
 726                                    .get(&(entity_type_id, entity_id))?;
 727                                if let Some(entity) = entity.upgrade(&cx) {
 728                                    Some(entity)
 729                                } else {
 730                                    state
 731                                        .entities_by_type_and_remote_id
 732                                        .remove(&(entity_type_id, entity_id));
 733                                    None
 734                                }
 735                            });
 736
 737                        let model = if let Some(model) = model {
 738                            model
 739                        } else {
 740                            log::info!("unhandled message {}", type_name);
 741                            continue;
 742                        };
 743
 744                        if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
 745                        {
 746                            drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
 747                            let future = handler(model, message, &this, cx.clone());
 748
 749                            let client_id = this.id;
 750                            log::debug!(
 751                                "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 752                                client_id,
 753                                message_id,
 754                                sender_id,
 755                                type_name
 756                            );
 757                            cx.foreground()
 758                                .spawn(async move {
 759                                    match future.await {
 760                                        Ok(()) => {
 761                                            log::debug!(
 762                                                "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 763                                                client_id,
 764                                                message_id,
 765                                                sender_id,
 766                                                type_name
 767                                            );
 768                                        }
 769                                        Err(error) => {
 770                                            log::error!(
 771                                                "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
 772                                                client_id,
 773                                                message_id,
 774                                                sender_id,
 775                                                type_name,
 776                                                error
 777                                            );
 778                                        }
 779                                    }
 780                                })
 781                                .detach();
 782                        } else {
 783                            log::info!("unhandled message {}", type_name);
 784                        }
 785
 786                        // Don't starve the main thread when receiving lots of messages at once.
 787                        smol::future::yield_now().await;
 788                    }
 789                }
 790            })
 791            .detach();
 792
 793        let handle_io = cx.background().spawn(handle_io);
 794        let this = self.clone();
 795        let cx = cx.clone();
 796        cx.foreground()
 797            .spawn(async move {
 798                match handle_io.await {
 799                    Ok(()) => {
 800                        if *this.status().borrow() == (Status::Connected { connection_id }) {
 801                            this.set_status(Status::SignedOut, &cx);
 802                        }
 803                    }
 804                    Err(err) => {
 805                        log::error!("connection error: {:?}", err);
 806                        this.set_status(Status::ConnectionLost, &cx);
 807                    }
 808                }
 809            })
 810            .detach();
 811    }
 812
 813    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 814        #[cfg(any(test, feature = "test-support"))]
 815        if let Some(callback) = self.authenticate.read().as_ref() {
 816            return callback(cx);
 817        }
 818
 819        self.authenticate_with_browser(cx)
 820    }
 821
 822    fn establish_connection(
 823        self: &Arc<Self>,
 824        credentials: &Credentials,
 825        cx: &AsyncAppContext,
 826    ) -> Task<Result<Connection, EstablishConnectionError>> {
 827        #[cfg(any(test, feature = "test-support"))]
 828        if let Some(callback) = self.establish_connection.read().as_ref() {
 829            return callback(credentials, cx);
 830        }
 831
 832        self.establish_websocket_connection(credentials, cx)
 833    }
 834
 835    fn establish_websocket_connection(
 836        self: &Arc<Self>,
 837        credentials: &Credentials,
 838        cx: &AsyncAppContext,
 839    ) -> Task<Result<Connection, EstablishConnectionError>> {
 840        let request = Request::builder()
 841            .header(
 842                "Authorization",
 843                format!("{} {}", credentials.user_id, credentials.access_token),
 844            )
 845            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
 846
 847        let http = self.http.clone();
 848        cx.background().spawn(async move {
 849            let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
 850            let rpc_response = http.get(&rpc_url, Default::default(), false).await?;
 851            if rpc_response.status().is_redirection() {
 852                rpc_url = rpc_response
 853                    .headers()
 854                    .get("Location")
 855                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 856                    .to_str()
 857                    .map_err(EstablishConnectionError::other)?
 858                    .to_string();
 859            }
 860            // Until we switch the zed.dev domain to point to the new Next.js app, there
 861            // will be no redirect required, and the app will connect directly to
 862            // wss://zed.dev/rpc.
 863            else if rpc_response.status() != StatusCode::UPGRADE_REQUIRED {
 864                Err(anyhow!(
 865                    "unexpected /rpc response status {}",
 866                    rpc_response.status()
 867                ))?
 868            }
 869
 870            let mut rpc_url = Url::parse(&rpc_url).context("invalid rpc url")?;
 871            let rpc_host = rpc_url
 872                .host_str()
 873                .zip(rpc_url.port_or_known_default())
 874                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
 875            let stream = smol::net::TcpStream::connect(rpc_host).await?;
 876
 877            log::info!("connected to rpc endpoint {}", rpc_url);
 878
 879            match rpc_url.scheme() {
 880                "https" => {
 881                    rpc_url.set_scheme("wss").unwrap();
 882                    let request = request.uri(rpc_url.as_str()).body(())?;
 883                    let (stream, _) =
 884                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
 885                    Ok(Connection::new(
 886                        stream
 887                            .map_err(|error| anyhow!(error))
 888                            .sink_map_err(|error| anyhow!(error)),
 889                    ))
 890                }
 891                "http" => {
 892                    rpc_url.set_scheme("ws").unwrap();
 893                    let request = request.uri(rpc_url.as_str()).body(())?;
 894                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
 895                    Ok(Connection::new(
 896                        stream
 897                            .map_err(|error| anyhow!(error))
 898                            .sink_map_err(|error| anyhow!(error)),
 899                    ))
 900                }
 901                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
 902            }
 903        })
 904    }
 905
 906    pub fn authenticate_with_browser(
 907        self: &Arc<Self>,
 908        cx: &AsyncAppContext,
 909    ) -> Task<Result<Credentials>> {
 910        let platform = cx.platform();
 911        let executor = cx.background();
 912        let telemetry = self.telemetry.clone();
 913        executor.clone().spawn(async move {
 914            // Generate a pair of asymmetric encryption keys. The public key will be used by the
 915            // zed server to encrypt the user's access token, so that it can'be intercepted by
 916            // any other app running on the user's device.
 917            let (public_key, private_key) =
 918                rpc::auth::keypair().expect("failed to generate keypair for auth");
 919            let public_key_string =
 920                String::try_from(public_key).expect("failed to serialize public key for auth");
 921
 922            // Start an HTTP server to receive the redirect from Zed's sign-in page.
 923            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
 924            let port = server.server_addr().port();
 925
 926            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
 927            // that the user is signing in from a Zed app running on the same device.
 928            let mut url = format!(
 929                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
 930                *ZED_SERVER_URL, port, public_key_string
 931            );
 932
 933            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
 934                log::info!("impersonating user @{}", impersonate_login);
 935                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
 936            }
 937
 938            platform.open_url(&url);
 939
 940            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
 941            // access token from the query params.
 942            //
 943            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
 944            // custom URL scheme instead of this local HTTP server.
 945            let (user_id, access_token) = executor
 946                .spawn(async move {
 947                    for _ in 0..100 {
 948                        if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
 949                            let path = req.url();
 950                            let mut user_id = None;
 951                            let mut access_token = None;
 952                            let url = Url::parse(&format!("http://example.com{}", path))
 953                                .context("failed to parse login notification url")?;
 954                            for (key, value) in url.query_pairs() {
 955                                if key == "access_token" {
 956                                    access_token = Some(value.to_string());
 957                                } else if key == "user_id" {
 958                                    user_id = Some(value.to_string());
 959                                }
 960                            }
 961
 962                            let post_auth_url =
 963                                format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
 964                            req.respond(
 965                                tiny_http::Response::empty(302).with_header(
 966                                    tiny_http::Header::from_bytes(
 967                                        &b"Location"[..],
 968                                        post_auth_url.as_bytes(),
 969                                    )
 970                                    .unwrap(),
 971                                ),
 972                            )
 973                            .context("failed to respond to login http request")?;
 974                            return Ok((
 975                                user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
 976                                access_token
 977                                    .ok_or_else(|| anyhow!("missing access_token parameter"))?,
 978                            ));
 979                        }
 980                    }
 981
 982                    Err(anyhow!("didn't receive login redirect"))
 983                })
 984                .await?;
 985
 986            let access_token = private_key
 987                .decrypt_string(&access_token)
 988                .context("failed to decrypt access token")?;
 989            platform.activate(true);
 990
 991            telemetry.report_event("authenticate with browser", Default::default());
 992
 993            Ok(Credentials {
 994                user_id: user_id.parse()?,
 995                access_token,
 996            })
 997        })
 998    }
 999
1000    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
1001        let conn_id = self.connection_id()?;
1002        self.peer.disconnect(conn_id);
1003        self.set_status(Status::SignedOut, cx);
1004        Ok(())
1005    }
1006
1007    fn connection_id(&self) -> Result<ConnectionId> {
1008        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1009            Ok(connection_id)
1010        } else {
1011            Err(anyhow!("not connected"))
1012        }
1013    }
1014
1015    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1016        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1017        self.peer.send(self.connection_id()?, message)
1018    }
1019
1020    pub fn request<T: RequestMessage>(
1021        &self,
1022        request: T,
1023    ) -> impl Future<Output = Result<T::Response>> {
1024        let client_id = self.id;
1025        log::debug!(
1026            "rpc request start. client_id:{}. name:{}",
1027            client_id,
1028            T::NAME
1029        );
1030        let response = self
1031            .connection_id()
1032            .map(|conn_id| self.peer.request(conn_id, request));
1033        async move {
1034            let response = response?.await;
1035            log::debug!(
1036                "rpc request finish. client_id:{}. name:{}",
1037                client_id,
1038                T::NAME
1039            );
1040            response
1041        }
1042    }
1043
1044    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1045        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1046        self.peer.respond(receipt, response)
1047    }
1048
1049    fn respond_with_error<T: RequestMessage>(
1050        &self,
1051        receipt: Receipt<T>,
1052        error: proto::Error,
1053    ) -> Result<()> {
1054        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1055        self.peer.respond_with_error(receipt, error)
1056    }
1057
1058    pub fn start_telemetry(&self, db: Arc<Db>) {
1059        self.telemetry.start(db);
1060    }
1061
1062    pub fn report_event(&self, kind: &str, properties: Value) {
1063        self.telemetry.report_event(kind, properties)
1064    }
1065
1066    pub fn telemetry_log_file_path(&self) -> Option<PathBuf> {
1067        self.telemetry.log_file_path()
1068    }
1069}
1070
1071impl AnyWeakEntityHandle {
1072    fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
1073        match self {
1074            AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
1075            AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
1076        }
1077    }
1078}
1079
1080fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1081    if IMPERSONATE_LOGIN.is_some() {
1082        return None;
1083    }
1084
1085    let (user_id, access_token) = cx
1086        .platform()
1087        .read_credentials(&ZED_SERVER_URL)
1088        .log_err()
1089        .flatten()?;
1090    Some(Credentials {
1091        user_id: user_id.parse().ok()?,
1092        access_token: String::from_utf8(access_token).ok()?,
1093    })
1094}
1095
1096fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1097    cx.platform().write_credentials(
1098        &ZED_SERVER_URL,
1099        &credentials.user_id.to_string(),
1100        credentials.access_token.as_bytes(),
1101    )
1102}
1103
1104const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1105
1106pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1107    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1108}
1109
1110pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1111    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1112    let mut parts = path.split('/');
1113    let id = parts.next()?.parse::<u64>().ok()?;
1114    let access_token = parts.next()?;
1115    if access_token.is_empty() {
1116        return None;
1117    }
1118    Some((id, access_token.to_string()))
1119}
1120
1121#[cfg(test)]
1122mod tests {
1123    use super::*;
1124    use crate::test::{FakeHttpClient, FakeServer};
1125    use gpui::{executor::Deterministic, TestAppContext};
1126    use parking_lot::Mutex;
1127    use std::future;
1128
1129    #[gpui::test(iterations = 10)]
1130    async fn test_reconnection(cx: &mut TestAppContext) {
1131        cx.foreground().forbid_parking();
1132
1133        let user_id = 5;
1134        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1135        let server = FakeServer::for_client(user_id, &client, cx).await;
1136        let mut status = client.status();
1137        assert!(matches!(
1138            status.next().await,
1139            Some(Status::Connected { .. })
1140        ));
1141        assert_eq!(server.auth_count(), 1);
1142
1143        server.forbid_connections();
1144        server.disconnect();
1145        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1146
1147        server.allow_connections();
1148        cx.foreground().advance_clock(Duration::from_secs(10));
1149        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1150        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1151
1152        server.forbid_connections();
1153        server.disconnect();
1154        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1155
1156        // Clear cached credentials after authentication fails
1157        server.roll_access_token();
1158        server.allow_connections();
1159        cx.foreground().advance_clock(Duration::from_secs(10));
1160        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1161        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1162    }
1163
1164    #[gpui::test(iterations = 10)]
1165    async fn test_authenticating_more_than_once(
1166        cx: &mut TestAppContext,
1167        deterministic: Arc<Deterministic>,
1168    ) {
1169        cx.foreground().forbid_parking();
1170
1171        let auth_count = Arc::new(Mutex::new(0));
1172        let dropped_auth_count = Arc::new(Mutex::new(0));
1173        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1174        client.override_authenticate({
1175            let auth_count = auth_count.clone();
1176            let dropped_auth_count = dropped_auth_count.clone();
1177            move |cx| {
1178                let auth_count = auth_count.clone();
1179                let dropped_auth_count = dropped_auth_count.clone();
1180                cx.foreground().spawn(async move {
1181                    *auth_count.lock() += 1;
1182                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1183                    future::pending::<()>().await;
1184                    unreachable!()
1185                })
1186            }
1187        });
1188
1189        let _authenticate = cx.spawn(|cx| {
1190            let client = client.clone();
1191            async move { client.authenticate_and_connect(false, &cx).await }
1192        });
1193        deterministic.run_until_parked();
1194        assert_eq!(*auth_count.lock(), 1);
1195        assert_eq!(*dropped_auth_count.lock(), 0);
1196
1197        let _authenticate = cx.spawn(|cx| {
1198            let client = client.clone();
1199            async move { client.authenticate_and_connect(false, &cx).await }
1200        });
1201        deterministic.run_until_parked();
1202        assert_eq!(*auth_count.lock(), 2);
1203        assert_eq!(*dropped_auth_count.lock(), 1);
1204    }
1205
1206    #[test]
1207    fn test_encode_and_decode_worktree_url() {
1208        let url = encode_worktree_url(5, "deadbeef");
1209        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1210        assert_eq!(
1211            decode_worktree_url(&format!("\n {}\t", url)),
1212            Some((5, "deadbeef".to_string()))
1213        );
1214        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1215    }
1216
1217    #[gpui::test]
1218    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1219        cx.foreground().forbid_parking();
1220
1221        let user_id = 5;
1222        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1223        let server = FakeServer::for_client(user_id, &client, cx).await;
1224
1225        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1226        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1227        client.add_model_message_handler(
1228            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1229                match model.read_with(&cx, |model, _| model.id) {
1230                    1 => done_tx1.try_send(()).unwrap(),
1231                    2 => done_tx2.try_send(()).unwrap(),
1232                    _ => unreachable!(),
1233                }
1234                async { Ok(()) }
1235            },
1236        );
1237        let model1 = cx.add_model(|_| Model {
1238            id: 1,
1239            subscription: None,
1240        });
1241        let model2 = cx.add_model(|_| Model {
1242            id: 2,
1243            subscription: None,
1244        });
1245        let model3 = cx.add_model(|_| Model {
1246            id: 3,
1247            subscription: None,
1248        });
1249
1250        let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1251        let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1252        // Ensure dropping a subscription for the same entity type still allows receiving of
1253        // messages for other entity IDs of the same type.
1254        let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1255        drop(subscription3);
1256
1257        server.send(proto::JoinProject { project_id: 1 });
1258        server.send(proto::JoinProject { project_id: 2 });
1259        done_rx1.next().await.unwrap();
1260        done_rx2.next().await.unwrap();
1261    }
1262
1263    #[gpui::test]
1264    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1265        cx.foreground().forbid_parking();
1266
1267        let user_id = 5;
1268        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1269        let server = FakeServer::for_client(user_id, &client, cx).await;
1270
1271        let model = cx.add_model(|_| Model::default());
1272        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1273        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1274        let subscription1 = client.add_message_handler(
1275            model.clone(),
1276            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1277                done_tx1.try_send(()).unwrap();
1278                async { Ok(()) }
1279            },
1280        );
1281        drop(subscription1);
1282        let _subscription2 =
1283            client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1284                done_tx2.try_send(()).unwrap();
1285                async { Ok(()) }
1286            });
1287        server.send(proto::Ping {});
1288        done_rx2.next().await.unwrap();
1289    }
1290
1291    #[gpui::test]
1292    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1293        cx.foreground().forbid_parking();
1294
1295        let user_id = 5;
1296        let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1297        let server = FakeServer::for_client(user_id, &client, cx).await;
1298
1299        let model = cx.add_model(|_| Model::default());
1300        let (done_tx, mut done_rx) = smol::channel::unbounded();
1301        let subscription = client.add_message_handler(
1302            model.clone(),
1303            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1304                model.update(&mut cx, |model, _| model.subscription.take());
1305                done_tx.try_send(()).unwrap();
1306                async { Ok(()) }
1307            },
1308        );
1309        model.update(cx, |model, _| {
1310            model.subscription = Some(subscription);
1311        });
1312        server.send(proto::Ping {});
1313        done_rx.next().await.unwrap();
1314    }
1315
1316    #[derive(Default)]
1317    struct Model {
1318        id: usize,
1319        subscription: Option<Subscription>,
1320    }
1321
1322    impl Entity for Model {
1323        type Event = ();
1324    }
1325}