client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod channel;
   5pub mod http;
   6pub mod user;
   7
   8use anyhow::{anyhow, Context, Result};
   9use async_recursion::async_recursion;
  10use async_tungstenite::tungstenite::{
  11    error::Error as WebsocketError,
  12    http::{Request, StatusCode},
  13};
  14use futures::{future::LocalBoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt};
  15use gpui::{
  16    actions, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AsyncAppContext,
  17    Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle,
  18};
  19use http::HttpClient;
  20use lazy_static::lazy_static;
  21use parking_lot::RwLock;
  22use postage::watch;
  23use rand::prelude::*;
  24use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
  25use std::{
  26    any::TypeId,
  27    collections::HashMap,
  28    convert::TryFrom,
  29    fmt::Write as _,
  30    future::Future,
  31    sync::{
  32        atomic::{AtomicUsize, Ordering},
  33        Arc, Weak,
  34    },
  35    time::{Duration, Instant},
  36};
  37use thiserror::Error;
  38use url::Url;
  39use util::{ResultExt, TryFutureExt};
  40
  41pub use channel::*;
  42pub use rpc::*;
  43pub use user::*;
  44
  45lazy_static! {
  46    pub static ref ZED_SERVER_URL: String =
  47        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
  48    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  49        .ok()
  50        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  51}
  52
  53pub const ZED_SECRET_CLIENT_TOKEN: &'static str = "618033988749894";
  54
  55actions!(client, [Authenticate]);
  56
  57pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
  58    cx.add_global_action(move |_: &Authenticate, cx| {
  59        let rpc = rpc.clone();
  60        cx.spawn(|cx| async move { rpc.authenticate_and_connect(true, &cx).log_err().await })
  61            .detach();
  62    });
  63}
  64
  65pub struct Client {
  66    id: usize,
  67    peer: Arc<Peer>,
  68    http: Arc<dyn HttpClient>,
  69    state: RwLock<ClientState>,
  70
  71    #[cfg(any(test, feature = "test-support"))]
  72    authenticate: RwLock<
  73        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
  74    >,
  75    #[cfg(any(test, feature = "test-support"))]
  76    establish_connection: RwLock<
  77        Option<
  78            Box<
  79                dyn 'static
  80                    + Send
  81                    + Sync
  82                    + Fn(
  83                        &Credentials,
  84                        &AsyncAppContext,
  85                    ) -> Task<Result<Connection, EstablishConnectionError>>,
  86            >,
  87        >,
  88    >,
  89}
  90
  91#[derive(Error, Debug)]
  92pub enum EstablishConnectionError {
  93    #[error("upgrade required")]
  94    UpgradeRequired,
  95    #[error("unauthorized")]
  96    Unauthorized,
  97    #[error("{0}")]
  98    Other(#[from] anyhow::Error),
  99    #[error("{0}")]
 100    Http(#[from] http::Error),
 101    #[error("{0}")]
 102    Io(#[from] std::io::Error),
 103    #[error("{0}")]
 104    Websocket(#[from] async_tungstenite::tungstenite::http::Error),
 105}
 106
 107impl From<WebsocketError> for EstablishConnectionError {
 108    fn from(error: WebsocketError) -> Self {
 109        if let WebsocketError::Http(response) = &error {
 110            match response.status() {
 111                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 112                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 113                _ => {}
 114            }
 115        }
 116        EstablishConnectionError::Other(error.into())
 117    }
 118}
 119
 120impl EstablishConnectionError {
 121    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 122        Self::Other(error.into())
 123    }
 124}
 125
 126#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 127pub enum Status {
 128    SignedOut,
 129    UpgradeRequired,
 130    Authenticating,
 131    Connecting,
 132    ConnectionError,
 133    Connected { connection_id: ConnectionId },
 134    ConnectionLost,
 135    Reauthenticating,
 136    Reconnecting,
 137    ReconnectionError { next_reconnection: Instant },
 138}
 139
 140impl Status {
 141    pub fn is_connected(&self) -> bool {
 142        matches!(self, Self::Connected { .. })
 143    }
 144}
 145
 146struct ClientState {
 147    credentials: Option<Credentials>,
 148    status: (watch::Sender<Status>, watch::Receiver<Status>),
 149    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 150    _reconnect_task: Option<Task<()>>,
 151    reconnect_interval: Duration,
 152    entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
 153    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 154    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 155    message_handlers: HashMap<
 156        TypeId,
 157        Arc<
 158            dyn Send
 159                + Sync
 160                + Fn(
 161                    AnyEntityHandle,
 162                    Box<dyn AnyTypedEnvelope>,
 163                    &Arc<Client>,
 164                    AsyncAppContext,
 165                ) -> LocalBoxFuture<'static, Result<()>>,
 166        >,
 167    >,
 168}
 169
 170enum AnyWeakEntityHandle {
 171    Model(AnyWeakModelHandle),
 172    View(AnyWeakViewHandle),
 173}
 174
 175enum AnyEntityHandle {
 176    Model(AnyModelHandle),
 177    View(AnyViewHandle),
 178}
 179
 180#[derive(Clone, Debug)]
 181pub struct Credentials {
 182    pub user_id: u64,
 183    pub access_token: String,
 184}
 185
 186impl Default for ClientState {
 187    fn default() -> Self {
 188        Self {
 189            credentials: None,
 190            status: watch::channel_with(Status::SignedOut),
 191            entity_id_extractors: Default::default(),
 192            _reconnect_task: None,
 193            reconnect_interval: Duration::from_secs(5),
 194            models_by_message_type: Default::default(),
 195            entities_by_type_and_remote_id: Default::default(),
 196            entity_types_by_message_type: Default::default(),
 197            message_handlers: Default::default(),
 198        }
 199    }
 200}
 201
 202pub enum Subscription {
 203    Entity {
 204        client: Weak<Client>,
 205        id: (TypeId, u64),
 206    },
 207    Message {
 208        client: Weak<Client>,
 209        id: TypeId,
 210    },
 211}
 212
 213impl Drop for Subscription {
 214    fn drop(&mut self) {
 215        match self {
 216            Subscription::Entity { client, id } => {
 217                if let Some(client) = client.upgrade() {
 218                    let mut state = client.state.write();
 219                    let _ = state.entities_by_type_and_remote_id.remove(id);
 220                }
 221            }
 222            Subscription::Message { client, id } => {
 223                if let Some(client) = client.upgrade() {
 224                    let mut state = client.state.write();
 225                    let _ = state.entity_types_by_message_type.remove(id);
 226                    let _ = state.message_handlers.remove(id);
 227                }
 228            }
 229        }
 230    }
 231}
 232
 233impl Client {
 234    pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
 235        lazy_static! {
 236            static ref NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::default();
 237        }
 238
 239        Arc::new(Self {
 240            id: NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst),
 241            peer: Peer::new(),
 242            http,
 243            state: Default::default(),
 244
 245            #[cfg(any(test, feature = "test-support"))]
 246            authenticate: Default::default(),
 247            #[cfg(any(test, feature = "test-support"))]
 248            establish_connection: Default::default(),
 249        })
 250    }
 251
 252    pub fn id(&self) -> usize {
 253        self.id
 254    }
 255
 256    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 257        self.http.clone()
 258    }
 259
 260    #[cfg(any(test, feature = "test-support"))]
 261    pub fn tear_down(&self) {
 262        let mut state = self.state.write();
 263        state._reconnect_task.take();
 264        state.message_handlers.clear();
 265        state.models_by_message_type.clear();
 266        state.entities_by_type_and_remote_id.clear();
 267        state.entity_id_extractors.clear();
 268        self.peer.reset();
 269    }
 270
 271    #[cfg(any(test, feature = "test-support"))]
 272    pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
 273    where
 274        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 275    {
 276        *self.authenticate.write() = Some(Box::new(authenticate));
 277        self
 278    }
 279
 280    #[cfg(any(test, feature = "test-support"))]
 281    pub fn override_establish_connection<F>(&self, connect: F) -> &Self
 282    where
 283        F: 'static
 284            + Send
 285            + Sync
 286            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 287    {
 288        *self.establish_connection.write() = Some(Box::new(connect));
 289        self
 290    }
 291
 292    pub fn user_id(&self) -> Option<u64> {
 293        self.state
 294            .read()
 295            .credentials
 296            .as_ref()
 297            .map(|credentials| credentials.user_id)
 298    }
 299
 300    pub fn status(&self) -> watch::Receiver<Status> {
 301        self.state.read().status.1.clone()
 302    }
 303
 304    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 305        log::info!("set status on client {}: {:?}", self.id, status);
 306        let mut state = self.state.write();
 307        *state.status.0.borrow_mut() = status;
 308
 309        match status {
 310            Status::Connected { .. } => {
 311                state._reconnect_task = None;
 312            }
 313            Status::ConnectionLost => {
 314                let this = self.clone();
 315                let reconnect_interval = state.reconnect_interval;
 316                state._reconnect_task = Some(cx.spawn(|cx| async move {
 317                    let mut rng = StdRng::from_entropy();
 318                    let mut delay = Duration::from_millis(100);
 319                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 320                        log::error!("failed to connect {}", error);
 321                        this.set_status(
 322                            Status::ReconnectionError {
 323                                next_reconnection: Instant::now() + delay,
 324                            },
 325                            &cx,
 326                        );
 327                        cx.background().timer(delay).await;
 328                        delay = delay
 329                            .mul_f32(rng.gen_range(1.0..=2.0))
 330                            .min(reconnect_interval);
 331                    }
 332                }));
 333            }
 334            Status::SignedOut | Status::UpgradeRequired => {
 335                state._reconnect_task.take();
 336            }
 337            _ => {}
 338        }
 339    }
 340
 341    pub fn add_view_for_remote_entity<T: View>(
 342        self: &Arc<Self>,
 343        remote_id: u64,
 344        cx: &mut ViewContext<T>,
 345    ) -> Subscription {
 346        let id = (TypeId::of::<T>(), remote_id);
 347        self.state
 348            .write()
 349            .entities_by_type_and_remote_id
 350            .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
 351        Subscription::Entity {
 352            client: Arc::downgrade(self),
 353            id,
 354        }
 355    }
 356
 357    pub fn add_model_for_remote_entity<T: Entity>(
 358        self: &Arc<Self>,
 359        remote_id: u64,
 360        cx: &mut ModelContext<T>,
 361    ) -> Subscription {
 362        let id = (TypeId::of::<T>(), remote_id);
 363        self.state
 364            .write()
 365            .entities_by_type_and_remote_id
 366            .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
 367        Subscription::Entity {
 368            client: Arc::downgrade(self),
 369            id,
 370        }
 371    }
 372
 373    pub fn add_message_handler<M, E, H, F>(
 374        self: &Arc<Self>,
 375        model: ModelHandle<E>,
 376        handler: H,
 377    ) -> Subscription
 378    where
 379        M: EnvelopedMessage,
 380        E: Entity,
 381        H: 'static
 382            + Send
 383            + Sync
 384            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 385        F: 'static + Future<Output = Result<()>>,
 386    {
 387        let message_type_id = TypeId::of::<M>();
 388
 389        let mut state = self.state.write();
 390        state
 391            .models_by_message_type
 392            .insert(message_type_id, model.downgrade().into());
 393
 394        let prev_handler = state.message_handlers.insert(
 395            message_type_id,
 396            Arc::new(move |handle, envelope, client, cx| {
 397                let handle = if let AnyEntityHandle::Model(handle) = handle {
 398                    handle
 399                } else {
 400                    unreachable!();
 401                };
 402                let model = handle.downcast::<E>().unwrap();
 403                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 404                handler(model, *envelope, client.clone(), cx).boxed_local()
 405            }),
 406        );
 407        if prev_handler.is_some() {
 408            panic!("registered handler for the same message twice");
 409        }
 410
 411        Subscription::Message {
 412            client: Arc::downgrade(self),
 413            id: message_type_id,
 414        }
 415    }
 416
 417    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 418    where
 419        M: EntityMessage,
 420        E: View,
 421        H: 'static
 422            + Send
 423            + Sync
 424            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 425        F: 'static + Future<Output = Result<()>>,
 426    {
 427        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 428            if let AnyEntityHandle::View(handle) = handle {
 429                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 430            } else {
 431                unreachable!();
 432            }
 433        })
 434    }
 435
 436    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 437    where
 438        M: EntityMessage,
 439        E: Entity,
 440        H: 'static
 441            + Send
 442            + Sync
 443            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 444        F: 'static + Future<Output = Result<()>>,
 445    {
 446        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 447            if let AnyEntityHandle::Model(handle) = handle {
 448                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 449            } else {
 450                unreachable!();
 451            }
 452        })
 453    }
 454
 455    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 456    where
 457        M: EntityMessage,
 458        E: Entity,
 459        H: 'static
 460            + Send
 461            + Sync
 462            + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 463        F: 'static + Future<Output = Result<()>>,
 464    {
 465        let model_type_id = TypeId::of::<E>();
 466        let message_type_id = TypeId::of::<M>();
 467
 468        let mut state = self.state.write();
 469        state
 470            .entity_types_by_message_type
 471            .insert(message_type_id, model_type_id);
 472        state
 473            .entity_id_extractors
 474            .entry(message_type_id)
 475            .or_insert_with(|| {
 476                |envelope| {
 477                    envelope
 478                        .as_any()
 479                        .downcast_ref::<TypedEnvelope<M>>()
 480                        .unwrap()
 481                        .payload
 482                        .remote_entity_id()
 483                }
 484            });
 485        let prev_handler = state.message_handlers.insert(
 486            message_type_id,
 487            Arc::new(move |handle, envelope, client, cx| {
 488                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 489                handler(handle, *envelope, client.clone(), cx).boxed_local()
 490            }),
 491        );
 492        if prev_handler.is_some() {
 493            panic!("registered handler for the same message twice");
 494        }
 495    }
 496
 497    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 498    where
 499        M: EntityMessage + RequestMessage,
 500        E: Entity,
 501        H: 'static
 502            + Send
 503            + Sync
 504            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 505        F: 'static + Future<Output = Result<M::Response>>,
 506    {
 507        self.add_model_message_handler(move |entity, envelope, client, cx| {
 508            Self::respond_to_request::<M, _>(
 509                envelope.receipt(),
 510                handler(entity, envelope, client.clone(), cx),
 511                client,
 512            )
 513        })
 514    }
 515
 516    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 517    where
 518        M: EntityMessage + RequestMessage,
 519        E: View,
 520        H: 'static
 521            + Send
 522            + Sync
 523            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 524        F: 'static + Future<Output = Result<M::Response>>,
 525    {
 526        self.add_view_message_handler(move |entity, envelope, client, cx| {
 527            Self::respond_to_request::<M, _>(
 528                envelope.receipt(),
 529                handler(entity, envelope, client.clone(), cx),
 530                client,
 531            )
 532        })
 533    }
 534
 535    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 536        receipt: Receipt<T>,
 537        response: F,
 538        client: Arc<Self>,
 539    ) -> Result<()> {
 540        match response.await {
 541            Ok(response) => {
 542                client.respond(receipt, response)?;
 543                Ok(())
 544            }
 545            Err(error) => {
 546                client.respond_with_error(
 547                    receipt,
 548                    proto::Error {
 549                        message: error.to_string(),
 550                    },
 551                )?;
 552                Err(error)
 553            }
 554        }
 555    }
 556
 557    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 558        read_credentials_from_keychain(cx).is_some()
 559    }
 560
 561    #[async_recursion(?Send)]
 562    pub async fn authenticate_and_connect(
 563        self: &Arc<Self>,
 564        try_keychain: bool,
 565        cx: &AsyncAppContext,
 566    ) -> anyhow::Result<()> {
 567        let was_disconnected = match *self.status().borrow() {
 568            Status::SignedOut => true,
 569            Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
 570                false
 571            }
 572            Status::Connected { .. }
 573            | Status::Connecting { .. }
 574            | Status::Reconnecting { .. }
 575            | Status::Authenticating
 576            | Status::Reauthenticating => return Ok(()),
 577            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 578        };
 579
 580        if was_disconnected {
 581            self.set_status(Status::Authenticating, cx);
 582        } else {
 583            self.set_status(Status::Reauthenticating, cx)
 584        }
 585
 586        let mut read_from_keychain = false;
 587        let mut credentials = self.state.read().credentials.clone();
 588        if credentials.is_none() && try_keychain {
 589            credentials = read_credentials_from_keychain(cx);
 590            read_from_keychain = credentials.is_some();
 591        }
 592        if credentials.is_none() {
 593            credentials = Some(match self.authenticate(&cx).await {
 594                Ok(credentials) => credentials,
 595                Err(err) => {
 596                    self.set_status(Status::ConnectionError, cx);
 597                    return Err(err);
 598                }
 599            });
 600        }
 601        let credentials = credentials.unwrap();
 602
 603        if was_disconnected {
 604            self.set_status(Status::Connecting, cx);
 605        } else {
 606            self.set_status(Status::Reconnecting, cx);
 607        }
 608
 609        match self.establish_connection(&credentials, cx).await {
 610            Ok(conn) => {
 611                self.state.write().credentials = Some(credentials.clone());
 612                if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 613                    write_credentials_to_keychain(&credentials, cx).log_err();
 614                }
 615                self.set_connection(conn, cx).await;
 616                Ok(())
 617            }
 618            Err(EstablishConnectionError::Unauthorized) => {
 619                self.state.write().credentials.take();
 620                if read_from_keychain {
 621                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 622                    self.set_status(Status::SignedOut, cx);
 623                    self.authenticate_and_connect(false, cx).await
 624                } else {
 625                    self.set_status(Status::ConnectionError, cx);
 626                    Err(EstablishConnectionError::Unauthorized)?
 627                }
 628            }
 629            Err(EstablishConnectionError::UpgradeRequired) => {
 630                self.set_status(Status::UpgradeRequired, cx);
 631                Err(EstablishConnectionError::UpgradeRequired)?
 632            }
 633            Err(error) => {
 634                self.set_status(Status::ConnectionError, cx);
 635                Err(error)?
 636            }
 637        }
 638    }
 639
 640    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
 641        let executor = cx.background();
 642        log::info!("add connection to peer");
 643        let (connection_id, handle_io, mut incoming) = self
 644            .peer
 645            .add_connection(conn, move |duration| executor.timer(duration))
 646            .await;
 647        log::info!("set status to connected {}", connection_id);
 648        self.set_status(Status::Connected { connection_id }, cx);
 649        cx.foreground()
 650            .spawn({
 651                let cx = cx.clone();
 652                let this = self.clone();
 653                async move {
 654                    let mut message_id = 0_usize;
 655                    while let Some(message) = incoming.next().await {
 656                        let mut state = this.state.write();
 657                        message_id += 1;
 658                        let type_name = message.payload_type_name();
 659                        let payload_type_id = message.payload_type_id();
 660                        let sender_id = message.original_sender_id().map(|id| id.0);
 661
 662                        let model = state
 663                            .models_by_message_type
 664                            .get(&payload_type_id)
 665                            .and_then(|model| model.upgrade(&cx))
 666                            .map(AnyEntityHandle::Model)
 667                            .or_else(|| {
 668                                let entity_type_id =
 669                                    *state.entity_types_by_message_type.get(&payload_type_id)?;
 670                                let entity_id = state
 671                                    .entity_id_extractors
 672                                    .get(&message.payload_type_id())
 673                                    .map(|extract_entity_id| {
 674                                        (extract_entity_id)(message.as_ref())
 675                                    })?;
 676
 677                                let entity = state
 678                                    .entities_by_type_and_remote_id
 679                                    .get(&(entity_type_id, entity_id))?;
 680                                if let Some(entity) = entity.upgrade(&cx) {
 681                                    Some(entity)
 682                                } else {
 683                                    state
 684                                        .entities_by_type_and_remote_id
 685                                        .remove(&(entity_type_id, entity_id));
 686                                    None
 687                                }
 688                            });
 689
 690                        let model = if let Some(model) = model {
 691                            model
 692                        } else {
 693                            log::info!("unhandled message {}", type_name);
 694                            continue;
 695                        };
 696
 697                        if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
 698                        {
 699                            drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
 700                            let future = handler(model, message, &this, cx.clone());
 701
 702                            let client_id = this.id;
 703                            log::debug!(
 704                                "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 705                                client_id,
 706                                message_id,
 707                                sender_id,
 708                                type_name
 709                            );
 710                            cx.foreground()
 711                                .spawn(async move {
 712                                    match future.await {
 713                                        Ok(()) => {
 714                                            log::debug!(
 715                                                "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 716                                                client_id,
 717                                                message_id,
 718                                                sender_id,
 719                                                type_name
 720                                            );
 721                                        }
 722                                        Err(error) => {
 723                                            log::error!(
 724                                                "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
 725                                                client_id,
 726                                                message_id,
 727                                                sender_id,
 728                                                type_name,
 729                                                error
 730                                            );
 731                                        }
 732                                    }
 733                                })
 734                                .detach();
 735                        } else {
 736                            log::info!("unhandled message {}", type_name);
 737                        }
 738
 739                        // Don't starve the main thread when receiving lots of messages at once.
 740                        smol::future::yield_now().await;
 741                    }
 742                }
 743            })
 744            .detach();
 745
 746        let handle_io = cx.background().spawn(handle_io);
 747        let this = self.clone();
 748        let cx = cx.clone();
 749        cx.foreground()
 750            .spawn(async move {
 751                match handle_io.await {
 752                    Ok(()) => {
 753                        if *this.status().borrow() == (Status::Connected { connection_id }) {
 754                            this.set_status(Status::SignedOut, &cx);
 755                        }
 756                    }
 757                    Err(err) => {
 758                        log::error!("connection error: {:?}", err);
 759                        this.set_status(Status::ConnectionLost, &cx);
 760                    }
 761                }
 762            })
 763            .detach();
 764    }
 765
 766    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 767        #[cfg(any(test, feature = "test-support"))]
 768        if let Some(callback) = self.authenticate.read().as_ref() {
 769            return callback(cx);
 770        }
 771
 772        self.authenticate_with_browser(cx)
 773    }
 774
 775    fn establish_connection(
 776        self: &Arc<Self>,
 777        credentials: &Credentials,
 778        cx: &AsyncAppContext,
 779    ) -> Task<Result<Connection, EstablishConnectionError>> {
 780        #[cfg(any(test, feature = "test-support"))]
 781        if let Some(callback) = self.establish_connection.read().as_ref() {
 782            return callback(credentials, cx);
 783        }
 784
 785        self.establish_websocket_connection(credentials, cx)
 786    }
 787
 788    fn establish_websocket_connection(
 789        self: &Arc<Self>,
 790        credentials: &Credentials,
 791        cx: &AsyncAppContext,
 792    ) -> Task<Result<Connection, EstablishConnectionError>> {
 793        let request = Request::builder()
 794            .header(
 795                "Authorization",
 796                format!("{} {}", credentials.user_id, credentials.access_token),
 797            )
 798            .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
 799
 800        let http = self.http.clone();
 801        cx.background().spawn(async move {
 802            let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
 803            let rpc_response = http.get(&rpc_url, Default::default(), false).await?;
 804            if rpc_response.status().is_redirection() {
 805                rpc_url = rpc_response
 806                    .headers()
 807                    .get("Location")
 808                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 809                    .to_str()
 810                    .map_err(|error| EstablishConnectionError::other(error))?
 811                    .to_string();
 812            }
 813            // Until we switch the zed.dev domain to point to the new Next.js app, there
 814            // will be no redirect required, and the app will connect directly to
 815            // wss://zed.dev/rpc.
 816            else if rpc_response.status() != StatusCode::UPGRADE_REQUIRED {
 817                Err(anyhow!(
 818                    "unexpected /rpc response status {}",
 819                    rpc_response.status()
 820                ))?
 821            }
 822
 823            let mut rpc_url = Url::parse(&rpc_url).context("invalid rpc url")?;
 824            let rpc_host = rpc_url
 825                .host_str()
 826                .zip(rpc_url.port_or_known_default())
 827                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
 828            let stream = smol::net::TcpStream::connect(rpc_host).await?;
 829
 830            log::info!("connected to rpc endpoint {}", rpc_url);
 831
 832            match rpc_url.scheme() {
 833                "https" => {
 834                    rpc_url.set_scheme("wss").unwrap();
 835                    let request = request.uri(rpc_url.as_str()).body(())?;
 836                    let (stream, _) =
 837                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
 838                    Ok(Connection::new(
 839                        stream
 840                            .map_err(|error| anyhow!(error))
 841                            .sink_map_err(|error| anyhow!(error)),
 842                    ))
 843                }
 844                "http" => {
 845                    rpc_url.set_scheme("ws").unwrap();
 846                    let request = request.uri(rpc_url.as_str()).body(())?;
 847                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
 848                    Ok(Connection::new(
 849                        stream
 850                            .map_err(|error| anyhow!(error))
 851                            .sink_map_err(|error| anyhow!(error)),
 852                    ))
 853                }
 854                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
 855            }
 856        })
 857    }
 858
 859    pub fn authenticate_with_browser(
 860        self: &Arc<Self>,
 861        cx: &AsyncAppContext,
 862    ) -> Task<Result<Credentials>> {
 863        let platform = cx.platform();
 864        let executor = cx.background();
 865        executor.clone().spawn(async move {
 866            // Generate a pair of asymmetric encryption keys. The public key will be used by the
 867            // zed server to encrypt the user's access token, so that it can'be intercepted by
 868            // any other app running on the user's device.
 869            let (public_key, private_key) =
 870                rpc::auth::keypair().expect("failed to generate keypair for auth");
 871            let public_key_string =
 872                String::try_from(public_key).expect("failed to serialize public key for auth");
 873
 874            // Start an HTTP server to receive the redirect from Zed's sign-in page.
 875            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
 876            let port = server.server_addr().port();
 877
 878            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
 879            // that the user is signing in from a Zed app running on the same device.
 880            let mut url = format!(
 881                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
 882                *ZED_SERVER_URL, port, public_key_string
 883            );
 884
 885            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
 886                log::info!("impersonating user @{}", impersonate_login);
 887                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
 888            }
 889
 890            platform.open_url(&url);
 891
 892            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
 893            // access token from the query params.
 894            //
 895            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
 896            // custom URL scheme instead of this local HTTP server.
 897            let (user_id, access_token) = executor
 898                .spawn(async move {
 899                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
 900                        let path = req.url();
 901                        let mut user_id = None;
 902                        let mut access_token = None;
 903                        let url = Url::parse(&format!("http://example.com{}", path))
 904                            .context("failed to parse login notification url")?;
 905                        for (key, value) in url.query_pairs() {
 906                            if key == "access_token" {
 907                                access_token = Some(value.to_string());
 908                            } else if key == "user_id" {
 909                                user_id = Some(value.to_string());
 910                            }
 911                        }
 912
 913                        let post_auth_url =
 914                            format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
 915                        req.respond(
 916                            tiny_http::Response::empty(302).with_header(
 917                                tiny_http::Header::from_bytes(
 918                                    &b"Location"[..],
 919                                    post_auth_url.as_bytes(),
 920                                )
 921                                .unwrap(),
 922                            ),
 923                        )
 924                        .context("failed to respond to login http request")?;
 925                        Ok((
 926                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
 927                            access_token
 928                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
 929                        ))
 930                    } else {
 931                        Err(anyhow!("didn't receive login redirect"))
 932                    }
 933                })
 934                .await?;
 935
 936            let access_token = private_key
 937                .decrypt_string(&access_token)
 938                .context("failed to decrypt access token")?;
 939            platform.activate(true);
 940
 941            Ok(Credentials {
 942                user_id: user_id.parse()?,
 943                access_token,
 944            })
 945        })
 946    }
 947
 948    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
 949        let conn_id = self.connection_id()?;
 950        self.peer.disconnect(conn_id);
 951        self.set_status(Status::SignedOut, cx);
 952        Ok(())
 953    }
 954
 955    fn connection_id(&self) -> Result<ConnectionId> {
 956        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
 957            Ok(connection_id)
 958        } else {
 959            Err(anyhow!("not connected"))
 960        }
 961    }
 962
 963    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
 964        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
 965        self.peer.send(self.connection_id()?, message)
 966    }
 967
 968    pub fn request<T: RequestMessage>(
 969        &self,
 970        request: T,
 971    ) -> impl Future<Output = Result<T::Response>> {
 972        let client_id = self.id;
 973        log::debug!(
 974            "rpc request start. client_id:{}. name:{}",
 975            client_id,
 976            T::NAME
 977        );
 978        let response = self
 979            .connection_id()
 980            .map(|conn_id| self.peer.request(conn_id, request));
 981        async move {
 982            let response = response?.await;
 983            log::debug!(
 984                "rpc request finish. client_id:{}. name:{}",
 985                client_id,
 986                T::NAME
 987            );
 988            response
 989        }
 990    }
 991
 992    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
 993        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
 994        self.peer.respond(receipt, response)
 995    }
 996
 997    fn respond_with_error<T: RequestMessage>(
 998        &self,
 999        receipt: Receipt<T>,
1000        error: proto::Error,
1001    ) -> Result<()> {
1002        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1003        self.peer.respond_with_error(receipt, error)
1004    }
1005}
1006
1007impl AnyWeakEntityHandle {
1008    fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
1009        match self {
1010            AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
1011            AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
1012        }
1013    }
1014}
1015
1016fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1017    if IMPERSONATE_LOGIN.is_some() {
1018        return None;
1019    }
1020
1021    let (user_id, access_token) = cx
1022        .platform()
1023        .read_credentials(&ZED_SERVER_URL)
1024        .log_err()
1025        .flatten()?;
1026    Some(Credentials {
1027        user_id: user_id.parse().ok()?,
1028        access_token: String::from_utf8(access_token).ok()?,
1029    })
1030}
1031
1032fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1033    cx.platform().write_credentials(
1034        &ZED_SERVER_URL,
1035        &credentials.user_id.to_string(),
1036        credentials.access_token.as_bytes(),
1037    )
1038}
1039
1040const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
1041
1042pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1043    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1044}
1045
1046pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1047    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1048    let mut parts = path.split('/');
1049    let id = parts.next()?.parse::<u64>().ok()?;
1050    let access_token = parts.next()?;
1051    if access_token.is_empty() {
1052        return None;
1053    }
1054    Some((id, access_token.to_string()))
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059    use super::*;
1060    use crate::test::{FakeHttpClient, FakeServer};
1061    use gpui::TestAppContext;
1062
1063    #[gpui::test(iterations = 10)]
1064    async fn test_reconnection(cx: &mut TestAppContext) {
1065        cx.foreground().forbid_parking();
1066
1067        let user_id = 5;
1068        let mut client = Client::new(FakeHttpClient::with_404_response());
1069        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1070        let mut status = client.status();
1071        assert!(matches!(
1072            status.next().await,
1073            Some(Status::Connected { .. })
1074        ));
1075        assert_eq!(server.auth_count(), 1);
1076
1077        server.forbid_connections();
1078        server.disconnect();
1079        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1080
1081        server.allow_connections();
1082        cx.foreground().advance_clock(Duration::from_secs(10));
1083        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1084        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1085
1086        server.forbid_connections();
1087        server.disconnect();
1088        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1089
1090        // Clear cached credentials after authentication fails
1091        server.roll_access_token();
1092        server.allow_connections();
1093        cx.foreground().advance_clock(Duration::from_secs(10));
1094        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1095        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1096    }
1097
1098    #[test]
1099    fn test_encode_and_decode_worktree_url() {
1100        let url = encode_worktree_url(5, "deadbeef");
1101        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1102        assert_eq!(
1103            decode_worktree_url(&format!("\n {}\t", url)),
1104            Some((5, "deadbeef".to_string()))
1105        );
1106        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1107    }
1108
1109    #[gpui::test]
1110    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1111        cx.foreground().forbid_parking();
1112
1113        let user_id = 5;
1114        let mut client = Client::new(FakeHttpClient::with_404_response());
1115        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1116
1117        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1118        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1119        client.add_model_message_handler(
1120            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1121                match model.read_with(&cx, |model, _| model.id) {
1122                    1 => done_tx1.try_send(()).unwrap(),
1123                    2 => done_tx2.try_send(()).unwrap(),
1124                    _ => unreachable!(),
1125                }
1126                async { Ok(()) }
1127            },
1128        );
1129        let model1 = cx.add_model(|_| Model {
1130            id: 1,
1131            subscription: None,
1132        });
1133        let model2 = cx.add_model(|_| Model {
1134            id: 2,
1135            subscription: None,
1136        });
1137        let model3 = cx.add_model(|_| Model {
1138            id: 3,
1139            subscription: None,
1140        });
1141
1142        let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1143        let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1144        // Ensure dropping a subscription for the same entity type still allows receiving of
1145        // messages for other entity IDs of the same type.
1146        let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1147        drop(subscription3);
1148
1149        server.send(proto::JoinProject { project_id: 1 });
1150        server.send(proto::JoinProject { project_id: 2 });
1151        done_rx1.next().await.unwrap();
1152        done_rx2.next().await.unwrap();
1153    }
1154
1155    #[gpui::test]
1156    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1157        cx.foreground().forbid_parking();
1158
1159        let user_id = 5;
1160        let mut client = Client::new(FakeHttpClient::with_404_response());
1161        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1162
1163        let model = cx.add_model(|_| Model::default());
1164        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1165        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1166        let subscription1 = client.add_message_handler(
1167            model.clone(),
1168            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1169                done_tx1.try_send(()).unwrap();
1170                async { Ok(()) }
1171            },
1172        );
1173        drop(subscription1);
1174        let _subscription2 =
1175            client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1176                done_tx2.try_send(()).unwrap();
1177                async { Ok(()) }
1178            });
1179        server.send(proto::Ping {});
1180        done_rx2.next().await.unwrap();
1181    }
1182
1183    #[gpui::test]
1184    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1185        cx.foreground().forbid_parking();
1186
1187        let user_id = 5;
1188        let mut client = Client::new(FakeHttpClient::with_404_response());
1189        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1190
1191        let model = cx.add_model(|_| Model::default());
1192        let (done_tx, mut done_rx) = smol::channel::unbounded();
1193        let subscription = client.add_message_handler(
1194            model.clone(),
1195            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1196                model.update(&mut cx, |model, _| model.subscription.take());
1197                done_tx.try_send(()).unwrap();
1198                async { Ok(()) }
1199            },
1200        );
1201        model.update(cx, |model, _| {
1202            model.subscription = Some(subscription);
1203        });
1204        server.send(proto::Ping {});
1205        done_rx.next().await.unwrap();
1206    }
1207
1208    #[derive(Default)]
1209    struct Model {
1210        id: usize,
1211        subscription: Option<Subscription>,
1212    }
1213
1214    impl Entity for Model {
1215        type Event = ();
1216    }
1217}