client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod channel;
   5pub mod http;
   6pub mod user;
   7
   8use anyhow::{anyhow, Context, Result};
   9use async_recursion::async_recursion;
  10use async_tungstenite::tungstenite::{
  11    error::Error as WebsocketError,
  12    http::{Request, StatusCode},
  13};
  14use futures::{future::LocalBoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt};
  15use gpui::{
  16    actions, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AsyncAppContext,
  17    Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle,
  18};
  19use http::HttpClient;
  20use lazy_static::lazy_static;
  21use parking_lot::RwLock;
  22use postage::watch;
  23use rand::prelude::*;
  24use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
  25use std::{
  26    any::TypeId,
  27    collections::HashMap,
  28    convert::TryFrom,
  29    fmt::Write as _,
  30    future::Future,
  31    sync::{Arc, Weak},
  32    time::{Duration, Instant},
  33};
  34use thiserror::Error;
  35use url::Url;
  36use util::{ResultExt, TryFutureExt};
  37
  38pub use channel::*;
  39pub use rpc::*;
  40pub use user::*;
  41
  42lazy_static! {
  43    pub static ref ZED_SERVER_URL: String =
  44        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
  45    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  46        .ok()
  47        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  48}
  49
  50pub const ZED_SECRET_CLIENT_TOKEN: &'static str = "618033988749894";
  51
  52actions!(client, [Authenticate]);
  53
  54pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
  55    cx.add_global_action(move |_: &Authenticate, cx| {
  56        let rpc = rpc.clone();
  57        cx.spawn(|cx| async move { rpc.authenticate_and_connect(true, &cx).log_err().await })
  58            .detach();
  59    });
  60}
  61
  62pub struct Client {
  63    id: usize,
  64    peer: Arc<Peer>,
  65    http: Arc<dyn HttpClient>,
  66    state: RwLock<ClientState>,
  67
  68    #[cfg(any(test, feature = "test-support"))]
  69    authenticate: RwLock<
  70        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
  71    >,
  72    #[cfg(any(test, feature = "test-support"))]
  73    establish_connection: RwLock<
  74        Option<
  75            Box<
  76                dyn 'static
  77                    + Send
  78                    + Sync
  79                    + Fn(
  80                        &Credentials,
  81                        &AsyncAppContext,
  82                    ) -> Task<Result<Connection, EstablishConnectionError>>,
  83            >,
  84        >,
  85    >,
  86}
  87
  88#[derive(Error, Debug)]
  89pub enum EstablishConnectionError {
  90    #[error("upgrade required")]
  91    UpgradeRequired,
  92    #[error("unauthorized")]
  93    Unauthorized,
  94    #[error("{0}")]
  95    Other(#[from] anyhow::Error),
  96    #[error("{0}")]
  97    Http(#[from] http::Error),
  98    #[error("{0}")]
  99    Io(#[from] std::io::Error),
 100    #[error("{0}")]
 101    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 102}
 103
 104impl From<WebsocketError> for EstablishConnectionError {
 105    fn from(error: WebsocketError) -> Self {
 106        if let WebsocketError::Http(response) = &error {
 107            match response.status() {
 108                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 109                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 110                _ => {}
 111            }
 112        }
 113        EstablishConnectionError::Other(error.into())
 114    }
 115}
 116
 117impl EstablishConnectionError {
 118    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 119        Self::Other(error.into())
 120    }
 121}
 122
 123#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 124pub enum Status {
 125    SignedOut,
 126    UpgradeRequired,
 127    Authenticating,
 128    Connecting,
 129    ConnectionError,
 130    Connected { connection_id: ConnectionId },
 131    ConnectionLost,
 132    Reauthenticating,
 133    Reconnecting,
 134    ReconnectionError { next_reconnection: Instant },
 135}
 136
 137impl Status {
 138    pub fn is_connected(&self) -> bool {
 139        matches!(self, Self::Connected { .. })
 140    }
 141}
 142
 143struct ClientState {
 144    credentials: Option<Credentials>,
 145    status: (watch::Sender<Status>, watch::Receiver<Status>),
 146    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 147    _reconnect_task: Option<Task<()>>,
 148    reconnect_interval: Duration,
 149    entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
 150    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 151    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 152    message_handlers: HashMap<
 153        TypeId,
 154        Arc<
 155            dyn Send
 156                + Sync
 157                + Fn(
 158                    AnyEntityHandle,
 159                    Box<dyn AnyTypedEnvelope>,
 160                    &Arc<Client>,
 161                    AsyncAppContext,
 162                ) -> LocalBoxFuture<'static, Result<()>>,
 163        >,
 164    >,
 165}
 166
 167enum AnyWeakEntityHandle {
 168    Model(AnyWeakModelHandle),
 169    View(AnyWeakViewHandle),
 170}
 171
 172enum AnyEntityHandle {
 173    Model(AnyModelHandle),
 174    View(AnyViewHandle),
 175}
 176
 177#[derive(Clone, Debug)]
 178pub struct Credentials {
 179    pub user_id: u64,
 180    pub access_token: String,
 181}
 182
 183impl Default for ClientState {
 184    fn default() -> Self {
 185        Self {
 186            credentials: None,
 187            status: watch::channel_with(Status::SignedOut),
 188            entity_id_extractors: Default::default(),
 189            _reconnect_task: None,
 190            reconnect_interval: Duration::from_secs(5),
 191            models_by_message_type: Default::default(),
 192            entities_by_type_and_remote_id: Default::default(),
 193            entity_types_by_message_type: Default::default(),
 194            message_handlers: Default::default(),
 195        }
 196    }
 197}
 198
 199pub enum Subscription {
 200    Entity {
 201        client: Weak<Client>,
 202        id: (TypeId, u64),
 203    },
 204    Message {
 205        client: Weak<Client>,
 206        id: TypeId,
 207    },
 208}
 209
 210impl Drop for Subscription {
 211    fn drop(&mut self) {
 212        match self {
 213            Subscription::Entity { client, id } => {
 214                if let Some(client) = client.upgrade() {
 215                    let mut state = client.state.write();
 216                    let _ = state.entities_by_type_and_remote_id.remove(id);
 217                }
 218            }
 219            Subscription::Message { client, id } => {
 220                if let Some(client) = client.upgrade() {
 221                    let mut state = client.state.write();
 222                    let _ = state.entity_types_by_message_type.remove(id);
 223                    let _ = state.message_handlers.remove(id);
 224                }
 225            }
 226        }
 227    }
 228}
 229
 230impl Client {
 231    pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
 232        Arc::new(Self {
 233            id: 0,
 234            peer: Peer::new(),
 235            http,
 236            state: Default::default(),
 237
 238            #[cfg(any(test, feature = "test-support"))]
 239            authenticate: Default::default(),
 240            #[cfg(any(test, feature = "test-support"))]
 241            establish_connection: Default::default(),
 242        })
 243    }
 244
 245    pub fn id(&self) -> usize {
 246        self.id
 247    }
 248
 249    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 250        self.http.clone()
 251    }
 252
 253    #[cfg(any(test, feature = "test-support"))]
 254    pub fn set_id(&mut self, id: usize) -> &Self {
 255        self.id = id;
 256        self
 257    }
 258
 259    #[cfg(any(test, feature = "test-support"))]
 260    pub fn tear_down(&self) {
 261        let mut state = self.state.write();
 262        state._reconnect_task.take();
 263        state.message_handlers.clear();
 264        state.models_by_message_type.clear();
 265        state.entities_by_type_and_remote_id.clear();
 266        state.entity_id_extractors.clear();
 267        self.peer.reset();
 268    }
 269
 270    #[cfg(any(test, feature = "test-support"))]
 271    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 272    where
 273        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 274    {
 275        *self.authenticate.write() = Some(Box::new(authenticate));
 276        self
 277    }
 278
 279    #[cfg(any(test, feature = "test-support"))]
 280    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 281    where
 282        F: 'static
 283            + Send
 284            + Sync
 285            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 286    {
 287        *self.establish_connection.write() = Some(Box::new(connect));
 288        self
 289    }
 290
 291    pub fn user_id(&self) -> Option<u64> {
 292        self.state
 293            .read()
 294            .credentials
 295            .as_ref()
 296            .map(|credentials| credentials.user_id)
 297    }
 298
 299    pub fn status(&self) -> watch::Receiver<Status> {
 300        self.state.read().status.1.clone()
 301    }
 302
 303    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 304        log::info!("set status on client {}: {:?}", self.id, status);
 305        let mut state = self.state.write();
 306        *state.status.0.borrow_mut() = status;
 307
 308        match status {
 309            Status::Connected { .. } => {
 310                state._reconnect_task = None;
 311            }
 312            Status::ConnectionLost => {
 313                let this = self.clone();
 314                let reconnect_interval = state.reconnect_interval;
 315                state._reconnect_task = Some(cx.spawn(|cx| async move {
 316                    let mut rng = StdRng::from_entropy();
 317                    let mut delay = Duration::from_millis(100);
 318                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 319                        log::error!("failed to connect {}", error);
 320                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 321                            this.set_status(
 322                                Status::ReconnectionError {
 323                                    next_reconnection: Instant::now() + delay,
 324                                },
 325                                &cx,
 326                            );
 327                            cx.background().timer(delay).await;
 328                            delay = delay
 329                                .mul_f32(rng.gen_range(1.0..=2.0))
 330                                .min(reconnect_interval);
 331                        } else {
 332                            break;
 333                        }
 334                    }
 335                }));
 336            }
 337            Status::SignedOut | Status::UpgradeRequired => {
 338                state._reconnect_task.take();
 339            }
 340            _ => {}
 341        }
 342    }
 343
 344    pub fn add_view_for_remote_entity<T: View>(
 345        self: &Arc<Self>,
 346        remote_id: u64,
 347        cx: &mut ViewContext<T>,
 348    ) -> Subscription {
 349        let id = (TypeId::of::<T>(), remote_id);
 350        self.state
 351            .write()
 352            .entities_by_type_and_remote_id
 353            .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
 354        Subscription::Entity {
 355            client: Arc::downgrade(self),
 356            id,
 357        }
 358    }
 359
 360    pub fn add_model_for_remote_entity<T: Entity>(
 361        self: &Arc<Self>,
 362        remote_id: u64,
 363        cx: &mut ModelContext<T>,
 364    ) -> Subscription {
 365        let id = (TypeId::of::<T>(), remote_id);
 366        self.state
 367            .write()
 368            .entities_by_type_and_remote_id
 369            .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
 370        Subscription::Entity {
 371            client: Arc::downgrade(self),
 372            id,
 373        }
 374    }
 375
 376    pub fn add_message_handler<M, E, H, F>(
 377        self: &Arc<Self>,
 378        model: ModelHandle<E>,
 379        handler: H,
 380    ) -> Subscription
 381    where
 382        M: EnvelopedMessage,
 383        E: Entity,
 384        H: 'static
 385            + Send
 386            + Sync
 387            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 388        F: 'static + Future<Output = Result<()>>,
 389    {
 390        let message_type_id = TypeId::of::<M>();
 391
 392        let mut state = self.state.write();
 393        state
 394            .models_by_message_type
 395            .insert(message_type_id, model.downgrade().into());
 396
 397        let prev_handler = state.message_handlers.insert(
 398            message_type_id,
 399            Arc::new(move |handle, envelope, client, cx| {
 400                let handle = if let AnyEntityHandle::Model(handle) = handle {
 401                    handle
 402                } else {
 403                    unreachable!();
 404                };
 405                let model = handle.downcast::<E>().unwrap();
 406                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 407                handler(model, *envelope, client.clone(), cx).boxed_local()
 408            }),
 409        );
 410        if prev_handler.is_some() {
 411            panic!("registered handler for the same message twice");
 412        }
 413
 414        Subscription::Message {
 415            client: Arc::downgrade(self),
 416            id: message_type_id,
 417        }
 418    }
 419
 420    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 421    where
 422        M: EntityMessage,
 423        E: View,
 424        H: 'static
 425            + Send
 426            + Sync
 427            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 428        F: 'static + Future<Output = Result<()>>,
 429    {
 430        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 431            if let AnyEntityHandle::View(handle) = handle {
 432                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 433            } else {
 434                unreachable!();
 435            }
 436        })
 437    }
 438
 439    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 440    where
 441        M: EntityMessage,
 442        E: Entity,
 443        H: 'static
 444            + Send
 445            + Sync
 446            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 447        F: 'static + Future<Output = Result<()>>,
 448    {
 449        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 450            if let AnyEntityHandle::Model(handle) = handle {
 451                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 452            } else {
 453                unreachable!();
 454            }
 455        })
 456    }
 457
 458    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 459    where
 460        M: EntityMessage,
 461        E: Entity,
 462        H: 'static
 463            + Send
 464            + Sync
 465            + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 466        F: 'static + Future<Output = Result<()>>,
 467    {
 468        let model_type_id = TypeId::of::<E>();
 469        let message_type_id = TypeId::of::<M>();
 470
 471        let mut state = self.state.write();
 472        state
 473            .entity_types_by_message_type
 474            .insert(message_type_id, model_type_id);
 475        state
 476            .entity_id_extractors
 477            .entry(message_type_id)
 478            .or_insert_with(|| {
 479                |envelope| {
 480                    envelope
 481                        .as_any()
 482                        .downcast_ref::<TypedEnvelope<M>>()
 483                        .unwrap()
 484                        .payload
 485                        .remote_entity_id()
 486                }
 487            });
 488        let prev_handler = state.message_handlers.insert(
 489            message_type_id,
 490            Arc::new(move |handle, envelope, client, cx| {
 491                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 492                handler(handle, *envelope, client.clone(), cx).boxed_local()
 493            }),
 494        );
 495        if prev_handler.is_some() {
 496            panic!("registered handler for the same message twice");
 497        }
 498    }
 499
 500    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 501    where
 502        M: EntityMessage + RequestMessage,
 503        E: Entity,
 504        H: 'static
 505            + Send
 506            + Sync
 507            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 508        F: 'static + Future<Output = Result<M::Response>>,
 509    {
 510        self.add_model_message_handler(move |entity, envelope, client, cx| {
 511            Self::respond_to_request::<M, _>(
 512                envelope.receipt(),
 513                handler(entity, envelope, client.clone(), cx),
 514                client,
 515            )
 516        })
 517    }
 518
 519    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 520    where
 521        M: EntityMessage + RequestMessage,
 522        E: View,
 523        H: 'static
 524            + Send
 525            + Sync
 526            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 527        F: 'static + Future<Output = Result<M::Response>>,
 528    {
 529        self.add_view_message_handler(move |entity, envelope, client, cx| {
 530            Self::respond_to_request::<M, _>(
 531                envelope.receipt(),
 532                handler(entity, envelope, client.clone(), cx),
 533                client,
 534            )
 535        })
 536    }
 537
 538    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 539        receipt: Receipt<T>,
 540        response: F,
 541        client: Arc<Self>,
 542    ) -> Result<()> {
 543        match response.await {
 544            Ok(response) => {
 545                client.respond(receipt, response)?;
 546                Ok(())
 547            }
 548            Err(error) => {
 549                client.respond_with_error(
 550                    receipt,
 551                    proto::Error {
 552                        message: format!("{:?}", error),
 553                    },
 554                )?;
 555                Err(error)
 556            }
 557        }
 558    }
 559
 560    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 561        read_credentials_from_keychain(cx).is_some()
 562    }
 563
 564    #[async_recursion(?Send)]
 565    pub async fn authenticate_and_connect(
 566        self: &Arc<Self>,
 567        try_keychain: bool,
 568        cx: &AsyncAppContext,
 569    ) -> anyhow::Result<()> {
 570        let was_disconnected = match *self.status().borrow() {
 571            Status::SignedOut => true,
 572            Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
 573                false
 574            }
 575            Status::Connected { .. }
 576            | Status::Connecting { .. }
 577            | Status::Reconnecting { .. }
 578            | Status::Authenticating
 579            | Status::Reauthenticating => return Ok(()),
 580            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 581        };
 582
 583        if was_disconnected {
 584            self.set_status(Status::Authenticating, cx);
 585        } else {
 586            self.set_status(Status::Reauthenticating, cx)
 587        }
 588
 589        let mut read_from_keychain = false;
 590        let mut credentials = self.state.read().credentials.clone();
 591        if credentials.is_none() && try_keychain {
 592            credentials = read_credentials_from_keychain(cx);
 593            read_from_keychain = credentials.is_some();
 594        }
 595        if credentials.is_none() {
 596            credentials = Some(match self.authenticate(&cx).await {
 597                Ok(credentials) => credentials,
 598                Err(err) => {
 599                    self.set_status(Status::ConnectionError, cx);
 600                    return Err(err);
 601                }
 602            });
 603        }
 604        let credentials = credentials.unwrap();
 605
 606        if was_disconnected {
 607            self.set_status(Status::Connecting, cx);
 608        } else {
 609            self.set_status(Status::Reconnecting, cx);
 610        }
 611
 612        match self.establish_connection(&credentials, cx).await {
 613            Ok(conn) => {
 614                self.state.write().credentials = Some(credentials.clone());
 615                if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 616                    write_credentials_to_keychain(&credentials, cx).log_err();
 617                }
 618                self.set_connection(conn, cx).await;
 619                Ok(())
 620            }
 621            Err(EstablishConnectionError::Unauthorized) => {
 622                self.state.write().credentials.take();
 623                if read_from_keychain {
 624                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 625                    self.set_status(Status::SignedOut, cx);
 626                    self.authenticate_and_connect(false, cx).await
 627                } else {
 628                    self.set_status(Status::ConnectionError, cx);
 629                    Err(EstablishConnectionError::Unauthorized)?
 630                }
 631            }
 632            Err(EstablishConnectionError::UpgradeRequired) => {
 633                self.set_status(Status::UpgradeRequired, cx);
 634                Err(EstablishConnectionError::UpgradeRequired)?
 635            }
 636            Err(error) => {
 637                self.set_status(Status::ConnectionError, cx);
 638                Err(error)?
 639            }
 640        }
 641    }
 642
 643    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
 644        let executor = cx.background();
 645        log::info!("add connection to peer");
 646        let (connection_id, handle_io, mut incoming) = self
 647            .peer
 648            .add_connection(conn, move |duration| executor.timer(duration))
 649            .await;
 650        log::info!("set status to connected {}", connection_id);
 651        self.set_status(Status::Connected { connection_id }, cx);
 652        cx.foreground()
 653            .spawn({
 654                let cx = cx.clone();
 655                let this = self.clone();
 656                async move {
 657                    let mut message_id = 0_usize;
 658                    while let Some(message) = incoming.next().await {
 659                        let mut state = this.state.write();
 660                        message_id += 1;
 661                        let type_name = message.payload_type_name();
 662                        let payload_type_id = message.payload_type_id();
 663                        let sender_id = message.original_sender_id().map(|id| id.0);
 664
 665                        let model = state
 666                            .models_by_message_type
 667                            .get(&payload_type_id)
 668                            .and_then(|model| model.upgrade(&cx))
 669                            .map(AnyEntityHandle::Model)
 670                            .or_else(|| {
 671                                let entity_type_id =
 672                                    *state.entity_types_by_message_type.get(&payload_type_id)?;
 673                                let entity_id = state
 674                                    .entity_id_extractors
 675                                    .get(&message.payload_type_id())
 676                                    .map(|extract_entity_id| {
 677                                        (extract_entity_id)(message.as_ref())
 678                                    })?;
 679
 680                                let entity = state
 681                                    .entities_by_type_and_remote_id
 682                                    .get(&(entity_type_id, entity_id))?;
 683                                if let Some(entity) = entity.upgrade(&cx) {
 684                                    Some(entity)
 685                                } else {
 686                                    state
 687                                        .entities_by_type_and_remote_id
 688                                        .remove(&(entity_type_id, entity_id));
 689                                    None
 690                                }
 691                            });
 692
 693                        let model = if let Some(model) = model {
 694                            model
 695                        } else {
 696                            log::info!("unhandled message {}", type_name);
 697                            continue;
 698                        };
 699
 700                        if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
 701                        {
 702                            drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
 703                            let future = handler(model, message, &this, cx.clone());
 704
 705                            let client_id = this.id;
 706                            log::debug!(
 707                                "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 708                                client_id,
 709                                message_id,
 710                                sender_id,
 711                                type_name
 712                            );
 713                            cx.foreground()
 714                                .spawn(async move {
 715                                    match future.await {
 716                                        Ok(()) => {
 717                                            log::debug!(
 718                                                "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 719                                                client_id,
 720                                                message_id,
 721                                                sender_id,
 722                                                type_name
 723                                            );
 724                                        }
 725                                        Err(error) => {
 726                                            log::error!(
 727                                                "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
 728                                                client_id,
 729                                                message_id,
 730                                                sender_id,
 731                                                type_name,
 732                                                error
 733                                            );
 734                                        }
 735                                    }
 736                                })
 737                                .detach();
 738                        } else {
 739                            log::info!("unhandled message {}", type_name);
 740                        }
 741
 742                        // Don't starve the main thread when receiving lots of messages at once.
 743                        smol::future::yield_now().await;
 744                    }
 745                }
 746            })
 747            .detach();
 748
 749        let handle_io = cx.background().spawn(handle_io);
 750        let this = self.clone();
 751        let cx = cx.clone();
 752        cx.foreground()
 753            .spawn(async move {
 754                match handle_io.await {
 755                    Ok(()) => {
 756                        if *this.status().borrow() == (Status::Connected { connection_id }) {
 757                            this.set_status(Status::SignedOut, &cx);
 758                        }
 759                    }
 760                    Err(err) => {
 761                        log::error!("connection error: {:?}", err);
 762                        this.set_status(Status::ConnectionLost, &cx);
 763                    }
 764                }
 765            })
 766            .detach();
 767    }
 768
 769    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 770        #[cfg(any(test, feature = "test-support"))]
 771        if let Some(callback) = self.authenticate.read().as_ref() {
 772            return callback(cx);
 773        }
 774
 775        self.authenticate_with_browser(cx)
 776    }
 777
 778    fn establish_connection(
 779        self: &Arc<Self>,
 780        credentials: &Credentials,
 781        cx: &AsyncAppContext,
 782    ) -> Task<Result<Connection, EstablishConnectionError>> {
 783        #[cfg(any(test, feature = "test-support"))]
 784        if let Some(callback) = self.establish_connection.read().as_ref() {
 785            return callback(credentials, cx);
 786        }
 787
 788        self.establish_websocket_connection(credentials, cx)
 789    }
 790
 791    fn establish_websocket_connection(
 792        self: &Arc<Self>,
 793        credentials: &Credentials,
 794        cx: &AsyncAppContext,
 795    ) -> Task<Result<Connection, EstablishConnectionError>> {
 796        let request = Request::builder()
 797            .header(
 798                "Authorization",
 799                format!("{} {}", credentials.user_id, credentials.access_token),
 800            )
 801            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
 802
 803        let http = self.http.clone();
 804        cx.background().spawn(async move {
 805            let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
 806            let rpc_response = http.get(&rpc_url, Default::default(), false).await?;
 807            if rpc_response.status().is_redirection() {
 808                rpc_url = rpc_response
 809                    .headers()
 810                    .get("Location")
 811                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 812                    .to_str()
 813                    .map_err(|error| EstablishConnectionError::other(error))?
 814                    .to_string();
 815            }
 816            // Until we switch the zed.dev domain to point to the new Next.js app, there
 817            // will be no redirect required, and the app will connect directly to
 818            // wss://zed.dev/rpc.
 819            else if rpc_response.status() != StatusCode::UPGRADE_REQUIRED {
 820                Err(anyhow!(
 821                    "unexpected /rpc response status {}",
 822                    rpc_response.status()
 823                ))?
 824            }
 825
 826            let mut rpc_url = Url::parse(&rpc_url).context("invalid rpc url")?;
 827            let rpc_host = rpc_url
 828                .host_str()
 829                .zip(rpc_url.port_or_known_default())
 830                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
 831            let stream = smol::net::TcpStream::connect(rpc_host).await?;
 832
 833            log::info!("connected to rpc endpoint {}", rpc_url);
 834
 835            match rpc_url.scheme() {
 836                "https" => {
 837                    rpc_url.set_scheme("wss").unwrap();
 838                    let request = request.uri(rpc_url.as_str()).body(())?;
 839                    let (stream, _) =
 840                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
 841                    Ok(Connection::new(
 842                        stream
 843                            .map_err(|error| anyhow!(error))
 844                            .sink_map_err(|error| anyhow!(error)),
 845                    ))
 846                }
 847                "http" => {
 848                    rpc_url.set_scheme("ws").unwrap();
 849                    let request = request.uri(rpc_url.as_str()).body(())?;
 850                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
 851                    Ok(Connection::new(
 852                        stream
 853                            .map_err(|error| anyhow!(error))
 854                            .sink_map_err(|error| anyhow!(error)),
 855                    ))
 856                }
 857                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
 858            }
 859        })
 860    }
 861
 862    pub fn authenticate_with_browser(
 863        self: &Arc<Self>,
 864        cx: &AsyncAppContext,
 865    ) -> Task<Result<Credentials>> {
 866        let platform = cx.platform();
 867        let executor = cx.background();
 868        executor.clone().spawn(async move {
 869            // Generate a pair of asymmetric encryption keys. The public key will be used by the
 870            // zed server to encrypt the user's access token, so that it can'be intercepted by
 871            // any other app running on the user's device.
 872            let (public_key, private_key) =
 873                rpc::auth::keypair().expect("failed to generate keypair for auth");
 874            let public_key_string =
 875                String::try_from(public_key).expect("failed to serialize public key for auth");
 876
 877            // Start an HTTP server to receive the redirect from Zed's sign-in page.
 878            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
 879            let port = server.server_addr().port();
 880
 881            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
 882            // that the user is signing in from a Zed app running on the same device.
 883            let mut url = format!(
 884                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
 885                *ZED_SERVER_URL, port, public_key_string
 886            );
 887
 888            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
 889                log::info!("impersonating user @{}", impersonate_login);
 890                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
 891            }
 892
 893            platform.open_url(&url);
 894
 895            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
 896            // access token from the query params.
 897            //
 898            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
 899            // custom URL scheme instead of this local HTTP server.
 900            let (user_id, access_token) = executor
 901                .spawn(async move {
 902                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
 903                        let path = req.url();
 904                        let mut user_id = None;
 905                        let mut access_token = None;
 906                        let url = Url::parse(&format!("http://example.com{}", path))
 907                            .context("failed to parse login notification url")?;
 908                        for (key, value) in url.query_pairs() {
 909                            if key == "access_token" {
 910                                access_token = Some(value.to_string());
 911                            } else if key == "user_id" {
 912                                user_id = Some(value.to_string());
 913                            }
 914                        }
 915
 916                        let post_auth_url =
 917                            format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
 918                        req.respond(
 919                            tiny_http::Response::empty(302).with_header(
 920                                tiny_http::Header::from_bytes(
 921                                    &b"Location"[..],
 922                                    post_auth_url.as_bytes(),
 923                                )
 924                                .unwrap(),
 925                            ),
 926                        )
 927                        .context("failed to respond to login http request")?;
 928                        Ok((
 929                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
 930                            access_token
 931                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
 932                        ))
 933                    } else {
 934                        Err(anyhow!("didn't receive login redirect"))
 935                    }
 936                })
 937                .await?;
 938
 939            let access_token = private_key
 940                .decrypt_string(&access_token)
 941                .context("failed to decrypt access token")?;
 942            platform.activate(true);
 943
 944            Ok(Credentials {
 945                user_id: user_id.parse()?,
 946                access_token,
 947            })
 948        })
 949    }
 950
 951    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
 952        let conn_id = self.connection_id()?;
 953        self.peer.disconnect(conn_id);
 954        self.set_status(Status::SignedOut, cx);
 955        Ok(())
 956    }
 957
 958    fn connection_id(&self) -> Result<ConnectionId> {
 959        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
 960            Ok(connection_id)
 961        } else {
 962            Err(anyhow!("not connected"))
 963        }
 964    }
 965
 966    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
 967        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
 968        self.peer.send(self.connection_id()?, message)
 969    }
 970
 971    pub fn request<T: RequestMessage>(
 972        &self,
 973        request: T,
 974    ) -> impl Future<Output = Result<T::Response>> {
 975        let client_id = self.id;
 976        log::debug!(
 977            "rpc request start. client_id:{}. name:{}",
 978            client_id,
 979            T::NAME
 980        );
 981        let response = self
 982            .connection_id()
 983            .map(|conn_id| self.peer.request(conn_id, request));
 984        async move {
 985            let response = response?.await;
 986            log::debug!(
 987                "rpc request finish. client_id:{}. name:{}",
 988                client_id,
 989                T::NAME
 990            );
 991            response
 992        }
 993    }
 994
 995    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
 996        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
 997        self.peer.respond(receipt, response)
 998    }
 999
