client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod channel;
   5pub mod http;
   6pub mod user;
   7
   8use anyhow::{anyhow, Context, Result};
   9use async_recursion::async_recursion;
  10use async_tungstenite::tungstenite::{
  11    error::Error as WebsocketError,
  12    http::{Request, StatusCode},
  13};
  14use futures::{future::LocalBoxFuture, FutureExt, StreamExt};
  15use gpui::{
  16    action, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AsyncAppContext,
  17    Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle,
  18};
  19use http::HttpClient;
  20use lazy_static::lazy_static;
  21use parking_lot::RwLock;
  22use postage::watch;
  23use rand::prelude::*;
  24use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
  25use std::{
  26    any::TypeId,
  27    collections::HashMap,
  28    convert::TryFrom,
  29    fmt::Write as _,
  30    future::Future,
  31    sync::{
  32        atomic::{AtomicUsize, Ordering},
  33        Arc, Weak,
  34    },
  35    time::{Duration, Instant},
  36};
  37use surf::{http::Method, Url};
  38use thiserror::Error;
  39use util::{ResultExt, TryFutureExt};
  40
  41pub use channel::*;
  42pub use rpc::*;
  43pub use user::*;
  44
  45lazy_static! {
  46    static ref ZED_SERVER_URL: String =
  47        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
  48    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  49        .ok()
  50        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  51}
  52
  53action!(Authenticate);
  54
  55pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
  56    cx.add_global_action(move |_: &Authenticate, cx| {
  57        let rpc = rpc.clone();
  58        cx.spawn(|cx| async move { rpc.authenticate_and_connect(true, &cx).log_err().await })
  59            .detach();
  60    });
  61}
  62
  63pub struct Client {
  64    id: usize,
  65    peer: Arc<Peer>,
  66    http: Arc<dyn HttpClient>,
  67    state: RwLock<ClientState>,
  68    authenticate:
  69        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
  70    establish_connection: Option<
  71        Box<
  72            dyn 'static
  73                + Send
  74                + Sync
  75                + Fn(
  76                    &Credentials,
  77                    &AsyncAppContext,
  78                ) -> Task<Result<Connection, EstablishConnectionError>>,
  79        >,
  80    >,
  81}
  82
  83#[derive(Error, Debug)]
  84pub enum EstablishConnectionError {
  85    #[error("upgrade required")]
  86    UpgradeRequired,
  87    #[error("unauthorized")]
  88    Unauthorized,
  89    #[error("{0}")]
  90    Other(#[from] anyhow::Error),
  91    #[error("{0}")]
  92    Io(#[from] std::io::Error),
  93    #[error("{0}")]
  94    Http(#[from] async_tungstenite::tungstenite::http::Error),
  95}
  96
  97impl From<WebsocketError> for EstablishConnectionError {
  98    fn from(error: WebsocketError) -> Self {
  99        if let WebsocketError::Http(response) = &error {
 100            match response.status() {
 101                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 102                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 103                _ => {}
 104            }
 105        }
 106        EstablishConnectionError::Other(error.into())
 107    }
 108}
 109
 110impl EstablishConnectionError {
 111    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 112        Self::Other(error.into())
 113    }
 114}
 115
 116#[derive(Copy, Clone, Debug)]
 117pub enum Status {
 118    SignedOut,
 119    UpgradeRequired,
 120    Authenticating,
 121    Connecting,
 122    ConnectionError,
 123    Connected { connection_id: ConnectionId },
 124    ConnectionLost,
 125    Reauthenticating,
 126    Reconnecting,
 127    ReconnectionError { next_reconnection: Instant },
 128}
 129
 130impl Status {
 131    pub fn is_connected(&self) -> bool {
 132        matches!(self, Self::Connected { .. })
 133    }
 134}
 135
 136struct ClientState {
 137    credentials: Option<Credentials>,
 138    status: (watch::Sender<Status>, watch::Receiver<Status>),
 139    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
 140    _reconnect_task: Option<Task<()>>,
 141    reconnect_interval: Duration,
 142    entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
 143    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
 144    entity_types_by_message_type: HashMap<TypeId, TypeId>,
 145    message_handlers: HashMap<
 146        TypeId,
 147        Arc<
 148            dyn Send
 149                + Sync
 150                + Fn(
 151                    AnyEntityHandle,
 152                    Box<dyn AnyTypedEnvelope>,
 153                    &Arc<Client>,
 154                    AsyncAppContext,
 155                ) -> LocalBoxFuture<'static, Result<()>>,
 156        >,
 157    >,
 158}
 159
 160enum AnyWeakEntityHandle {
 161    Model(AnyWeakModelHandle),
 162    View(AnyWeakViewHandle),
 163}
 164
 165enum AnyEntityHandle {
 166    Model(AnyModelHandle),
 167    View(AnyViewHandle),
 168}
 169
 170#[derive(Clone, Debug)]
 171pub struct Credentials {
 172    pub user_id: u64,
 173    pub access_token: String,
 174}
 175
 176impl Default for ClientState {
 177    fn default() -> Self {
 178        Self {
 179            credentials: None,
 180            status: watch::channel_with(Status::SignedOut),
 181            entity_id_extractors: Default::default(),
 182            _reconnect_task: None,
 183            reconnect_interval: Duration::from_secs(5),
 184            models_by_message_type: Default::default(),
 185            entities_by_type_and_remote_id: Default::default(),
 186            entity_types_by_message_type: Default::default(),
 187            message_handlers: Default::default(),
 188        }
 189    }
 190}
 191
 192pub enum Subscription {
 193    Entity {
 194        client: Weak<Client>,
 195        id: (TypeId, u64),
 196    },
 197    Message {
 198        client: Weak<Client>,
 199        id: TypeId,
 200    },
 201}
 202
 203impl Drop for Subscription {
 204    fn drop(&mut self) {
 205        match self {
 206            Subscription::Entity { client, id } => {
 207                if let Some(client) = client.upgrade() {
 208                    let mut state = client.state.write();
 209                    let _ = state.entities_by_type_and_remote_id.remove(id);
 210                }
 211            }
 212            Subscription::Message { client, id } => {
 213                if let Some(client) = client.upgrade() {
 214                    let mut state = client.state.write();
 215                    let _ = state.entity_types_by_message_type.remove(id);
 216                    let _ = state.message_handlers.remove(id);
 217                }
 218            }
 219        }
 220    }
 221}
 222
 223impl Client {
 224    pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
 225        lazy_static! {
 226            static ref NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::default();
 227        }
 228
 229        Arc::new(Self {
 230            id: NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst),
 231            peer: Peer::new(),
 232            http,
 233            state: Default::default(),
 234            authenticate: None,
 235            establish_connection: None,
 236        })
 237    }
 238
 239    pub fn id(&self) -> usize {
 240        self.id
 241    }
 242
 243    pub fn http_client(&self) -> Arc<dyn HttpClient> {
 244        self.http.clone()
 245    }
 246
 247    #[cfg(any(test, feature = "test-support"))]
 248    pub fn tear_down(&self) {
 249        let mut state = self.state.write();
 250        state._reconnect_task.take();
 251        state.message_handlers.clear();
 252        state.models_by_message_type.clear();
 253        state.entities_by_type_and_remote_id.clear();
 254        state.entity_id_extractors.clear();
 255        self.peer.reset();
 256    }
 257
 258    #[cfg(any(test, feature = "test-support"))]
 259    pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
 260    where
 261        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 262    {
 263        self.authenticate = Some(Box::new(authenticate));
 264        self
 265    }
 266
 267    #[cfg(any(test, feature = "test-support"))]
 268    pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
 269    where
 270        F: 'static
 271            + Send
 272            + Sync
 273            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 274    {
 275        self.establish_connection = Some(Box::new(connect));
 276        self
 277    }
 278
 279    pub fn user_id(&self) -> Option<u64> {
 280        self.state
 281            .read()
 282            .credentials
 283            .as_ref()
 284            .map(|credentials| credentials.user_id)
 285    }
 286
 287    pub fn status(&self) -> watch::Receiver<Status> {
 288        self.state.read().status.1.clone()
 289    }
 290
 291    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 292        let mut state = self.state.write();
 293        *state.status.0.borrow_mut() = status;
 294
 295        match status {
 296            Status::Connected { .. } => {
 297                state._reconnect_task = None;
 298            }
 299            Status::ConnectionLost => {
 300                let this = self.clone();
 301                let reconnect_interval = state.reconnect_interval;
 302                state._reconnect_task = Some(cx.spawn(|cx| async move {
 303                    let mut rng = StdRng::from_entropy();
 304                    let mut delay = Duration::from_millis(100);
 305                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
 306                        log::error!("failed to connect {}", error);
 307                        this.set_status(
 308                            Status::ReconnectionError {
 309                                next_reconnection: Instant::now() + delay,
 310                            },
 311                            &cx,
 312                        );
 313                        cx.background().timer(delay).await;
 314                        delay = delay
 315                            .mul_f32(rng.gen_range(1.0..=2.0))
 316                            .min(reconnect_interval);
 317                    }
 318                }));
 319            }
 320            Status::SignedOut | Status::UpgradeRequired => {
 321                state._reconnect_task.take();
 322            }
 323            _ => {}
 324        }
 325    }
 326
 327    pub fn add_view_for_remote_entity<T: View>(
 328        self: &Arc<Self>,
 329        remote_id: u64,
 330        cx: &mut ViewContext<T>,
 331    ) -> Subscription {
 332        let id = (TypeId::of::<T>(), remote_id);
 333        self.state
 334            .write()
 335            .entities_by_type_and_remote_id
 336            .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
 337        Subscription::Entity {
 338            client: Arc::downgrade(self),
 339            id,
 340        }
 341    }
 342
 343    pub fn add_model_for_remote_entity<T: Entity>(
 344        self: &Arc<Self>,
 345        remote_id: u64,
 346        cx: &mut ModelContext<T>,
 347    ) -> Subscription {
 348        let id = (TypeId::of::<T>(), remote_id);
 349        self.state
 350            .write()
 351            .entities_by_type_and_remote_id
 352            .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
 353        Subscription::Entity {
 354            client: Arc::downgrade(self),
 355            id,
 356        }
 357    }
 358
 359    pub fn add_message_handler<M, E, H, F>(
 360        self: &Arc<Self>,
 361        model: ModelHandle<E>,
 362        handler: H,
 363    ) -> Subscription
 364    where
 365        M: EnvelopedMessage,
 366        E: Entity,
 367        H: 'static
 368            + Send
 369            + Sync
 370            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 371        F: 'static + Future<Output = Result<()>>,
 372    {
 373        let message_type_id = TypeId::of::<M>();
 374
 375        let mut state = self.state.write();
 376        state
 377            .models_by_message_type
 378            .insert(message_type_id, model.downgrade().into());
 379
 380        let prev_handler = state.message_handlers.insert(
 381            message_type_id,
 382            Arc::new(move |handle, envelope, client, cx| {
 383                let handle = if let AnyEntityHandle::Model(handle) = handle {
 384                    handle
 385                } else {
 386                    unreachable!();
 387                };
 388                let model = handle.downcast::<E>().unwrap();
 389                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 390                handler(model, *envelope, client.clone(), cx).boxed_local()
 391            }),
 392        );
 393        if prev_handler.is_some() {
 394            panic!("registered handler for the same message twice");
 395        }
 396
 397        Subscription::Message {
 398            client: Arc::downgrade(self),
 399            id: message_type_id,
 400        }
 401    }
 402
 403    pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 404    where
 405        M: EntityMessage,
 406        E: View,
 407        H: 'static
 408            + Send
 409            + Sync
 410            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 411        F: 'static + Future<Output = Result<()>>,
 412    {
 413        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 414            if let AnyEntityHandle::View(handle) = handle {
 415                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 416            } else {
 417                unreachable!();
 418            }
 419        })
 420    }
 421
 422    pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 423    where
 424        M: EntityMessage,
 425        E: Entity,
 426        H: 'static
 427            + Send
 428            + Sync
 429            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 430        F: 'static + Future<Output = Result<()>>,
 431    {
 432        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
 433            if let AnyEntityHandle::Model(handle) = handle {
 434                handler(handle.downcast::<E>().unwrap(), message, client, cx)
 435            } else {
 436                unreachable!();
 437            }
 438        })
 439    }
 440
 441    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 442    where
 443        M: EntityMessage,
 444        E: Entity,
 445        H: 'static
 446            + Send
 447            + Sync
 448            + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 449        F: 'static + Future<Output = Result<()>>,
 450    {
 451        let model_type_id = TypeId::of::<E>();
 452        let message_type_id = TypeId::of::<M>();
 453
 454        let mut state = self.state.write();
 455        state
 456            .entity_types_by_message_type
 457            .insert(message_type_id, model_type_id);
 458        state
 459            .entity_id_extractors
 460            .entry(message_type_id)
 461            .or_insert_with(|| {
 462                |envelope| {
 463                    envelope
 464                        .as_any()
 465                        .downcast_ref::<TypedEnvelope<M>>()
 466                        .unwrap()
 467                        .payload
 468                        .remote_entity_id()
 469                }
 470            });
 471        let prev_handler = state.message_handlers.insert(
 472            message_type_id,
 473            Arc::new(move |handle, envelope, client, cx| {
 474                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 475                handler(handle, *envelope, client.clone(), cx).boxed_local()
 476            }),
 477        );
 478        if prev_handler.is_some() {
 479            panic!("registered handler for the same message twice");
 480        }
 481    }
 482
 483    pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 484    where
 485        M: EntityMessage + RequestMessage,
 486        E: Entity,
 487        H: 'static
 488            + Send
 489            + Sync
 490            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 491        F: 'static + Future<Output = Result<M::Response>>,
 492    {
 493        self.add_model_message_handler(move |entity, envelope, client, cx| {
 494            Self::respond_to_request::<M, _>(
 495                envelope.receipt(),
 496                handler(entity, envelope, client.clone(), cx),
 497                client,
 498            )
 499        })
 500    }
 501
 502    pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 503    where
 504        M: EntityMessage + RequestMessage,
 505        E: View,
 506        H: 'static
 507            + Send
 508            + Sync
 509            + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 510        F: 'static + Future<Output = Result<M::Response>>,
 511    {
 512        self.add_view_message_handler(move |entity, envelope, client, cx| {
 513            Self::respond_to_request::<M, _>(
 514                envelope.receipt(),
 515                handler(entity, envelope, client.clone(), cx),
 516                client,
 517            )
 518        })
 519    }
 520
 521    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
 522        receipt: Receipt<T>,
 523        response: F,
 524        client: Arc<Self>,
 525    ) -> Result<()> {
 526        match response.await {
 527            Ok(response) => {
 528                client.respond(receipt, response)?;
 529                Ok(())
 530            }
 531            Err(error) => {
 532                client.respond_with_error(
 533                    receipt,
 534                    proto::Error {
 535                        message: error.to_string(),
 536                    },
 537                )?;
 538                Err(error)
 539            }
 540        }
 541    }
 542
 543    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 544        read_credentials_from_keychain(cx).is_some()
 545    }
 546
 547    #[async_recursion(?Send)]
 548    pub async fn authenticate_and_connect(
 549        self: &Arc<Self>,
 550        try_keychain: bool,
 551        cx: &AsyncAppContext,
 552    ) -> anyhow::Result<()> {
 553        let was_disconnected = match *self.status().borrow() {
 554            Status::SignedOut => true,
 555            Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
 556                false
 557            }
 558            Status::Connected { .. }
 559            | Status::Connecting { .. }
 560            | Status::Reconnecting { .. }
 561            | Status::Authenticating
 562            | Status::Reauthenticating => return Ok(()),
 563            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 564        };
 565
 566        if was_disconnected {
 567            self.set_status(Status::Authenticating, cx);
 568        } else {
 569            self.set_status(Status::Reauthenticating, cx)
 570        }
 571
 572        let mut read_from_keychain = false;
 573        let mut credentials = self.state.read().credentials.clone();
 574        if credentials.is_none() && try_keychain {
 575            credentials = read_credentials_from_keychain(cx);
 576            read_from_keychain = credentials.is_some();
 577        }
 578        if credentials.is_none() {
 579            credentials = Some(match self.authenticate(&cx).await {
 580                Ok(credentials) => credentials,
 581                Err(err) => {
 582                    self.set_status(Status::ConnectionError, cx);
 583                    return Err(err);
 584                }
 585            });
 586        }
 587        let credentials = credentials.unwrap();
 588
 589        if was_disconnected {
 590            self.set_status(Status::Connecting, cx);
 591        } else {
 592            self.set_status(Status::Reconnecting, cx);
 593        }
 594
 595        match self.establish_connection(&credentials, cx).await {
 596            Ok(conn) => {
 597                self.state.write().credentials = Some(credentials.clone());
 598                if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
 599                    write_credentials_to_keychain(&credentials, cx).log_err();
 600                }
 601                self.set_connection(conn, cx).await;
 602                Ok(())
 603            }
 604            Err(EstablishConnectionError::Unauthorized) => {
 605                self.state.write().credentials.take();
 606                if read_from_keychain {
 607                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 608                    self.set_status(Status::SignedOut, cx);
 609                    self.authenticate_and_connect(false, cx).await
 610                } else {
 611                    self.set_status(Status::ConnectionError, cx);
 612                    Err(EstablishConnectionError::Unauthorized)?
 613                }
 614            }
 615            Err(EstablishConnectionError::UpgradeRequired) => {
 616                self.set_status(Status::UpgradeRequired, cx);
 617                Err(EstablishConnectionError::UpgradeRequired)?
 618            }
 619            Err(error) => {
 620                self.set_status(Status::ConnectionError, cx);
 621                Err(error)?
 622            }
 623        }
 624    }
 625
 626    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
 627        let executor = cx.background();
 628        let (connection_id, handle_io, mut incoming) = self
 629            .peer
 630            .add_connection(conn, move |duration| executor.timer(duration))
 631            .await;
 632        cx.foreground()
 633            .spawn({
 634                let cx = cx.clone();
 635                let this = self.clone();
 636                async move {
 637                    let mut message_id = 0_usize;
 638                    while let Some(message) = incoming.next().await {
 639                        let mut state = this.state.write();
 640                        message_id += 1;
 641                        let type_name = message.payload_type_name();
 642                        let payload_type_id = message.payload_type_id();
 643                        let sender_id = message.original_sender_id().map(|id| id.0);
 644
 645                        let model = state
 646                            .models_by_message_type
 647                            .get(&payload_type_id)
 648                            .and_then(|model| model.upgrade(&cx))
 649                            .map(AnyEntityHandle::Model)
 650                            .or_else(|| {
 651                                let entity_type_id =
 652                                    *state.entity_types_by_message_type.get(&payload_type_id)?;
 653                                let entity_id = state
 654                                    .entity_id_extractors
 655                                    .get(&message.payload_type_id())
 656                                    .map(|extract_entity_id| {
 657                                        (extract_entity_id)(message.as_ref())
 658                                    })?;
 659
 660                                let entity = state
 661                                    .entities_by_type_and_remote_id
 662                                    .get(&(entity_type_id, entity_id))?;
 663                                if let Some(entity) = entity.upgrade(&cx) {
 664                                    Some(entity)
 665                                } else {
 666                                    state
 667                                        .entities_by_type_and_remote_id
 668                                        .remove(&(entity_type_id, entity_id));
 669                                    None
 670                                }
 671                            });
 672
 673                        let model = if let Some(model) = model {
 674                            model
 675                        } else {
 676                            log::info!("unhandled message {}", type_name);
 677                            continue;
 678                        };
 679
 680                        if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
 681                        {
 682                            drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
 683                            let future = handler(model, message, &this, cx.clone());
 684
 685                            let client_id = this.id;
 686                            log::debug!(
 687                                "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 688                                client_id,
 689                                message_id,
 690                                sender_id,
 691                                type_name
 692                            );
 693                            cx.foreground()
 694                                .spawn(async move {
 695                                    match future.await {
 696                                        Ok(()) => {
 697                                            log::debug!(
 698                                                "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
 699                                                client_id,
 700                                                message_id,
 701                                                sender_id,
 702                                                type_name
 703                                            );
 704                                        }
 705                                        Err(error) => {
 706                                            log::error!(
 707                                                "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
 708                                                client_id,
 709                                                message_id,
 710                                                sender_id,
 711                                                type_name,
 712                                                error
 713                                            );
 714                                        }
 715                                    }
 716                                })
 717                                .detach();
 718                        } else {
 719                            log::info!("unhandled message {}", type_name);
 720                        }
 721
 722                        // Don't starve the main thread when receiving lots of messages at once.
 723                        smol::future::yield_now().await;
 724                    }
 725                }
 726            })
 727            .detach();
 728
 729        self.set_status(Status::Connected { connection_id }, cx);
 730
 731        let handle_io = cx.background().spawn(handle_io);
 732        let this = self.clone();
 733        let cx = cx.clone();
 734        cx.foreground()
 735            .spawn(async move {
 736                match handle_io.await {
 737                    Ok(()) => this.set_status(Status::SignedOut, &cx),
 738                    Err(err) => {
 739                        log::error!("connection error: {:?}", err);
 740                        this.set_status(Status::ConnectionLost, &cx);
 741                    }
 742                }
 743            })
 744            .detach();
 745    }
 746
 747    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 748        if let Some(callback) = self.authenticate.as_ref() {
 749            callback(cx)
 750        } else {
 751            self.authenticate_with_browser(cx)
 752        }
 753    }
 754
 755    fn establish_connection(
 756        self: &Arc<Self>,
 757        credentials: &Credentials,
 758        cx: &AsyncAppContext,
 759    ) -> Task<Result<Connection, EstablishConnectionError>> {
 760        if let Some(callback) = self.establish_connection.as_ref() {
 761            callback(credentials, cx)
 762        } else {
 763            self.establish_websocket_connection(credentials, cx)
 764        }
 765    }
 766
 767    fn establish_websocket_connection(
 768        self: &Arc<Self>,
 769        credentials: &Credentials,
 770        cx: &AsyncAppContext,
 771    ) -> Task<Result<Connection, EstablishConnectionError>> {
 772        let request = Request::builder()
 773            .header(
 774                "Authorization",
 775                format!("{} {}", credentials.user_id, credentials.access_token),
 776            )
 777            .header("X-Zed-Protocol-Version", rpc::PROTOCOL_VERSION);
 778
 779        let http = self.http.clone();
 780        cx.background().spawn(async move {
 781            let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
 782            let rpc_request = surf::Request::new(
 783                Method::Get,
 784                surf::Url::parse(&rpc_url).context("invalid ZED_SERVER_URL")?,
 785            );
 786            let rpc_response = http.send(rpc_request).await?;
 787
 788            if rpc_response.status().is_redirection() {
 789                rpc_url = rpc_response
 790                    .header("Location")
 791                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 792                    .as_str()
 793                    .to_string();
 794            }
 795            // Until we switch the zed.dev domain to point to the new Next.js app, there
 796            // will be no redirect required, and the app will connect directly to
 797            // wss://zed.dev/rpc.
 798            else if rpc_response.status() != surf::StatusCode::UpgradeRequired {
 799                Err(anyhow!(
 800                    "unexpected /rpc response status {}",
 801                    rpc_response.status()
 802                ))?
 803            }
 804
 805            let mut rpc_url = surf::Url::parse(&rpc_url).context("invalid rpc url")?;
 806            let rpc_host = rpc_url
 807                .host_str()
 808                .zip(rpc_url.port_or_known_default())
 809                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
 810            let stream = smol::net::TcpStream::connect(rpc_host).await?;
 811
 812            log::info!("connected to rpc endpoint {}", rpc_url);
 813
 814            match rpc_url.scheme() {
 815                "https" => {
 816                    rpc_url.set_scheme("wss").unwrap();
 817                    let request = request.uri(rpc_url.as_str()).body(())?;
 818                    let (stream, _) =
 819                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
 820                    Ok(Connection::new(stream))
 821                }
 822                "http" => {
 823                    rpc_url.set_scheme("ws").unwrap();
 824                    let request = request.uri(rpc_url.as_str()).body(())?;
 825                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
 826                    Ok(Connection::new(stream))
 827                }
 828                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
 829            }
 830        })
 831    }
 832
 833    pub fn authenticate_with_browser(
 834        self: &Arc<Self>,
 835        cx: &AsyncAppContext,
 836    ) -> Task<Result<Credentials>> {
 837        let platform = cx.platform();
 838        let executor = cx.background();
 839        executor.clone().spawn(async move {
 840            // Generate a pair of asymmetric encryption keys. The public key will be used by the
 841            // zed server to encrypt the user's access token, so that it can'be intercepted by
 842            // any other app running on the user's device.
 843            let (public_key, private_key) =
 844                rpc::auth::keypair().expect("failed to generate keypair for auth");
 845            let public_key_string =
 846                String::try_from(public_key).expect("failed to serialize public key for auth");
 847
 848            // Start an HTTP server to receive the redirect from Zed's sign-in page.
 849            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
 850            let port = server.server_addr().port();
 851
 852            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
 853            // that the user is signing in from a Zed app running on the same device.
 854            let mut url = format!(
 855                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
 856                *ZED_SERVER_URL, port, public_key_string
 857            );
 858
 859            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
 860                log::info!("impersonating user @{}", impersonate_login);
 861                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
 862            }
 863
 864            platform.open_url(&url);
 865
 866            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
 867            // access token from the query params.
 868            //
 869            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
 870            // custom URL scheme instead of this local HTTP server.
 871            let (user_id, access_token) = executor
 872                .spawn(async move {
 873                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
 874                        let path = req.url();
 875                        let mut user_id = None;
 876                        let mut access_token = None;
 877                        let url = Url::parse(&format!("http://example.com{}", path))
 878                            .context("failed to parse login notification url")?;
 879                        for (key, value) in url.query_pairs() {
 880                            if key == "access_token" {
 881                                access_token = Some(value.to_string());
 882                            } else if key == "user_id" {
 883                                user_id = Some(value.to_string());
 884                            }
 885                        }
 886
 887                        let post_auth_url =
 888                            format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
 889                        req.respond(
 890                            tiny_http::Response::empty(302).with_header(
 891                                tiny_http::Header::from_bytes(
 892                                    &b"Location"[..],
 893                                    post_auth_url.as_bytes(),
 894                                )
 895                                .unwrap(),
 896                            ),
 897                        )
 898                        .context("failed to respond to login http request")?;
 899                        Ok((
 900                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
 901                            access_token
 902                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
 903                        ))
 904                    } else {
 905                        Err(anyhow!("didn't receive login redirect"))
 906                    }
 907                })
 908                .await?;
 909
 910            let access_token = private_key
 911                .decrypt_string(&access_token)
 912                .context("failed to decrypt access token")?;
 913            platform.activate(true);
 914
 915            Ok(Credentials {
 916                user_id: user_id.parse()?,
 917                access_token,
 918            })
 919        })
 920    }
 921
 922    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
 923        let conn_id = self.connection_id()?;
 924        self.peer.disconnect(conn_id);
 925        self.set_status(Status::SignedOut, cx);
 926        Ok(())
 927    }
 928
 929    fn connection_id(&self) -> Result<ConnectionId> {
 930        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
 931            Ok(connection_id)
 932        } else {
 933            Err(anyhow!("not connected"))
 934        }
 935    }
 936
 937    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
 938        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
 939        self.peer.send(self.connection_id()?, message)
 940    }
 941
 942    pub fn request<T: RequestMessage>(
 943        &self,
 944        request: T,
 945    ) -> impl Future<Output = Result<T::Response>> {
 946        let client_id = self.id;
 947        log::debug!(
 948            "rpc request start. client_id:{}. name:{}",
 949            client_id,
 950            T::NAME
 951        );
 952        let response = self
 953            .connection_id()
 954            .map(|conn_id| self.peer.request(conn_id, request));
 955        async move {
 956            let response = response?.await;
 957            log::debug!(
 958                "rpc request finish. client_id:{}. name:{}",
 959                client_id,
 960                T::NAME
 961            );
 962            response
 963        }
 964    }
 965
 966    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
 967        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
 968        self.peer.respond(receipt, response)
 969    }
 970
 971    fn respond_with_error<T: RequestMessage>(
 972        &self,
 973        receipt: Receipt<T>,
 974        error: proto::Error,
 975    ) -> Result<()> {
 976        log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
 977        self.peer.respond_with_error(receipt, error)
 978    }
 979}
 980
 981impl AnyWeakEntityHandle {
 982    fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
 983        match self {
 984            AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
 985            AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
 986        }
 987    }
 988}
 989
 990fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
 991    if IMPERSONATE_LOGIN.is_some() {
 992        return None;
 993    }
 994
 995    let (user_id, access_token) = cx
 996        .platform()
 997        .read_credentials(&ZED_SERVER_URL)
 998        .log_err()
 999        .flatten()?;
