Detailed changes
@@ -1449,6 +1449,43 @@ dependencies = [
"uuid 1.4.1",
]
+[[package]]
+name = "client2"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "async-recursion 0.3.2",
+ "async-tungstenite",
+ "collections",
+ "db",
+ "feature_flags",
+ "futures 0.3.28",
+ "gpui",
+ "gpui2",
+ "image",
+ "lazy_static",
+ "log",
+ "parking_lot 0.11.2",
+ "postage",
+ "rand 0.8.5",
+ "rpc",
+ "schemars",
+ "serde",
+ "serde_derive",
+ "settings",
+ "smol",
+ "sum_tree",
+ "sysinfo",
+ "tempfile",
+ "text",
+ "thiserror",
+ "time",
+ "tiny_http",
+ "url",
+ "util",
+ "uuid 1.4.1",
+]
+
[[package]]
name = "clock"
version = "0.1.0"
@@ -10379,6 +10416,7 @@ dependencies = [
"backtrace",
"chrono",
"cli",
+ "client2",
"collections",
"ctor",
"env_logger 0.9.3",
@@ -10,6 +10,7 @@ members = [
"crates/channel",
"crates/cli",
"crates/client",
+ "crates/client2",
"crates/clock",
"crates/collab",
"crates/collab_ui",
@@ -0,0 +1,52 @@
+[package]
+name = "client2"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+[lib]
+path = "src/client2.rs"
+doctest = false
+
+[features]
+test-support = ["collections/test-support", "gpui/test-support", "rpc/test-support"]
+
+[dependencies]
+collections = { path = "../collections" }
+db = { path = "../db" }
+gpui2 = { path = "../gpui2" }
+util = { path = "../util" }
+rpc = { path = "../rpc" }
+text = { path = "../text" }
+settings = { path = "../settings" }
+feature_flags = { path = "../feature_flags" }
+sum_tree = { path = "../sum_tree" }
+
+anyhow.workspace = true
+async-recursion = "0.3"
+async-tungstenite = { version = "0.16", features = ["async-tls"] }
+futures.workspace = true
+image = "0.23"
+lazy_static.workspace = true
+log.workspace = true
+parking_lot.workspace = true
+postage.workspace = true
+rand.workspace = true
+schemars.workspace = true
+serde.workspace = true
+serde_derive.workspace = true
+smol.workspace = true
+sysinfo.workspace = true
+tempfile = "3"
+thiserror.workspace = true
+time.workspace = true
+tiny_http = "0.8"
+uuid.workspace = true
+url = "2.2"
+
+[dev-dependencies]
+collections = { path = "../collections", features = ["test-support"] }
+gpui2 = { path = "../gpui2", features = ["test-support"] }
+rpc = { path = "../rpc", features = ["test-support"] }
+settings = { path = "../settings", features = ["test-support"] }
+util = { path = "../util", features = ["test-support"] }
@@ -0,0 +1,1723 @@
+#[cfg(any(test, feature = "test-support"))]
+pub mod test;
+
+pub mod telemetry;
+pub mod user;
+
+use anyhow::{anyhow, Context, Result};
+use async_recursion::async_recursion;
+use async_tungstenite::tungstenite::{
+ error::Error as WebsocketError,
+ http::{Request, StatusCode},
+};
+use futures::{
+ future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
+ TryStreamExt,
+};
+use gpui::{
+ actions, platform::AppVersion, serde_json, AnyModelHandle, AnyWeakModelHandle,
+ AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Task, View, ViewContext,
+ WeakViewHandle,
+};
+use lazy_static::lazy_static;
+use parking_lot::RwLock;
+use postage::watch;
+use rand::prelude::*;
+use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use std::{
+ any::TypeId,
+ collections::HashMap,
+ convert::TryFrom,
+ fmt::Write as _,
+ future::Future,
+ marker::PhantomData,
+ path::PathBuf,
+ sync::{atomic::AtomicU64, Arc, Weak},
+ time::{Duration, Instant},
+};
+use telemetry::Telemetry;
+use thiserror::Error;
+use url::Url;
+use util::channel::ReleaseChannel;
+use util::http::HttpClient;
+use util::{ResultExt, TryFutureExt};
+
+pub use rpc::*;
+pub use telemetry::ClickhouseEvent;
+pub use user::*;
+
+lazy_static! {
+ pub static ref ZED_SERVER_URL: String =
+ std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
+ pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
+ .ok()
+ .and_then(|s| if s.is_empty() { None } else { Some(s) });
+ pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
+ .ok()
+ .and_then(|s| if s.is_empty() { None } else { Some(s) });
+ pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
+ .ok()
+ .and_then(|v| v.parse().ok());
+ pub static ref ZED_APP_PATH: Option<PathBuf> =
+ std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
+ pub static ref ZED_ALWAYS_ACTIVE: bool =
+ std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| e.len() > 0);
+}
+
+pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
+pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
+pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
+
+actions!(client, [SignIn, SignOut, Reconnect]);
+
+pub fn init_settings(cx: &mut AppContext) {
+ settings::register::<TelemetrySettings>(cx);
+}
+
+pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
+ init_settings(cx);
+
+ let client = Arc::downgrade(client);
+ cx.add_global_action({
+ let client = client.clone();
+ move |_: &SignIn, cx| {
+ if let Some(client) = client.upgrade() {
+ cx.spawn(
+ |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
+ )
+ .detach();
+ }
+ }
+ });
+ cx.add_global_action({
+ let client = client.clone();
+ move |_: &SignOut, cx| {
+ if let Some(client) = client.upgrade() {
+ cx.spawn(|cx| async move {
+ client.disconnect(&cx);
+ })
+ .detach();
+ }
+ }
+ });
+ cx.add_global_action({
+ let client = client.clone();
+ move |_: &Reconnect, cx| {
+ if let Some(client) = client.upgrade() {
+ cx.spawn(|cx| async move {
+ client.reconnect(&cx);
+ })
+ .detach();
+ }
+ }
+ });
+}
+
+pub struct Client {
+ id: AtomicU64,
+ peer: Arc<Peer>,
+ http: Arc<dyn HttpClient>,
+ telemetry: Arc<Telemetry>,
+ state: RwLock<ClientState>,
+
+ #[allow(clippy::type_complexity)]
+ #[cfg(any(test, feature = "test-support"))]
+ authenticate: RwLock<
+ Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
+ >,
+
+ #[allow(clippy::type_complexity)]
+ #[cfg(any(test, feature = "test-support"))]
+ establish_connection: RwLock<
+ Option<
+ Box<
+ dyn 'static
+ + Send
+ + Sync
+ + Fn(
+ &Credentials,
+ &AsyncAppContext,
+ ) -> Task<Result<Connection, EstablishConnectionError>>,
+ >,
+ >,
+ >,
+}
+
+#[derive(Error, Debug)]
+pub enum EstablishConnectionError {
+ #[error("upgrade required")]
+ UpgradeRequired,
+ #[error("unauthorized")]
+ Unauthorized,
+ #[error("{0}")]
+ Other(#[from] anyhow::Error),
+ #[error("{0}")]
+ Http(#[from] util::http::Error),
+ #[error("{0}")]
+ Io(#[from] std::io::Error),
+ #[error("{0}")]
+ Websocket(#[from] async_tungstenite::tungstenite::http::Error),
+}
+
+impl From<WebsocketError> for EstablishConnectionError {
+ fn from(error: WebsocketError) -> Self {
+ if let WebsocketError::Http(response) = &error {
+ match response.status() {
+ StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
+ StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
+ _ => {}
+ }
+ }
+ EstablishConnectionError::Other(error.into())
+ }
+}
+
+impl EstablishConnectionError {
+ pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
+ Self::Other(error.into())
+ }
+}
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum Status {
+ SignedOut,
+ UpgradeRequired,
+ Authenticating,
+ Connecting,
+ ConnectionError,
+ Connected {
+ peer_id: PeerId,
+ connection_id: ConnectionId,
+ },
+ ConnectionLost,
+ Reauthenticating,
+ Reconnecting,
+ ReconnectionError {
+ next_reconnection: Instant,
+ },
+}
+
+impl Status {
+ pub fn is_connected(&self) -> bool {
+ matches!(self, Self::Connected { .. })
+ }
+
+ pub fn is_signed_out(&self) -> bool {
+ matches!(self, Self::SignedOut | Self::UpgradeRequired)
+ }
+}
+
+struct ClientState {
+ credentials: Option<Credentials>,
+ status: (watch::Sender<Status>, watch::Receiver<Status>),
+ entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
+ _reconnect_task: Option<Task<()>>,
+ reconnect_interval: Duration,
+ entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
+ models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
+ entity_types_by_message_type: HashMap<TypeId, TypeId>,
+ #[allow(clippy::type_complexity)]
+ message_handlers: HashMap<
+ TypeId,
+ Arc<
+ dyn Send
+ + Sync
+ + Fn(
+ Subscriber,
+ Box<dyn AnyTypedEnvelope>,
+ &Arc<Client>,
+ AsyncAppContext,
+ ) -> LocalBoxFuture<'static, Result<()>>,
+ >,
+ >,
+}
+
+enum WeakSubscriber {
+ Model(AnyWeakModelHandle),
+ View(AnyWeakViewHandle),
+ Pending(Vec<Box<dyn AnyTypedEnvelope>>),
+}
+
+enum Subscriber {
+ Model(AnyModelHandle),
+ View(AnyWeakViewHandle),
+}
+
+#[derive(Clone, Debug)]
+pub struct Credentials {
+ pub user_id: u64,
+ pub access_token: String,
+}
+
+impl Default for ClientState {
+ fn default() -> Self {
+ Self {
+ credentials: None,
+ status: watch::channel_with(Status::SignedOut),
+ entity_id_extractors: Default::default(),
+ _reconnect_task: None,
+ reconnect_interval: Duration::from_secs(5),
+ models_by_message_type: Default::default(),
+ entities_by_type_and_remote_id: Default::default(),
+ entity_types_by_message_type: Default::default(),
+ message_handlers: Default::default(),
+ }
+ }
+}
+
+pub enum Subscription {
+ Entity {
+ client: Weak<Client>,
+ id: (TypeId, u64),
+ },
+ Message {
+ client: Weak<Client>,
+ id: TypeId,
+ },
+}
+
+impl Drop for Subscription {
+ fn drop(&mut self) {
+ match self {
+ Subscription::Entity { client, id } => {
+ if let Some(client) = client.upgrade() {
+ let mut state = client.state.write();
+ let _ = state.entities_by_type_and_remote_id.remove(id);
+ }
+ }
+ Subscription::Message { client, id } => {
+ if let Some(client) = client.upgrade() {
+ let mut state = client.state.write();
+ let _ = state.entity_types_by_message_type.remove(id);
+ let _ = state.message_handlers.remove(id);
+ }
+ }
+ }
+ }
+}
+
+pub struct PendingEntitySubscription<T: Entity> {
+ client: Arc<Client>,
+ remote_id: u64,
+ _entity_type: PhantomData<T>,
+ consumed: bool,
+}
+
+impl<T: Entity> PendingEntitySubscription<T> {
+ pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
+ self.consumed = true;
+ let mut state = self.client.state.write();
+ let id = (TypeId::of::<T>(), self.remote_id);
+ let Some(WeakSubscriber::Pending(messages)) =
+ state.entities_by_type_and_remote_id.remove(&id)
+ else {
+ unreachable!()
+ };
+
+ state
+ .entities_by_type_and_remote_id
+ .insert(id, WeakSubscriber::Model(model.downgrade().into_any()));
+ drop(state);
+ for message in messages {
+ self.client.handle_message(message, cx);
+ }
+ Subscription::Entity {
+ client: Arc::downgrade(&self.client),
+ id,
+ }
+ }
+}
+
+impl<T: Entity> Drop for PendingEntitySubscription<T> {
+ fn drop(&mut self) {
+ if !self.consumed {
+ let mut state = self.client.state.write();
+ if let Some(WeakSubscriber::Pending(messages)) = state
+ .entities_by_type_and_remote_id
+ .remove(&(TypeId::of::<T>(), self.remote_id))
+ {
+ for message in messages {
+ log::info!("unhandled message {}", message.payload_type_name());
+ }
+ }
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct TelemetrySettings {
+ pub diagnostics: bool,
+ pub metrics: bool,
+}
+
+#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
+pub struct TelemetrySettingsContent {
+ pub diagnostics: Option<bool>,
+ pub metrics: Option<bool>,
+}
+
+impl settings::Setting for TelemetrySettings {
+ const KEY: Option<&'static str> = Some("telemetry");
+
+ type FileContent = TelemetrySettingsContent;
+
+ fn load(
+ default_value: &Self::FileContent,
+ user_values: &[&Self::FileContent],
+ _: &AppContext,
+ ) -> Result<Self> {
+ Ok(Self {
+ diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
+ default_value
+ .diagnostics
+ .ok_or_else(Self::missing_default)?,
+ ),
+ metrics: user_values
+ .first()
+ .and_then(|v| v.metrics)
+ .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
+ })
+ }
+}
+
+impl Client {
+ pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
+ Arc::new(Self {
+ id: AtomicU64::new(0),
+ peer: Peer::new(0),
+ telemetry: Telemetry::new(http.clone(), cx),
+ http,
+ state: Default::default(),
+
+ #[cfg(any(test, feature = "test-support"))]
+ authenticate: Default::default(),
+ #[cfg(any(test, feature = "test-support"))]
+ establish_connection: Default::default(),
+ })
+ }
+
+ pub fn id(&self) -> u64 {
+ self.id.load(std::sync::atomic::Ordering::SeqCst)
+ }
+
+ pub fn http_client(&self) -> Arc<dyn HttpClient> {
+ self.http.clone()
+ }
+
+ pub fn set_id(&self, id: u64) -> &Self {
+ self.id.store(id, std::sync::atomic::Ordering::SeqCst);
+ self
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn teardown(&self) {
+ let mut state = self.state.write();
+ state._reconnect_task.take();
+ state.message_handlers.clear();
+ state.models_by_message_type.clear();
+ state.entities_by_type_and_remote_id.clear();
+ state.entity_id_extractors.clear();
+ self.peer.teardown();
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
+ where
+ F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
+ {
+ *self.authenticate.write() = Some(Box::new(authenticate));
+ self
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn override_establish_connection<F>(&self, connect: F) -> &Self
+ where
+ F: 'static
+ + Send
+ + Sync
+ + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
+ {
+ *self.establish_connection.write() = Some(Box::new(connect));
+ self
+ }
+
+ pub fn user_id(&self) -> Option<u64> {
+ self.state
+ .read()
+ .credentials
+ .as_ref()
+ .map(|credentials| credentials.user_id)
+ }
+
+ pub fn peer_id(&self) -> Option<PeerId> {
+ if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
+ Some(*peer_id)
+ } else {
+ None
+ }
+ }
+
+ pub fn status(&self) -> watch::Receiver<Status> {
+ self.state.read().status.1.clone()
+ }
+
+ fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
+ log::info!("set status on client {}: {:?}", self.id(), status);
+ let mut state = self.state.write();
+ *state.status.0.borrow_mut() = status;
+
+ match status {
+ Status::Connected { .. } => {
+ state._reconnect_task = None;
+ }
+ Status::ConnectionLost => {
+ let this = self.clone();
+ let reconnect_interval = state.reconnect_interval;
+ state._reconnect_task = Some(cx.spawn(|cx| async move {
+ #[cfg(any(test, feature = "test-support"))]
+ let mut rng = StdRng::seed_from_u64(0);
+ #[cfg(not(any(test, feature = "test-support")))]
+ let mut rng = StdRng::from_entropy();
+
+ let mut delay = INITIAL_RECONNECTION_DELAY;
+ while let Err(error) = this.authenticate_and_connect(true, &cx).await {
+ log::error!("failed to connect {}", error);
+ if matches!(*this.status().borrow(), Status::ConnectionError) {
+ this.set_status(
+ Status::ReconnectionError {
+ next_reconnection: Instant::now() + delay,
+ },
+ &cx,
+ );
+ cx.background().timer(delay).await;
+ delay = delay
+ .mul_f32(rng.gen_range(1.0..=2.0))
+ .min(reconnect_interval);
+ } else {
+ break;
+ }
+ }
+ }));
+ }
+ Status::SignedOut | Status::UpgradeRequired => {
+ cx.read(|cx| self.telemetry.set_authenticated_user_info(None, false, cx));
+ state._reconnect_task.take();
+ }
+ _ => {}
+ }
+ }
+
+ pub fn add_view_for_remote_entity<T: View>(
+ self: &Arc<Self>,
+ remote_id: u64,
+ cx: &mut ViewContext<T>,
+ ) -> Subscription {
+ let id = (TypeId::of::<T>(), remote_id);
+ self.state
+ .write()
+ .entities_by_type_and_remote_id
+ .insert(id, WeakSubscriber::View(cx.weak_handle().into_any()));
+ Subscription::Entity {
+ client: Arc::downgrade(self),
+ id,
+ }
+ }
+
+ pub fn subscribe_to_entity<T: Entity>(
+ self: &Arc<Self>,
+ remote_id: u64,
+ ) -> Result<PendingEntitySubscription<T>> {
+ let id = (TypeId::of::<T>(), remote_id);
+
+ let mut state = self.state.write();
+ if state.entities_by_type_and_remote_id.contains_key(&id) {
+ return Err(anyhow!("already subscribed to entity"));
+ } else {
+ state
+ .entities_by_type_and_remote_id
+ .insert(id, WeakSubscriber::Pending(Default::default()));
+ Ok(PendingEntitySubscription {
+ client: self.clone(),
+ remote_id,
+ consumed: false,
+ _entity_type: PhantomData,
+ })
+ }
+ }
+
+ #[track_caller]
+ pub fn add_message_handler<M, E, H, F>(
+ self: &Arc<Self>,
+ model: ModelHandle<E>,
+ handler: H,
+ ) -> Subscription
+ where
+ M: EnvelopedMessage,
+ E: Entity,
+ H: 'static
+ + Send
+ + Sync
+ + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+ F: 'static + Future<Output = Result<()>>,
+ {
+ let message_type_id = TypeId::of::<M>();
+
+ let mut state = self.state.write();
+ state
+ .models_by_message_type
+ .insert(message_type_id, model.downgrade().into_any());
+
+ let prev_handler = state.message_handlers.insert(
+ message_type_id,
+ Arc::new(move |handle, envelope, client, cx| {
+ let handle = if let Subscriber::Model(handle) = handle {
+ handle
+ } else {
+ unreachable!();
+ };
+ let model = handle.downcast::<E>().unwrap();
+ let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
+ handler(model, *envelope, client.clone(), cx).boxed_local()
+ }),
+ );
+ if prev_handler.is_some() {
+ let location = std::panic::Location::caller();
+ panic!(
+ "{}:{} registered handler for the same message {} twice",
+ location.file(),
+ location.line(),
+ std::any::type_name::<M>()
+ );
+ }
+
+ Subscription::Message {
+ client: Arc::downgrade(self),
+ id: message_type_id,
+ }
+ }
+
+ pub fn add_request_handler<M, E, H, F>(
+ self: &Arc<Self>,
+ model: ModelHandle<E>,
+ handler: H,
+ ) -> Subscription
+ where
+ M: RequestMessage,
+ E: Entity,
+ H: 'static
+ + Send
+ + Sync
+ + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+ F: 'static + Future<Output = Result<M::Response>>,
+ {
+ self.add_message_handler(model, move |handle, envelope, this, cx| {
+ Self::respond_to_request(
+ envelope.receipt(),
+ handler(handle, envelope, this.clone(), cx),
+ this,
+ )
+ })
+ }
+
+ pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
+ where
+ M: EntityMessage,
+ E: View,
+ H: 'static
+ + Send
+ + Sync
+ + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+ F: 'static + Future<Output = Result<()>>,
+ {
+ self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
+ if let Subscriber::View(handle) = handle {
+ handler(handle.downcast::<E>().unwrap(), message, client, cx)
+ } else {
+ unreachable!();
+ }
+ })
+ }
+
+ pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
+ where
+ M: EntityMessage,
+ E: Entity,
+ H: 'static
+ + Send
+ + Sync
+ + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+ F: 'static + Future<Output = Result<()>>,
+ {
+ self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
+ if let Subscriber::Model(handle) = handle {
+ handler(handle.downcast::<E>().unwrap(), message, client, cx)
+ } else {
+ unreachable!();
+ }
+ })
+ }
+
+ fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
+ where
+ M: EntityMessage,
+ E: Entity,
+ H: 'static
+ + Send
+ + Sync
+ + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+ F: 'static + Future<Output = Result<()>>,
+ {
+ let model_type_id = TypeId::of::<E>();
+ let message_type_id = TypeId::of::<M>();
+
+ let mut state = self.state.write();
+ state
+ .entity_types_by_message_type
+ .insert(message_type_id, model_type_id);
+ state
+ .entity_id_extractors
+ .entry(message_type_id)
+ .or_insert_with(|| {
+ |envelope| {
+ envelope
+ .as_any()
+ .downcast_ref::<TypedEnvelope<M>>()
+ .unwrap()
+ .payload
+ .remote_entity_id()
+ }
+ });
+ let prev_handler = state.message_handlers.insert(
+ message_type_id,
+ Arc::new(move |handle, envelope, client, cx| {
+ let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
+ handler(handle, *envelope, client.clone(), cx).boxed_local()
+ }),
+ );
+ if prev_handler.is_some() {
+ panic!("registered handler for the same message twice");
+ }
+ }
+
+ pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
+ where
+ M: EntityMessage + RequestMessage,
+ E: Entity,
+ H: 'static
+ + Send
+ + Sync
+ + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+ F: 'static + Future<Output = Result<M::Response>>,
+ {
+ self.add_model_message_handler(move |entity, envelope, client, cx| {
+ Self::respond_to_request::<M, _>(
+ envelope.receipt(),
+ handler(entity, envelope, client.clone(), cx),
+ client,
+ )
+ })
+ }
+
+ pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
+ where
+ M: EntityMessage + RequestMessage,
+ E: View,
+ H: 'static
+ + Send
+ + Sync
+ + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+ F: 'static + Future<Output = Result<M::Response>>,
+ {
+ self.add_view_message_handler(move |entity, envelope, client, cx| {
+ Self::respond_to_request::<M, _>(
+ envelope.receipt(),
+ handler(entity, envelope, client.clone(), cx),
+ client,
+ )
+ })
+ }
+
+ async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
+ receipt: Receipt<T>,
+ response: F,
+ client: Arc<Self>,
+ ) -> Result<()> {
+ match response.await {
+ Ok(response) => {
+ client.respond(receipt, response)?;
+ Ok(())
+ }
+ Err(error) => {
+ client.respond_with_error(
+ receipt,
+ proto::Error {
+ message: format!("{:?}", error),
+ },
+ )?;
+ Err(error)
+ }
+ }
+ }
+
+ pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
+ read_credentials_from_keychain(cx).is_some()
+ }
+
+ #[async_recursion(?Send)]
+ pub async fn authenticate_and_connect(
+ self: &Arc<Self>,
+ try_keychain: bool,
+ cx: &AsyncAppContext,
+ ) -> anyhow::Result<()> {
+ let was_disconnected = match *self.status().borrow() {
+ Status::SignedOut => true,
+ Status::ConnectionError
+ | Status::ConnectionLost
+ | Status::Authenticating { .. }
+ | Status::Reauthenticating { .. }
+ | Status::ReconnectionError { .. } => false,
+ Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
+ return Ok(())
+ }
+ Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
+ };
+
+ if was_disconnected {
+ self.set_status(Status::Authenticating, cx);
+ } else {
+ self.set_status(Status::Reauthenticating, cx)
+ }
+
+ let mut read_from_keychain = false;
+ let mut credentials = self.state.read().credentials.clone();
+ if credentials.is_none() && try_keychain {
+ credentials = read_credentials_from_keychain(cx);
+ read_from_keychain = credentials.is_some();
+ }
+ if credentials.is_none() {
+ let mut status_rx = self.status();
+ let _ = status_rx.next().await;
+ futures::select_biased! {
+ authenticate = self.authenticate(cx).fuse() => {
+ match authenticate {
+ Ok(creds) => credentials = Some(creds),
+ Err(err) => {
+ self.set_status(Status::ConnectionError, cx);
+ return Err(err);
+ }
+ }
+ }
+ _ = status_rx.next().fuse() => {
+ return Err(anyhow!("authentication canceled"));
+ }
+ }
+ }
+ let credentials = credentials.unwrap();
+ self.set_id(credentials.user_id);
+
+ if was_disconnected {
+ self.set_status(Status::Connecting, cx);
+ } else {
+ self.set_status(Status::Reconnecting, cx);
+ }
+
+ let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
+ futures::select_biased! {
+ connection = self.establish_connection(&credentials, cx).fuse() => {
+ match connection {
+ Ok(conn) => {
+ self.state.write().credentials = Some(credentials.clone());
+ if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
+ write_credentials_to_keychain(&credentials, cx).log_err();
+ }
+
+ futures::select_biased! {
+ result = self.set_connection(conn, cx).fuse() => result,
+ _ = timeout => {
+ self.set_status(Status::ConnectionError, cx);
+ Err(anyhow!("timed out waiting on hello message from server"))
+ }
+ }
+ }
+ Err(EstablishConnectionError::Unauthorized) => {
+ self.state.write().credentials.take();
+ if read_from_keychain {
+ cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
+ self.set_status(Status::SignedOut, cx);
+ self.authenticate_and_connect(false, cx).await
+ } else {
+ self.set_status(Status::ConnectionError, cx);
+ Err(EstablishConnectionError::Unauthorized)?
+ }
+ }
+ Err(EstablishConnectionError::UpgradeRequired) => {
+ self.set_status(Status::UpgradeRequired, cx);
+ Err(EstablishConnectionError::UpgradeRequired)?
+ }
+ Err(error) => {
+ self.set_status(Status::ConnectionError, cx);
+ Err(error)?
+ }
+ }
+ }
+ _ = &mut timeout => {
+ self.set_status(Status::ConnectionError, cx);
+ Err(anyhow!("timed out trying to establish connection"))
+ }
+ }
+ }
+
+ async fn set_connection(
+ self: &Arc<Self>,
+ conn: Connection,
+ cx: &AsyncAppContext,
+ ) -> Result<()> {
+ let executor = cx.background();
+ log::info!("add connection to peer");
+ let (connection_id, handle_io, mut incoming) = self
+ .peer
+ .add_connection(conn, move |duration| executor.timer(duration));
+ let handle_io = cx.background().spawn(handle_io);
+
+ let peer_id = async {
+ log::info!("waiting for server hello");
+ let message = incoming
+ .next()
+ .await
+ .ok_or_else(|| anyhow!("no hello message received"))?;
+ log::info!("got server hello");
+ let hello_message_type_name = message.payload_type_name().to_string();
+ let hello = message
+ .into_any()
+ .downcast::<TypedEnvelope<proto::Hello>>()
+ .map_err(|_| {
+ anyhow!(
+ "invalid hello message received: {:?}",
+ hello_message_type_name
+ )
+ })?;
+ let peer_id = hello
+ .payload
+ .peer_id
+ .ok_or_else(|| anyhow!("invalid peer id"))?;
+ Ok(peer_id)
+ };
+
+ let peer_id = match peer_id.await {
+ Ok(peer_id) => peer_id,
+ Err(error) => {
+ self.peer.disconnect(connection_id);
+ return Err(error);
+ }
+ };
+
+ log::info!(
+ "set status to connected (connection id: {:?}, peer id: {:?})",
+ connection_id,
+ peer_id
+ );
+ self.set_status(
+ Status::Connected {
+ peer_id,
+ connection_id,
+ },
+ cx,
+ );
+ cx.foreground()
+ .spawn({
+ let cx = cx.clone();
+ let this = self.clone();
+ async move {
+ while let Some(message) = incoming.next().await {
+ this.handle_message(message, &cx);
+ // Don't starve the main thread when receiving lots of messages at once.
+ smol::future::yield_now().await;
+ }
+ }
+ })
+ .detach();
+
+ let this = self.clone();
+ let cx = cx.clone();
+ cx.foreground()
+ .spawn(async move {
+ match handle_io.await {
+ Ok(()) => {
+ if this.status().borrow().clone()
+ == (Status::Connected {
+ connection_id,
+ peer_id,
+ })
+ {
+ this.set_status(Status::SignedOut, &cx);
+ }
+ }
+ Err(err) => {
+ log::error!("connection error: {:?}", err);
+ this.set_status(Status::ConnectionLost, &cx);
+ }
+ }
+ })
+ .detach();
+
+ Ok(())
+ }
+
+ fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
+ #[cfg(any(test, feature = "test-support"))]
+ if let Some(callback) = self.authenticate.read().as_ref() {
+ return callback(cx);
+ }
+
+ self.authenticate_with_browser(cx)
+ }
+
+ fn establish_connection(
+ self: &Arc<Self>,
+ credentials: &Credentials,
+ cx: &AsyncAppContext,
+ ) -> Task<Result<Connection, EstablishConnectionError>> {
+ #[cfg(any(test, feature = "test-support"))]
+ if let Some(callback) = self.establish_connection.read().as_ref() {
+ return callback(credentials, cx);
+ }
+
+ self.establish_websocket_connection(credentials, cx)
+ }
+
+ async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
+ let preview_param = if is_preview { "?preview=1" } else { "" };
+ let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
+ let response = http.get(&url, Default::default(), false).await?;
+
+ // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
+ // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
+ // which requires authorization via an HTTP header.
+ //
+ // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
+ // of a collab server. In that case, a request to the /rpc endpoint will
+ // return an 'unauthorized' response.
+ let collab_url = if response.status().is_redirection() {
+ response
+ .headers()
+ .get("Location")
+ .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
+ .to_str()
+ .map_err(EstablishConnectionError::other)?
+ .to_string()
+ } else if response.status() == StatusCode::UNAUTHORIZED {
+ url
+ } else {
+ Err(anyhow!(
+ "unexpected /rpc response status {}",
+ response.status()
+ ))?
+ };
+
+ Url::parse(&collab_url).context("invalid rpc url")
+ }
+
+ fn establish_websocket_connection(
+ self: &Arc<Self>,
+ credentials: &Credentials,
+ cx: &AsyncAppContext,
+ ) -> Task<Result<Connection, EstablishConnectionError>> {
+ let use_preview_server = cx.read(|cx| {
+ if cx.has_global::<ReleaseChannel>() {
+ *cx.global::<ReleaseChannel>() != ReleaseChannel::Stable
+ } else {
+ false
+ }
+ });
+
+ let request = Request::builder()
+ .header(
+ "Authorization",
+ format!("{} {}", credentials.user_id, credentials.access_token),
+ )
+ .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
+
+ let http = self.http.clone();
+ cx.background().spawn(async move {
+ let mut rpc_url = Self::get_rpc_url(http, use_preview_server).await?;
+ let rpc_host = rpc_url
+ .host_str()
+ .zip(rpc_url.port_or_known_default())
+ .ok_or_else(|| anyhow!("missing host in rpc url"))?;
+ let stream = smol::net::TcpStream::connect(rpc_host).await?;
+
+ log::info!("connected to rpc endpoint {}", rpc_url);
+
+ match rpc_url.scheme() {
+ "https" => {
+ rpc_url.set_scheme("wss").unwrap();
+ let request = request.uri(rpc_url.as_str()).body(())?;
+ let (stream, _) =
+ async_tungstenite::async_tls::client_async_tls(request, stream).await?;
+ Ok(Connection::new(
+ stream
+ .map_err(|error| anyhow!(error))
+ .sink_map_err(|error| anyhow!(error)),
+ ))
+ }
+ "http" => {
+ rpc_url.set_scheme("ws").unwrap();
+ let request = request.uri(rpc_url.as_str()).body(())?;
+ let (stream, _) = async_tungstenite::client_async(request, stream).await?;
+ Ok(Connection::new(
+ stream
+ .map_err(|error| anyhow!(error))
+ .sink_map_err(|error| anyhow!(error)),
+ ))
+ }
+ _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
+ }
+ })
+ }
+
+ pub fn authenticate_with_browser(
+ self: &Arc<Self>,
+ cx: &AsyncAppContext,
+ ) -> Task<Result<Credentials>> {
+ let platform = cx.platform();
+ let executor = cx.background();
+ let http = self.http.clone();
+
+ executor.clone().spawn(async move {
+ // Generate a pair of asymmetric encryption keys. The public key will be used by the
+ // zed server to encrypt the user's access token, so that it can'be intercepted by
+ // any other app running on the user's device.
+ let (public_key, private_key) =
+ rpc::auth::keypair().expect("failed to generate keypair for auth");
+ let public_key_string =
+ String::try_from(public_key).expect("failed to serialize public key for auth");
+
+ if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
+ return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
+ }
+
+ // Start an HTTP server to receive the redirect from Zed's sign-in page.
+ let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
+ let port = server.server_addr().port();
+
+ // Open the Zed sign-in page in the user's browser, with query parameters that indicate
+ // that the user is signing in from a Zed app running on the same device.
+ let mut url = format!(
+ "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
+ *ZED_SERVER_URL, port, public_key_string
+ );
+
+ if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
+ log::info!("impersonating user @{}", impersonate_login);
+ write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
+ }
+
+ platform.open_url(&url);
+
+ // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
+ // access token from the query params.
+ //
+ // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
+ // custom URL scheme instead of this local HTTP server.
+ let (user_id, access_token) = executor
+ .spawn(async move {
+ for _ in 0..100 {
+ if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
+ let path = req.url();
+ let mut user_id = None;
+ let mut access_token = None;
+ let url = Url::parse(&format!("http://example.com{}", path))
+ .context("failed to parse login notification url")?;
+ for (key, value) in url.query_pairs() {
+ if key == "access_token" {
+ access_token = Some(value.to_string());
+ } else if key == "user_id" {
+ user_id = Some(value.to_string());
+ }
+ }
+
+ let post_auth_url =
+ format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
+ req.respond(
+ tiny_http::Response::empty(302).with_header(
+ tiny_http::Header::from_bytes(
+ &b"Location"[..],
+ post_auth_url.as_bytes(),
+ )
+ .unwrap(),
+ ),
+ )
+ .context("failed to respond to login http request")?;
+ return Ok((
+ user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
+ access_token
+ .ok_or_else(|| anyhow!("missing access_token parameter"))?,
+ ));
+ }
+ }
+
+ Err(anyhow!("didn't receive login redirect"))
+ })
+ .await?;
+
+ let access_token = private_key
+ .decrypt_string(&access_token)
+ .context("failed to decrypt access token")?;
+ platform.activate(true);
+
+ Ok(Credentials {
+ user_id: user_id.parse()?,
+ access_token,
+ })
+ })
+ }
+
+ async fn authenticate_as_admin(
+ http: Arc<dyn HttpClient>,
+ login: String,
+ mut api_token: String,
+ ) -> Result<Credentials> {
+ #[derive(Deserialize)]
+ struct AuthenticatedUserResponse {
+ user: User,
+ }
+
+ #[derive(Deserialize)]
+ struct User {
+ id: u64,
+ }
+
+ // Use the collab server's admin API to retrieve the id
+ // of the impersonated user.
+ let mut url = Self::get_rpc_url(http.clone(), false).await?;
+ url.set_path("/user");
+ url.set_query(Some(&format!("github_login={login}")));
+ let request = Request::get(url.as_str())
+ .header("Authorization", format!("token {api_token}"))
+ .body("".into())?;
+
+ let mut response = http.send(request).await?;
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ if !response.status().is_success() {
+ Err(anyhow!(
+ "admin user request failed {} - {}",
+ response.status().as_u16(),
+ body,
+ ))?;
+ }
+ let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
+
+ // Use the admin API token to authenticate as the impersonated user.
+ api_token.insert_str(0, "ADMIN_TOKEN:");
+ Ok(Credentials {
+ user_id: response.user.id,
+ access_token: api_token,
+ })
+ }
+
+ pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
+ self.peer.teardown();
+ self.set_status(Status::SignedOut, cx);
+ }
+
+ pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
+ self.peer.teardown();
+ self.set_status(Status::ConnectionLost, cx);
+ }
+
+ fn connection_id(&self) -> Result<ConnectionId> {
+ if let Status::Connected { connection_id, .. } = *self.status().borrow() {
+ Ok(connection_id)
+ } else {
+ Err(anyhow!("not connected"))
+ }
+ }
+
+ pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
+ log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
+ self.peer.send(self.connection_id()?, message)
+ }
+
+ pub fn request<T: RequestMessage>(
+ &self,
+ request: T,
+ ) -> impl Future<Output = Result<T::Response>> {
+ self.request_envelope(request)
+ .map_ok(|envelope| envelope.payload)
+ }
+
+ pub fn request_envelope<T: RequestMessage>(
+ &self,
+ request: T,
+ ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
+ let client_id = self.id();
+ log::debug!(
+ "rpc request start. client_id:{}. name:{}",
+ client_id,
+ T::NAME
+ );
+ let response = self
+ .connection_id()
+ .map(|conn_id| self.peer.request_envelope(conn_id, request));
+ async move {
+ let response = response?.await;
+ log::debug!(
+ "rpc request finish. client_id:{}. name:{}",
+ client_id,
+ T::NAME
+ );
+ response
+ }
+ }
+
+ fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
+ log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
+ self.peer.respond(receipt, response)
+ }
+
+ fn respond_with_error<T: RequestMessage>(
+ &self,
+ receipt: Receipt<T>,
+ error: proto::Error,
+ ) -> Result<()> {
+ log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
+ self.peer.respond_with_error(receipt, error)
+ }
+
+ fn handle_message(
+ self: &Arc<Client>,
+ message: Box<dyn AnyTypedEnvelope>,
+ cx: &AsyncAppContext,
+ ) {
+ let mut state = self.state.write();
+ let type_name = message.payload_type_name();
+ let payload_type_id = message.payload_type_id();
+ let sender_id = message.original_sender_id();
+
+ let mut subscriber = None;
+
+ if let Some(message_model) = state
+ .models_by_message_type
+ .get(&payload_type_id)
+ .and_then(|model| model.upgrade(cx))
+ {
+ subscriber = Some(Subscriber::Model(message_model));
+ } else if let Some((extract_entity_id, entity_type_id)) =
+ state.entity_id_extractors.get(&payload_type_id).zip(
+ state
+ .entity_types_by_message_type
+ .get(&payload_type_id)
+ .copied(),
+ )
+ {
+ let entity_id = (extract_entity_id)(message.as_ref());
+
+ match state
+ .entities_by_type_and_remote_id
+ .get_mut(&(entity_type_id, entity_id))
+ {
+ Some(WeakSubscriber::Pending(pending)) => {
+ pending.push(message);
+ return;
+ }
+ Some(weak_subscriber @ _) => match weak_subscriber {
+ WeakSubscriber::Model(handle) => {
+ subscriber = handle.upgrade(cx).map(Subscriber::Model);
+ }
+ WeakSubscriber::View(handle) => {
+ subscriber = Some(Subscriber::View(handle.clone()));
+ }
+ WeakSubscriber::Pending(_) => {}
+ },
+ _ => {}
+ }
+ }
+
+ let subscriber = if let Some(subscriber) = subscriber {
+ subscriber
+ } else {
+ log::info!("unhandled message {}", type_name);
+ self.peer.respond_with_unhandled_message(message).log_err();
+ return;
+ };
+
+ let handler = state.message_handlers.get(&payload_type_id).cloned();
+ // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
+ // It also ensures we don't hold the lock while yielding back to the executor, as
+ // that might cause the executor thread driving this future to block indefinitely.
+ drop(state);
+
+ if let Some(handler) = handler {
+ let future = handler(subscriber, message, &self, cx.clone());
+ let client_id = self.id();
+ log::debug!(
+ "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
+ client_id,
+ sender_id,
+ type_name
+ );
+ cx.foreground()
+ .spawn(async move {
+ match future.await {
+ Ok(()) => {
+ log::debug!(
+ "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
+ client_id,
+ sender_id,
+ type_name
+ );
+ }
+ Err(error) => {
+ log::error!(
+ "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
+ client_id,
+ sender_id,
+ type_name,
+ error
+ );
+ }
+ }
+ })
+ .detach();
+ } else {
+ log::info!("unhandled message {}", type_name);
+ self.peer.respond_with_unhandled_message(message).log_err();
+ }
+ }
+
+ pub fn telemetry(&self) -> &Arc<Telemetry> {
+ &self.telemetry
+ }
+}
+
+fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
+ if IMPERSONATE_LOGIN.is_some() {
+ return None;
+ }
+
+ let (user_id, access_token) = cx
+ .platform()
+ .read_credentials(&ZED_SERVER_URL)
+ .log_err()
+ .flatten()?;
+ Some(Credentials {
+ user_id: user_id.parse().ok()?,
+ access_token: String::from_utf8(access_token).ok()?,
+ })
+}
+
+fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
+ cx.platform().write_credentials(
+ &ZED_SERVER_URL,
+ &credentials.user_id.to_string(),
+ credentials.access_token.as_bytes(),
+ )
+}
+
+const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
+
+pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
+ format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
+}
+
+pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
+ let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
+ let mut parts = path.split('/');
+ let id = parts.next()?.parse::<u64>().ok()?;
+ let access_token = parts.next()?;
+ if access_token.is_empty() {
+ return None;
+ }
+ Some((id, access_token.to_string()))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::test::FakeServer;
+ use gpui::{executor::Deterministic, TestAppContext};
+ use parking_lot::Mutex;
+ use std::future;
+ use util::http::FakeHttpClient;
+
+ #[gpui::test(iterations = 10)]
+ async fn test_reconnection(cx: &mut TestAppContext) {
+ cx.foreground().forbid_parking();
+
+ let user_id = 5;
+ let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
+ let server = FakeServer::for_client(user_id, &client, cx).await;
+ let mut status = client.status();
+ assert!(matches!(
+ status.next().await,
+ Some(Status::Connected { .. })
+ ));
+ assert_eq!(server.auth_count(), 1);
+
+ server.forbid_connections();
+ server.disconnect();
+ while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
+
+ server.allow_connections();
+ cx.foreground().advance_clock(Duration::from_secs(10));
+ while !matches!(status.next().await, Some(Status::Connected { .. })) {}
+ assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
+
+ server.forbid_connections();
+ server.disconnect();
+ while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
+
+ // Clear cached credentials after authentication fails
+ server.roll_access_token();
+ server.allow_connections();
+ cx.foreground().advance_clock(Duration::from_secs(10));
+ while !matches!(status.next().await, Some(Status::Connected { .. })) {}
+ assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
+ deterministic.forbid_parking();
+
+ let user_id = 5;
+ let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
+ let mut status = client.status();
+
+ // Time out when client tries to connect.
+ client.override_authenticate(move |cx| {
+ cx.foreground().spawn(async move {
+ Ok(Credentials {
+ user_id,
+ access_token: "token".into(),
+ })
+ })
+ });
+ client.override_establish_connection(|_, cx| {
+ cx.foreground().spawn(async move {
+ future::pending::<()>().await;
+ unreachable!()
+ })
+ });
+ let auth_and_connect = cx.spawn({
+ let client = client.clone();
+ |cx| async move { client.authenticate_and_connect(false, &cx).await }
+ });
+ deterministic.run_until_parked();
+ assert!(matches!(status.next().await, Some(Status::Connecting)));
+
+ deterministic.advance_clock(CONNECTION_TIMEOUT);
+ assert!(matches!(
+ status.next().await,
+ Some(Status::ConnectionError { .. })
+ ));
+ auth_and_connect.await.unwrap_err();
+
+ // Allow the connection to be established.
+ let server = FakeServer::for_client(user_id, &client, cx).await;
+ assert!(matches!(
+ status.next().await,
+ Some(Status::Connected { .. })
+ ));
+
+ // Disconnect client.
+ server.forbid_connections();
+ server.disconnect();
+ while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
+
+ // Time out when re-establishing the connection.
+ server.allow_connections();
+ client.override_establish_connection(|_, cx| {
+ cx.foreground().spawn(async move {
+ future::pending::<()>().await;
+ unreachable!()
+ })
+ });
+ deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
+ assert!(matches!(
+ status.next().await,
+ Some(Status::Reconnecting { .. })
+ ));
+
+ deterministic.advance_clock(CONNECTION_TIMEOUT);
+ assert!(matches!(
+ status.next().await,
+ Some(Status::ReconnectionError { .. })
+ ));
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_authenticating_more_than_once(
+ cx: &mut TestAppContext,
+ deterministic: Arc<Deterministic>,
+ ) {
+ cx.foreground().forbid_parking();
+
+ let auth_count = Arc::new(Mutex::new(0));
+ let dropped_auth_count = Arc::new(Mutex::new(0));
+ let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
+ client.override_authenticate({
+ let auth_count = auth_count.clone();
+ let dropped_auth_count = dropped_auth_count.clone();
+ move |cx| {
+ let auth_count = auth_count.clone();
+ let dropped_auth_count = dropped_auth_count.clone();
+ cx.foreground().spawn(async move {
+ *auth_count.lock() += 1;
+ let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
+ future::pending::<()>().await;
+ unreachable!()
+ })
+ }
+ });
+
+ let _authenticate = cx.spawn(|cx| {
+ let client = client.clone();
+ async move { client.authenticate_and_connect(false, &cx).await }
+ });
+ deterministic.run_until_parked();
+ assert_eq!(*auth_count.lock(), 1);
+ assert_eq!(*dropped_auth_count.lock(), 0);
+
+ let _authenticate = cx.spawn(|cx| {
+ let client = client.clone();
+ async move { client.authenticate_and_connect(false, &cx).await }
+ });
+ deterministic.run_until_parked();
+ assert_eq!(*auth_count.lock(), 2);
+ assert_eq!(*dropped_auth_count.lock(), 1);
+ }
+
+ #[test]
+ fn test_encode_and_decode_worktree_url() {
+ let url = encode_worktree_url(5, "deadbeef");
+ assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
+ assert_eq!(
+ decode_worktree_url(&format!("\n {}\t", url)),
+ Some((5, "deadbeef".to_string()))
+ );
+ assert_eq!(decode_worktree_url("not://the-right-format"), None);
+ }
+
+ #[gpui::test]
+ async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
+ cx.foreground().forbid_parking();
+
+ let user_id = 5;
+ let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
+ let server = FakeServer::for_client(user_id, &client, cx).await;
+
+ let (done_tx1, mut done_rx1) = smol::channel::unbounded();
+ let (done_tx2, mut done_rx2) = smol::channel::unbounded();
+ client.add_model_message_handler(
+ move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
+ match model.read_with(&cx, |model, _| model.id) {
+ 1 => done_tx1.try_send(()).unwrap(),
+ 2 => done_tx2.try_send(()).unwrap(),
+ _ => unreachable!(),
+ }
+ async { Ok(()) }
+ },
+ );
+ let model1 = cx.add_model(|_| Model {
+ id: 1,
+ subscription: None,
+ });
+ let model2 = cx.add_model(|_| Model {
+ id: 2,
+ subscription: None,
+ });
+ let model3 = cx.add_model(|_| Model {
+ id: 3,
+ subscription: None,
+ });
+
+ let _subscription1 = client
+ .subscribe_to_entity(1)
+ .unwrap()
+ .set_model(&model1, &mut cx.to_async());
+ let _subscription2 = client
+ .subscribe_to_entity(2)
+ .unwrap()
+ .set_model(&model2, &mut cx.to_async());
+ // Ensure dropping a subscription for the same entity type still allows receiving of
+ // messages for other entity IDs of the same type.
+ let subscription3 = client
+ .subscribe_to_entity(3)
+ .unwrap()
+ .set_model(&model3, &mut cx.to_async());
+ drop(subscription3);
+
+ server.send(proto::JoinProject { project_id: 1 });
+ server.send(proto::JoinProject { project_id: 2 });
+ done_rx1.next().await.unwrap();
+ done_rx2.next().await.unwrap();
+ }
+
+ #[gpui::test]
+ async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
+ cx.foreground().forbid_parking();
+
+ let user_id = 5;
+ let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
+ let server = FakeServer::for_client(user_id, &client, cx).await;
+
+ let model = cx.add_model(|_| Model::default());
+ let (done_tx1, _done_rx1) = smol::channel::unbounded();
+ let (done_tx2, mut done_rx2) = smol::channel::unbounded();
+ let subscription1 = client.add_message_handler(
+ model.clone(),
+ move |_, _: TypedEnvelope<proto::Ping>, _, _| {
+ done_tx1.try_send(()).unwrap();
+ async { Ok(()) }
+ },
+ );
+ drop(subscription1);
+ let _subscription2 = client.add_message_handler(
+ model.clone(),
+ move |_, _: TypedEnvelope<proto::Ping>, _, _| {
+ done_tx2.try_send(()).unwrap();
+ async { Ok(()) }
+ },
+ );
+ server.send(proto::Ping {});
+ done_rx2.next().await.unwrap();
+ }
+
+ #[gpui::test]
+ async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
+ cx.foreground().forbid_parking();
+
+ let user_id = 5;
+ let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
+ let server = FakeServer::for_client(user_id, &client, cx).await;
+
+ let model = cx.add_model(|_| Model::default());
+ let (done_tx, mut done_rx) = smol::channel::unbounded();
+ let subscription = client.add_message_handler(
+ model.clone(),
+ move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
+ model.update(&mut cx, |model, _| model.subscription.take());
+ done_tx.try_send(()).unwrap();
+ async { Ok(()) }
+ },
+ );
+ model.update(cx, |model, _| {
+ model.subscription = Some(subscription);
+ });
+ server.send(proto::Ping {});
+ done_rx.next().await.unwrap();
+ }
+
+ #[derive(Default)]
+ struct Model {
+ id: usize,
+ subscription: Option<Subscription>,
+ }
+
+ impl Entity for Model {
+ type Event = ();
+ }
+}
@@ -0,0 +1,317 @@
+use crate::{TelemetrySettings, ZED_SECRET_CLIENT_TOKEN, ZED_SERVER_URL};
+use gpui::{executor::Background, serde_json, AppContext, Task};
+use lazy_static::lazy_static;
+use parking_lot::Mutex;
+use serde::Serialize;
+use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration};
+use sysinfo::{Pid, PidExt, ProcessExt, System, SystemExt};
+use tempfile::NamedTempFile;
+use util::http::HttpClient;
+use util::{channel::ReleaseChannel, TryFutureExt};
+
+pub struct Telemetry {
+ http_client: Arc<dyn HttpClient>,
+ executor: Arc<Background>,
+ state: Mutex<TelemetryState>,
+}
+
+#[derive(Default)]
+struct TelemetryState {
+ metrics_id: Option<Arc<str>>, // Per logged-in user
+ installation_id: Option<Arc<str>>, // Per app installation (different for dev, preview, and stable)
+ session_id: Option<Arc<str>>, // Per app launch
+ app_version: Option<Arc<str>>,
+ release_channel: Option<&'static str>,
+ os_name: &'static str,
+ os_version: Option<Arc<str>>,
+ architecture: &'static str,
+ clickhouse_events_queue: Vec<ClickhouseEventWrapper>,
+ flush_clickhouse_events_task: Option<Task<()>>,
+ log_file: Option<NamedTempFile>,
+ is_staff: Option<bool>,
+}
+
+const CLICKHOUSE_EVENTS_URL_PATH: &'static str = "/api/events";
+
+lazy_static! {
+ static ref CLICKHOUSE_EVENTS_URL: String =
+ format!("{}{}", *ZED_SERVER_URL, CLICKHOUSE_EVENTS_URL_PATH);
+}
+
+#[derive(Serialize, Debug)]
+struct ClickhouseEventRequestBody {
+ token: &'static str,
+ installation_id: Option<Arc<str>>,
+ session_id: Option<Arc<str>>,
+ is_staff: Option<bool>,
+ app_version: Option<Arc<str>>,
+ os_name: &'static str,
+ os_version: Option<Arc<str>>,
+ architecture: &'static str,
+ release_channel: Option<&'static str>,
+ events: Vec<ClickhouseEventWrapper>,
+}
+
+#[derive(Serialize, Debug)]
+struct ClickhouseEventWrapper {
+ signed_in: bool,
+ #[serde(flatten)]
+ event: ClickhouseEvent,
+}
+
+#[derive(Serialize, Debug)]
+#[serde(rename_all = "snake_case")]
+pub enum AssistantKind {
+ Panel,
+ Inline,
+}
+
+#[derive(Serialize, Debug)]
+#[serde(tag = "type")]
+pub enum ClickhouseEvent {
+ Editor {
+ operation: &'static str,
+ file_extension: Option<String>,
+ vim_mode: bool,
+ copilot_enabled: bool,
+ copilot_enabled_for_language: bool,
+ },
+ Copilot {
+ suggestion_id: Option<String>,
+ suggestion_accepted: bool,
+ file_extension: Option<String>,
+ },
+ Call {
+ operation: &'static str,
+ room_id: Option<u64>,
+ channel_id: Option<u64>,
+ },
+ Assistant {
+ conversation_id: Option<String>,
+ kind: AssistantKind,
+ model: &'static str,
+ },
+ Cpu {
+ usage_as_percentage: f32,
+ core_count: u32,
+ },
+ Memory {
+ memory_in_bytes: u64,
+ virtual_memory_in_bytes: u64,
+ },
+}
+
+#[cfg(debug_assertions)]
+const MAX_QUEUE_LEN: usize = 1;
+
+#[cfg(not(debug_assertions))]
+const MAX_QUEUE_LEN: usize = 10;
+
+#[cfg(debug_assertions)]
+const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(1);
+
+#[cfg(not(debug_assertions))]
+const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(30);
+
+impl Telemetry {
+ pub fn new(client: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
+ let platform = cx.platform();
+ let release_channel = if cx.has_global::<ReleaseChannel>() {
+ Some(cx.global::<ReleaseChannel>().display_name())
+ } else {
+ None
+ };
+ // TODO: Replace all hardware stuff with nested SystemSpecs json
+ let this = Arc::new(Self {
+ http_client: client,
+ executor: cx.background().clone(),
+ state: Mutex::new(TelemetryState {
+ os_name: platform.os_name().into(),
+ os_version: platform.os_version().ok().map(|v| v.to_string().into()),
+ architecture: env::consts::ARCH,
+ app_version: platform.app_version().ok().map(|v| v.to_string().into()),
+ release_channel,
+ installation_id: None,
+ metrics_id: None,
+ session_id: None,
+ clickhouse_events_queue: Default::default(),
+ flush_clickhouse_events_task: Default::default(),
+ log_file: None,
+ is_staff: None,
+ }),
+ });
+
+ this
+ }
+
+ pub fn log_file_path(&self) -> Option<PathBuf> {
+ Some(self.state.lock().log_file.as_ref()?.path().to_path_buf())
+ }
+
+ pub fn start(
+ self: &Arc<Self>,
+ installation_id: Option<String>,
+ session_id: String,
+ cx: &mut AppContext,
+ ) {
+ let mut state = self.state.lock();
+ state.installation_id = installation_id.map(|id| id.into());
+ state.session_id = Some(session_id.into());
+ let has_clickhouse_events = !state.clickhouse_events_queue.is_empty();
+ drop(state);
+
+ if has_clickhouse_events {
+ self.flush_clickhouse_events();
+ }
+
+ let this = self.clone();
+ cx.spawn(|mut cx| async move {
+ let mut system = System::new_all();
+ system.refresh_all();
+
+ loop {
+ // Waiting some amount of time before the first query is important to get a reasonable value
+ // https://docs.rs/sysinfo/0.29.10/sysinfo/trait.ProcessExt.html#tymethod.cpu_usage
+ const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(60);
+ smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await;
+
+ system.refresh_memory();
+ system.refresh_processes();
+
+ let current_process = Pid::from_u32(std::process::id());
+ let Some(process) = system.processes().get(¤t_process) else {
+ let process = current_process;
+ log::error!("Failed to find own process {process:?} in system process table");
+ // TODO: Fire an error telemetry event
+ return;
+ };
+
+ let memory_event = ClickhouseEvent::Memory {
+ memory_in_bytes: process.memory(),
+ virtual_memory_in_bytes: process.virtual_memory(),
+ };
+
+ let cpu_event = ClickhouseEvent::Cpu {
+ usage_as_percentage: process.cpu_usage(),
+ core_count: system.cpus().len() as u32,
+ };
+
+ let telemetry_settings = cx.update(|cx| *settings::get::<TelemetrySettings>(cx));
+
+ this.report_clickhouse_event(memory_event, telemetry_settings);
+ this.report_clickhouse_event(cpu_event, telemetry_settings);
+ }
+ })
+ .detach();
+ }
+
+ pub fn set_authenticated_user_info(
+ self: &Arc<Self>,
+ metrics_id: Option<String>,
+ is_staff: bool,
+ cx: &AppContext,
+ ) {
+ if !settings::get::<TelemetrySettings>(cx).metrics {
+ return;
+ }
+
+ let mut state = self.state.lock();
+ let metrics_id: Option<Arc<str>> = metrics_id.map(|id| id.into());
+ state.metrics_id = metrics_id.clone();
+ state.is_staff = Some(is_staff);
+ drop(state);
+ }
+
+ pub fn report_clickhouse_event(
+ self: &Arc<Self>,
+ event: ClickhouseEvent,
+ telemetry_settings: TelemetrySettings,
+ ) {
+ if !telemetry_settings.metrics {
+ return;
+ }
+
+ let mut state = self.state.lock();
+ let signed_in = state.metrics_id.is_some();
+ state
+ .clickhouse_events_queue
+ .push(ClickhouseEventWrapper { signed_in, event });
+
+ if state.installation_id.is_some() {
+ if state.clickhouse_events_queue.len() >= MAX_QUEUE_LEN {
+ drop(state);
+ self.flush_clickhouse_events();
+ } else {
+ let this = self.clone();
+ let executor = self.executor.clone();
+ state.flush_clickhouse_events_task = Some(self.executor.spawn(async move {
+ executor.timer(DEBOUNCE_INTERVAL).await;
+ this.flush_clickhouse_events();
+ }));
+ }
+ }
+ }
+
+ pub fn metrics_id(self: &Arc<Self>) -> Option<Arc<str>> {
+ self.state.lock().metrics_id.clone()
+ }
+
+ pub fn installation_id(self: &Arc<Self>) -> Option<Arc<str>> {
+ self.state.lock().installation_id.clone()
+ }
+
+ pub fn is_staff(self: &Arc<Self>) -> Option<bool> {
+ self.state.lock().is_staff
+ }
+
+ fn flush_clickhouse_events(self: &Arc<Self>) {
+ let mut state = self.state.lock();
+ let mut events = mem::take(&mut state.clickhouse_events_queue);
+ state.flush_clickhouse_events_task.take();
+ drop(state);
+
+ let this = self.clone();
+ self.executor
+ .spawn(
+ async move {
+ let mut json_bytes = Vec::new();
+
+ if let Some(file) = &mut this.state.lock().log_file {
+ let file = file.as_file_mut();
+ for event in &mut events {
+ json_bytes.clear();
+ serde_json::to_writer(&mut json_bytes, event)?;
+ file.write_all(&json_bytes)?;
+ file.write(b"\n")?;
+ }
+ }
+
+ {
+ let state = this.state.lock();
+ let request_body = ClickhouseEventRequestBody {
+ token: ZED_SECRET_CLIENT_TOKEN,
+ installation_id: state.installation_id.clone(),
+ session_id: state.session_id.clone(),
+ is_staff: state.is_staff.clone(),
+ app_version: state.app_version.clone(),
+ os_name: state.os_name,
+ os_version: state.os_version.clone(),
+ architecture: state.architecture,
+
+ release_channel: state.release_channel,
+ events,
+ };
+ json_bytes.clear();
+ serde_json::to_writer(&mut json_bytes, &request_body)?;
+ }
+
+ this.http_client
+ .post_json(CLICKHOUSE_EVENTS_URL.as_str(), json_bytes.into())
+ .await?;
+ anyhow::Ok(())
+ }
+ .log_err(),
+ )
+ .detach();
+ }
+}
@@ -0,0 +1,215 @@
+use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
+use anyhow::{anyhow, Result};
+use futures::{stream::BoxStream, StreamExt};
+use gpui::{executor, ModelHandle, TestAppContext};
+use parking_lot::Mutex;
+use rpc::{
+ proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
+ ConnectionId, Peer, Receipt, TypedEnvelope,
+};
+use std::{rc::Rc, sync::Arc};
+use util::http::FakeHttpClient;
+
+pub struct FakeServer {
+ peer: Arc<Peer>,
+ state: Arc<Mutex<FakeServerState>>,
+ user_id: u64,
+ executor: Rc<executor::Foreground>,
+}
+
+#[derive(Default)]
+struct FakeServerState {
+ incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
+ connection_id: Option<ConnectionId>,
+ forbid_connections: bool,
+ auth_count: usize,
+ access_token: usize,
+}
+
+impl FakeServer {
+ pub async fn for_client(
+ client_user_id: u64,
+ client: &Arc<Client>,
+ cx: &TestAppContext,
+ ) -> Self {
+ let server = Self {
+ peer: Peer::new(0),
+ state: Default::default(),
+ user_id: client_user_id,
+ executor: cx.foreground(),
+ };
+
+ client
+ .override_authenticate({
+ let state = Arc::downgrade(&server.state);
+ move |cx| {
+ let state = state.clone();
+ cx.spawn(move |_| async move {
+ let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
+ let mut state = state.lock();
+ state.auth_count += 1;
+ let access_token = state.access_token.to_string();
+ Ok(Credentials {
+ user_id: client_user_id,
+ access_token,
+ })
+ })
+ }
+ })
+ .override_establish_connection({
+ let peer = Arc::downgrade(&server.peer);
+ let state = Arc::downgrade(&server.state);
+ move |credentials, cx| {
+ let peer = peer.clone();
+ let state = state.clone();
+ let credentials = credentials.clone();
+ cx.spawn(move |cx| async move {
+ let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
+ let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
+ if state.lock().forbid_connections {
+ Err(EstablishConnectionError::Other(anyhow!(
+ "server is forbidding connections"
+ )))?
+ }
+
+ assert_eq!(credentials.user_id, client_user_id);
+
+ if credentials.access_token != state.lock().access_token.to_string() {
+ Err(EstablishConnectionError::Unauthorized)?
+ }
+
+ let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
+ let (connection_id, io, incoming) =
+ peer.add_test_connection(server_conn, cx.background());
+ cx.background().spawn(io).detach();
+ {
+ let mut state = state.lock();
+ state.connection_id = Some(connection_id);
+ state.incoming = Some(incoming);
+ }
+ peer.send(
+ connection_id,
+ proto::Hello {
+ peer_id: Some(connection_id.into()),
+ },
+ )
+ .unwrap();
+
+ Ok(client_conn)
+ })
+ }
+ });
+
+ client
+ .authenticate_and_connect(false, &cx.to_async())
+ .await
+ .unwrap();
+
+ server
+ }
+
+ pub fn disconnect(&self) {
+ if self.state.lock().connection_id.is_some() {
+ self.peer.disconnect(self.connection_id());
+ let mut state = self.state.lock();
+ state.connection_id.take();
+ state.incoming.take();
+ }
+ }
+
+ pub fn auth_count(&self) -> usize {
+ self.state.lock().auth_count
+ }
+
+ pub fn roll_access_token(&self) {
+ self.state.lock().access_token += 1;
+ }
+
+ pub fn forbid_connections(&self) {
+ self.state.lock().forbid_connections = true;
+ }
+
+ pub fn allow_connections(&self) {
+ self.state.lock().forbid_connections = false;
+ }
+
+ pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
+ self.peer.send(self.connection_id(), message).unwrap();
+ }
+
+ #[allow(clippy::await_holding_lock)]
+ pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
+ self.executor.start_waiting();
+
+ loop {
+ let message = self
+ .state
+ .lock()
+ .incoming
+ .as_mut()
+ .expect("not connected")
+ .next()
+ .await
+ .ok_or_else(|| anyhow!("other half hung up"))?;
+ self.executor.finish_waiting();
+ let type_name = message.payload_type_name();
+ let message = message.into_any();
+
+ if message.is::<TypedEnvelope<M>>() {
+ return Ok(*message.downcast().unwrap());
+ }
+
+ if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
+ self.respond(
+ message
+ .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
+ .unwrap()
+ .receipt(),
+ GetPrivateUserInfoResponse {
+ metrics_id: "the-metrics-id".into(),
+ staff: false,
+ flags: Default::default(),
+ },
+ );
+ continue;
+ }
+
+ panic!(
+ "fake server received unexpected message type: {:?}",
+ type_name
+ );
+ }
+ }
+
+ pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
+ self.peer.respond(receipt, response).unwrap()
+ }
+
+ fn connection_id(&self) -> ConnectionId {
+ self.state.lock().connection_id.expect("not connected")
+ }
+
+ pub async fn build_user_store(
+ &self,
+ client: Arc<Client>,
+ cx: &mut TestAppContext,
+ ) -> ModelHandle<UserStore> {
+ let http_client = FakeHttpClient::with_404_response();
+ let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
+ assert_eq!(
+ self.receive::<proto::GetUsers>()
+ .await
+ .unwrap()
+ .payload
+ .user_ids,
+ &[self.user_id]
+ );
+ user_store
+ }
+}
+
+impl Drop for FakeServer {
+ fn drop(&mut self) {
+ self.disconnect();
+ }
+}
@@ -0,0 +1,737 @@
+use super::{proto, Client, Status, TypedEnvelope};
+use anyhow::{anyhow, Context, Result};
+use collections::{hash_map::Entry, HashMap, HashSet};
+use feature_flags::FeatureFlagAppExt;
+use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt};
+use gpui::{AsyncAppContext, Entity, ImageData, ModelContext, ModelHandle, Task};
+use postage::{sink::Sink, watch};
+use rpc::proto::{RequestMessage, UsersResponse};
+use std::sync::{Arc, Weak};
+use text::ReplicaId;
+use util::http::HttpClient;
+use util::TryFutureExt as _;
+
+pub type UserId = u64;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub struct ParticipantIndex(pub u32);
+
+#[derive(Default, Debug)]
+pub struct User {
+ pub id: UserId,
+ pub github_login: String,
+ pub avatar: Option<Arc<ImageData>>,
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct Collaborator {
+ pub peer_id: proto::PeerId,
+ pub replica_id: ReplicaId,
+ pub user_id: UserId,
+}
+
+impl PartialOrd for User {
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl Ord for User {
+ fn cmp(&self, other: &Self) -> std::cmp::Ordering {
+ self.github_login.cmp(&other.github_login)
+ }
+}
+
+impl PartialEq for User {
+ fn eq(&self, other: &Self) -> bool {
+ self.id == other.id && self.github_login == other.github_login
+ }
+}
+
+impl Eq for User {}
+
+#[derive(Debug, PartialEq)]
+pub struct Contact {
+ pub user: Arc<User>,
+ pub online: bool,
+ pub busy: bool,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum ContactRequestStatus {
+ None,
+ RequestSent,
+ RequestReceived,
+ RequestAccepted,
+}
+
+pub struct UserStore {
+ users: HashMap<u64, Arc<User>>,
+ participant_indices: HashMap<u64, ParticipantIndex>,
+ update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
+ current_user: watch::Receiver<Option<Arc<User>>>,
+ contacts: Vec<Arc<Contact>>,
+ incoming_contact_requests: Vec<Arc<User>>,
+ outgoing_contact_requests: Vec<Arc<User>>,
+ pending_contact_requests: HashMap<u64, usize>,
+ invite_info: Option<InviteInfo>,
+ client: Weak<Client>,
+ http: Arc<dyn HttpClient>,
+ _maintain_contacts: Task<()>,
+ _maintain_current_user: Task<()>,
+}
+
+#[derive(Clone)]
+pub struct InviteInfo {
+ pub count: u32,
+ pub url: Arc<str>,
+}
+
+pub enum Event {
+ Contact {
+ user: Arc<User>,
+ kind: ContactEventKind,
+ },
+ ShowContacts,
+ ParticipantIndicesChanged,
+}
+
+#[derive(Clone, Copy)]
+pub enum ContactEventKind {
+ Requested,
+ Accepted,
+ Cancelled,
+}
+
+impl Entity for UserStore {
+ type Event = Event;
+}
+
+enum UpdateContacts {
+ Update(proto::UpdateContacts),
+ Wait(postage::barrier::Sender),
+ Clear(postage::barrier::Sender),
+}
+
+impl UserStore {
+ pub fn new(
+ client: Arc<Client>,
+ http: Arc<dyn HttpClient>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let (mut current_user_tx, current_user_rx) = watch::channel();
+ let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded();
+ let rpc_subscriptions = vec![
+ client.add_message_handler(cx.handle(), Self::handle_update_contacts),
+ client.add_message_handler(cx.handle(), Self::handle_update_invite_info),
+ client.add_message_handler(cx.handle(), Self::handle_show_contacts),
+ ];
+ Self {
+ users: Default::default(),
+ current_user: current_user_rx,
+ contacts: Default::default(),
+ incoming_contact_requests: Default::default(),
+ participant_indices: Default::default(),
+ outgoing_contact_requests: Default::default(),
+ invite_info: None,
+ client: Arc::downgrade(&client),
+ update_contacts_tx,
+ http,
+ _maintain_contacts: cx.spawn_weak(|this, mut cx| async move {
+ let _subscriptions = rpc_subscriptions;
+ while let Some(message) = update_contacts_rx.next().await {
+ if let Some(this) = this.upgrade(&cx) {
+ this.update(&mut cx, |this, cx| this.update_contacts(message, cx))
+ .log_err()
+ .await;
+ }
+ }
+ }),
+ _maintain_current_user: cx.spawn_weak(|this, mut cx| async move {
+ let mut status = client.status();
+ while let Some(status) = status.next().await {
+ match status {
+ Status::Connected { .. } => {
+ if let Some((this, user_id)) = this.upgrade(&cx).zip(client.user_id()) {
+ let fetch_user = this
+ .update(&mut cx, |this, cx| this.get_user(user_id, cx))
+ .log_err();
+ let fetch_metrics_id =
+ client.request(proto::GetPrivateUserInfo {}).log_err();
+ let (user, info) = futures::join!(fetch_user, fetch_metrics_id);
+
+ if let Some(info) = info {
+ cx.update(|cx| {
+ cx.update_flags(info.staff, info.flags);
+ client.telemetry.set_authenticated_user_info(
+ Some(info.metrics_id.clone()),
+ info.staff,
+ cx,
+ )
+ });
+ } else {
+ cx.read(|cx| {
+ client
+ .telemetry
+ .set_authenticated_user_info(None, false, cx)
+ });
+ }
+
+ current_user_tx.send(user).await.ok();
+
+ this.update(&mut cx, |_, cx| {
+ cx.notify();
+ });
+ }
+ }
+ Status::SignedOut => {
+ current_user_tx.send(None).await.ok();
+ if let Some(this) = this.upgrade(&cx) {
+ this.update(&mut cx, |this, cx| {
+ cx.notify();
+ this.clear_contacts()
+ })
+ .await;
+ }
+ }
+ Status::ConnectionLost => {
+ if let Some(this) = this.upgrade(&cx) {
+ this.update(&mut cx, |this, cx| {
+ cx.notify();
+ this.clear_contacts()
+ })
+ .await;
+ }
+ }
+ _ => {}
+ }
+ }
+ }),
+ pending_contact_requests: Default::default(),
+ }
+ }
+
+ #[cfg(feature = "test-support")]
+ pub fn clear_cache(&mut self) {
+ self.users.clear();
+ }
+
+ async fn handle_update_invite_info(
+ this: ModelHandle<Self>,
+ message: TypedEnvelope<proto::UpdateInviteInfo>,
+ _: Arc<Client>,
+ mut cx: AsyncAppContext,
+ ) -> Result<()> {
+ this.update(&mut cx, |this, cx| {
+ this.invite_info = Some(InviteInfo {
+ url: Arc::from(message.payload.url),
+ count: message.payload.count,
+ });
+ cx.notify();
+ });
+ Ok(())
+ }
+
+ async fn handle_show_contacts(
+ this: ModelHandle<Self>,
+ _: TypedEnvelope<proto::ShowContacts>,
+ _: Arc<Client>,
+ mut cx: AsyncAppContext,
+ ) -> Result<()> {
+ this.update(&mut cx, |_, cx| cx.emit(Event::ShowContacts));
+ Ok(())
+ }
+
+ pub fn invite_info(&self) -> Option<&InviteInfo> {
+ self.invite_info.as_ref()
+ }
+
+ async fn handle_update_contacts(
+ this: ModelHandle<Self>,
+ message: TypedEnvelope<proto::UpdateContacts>,
+ _: Arc<Client>,
+ mut cx: AsyncAppContext,
+ ) -> Result<()> {
+ this.update(&mut cx, |this, _| {
+ this.update_contacts_tx
+ .unbounded_send(UpdateContacts::Update(message.payload))
+ .unwrap();
+ });
+ Ok(())
+ }
+
+ fn update_contacts(
+ &mut self,
+ message: UpdateContacts,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<()>> {
+ match message {
+ UpdateContacts::Wait(barrier) => {
+ drop(barrier);
+ Task::ready(Ok(()))
+ }
+ UpdateContacts::Clear(barrier) => {
+ self.contacts.clear();
+ self.incoming_contact_requests.clear();
+ self.outgoing_contact_requests.clear();
+ drop(barrier);
+ Task::ready(Ok(()))
+ }
+ UpdateContacts::Update(message) => {
+ let mut user_ids = HashSet::default();
+ for contact in &message.contacts {
+ user_ids.insert(contact.user_id);
+ }
+ user_ids.extend(message.incoming_requests.iter().map(|req| req.requester_id));
+ user_ids.extend(message.outgoing_requests.iter());
+
+ let load_users = self.get_users(user_ids.into_iter().collect(), cx);
+ cx.spawn(|this, mut cx| async move {
+ load_users.await?;
+
+ // Users are fetched in parallel above and cached in call to get_users
+ // No need to paralellize here
+ let mut updated_contacts = Vec::new();
+ for contact in message.contacts {
+ let should_notify = contact.should_notify;
+ updated_contacts.push((
+ Arc::new(Contact::from_proto(contact, &this, &mut cx).await?),
+ should_notify,
+ ));
+ }
+
+ let mut incoming_requests = Vec::new();
+ for request in message.incoming_requests {
+ incoming_requests.push({
+ let user = this
+ .update(&mut cx, |this, cx| this.get_user(request.requester_id, cx))
+ .await?;
+ (user, request.should_notify)
+ });
+ }
+
+ let mut outgoing_requests = Vec::new();
+ for requested_user_id in message.outgoing_requests {
+ outgoing_requests.push(
+ this.update(&mut cx, |this, cx| this.get_user(requested_user_id, cx))
+ .await?,
+ );
+ }
+
+ let removed_contacts =
+ HashSet::<u64>::from_iter(message.remove_contacts.iter().copied());
+ let removed_incoming_requests =
+ HashSet::<u64>::from_iter(message.remove_incoming_requests.iter().copied());
+ let removed_outgoing_requests =
+ HashSet::<u64>::from_iter(message.remove_outgoing_requests.iter().copied());
+
+ this.update(&mut cx, |this, cx| {
+ // Remove contacts
+ this.contacts
+ .retain(|contact| !removed_contacts.contains(&contact.user.id));
+ // Update existing contacts and insert new ones
+ for (updated_contact, should_notify) in updated_contacts {
+ if should_notify {
+ cx.emit(Event::Contact {
+ user: updated_contact.user.clone(),
+ kind: ContactEventKind::Accepted,
+ });
+ }
+ match this.contacts.binary_search_by_key(
+ &&updated_contact.user.github_login,
+ |contact| &contact.user.github_login,
+ ) {
+ Ok(ix) => this.contacts[ix] = updated_contact,
+ Err(ix) => this.contacts.insert(ix, updated_contact),
+ }
+ }
+
+ // Remove incoming contact requests
+ this.incoming_contact_requests.retain(|user| {
+ if removed_incoming_requests.contains(&user.id) {
+ cx.emit(Event::Contact {
+ user: user.clone(),
+ kind: ContactEventKind::Cancelled,
+ });
+ false
+ } else {
+ true
+ }
+ });
+ // Update existing incoming requests and insert new ones
+ for (user, should_notify) in incoming_requests {
+ if should_notify {
+ cx.emit(Event::Contact {
+ user: user.clone(),
+ kind: ContactEventKind::Requested,
+ });
+ }
+
+ match this
+ .incoming_contact_requests
+ .binary_search_by_key(&&user.github_login, |contact| {
+ &contact.github_login
+ }) {
+ Ok(ix) => this.incoming_contact_requests[ix] = user,
+ Err(ix) => this.incoming_contact_requests.insert(ix, user),
+ }
+ }
+
+ // Remove outgoing contact requests
+ this.outgoing_contact_requests
+ .retain(|user| !removed_outgoing_requests.contains(&user.id));
+ // Update existing incoming requests and insert new ones
+ for request in outgoing_requests {
+ match this
+ .outgoing_contact_requests
+ .binary_search_by_key(&&request.github_login, |contact| {
+ &contact.github_login
+ }) {
+ Ok(ix) => this.outgoing_contact_requests[ix] = request,
+ Err(ix) => this.outgoing_contact_requests.insert(ix, request),
+ }
+ }
+
+ cx.notify();
+ });
+
+ Ok(())
+ })
+ }
+ }
+ }
+
+ pub fn contacts(&self) -> &[Arc<Contact>] {
+ &self.contacts
+ }
+
+ pub fn has_contact(&self, user: &Arc<User>) -> bool {
+ self.contacts
+ .binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login)
+ .is_ok()
+ }
+
+ pub fn incoming_contact_requests(&self) -> &[Arc<User>] {
+ &self.incoming_contact_requests
+ }
+
+ pub fn outgoing_contact_requests(&self) -> &[Arc<User>] {
+ &self.outgoing_contact_requests
+ }
+
+ pub fn is_contact_request_pending(&self, user: &User) -> bool {
+ self.pending_contact_requests.contains_key(&user.id)
+ }
+
+ pub fn contact_request_status(&self, user: &User) -> ContactRequestStatus {
+ if self
+ .contacts
+ .binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login)
+ .is_ok()
+ {
+ ContactRequestStatus::RequestAccepted
+ } else if self
+ .outgoing_contact_requests
+ .binary_search_by_key(&&user.github_login, |user| &user.github_login)
+ .is_ok()
+ {
+ ContactRequestStatus::RequestSent
+ } else if self
+ .incoming_contact_requests
+ .binary_search_by_key(&&user.github_login, |user| &user.github_login)
+ .is_ok()
+ {
+ ContactRequestStatus::RequestReceived
+ } else {
+ ContactRequestStatus::None
+ }
+ }
+
+ pub fn request_contact(
+ &mut self,
+ responder_id: u64,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<()>> {
+ self.perform_contact_request(responder_id, proto::RequestContact { responder_id }, cx)
+ }
+
+ pub fn remove_contact(
+ &mut self,
+ user_id: u64,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<()>> {
+ self.perform_contact_request(user_id, proto::RemoveContact { user_id }, cx)
+ }
+
+ pub fn respond_to_contact_request(
+ &mut self,
+ requester_id: u64,
+ accept: bool,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<()>> {
+ self.perform_contact_request(
+ requester_id,
+ proto::RespondToContactRequest {
+ requester_id,
+ response: if accept {
+ proto::ContactRequestResponse::Accept
+ } else {
+ proto::ContactRequestResponse::Decline
+ } as i32,
+ },
+ cx,
+ )
+ }
+
+ pub fn dismiss_contact_request(
+ &mut self,
+ requester_id: u64,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<()>> {
+ let client = self.client.upgrade();
+ cx.spawn_weak(|_, _| async move {
+ client
+ .ok_or_else(|| anyhow!("can't upgrade client reference"))?
+ .request(proto::RespondToContactRequest {
+ requester_id,
+ response: proto::ContactRequestResponse::Dismiss as i32,
+ })
+ .await?;
+ Ok(())
+ })
+ }
+
+ fn perform_contact_request<T: RequestMessage>(
+ &mut self,
+ user_id: u64,
+ request: T,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<()>> {
+ let client = self.client.upgrade();
+ *self.pending_contact_requests.entry(user_id).or_insert(0) += 1;
+ cx.notify();
+
+ cx.spawn(|this, mut cx| async move {
+ let response = client
+ .ok_or_else(|| anyhow!("can't upgrade client reference"))?
+ .request(request)
+ .await;
+ this.update(&mut cx, |this, cx| {
+ if let Entry::Occupied(mut request_count) =
+ this.pending_contact_requests.entry(user_id)
+ {
+ *request_count.get_mut() -= 1;
+ if *request_count.get() == 0 {
+ request_count.remove();
+ }
+ }
+ cx.notify();
+ });
+ response?;
+ Ok(())
+ })
+ }
+
+ pub fn clear_contacts(&mut self) -> impl Future<Output = ()> {
+ let (tx, mut rx) = postage::barrier::channel();
+ self.update_contacts_tx
+ .unbounded_send(UpdateContacts::Clear(tx))
+ .unwrap();
+ async move {
+ rx.next().await;
+ }
+ }
+
+ pub fn contact_updates_done(&mut self) -> impl Future<Output = ()> {
+ let (tx, mut rx) = postage::barrier::channel();
+ self.update_contacts_tx
+ .unbounded_send(UpdateContacts::Wait(tx))
+ .unwrap();
+ async move {
+ rx.next().await;
+ }
+ }
+
+ pub fn get_users(
+ &mut self,
+ user_ids: Vec<u64>,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<Vec<Arc<User>>>> {
+ let mut user_ids_to_fetch = user_ids.clone();
+ user_ids_to_fetch.retain(|id| !self.users.contains_key(id));
+
+ cx.spawn(|this, mut cx| async move {
+ if !user_ids_to_fetch.is_empty() {
+ this.update(&mut cx, |this, cx| {
+ this.load_users(
+ proto::GetUsers {
+ user_ids: user_ids_to_fetch,
+ },
+ cx,
+ )
+ })
+ .await?;
+ }
+
+ this.read_with(&cx, |this, _| {
+ user_ids
+ .iter()
+ .map(|user_id| {
+ this.users
+ .get(user_id)
+ .cloned()
+ .ok_or_else(|| anyhow!("user {} not found", user_id))
+ })
+ .collect()
+ })
+ })
+ }
+
+ pub fn fuzzy_search_users(
+ &mut self,
+ query: String,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<Vec<Arc<User>>>> {
+ self.load_users(proto::FuzzySearchUsers { query }, cx)
+ }
+
+ pub fn get_cached_user(&self, user_id: u64) -> Option<Arc<User>> {
+ self.users.get(&user_id).cloned()
+ }
+
+ pub fn get_user(
+ &mut self,
+ user_id: u64,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<Arc<User>>> {
+ if let Some(user) = self.users.get(&user_id).cloned() {
+ return cx.foreground().spawn(async move { Ok(user) });
+ }
+
+ let load_users = self.get_users(vec![user_id], cx);
+ cx.spawn(|this, mut cx| async move {
+ load_users.await?;
+ this.update(&mut cx, |this, _| {
+ this.users
+ .get(&user_id)
+ .cloned()
+ .ok_or_else(|| anyhow!("server responded with no users"))
+ })
+ })
+ }
+
+ pub fn current_user(&self) -> Option<Arc<User>> {
+ self.current_user.borrow().clone()
+ }
+
+ pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
+ self.current_user.clone()
+ }
+
+ fn load_users(
+ &mut self,
+ request: impl RequestMessage<Response = UsersResponse>,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<Vec<Arc<User>>>> {
+ let client = self.client.clone();
+ let http = self.http.clone();
+ cx.spawn_weak(|this, mut cx| async move {
+ if let Some(rpc) = client.upgrade() {
+ let response = rpc.request(request).await.context("error loading users")?;
+ let users = future::join_all(
+ response
+ .users
+ .into_iter()
+ .map(|user| User::new(user, http.as_ref())),
+ )
+ .await;
+
+ if let Some(this) = this.upgrade(&cx) {
+ this.update(&mut cx, |this, _| {
+ for user in &users {
+ this.users.insert(user.id, user.clone());
+ }
+ });
+ }
+ Ok(users)
+ } else {
+ Ok(Vec::new())
+ }
+ })
+ }
+
+ pub fn set_participant_indices(
+ &mut self,
+ participant_indices: HashMap<u64, ParticipantIndex>,
+ cx: &mut ModelContext<Self>,
+ ) {
+ if participant_indices != self.participant_indices {
+ self.participant_indices = participant_indices;
+ cx.emit(Event::ParticipantIndicesChanged);
+ }
+ }
+
+ pub fn participant_indices(&self) -> &HashMap<u64, ParticipantIndex> {
+ &self.participant_indices
+ }
+}
+
+impl User {
+ async fn new(message: proto::User, http: &dyn HttpClient) -> Arc<Self> {
+ Arc::new(User {
+ id: message.id,
+ github_login: message.github_login,
+ avatar: fetch_avatar(http, &message.avatar_url).warn_on_err().await,
+ })
+ }
+}
+
+impl Contact {
+ async fn from_proto(
+ contact: proto::Contact,
+ user_store: &ModelHandle<UserStore>,
+ cx: &mut AsyncAppContext,
+ ) -> Result<Self> {
+ let user = user_store
+ .update(cx, |user_store, cx| {
+ user_store.get_user(contact.user_id, cx)
+ })
+ .await?;
+ Ok(Self {
+ user,
+ online: contact.online,
+ busy: contact.busy,
+ })
+ }
+}
+
+impl Collaborator {
+ pub fn from_proto(message: proto::Collaborator) -> Result<Self> {
+ Ok(Self {
+ peer_id: message.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?,
+ replica_id: message.replica_id as ReplicaId,
+ user_id: message.user_id as UserId,
+ })
+ }
+}
+
+async fn fetch_avatar(http: &dyn HttpClient, url: &str) -> Result<Arc<ImageData>> {
+ let mut response = http
+ .get(url, Default::default(), true)
+ .await
+ .map_err(|e| anyhow!("failed to send user avatar request: {}", e))?;
+
+ if !response.status().is_success() {
+ return Err(anyhow!("avatar request failed {:?}", response.status()));
+ }
+
+ let mut body = Vec::new();
+ response
+ .body_mut()
+ .read_to_end(&mut body)
+ .await
+ .map_err(|e| anyhow!("failed to read user avatar response body: {}", e))?;
+ let format = image::guess_format(&body)?;
+ let image = image::load_from_memory_with_format(&body, format)?.into_bgra8();
+ Ok(ImageData::new(image))
+}
@@ -11,8 +11,8 @@ use smallvec::SmallVec;
use crate::{
current_platform, image_cache::ImageCache, Action, AssetSource, Context, DisplayId, Executor,
FocusEvent, FocusHandle, FocusId, KeyBinding, Keymap, LayoutId, MainThread, MainThreadOnly,
- Platform, SharedString, SubscriberSet, SvgRenderer, Task, TextStyle, TextStyleRefinement,
- TextSystem, View, Window, WindowContext, WindowHandle, WindowId,
+ Platform, SemanticVersion, SharedString, SubscriberSet, SvgRenderer, Task, TextStyle,
+ TextStyleRefinement, TextSystem, View, Window, WindowContext, WindowHandle, WindowId,
};
use anyhow::{anyhow, Result};
use collections::{HashMap, HashSet, VecDeque};
@@ -125,6 +125,18 @@ impl App {
self
}
+ pub fn app_version(&self) -> Result<SemanticVersion> {
+ self.0.lock().platform.borrow_on_main_thread().app_version()
+ }
+
+ pub fn os_name(&self) -> &'static str {
+ self.0.lock().platform.borrow_on_main_thread().os_name()
+ }
+
+ pub fn os_version(&self) -> Result<SemanticVersion> {
+ self.0.lock().platform.borrow_on_main_thread().os_version()
+ }
+
pub fn executor(&self) -> Executor {
self.0.lock().executor.clone()
}
@@ -348,7 +360,7 @@ impl AppContext {
.ok();
}
- pub fn apply_refresh(&mut self) {
+ fn apply_refresh(&mut self) {
for window in self.windows.values_mut() {
if let Some(window) = window.as_mut() {
window.dirty = true;
@@ -27,7 +27,7 @@ collections = { path = "../collections" }
# command_palette = { path = "../command_palette" }
# component_test = { path = "../component_test" }
# context_menu = { path = "../context_menu" }
-# client = { path = "../client" }
+client2 = { path = "../client2" }
# clock = { path = "../clock" }
# copilot = { path = "../copilot" }
# copilot_button = { path = "../copilot_button" }
@@ -3,13 +3,15 @@
use crate::open_listener::{OpenListener, OpenRequest};
use anyhow::{anyhow, Context, Result};
+use backtrace::Backtrace;
use cli::{
ipc::{self, IpcSender},
CliRequest, CliResponse, IpcHandshake, FORCE_CLI_MODE_ENV_VAR_NAME,
};
use fs::RealFs;
use futures::{channel::mpsc, SinkExt, StreamExt};
-use gpui2::{App, AssetSource, AsyncAppContext, Task};
+use gpui2::{App, AppContext, AssetSource, AsyncAppContext, SemanticVersion, Task};
+use isahc::{prelude::Configurable, Request};
use log::LevelFilter;
use parking_lot::Mutex;
@@ -19,17 +21,21 @@ use simplelog::ConfigBuilder;
use smol::process::Command;
use std::{
env,
+ ffi::OsStr,
fs::OpenOptions,
- io::IsTerminal,
+ io::{IsTerminal, Write},
+ panic,
path::Path,
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
thread,
+ time::{SystemTime, UNIX_EPOCH},
};
use util::{
- channel::{parse_zed_link, RELEASE_CHANNEL},
+ channel::{parse_zed_link, ReleaseChannel, RELEASE_CHANNEL},
+ http::HttpClient,
paths, ResultExt,
};
use zed2::{ensure_only_instance, AppState, Assets, IsOnlyInstance};
@@ -75,8 +81,8 @@ fn main() {
let (listener, mut open_rx) = OpenListener::new();
let listener = Arc::new(listener);
- let callback_listener = listener.clone();
- app.on_open_urls(move |urls, _| callback_listener.open_urls(urls));
+ let open_listener = listener.clone();
+ app.on_open_urls(move |urls, _| open_listener.open_urls(urls));
app.on_reopen(move |_cx| {
// todo!("workspace")
// if cx.has_global::<Weak<AppState>>() {
@@ -394,192 +400,184 @@ struct PanicRequest {
token: String,
}
-static _PANIC_COUNT: AtomicU32 = AtomicU32::new(0);
+static PANIC_COUNT: AtomicU32 = AtomicU32::new(0);
-// fn init_panic_hook(app: &App, installation_id: Option<String>, session_id: String) {
-// let is_pty = stdout_is_a_pty();
-// let platform = app.platform();
+fn init_panic_hook(app: &App, installation_id: Option<String>, session_id: String) {
+ let is_pty = stdout_is_a_pty();
+ let app_version = app.app_version().ok();
+ let os_name = app.os_name();
+ let os_version = app.os_version().ok();
-// panic::set_hook(Box::new(move |info| {
-// let prior_panic_count = PANIC_COUNT.fetch_add(1, Ordering::SeqCst);
-// if prior_panic_count > 0 {
-// // Give the panic-ing thread time to write the panic file
-// loop {
-// std::thread::yield_now();
-// }
-// }
+ panic::set_hook(Box::new(move |info| {
+ let prior_panic_count = PANIC_COUNT.fetch_add(1, Ordering::SeqCst);
+ if prior_panic_count > 0 {
+ // Give the panic-ing thread time to write the panic file
+ loop {
+ std::thread::yield_now();
+ }
+ }
-// let thread = thread::current();
-// let thread_name = thread.name().unwrap_or("<unnamed>");
+ let thread = thread::current();
+ let thread_name = thread.name().unwrap_or("<unnamed>");
+
+ let payload = info
+ .payload()
+ .downcast_ref::<&str>()
+ .map(|s| s.to_string())
+ .or_else(|| info.payload().downcast_ref::<String>().map(|s| s.clone()))
+ .unwrap_or_else(|| "Box<Any>".to_string());
+
+ if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Dev {
+ let location = info.location().unwrap();
+ let backtrace = Backtrace::new();
+ eprintln!(
+ "Thread {:?} panicked with {:?} at {}:{}:{}\n{:?}",
+ thread_name,
+ payload,
+ location.file(),
+ location.line(),
+ location.column(),
+ backtrace,
+ );
+ std::process::exit(-1);
+ }
-// let payload = info
-// .payload()
-// .downcast_ref::<&str>()
-// .map(|s| s.to_string())
-// .or_else(|| info.payload().downcast_ref::<String>().map(|s| s.clone()))
-// .unwrap_or_else(|| "Box<Any>".to_string());
+ let app_version = client::ZED_APP_VERSION
+ .or(app_version)
+ .map_or("dev".to_string(), |v| v.to_string());
+
+ let backtrace = Backtrace::new();
+ let mut backtrace = backtrace
+ .frames()
+ .iter()
+ .filter_map(|frame| Some(format!("{:#}", frame.symbols().first()?.name()?)))
+ .collect::<Vec<_>>();
+
+ // Strip out leading stack frames for rust panic-handling.
+ if let Some(ix) = backtrace
+ .iter()
+ .position(|name| name == "rust_begin_unwind")
+ {
+ backtrace.drain(0..=ix);
+ }
-// if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Dev {
-// let location = info.location().unwrap();
-// let backtrace = Backtrace::new();
-// eprintln!(
-// "Thread {:?} panicked with {:?} at {}:{}:{}\n{:?}",
-// thread_name,
-// payload,
-// location.file(),
-// location.line(),
-// location.column(),
-// backtrace,
-// );
-// std::process::exit(-1);
-// }
+ let panic_data = Panic {
+ thread: thread_name.into(),
+ payload: payload.into(),
+ location_data: info.location().map(|location| LocationData {
+ file: location.file().into(),
+ line: location.line(),
+ }),
+ app_version: app_version.clone(),
+ release_channel: RELEASE_CHANNEL.display_name().into(),
+ os_name: os_name.into(),
+ os_version: os_version.as_ref().map(SemanticVersion::to_string),
+ architecture: env::consts::ARCH.into(),
+ panicked_on: SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap()
+ .as_millis(),
+ backtrace,
+ installation_id: installation_id.clone(),
+ session_id: session_id.clone(),
+ };
+
+ if let Some(panic_data_json) = serde_json::to_string_pretty(&panic_data).log_err() {
+ log::error!("{}", panic_data_json);
+ }
-// let app_version = ZED_APP_VERSION
-// .or_else(|| platform.app_version().ok())
-// .map_or("dev".to_string(), |v| v.to_string());
-
-// let backtrace = Backtrace::new();
-// let mut backtrace = backtrace
-// .frames()
-// .iter()
-// .filter_map(|frame| Some(format!("{:#}", frame.symbols().first()?.name()?)))
-// .collect::<Vec<_>>();
-
-// // Strip out leading stack frames for rust panic-handling.
-// if let Some(ix) = backtrace
-// .iter()
-// .position(|name| name == "rust_begin_unwind")
-// {
-// backtrace.drain(0..=ix);
-// }
+ if !is_pty {
+ if let Some(panic_data_json) = serde_json::to_string(&panic_data).log_err() {
+ let timestamp = chrono::Utc::now().format("%Y_%m_%d %H_%M_%S").to_string();
+ let panic_file_path = paths::LOGS_DIR.join(format!("zed-{}.panic", timestamp));
+ let panic_file = std::fs::OpenOptions::new()
+ .append(true)
+ .create(true)
+ .open(&panic_file_path)
+ .log_err();
+ if let Some(mut panic_file) = panic_file {
+ writeln!(&mut panic_file, "{}", panic_data_json).log_err();
+ panic_file.flush().log_err();
+ }
+ }
+ }
-// let panic_data = Panic {
-// thread: thread_name.into(),
-// payload: payload.into(),
-// location_data: info.location().map(|location| LocationData {
-// file: location.file().into(),
-// line: location.line(),
-// }),
-// app_version: app_version.clone(),
-// release_channel: RELEASE_CHANNEL.display_name().into(),
-// os_name: platform.os_name().into(),
-// os_version: platform
-// .os_version()
-// .ok()
-// .map(|os_version| os_version.to_string()),
-// architecture: env::consts::ARCH.into(),
-// panicked_on: SystemTime::now()
-// .duration_since(UNIX_EPOCH)
-// .unwrap()
-// .as_millis(),
-// backtrace,
-// installation_id: installation_id.clone(),
-// session_id: session_id.clone(),
-// };
-
-// if let Some(panic_data_json) = serde_json::to_string_pretty(&panic_data).log_err() {
-// log::error!("{}", panic_data_json);
-// }
+ std::process::abort();
+ }));
+}
-// if !is_pty {
-// if let Some(panic_data_json) = serde_json::to_string(&panic_data).log_err() {
-// let timestamp = chrono::Utc::now().format("%Y_%m_%d %H_%M_%S").to_string();
-// let panic_file_path = paths::LOGS_DIR.join(format!("zed-{}.panic", timestamp));
-// let panic_file = std::fs::OpenOptions::new()
-// .append(true)
-// .create(true)
-// .open(&panic_file_path)
-// .log_err();
-// if let Some(mut panic_file) = panic_file {
-// writeln!(&mut panic_file, "{}", panic_data_json).log_err();
-// panic_file.flush().log_err();
-// }
-// }
-// }
+fn upload_previous_panics(http: Arc<dyn HttpClient>, cx: &mut AppContext) {
+ let telemetry_settings = *settings2::get::<client::TelemetrySettings>(cx);
-// std::process::abort();
-// }));
-// }
+ cx.executor()
+ .spawn(async move {
+ let panic_report_url = format!("{}/api/panic", &*client::ZED_SERVER_URL);
+ let mut children = smol::fs::read_dir(&*paths::LOGS_DIR).await?;
+ while let Some(child) = children.next().await {
+ let child = child?;
+ let child_path = child.path();
-// fn upload_previous_panics(http: Arc<dyn HttpClient>, cx: &mut AppContext) {
-// let telemetry_settings = *settings::get::<TelemetrySettings>(cx);
-
-// cx.background()
-// .spawn({
-// async move {
-// let panic_report_url = format!("{}/api/panic", &*client::ZED_SERVER_URL);
-// let mut children = smol::fs::read_dir(&*paths::LOGS_DIR).await?;
-// while let Some(child) = children.next().await {
-// let child = child?;
-// let child_path = child.path();
-
-// if child_path.extension() != Some(OsStr::new("panic")) {
-// continue;
-// }
-// let filename = if let Some(filename) = child_path.file_name() {
-// filename.to_string_lossy()
-// } else {
-// continue;
-// };
-
-// if !filename.starts_with("zed") {
-// continue;
-// }
-
-// if telemetry_settings.diagnostics {
-// let panic_file_content = smol::fs::read_to_string(&child_path)
-// .await
-// .context("error reading panic file")?;
-
-// let panic = serde_json::from_str(&panic_file_content)
-// .ok()
-// .or_else(|| {
-// panic_file_content
-// .lines()
-// .next()
-// .and_then(|line| serde_json::from_str(line).ok())
-// })
-// .unwrap_or_else(|| {
-// log::error!(
-// "failed to deserialize panic file {:?}",
-// panic_file_content
-// );
-// None
-// });
-
-// if let Some(panic) = panic {
-// let body = serde_json::to_string(&PanicRequest {
-// panic,
-// token: ZED_SECRET_CLIENT_TOKEN.into(),
-// })
-// .unwrap();
-
-// let request = Request::post(&panic_report_url)
-// .redirect_policy(isahc::config::RedirectPolicy::Follow)
-// .header("Content-Type", "application/json")
-// .body(body.into())?;
-// let response =
-// http.send(request).await.context("error sending panic")?;
-// if !response.status().is_success() {
-// log::error!(
-// "Error uploading panic to server: {}",
-// response.status()
-// );
-// }
-// }
-// }
-
-// // We've done what we can, delete the file
-// std::fs::remove_file(child_path)
-// .context("error removing panic")
-// .log_err();
-// }
-// Ok::<_, anyhow::Error>(())
-// }
-// .log_err()
-// })
-// .detach();
-// }
+ if child_path.extension() != Some(OsStr::new("panic")) {
+ continue;
+ }
+ let filename = if let Some(filename) = child_path.file_name() {
+ filename.to_string_lossy()
+ } else {
+ continue;
+ };
+
+ if !filename.starts_with("zed") {
+ continue;
+ }
+
+ if telemetry_settings.diagnostics {
+ let panic_file_content = smol::fs::read_to_string(&child_path)
+ .await
+ .context("error reading panic file")?;
+
+ let panic = serde_json::from_str(&panic_file_content)
+ .ok()
+ .or_else(|| {
+ panic_file_content
+ .lines()
+ .next()
+ .and_then(|line| serde_json::from_str(line).ok())
+ })
+ .unwrap_or_else(|| {
+ log::error!(
+ "failed to deserialize panic file {:?}",
+ panic_file_content
+ );
+ None
+ });
+
+ if let Some(panic) = panic {
+ let body = serde_json::to_string(&PanicRequest {
+ panic,
+ token: client::ZED_SECRET_CLIENT_TOKEN.into(),
+ })
+ .unwrap();
+
+ let request = Request::post(&panic_report_url)
+ .redirect_policy(isahc::config::RedirectPolicy::Follow)
+ .header("Content-Type", "application/json")
+ .body(body.into())?;
+ let response = http.send(request).await.context("error sending panic")?;
+ if !response.status().is_success() {
+ log::error!("Error uploading panic to server: {}", response.status());
+ }
+ }
+ }
+
+ // We've done what we can, delete the file
+ std::fs::remove_file(child_path)
+ .context("error removing panic")
+ .log_err();
+ }
+ Ok::<_, anyhow::Error>(())
+ })
+ .detach_and_log_err(cx);
+}
async fn load_login_shell_environment() -> Result<()> {
let marker = "ZED_LOGIN_SHELL_START";