client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod channel;
   5pub mod http;
   6pub mod user;
   7
   8use anyhow::{anyhow, Context, Result};
   9use async_recursion::async_recursion;
  10use async_tungstenite::tungstenite::{
  11    error::Error as WebsocketError,
  12    http::{Request, StatusCode},
  13};
  14use futures::{future::LocalBoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt};
  15use gpui::{
  16    actions, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AsyncAppContext,
  17    Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle,
  18};
  19use http::HttpClient;
  20use lazy_static::lazy_static;
  21use parking_lot::RwLock;
  22use postage::watch;
  23use rand::prelude::*;
  24use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
  25use std::{
  26    any::TypeId,
  27    collections::HashMap,
  28    convert::TryFrom,
  29    fmt::Write as _,
  30    future::Future,
  31    sync::{
  32        atomic::{AtomicUsize, Ordering},
  33        Arc, Weak,
  34    },
  35    time::{Duration, Instant},
  36};
  37use thiserror::Error;
  38use url::Url;
  39use util::{ResultExt, TryFutureExt};
  40
  41pub use channel::*;
  42pub use rpc::*;
  43pub use user::*;
  44
  45lazy_static! {
  46    pub static ref ZED_SERVER_URL: String =
  47        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
  48    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  49        .ok()
  50        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  51}
  52
  53pub const ZED_SECRET_CLIENT_TOKEN: &'static str = "618033988749894";
  54
  55actions!(client, [Authenticate]);
  56
  57pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
  58    cx.add_global_action(move |_: &Authenticate, cx| {
  59        let rpc = rpc.clone();
  60        cx.spawn(|cx| async move { rpc.authenticate_and_connect(true, &cx).log_err().await })
  61            .detach();
  62    });
  63}
  64
  65pub struct Client {
  66    id: usize,
  67    peer: Arc<Peer>,
  68    http: Arc<dyn HttpClient>,
  69    state: RwLock<ClientState>,
  70
  71    #[cfg(any(test, feature = "test-support"))]
  72    authenticate: RwLock<
  73        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
  74    >,
  75    #[cfg(any(test, feature = "test-support"))]
  76    establish_connection: RwLock<
  77        Option<
  78            Box<
  79                dyn 'static
  80                    + Send
  81                    + Sync
  82                    + Fn(
  83                        &Credentials,
  84                        &AsyncAppContext,
  85                    ) -> Task<Result<Connection, EstablishConnectionError>>,
  86            >,
  87        >,
  88    >,
  89}
  90
  91#[derive(Error, Debug)]
  92pub enum EstablishConnectionError {
  93    #[error("upgrade required")]
  94    UpgradeRequired,
  95    #[error("unauthorized")]
  96    Unauthorized,
  97    #[error("{0}")]
  98    Other(#[from] anyhow::Error),
  99    #[error("{0}")]
 100    Http(#[from] http::Error),
 101    #[error("{0}")]
 102    Io(#[from] std::io::Error),
 103    #[error("{0}")]
 104    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 105}
 106
 107impl From<WebsocketError> for EstablishConnectionError {
 108    fn from(error: WebsocketError) -> Self {
 109        if let WebsocketError::Http(response) = &error {
 110            match response.status() {
 111                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 112                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 113                _ => {}
 114            }
 115        }
 116        EstablishConnectionError::Other(error.into())
 117    }
 118}
 119
 120impl EstablishConnectionError {
 121    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 122        Self::Other(error.into())
 123    }
 124}
 125
 126#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 127pub enum Status {
 128    SignedOut,
 129    UpgradeRequired,
 130    Authenticating,
 131    Connecting,
 132    ConnectionError,
 133    Connected { connection_id: ConnectionId },
 134    ConnectionLost,
 135    Reauthenticating,
 136    Reconnecting,
 137    ReconnectionError { next_reconnection: Instant },
 138}
 139
 140impl Status {
 141    pub fn is_connected(&self) -> bool {
 142        matches!(self, Self::Connected { .. })
 143    }
 144}
 145
 146struct ClientState {
 147    credentials: Option<Credentials>,
 148    status: (watch::Sender<Status>, watch::Receiver<Status>),
 149    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 150    _reconnect_task: Option<Task<()>>,
 151    reconnect_interval: Duration,
 152    entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
 153    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 154    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 155    message_handlers: HashMap<
 156        TypeId,
 157        Arc<
 158            dyn Send
 159                + Sync
 160                + Fn(
 161                    AnyEntityHandle,
 162                    Box<dyn AnyTypedEnvelope>,
 163                    &Arc<Client>,
 164                    AsyncAppContext,
 165                ) -> LocalBoxFuture<'static, Result<()>>,
 166        >,
 167    >,
 168}
 169
 170enum AnyWeakEntityHandle {
 171    Model(AnyWeakModelHandle),
 172    View(AnyWeakViewHandle),
 173}
 174
 175enum AnyEntityHandle {
 176    Model(AnyModelHandle),
 177    View(AnyViewHandle),
 178}
 179
 180#[derive(Clone, Debug)]
 181pub struct Credentials {
 182    pub user_id: u64,
 183    pub access_token: String,
 184}
 185
 186impl Default for ClientState {
 187    fn default() -> Self {
 188        Self {
 189            credentials: None,
 190            status: watch::channel_with(Status::SignedOut),
 191            entity_id_extractors: Default::default(),
 192            _reconnect_task: None,
 193            reconnect_interval: Duration::from_secs(5),
 194            models_by_message_type: Default::default(),
 195            entities_by_type_and_remote_id: Default::default(),
 196            entity_types_by_message_type: Default::default(),
 197            message_handlers: Default::default(),
 198        }
 199    }
 200}
 201
 202pub enum Subscription {
 203    Entity {
 204        client: Weak<Client>,
 205        id: (TypeId, u64),
 206    },
 207    Message {
 208        client: Weak<Client>,
 209        id: TypeId,
 210    },
 211}
 212
 213impl Drop for Subscription {
 214    fn drop(&mut self) {
 215        match self {
 216            Subscription::Entity { client, id } => {
 217                if let Some(client) = client.upgrade() {
 218                    let mut state = client.state.write();
 219                    let _ = state.entities_by_type_and_remote_id.remove(id);
 220                }
 221            }
 222            Subscription::Message { client, id } => {
 223                if let Some(client) = client.upgrade() {
 224                    let mut state = client.state.write();
 225                    let _ = state.entity_types_by_message_type.remove(id);
 226                    let _ = state.message_handlers.remove(id);
 227                }
 228            }
 229        }
 230    }
 231}
 232
 233impl Client {
 234    pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
 235        lazy_static! {
 236            static ref NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::default();
 237        }
 238
 239        Arc::new(Self {
 240            id: NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst),
 241            peer: Peer::new(),
 242            http,
 243            state: Default::default(),
 244
 245            #[cfg(any(test, feature = "test-support"))]
 246            authenticate: Default::default(),
 247            #[cfg(any(test, feature = "test-support"))]
 248            establish_connection: Default::default(),
 249        })
 250    }
 251
 252    pub fn id(&self) -> usize {
 253        self.id
 254    }
 255
 256    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 257        self.http.clone()
 258    }
 259
 260    #[cfg(any(test, feature = "test-support"))]
 261    pub fn tear_down(&self) {
 262        let mut state = self.state.write();
 263        state._reconnect_task.take();
 264        state.message_handlers.clear();
 265        state.models_by_message_type.clear();
 266        state.entities_by_type_and_remote_id.clear();
 267        state.entity_id_extractors.clear();
 268        self.peer.reset();
 269    }
 270
 271    #[cfg(any(test, feature = "test-support"))]
 272    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 273    where
 274        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 275    {
 276        *self.authenticate.write() = Some(Box::new(authenticate));
 277        self
 278    }
 279
 280    #[cfg(any(test, feature = "test-support"))]
 281    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 282    where
 283        F: 'static
 284            + Send
 285            + Sync
 286            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 287    {
 288        *self.establish_connection.write() = Some(Box::new(connect));
 289        self
 290    }
 291
 292    pub fn user_id(&self) -> Option<u64> {
 293        self.state
 294            .read()
 295            .credentials
 296            .as_ref()
 297            .map(|credentials| credentials.user_id)
 298    }
 299
 300    pub fn status(&self) -> watch::Receiver<Status> {
 301        self.state.read().status.1.clone()
 302    }
 303
 304    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 305        log::info!("set status on client {}: {:?}", self.id, status);
 306        let mut state = self.state.write();
 307        *state.status.0.borrow_mut() = status;
 308
 309        match status {
 310            Status::Connected { .. } => {
 311                state._reconnect_task = None;
 312            }
 313            Status::ConnectionLost => {
 314                let this = self.clone();
 315                let reconnect_interval = state.reconnect_interval;
 316                state._reconnect_task = Some(cx.spawn(|cx| async move {
 317                    let mut rng = StdRng::from_entropy();
 318                    let mut delay = Duration::from_millis(100);
 319                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 320                        log::error!("failed to connect {}", error);
 321                        if matches!(*this.status().borrow(), Status::ConnectionError) {
 322                            this.set_status(
 323                                Status::ReconnectionError {
 324                                    next_reconnection: Instant::now() + delay,
 325                                },
 326                                &cx,
 327                            );
 328                            cx.background().timer(delay).await;
 329                            delay = delay
 330                                .mul_f32(rng.gen_range(1.0..=2.0))
 331                                .min(reconnect_interval);
 332                        } else {
 333                            break;
 334                        }
 335                    }
 336                }));
 337            }
 338            Status::SignedOut | Status::UpgradeRequired => {
 339                state._reconnect_task.take();
 340            }
 341            _ => {}
 342        }
 343    }
 344
 345    pub fn add_view_for_remote_entity<T: View>(
 346        self: &Arc<Self>,
 347        remote_id: u64,
 348        cx: &mut ViewContext<T>,
 349    ) -> Subscription {
 350        let id = (TypeId::of::<T>(), remote_id);
 351        self.state
 352            .write()
 353            .entities_by_type_and_remote_id
 354            .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
 355        Subscription::Entity {
 356            client: Arc::downgrade(self),
 357            id,
 358        }
 359    }
 360
 361    pub fn add_model_for_remote_entity<T: Entity>(
 362        self: &Arc<Self>,
 363        remote_id: u64,
 364        cx: &mut ModelContext<T>,
 365    ) -> Subscription {
 366        let id = (TypeId::of::<T>(), remote_id);
 367        self.state
 368            .write()
 369            .entities_by_type_and_remote_id
 370            .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
 371        Subscription::Entity {
 372            client: Arc::downgrade(self),
 373            id,
 374        }
 375    }
 376
 377    pub fn add_message_handler<M, E, H, F>(
 378        self: &Arc<Self>,
 379        model: ModelHandle<E>,
 380        handler: H,
 381    ) -> Subscription
 382    where
 383        M: EnvelopedMessage,
 384        E: Entity,
 385        H: 'static
 386            + Send
 387            + Sync
 388            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 389        F: 'static + Future<Output = Result<()>>,
 390    {
 391        let message_type_id = TypeId::of::<M>();
 392
 393        let mut state = self.state.write();
 394        state
 395            .models_by_message_type
 396            .insert(message_type_id, model.downgrade().into());
 397
 398        let prev_handler = state.message_handlers.insert(
 399            message_type_id,
 400            Arc::new(move |handle, envelope, client, cx| {
 401                let handle = if let AnyEntityHandle::Model(handle) = handle {
 402                    handle
 403                } else {
 404                    unreachable!();
 405                };
 406                let model = handle.downcast::<E>().unwrap();
 407                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 408                handler(model, *envelope, client.clone(), cx).boxed_local()
 409            }),
 410        );
 411        if prev_handler.is_some() {
 412            panic!("registered handler for the same message twice");
 413        }
 414
 415        Subscription::Message {
 416            client: Arc::downgrade(self),
 417            id: message_type_id,
 418        }
 419    }
 420
 421    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 422    where
 423        M: EntityMessage,
 424        E: View,
 425        H: 'static
 426            + Send
 427            + Sync
 428            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 429        F: 'static + Future<Output = Result<()>>,
 430    {
 431        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 432            if let AnyEntityHandle::View(handle) = handle {
 433                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 434            } else {
 435                unreachable!();
 436            }
 437        })
 438    }
 439
 440    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 441    where
 442        M: EntityMessage,
 443        E: Entity,
 444        H: 'static
 445            + Send
 446            + Sync
 447            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 448        F: 'static + Future<Output = Result<()>>,
 449    {
 450        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 451            if let AnyEntityHandle::Model(handle) = handle {
 452                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 453            } else {
 454                unreachable!();
 455            }
 456        })
 457    }
 458
 459    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 460    where
 461        M: EntityMessage,
 462        E: Entity,
 463        H: 'static
 464            + Send
 465            + Sync
 466            + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 467        F: 'static + Future<Output = Result<()>>,
 468    {
 469        let model_type_id = TypeId::of::<E>();
 470        let message_type_id = TypeId::of::<M>();
 471
 472        let mut state = self.state.write();
 473        state
 474            .entity_types_by_message_type
 475            .insert(message_type_id, model_type_id);
 476        state
 477            .entity_id_extractors
 478            .entry(message_type_id)
 479            .or_insert_with(|| {
 480                |envelope| {
 481                    envelope
 482                        .as_any()
 483                        .downcast_ref::<TypedEnvelope<M>>()
 484                        .unwrap()
 485                        .payload
 486                        .remote_entity_id()
 487                }
 488            });
 489        let prev_handler = state.message_handlers.insert(
 490            message_type_id,
 491            Arc::new(move |handle, envelope, client, cx| {
 492                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 493                handler(handle, *envelope, client.clone(), cx).boxed_local()
 494            }),
 495        );
 496        if prev_handler.is_some() {
 497            panic!("registered handler for the same message twice");
 498        }
 499    }
 500
 501    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 502    where
 503        M: EntityMessage + RequestMessage,
 504        E: Entity,
 505        H: 'static
 506            + Send
 507            + Sync
 508            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 509        F: 'static + Future<Output = Result<M::Response>>,
 510    {
 511        self.add_model_message_handler(move |entity, envelope, client, cx| {
 512            Self::respond_to_request::<M, _>(
 513                envelope.receipt(),
 514                handler(entity, envelope, client.clone(), cx),
 515                client,
 516            )
 517        })
 518    }
 519
 520    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 521    where
 522        M: EntityMessage + RequestMessage,
 523        E: View,
 524        H: 'static
 525            + Send
 526            + Sync
 527            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 528        F: 'static + Future<Output = Result<M::Response>>,
 529    {
 530        self.add_view_message_handler(move |entity, envelope, client, cx| {
 531            Self::respond_to_request::<M, _>(
 532                envelope.receipt(),
 533                handler(entity, envelope, client.clone(), cx),
 534                client,
 535            )
 536        })
 537    }
 538
 539    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 540        receipt: Receipt<T>,
 541        response: F,
 542        client: Arc<Self>,
 543    ) -> Result<()> {
 544        match response.await {
 545            Ok(response) => {
 546                client.respond(receipt, response)?;
 547                Ok(())
 548            }
 549            Err(error) => {
 550                client.respond_with_error(
 551                    receipt,
 552                    proto::Error {
 553                        message: error.to_string(),
 554                    },
 555                )?;
 556                Err(error)
 557            }
 558        }
 559    }
 560
 561    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 562        read_credentials_from_keychain(cx).is_some()
 563    }
 564
 565    #[async_recursion(?Send)]
 566    pub async fn authenticate_and_connect(
 567        self: &Arc<Self>,
 568        try_keychain: bool,
 569        cx: &AsyncAppContext,
 570    ) -> anyhow::Result<()> {
 571        let was_disconnected = match *self.status().borrow() {
 572            Status::SignedOut => true,
 573            Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
 574                false
 575            }
 576            Status::Connected { .. }
 577            | Status::Connecting { .. }
 578            | Status::Reconnecting { .. }
 579            | Status::Authenticating
 580            | Status::Reauthenticating => return Ok(()),
 581            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 582        };
 583
 584        if was_disconnected {
 585            self.set_status(Status::Authenticating, cx);
 586        } else {
 587            self.set_status(Status::Reauthenticating, cx)
 588        }
 589
 590        let mut read_from_keychain = false;
 591        let mut credentials = self.state.read().credentials.clone();
 592        if credentials.is_none() && try_keychain {
 593            credentials = read_credentials_from_keychain(cx);
 594            read_from_keychain = credentials.is_some();
 595        }
 596        if credentials.is_none() {
 597            credentials = Some(match self.authenticate(&cx).await {
 598                Ok(credentials) => credentials,
 599                Err(err) => {
 600                    self.set_status(Status::ConnectionError, cx);
 601                    return Err(err);
 602                }
 603            });
 604        }
 605        let credentials = credentials.unwrap();
 606
 607        if was_disconnected {
 608            self.set_status(Status::Connecting, cx);
 609        } else {
 610            self.set_status(Status::Reconnecting, cx);
 611        }
 612
 613        match self.establish_connection(&credentials, cx).await {
 614            Ok(conn) => {
 615                self.state.write().credentials = Some(credentials.clone());
 616                if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 617                    write_credentials_to_keychain(&credentials, cx).log_err();
 618                }
 619                self.set_connection(conn, cx).await;
 620                Ok(())
 621            }
 622            Err(EstablishConnectionError::Unauthorized) => {
 623                self.state.write().credentials.take();
 624                if read_from_keychain {
 625                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 626                    self.set_status(Status::SignedOut, cx);
 627                    self.authenticate_and_connect(false, cx).await
 628                } else {
 629                    self.set_status(Status::ConnectionError, cx);
 630                    Err(EstablishConnectionError::Unauthorized)?
 631                }
 632            }
 633            Err(EstablishConnectionError::UpgradeRequired) => {
 634                self.set_status(Status::UpgradeRequired, cx);
 635                Err(EstablishConnectionError::UpgradeRequired)?
 636            }
 637            Err(error) => {
 638                self.set_status(Status::ConnectionError, cx);
 639                Err(error)?
 640            }
 641        }
 642    }
 643
 644    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
 645        let executor = cx.background();
 646        log::info!("add connection to peer");
 647        let (connection_id, handle_io, mut incoming) = self
 648            .peer
 649            .add_connection(conn, move |duration| executor.timer(duration))
 650            .await;
 651        log::info!("set status to connected {}", connection_id);
 652        self.set_status(Status::Connected { connection_id }, cx);
 653        cx.foreground()
 654            .spawn({
 655                let cx = cx.clone();
 656                let this = self.clone();
 657                async move {
 658                    let mut message_id = 0_usize;
 659                    while let Some(message) = incoming.next().await {
 660                        let mut state = this.state.write();
 661                        message_id += 1;
 662                        let type_name = message.payload_type_name();
 663                        let payload_type_id = message.payload_type_id();
 664                        let sender_id = message.original_sender_id().map(|id| id.0);
 665
 666                        let model = state
 667                            .models_by_message_type
 668                            .get(&payload_type_id)
 669                            .and_then(|model| model.upgrade(&cx))
 670                            .map(AnyEntityHandle::Model)
 671                            .or_else(|| {
 672                                let entity_type_id =
 673                                    *state.entity_types_by_message_type.get(&payload_type_id)?;
 674                                let entity_id = state
 675                                    .entity_id_extractors
 676                                    .get(&message.payload_type_id())
 677                                    .map(|extract_entity_id| {
 678                                        (extract_entity_id)(message.as_ref())
 679                                    })?;
 680
 681                                let entity = state
 682                                    .entities_by_type_and_remote_id
 683                                    .get(&(entity_type_id, entity_id))?;
 684                                if let Some(entity) = entity.upgrade(&cx) {
 685                                    Some(entity)
 686                                } else {
 687                                    state
 688                                        .entities_by_type_and_remote_id
 689                                        .remove(&(entity_type_id, entity_id));
 690                                    None
 691                                }
 692                            });
 693
 694                        let model = if let Some(model) = model {
 695                            model
 696                        } else {
 697                            log::info!("unhandled message {}", type_name);
 698                            continue;
 699                        };
 700
 701                        if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
 702                        {
 703                            drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
 704                            let future = handler(model, message, &this, cx.clone());
 705
 706                            let client_id = this.id;
 707                            log::debug!(
 708                                "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 709                                client_id,
 710                                message_id,
 711                                sender_id,
 712                                type_name
 713                            );
 714                            cx.foreground()
 715                                .spawn(async move {
 716                                    match future.await {
 717                                        Ok(()) => {
 718                                            log::debug!(
 719                                                "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 720                                                client_id,
 721                                                message_id,
 722                                                sender_id,
 723                                                type_name
 724                                            );
 725                                        }
 726                                        Err(error) => {
 727                                            log::error!(
 728                                                "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
 729                                                client_id,
 730                                                message_id,
 731                                                sender_id,
 732                                                type_name,
 733                                                error
 734                                            );
 735                                        }
 736                                    }
 737                                })
 738                                .detach();
 739                        } else {
 740                            log::info!("unhandled message {}", type_name);
 741                        }
 742
 743                        // Don't starve the main thread when receiving lots of messages at once.
 744                        smol::future::yield_now().await;
 745                    }
 746                }
 747            })
 748            .detach();
 749
 750        let handle_io = cx.background().spawn(handle_io);
 751        let this = self.clone();
 752        let cx = cx.clone();
 753        cx.foreground()
 754            .spawn(async move {
 755                match handle_io.await {
 756                    Ok(()) => {
 757                        if *this.status().borrow() == (Status::Connected { connection_id }) {
 758                            this.set_status(Status::SignedOut, &cx);
 759                        }
 760                    }
 761                    Err(err) => {
 762                        log::error!("connection error: {:?}", err);
 763                        this.set_status(Status::ConnectionLost, &cx);
 764                    }
 765                }
 766            })
 767            .detach();
 768    }
 769
 770    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 771        #[cfg(any(test, feature = "test-support"))]
 772        if let Some(callback) = self.authenticate.read().as_ref() {
 773            return callback(cx);
 774        }
 775
 776        self.authenticate_with_browser(cx)
 777    }
 778
 779    fn establish_connection(
 780        self: &Arc<Self>,
 781        credentials: &Credentials,
 782        cx: &AsyncAppContext,
 783    ) -> Task<Result<Connection, EstablishConnectionError>> {
 784        #[cfg(any(test, feature = "test-support"))]
 785        if let Some(callback) = self.establish_connection.read().as_ref() {
 786            return callback(credentials, cx);
 787        }
 788
 789        self.establish_websocket_connection(credentials, cx)
 790    }
 791
 792    fn establish_websocket_connection(
 793        self: &Arc<Self>,
 794        credentials: &Credentials,
 795        cx: &AsyncAppContext,
 796    ) -> Task<Result<Connection, EstablishConnectionError>> {
 797        let request = Request::builder()
 798            .header(
 799                "Authorization",
 800                format!("{} {}", credentials.user_id, credentials.access_token),
 801            )
 802            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
 803
 804        let http = self.http.clone();
 805        cx.background().spawn(async move {
 806            let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
 807            let rpc_response = http.get(&rpc_url, Default::default(), false).await?;
 808            if rpc_response.status().is_redirection() {
 809                rpc_url = rpc_response
 810                    .headers()
 811                    .get("Location")
 812                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 813                    .to_str()
 814                    .map_err(|error| EstablishConnectionError::other(error))?
 815                    .to_string();
 816            }
 817            // Until we switch the zed.dev domain to point to the new Next.js app, there
 818            // will be no redirect required, and the app will connect directly to
 819            // wss://zed.dev/rpc.
 820            else if rpc_response.status() != StatusCode::UPGRADE_REQUIRED {
 821                Err(anyhow!(
 822                    "unexpected /rpc response status {}",
 823                    rpc_response.status()
 824                ))?
 825            }
 826
 827            let mut rpc_url = Url::parse(&rpc_url).context("invalid rpc url")?;
 828            let rpc_host = rpc_url
 829                .host_str()
 830                .zip(rpc_url.port_or_known_default())
 831                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
 832            let stream = smol::net::TcpStream::connect(rpc_host).await?;
 833
 834            log::info!("connected to rpc endpoint {}", rpc_url);
 835
 836            match rpc_url.scheme() {
 837                "https" => {
 838                    rpc_url.set_scheme("wss").unwrap();
 839                    let request = request.uri(rpc_url.as_str()).body(())?;
 840                    let (stream, _) =
 841                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
 842                    Ok(Connection::new(
 843                        stream
 844                            .map_err(|error| anyhow!(error))
 845                            .sink_map_err(|error| anyhow!(error)),
 846                    ))
 847                }
 848                "http" => {
 849                    rpc_url.set_scheme("ws").unwrap();
 850                    let request = request.uri(rpc_url.as_str()).body(())?;
 851                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
 852                    Ok(Connection::new(
 853                        stream
 854                            .map_err(|error| anyhow!(error))
 855                            .sink_map_err(|error| anyhow!(error)),
 856                    ))
 857                }
 858                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
 859            }
 860        })
 861    }
 862
 863    pub fn authenticate_with_browser(
 864        self: &Arc<Self>,
 865        cx: &AsyncAppContext,
 866    ) -> Task<Result<Credentials>> {
 867        let platform = cx.platform();
 868        let executor = cx.background();
 869        executor.clone().spawn(async move {
 870            // Generate a pair of asymmetric encryption keys. The public key will be used by the
 871            // zed server to encrypt the user's access token, so that it can'be intercepted by
 872            // any other app running on the user's device.
 873            let (public_key, private_key) =
 874                rpc::auth::keypair().expect("failed to generate keypair for auth");
 875            let public_key_string =
 876                String::try_from(public_key).expect("failed to serialize public key for auth");
 877
 878            // Start an HTTP server to receive the redirect from Zed's sign-in page.
 879            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
 880            let port = server.server_addr().port();
 881
 882            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
 883            // that the user is signing in from a Zed app running on the same device.
 884            let mut url = format!(
 885                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
 886                *ZED_SERVER_URL, port, public_key_string
 887            );
 888
 889            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
 890                log::info!("impersonating user @{}", impersonate_login);
 891                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
 892            }
 893
 894            platform.open_url(&url);
 895
 896            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
 897            // access token from the query params.
 898            //
 899            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
 900            // custom URL scheme instead of this local HTTP server.
 901            let (user_id, access_token) = executor
 902                .spawn(async move {
 903                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
 904                        let path = req.url();
 905                        let mut user_id = None;
 906                        let mut access_token = None;
 907                        let url = Url::parse(&format!("http://example.com{}", path))
 908                            .context("failed to parse login notification url")?;
 909                        for (key, value) in url.query_pairs() {
 910                            if key == "access_token" {
 911                                access_token = Some(value.to_string());
 912                            } else if key == "user_id" {
 913                                user_id = Some(value.to_string());
 914                            }
 915                        }
 916
 917                        let post_auth_url =
 918                            format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
 919                        req.respond(
 920                            tiny_http::Response::empty(302).with_header(
 921                                tiny_http::Header::from_bytes(
 922                                    &b"Location"[..],
 923                                    post_auth_url.as_bytes(),
 924                                )
 925                                .unwrap(),
 926                            ),
 927                        )
 928                        .context("failed to respond to login http request")?;
 929                        Ok((
 930                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
 931                            access_token
 932                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
 933                        ))
 934                    } else {
 935                        Err(anyhow!("didn't receive login redirect"))
 936                    }
 937                })
 938                .await?;
 939
 940            let access_token = private_key
 941                .decrypt_string(&access_token)
 942                .context("failed to decrypt access token")?;
 943            platform.activate(true);
 944
 945            Ok(Credentials {
 946                user_id: user_id.parse()?,
 947                access_token,
 948            })
 949        })
 950    }
 951
 952    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
 953        let conn_id = self.connection_id()?;
 954        self.peer.disconnect(conn_id);
 955        self.set_status(Status::SignedOut, cx);
 956        Ok(())
 957    }
 958
 959    fn connection_id(&self) -> Result<ConnectionId> {
 960        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
 961            Ok(connection_id)
 962        } else {
 963            Err(anyhow!("not connected"))
 964        }
 965    }
 966
 967    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
 968        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
 969        self.peer.send(self.connection_id()?, message)
 970    }
 971
 972    pub fn request<T: RequestMessage>(
 973        &self,
 974        request: T,
 975    ) -> impl Future<Output = Result<T::Response>> {
 976        let client_id = self.id;
 977        log::debug!(
 978            "rpc request start. client_id:{}. name:{}",
 979            client_id,
 980            T::NAME
 981        );
 982        let response = self
 983            .connection_id()
 984            .map(|conn_id| self.peer.request(conn_id, request));
 985        async move {
 986            let response = response?.await;
 987            log::debug!(
 988                "rpc request finish. client_id:{}. name:{}",
 989                client_id,
 990                T::NAME
 991            );
 992            response
 993        }
 994    }
 995
 996    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
 997        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
 998        self.peer.respond(receipt, response)
 999    }
