client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod channel;
   5pub mod http;
   6pub mod user;
   7
   8use anyhow::{anyhow, Context, Result};
   9use async_recursion::async_recursion;
  10use async_tungstenite::tungstenite::{
  11    error::Error as WebsocketError,
  12    http::{Request, StatusCode},
  13};
  14use futures::{future::LocalBoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt};
  15use gpui::{
  16    actions, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AsyncAppContext,
  17    Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle,
  18};
  19use http::HttpClient;
  20use lazy_static::lazy_static;
  21use parking_lot::RwLock;
  22use postage::watch;
  23use rand::prelude::*;
  24use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
  25use std::{
  26    any::TypeId,
  27    collections::HashMap,
  28    convert::TryFrom,
  29    fmt::Write as _,
  30    future::Future,
  31    sync::{Arc, Weak},
  32    time::{Duration, Instant},
  33};
  34use thiserror::Error;
  35use url::Url;
  36use util::{ResultExt, TryFutureExt};
  37
  38pub use channel::*;
  39pub use rpc::*;
  40pub use user::*;
  41
  42lazy_static! {
  43    pub static ref ZED_SERVER_URL: String =
  44        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
  45    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  46        .ok()
  47        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  48}
  49
  50pub const ZED_SECRET_CLIENT_TOKEN: &'static str = "618033988749894";
  51
  52actions!(client, [Authenticate]);
  53
  54pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
  55    cx.add_global_action(move |_: &Authenticate, cx| {
  56        let rpc = rpc.clone();
  57        cx.spawn(|cx| async move { rpc.authenticate_and_connect(true, &cx).log_err().await })
  58            .detach();
  59    });
  60}
  61
  62pub struct Client {
  63    id: usize,
  64    peer: Arc<Peer>,
  65    http: Arc<dyn HttpClient>,
  66    state: RwLock<ClientState>,
  67
  68    #[cfg(any(test, feature = "test-support"))]
  69    authenticate: RwLock<
  70        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
  71    >,
  72    #[cfg(any(test, feature = "test-support"))]
  73    establish_connection: RwLock<
  74        Option<
  75            Box<
  76                dyn 'static
  77                    + Send
  78                    + Sync
  79                    + Fn(
  80                        &Credentials,
  81                        &AsyncAppContext,
  82                    ) -> Task<Result<Connection, EstablishConnectionError>>,
  83            >,
  84        >,
  85    >,
  86}
  87
  88#[derive(Error, Debug)]
  89pub enum EstablishConnectionError {
  90    #[error("upgrade required")]
  91    UpgradeRequired,
  92    #[error("unauthorized")]
  93    Unauthorized,
  94    #[error("{0}")]
  95    Other(#[from] anyhow::Error),
  96    #[error("{0}")]
  97    Http(#[from] http::Error),
  98    #[error("{0}")]
  99    Io(#[from] std::io::Error),
 100    #[error("{0}")]
 101    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 102}
 103
 104impl From<WebsocketError> for EstablishConnectionError {
 105    fn from(error: WebsocketError) -> Self {
 106        if let WebsocketError::Http(response) = &error {
 107            match response.status() {
 108                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 109                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 110                _ => {}
 111            }
 112        }
 113        EstablishConnectionError::Other(error.into())
 114    }
 115}
 116
 117impl EstablishConnectionError {
 118    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 119        Self::Other(error.into())
 120    }
 121}
 122
 123#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 124pub enum Status {
 125    SignedOut,
 126    UpgradeRequired,
 127    Authenticating,
 128    Connecting,
 129    ConnectionError,
 130    Connected { connection_id: ConnectionId },
 131    ConnectionLost,
 132    Reauthenticating,
 133    Reconnecting,
 134    ReconnectionError { next_reconnection: Instant },
 135}
 136
 137impl Status {
 138    pub fn is_connected(&self) -> bool {
 139        matches!(self, Self::Connected { .. })
 140    }
 141}
 142
 143struct ClientState {
 144    credentials: Option<Credentials>,
 145    status: (watch::Sender<Status>, watch::Receiver<Status>),
 146    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 147    _reconnect_task: Option<Task<()>>,
 148    reconnect_interval: Duration,
 149    entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
 150    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 151    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 152    message_handlers: HashMap<
 153        TypeId,
 154        Arc<
 155            dyn Send
 156                + Sync
 157                + Fn(
 158                    AnyEntityHandle,
 159                    Box<dyn AnyTypedEnvelope>,
 160                    &Arc<Client>,
 161                    AsyncAppContext,
 162                ) -> LocalBoxFuture<'static, Result<()>>,
 163        >,
 164    >,
 165}
 166
 167enum AnyWeakEntityHandle {
 168    Model(AnyWeakModelHandle),
 169    View(AnyWeakViewHandle),
 170}
 171
 172enum AnyEntityHandle {
 173    Model(AnyModelHandle),
 174    View(AnyViewHandle),
 175}
 176
 177#[derive(Clone, Debug)]
 178pub struct Credentials {
 179    pub user_id: u64,
 180    pub access_token: String,
 181}
 182
 183impl Default for ClientState {
 184    fn default() -> Self {
 185        Self {
 186            credentials: None,
 187            status: watch::channel_with(Status::SignedOut),
 188            entity_id_extractors: Default::default(),
 189            _reconnect_task: None,
 190            reconnect_interval: Duration::from_secs(5),
 191            models_by_message_type: Default::default(),
 192            entities_by_type_and_remote_id: Default::default(),
 193            entity_types_by_message_type: Default::default(),
 194            message_handlers: Default::default(),
 195        }
 196    }
 197}
 198
 199pub enum Subscription {
 200    Entity {
 201        client: Weak<Client>,
 202        id: (TypeId, u64),
 203    },
 204    Message {
 205        client: Weak<Client>,
 206        id: TypeId,
 207    },
 208}
 209
 210impl Drop for Subscription {
 211    fn drop(&mut self) {
 212        match self {
 213            Subscription::Entity { client, id } => {
 214                if let Some(client) = client.upgrade() {
 215                    let mut state = client.state.write();
 216                    let _ = state.entities_by_type_and_remote_id.remove(id);
 217                }
 218            }
 219            Subscription::Message { client, id } => {
 220                if let Some(client) = client.upgrade() {
 221                    let mut state = client.state.write();
 222                    let _ = state.entity_types_by_message_type.remove(id);
 223                    let _ = state.message_handlers.remove(id);
 224                }
 225            }
 226        }
 227    }
 228}
 229
 230impl Client {
 231    pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
 232        Arc::new(Self {
 233            id: 0,
 234            peer: Peer::new(),
 235            http,
 236            state: Default::default(),
 237
 238            #[cfg(any(test, feature = "test-support"))]
 239            authenticate: Default::default(),
 240            #[cfg(any(test, feature = "test-support"))]
 241            establish_connection: Default::default(),
 242        })
 243    }
 244
 245    pub fn id(&self) -> usize {
 246        self.id
 247    }
 248
 249    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 250        self.http.clone()
 251    }
 252
 253    #[cfg(any(test, feature = "test-support"))]
 254    pub fn set_id(&mut self, id: usize) -> &Self {
 255        self.id = id;
 256        self
 257    }
 258
 259    #[cfg(any(test, feature = "test-support"))]
 260    pub fn tear_down(&self) {
 261        let mut state = self.state.write();
 262        state._reconnect_task.take();
 263        state.message_handlers.clear();
 264        state.models_by_message_type.clear();
 265        state.entities_by_type_and_remote_id.clear();
 266        state.entity_id_extractors.clear();
 267        self.peer.reset();
 268    }
 269
 270    #[cfg(any(test, feature = "test-support"))]
 271    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 272    where
 273        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 274    {
 275        *self.authenticate.write() = Some(Box::new(authenticate));
 276        self
 277    }
 278
 279    #[cfg(any(test, feature = "test-support"))]
 280    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 281    where
 282        F: 'static
 283            + Send
 284            + Sync
 285            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 286    {
 287        *self.establish_connection.write() = Some(Box::new(connect));
 288        self
 289    }
 290
 291    pub fn user_id(&self) -> Option<u64> {
 292        self.state
 293            .read()
 294            .credentials
 295            .as_ref()
 296            .map(|credentials| credentials.user_id)
 297    }
 298
 299    pub fn status(&self) -> watch::Receiver<Status> {
 300        self.state.read().status.1.clone()
 301    }
 302
 303    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 304        log::info!("set status on client {}: {:?}", self.id, status);
 305        let mut state = self.state.write();
 306        *state.status.0.borrow_mut() = status;
 307
 308        match status {
 309            Status::Connected { .. } => {
 310                state._reconnect_task = None;
 311            }
 312            Status::ConnectionLost => {
 313                let this = self.clone();
 314                let reconnect_interval = state.reconnect_interval;
 315                state._reconnect_task = Some(cx.spawn(|cx| async move {
 316                    let mut rng = StdRng::from_entropy();
 317                    let mut delay = Duration::from_millis(100);
 318                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 319                        log::error!("failed to connect {}", error);
 320                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 321                            this.set_status(
 322                                Status::ReconnectionError {
 323                                    next_reconnection: Instant::now() + delay,
 324                                },
 325                                &cx,
 326                            );
 327                            cx.background().timer(delay).await;
 328                            delay = delay
 329                                .mul_f32(rng.gen_range(1.0..=2.0))
 330                                .min(reconnect_interval);
 331                        } else {
 332                            break;
 333                        }
 334                    }
 335                }));
 336            }
 337            Status::SignedOut | Status::UpgradeRequired => {
 338                state._reconnect_task.take();
 339            }
 340            _ => {}
 341        }
 342    }
 343
 344    pub fn add_view_for_remote_entity<T: View>(
 345        self: &Arc<Self>,
 346        remote_id: u64,
 347        cx: &mut ViewContext<T>,
 348    ) -> Subscription {
 349        let id = (TypeId::of::<T>(), remote_id);
 350        self.state
 351            .write()
 352            .entities_by_type_and_remote_id
 353            .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
 354        Subscription::Entity {
 355            client: Arc::downgrade(self),
 356            id,
 357        }
 358    }
 359
 360    pub fn add_model_for_remote_entity<T: Entity>(
 361        self: &Arc<Self>,
 362        remote_id: u64,
 363        cx: &mut ModelContext<T>,
 364    ) -> Subscription {
 365        let id = (TypeId::of::<T>(), remote_id);
 366        self.state
 367            .write()
 368            .entities_by_type_and_remote_id
 369            .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
 370        Subscription::Entity {
 371            client: Arc::downgrade(self),
 372            id,
 373        }
 374    }
 375
 376    pub fn add_message_handler<M, E, H, F>(
 377        self: &Arc<Self>,
 378        model: ModelHandle<E>,
 379        handler: H,
 380    ) -> Subscription
 381    where
 382        M: EnvelopedMessage,
 383        E: Entity,
 384        H: 'static
 385            + Send
 386            + Sync
 387            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 388        F: 'static + Future<Output = Result<()>>,
 389    {
 390        let message_type_id = TypeId::of::<M>();
 391
 392        let mut state = self.state.write();
 393        state
 394            .models_by_message_type
 395            .insert(message_type_id, model.downgrade().into());
 396
 397        let prev_handler = state.message_handlers.insert(
 398            message_type_id,
 399            Arc::new(move |handle, envelope, client, cx| {
 400                let handle = if let AnyEntityHandle::Model(handle) = handle {
 401                    handle
 402                } else {
 403                    unreachable!();
 404                };
 405                let model = handle.downcast::<E>().unwrap();
 406                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 407                handler(model, *envelope, client.clone(), cx).boxed_local()
 408            }),
 409        );
 410        if prev_handler.is_some() {
 411            panic!("registered handler for the same message twice");
 412        }
 413
 414        Subscription::Message {
 415            client: Arc::downgrade(self),
 416            id: message_type_id,
 417        }
 418    }
 419
 420    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 421    where
 422        M: EntityMessage,
 423        E: View,
 424        H: 'static
 425            + Send
 426            + Sync
 427            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 428        F: 'static + Future<Output = Result<()>>,
 429    {
 430        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 431            if let AnyEntityHandle::View(handle) = handle {
 432                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 433            } else {
 434                unreachable!();
 435            }
 436        })
 437    }
 438
 439    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 440    where
 441        M: EntityMessage,
 442        E: Entity,
 443        H: 'static
 444            + Send
 445            + Sync
 446            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 447        F: 'static + Future<Output = Result<()>>,
 448    {
 449        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 450            if let AnyEntityHandle::Model(handle) = handle {
 451                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 452            } else {
 453                unreachable!();
 454            }
 455        })
 456    }
 457
 458    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 459    where
 460        M: EntityMessage,
 461        E: Entity,
 462        H: 'static
 463            + Send
 464            + Sync
 465            + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 466        F: 'static + Future<Output = Result<()>>,
 467    {
 468        let model_type_id = TypeId::of::<E>();
 469        let message_type_id = TypeId::of::<M>();
 470
 471        let mut state = self.state.write();
 472        state
 473            .entity_types_by_message_type
 474            .insert(message_type_id, model_type_id);
 475        state
 476            .entity_id_extractors
 477            .entry(message_type_id)
 478            .or_insert_with(|| {
 479                |envelope| {
 480                    envelope
 481                        .as_any()
 482                        .downcast_ref::<TypedEnvelope<M>>()
 483                        .unwrap()
 484                        .payload
 485                        .remote_entity_id()
 486                }
 487            });
 488        let prev_handler = state.message_handlers.insert(
 489            message_type_id,
 490            Arc::new(move |handle, envelope, client, cx| {
 491                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 492                handler(handle, *envelope, client.clone(), cx).boxed_local()
 493            }),
 494        );
 495        if prev_handler.is_some() {
 496            panic!("registered handler for the same message twice");
 497        }
 498    }
 499
 500    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 501    where
 502        M: EntityMessage + RequestMessage,
 503        E: Entity,
 504        H: 'static
 505            + Send
 506            + Sync
 507            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 508        F: 'static + Future<Output = Result<M::Response>>,
 509    {
 510        self.add_model_message_handler(move |entity, envelope, client, cx| {
 511            Self::respond_to_request::<M, _>(
 512                envelope.receipt(),
 513                handler(entity, envelope, client.clone(), cx),
 514                client,
 515            )
 516        })
 517    }
 518
 519    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 520    where
 521        M: EntityMessage + RequestMessage,
 522        E: View,
 523        H: 'static
 524            + Send
 525            + Sync
 526            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 527        F: 'static + Future<Output = Result<M::Response>>,
 528    {
 529        self.add_view_message_handler(move |entity, envelope, client, cx| {
 530            Self::respond_to_request::<M, _>(
 531                envelope.receipt(),
 532                handler(entity, envelope, client.clone(), cx),
 533                client,
 534            )
 535        })
 536    }
 537
 538    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 539        receipt: Receipt<T>,
 540        response: F,
 541        client: Arc<Self>,
 542    ) -> Result<()> {
 543        match response.await {
 544            Ok(response) => {
 545                client.respond(receipt, response)?;
 546                Ok(())
 547            }
 548            Err(error) => {
 549                client.respond_with_error(
 550                    receipt,
 551                    proto::Error {
 552                        message: format!("{:?}", error),
 553                    },
 554                )?;
 555                Err(error)
 556            }
 557        }
 558    }
 559
 560    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 561        read_credentials_from_keychain(cx).is_some()
 562    }
 563
 564    #[async_recursion(?Send)]
 565    pub async fn authenticate_and_connect(
 566        self: &Arc<Self>,
 567        try_keychain: bool,
 568        cx: &AsyncAppContext,
 569    ) -> anyhow::Result<()> {
 570        let was_disconnected = match *self.status().borrow() {
 571            Status::SignedOut => true,
 572            Status::ConnectionError
 573            | Status::ConnectionLost
 574            | Status::Authenticating { .. }
 575            | Status::Reauthenticating { .. }
 576            | Status::ReconnectionError { .. } => false,
 577            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
 578                return Ok(())
 579            }
 580            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 581        };
 582
 583        if was_disconnected {
 584            self.set_status(Status::Authenticating, cx);
 585        } else {
 586            self.set_status(Status::Reauthenticating, cx)
 587        }
 588
 589        let mut read_from_keychain = false;
 590        let mut credentials = self.state.read().credentials.clone();
 591        if credentials.is_none() && try_keychain {
 592            credentials = read_credentials_from_keychain(cx);
 593            read_from_keychain = credentials.is_some();
 594        }
 595        if credentials.is_none() {
 596            let mut status_rx = self.status();
 597            let _ = status_rx.next().await;
 598            futures::select_biased! {
 599                authenticate = self.authenticate(&cx).fuse() => {
 600                    match authenticate {
 601                        Ok(creds) => credentials = Some(creds),
 602                        Err(err) => {
 603                            self.set_status(Status::ConnectionError, cx);
 604                            return Err(err);
 605                        }
 606                    }
 607                }
 608                _ = status_rx.next().fuse() => {
 609                    return Err(anyhow!("authentication canceled"));
 610                }
 611            }
 612        }
 613        let credentials = credentials.unwrap();
 614
 615        if was_disconnected {
 616            self.set_status(Status::Connecting, cx);
 617        } else {
 618            self.set_status(Status::Reconnecting, cx);
 619        }
 620
 621        match self.establish_connection(&credentials, cx).await {
 622            Ok(conn) => {
 623                self.state.write().credentials = Some(credentials.clone());
 624                if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 625                    write_credentials_to_keychain(&credentials, cx).log_err();
 626                }
 627                self.set_connection(conn, cx).await;
 628                Ok(())
 629            }
 630            Err(EstablishConnectionError::Unauthorized) => {
 631                self.state.write().credentials.take();
 632                if read_from_keychain {
 633                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 634                    self.set_status(Status::SignedOut, cx);
 635                    self.authenticate_and_connect(false, cx).await
 636                } else {
 637                    self.set_status(Status::ConnectionError, cx);
 638                    Err(EstablishConnectionError::Unauthorized)?
 639                }
 640            }
 641            Err(EstablishConnectionError::UpgradeRequired) => {
 642                self.set_status(Status::UpgradeRequired, cx);
 643                Err(EstablishConnectionError::UpgradeRequired)?
 644            }
 645            Err(error) => {
 646                self.set_status(Status::ConnectionError, cx);
 647                Err(error)?
 648            }
 649        }
 650    }
 651
 652    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
 653        let executor = cx.background();
 654        log::info!("add connection to peer");
 655        let (connection_id, handle_io, mut incoming) = self
 656            .peer
 657            .add_connection(conn, move |duration| executor.timer(duration))
 658            .await;
 659        log::info!("set status to connected {}", connection_id);
 660        self.set_status(Status::Connected { connection_id }, cx);
 661        cx.foreground()
 662            .spawn({
 663                let cx = cx.clone();
 664                let this = self.clone();
 665                async move {
 666                    let mut message_id = 0_usize;
 667                    while let Some(message) = incoming.next().await {
 668                        let mut state = this.state.write();
 669                        message_id += 1;
 670                        let type_name = message.payload_type_name();
 671                        let payload_type_id = message.payload_type_id();
 672                        let sender_id = message.original_sender_id().map(|id| id.0);
 673
 674                        let model = state
 675                            .models_by_message_type
 676                            .get(&payload_type_id)
 677                            .and_then(|model| model.upgrade(&cx))
 678                            .map(AnyEntityHandle::Model)
 679                            .or_else(|| {
 680                                let entity_type_id =
 681                                    *state.entity_types_by_message_type.get(&payload_type_id)?;
 682                                let entity_id = state
 683                                    .entity_id_extractors
 684                                    .get(&message.payload_type_id())
 685                                    .map(|extract_entity_id| {
 686                                        (extract_entity_id)(message.as_ref())
 687                                    })?;
 688
 689                                let entity = state
 690                                    .entities_by_type_and_remote_id
 691                                    .get(&(entity_type_id, entity_id))?;
 692                                if let Some(entity) = entity.upgrade(&cx) {
 693                                    Some(entity)
 694                                } else {
 695                                    state
 696                                        .entities_by_type_and_remote_id
 697                                        .remove(&(entity_type_id, entity_id));
 698                                    None
 699                                }
 700                            });
 701
 702                        let model = if let Some(model) = model {
 703                            model
 704                        } else {
 705                            log::info!("unhandled message {}", type_name);
 706                            continue;
 707                        };
 708
 709                        if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
 710                        {
 711                            drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
 712                            let future = handler(model, message, &this, cx.clone());
 713
 714                            let client_id = this.id;
 715                            log::debug!(
 716                                "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 717                                client_id,
 718                                message_id,
 719                                sender_id,
 720                                type_name
 721                            );
 722                            cx.foreground()
 723                                .spawn(async move {
 724                                    match future.await {
 725                                        Ok(()) => {
 726                                            log::debug!(
 727                                                "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 728                                                client_id,
 729                                                message_id,
 730                                                sender_id,
 731                                                type_name
 732                                            );
 733                                        }
 734                                        Err(error) => {
 735                                            log::error!(
 736                                                "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
 737                                                client_id,
 738                                                message_id,
 739                                                sender_id,
 740                                                type_name,
 741                                                error
 742                                            );
 743                                        }
 744                                    }
 745                                })
 746                                .detach();
 747                        } else {
 748                            log::info!("unhandled message {}", type_name);
 749                        }
 750
 751                        // Don't starve the main thread when receiving lots of messages at once.
 752                        smol::future::yield_now().await;
 753                    }
 754                }
 755            })
 756            .detach();
 757
 758        let handle_io = cx.background().spawn(handle_io);
 759        let this = self.clone();
 760        let cx = cx.clone();
 761        cx.foreground()
 762            .spawn(async move {
 763                match handle_io.await {
 764                    Ok(()) => {
 765                        if *this.status().borrow() == (Status::Connected { connection_id }) {
 766                            this.set_status(Status::SignedOut, &cx);
 767                        }
 768                    }
 769                    Err(err) => {
 770                        log::error!("connection error: {:?}", err);
 771                        this.set_status(Status::ConnectionLost, &cx);
 772                    }
 773                }
 774            })
 775            .detach();
 776    }
 777
 778    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 779        #[cfg(any(test, feature = "test-support"))]
 780        if let Some(callback) = self.authenticate.read().as_ref() {
 781            return callback(cx);
 782        }
 783
 784        self.authenticate_with_browser(cx)
 785    }
 786
 787    fn establish_connection(
 788        self: &Arc<Self>,
 789        credentials: &Credentials,
 790        cx: &AsyncAppContext,
 791    ) -> Task<Result<Connection, EstablishConnectionError>> {
 792        #[cfg(any(test, feature = "test-support"))]
 793        if let Some(callback) = self.establish_connection.read().as_ref() {
 794            return callback(credentials, cx);
 795        }
 796
 797        self.establish_websocket_connection(credentials, cx)
 798    }
 799
 800    fn establish_websocket_connection(
 801        self: &Arc<Self>,
 802        credentials: &Credentials,
 803        cx: &AsyncAppContext,
 804    ) -> Task<Result<Connection, EstablishConnectionError>> {
 805        let request = Request::builder()
 806            .header(
 807                "Authorization",
 808                format!("{} {}", credentials.user_id, credentials.access_token),
 809            )
 810            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
 811
 812        let http = self.http.clone();
 813        cx.background().spawn(async move {
 814            let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
 815            let rpc_response = http.get(&rpc_url, Default::default(), false).await?;
 816            if rpc_response.status().is_redirection() {
 817                rpc_url = rpc_response
 818                    .headers()
 819                    .get("Location")
 820                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 821                    .to_str()
 822                    .map_err(|error| EstablishConnectionError::other(error))?
 823                    .to_string();
 824            }
 825            // Until we switch the zed.dev domain to point to the new Next.js app, there
 826            // will be no redirect required, and the app will connect directly to
 827            // wss://zed.dev/rpc.
 828            else if rpc_response.status() != StatusCode::UPGRADE_REQUIRED {
 829                Err(anyhow!(
 830                    "unexpected /rpc response status {}",
 831                    rpc_response.status()
 832                ))?
 833            }
 834
 835            let mut rpc_url = Url::parse(&rpc_url).context("invalid rpc url")?;
 836            let rpc_host = rpc_url
 837                .host_str()
 838                .zip(rpc_url.port_or_known_default())
 839                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
 840            let stream = smol::net::TcpStream::connect(rpc_host).await?;
 841
 842            log::info!("connected to rpc endpoint {}", rpc_url);
 843
 844            match rpc_url.scheme() {
 845                "https" => {
 846                    rpc_url.set_scheme("wss").unwrap();
 847                    let request = request.uri(rpc_url.as_str()).body(())?;
 848                    let (stream, _) =
 849                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
 850                    Ok(Connection::new(
 851                        stream
 852                            .map_err(|error| anyhow!(error))
 853                            .sink_map_err(|error| anyhow!(error)),
 854                    ))
 855                }
 856                "http" => {
 857                    rpc_url.set_scheme("ws").unwrap();
 858                    let request = request.uri(rpc_url.as_str()).body(())?;
 859                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
 860                    Ok(Connection::new(
 861                        stream
 862                            .map_err(|error| anyhow!(error))
 863                            .sink_map_err(|error| anyhow!(error)),
 864                    ))
 865                }
 866                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
 867            }
 868        })
 869    }
 870
 871    pub fn authenticate_with_browser(
 872        self: &Arc<Self>,
 873        cx: &AsyncAppContext,
 874    ) -> Task<Result<Credentials>> {
 875        let platform = cx.platform();
 876        let executor = cx.background();
 877        executor.clone().spawn(async move {
 878            // Generate a pair of asymmetric encryption keys. The public key will be used by the
 879            // zed server to encrypt the user's access token, so that it can'be intercepted by
 880            // any other app running on the user's device.
 881            let (public_key, private_key) =
 882                rpc::auth::keypair().expect("failed to generate keypair for auth");
 883            let public_key_string =
 884                String::try_from(public_key).expect("failed to serialize public key for auth");
 885
 886            // Start an HTTP server to receive the redirect from Zed's sign-in page.
 887            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
 888            let port = server.server_addr().port();
 889
 890            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
 891            // that the user is signing in from a Zed app running on the same device.
 892            let mut url = format!(
 893                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
 894                *ZED_SERVER_URL, port, public_key_string
 895            );
 896
 897            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
 898                log::info!("impersonating user @{}", impersonate_login);
 899                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
 900            }
 901
 902            platform.open_url(&url);
 903
 904            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
 905            // access token from the query params.
 906            //
 907            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
 908            // custom URL scheme instead of this local HTTP server.
 909            let (user_id, access_token) = executor
 910                .spawn(async move {
 911                    for _ in 0..100 {
 912                        if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
 913                            let path = req.url();
 914                            let mut user_id = None;
 915                            let mut access_token = None;
 916                            let url = Url::parse(&format!("http://example.com{}", path))
 917                                .context("failed to parse login notification url")?;
 918                            for (key, value) in url.query_pairs() {
 919                                if key == "access_token" {
 920                                    access_token = Some(value.to_string());
 921                                } else if key == "user_id" {
 922                                    user_id = Some(value.to_string());
 923                                }
 924                            }
 925
 926                            let post_auth_url =
 927                                format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
 928                            req.respond(
 929                                tiny_http::Response::empty(302).with_header(
 930                                    tiny_http::Header::from_bytes(
 931                                        &b"Location"[..],
 932                                        post_auth_url.as_bytes(),
 933                                    )
 934                                    .unwrap(),
 935                                ),
 936                            )
 937                            .context("failed to respond to login http request")?;
 938                            return Ok((
 939                                user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
 940                                access_token
 941                                    .ok_or_else(|| anyhow!("missing access_token parameter"))?,
 942                            ));
 943                        }
 944                    }
 945
 946                    Err(anyhow!("didn't receive login redirect"))
 947                })
 948                .await?;
 949
 950            let access_token = private_key
 951                .decrypt_string(&access_token)
 952                .context("failed to decrypt access token")?;
 953            platform.activate(true);
 954
 955            Ok(Credentials {
 956                user_id: user_id.parse()?,
 957                access_token,
 958            })
 959        })
 960    }
 961
 962    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
 963        let conn_id = self.connection_id()?;
 964        self.peer.disconnect(conn_id);
 965        self.set_status(Status::SignedOut, cx);
 966        Ok(())
 967    }
 968
 969    fn connection_id(&self) -> Result<ConnectionId> {
 970        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
 971            Ok(connection_id)
 972        } else {
 973            Err(anyhow!("not connected"))
 974        }
 975    }
 976
 977    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
 978        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
 979        self.peer.send(self.connection_id()?, message)
 980    }
 981
 982    pub fn request<T: RequestMessage>(
 983        &self,
 984        request: T,
 985    ) -> impl Future<Output = Result<T::Response>> {
 986        let client_id = self.id;
 987        log::debug!(
 988            "rpc request start. client_id:{}. name:{}",
 989            client_id,
 990            T::NAME
 991        );
 992        let response = self
 993            .connection_id()
 994            .map(|conn_id| self.peer.request(conn_id, request));
 995        async move {
 996            let response = response?.await;
 997            log::debug!(
 998                "rpc request finish. client_id:{}. name:{}",
 999                client_id,
1000                T::NAME
1001            );
1002            response
1003        }
1004    }
1005
1006    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1007        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1008        self.peer.respond(receipt, response)
1009    }
1010
1011    fn respond_with_error<T: RequestMessage>(
1012        &self,
1013        receipt: Receipt<T>,
1014        error: proto::Error,
1015    ) -> Result<()> {
1016        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1017        self.peer.respond_with_error(receipt, error)
1018    }
1019}
1020
1021impl AnyWeakEntityHandle {
1022    fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
1023        match self {
1024            AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
1025            AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
1026        }
1027    }
1028}
1029
1030fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1031    if IMPERSONATE_LOGIN.is_some() {
1032        return None;
1033    }
1034
1035    let (user_id, access_token) = cx
1036        .platform()
1037        .read_credentials(&ZED_SERVER_URL)
1038        .log_err()
1039        .flatten()?;
1040    Some(Credentials {
1041        user_id: user_id.parse().ok()?,
1042        access_token: String::from_utf8(access_token).ok()?,
1043    })
1044}
1045
1046fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1047    cx.platform().write_credentials(
1048        &ZED_SERVER_URL,
1049        &credentials.user_id.to_string(),
1050        credentials.access_token.as_bytes(),
1051    )
1052}
1053
1054const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
1055
1056pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1057    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1058}
1059
1060pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1061    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1062    let mut parts = path.split('/');
1063    let id = parts.next()?.parse::<u64>().ok()?;
1064    let access_token = parts.next()?;
1065    if access_token.is_empty() {
1066        return None;
1067    }
1068    Some((id, access_token.to_string()))
1069}
1070
1071#[cfg(test)]
1072mod tests {
1073    use super::*;
1074    use crate::test::{FakeHttpClient, FakeServer};
1075    use gpui::{executor::Deterministic, TestAppContext};
1076    use parking_lot::Mutex;
1077    use std::future;
1078
1079    #[gpui::test(iterations = 10)]
1080    async fn test_reconnection(cx: &mut TestAppContext) {
1081        cx.foreground().forbid_parking();
1082
1083        let user_id = 5;
1084        let mut client = Client::new(FakeHttpClient::with_404_response());
1085        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1086        let mut status = client.status();
1087        assert!(matches!(
1088            status.next().await,
1089            Some(Status::Connected { .. })
1090        ));
1091        assert_eq!(server.auth_count(), 1);
1092
1093        server.forbid_connections();
1094        server.disconnect();
1095        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1096
1097        server.allow_connections();
1098        cx.foreground().advance_clock(Duration::from_secs(10));
1099        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1100        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1101
1102        server.forbid_connections();
1103        server.disconnect();
1104        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1105
1106        // Clear cached credentials after authentication fails
1107        server.roll_access_token();
1108        server.allow_connections();
1109        cx.foreground().advance_clock(Duration::from_secs(10));
1110        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1111        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1112    }
1113
1114    #[gpui::test(iterations = 10)]
1115    async fn test_authenticating_more_than_once(
1116        cx: &mut TestAppContext,
1117        deterministic: Arc<Deterministic>,
1118    ) {
1119        cx.foreground().forbid_parking();
1120
1121        let auth_count = Arc::new(Mutex::new(0));
1122        let dropped_auth_count = Arc::new(Mutex::new(0));
1123        let client = Client::new(FakeHttpClient::with_404_response());
1124        client.override_authenticate({
1125            let auth_count = auth_count.clone();
1126            let dropped_auth_count = dropped_auth_count.clone();
1127            move |cx| {
1128                let auth_count = auth_count.clone();
1129                let dropped_auth_count = dropped_auth_count.clone();
1130                cx.foreground().spawn(async move {
1131                    *auth_count.lock() += 1;
1132                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1133                    future::pending::<()>().await;
1134                    unreachable!()
1135                })
1136            }
1137        });
1138
1139        let _authenticate = cx.spawn(|cx| {
1140            let client = client.clone();
1141            async move { client.authenticate_and_connect(false, &cx).await }
1142        });
1143        deterministic.run_until_parked();
1144        assert_eq!(*auth_count.lock(), 1);
1145        assert_eq!(*dropped_auth_count.lock(), 0);
1146
1147        let _authenticate = cx.spawn(|cx| {
1148            let client = client.clone();
1149            async move { client.authenticate_and_connect(false, &cx).await }
1150        });
1151        deterministic.run_until_parked();
1152        assert_eq!(*auth_count.lock(), 2);
1153        assert_eq!(*dropped_auth_count.lock(), 1);
1154    }
1155
1156    #[test]
1157    fn test_encode_and_decode_worktree_url() {
1158        let url = encode_worktree_url(5, "deadbeef");
1159        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1160        assert_eq!(
1161            decode_worktree_url(&format!("\n {}\t", url)),
1162            Some((5, "deadbeef".to_string()))
1163        );
1164        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1165    }
1166
1167    #[gpui::test]
1168    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1169        cx.foreground().forbid_parking();
1170
1171        let user_id = 5;
1172        let mut client = Client::new(FakeHttpClient::with_404_response());
1173        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1174
1175        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1176        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1177        client.add_model_message_handler(
1178            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1179                match model.read_with(&cx, |model, _| model.id) {
1180                    1 => done_tx1.try_send(()).unwrap(),
1181                    2 => done_tx2.try_send(()).unwrap(),
1182                    _ => unreachable!(),
1183                }
1184                async { Ok(()) }
1185            },
1186        );
1187        let model1 = cx.add_model(|_| Model {
1188            id: 1,
1189            subscription: None,
1190        });
1191        let model2 = cx.add_model(|_| Model {
1192            id: 2,
1193            subscription: None,
1194        });
1195        let model3 = cx.add_model(|_| Model {
1196            id: 3,
1197            subscription: None,
1198        });
1199
1200        let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1201        let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1202        // Ensure dropping a subscription for the same entity type still allows receiving of
1203        // messages for other entity IDs of the same type.
1204        let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1205        drop(subscription3);
1206
1207        server.send(proto::JoinProject { project_id: 1 });
1208        server.send(proto::JoinProject { project_id: 2 });
1209        done_rx1.next().await.unwrap();
1210        done_rx2.next().await.unwrap();
1211    }
1212
1213    #[gpui::test]
1214    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1215        cx.foreground().forbid_parking();
1216
1217        let user_id = 5;
1218        let mut client = Client::new(FakeHttpClient::with_404_response());
1219        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1220
1221        let model = cx.add_model(|_| Model::default());
1222        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1223        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1224        let subscription1 = client.add_message_handler(
1225            model.clone(),
1226            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1227                done_tx1.try_send(()).unwrap();
1228                async { Ok(()) }
1229            },
1230        );
1231        drop(subscription1);
1232        let _subscription2 =
1233            client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1234                done_tx2.try_send(()).unwrap();
1235                async { Ok(()) }
1236            });
1237        server.send(proto::Ping {});
1238        done_rx2.next().await.unwrap();
1239    }
1240
1241    #[gpui::test]
1242    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1243        cx.foreground().forbid_parking();
1244
1245        let user_id = 5;
1246        let mut client = Client::new(FakeHttpClient::with_404_response());
1247        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1248
1249        let model = cx.add_model(|_| Model::default());
1250        let (done_tx, mut done_rx) = smol::channel::unbounded();
1251        let subscription = client.add_message_handler(
1252            model.clone(),
1253            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1254                model.update(&mut cx, |model, _| model.subscription.take());
1255                done_tx.try_send(()).unwrap();
1256                async { Ok(()) }
1257            },
1258        );
1259        model.update(cx, |model, _| {
1260            model.subscription = Some(subscription);
1261        });
1262        server.send(proto::Ping {});
1263        done_rx.next().await.unwrap();
1264    }
1265
1266    #[derive(Default)]
1267    struct Model {
1268        id: usize,
1269        subscription: Option<Subscription>,
1270    }
1271
1272    impl Entity for Model {
1273        type Event = ();
1274    }
1275}