1000    Some(Credentials {
1001        user_id: user_id.parse().ok()?,
1002        access_token: String::from_utf8(access_token).ok()?,
1003    })
1004}
1005
1006fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1007    cx.platform().write_credentials(
1008        &ZED_SERVER_URL,
1009        &credentials.user_id.to_string(),
1010        credentials.access_token.as_bytes(),
1011    )
1012}
1013
1014const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
1015
1016pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1017    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1018}
1019
1020pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1021    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1022    let mut parts = path.split('/');
1023    let id = parts.next()?.parse::<u64>().ok()?;
1024    let access_token = parts.next()?;
1025    if access_token.is_empty() {
1026        return None;
1027    }
1028    Some((id, access_token.to_string()))
1029}
1030
1031#[cfg(test)]
1032mod tests {
1033    use super::*;
1034    use crate::test::{FakeHttpClient, FakeServer};
1035    use gpui::TestAppContext;
1036
1037    #[gpui::test(iterations = 10)]
1038    async fn test_reconnection(cx: &mut TestAppContext) {
1039        cx.foreground().forbid_parking();
1040
1041        let user_id = 5;
1042        let mut client = Client::new(FakeHttpClient::with_404_response());
1043        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1044        let mut status = client.status();
1045        assert!(matches!(
1046            status.next().await,
1047            Some(Status::Connected { .. })
1048        ));
1049        assert_eq!(server.auth_count(), 1);
1050
1051        server.forbid_connections();
1052        server.disconnect();
1053        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1054
1055        server.allow_connections();
1056        cx.foreground().advance_clock(Duration::from_secs(10));
1057        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1058        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1059
1060        server.forbid_connections();
1061        server.disconnect();
1062        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1063
1064        // Clear cached credentials after authentication fails
1065        server.roll_access_token();
1066        server.allow_connections();
1067        cx.foreground().advance_clock(Duration::from_secs(10));
1068        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1069        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1070    }
1071
1072    #[test]
1073    fn test_encode_and_decode_worktree_url() {
1074        let url = encode_worktree_url(5, "deadbeef");
1075        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1076        assert_eq!(
1077            decode_worktree_url(&format!("\n {}\t", url)),
1078            Some((5, "deadbeef".to_string()))
1079        );
1080        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1081    }
1082
1083    #[gpui::test]
1084    async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1085        cx.foreground().forbid_parking();
1086
1087        let user_id = 5;
1088        let mut client = Client::new(FakeHttpClient::with_404_response());
1089        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1090
1091        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1092        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1093        client.add_model_message_handler(
1094            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::UnshareProject>, _, cx| {
1095                match model.read_with(&cx, |model, _| model.id) {
1096                    1 => done_tx1.try_send(()).unwrap(),
1097                    2 => done_tx2.try_send(()).unwrap(),
1098                    _ => unreachable!(),
1099                }
1100                async { Ok(()) }
1101            },
1102        );
1103        let model1 = cx.add_model(|_| Model {
1104            id: 1,
1105            subscription: None,
1106        });
1107        let model2 = cx.add_model(|_| Model {
1108            id: 2,
1109            subscription: None,
1110        });
1111        let model3 = cx.add_model(|_| Model {
1112            id: 3,
1113            subscription: None,
1114        });
1115
1116        let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1117        let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1118        // Ensure dropping a subscription for the same entity type still allows receiving of
1119        // messages for other entity IDs of the same type.
1120        let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1121        drop(subscription3);
1122
1123        server.send(proto::UnshareProject { project_id: 1 });
1124        server.send(proto::UnshareProject { project_id: 2 });
1125        done_rx1.next().await.unwrap();
1126        done_rx2.next().await.unwrap();
1127    }
1128
1129    #[gpui::test]
1130    async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1131        cx.foreground().forbid_parking();
1132
1133        let user_id = 5;
1134        let mut client = Client::new(FakeHttpClient::with_404_response());
1135        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1136
1137        let model = cx.add_model(|_| Model::default());
1138        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1139        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1140        let subscription1 = client.add_message_handler(
1141            model.clone(),
1142            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1143                done_tx1.try_send(()).unwrap();
1144                async { Ok(()) }
1145            },
1146        );
1147        drop(subscription1);
1148        let _subscription2 =
1149            client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1150                done_tx2.try_send(()).unwrap();
1151                async { Ok(()) }
1152            });
1153        server.send(proto::Ping {});
1154        done_rx2.next().await.unwrap();
1155    }
1156
1157    #[gpui::test]
1158    async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1159        cx.foreground().forbid_parking();
1160
1161        let user_id = 5;
1162        let mut client = Client::new(FakeHttpClient::with_404_response());
1163        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1164
1165        let model = cx.add_model(|_| Model::default());
1166        let (done_tx, mut done_rx) = smol::channel::unbounded();
1167        let subscription = client.add_message_handler(
1168            model.clone(),
1169            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1170                model.update(&mut cx, |model, _| model.subscription.take());
1171                done_tx.try_send(()).unwrap();
1172                async { Ok(()) }
1173            },
1174        );
1175        model.update(cx, |model, _| {
1176            model.subscription = Some(subscription);
1177        });
1178        server.send(proto::Ping {});
1179        done_rx.next().await.unwrap();
1180    }
1181
1182    #[derive(Default)]
1183    struct Model {
1184        id: usize,
1185        subscription: Option<Subscription>,
1186    }
1187
1188    impl Entity for Model {
1189        type Event = ();
1190    }
1191}