1000
1001    fn respond_with_error<T: RequestMessage>(
1002        &self,
1003        receipt: Receipt<T>,
1004        error: proto::Error,
1005    ) -> Result<()> {
1006        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1007        self.peer.respond_with_error(receipt, error)
1008    }
1009}
1010
1011impl AnyWeakEntityHandle {
1012    fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
1013        match self {
1014            AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
1015            AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
1016        }
1017    }
1018}
1019
1020fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1021    if IMPERSONATE_LOGIN.is_some() {
1022        return None;
1023    }
1024
1025    let (user_id, access_token) = cx
1026        .platform()
1027        .read_credentials(&ZED_SERVER_URL)
1028        .log_err()
1029        .flatten()?;
1030    Some(Credentials {
1031        user_id: user_id.parse().ok()?,
1032        access_token: String::from_utf8(access_token).ok()?,
1033    })
1034}
1035
1036fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1037    cx.platform().write_credentials(
1038        &ZED_SERVER_URL,
1039        &credentials.user_id.to_string(),
1040        credentials.access_token.as_bytes(),
1041    )
1042}
1043
1044const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
1045
1046pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1047    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1048}
1049
1050pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1051    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1052    let mut parts = path.split('/');
1053    let id = parts.next()?.parse::<u64>().ok()?;
1054    let access_token = parts.next()?;
1055    if access_token.is_empty() {
1056        return None;
1057    }
1058    Some((id, access_token.to_string()))
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063    use super::*;
1064    use crate::test::{FakeHttpClient, FakeServer};
1065    use gpui::TestAppContext;
1066
1067    #[gpui::test(iterations = 10)]
1068    async fn test_reconnection(cx: &mut TestAppContext) {
1069        cx.foreground().forbid_parking();
1070
1071        let user_id = 5;
1072        let mut client = Client::new(FakeHttpClient::with_404_response());
1073        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1074        let mut status = client.status();
1075        assert!(matches!(
1076            status.next().await,
1077            Some(Status::Connected { .. })
1078        ));
1079        assert_eq!(server.auth_count(), 1);
1080
1081        server.forbid_connections();
1082        server.disconnect();
1083        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1084
1085        server.allow_connections();
1086        cx.foreground().advance_clock(Duration::from_secs(10));
1087        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1088        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1089
1090        server.forbid_connections();
1091        server.disconnect();
1092        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1093
1094        // Clear cached credentials after authentication fails
1095        server.roll_access_token();
1096        server.allow_connections();
1097        cx.foreground().advance_clock(Duration::from_secs(10));
1098        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1099        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1100    }
1101
1102    #[test]
1103    fn test_encode_and_decode_worktree_url() {
1104        let url = encode_worktree_url(5, "deadbeef");
1105        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1106        assert_eq!(
1107            decode_worktree_url(&format!("\n {}\t", url)),
1108            Some((5, "deadbeef".to_string()))
1109        );
1110        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1111    }
1112
1113    #[gpui::test]
1114    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1115        cx.foreground().forbid_parking();
1116
1117        let user_id = 5;
1118        let mut client = Client::new(FakeHttpClient::with_404_response());
1119        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1120
1121        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1122        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1123        client.add_model_message_handler(
1124            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1125                match model.read_with(&cx, |model, _| model.id) {
1126                    1 => done_tx1.try_send(()).unwrap(),
1127                    2 => done_tx2.try_send(()).unwrap(),
1128                    _ => unreachable!(),
1129                }
1130                async { Ok(()) }
1131            },
1132        );
1133        let model1 = cx.add_model(|_| Model {
1134            id: 1,
1135            subscription: None,
1136        });
1137        let model2 = cx.add_model(|_| Model {
1138            id: 2,
1139            subscription: None,
1140        });
1141        let model3 = cx.add_model(|_| Model {
1142            id: 3,
1143            subscription: None,
1144        });
1145
1146        let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1147        let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1148        // Ensure dropping a subscription for the same entity type still allows receiving of
1149        // messages for other entity IDs of the same type.
1150        let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1151        drop(subscription3);
1152
1153        server.send(proto::JoinProject { project_id: 1 });
1154        server.send(proto::JoinProject { project_id: 2 });
1155        done_rx1.next().await.unwrap();
1156        done_rx2.next().await.unwrap();
1157    }
1158
1159    #[gpui::test]
1160    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1161        cx.foreground().forbid_parking();
1162
1163        let user_id = 5;
1164        let mut client = Client::new(FakeHttpClient::with_404_response());
1165        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1166
1167        let model = cx.add_model(|_| Model::default());
1168        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1169        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1170        let subscription1 = client.add_message_handler(
1171            model.clone(),
1172            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1173                done_tx1.try_send(()).unwrap();
1174                async { Ok(()) }
1175            },
1176        );
1177        drop(subscription1);
1178        let _subscription2 =
1179            client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1180                done_tx2.try_send(()).unwrap();
1181                async { Ok(()) }
1182            });
1183        server.send(proto::Ping {});
1184        done_rx2.next().await.unwrap();
1185    }
1186
1187    #[gpui::test]
1188    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1189        cx.foreground().forbid_parking();
1190
1191        let user_id = 5;
1192        let mut client = Client::new(FakeHttpClient::with_404_response());
1193        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1194
1195        let model = cx.add_model(|_| Model::default());
1196        let (done_tx, mut done_rx) = smol::channel::unbounded();
1197        let subscription = client.add_message_handler(
1198            model.clone(),
1199            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1200                model.update(&mut cx, |model, _| model.subscription.take());
1201                done_tx.try_send(()).unwrap();
1202                async { Ok(()) }
1203            },
1204        );
1205        model.update(cx, |model, _| {
1206            model.subscription = Some(subscription);
1207        });
1208        server.send(proto::Ping {});
1209        done_rx.next().await.unwrap();
1210    }
1211
1212    #[derive(Default)]
1213    struct Model {
1214        id: usize,
1215        subscription: Option<Subscription>,
1216    }
1217
1218    impl Entity for Model {
1219        type Event = ();
1220    }
1221}