client.rs

   1#[cfg(any(test, feature = "test-support"))]
   2pub mod test;
   3
   4pub mod channel;
   5pub mod http;
   6pub mod user;
   7
   8use anyhow::{anyhow, Context, Result};
   9use async_recursion::async_recursion;
  10use async_tungstenite::tungstenite::{
  11    error::Error as WebsocketError,
  12    http::{Request, StatusCode},
  13};
  14use futures::{future::LocalBoxFuture, FutureExt, StreamExt};
  15use gpui::{
  16    action, AnyModelHandle, AnyWeakModelHandle, AsyncAppContext, Entity, ModelContext, ModelHandle,
  17    MutableAppContext, Task,
  18};
  19use http::HttpClient;
  20use lazy_static::lazy_static;
  21use parking_lot::RwLock;
  22use postage::watch;
  23use rand::prelude::*;
  24use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
  25use std::{
  26    any::TypeId,
  27    collections::HashMap,
  28    convert::TryFrom,
  29    fmt::Write as _,
  30    future::Future,
  31    sync::{
  32        atomic::{AtomicUsize, Ordering},
  33        Arc, Weak,
  34    },
  35    time::{Duration, Instant},
  36};
  37use surf::{http::Method, Url};
  38use thiserror::Error;
  39use util::{ResultExt, TryFutureExt};
  40
  41pub use channel::*;
  42pub use rpc::*;
  43pub use user::*;
  44
  45lazy_static! {
  46    static ref ZED_SERVER_URL: String =
  47        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
  48    static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
  49        .ok()
  50        .and_then(|s| if s.is_empty() { None } else { Some(s) });
  51}
  52
  53action!(Authenticate);
  54
  55pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
  56    cx.add_global_action(move |_: &Authenticate, cx| {
  57        let rpc = rpc.clone();
  58        cx.spawn(|cx| async move { rpc.authenticate_and_connect(&cx).log_err().await })
  59            .detach();
  60    });
  61}
  62
  63pub struct Client {
  64    id: usize,
  65    peer: Arc<Peer>,
  66    http: Arc<dyn HttpClient>,
  67    state: RwLock<ClientState>,
  68    authenticate:
  69        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
  70    establish_connection: Option<
  71        Box<
  72            dyn 'static
  73                + Send
  74                + Sync
  75                + Fn(
  76                    &Credentials,
  77                    &AsyncAppContext,
  78                ) -> Task<Result<Connection, EstablishConnectionError>>,
  79        >,
  80    >,
  81}
  82
  83#[derive(Error, Debug)]
  84pub enum EstablishConnectionError {
  85    #[error("upgrade required")]
  86    UpgradeRequired,
  87    #[error("unauthorized")]
  88    Unauthorized,
  89    #[error("{0}")]
  90    Other(#[from] anyhow::Error),
  91    #[error("{0}")]
  92    Io(#[from] std::io::Error),
  93    #[error("{0}")]
  94    Http(#[from] async_tungstenite::tungstenite::http::Error),
  95}
  96
  97impl From<WebsocketError> for EstablishConnectionError {
  98    fn from(error: WebsocketError) -> Self {
  99        if let WebsocketError::Http(response) = &error {
 100            match response.status() {
 101                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 102                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 103                _ => {}
 104            }
 105        }
 106        EstablishConnectionError::Other(error.into())
 107    }
 108}
 109
 110impl EstablishConnectionError {
 111    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 112        Self::Other(error.into())
 113    }
 114}
 115
 116#[derive(Copy, Clone, Debug)]
 117pub enum Status {
 118    SignedOut,
 119    UpgradeRequired,
 120    Authenticating,
 121    Connecting,
 122    ConnectionError,
 123    Connected { connection_id: ConnectionId },
 124    ConnectionLost,
 125    Reauthenticating,
 126    Reconnecting,
 127    ReconnectionError { next_reconnection: Instant },
 128}
 129
 130struct ClientState {
 131    credentials: Option<Credentials>,
 132    status: (watch::Sender<Status>, watch::Receiver<Status>),
 133    entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
 134    _maintain_connection: Option<Task<()>>,
 135    heartbeat_interval: Duration,
 136
 137    pending_messages: HashMap<(TypeId, u64), Vec<Box<dyn AnyTypedEnvelope>>>,
 138    models_by_entity_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakModelHandle>,
 139    models_by_message_type: HashMap<TypeId, AnyModelHandle>,
 140    model_types_by_message_type: HashMap<TypeId, TypeId>,
 141    message_handlers: HashMap<
 142        TypeId,
 143        Arc<
 144            dyn Send
 145                + Sync
 146                + Fn(
 147                    AnyModelHandle,
 148                    Box<dyn AnyTypedEnvelope>,
 149                    AsyncAppContext,
 150                ) -> LocalBoxFuture<'static, Result<()>>,
 151        >,
 152    >,
 153}
 154
 155#[derive(Clone, Debug)]
 156pub struct Credentials {
 157    pub user_id: u64,
 158    pub access_token: String,
 159}
 160
 161impl Default for ClientState {
 162    fn default() -> Self {
 163        Self {
 164            credentials: None,
 165            status: watch::channel_with(Status::SignedOut),
 166            entity_id_extractors: Default::default(),
 167            _maintain_connection: None,
 168            heartbeat_interval: Duration::from_secs(5),
 169            models_by_message_type: Default::default(),
 170            models_by_entity_type_and_remote_id: Default::default(),
 171            model_types_by_message_type: Default::default(),
 172            pending_messages: Default::default(),
 173            message_handlers: Default::default(),
 174        }
 175    }
 176}
 177
 178pub enum Subscription {
 179    Entity {
 180        client: Weak<Client>,
 181        id: (TypeId, u64),
 182    },
 183    Message {
 184        client: Weak<Client>,
 185        id: TypeId,
 186    },
 187}
 188
 189impl Drop for Subscription {
 190    fn drop(&mut self) {
 191        match self {
 192            Subscription::Entity { client, id } => {
 193                if let Some(client) = client.upgrade() {
 194                    let mut state = client.state.write();
 195                    let _ = state.models_by_entity_type_and_remote_id.remove(id);
 196                }
 197            }
 198            Subscription::Message { client, id } => {
 199                if let Some(client) = client.upgrade() {
 200                    let mut state = client.state.write();
 201                    let _ = state.model_types_by_message_type.remove(id);
 202                    let _ = state.message_handlers.remove(id);
 203                }
 204            }
 205        }
 206    }
 207}
 208
 209impl Client {
 210    pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
 211        lazy_static! {
 212            static ref NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::default();
 213        }
 214
 215        Arc::new(Self {
 216            id: NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst),
 217            peer: Peer::new(),
 218            http,
 219            state: Default::default(),
 220            authenticate: None,
 221            establish_connection: None,
 222        })
 223    }
 224
 225    #[cfg(any(test, feature = "test-support"))]
 226    pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
 227    where
 228        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
 229    {
 230        self.authenticate = Some(Box::new(authenticate));
 231        self
 232    }
 233
 234    #[cfg(any(test, feature = "test-support"))]
 235    pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
 236    where
 237        F: 'static
 238            + Send
 239            + Sync
 240            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
 241    {
 242        self.establish_connection = Some(Box::new(connect));
 243        self
 244    }
 245
 246    pub fn user_id(&self) -> Option<u64> {
 247        self.state
 248            .read()
 249            .credentials
 250            .as_ref()
 251            .map(|credentials| credentials.user_id)
 252    }
 253
 254    pub fn status(&self) -> watch::Receiver<Status> {
 255        self.state.read().status.1.clone()
 256    }
 257
 258    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
 259        let mut state = self.state.write();
 260        *state.status.0.borrow_mut() = status;
 261
 262        match status {
 263            Status::Connected { .. } => {
 264                let heartbeat_interval = state.heartbeat_interval;
 265                let this = self.clone();
 266                let foreground = cx.foreground();
 267                state._maintain_connection = Some(cx.foreground().spawn(async move {
 268                    loop {
 269                        foreground.timer(heartbeat_interval).await;
 270                        let _ = this.request(proto::Ping {}).await;
 271                    }
 272                }));
 273            }
 274            Status::ConnectionLost => {
 275                let this = self.clone();
 276                let foreground = cx.foreground();
 277                let heartbeat_interval = state.heartbeat_interval;
 278                state._maintain_connection = Some(cx.spawn(|cx| async move {
 279                    let mut rng = StdRng::from_entropy();
 280                    let mut delay = Duration::from_millis(100);
 281                    while let Err(error) = this.authenticate_and_connect(&cx).await {
 282                        log::error!("failed to connect {}", error);
 283                        this.set_status(
 284                            Status::ReconnectionError {
 285                                next_reconnection: Instant::now() + delay,
 286                            },
 287                            &cx,
 288                        );
 289                        foreground.timer(delay).await;
 290                        delay = delay
 291                            .mul_f32(rng.gen_range(1.0..=2.0))
 292                            .min(heartbeat_interval);
 293                    }
 294                }));
 295            }
 296            Status::SignedOut | Status::UpgradeRequired => {
 297                state._maintain_connection.take();
 298            }
 299            _ => {}
 300        }
 301    }
 302
 303    pub fn add_model_for_remote_entity<T: Entity>(
 304        self: &Arc<Self>,
 305        remote_id: u64,
 306        cx: &mut ModelContext<T>,
 307    ) -> Subscription {
 308        let handle = AnyModelHandle::from(cx.handle());
 309        let mut state = self.state.write();
 310        let id = (TypeId::of::<T>(), remote_id);
 311        state
 312            .models_by_entity_type_and_remote_id
 313            .insert(id, handle.downgrade());
 314        let pending_messages = state.pending_messages.remove(&id);
 315        drop(state);
 316
 317        let client_id = self.id;
 318        for message in pending_messages.into_iter().flatten() {
 319            let type_id = message.payload_type_id();
 320            let type_name = message.payload_type_name();
 321            let state = self.state.read();
 322            if let Some(handler) = state.message_handlers.get(&type_id).cloned() {
 323                let future = (handler)(handle.clone(), message, cx.to_async());
 324                drop(state);
 325                log::debug!(
 326                    "deferred rpc message received. client_id:{}, name:{}",
 327                    client_id,
 328                    type_name
 329                );
 330                cx.foreground()
 331                    .spawn(async move {
 332                        match future.await {
 333                            Ok(()) => {
 334                                log::debug!(
 335                                    "deferred rpc message handled. client_id:{}, name:{}",
 336                                    client_id,
 337                                    type_name
 338                                );
 339                            }
 340                            Err(error) => {
 341                                log::error!(
 342                                    "error handling deferred message. client_id:{}, name:{}, {}",
 343                                    client_id,
 344                                    type_name,
 345                                    error
 346                                );
 347                            }
 348                        }
 349                    })
 350                    .detach();
 351            }
 352        }
 353
 354        Subscription::Entity {
 355            client: Arc::downgrade(self),
 356            id,
 357        }
 358    }
 359
 360    pub fn add_message_handler<M, E, H, F>(
 361        self: &Arc<Self>,
 362        model: ModelHandle<E>,
 363        handler: H,
 364    ) -> Subscription
 365    where
 366        M: EnvelopedMessage,
 367        E: Entity,
 368        H: 'static
 369            + Send
 370            + Sync
 371            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 372        F: 'static + Future<Output = Result<()>>,
 373    {
 374        let message_type_id = TypeId::of::<M>();
 375
 376        let client = self.clone();
 377        let mut state = self.state.write();
 378        state
 379            .models_by_message_type
 380            .insert(message_type_id, model.into());
 381
 382        let prev_handler = state.message_handlers.insert(
 383            message_type_id,
 384            Arc::new(move |handle, envelope, cx| {
 385                let model = handle.downcast::<E>().unwrap();
 386                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 387                handler(model, *envelope, client.clone(), cx).boxed_local()
 388            }),
 389        );
 390        if prev_handler.is_some() {
 391            panic!("registered handler for the same message twice");
 392        }
 393
 394        Subscription::Message {
 395            client: Arc::downgrade(self),
 396            id: message_type_id,
 397        }
 398    }
 399
 400    pub fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 401    where
 402        M: EntityMessage,
 403        E: Entity,
 404        H: 'static
 405            + Send
 406            + Sync
 407            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 408        F: 'static + Future<Output = Result<()>>,
 409    {
 410        let model_type_id = TypeId::of::<E>();
 411        let message_type_id = TypeId::of::<M>();
 412
 413        let client = self.clone();
 414        let mut state = self.state.write();
 415        state
 416            .model_types_by_message_type
 417            .insert(message_type_id, model_type_id);
 418        state
 419            .entity_id_extractors
 420            .entry(message_type_id)
 421            .or_insert_with(|| {
 422                Box::new(|envelope| {
 423                    let envelope = envelope
 424                        .as_any()
 425                        .downcast_ref::<TypedEnvelope<M>>()
 426                        .unwrap();
 427                    envelope.payload.remote_entity_id()
 428                })
 429            });
 430
 431        let prev_handler = state.message_handlers.insert(
 432            message_type_id,
 433            Arc::new(move |handle, envelope, cx| {
 434                let model = handle.downcast::<E>().unwrap();
 435                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 436                handler(model, *envelope, client.clone(), cx).boxed_local()
 437            }),
 438        );
 439        if prev_handler.is_some() {
 440            panic!("registered handler for the same message twice");
 441        }
 442    }
 443
 444    pub fn add_entity_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
 445    where
 446        M: EntityMessage + RequestMessage,
 447        E: Entity,
 448        H: 'static
 449            + Send
 450            + Sync
 451            + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
 452        F: 'static + Future<Output = Result<M::Response>>,
 453    {
 454        self.add_entity_message_handler(move |model, envelope, client, cx| {
 455            let receipt = envelope.receipt();
 456            let response = handler(model, envelope, client.clone(), cx);
 457            async move {
 458                match response.await {
 459                    Ok(response) => {
 460                        client.respond(receipt, response)?;
 461                        Ok(())
 462                    }
 463                    Err(error) => {
 464                        client.respond_with_error(
 465                            receipt,
 466                            proto::Error {
 467                                message: error.to_string(),
 468                            },
 469                        )?;
 470                        Err(error)
 471                    }
 472                }
 473            }
 474        })
 475    }
 476
 477    pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
 478        read_credentials_from_keychain(cx).is_some()
 479    }
 480
 481    #[async_recursion(?Send)]
 482    pub async fn authenticate_and_connect(
 483        self: &Arc<Self>,
 484        cx: &AsyncAppContext,
 485    ) -> anyhow::Result<()> {
 486        let was_disconnected = match *self.status().borrow() {
 487            Status::SignedOut => true,
 488            Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
 489                false
 490            }
 491            Status::Connected { .. }
 492            | Status::Connecting { .. }
 493            | Status::Reconnecting { .. }
 494            | Status::Authenticating
 495            | Status::Reauthenticating => return Ok(()),
 496            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
 497        };
 498
 499        if was_disconnected {
 500            self.set_status(Status::Authenticating, cx);
 501        } else {
 502            self.set_status(Status::Reauthenticating, cx)
 503        }
 504
 505        let mut used_keychain = false;
 506        let credentials = self.state.read().credentials.clone();
 507        let credentials = if let Some(credentials) = credentials {
 508            credentials
 509        } else if let Some(credentials) = read_credentials_from_keychain(cx) {
 510            used_keychain = true;
 511            credentials
 512        } else {
 513            let credentials = match self.authenticate(&cx).await {
 514                Ok(credentials) => credentials,
 515                Err(err) => {
 516                    self.set_status(Status::ConnectionError, cx);
 517                    return Err(err);
 518                }
 519            };
 520            credentials
 521        };
 522
 523        if was_disconnected {
 524            self.set_status(Status::Connecting, cx);
 525        } else {
 526            self.set_status(Status::Reconnecting, cx);
 527        }
 528
 529        match self.establish_connection(&credentials, cx).await {
 530            Ok(conn) => {
 531                self.state.write().credentials = Some(credentials.clone());
 532                if !used_keychain && IMPERSONATE_LOGIN.is_none() {
 533                    write_credentials_to_keychain(&credentials, cx).log_err();
 534                }
 535                self.set_connection(conn, cx).await;
 536                Ok(())
 537            }
 538            Err(EstablishConnectionError::Unauthorized) => {
 539                self.state.write().credentials.take();
 540                if used_keychain {
 541                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
 542                    self.set_status(Status::SignedOut, cx);
 543                    self.authenticate_and_connect(cx).await
 544                } else {
 545                    self.set_status(Status::ConnectionError, cx);
 546                    Err(EstablishConnectionError::Unauthorized)?
 547                }
 548            }
 549            Err(EstablishConnectionError::UpgradeRequired) => {
 550                self.set_status(Status::UpgradeRequired, cx);
 551                Err(EstablishConnectionError::UpgradeRequired)?
 552            }
 553            Err(error) => {
 554                self.set_status(Status::ConnectionError, cx);
 555                Err(error)?
 556            }
 557        }
 558    }
 559
 560    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
 561        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
 562        cx.foreground()
 563            .spawn({
 564                let cx = cx.clone();
 565                let this = self.clone();
 566                async move {
 567                    while let Some(message) = incoming.next().await {
 568                        let mut state = this.state.write();
 569                        let payload_type_id = message.payload_type_id();
 570                        let type_name = message.payload_type_name();
 571                        let model_type_id = state
 572                            .model_types_by_message_type
 573                            .get(&payload_type_id)
 574                            .copied();
 575                        let entity_id = state
 576                            .entity_id_extractors
 577                            .get(&message.payload_type_id())
 578                            .map(|extract_entity_id| (extract_entity_id)(message.as_ref()));
 579
 580                        let model = state
 581                            .models_by_message_type
 582                            .get(&payload_type_id)
 583                            .cloned()
 584                            .or_else(|| {
 585                                let model_type_id = model_type_id?;
 586                                let entity_id = entity_id?;
 587                                let model = state
 588                                    .models_by_entity_type_and_remote_id
 589                                    .get(&(model_type_id, entity_id))?;
 590                                if let Some(model) = model.upgrade(&cx) {
 591                                    Some(model)
 592                                } else {
 593                                    state
 594                                        .models_by_entity_type_and_remote_id
 595                                        .remove(&(model_type_id, entity_id));
 596                                    None
 597                                }
 598                            });
 599
 600                        let model = if let Some(model) = model {
 601                            model
 602                        } else {
 603                            log::info!("unhandled message {}", type_name);
 604                            if let Some((model_type_id, entity_id)) = model_type_id.zip(entity_id) {
 605                                state
 606                                    .pending_messages
 607                                    .entry((model_type_id, entity_id))
 608                                    .or_default()
 609                                    .push(message);
 610                            }
 611
 612                            continue;
 613                        };
 614
 615                        if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
 616                        {
 617                            drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
 618                            let future = handler(model, message, cx.clone());
 619
 620                            let client_id = this.id;
 621                            log::debug!(
 622                                "rpc message received. client_id:{}, name:{}",
 623                                client_id,
 624                                type_name
 625                            );
 626                            cx.foreground()
 627                                .spawn(async move {
 628                                    match future.await {
 629                                        Ok(()) => {
 630                                            log::debug!(
 631                                                "rpc message handled. client_id:{}, name:{}",
 632                                                client_id,
 633                                                type_name
 634                                            );
 635                                        }
 636                                        Err(error) => {
 637                                            log::error!(
 638                                                "error handling message. client_id:{}, name:{}, {}",
 639                                                client_id,
 640                                                type_name,
 641                                                error
 642                                            );
 643                                        }
 644                                    }
 645                                })
 646                                .detach();
 647                        } else {
 648                            log::info!("unhandled message {}", type_name);
 649                        }
 650                    }
 651                }
 652            })
 653            .detach();
 654
 655        self.set_status(Status::Connected { connection_id }, cx);
 656
 657        let handle_io = cx.background().spawn(handle_io);
 658        let this = self.clone();
 659        let cx = cx.clone();
 660        cx.foreground()
 661            .spawn(async move {
 662                match handle_io.await {
 663                    Ok(()) => this.set_status(Status::SignedOut, &cx),
 664                    Err(err) => {
 665                        log::error!("connection error: {:?}", err);
 666                        this.set_status(Status::ConnectionLost, &cx);
 667                    }
 668                }
 669            })
 670            .detach();
 671    }
 672
 673    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
 674        if let Some(callback) = self.authenticate.as_ref() {
 675            callback(cx)
 676        } else {
 677            self.authenticate_with_browser(cx)
 678        }
 679    }
 680
 681    fn establish_connection(
 682        self: &Arc<Self>,
 683        credentials: &Credentials,
 684        cx: &AsyncAppContext,
 685    ) -> Task<Result<Connection, EstablishConnectionError>> {
 686        if let Some(callback) = self.establish_connection.as_ref() {
 687            callback(credentials, cx)
 688        } else {
 689            self.establish_websocket_connection(credentials, cx)
 690        }
 691    }
 692
 693    fn establish_websocket_connection(
 694        self: &Arc<Self>,
 695        credentials: &Credentials,
 696        cx: &AsyncAppContext,
 697    ) -> Task<Result<Connection, EstablishConnectionError>> {
 698        let request = Request::builder()
 699            .header(
 700                "Authorization",
 701                format!("{} {}", credentials.user_id, credentials.access_token),
 702            )
 703            .header("X-Zed-Protocol-Version", rpc::PROTOCOL_VERSION);
 704
 705        let http = self.http.clone();
 706        cx.background().spawn(async move {
 707            let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
 708            let rpc_request = surf::Request::new(
 709                Method::Get,
 710                surf::Url::parse(&rpc_url).context("invalid ZED_SERVER_URL")?,
 711            );
 712            let rpc_response = http.send(rpc_request).await?;
 713
 714            if rpc_response.status().is_redirection() {
 715                rpc_url = rpc_response
 716                    .header("Location")
 717                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
 718                    .as_str()
 719                    .to_string();
 720            }
 721            // Until we switch the zed.dev domain to point to the new Next.js app, there
 722            // will be no redirect required, and the app will connect directly to
 723            // wss://zed.dev/rpc.
 724            else if rpc_response.status() != surf::StatusCode::UpgradeRequired {
 725                Err(anyhow!(
 726                    "unexpected /rpc response status {}",
 727                    rpc_response.status()
 728                ))?
 729            }
 730
 731            let mut rpc_url = surf::Url::parse(&rpc_url).context("invalid rpc url")?;
 732            let rpc_host = rpc_url
 733                .host_str()
 734                .zip(rpc_url.port_or_known_default())
 735                .ok_or_else(|| anyhow!("missing host in rpc url"))?;
 736            let stream = smol::net::TcpStream::connect(rpc_host).await?;
 737
 738            log::info!("connected to rpc endpoint {}", rpc_url);
 739
 740            match rpc_url.scheme() {
 741                "https" => {
 742                    rpc_url.set_scheme("wss").unwrap();
 743                    let request = request.uri(rpc_url.as_str()).body(())?;
 744                    let (stream, _) =
 745                        async_tungstenite::async_tls::client_async_tls(request, stream).await?;
 746                    Ok(Connection::new(stream))
 747                }
 748                "http" => {
 749                    rpc_url.set_scheme("ws").unwrap();
 750                    let request = request.uri(rpc_url.as_str()).body(())?;
 751                    let (stream, _) = async_tungstenite::client_async(request, stream).await?;
 752                    Ok(Connection::new(stream))
 753                }
 754                _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
 755            }
 756        })
 757    }
 758
 759    pub fn authenticate_with_browser(
 760        self: &Arc<Self>,
 761        cx: &AsyncAppContext,
 762    ) -> Task<Result<Credentials>> {
 763        let platform = cx.platform();
 764        let executor = cx.background();
 765        executor.clone().spawn(async move {
 766            // Generate a pair of asymmetric encryption keys. The public key will be used by the
 767            // zed server to encrypt the user's access token, so that it can'be intercepted by
 768            // any other app running on the user's device.
 769            let (public_key, private_key) =
 770                rpc::auth::keypair().expect("failed to generate keypair for auth");
 771            let public_key_string =
 772                String::try_from(public_key).expect("failed to serialize public key for auth");
 773
 774            // Start an HTTP server to receive the redirect from Zed's sign-in page.
 775            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
 776            let port = server.server_addr().port();
 777
 778            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
 779            // that the user is signing in from a Zed app running on the same device.
 780            let mut url = format!(
 781                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
 782                *ZED_SERVER_URL, port, public_key_string
 783            );
 784
 785            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
 786                log::info!("impersonating user @{}", impersonate_login);
 787                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
 788            }
 789
 790            platform.open_url(&url);
 791
 792            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
 793            // access token from the query params.
 794            //
 795            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
 796            // custom URL scheme instead of this local HTTP server.
 797            let (user_id, access_token) = executor
 798                .spawn(async move {
 799                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
 800                        let path = req.url();
 801                        let mut user_id = None;
 802                        let mut access_token = None;
 803                        let url = Url::parse(&format!("http://example.com{}", path))
 804                            .context("failed to parse login notification url")?;
 805                        for (key, value) in url.query_pairs() {
 806                            if key == "access_token" {
 807                                access_token = Some(value.to_string());
 808                            } else if key == "user_id" {
 809                                user_id = Some(value.to_string());
 810                            }
 811                        }
 812
 813                        let post_auth_url =
 814                            format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
 815                        req.respond(
 816                            tiny_http::Response::empty(302).with_header(
 817                                tiny_http::Header::from_bytes(
 818                                    &b"Location"[..],
 819                                    post_auth_url.as_bytes(),
 820                                )
 821                                .unwrap(),
 822                            ),
 823                        )
 824                        .context("failed to respond to login http request")?;
 825                        Ok((
 826                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
 827                            access_token
 828                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
 829                        ))
 830                    } else {
 831                        Err(anyhow!("didn't receive login redirect"))
 832                    }
 833                })
 834                .await?;
 835
 836            let access_token = private_key
 837                .decrypt_string(&access_token)
 838                .context("failed to decrypt access token")?;
 839            platform.activate(true);
 840
 841            Ok(Credentials {
 842                user_id: user_id.parse()?,
 843                access_token,
 844            })
 845        })
 846    }
 847
 848    pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
 849        let conn_id = self.connection_id()?;
 850        self.peer.disconnect(conn_id);
 851        self.set_status(Status::SignedOut, cx);
 852        Ok(())
 853    }
 854
 855    fn connection_id(&self) -> Result<ConnectionId> {
 856        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
 857            Ok(connection_id)
 858        } else {
 859            Err(anyhow!("not connected"))
 860        }
 861    }
 862
 863    pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
 864        log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
 865        self.peer.send(self.connection_id()?, message)
 866    }
 867
 868    pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
 869        log::debug!(
 870            "rpc request start. client_id: {}. name:{}",
 871            self.id,
 872            T::NAME
 873        );
 874        let response = self.peer.request(self.connection_id()?, request).await;
 875        log::debug!(
 876            "rpc request finish. client_id: {}. name:{}",
 877            self.id,
 878            T::NAME
 879        );
 880        response
 881    }
 882
 883    fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
 884        log::debug!("rpc respond. client_id: {}. name:{}", self.id, T::NAME);
 885        self.peer.respond(receipt, response)
 886    }
 887
 888    fn respond_with_error<T: RequestMessage>(
 889        &self,
 890        receipt: Receipt<T>,
 891        error: proto::Error,
 892    ) -> Result<()> {
 893        log::debug!("rpc respond. client_id: {}. name:{}", self.id, T::NAME);
 894        self.peer.respond_with_error(receipt, error)
 895    }
 896}
 897
 898fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
 899    if IMPERSONATE_LOGIN.is_some() {
 900        return None;
 901    }
 902
 903    let (user_id, access_token) = cx
 904        .platform()
 905        .read_credentials(&ZED_SERVER_URL)
 906        .log_err()
 907        .flatten()?;
 908    Some(Credentials {
 909        user_id: user_id.parse().ok()?,
 910        access_token: String::from_utf8(access_token).ok()?,
 911    })
 912}
 913
 914fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
 915    cx.platform().write_credentials(
 916        &ZED_SERVER_URL,
 917        &credentials.user_id.to_string(),
 918        credentials.access_token.as_bytes(),
 919    )
 920}
 921
 922const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
 923
 924pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
 925    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
 926}
 927
 928pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
 929    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
 930    let mut parts = path.split('/');
 931    let id = parts.next()?.parse::<u64>().ok()?;
 932    let access_token = parts.next()?;
 933    if access_token.is_empty() {
 934        return None;
 935    }
 936    Some((id, access_token.to_string()))
 937}
 938
 939#[cfg(test)]
 940mod tests {
 941    use super::*;
 942    use crate::test::{FakeHttpClient, FakeServer};
 943    use gpui::TestAppContext;
 944
 945    #[gpui::test(iterations = 10)]
 946    async fn test_heartbeat(cx: TestAppContext) {
 947        cx.foreground().forbid_parking();
 948
 949        let user_id = 5;
 950        let mut client = Client::new(FakeHttpClient::with_404_response());
 951        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
 952
 953        cx.foreground().advance_clock(Duration::from_secs(10));
 954        let ping = server.receive::<proto::Ping>().await.unwrap();
 955        server.respond(ping.receipt(), proto::Ack {}).await;
 956
 957        cx.foreground().advance_clock(Duration::from_secs(10));
 958        let ping = server.receive::<proto::Ping>().await.unwrap();
 959        server.respond(ping.receipt(), proto::Ack {}).await;
 960
 961        client.disconnect(&cx.to_async()).unwrap();
 962        assert!(server.receive::<proto::Ping>().await.is_err());
 963    }
 964
 965    #[gpui::test(iterations = 10)]
 966    async fn test_reconnection(cx: TestAppContext) {
 967        cx.foreground().forbid_parking();
 968
 969        let user_id = 5;
 970        let mut client = Client::new(FakeHttpClient::with_404_response());
 971        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
 972        let mut status = client.status();
 973        assert!(matches!(
 974            status.next().await,
 975            Some(Status::Connected { .. })
 976        ));
 977        assert_eq!(server.auth_count(), 1);
 978
 979        server.forbid_connections();
 980        server.disconnect();
 981        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
 982
 983        server.allow_connections();
 984        cx.foreground().advance_clock(Duration::from_secs(10));
 985        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
 986        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
 987
 988        server.forbid_connections();
 989        server.disconnect();
 990        while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
 991
 992        // Clear cached credentials after authentication fails
 993        server.roll_access_token();
 994        server.allow_connections();
 995        cx.foreground().advance_clock(Duration::from_secs(10));
 996        assert_eq!(server.auth_count(), 1);
 997        cx.foreground().advance_clock(Duration::from_secs(10));
 998        while !matches!(status.next().await, Some(Status::Connected { .. })) {}
 999        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1000    }
1001
1002    #[test]
1003    fn test_encode_and_decode_worktree_url() {
1004        let url = encode_worktree_url(5, "deadbeef");
1005        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1006        assert_eq!(
1007            decode_worktree_url(&format!("\n {}\t", url)),
1008            Some((5, "deadbeef".to_string()))
1009        );
1010        assert_eq!(decode_worktree_url("not://the-right-format"), None);
1011    }
1012
1013    #[gpui::test]
1014    async fn test_subscribing_to_entity(mut cx: TestAppContext) {
1015        cx.foreground().forbid_parking();
1016
1017        let user_id = 5;
1018        let mut client = Client::new(FakeHttpClient::with_404_response());
1019        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1020
1021        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1022        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1023        client.add_entity_message_handler(
1024            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::UnshareProject>, _, cx| {
1025                match model.read_with(&cx, |model, _| model.id) {
1026                    1 => done_tx1.try_send(()).unwrap(),
1027                    2 => done_tx2.try_send(()).unwrap(),
1028                    _ => unreachable!(),
1029                }
1030                async { Ok(()) }
1031            },
1032        );
1033        let model1 = cx.add_model(|_| Model {
1034            id: 1,
1035            subscription: None,
1036        });
1037        let model2 = cx.add_model(|_| Model {
1038            id: 2,
1039            subscription: None,
1040        });
1041        let model3 = cx.add_model(|_| Model {
1042            id: 3,
1043            subscription: None,
1044        });
1045
1046        let _subscription1 =
1047            model1.update(&mut cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1048        let _subscription2 =
1049            model2.update(&mut cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1050        // Ensure dropping a subscription for the same entity type still allows receiving of
1051        // messages for other entity IDs of the same type.
1052        let subscription3 =
1053            model3.update(&mut cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1054        drop(subscription3);
1055
1056        server.send(proto::UnshareProject { project_id: 1 });
1057        server.send(proto::UnshareProject { project_id: 2 });
1058        done_rx1.next().await.unwrap();
1059        done_rx2.next().await.unwrap();
1060    }
1061
1062    #[gpui::test]
1063    async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) {
1064        cx.foreground().forbid_parking();
1065
1066        let user_id = 5;
1067        let mut client = Client::new(FakeHttpClient::with_404_response());
1068        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1069
1070        let model = cx.add_model(|_| Model::default());
1071        let (done_tx1, _done_rx1) = smol::channel::unbounded();
1072        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1073        let subscription1 = client.add_message_handler(
1074            model.clone(),
1075            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1076                done_tx1.try_send(()).unwrap();
1077                async { Ok(()) }
1078            },
1079        );
1080        drop(subscription1);
1081        let _subscription2 =
1082            client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1083                done_tx2.try_send(()).unwrap();
1084                async { Ok(()) }
1085            });
1086        server.send(proto::Ping {});
1087        done_rx2.next().await.unwrap();
1088    }
1089
1090    #[gpui::test]
1091    async fn test_dropping_subscription_in_handler(mut cx: TestAppContext) {
1092        cx.foreground().forbid_parking();
1093
1094        let user_id = 5;
1095        let mut client = Client::new(FakeHttpClient::with_404_response());
1096        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1097
1098        let model = cx.add_model(|_| Model::default());
1099        let (done_tx, mut done_rx) = smol::channel::unbounded();
1100        let subscription = client.add_message_handler(
1101            model.clone(),
1102            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1103                model.update(&mut cx, |model, _| model.subscription.take());
1104                done_tx.try_send(()).unwrap();
1105                async { Ok(()) }
1106            },
1107        );
1108        model.update(&mut cx, |model, _| {
1109            model.subscription = Some(subscription);
1110        });
1111        server.send(proto::Ping {});
1112        done_rx.next().await.unwrap();
1113    }
1114
1115    #[derive(Default)]
1116    struct Model {
1117        id: usize,
1118        subscription: Option<Subscription>,
1119    }
1120
1121    impl Entity for Model {
1122        type Event = ();
1123    }
1124}