1000    fn respond_with_error<T: RequestMessage>(
1001        &self,
1002        receipt: Receipt<T>,
1003        error: proto::Error,
1004    ) -> Result<()> {
1005        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1006        self.peer.respond_with_error(receipt, error)
1007    }
1008}
1009
1010impl AnyWeakEntityHandle {
1011    fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
1012        match self {
1013            AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
1014            AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
1015        }
1016    }
1017}
1018
1019fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1020    if IMPERSONATE_LOGIN.is_some() {
1021        return None;
1022    }
1023
1024    let (user_id, access_token) = cx
1025        .platform()
1026        .read_credentials(&ZED_SERVER_URL)
1027        .log_err()
1028        .flatten()?;
1029    Some(Credentials {
1030        user_id: user_id.parse().ok()?,
1031        access_token: String::from_utf8(access_token).ok()?,
1032    })
1033}
1034
1035fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1036    cx.platform().write_credentials(
1037        &ZED_SERVER_URL,
1038        &credentials.user_id.to_string(),
1039        credentials.access_token.as_bytes(),
1040    )
1041}
1042
1043const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
1044
1045pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1046    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1047}
1048
1049pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1050    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1051    let mut parts = path.split('/');
1052    let id = parts.next()?.parse::<u64>().ok()?;
1053    let access_token = parts.next()?;
1054    if access_token.is_empty() {
1055        return None;
1056    }
1057    Some((id, access_token.to_string()))
1058}
1059
1060#[cfg(test)]
1061mod tests {
1062    use super::*;
1063    use crate::test::{FakeHttpClient, FakeServer};
1064    use gpui::TestAppContext;
1065
1066    #[gpui::test(iterations = 10)]
1067    async fn test_reconnection(cx: &mut TestAppContext) {
1068        cx.foreground().forbid_parking();
1069
1070        let user_id = 5;
1071        let mut client = Client::new(FakeHttpClient::with_404_response());
1072        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1073        let mut status = client.status();
1074        assert!(matches!(
1075            status.next().await,
1076            Some(Status::Connected { .. })
1077        ));
1078        assert_eq!(server.auth_count(), 1);
1079
1080        server.forbid_connections();
1081        server.disconnect();
1082        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1083
1084        server.allow_connections();
1085        cx.foreground().advance_clock(Duration::from_secs(10));
1086        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1087        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1088
1089        server.forbid_connections();
1090        server.disconnect();
1091        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1092
1093        // Clear cached credentials after authentication fails
1094        server.roll_access_token();
1095        server.allow_connections();
1096        cx.foreground().advance_clock(Duration::from_secs(10));
1097        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1098        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1099    }
1100
1101    #[test]
1102    fn test_encode_and_decode_worktree_url() {
1103        let url = encode_worktree_url(5, "deadbeef");
1104        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1105        assert_eq!(
1106            decode_worktree_url(&format!("\n {}\t", url)),
1107            Some((5, "deadbeef".to_string()))
1108        );
1109        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1110    }
1111
1112    #[gpui::test]
1113    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1114        cx.foreground().forbid_parking();
1115
1116        let user_id = 5;
1117        let mut client = Client::new(FakeHttpClient::with_404_response());
1118        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1119
1120        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1121        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1122        client.add_model_message_handler(
1123            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1124                match model.read_with(&cx, |model, _| model.id) {
1125                    1 => done_tx1.try_send(()).unwrap(),
1126                    2 => done_tx2.try_send(()).unwrap(),
1127                    _ => unreachable!(),
1128                }
1129                async { Ok(()) }
1130            },
1131        );
1132        let model1 = cx.add_model(|_| Model {
1133            id: 1,
1134            subscription: None,
1135        });
1136        let model2 = cx.add_model(|_| Model {
1137            id: 2,
1138            subscription: None,
1139        });
1140        let model3 = cx.add_model(|_| Model {
1141            id: 3,
1142            subscription: None,
1143        });
1144
1145        let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1146        let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1147        // Ensure dropping a subscription for the same entity type still allows receiving of
1148        // messages for other entity IDs of the same type.
1149        let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1150        drop(subscription3);
1151
1152        server.send(proto::JoinProject { project_id: 1 });
1153        server.send(proto::JoinProject { project_id: 2 });
1154        done_rx1.next().await.unwrap();
1155        done_rx2.next().await.unwrap();
1156    }
1157
1158    #[gpui::test]
1159    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1160        cx.foreground().forbid_parking();
1161
1162        let user_id = 5;
1163        let mut client = Client::new(FakeHttpClient::with_404_response());
1164        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1165
1166        let model = cx.add_model(|_| Model::default());
1167        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1168        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1169        let subscription1 = client.add_message_handler(
1170            model.clone(),
1171            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1172                done_tx1.try_send(()).unwrap();
1173                async { Ok(()) }
1174            },
1175        );
1176        drop(subscription1);
1177        let _subscription2 =
1178            client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1179                done_tx2.try_send(()).unwrap();
1180                async { Ok(()) }
1181            });
1182        server.send(proto::Ping {});
1183        done_rx2.next().await.unwrap();
1184    }
1185
1186    #[gpui::test]
1187    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1188        cx.foreground().forbid_parking();
1189
1190        let user_id = 5;
1191        let mut client = Client::new(FakeHttpClient::with_404_response());
1192        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1193
1194        let model = cx.add_model(|_| Model::default());
1195        let (done_tx, mut done_rx) = smol::channel::unbounded();
1196        let subscription = client.add_message_handler(
1197            model.clone(),
1198            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1199                model.update(&mut cx, |model, _| model.subscription.take());
1200                done_tx.try_send(()).unwrap();
1201                async { Ok(()) }
1202            },
1203        );
1204        model.update(cx, |model, _| {
1205            model.subscription = Some(subscription);
1206        });
1207        server.send(proto::Ping {});
1208        done_rx.next().await.unwrap();
1209    }
1210
1211    #[derive(Default)]
1212    struct Model {
1213        id: usize,
1214        subscription: Option<Subscription>,
1215    }
1216
1217    impl Entity for Model {
1218        type Event = ();
1219    }
1220}