client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod channel;
   5pub mod http;
   6pub mod user;
   7
   8use anyhow::{anyhow, Context, Result};
   9use async_recursion::async_recursion;
  10use async_tungstenite::tungstenite::{
  11    error::Error as WebsocketError,
  12    http::{Request, StatusCode},
  13};
  14use futures::{future::LocalBoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt};
  15use gpui::{
  16    actions, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AsyncAppContext,
  17    Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle,
  18};
  19use http::HttpClient;
  20use lazy_static::lazy_static;
  21use parking_lot::RwLock;
  22use postage::watch;
  23use rand::prelude::*;
  24use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
  25use std::{
  26    any::TypeId,
  27    collections::HashMap,
  28    convert::TryFrom,
  29    fmt::Write as _,
  30    future::Future,
  31    sync::{
  32        atomic::{AtomicUsize, Ordering},
  33        Arc, Weak,
  34    },
  35    time::{Duration, Instant},
  36};
  37use thiserror::Error;
  38use url::Url;
  39use util::{ResultExt, TryFutureExt};
  40
  41pub use channel::*;
  42pub use rpc::*;
  43pub use user::*;
  44
  45lazy_static! {
  46    pub static ref ZED_SERVER_URL: String =
  47        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
  48    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  49        .ok()
  50        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  51}
  52
  53pub const ZED_SECRET_CLIENT_TOKEN: &'static str = "618033988749894";
  54
  55actions!(client, [Authenticate]);
  56
  57pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
  58    cx.add_global_action(move |_: &Authenticate, cx| {
  59        let rpc = rpc.clone();
  60        cx.spawn(|cx| async move { rpc.authenticate_and_connect(true, &cx).log_err().await })
  61            .detach();
  62    });
  63}
  64
  65pub struct Client {
  66    id: usize,
  67    peer: Arc<Peer>,
  68    http: Arc<dyn HttpClient>,
  69    state: RwLock<ClientState>,
  70    authenticate:
  71        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
  72    establish_connection: Option<
  73        Box<
  74            dyn 'static
  75                + Send
  76                + Sync
  77                + Fn(
  78                    &Credentials,
  79                    &AsyncAppContext,
  80                ) -> Task<Result<Connection, EstablishConnectionError>>,
  81        >,
  82    >,
  83}
  84
  85#[derive(Error, Debug)]
  86pub enum EstablishConnectionError {
  87    #[error("upgrade required")]
  88    UpgradeRequired,
  89    #[error("unauthorized")]
  90    Unauthorized,
  91    #[error("{0}")]
  92    Other(#[from] anyhow::Error),
  93    #[error("{0}")]
  94    Http(#[from] http::Error),
  95    #[error("{0}")]
  96    Io(#[from] std::io::Error),
  97    #[error("{0}")]
  98    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
  99}
 100
 101impl From<WebsocketError> for EstablishConnectionError {
 102    fn from(error: WebsocketError) -> Self {
 103        if let WebsocketError::Http(response) = &error {
 104            match response.status() {
 105                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 106                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 107                _ => {}
 108            }
 109        }
 110        EstablishConnectionError::Other(error.into())
 111    }
 112}
 113
 114impl EstablishConnectionError {
 115    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 116        Self::Other(error.into())
 117    }
 118}
 119
 120#[derive(Copy, Clone, Debug)]
 121pub enum Status {
 122    SignedOut,
 123    UpgradeRequired,
 124    Authenticating,
 125    Connecting,
 126    ConnectionError,
 127    Connected { connection_id: ConnectionId },
 128    ConnectionLost,
 129    Reauthenticating,
 130    Reconnecting,
 131    ReconnectionError { next_reconnection: Instant },
 132}
 133
 134impl Status {
 135    pub fn is_connected(&self) -> bool {
 136        matches!(self, Self::Connected { .. })
 137    }
 138}
 139
 140struct ClientState {
 141    credentials: Option<Credentials>,
 142    status: (watch::Sender<Status>, watch::Receiver<Status>),
 143    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 144    _reconnect_task: Option<Task<()>>,
 145    reconnect_interval: Duration,
 146    entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
 147    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 148    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 149    message_handlers: HashMap<
 150        TypeId,
 151        Arc<
 152            dyn Send
 153                + Sync
 154                + Fn(
 155                    AnyEntityHandle,
 156                    Box<dyn AnyTypedEnvelope>,
 157                    &Arc<Client>,
 158                    AsyncAppContext,
 159                ) -> LocalBoxFuture<'static, Result<()>>,
 160        >,
 161    >,
 162}
 163
 164enum AnyWeakEntityHandle {
 165    Model(AnyWeakModelHandle),
 166    View(AnyWeakViewHandle),
 167}
 168
 169enum AnyEntityHandle {
 170    Model(AnyModelHandle),
 171    View(AnyViewHandle),
 172}
 173
 174#[derive(Clone, Debug)]
 175pub struct Credentials {
 176    pub user_id: u64,
 177    pub access_token: String,
 178}
 179
 180impl Default for ClientState {
 181    fn default() -> Self {
 182        Self {
 183            credentials: None,
 184            status: watch::channel_with(Status::SignedOut),
 185            entity_id_extractors: Default::default(),
 186            _reconnect_task: None,
 187            reconnect_interval: Duration::from_secs(5),
 188            models_by_message_type: Default::default(),
 189            entities_by_type_and_remote_id: Default::default(),
 190            entity_types_by_message_type: Default::default(),
 191            message_handlers: Default::default(),
 192        }
 193    }
 194}
 195
 196pub enum Subscription {
 197    Entity {
 198        client: Weak<Client>,
 199        id: (TypeId, u64),
 200    },
 201    Message {
 202        client: Weak<Client>,
 203        id: TypeId,
 204    },
 205}
 206
 207impl Drop for Subscription {
 208    fn drop(&mut self) {
 209        match self {
 210            Subscription::Entity { client, id } => {
 211                if let Some(client) = client.upgrade() {
 212                    let mut state = client.state.write();
 213                    let _ = state.entities_by_type_and_remote_id.remove(id);
 214                }
 215            }
 216            Subscription::Message { client, id } => {
 217                if let Some(client) = client.upgrade() {
 218                    let mut state = client.state.write();
 219                    let _ = state.entity_types_by_message_type.remove(id);
 220                    let _ = state.message_handlers.remove(id);
 221                }
 222            }
 223        }
 224    }
 225}
 226
 227impl Client {
 228    pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
 229        lazy_static! {
 230            static ref NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::default();
 231        }
 232
 233        Arc::new(Self {
 234            id: NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst),
 235            peer: Peer::new(),
 236            http,
 237            state: Default::default(),
 238            authenticate: None,
 239            establish_connection: None,
 240        })
 241    }
 242
 243    pub fn id(&self) -> usize {
 244        self.id
 245    }
 246
 247    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 248        self.http.clone()
 249    }
 250
 251    #[cfg(any(test, feature = "test-support"))]
 252    pub fn tear_down(&self) {
 253        let mut state = self.state.write();
 254        state._reconnect_task.take();
 255        state.message_handlers.clear();
 256        state.models_by_message_type.clear();
 257        state.entities_by_type_and_remote_id.clear();
 258        state.entity_id_extractors.clear();
 259        self.peer.reset();
 260    }
 261
 262    #[cfg(any(test, feature = "test-support"))]
 263    pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
 264    where
 265        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 266    {
 267        self.authenticate = Some(Box::new(authenticate));
 268        self
 269    }
 270
 271    #[cfg(any(test, feature = "test-support"))]
 272    pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
 273    where
 274        F: 'static
 275            + Send
 276            + Sync
 277            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 278    {
 279        self.establish_connection = Some(Box::new(connect));
 280        self
 281    }
 282
 283    pub fn user_id(&self) -> Option<u64> {
 284        self.state
 285            .read()
 286            .credentials
 287            .as_ref()
 288            .map(|credentials| credentials.user_id)
 289    }
 290
 291    pub fn status(&self) -> watch::Receiver<Status> {
 292        self.state.read().status.1.clone()
 293    }
 294
 295    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 296        let mut state = self.state.write();
 297        *state.status.0.borrow_mut() = status;
 298
 299        match status {
 300            Status::Connected { .. } => {
 301                state._reconnect_task = None;
 302            }
 303            Status::ConnectionLost => {
 304                let this = self.clone();
 305                let reconnect_interval = state.reconnect_interval;
 306                state._reconnect_task = Some(cx.spawn(|cx| async move {
 307                    let mut rng = StdRng::from_entropy();
 308                    let mut delay = Duration::from_millis(100);
 309                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 310                        log::error!("failed to connect {}", error);
 311                        this.set_status(
 312                            Status::ReconnectionError {
 313                                next_reconnection: Instant::now() + delay,
 314                            },
 315                            &cx,
 316                        );
 317                        cx.background().timer(delay).await;
 318                        delay = delay
 319                            .mul_f32(rng.gen_range(1.0..=2.0))
 320                            .min(reconnect_interval);
 321                    }
 322                }));
 323            }
 324            Status::SignedOut | Status::UpgradeRequired => {
 325                state._reconnect_task.take();
 326            }
 327            _ => {}
 328        }
 329    }
 330
 331    pub fn add_view_for_remote_entity<T: View>(
 332        self: &Arc<Self>,
 333        remote_id: u64,
 334        cx: &mut ViewContext<T>,
 335    ) -> Subscription {
 336        let id = (TypeId::of::<T>(), remote_id);
 337        self.state
 338            .write()
 339            .entities_by_type_and_remote_id
 340            .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
 341        Subscription::Entity {
 342            client: Arc::downgrade(self),
 343            id,
 344        }
 345    }
 346
 347    pub fn add_model_for_remote_entity<T: Entity>(
 348        self: &Arc<Self>,
 349        remote_id: u64,
 350        cx: &mut ModelContext<T>,
 351    ) -> Subscription {
 352        let id = (TypeId::of::<T>(), remote_id);
 353        self.state
 354            .write()
 355            .entities_by_type_and_remote_id
 356            .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
 357        Subscription::Entity {
 358            client: Arc::downgrade(self),
 359            id,
 360        }
 361    }
 362
 363    pub fn add_message_handler<M, E, H, F>(
 364        self: &Arc<Self>,
 365        model: ModelHandle<E>,
 366        handler: H,
 367    ) -> Subscription
 368    where
 369        M: EnvelopedMessage,
 370        E: Entity,
 371        H: 'static
 372            + Send
 373            + Sync
 374            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 375        F: 'static + Future<Output = Result<()>>,
 376    {
 377        let message_type_id = TypeId::of::<M>();
 378
 379        let mut state = self.state.write();
 380        state
 381            .models_by_message_type
 382            .insert(message_type_id, model.downgrade().into());
 383
 384        let prev_handler = state.message_handlers.insert(
 385            message_type_id,
 386            Arc::new(move |handle, envelope, client, cx| {
 387                let handle = if let AnyEntityHandle::Model(handle) = handle {
 388                    handle
 389                } else {
 390                    unreachable!();
 391                };
 392                let model = handle.downcast::<E>().unwrap();
 393                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 394                handler(model, *envelope, client.clone(), cx).boxed_local()
 395            }),
 396        );
 397        if prev_handler.is_some() {
 398            panic!("registered handler for the same message twice");
 399        }
 400
 401        Subscription::Message {
 402            client: Arc::downgrade(self),
 403            id: message_type_id,
 404        }
 405    }
 406
 407    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 408    where
 409        M: EntityMessage,
 410        E: View,
 411        H: 'static
 412            + Send
 413            + Sync
 414            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 415        F: 'static + Future<Output = Result<()>>,
 416    {
 417        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 418            if let AnyEntityHandle::View(handle) = handle {
 419                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 420            } else {
 421                unreachable!();
 422            }
 423        })
 424    }
 425
 426    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 427    where
 428        M: EntityMessage,
 429        E: Entity,
 430        H: 'static
 431            + Send
 432            + Sync
 433            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 434        F: 'static + Future<Output = Result<()>>,
 435    {
 436        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 437            if let AnyEntityHandle::Model(handle) = handle {
 438                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 439            } else {
 440                unreachable!();
 441            }
 442        })
 443    }
 444
 445    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 446    where
 447        M: EntityMessage,
 448        E: Entity,
 449        H: 'static
 450            + Send
 451            + Sync
 452            + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 453        F: 'static + Future<Output = Result<()>>,
 454    {
 455        let model_type_id = TypeId::of::<E>();
 456        let message_type_id = TypeId::of::<M>();
 457
 458        let mut state = self.state.write();
 459        state
 460            .entity_types_by_message_type
 461            .insert(message_type_id, model_type_id);
 462        state
 463            .entity_id_extractors
 464            .entry(message_type_id)
 465            .or_insert_with(|| {
 466                |envelope| {
 467                    envelope
 468                        .as_any()
 469                        .downcast_ref::<TypedEnvelope<M>>()
 470                        .unwrap()
 471                        .payload
 472                        .remote_entity_id()
 473                }
 474            });
 475        let prev_handler = state.message_handlers.insert(
 476            message_type_id,
 477            Arc::new(move |handle, envelope, client, cx| {
 478                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 479                handler(handle, *envelope, client.clone(), cx).boxed_local()
 480            }),
 481        );
 482        if prev_handler.is_some() {
 483            panic!("registered handler for the same message twice");
 484        }
 485    }
 486
 487    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 488    where
 489        M: EntityMessage + RequestMessage,
 490        E: Entity,
 491        H: 'static
 492            + Send
 493            + Sync
 494            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 495        F: 'static + Future<Output = Result<M::Response>>,
 496    {
 497        self.add_model_message_handler(move |entity, envelope, client, cx| {
 498            Self::respond_to_request::<M, _>(
 499                envelope.receipt(),
 500                handler(entity, envelope, client.clone(), cx),
 501                client,
 502            )
 503        })
 504    }
 505
 506    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 507    where
 508        M: EntityMessage + RequestMessage,
 509        E: View,
 510        H: 'static
 511            + Send
 512            + Sync
 513            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 514        F: 'static + Future<Output = Result<M::Response>>,
 515    {
 516        self.add_view_message_handler(move |entity, envelope, client, cx| {
 517            Self::respond_to_request::<M, _>(
 518                envelope.receipt(),
 519                handler(entity, envelope, client.clone(), cx),
 520                client,
 521            )
 522        })
 523    }
 524
 525    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 526        receipt: Receipt<T>,
 527        response: F,
 528        client: Arc<Self>,
 529    ) -> Result<()> {
 530        match response.await {
 531            Ok(response) => {
 532                client.respond(receipt, response)?;
 533                Ok(())
 534            }
 535            Err(error) => {
 536                client.respond_with_error(
 537                    receipt,
 538                    proto::Error {
 539                        message: error.to_string(),
 540                    },
 541                )?;
 542                Err(error)
 543            }
 544        }
 545    }
 546
 547    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 548        read_credentials_from_keychain(cx).is_some()
 549    }
 550
 551    #[async_recursion(?Send)]
 552    pub async fn authenticate_and_connect(
 553        self: &Arc<Self>,
 554        try_keychain: bool,
 555        cx: &AsyncAppContext,
 556    ) -> anyhow::Result<()> {
 557        let was_disconnected = match *self.status().borrow() {
 558            Status::SignedOut => true,
 559            Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
 560                false
 561            }
 562            Status::Connected { .. }
 563            | Status::Connecting { .. }
 564            | Status::Reconnecting { .. }
 565            | Status::Authenticating
 566            | Status::Reauthenticating => return Ok(()),
 567            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 568        };
 569
 570        if was_disconnected {
 571            self.set_status(Status::Authenticating, cx);
 572        } else {
 573            self.set_status(Status::Reauthenticating, cx)
 574        }
 575
 576        let mut read_from_keychain = false;
 577        let mut credentials = self.state.read().credentials.clone();
 578        if credentials.is_none() && try_keychain {
 579            credentials = read_credentials_from_keychain(cx);
 580            read_from_keychain = credentials.is_some();
 581        }
 582        if credentials.is_none() {
 583            credentials = Some(match self.authenticate(&cx).await {
 584                Ok(credentials) => credentials,
 585                Err(err) => {
 586                    self.set_status(Status::ConnectionError, cx);
 587                    return Err(err);
 588                }
 589            });
 590        }
 591        let credentials = credentials.unwrap();
 592
 593        if was_disconnected {
 594            self.set_status(Status::Connecting, cx);
 595        } else {
 596            self.set_status(Status::Reconnecting, cx);
 597        }
 598
 599        match self.establish_connection(&credentials, cx).await {
 600            Ok(conn) => {
 601                self.state.write().credentials = Some(credentials.clone());
 602                if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 603                    write_credentials_to_keychain(&credentials, cx).log_err();
 604                }
 605                self.set_connection(conn, cx).await;
 606                Ok(())
 607            }
 608            Err(EstablishConnectionError::Unauthorized) => {
 609                self.state.write().credentials.take();
 610                if read_from_keychain {
 611                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 612                    self.set_status(Status::SignedOut, cx);
 613                    self.authenticate_and_connect(false, cx).await
 614                } else {
 615                    self.set_status(Status::ConnectionError, cx);
 616                    Err(EstablishConnectionError::Unauthorized)?
 617                }
 618            }
 619            Err(EstablishConnectionError::UpgradeRequired) => {
 620                self.set_status(Status::UpgradeRequired, cx);
 621                Err(EstablishConnectionError::UpgradeRequired)?
 622            }
 623            Err(error) => {
 624                self.set_status(Status::ConnectionError, cx);
 625                Err(error)?
 626            }
 627        }
 628    }
 629
 630    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
 631        let executor = cx.background();
 632        let (connection_id, handle_io, mut incoming) = self
 633            .peer
 634            .add_connection(conn, move |duration| executor.timer(duration))
 635            .await;
 636        cx.foreground()
 637            .spawn({
 638                let cx = cx.clone();
 639                let this = self.clone();
 640                async move {
 641                    let mut message_id = 0_usize;
 642                    while let Some(message) = incoming.next().await {
 643                        let mut state = this.state.write();
 644                        message_id += 1;
 645                        let type_name = message.payload_type_name();
 646                        let payload_type_id = message.payload_type_id();
 647                        let sender_id = message.original_sender_id().map(|id| id.0);
 648
 649                        let model = state
 650                            .models_by_message_type
 651                            .get(&payload_type_id)
 652                            .and_then(|model| model.upgrade(&cx))
 653                            .map(AnyEntityHandle::Model)
 654                            .or_else(|| {
 655                                let entity_type_id =
 656                                    *state.entity_types_by_message_type.get(&payload_type_id)?;
 657                                let entity_id = state
 658                                    .entity_id_extractors
 659                                    .get(&message.payload_type_id())
 660                                    .map(|extract_entity_id| {
 661                                        (extract_entity_id)(message.as_ref())
 662                                    })?;
 663
 664                                let entity = state
 665                                    .entities_by_type_and_remote_id
 666                                    .get(&(entity_type_id, entity_id))?;
 667                                if let Some(entity) = entity.upgrade(&cx) {
 668                                    Some(entity)
 669                                } else {
 670                                    state
 671                                        .entities_by_type_and_remote_id
 672                                        .remove(&(entity_type_id, entity_id));
 673                                    None
 674                                }
 675                            });
 676
 677                        let model = if let Some(model) = model {
 678                            model
 679                        } else {
 680                            log::info!("unhandled message {}", type_name);
 681                            continue;
 682                        };
 683
 684                        if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
 685                        {
 686                            drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
 687                            let future = handler(model, message, &this, cx.clone());
 688
 689                            let client_id = this.id;
 690                            log::debug!(
 691                                "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 692                                client_id,
 693                                message_id,
 694                                sender_id,
 695                                type_name
 696                            );
 697                            cx.foreground()
 698                                .spawn(async move {
 699                                    match future.await {
 700                                        Ok(()) => {
 701                                            log::debug!(
 702                                                "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 703                                                client_id,
 704                                                message_id,
 705                                                sender_id,
 706                                                type_name
 707                                            );
 708                                        }
 709                                        Err(error) => {
 710                                            log::error!(
 711                                                "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
 712                                                client_id,
 713                                                message_id,
 714                                                sender_id,
 715                                                type_name,
 716                                                error
 717                                            );
 718                                        }
 719                                    }
 720                                })
 721                                .detach();
 722                        } else {
 723                            log::info!("unhandled message {}", type_name);
 724                        }
 725
 726                        // Don't starve the main thread when receiving lots of messages at once.
 727                        smol::future::yield_now().await;
 728                    }
 729                }
 730            })
 731            .detach();
 732
 733        self.set_status(Status::Connected { connection_id }, cx);
 734
 735        let handle_io = cx.background().spawn(handle_io);
 736        let this = self.clone();
 737        let cx = cx.clone();
 738        cx.foreground()
 739            .spawn(async move {
 740                match handle_io.await {
 741                    Ok(()) => this.set_status(Status::SignedOut, &cx),
 742                    Err(err) => {
 743                        log::error!("connection error: {:?}", err);
 744                        this.set_status(Status::ConnectionLost, &cx);
 745                    }
 746                }
 747            })
 748            .detach();
 749    }
 750
 751    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 752        if let Some(callback) = self.authenticate.as_ref() {
 753            callback(cx)
 754        } else {
 755            self.authenticate_with_browser(cx)
 756        }
 757    }
 758
 759    fn establish_connection(
 760        self: &Arc<Self>,
 761        credentials: &Credentials,
 762        cx: &AsyncAppContext,
 763    ) -> Task<Result<Connection, EstablishConnectionError>> {
 764        if let Some(callback) = self.establish_connection.as_ref() {
 765            callback(credentials, cx)
 766        } else {
 767            self.establish_websocket_connection(credentials, cx)
 768        }
 769    }
 770
 771    fn establish_websocket_connection(
 772        self: &Arc<Self>,
 773        credentials: &Credentials,
 774        cx: &AsyncAppContext,
 775    ) -> Task<Result<Connection, EstablishConnectionError>> {
 776        let request = Request::builder()
 777            .header(
 778                "Authorization",
 779                format!("{} {}", credentials.user_id, credentials.access_token),
 780            )
 781            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
 782
 783        let http = self.http.clone();
 784        cx.background().spawn(async move {
 785            let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
 786            let rpc_response = http.get(&rpc_url, Default::default(), false).await?;
 787            if rpc_response.status().is_redirection() {
 788                rpc_url = rpc_response
 789                    .headers()
 790                    .get("Location")
 791                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 792                    .to_str()
 793                    .map_err(|error| EstablishConnectionError::other(error))?
 794                    .to_string();
 795            }
 796            // Until we switch the zed.dev domain to point to the new Next.js app, there
 797            // will be no redirect required, and the app will connect directly to
 798            // wss://zed.dev/rpc.
 799            else if rpc_response.status() != StatusCode::UPGRADE_REQUIRED {
 800                Err(anyhow!(
 801                    "unexpected /rpc response status {}",
 802                    rpc_response.status()
 803                ))?
 804            }
 805
 806            let mut rpc_url = Url::parse(&rpc_url).context("invalid rpc url")?;
 807            let rpc_host = rpc_url
 808                .host_str()
 809                .zip(rpc_url.port_or_known_default())
 810                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
 811            let stream = smol::net::TcpStream::connect(rpc_host).await?;
 812
 813            log::info!("connected to rpc endpoint {}", rpc_url);
 814
 815            match rpc_url.scheme() {
 816                "https" => {
 817                    rpc_url.set_scheme("wss").unwrap();
 818                    let request = request.uri(rpc_url.as_str()).body(())?;
 819                    let (stream, _) =
 820                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
 821                    Ok(Connection::new(
 822                        stream
 823                            .map_err(|error| anyhow!(error))
 824                            .sink_map_err(|error| anyhow!(error)),
 825                    ))
 826                }
 827                "http" => {
 828                    rpc_url.set_scheme("ws").unwrap();
 829                    let request = request.uri(rpc_url.as_str()).body(())?;
 830                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
 831                    Ok(Connection::new(
 832                        stream
 833                            .map_err(|error| anyhow!(error))
 834                            .sink_map_err(|error| anyhow!(error)),
 835                    ))
 836                }
 837                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
 838            }
 839        })
 840    }
 841
 842    pub fn authenticate_with_browser(
 843        self: &Arc<Self>,
 844        cx: &AsyncAppContext,
 845    ) -> Task<Result<Credentials>> {
 846        let platform = cx.platform();
 847        let executor = cx.background();
 848        executor.clone().spawn(async move {
 849            // Generate a pair of asymmetric encryption keys. The public key will be used by the
 850            // zed server to encrypt the user's access token, so that it can'be intercepted by
 851            // any other app running on the user's device.
 852            let (public_key, private_key) =
 853                rpc::auth::keypair().expect("failed to generate keypair for auth");
 854            let public_key_string =
 855                String::try_from(public_key).expect("failed to serialize public key for auth");
 856
 857            // Start an HTTP server to receive the redirect from Zed's sign-in page.
 858            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
 859            let port = server.server_addr().port();
 860
 861            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
 862            // that the user is signing in from a Zed app running on the same device.
 863            let mut url = format!(
 864                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
 865                *ZED_SERVER_URL, port, public_key_string
 866            );
 867
 868            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
 869                log::info!("impersonating user @{}", impersonate_login);
 870                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
 871            }
 872
 873            platform.open_url(&url);
 874
 875            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
 876            // access token from the query params.
 877            //
 878            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
 879            // custom URL scheme instead of this local HTTP server.
 880            let (user_id, access_token) = executor
 881                .spawn(async move {
 882                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
 883                        let path = req.url();
 884                        let mut user_id = None;
 885                        let mut access_token = None;
 886                        let url = Url::parse(&format!("http://example.com{}", path))
 887                            .context("failed to parse login notification url")?;
 888                        for (key, value) in url.query_pairs() {
 889                            if key == "access_token" {
 890                                access_token = Some(value.to_string());
 891                            } else if key == "user_id" {
 892                                user_id = Some(value.to_string());
 893                            }
 894                        }
 895
 896                        let post_auth_url =
 897                            format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
 898                        req.respond(
 899                            tiny_http::Response::empty(302).with_header(
 900                                tiny_http::Header::from_bytes(
 901                                    &b"Location"[..],
 902                                    post_auth_url.as_bytes(),
 903                                )
 904                                .unwrap(),
 905                            ),
 906                        )
 907                        .context("failed to respond to login http request")?;
 908                        Ok((
 909                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
 910                            access_token
 911                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
 912                        ))
 913                    } else {
 914                        Err(anyhow!("didn't receive login redirect"))
 915                    }
 916                })
 917                .await?;
 918
 919            let access_token = private_key
 920                .decrypt_string(&access_token)
 921                .context("failed to decrypt access token")?;
 922            platform.activate(true);
 923
 924            Ok(Credentials {
 925                user_id: user_id.parse()?,
 926                access_token,
 927            })
 928        })
 929    }
 930
 931    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
 932        let conn_id = self.connection_id()?;
 933        self.peer.disconnect(conn_id);
 934        self.set_status(Status::SignedOut, cx);
 935        Ok(())
 936    }
 937
 938    fn connection_id(&self) -> Result<ConnectionId> {
 939        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
 940            Ok(connection_id)
 941        } else {
 942            Err(anyhow!("not connected"))
 943        }
 944    }
 945
 946    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
 947        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
 948        self.peer.send(self.connection_id()?, message)
 949    }
 950
 951    pub fn request<T: RequestMessage>(
 952        &self,
 953        request: T,
 954    ) -> impl Future<Output = Result<T::Response>> {
 955        let client_id = self.id;
 956        log::debug!(
 957            "rpc request start. client_id:{}. name:{}",
 958            client_id,
 959            T::NAME
 960        );
 961        let response = self
 962            .connection_id()
 963            .map(|conn_id| self.peer.request(conn_id, request));
 964        async move {
 965            let response = response?.await;
 966            log::debug!(
 967                "rpc request finish. client_id:{}. name:{}",
 968                client_id,
 969                T::NAME
 970            );
 971            response
 972        }
 973    }
 974
 975    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
 976        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
 977        self.peer.respond(receipt, response)
 978    }
 979
 980    fn respond_with_error<T: RequestMessage>(
 981        &self,
 982        receipt: Receipt<T>,
 983        error: proto::Error,
 984    ) -> Result<()> {
 985        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
 986        self.peer.respond_with_error(receipt, error)
 987    }
 988}
 989
 990impl AnyWeakEntityHandle {
 991    fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
 992        match self {
 993            AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
 994            AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
 995        }
 996    }
 997}
 998
 999fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1000    if IMPERSONATE_LOGIN.is_some() {
1001        return None;
1002    }
1003
1004    let (user_id, access_token) = cx
1005        .platform()
1006        .read_credentials(&ZED_SERVER_URL)
1007        .log_err()
1008        .flatten()?;
1009    Some(Credentials {
1010        user_id: user_id.parse().ok()?,
1011        access_token: String::from_utf8(access_token).ok()?,
1012    })
1013}
1014
1015fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1016    cx.platform().write_credentials(
1017        &ZED_SERVER_URL,
1018        &credentials.user_id.to_string(),
1019        credentials.access_token.as_bytes(),
1020    )
1021}
1022
1023const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
1024
1025pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1026    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1027}
1028
1029pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1030    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1031    let mut parts = path.split('/');
1032    let id = parts.next()?.parse::<u64>().ok()?;
1033    let access_token = parts.next()?;
1034    if access_token.is_empty() {
1035        return None;
1036    }
1037    Some((id, access_token.to_string()))
1038}
1039
1040#[cfg(test)]
1041mod tests {
1042    use super::*;
1043    use crate::test::{FakeHttpClient, FakeServer};
1044    use gpui::TestAppContext;
1045
1046    #[gpui::test(iterations = 10)]
1047    async fn test_reconnection(cx: &mut TestAppContext) {
1048        cx.foreground().forbid_parking();
1049
1050        let user_id = 5;
1051        let mut client = Client::new(FakeHttpClient::with_404_response());
1052        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1053        let mut status = client.status();
1054        assert!(matches!(
1055            status.next().await,
1056            Some(Status::Connected { .. })
1057        ));
1058        assert_eq!(server.auth_count(), 1);
1059
1060        server.forbid_connections();
1061        server.disconnect();
1062        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1063
1064        server.allow_connections();
1065        cx.foreground().advance_clock(Duration::from_secs(10));
1066        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1067        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1068
1069        server.forbid_connections();
1070        server.disconnect();
1071        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1072
1073        // Clear cached credentials after authentication fails
1074        server.roll_access_token();
1075        server.allow_connections();
1076        cx.foreground().advance_clock(Duration::from_secs(10));
1077        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1078        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1079    }
1080
1081    #[test]
1082    fn test_encode_and_decode_worktree_url() {
1083        let url = encode_worktree_url(5, "deadbeef");
1084        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1085        assert_eq!(
1086            decode_worktree_url(&format!("\n {}\t", url)),
1087            Some((5, "deadbeef".to_string()))
1088        );
1089        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1090    }
1091
1092    #[gpui::test]
1093    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1094        cx.foreground().forbid_parking();
1095
1096        let user_id = 5;
1097        let mut client = Client::new(FakeHttpClient::with_404_response());
1098        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1099
1100        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1101        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1102        client.add_model_message_handler(
1103            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::UnshareProject>, _, cx| {
1104                match model.read_with(&cx, |model, _| model.id) {
1105                    1 => done_tx1.try_send(()).unwrap(),
1106                    2 => done_tx2.try_send(()).unwrap(),
1107                    _ => unreachable!(),
1108                }
1109                async { Ok(()) }
1110            },
1111        );
1112        let model1 = cx.add_model(|_| Model {
1113            id: 1,
1114            subscription: None,
1115        });
1116        let model2 = cx.add_model(|_| Model {
1117            id: 2,
1118            subscription: None,
1119        });
1120        let model3 = cx.add_model(|_| Model {
1121            id: 3,
1122            subscription: None,
1123        });
1124
1125        let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1126        let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1127        // Ensure dropping a subscription for the same entity type still allows receiving of
1128        // messages for other entity IDs of the same type.
1129        let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1130        drop(subscription3);
1131
1132        server.send(proto::UnshareProject { project_id: 1 });
1133        server.send(proto::UnshareProject { project_id: 2 });
1134        done_rx1.next().await.unwrap();
1135        done_rx2.next().await.unwrap();
1136    }
1137
1138    #[gpui::test]
1139    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1140        cx.foreground().forbid_parking();
1141
1142        let user_id = 5;
1143        let mut client = Client::new(FakeHttpClient::with_404_response());
1144        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1145
1146        let model = cx.add_model(|_| Model::default());
1147        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1148        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1149        let subscription1 = client.add_message_handler(
1150            model.clone(),
1151            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1152                done_tx1.try_send(()).unwrap();
1153                async { Ok(()) }
1154            },
1155        );
1156        drop(subscription1);
1157        let _subscription2 =
1158            client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1159                done_tx2.try_send(()).unwrap();
1160                async { Ok(()) }
1161            });
1162        server.send(proto::Ping {});
1163        done_rx2.next().await.unwrap();
1164    }
1165
1166    #[gpui::test]
1167    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1168        cx.foreground().forbid_parking();
1169
1170        let user_id = 5;
1171        let mut client = Client::new(FakeHttpClient::with_404_response());
1172        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1173
1174        let model = cx.add_model(|_| Model::default());
1175        let (done_tx, mut done_rx) = smol::channel::unbounded();
1176        let subscription = client.add_message_handler(
1177            model.clone(),
1178            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1179                model.update(&mut cx, |model, _| model.subscription.take());
1180                done_tx.try_send(()).unwrap();
1181                async { Ok(()) }
1182            },
1183        );
1184        model.update(cx, |model, _| {
1185            model.subscription = Some(subscription);
1186        });
1187        server.send(proto::Ping {});
1188        done_rx.next().await.unwrap();
1189    }
1190
1191    #[derive(Default)]
1192    struct Model {
1193        id: usize,
1194        subscription: Option<Subscription>,
1195    }
1196
1197    impl Entity for Model {
1198        type Event = ();
1199    }
1200}