1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod telemetry;
5pub mod user;
6
7use anyhow::{anyhow, Context as _, Result};
8use async_recursion::async_recursion;
9use async_tungstenite::tungstenite::{
10 error::Error as WebsocketError,
11 http::{Request, StatusCode},
12};
13use clock::SystemClock;
14use collections::HashMap;
15use futures::{
16 channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
17 TryFutureExt as _, TryStreamExt,
18};
19use gpui::{
20 actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Global, Model, Task, WeakModel,
21};
22use http::{HttpClient, HttpClientWithUrl};
23use lazy_static::lazy_static;
24use parking_lot::RwLock;
25use postage::watch;
26use rand::prelude::*;
27use release_channel::{AppVersion, ReleaseChannel};
28use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::{Settings, SettingsSources};
32use std::fmt;
33use std::pin::Pin;
34use std::{
35 any::TypeId,
36 convert::TryFrom,
37 fmt::Write as _,
38 future::Future,
39 marker::PhantomData,
40 path::PathBuf,
41 sync::{
42 atomic::{AtomicU64, Ordering},
43 Arc, Weak,
44 },
45 time::{Duration, Instant},
46};
47use telemetry::Telemetry;
48use thiserror::Error;
49use url::Url;
50use util::{ResultExt, TryFutureExt};
51
52pub use rpc::*;
53pub use telemetry_events::Event;
54pub use user::*;
55
56#[derive(Debug, Clone, Eq, PartialEq)]
57pub struct DevServerToken(pub String);
58
59impl fmt::Display for DevServerToken {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 write!(f, "{}", self.0)
62 }
63}
64
65lazy_static! {
66 static ref ZED_SERVER_URL: Option<String> = std::env::var("ZED_SERVER_URL").ok();
67 static ref ZED_RPC_URL: Option<String> = std::env::var("ZED_RPC_URL").ok();
68 /// An environment variable whose presence indicates that the development auth
69 /// provider should be used.
70 ///
71 /// Only works in development. Setting this environment variable in other release
72 /// channels is a no-op.
73 pub static ref ZED_DEVELOPMENT_AUTH: bool =
74 std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty());
75 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
76 .ok()
77 .and_then(|s| if s.is_empty() { None } else { Some(s) });
78 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
79 .ok()
80 .and_then(|s| if s.is_empty() { None } else { Some(s) });
81 pub static ref ZED_APP_PATH: Option<PathBuf> =
82 std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
83 pub static ref ZED_ALWAYS_ACTIVE: bool =
84 std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty());
85}
86
87pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
88pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
89pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
90
91actions!(client, [SignIn, SignOut, Reconnect]);
92
93#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
94pub struct ClientSettingsContent {
95 server_url: Option<String>,
96}
97
98#[derive(Deserialize)]
99pub struct ClientSettings {
100 pub server_url: String,
101}
102
103impl Settings for ClientSettings {
104 const KEY: Option<&'static str> = None;
105
106 type FileContent = ClientSettingsContent;
107
108 fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
109 let mut result = sources.json_merge::<Self>()?;
110 if let Some(server_url) = &*ZED_SERVER_URL {
111 result.server_url.clone_from(&server_url)
112 }
113 Ok(result)
114 }
115}
116
117#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
118pub struct ProxySettingsContent {
119 proxy: Option<String>,
120}
121
122#[derive(Deserialize, Default)]
123pub struct ProxySettings {
124 pub proxy: Option<String>,
125}
126
127impl Settings for ProxySettings {
128 const KEY: Option<&'static str> = None;
129
130 type FileContent = ProxySettingsContent;
131
132 fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
133 Ok(Self {
134 proxy: sources
135 .user
136 .and_then(|value| value.proxy.clone())
137 .or(sources.default.proxy.clone()),
138 })
139 }
140}
141
142pub fn init_settings(cx: &mut AppContext) {
143 TelemetrySettings::register(cx);
144 ClientSettings::register(cx);
145 ProxySettings::register(cx);
146}
147
148pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
149 let client = Arc::downgrade(client);
150 cx.on_action({
151 let client = client.clone();
152 move |_: &SignIn, cx| {
153 if let Some(client) = client.upgrade() {
154 cx.spawn(
155 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
156 )
157 .detach();
158 }
159 }
160 });
161
162 cx.on_action({
163 let client = client.clone();
164 move |_: &SignOut, cx| {
165 if let Some(client) = client.upgrade() {
166 cx.spawn(|cx| async move {
167 client.sign_out(&cx).await;
168 })
169 .detach();
170 }
171 }
172 });
173
174 cx.on_action({
175 let client = client.clone();
176 move |_: &Reconnect, cx| {
177 if let Some(client) = client.upgrade() {
178 cx.spawn(|cx| async move {
179 client.reconnect(&cx);
180 })
181 .detach();
182 }
183 }
184 });
185}
186
187struct GlobalClient(Arc<Client>);
188
189impl Global for GlobalClient {}
190
191pub struct Client {
192 id: AtomicU64,
193 peer: Arc<Peer>,
194 http: Arc<HttpClientWithUrl>,
195 telemetry: Arc<Telemetry>,
196 credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
197 state: RwLock<ClientState>,
198
199 #[allow(clippy::type_complexity)]
200 #[cfg(any(test, feature = "test-support"))]
201 authenticate: RwLock<
202 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
203 >,
204
205 #[allow(clippy::type_complexity)]
206 #[cfg(any(test, feature = "test-support"))]
207 establish_connection: RwLock<
208 Option<
209 Box<
210 dyn 'static
211 + Send
212 + Sync
213 + Fn(
214 &Credentials,
215 &AsyncAppContext,
216 ) -> Task<Result<Connection, EstablishConnectionError>>,
217 >,
218 >,
219 >,
220}
221
222#[derive(Error, Debug)]
223pub enum EstablishConnectionError {
224 #[error("upgrade required")]
225 UpgradeRequired,
226 #[error("unauthorized")]
227 Unauthorized,
228 #[error("{0}")]
229 Other(#[from] anyhow::Error),
230 #[error("{0}")]
231 Http(#[from] http::Error),
232 #[error("{0}")]
233 Io(#[from] std::io::Error),
234 #[error("{0}")]
235 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
236}
237
238impl From<WebsocketError> for EstablishConnectionError {
239 fn from(error: WebsocketError) -> Self {
240 if let WebsocketError::Http(response) = &error {
241 match response.status() {
242 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
243 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
244 _ => {}
245 }
246 }
247 EstablishConnectionError::Other(error.into())
248 }
249}
250
251impl EstablishConnectionError {
252 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
253 Self::Other(error.into())
254 }
255}
256
257#[derive(Copy, Clone, Debug, PartialEq)]
258pub enum Status {
259 SignedOut,
260 UpgradeRequired,
261 Authenticating,
262 Connecting,
263 ConnectionError,
264 Connected {
265 peer_id: PeerId,
266 connection_id: ConnectionId,
267 },
268 ConnectionLost,
269 Reauthenticating,
270 Reconnecting,
271 ReconnectionError {
272 next_reconnection: Instant,
273 },
274}
275
276impl Status {
277 pub fn is_connected(&self) -> bool {
278 matches!(self, Self::Connected { .. })
279 }
280
281 pub fn is_signed_out(&self) -> bool {
282 matches!(self, Self::SignedOut | Self::UpgradeRequired)
283 }
284}
285
286struct ClientState {
287 credentials: Option<Credentials>,
288 status: (watch::Sender<Status>, watch::Receiver<Status>),
289 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
290 _reconnect_task: Option<Task<()>>,
291 entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
292 models_by_message_type: HashMap<TypeId, AnyWeakModel>,
293 entity_types_by_message_type: HashMap<TypeId, TypeId>,
294 #[allow(clippy::type_complexity)]
295 message_handlers: HashMap<
296 TypeId,
297 Arc<
298 dyn Send
299 + Sync
300 + Fn(
301 AnyModel,
302 Box<dyn AnyTypedEnvelope>,
303 &Arc<Client>,
304 AsyncAppContext,
305 ) -> LocalBoxFuture<'static, Result<()>>,
306 >,
307 >,
308}
309
310enum WeakSubscriber {
311 Entity { handle: AnyWeakModel },
312 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
313}
314
315#[derive(Clone, Debug, Eq, PartialEq)]
316pub enum Credentials {
317 DevServer { token: DevServerToken },
318 User { user_id: u64, access_token: String },
319}
320
321impl Credentials {
322 pub fn authorization_header(&self) -> String {
323 match self {
324 Credentials::DevServer { token } => format!("dev-server-token {}", token),
325 Credentials::User {
326 user_id,
327 access_token,
328 } => format!("{} {}", user_id, access_token),
329 }
330 }
331}
332
333/// A provider for [`Credentials`].
334///
335/// Used to abstract over reading and writing credentials to some form of
336/// persistence (like the system keychain).
337trait CredentialsProvider {
338 /// Reads the credentials from the provider.
339 fn read_credentials<'a>(
340 &'a self,
341 cx: &'a AsyncAppContext,
342 ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>>;
343
344 /// Writes the credentials to the provider.
345 fn write_credentials<'a>(
346 &'a self,
347 user_id: u64,
348 access_token: String,
349 cx: &'a AsyncAppContext,
350 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
351
352 /// Deletes the credentials from the provider.
353 fn delete_credentials<'a>(
354 &'a self,
355 cx: &'a AsyncAppContext,
356 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
357}
358
359impl Default for ClientState {
360 fn default() -> Self {
361 Self {
362 credentials: None,
363 status: watch::channel_with(Status::SignedOut),
364 entity_id_extractors: Default::default(),
365 _reconnect_task: None,
366 models_by_message_type: Default::default(),
367 entities_by_type_and_remote_id: Default::default(),
368 entity_types_by_message_type: Default::default(),
369 message_handlers: Default::default(),
370 }
371 }
372}
373
374pub enum Subscription {
375 Entity {
376 client: Weak<Client>,
377 id: (TypeId, u64),
378 },
379 Message {
380 client: Weak<Client>,
381 id: TypeId,
382 },
383}
384
385impl Drop for Subscription {
386 fn drop(&mut self) {
387 match self {
388 Subscription::Entity { client, id } => {
389 if let Some(client) = client.upgrade() {
390 let mut state = client.state.write();
391 let _ = state.entities_by_type_and_remote_id.remove(id);
392 }
393 }
394 Subscription::Message { client, id } => {
395 if let Some(client) = client.upgrade() {
396 let mut state = client.state.write();
397 let _ = state.entity_types_by_message_type.remove(id);
398 let _ = state.message_handlers.remove(id);
399 }
400 }
401 }
402 }
403}
404
405pub struct PendingEntitySubscription<T: 'static> {
406 client: Arc<Client>,
407 remote_id: u64,
408 _entity_type: PhantomData<T>,
409 consumed: bool,
410}
411
412impl<T: 'static> PendingEntitySubscription<T> {
413 pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
414 self.consumed = true;
415 let mut state = self.client.state.write();
416 let id = (TypeId::of::<T>(), self.remote_id);
417 let Some(WeakSubscriber::Pending(messages)) =
418 state.entities_by_type_and_remote_id.remove(&id)
419 else {
420 unreachable!()
421 };
422
423 state.entities_by_type_and_remote_id.insert(
424 id,
425 WeakSubscriber::Entity {
426 handle: model.downgrade().into(),
427 },
428 );
429 drop(state);
430 for message in messages {
431 self.client.handle_message(message, cx);
432 }
433 Subscription::Entity {
434 client: Arc::downgrade(&self.client),
435 id,
436 }
437 }
438}
439
440impl<T: 'static> Drop for PendingEntitySubscription<T> {
441 fn drop(&mut self) {
442 if !self.consumed {
443 let mut state = self.client.state.write();
444 if let Some(WeakSubscriber::Pending(messages)) = state
445 .entities_by_type_and_remote_id
446 .remove(&(TypeId::of::<T>(), self.remote_id))
447 {
448 for message in messages {
449 log::info!("unhandled message {}", message.payload_type_name());
450 }
451 }
452 }
453 }
454}
455
456#[derive(Copy, Clone)]
457pub struct TelemetrySettings {
458 pub diagnostics: bool,
459 pub metrics: bool,
460}
461
462/// Control what info is collected by Zed.
463#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
464pub struct TelemetrySettingsContent {
465 /// Send debug info like crash reports.
466 ///
467 /// Default: true
468 pub diagnostics: Option<bool>,
469 /// Send anonymized usage data like what languages you're using Zed with.
470 ///
471 /// Default: true
472 pub metrics: Option<bool>,
473}
474
475impl settings::Settings for TelemetrySettings {
476 const KEY: Option<&'static str> = Some("telemetry");
477
478 type FileContent = TelemetrySettingsContent;
479
480 fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
481 Ok(Self {
482 diagnostics: sources.user.as_ref().and_then(|v| v.diagnostics).unwrap_or(
483 sources
484 .default
485 .diagnostics
486 .ok_or_else(Self::missing_default)?,
487 ),
488 metrics: sources
489 .user
490 .as_ref()
491 .and_then(|v| v.metrics)
492 .unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
493 })
494 }
495}
496
497impl Client {
498 pub fn new(
499 clock: Arc<dyn SystemClock>,
500 http: Arc<HttpClientWithUrl>,
501 cx: &mut AppContext,
502 ) -> Arc<Self> {
503 let use_zed_development_auth = match ReleaseChannel::try_global(cx) {
504 Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH,
505 Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
506 | None => false,
507 };
508
509 let credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static> =
510 if use_zed_development_auth {
511 Arc::new(DevelopmentCredentialsProvider {
512 path: util::paths::CONFIG_DIR.join("development_auth"),
513 })
514 } else {
515 Arc::new(KeychainCredentialsProvider)
516 };
517
518 Arc::new(Self {
519 id: AtomicU64::new(0),
520 peer: Peer::new(0),
521 telemetry: Telemetry::new(clock, http.clone(), cx),
522 http,
523 credentials_provider,
524 state: Default::default(),
525
526 #[cfg(any(test, feature = "test-support"))]
527 authenticate: Default::default(),
528 #[cfg(any(test, feature = "test-support"))]
529 establish_connection: Default::default(),
530 })
531 }
532
533 pub fn production(cx: &mut AppContext) -> Arc<Self> {
534 let clock = Arc::new(clock::RealSystemClock);
535 let http = Arc::new(HttpClientWithUrl::new(
536 &ClientSettings::get_global(cx).server_url,
537 ProxySettings::get_global(cx).proxy.clone(),
538 ));
539 Self::new(clock, http.clone(), cx)
540 }
541
542 pub fn id(&self) -> u64 {
543 self.id.load(Ordering::SeqCst)
544 }
545
546 pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
547 self.http.clone()
548 }
549
550 pub fn set_id(&self, id: u64) -> &Self {
551 self.id.store(id, Ordering::SeqCst);
552 self
553 }
554
555 #[cfg(any(test, feature = "test-support"))]
556 pub fn teardown(&self) {
557 let mut state = self.state.write();
558 state._reconnect_task.take();
559 state.message_handlers.clear();
560 state.models_by_message_type.clear();
561 state.entities_by_type_and_remote_id.clear();
562 state.entity_id_extractors.clear();
563 self.peer.teardown();
564 }
565
566 #[cfg(any(test, feature = "test-support"))]
567 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
568 where
569 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
570 {
571 *self.authenticate.write() = Some(Box::new(authenticate));
572 self
573 }
574
575 #[cfg(any(test, feature = "test-support"))]
576 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
577 where
578 F: 'static
579 + Send
580 + Sync
581 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
582 {
583 *self.establish_connection.write() = Some(Box::new(connect));
584 self
585 }
586
587 pub fn global(cx: &AppContext) -> Arc<Self> {
588 cx.global::<GlobalClient>().0.clone()
589 }
590 pub fn set_global(client: Arc<Client>, cx: &mut AppContext) {
591 cx.set_global(GlobalClient(client))
592 }
593
594 pub fn user_id(&self) -> Option<u64> {
595 if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() {
596 Some(*user_id)
597 } else {
598 None
599 }
600 }
601
602 pub fn peer_id(&self) -> Option<PeerId> {
603 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
604 Some(*peer_id)
605 } else {
606 None
607 }
608 }
609
610 pub fn status(&self) -> watch::Receiver<Status> {
611 self.state.read().status.1.clone()
612 }
613
614 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
615 log::info!("set status on client {}: {:?}", self.id(), status);
616 let mut state = self.state.write();
617 *state.status.0.borrow_mut() = status;
618
619 match status {
620 Status::Connected { .. } => {
621 state._reconnect_task = None;
622 }
623 Status::ConnectionLost => {
624 let this = self.clone();
625 state._reconnect_task = Some(cx.spawn(move |cx| async move {
626 #[cfg(any(test, feature = "test-support"))]
627 let mut rng = StdRng::seed_from_u64(0);
628 #[cfg(not(any(test, feature = "test-support")))]
629 let mut rng = StdRng::from_entropy();
630
631 let mut delay = INITIAL_RECONNECTION_DELAY;
632 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
633 log::error!("failed to connect {}", error);
634 if matches!(*this.status().borrow(), Status::ConnectionError) {
635 this.set_status(
636 Status::ReconnectionError {
637 next_reconnection: Instant::now() + delay,
638 },
639 &cx,
640 );
641 cx.background_executor().timer(delay).await;
642 delay = delay
643 .mul_f32(rng.gen_range(0.5..=2.5))
644 .max(INITIAL_RECONNECTION_DELAY)
645 .min(MAX_RECONNECTION_DELAY);
646 } else {
647 break;
648 }
649 }
650 }));
651 }
652 Status::SignedOut | Status::UpgradeRequired => {
653 self.telemetry.set_authenticated_user_info(None, false);
654 state._reconnect_task.take();
655 }
656 _ => {}
657 }
658 }
659
660 pub fn subscribe_to_entity<T>(
661 self: &Arc<Self>,
662 remote_id: u64,
663 ) -> Result<PendingEntitySubscription<T>>
664 where
665 T: 'static,
666 {
667 let id = (TypeId::of::<T>(), remote_id);
668
669 let mut state = self.state.write();
670 if state.entities_by_type_and_remote_id.contains_key(&id) {
671 return Err(anyhow!("already subscribed to entity"));
672 }
673
674 state
675 .entities_by_type_and_remote_id
676 .insert(id, WeakSubscriber::Pending(Default::default()));
677
678 Ok(PendingEntitySubscription {
679 client: self.clone(),
680 remote_id,
681 consumed: false,
682 _entity_type: PhantomData,
683 })
684 }
685
686 #[track_caller]
687 pub fn add_message_handler<M, E, H, F>(
688 self: &Arc<Self>,
689 entity: WeakModel<E>,
690 handler: H,
691 ) -> Subscription
692 where
693 M: EnvelopedMessage,
694 E: 'static,
695 H: 'static
696 + Sync
697 + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
698 + Send
699 + Sync,
700 F: 'static + Future<Output = Result<()>>,
701 {
702 let message_type_id = TypeId::of::<M>();
703 let mut state = self.state.write();
704 state
705 .models_by_message_type
706 .insert(message_type_id, entity.into());
707
708 let prev_handler = state.message_handlers.insert(
709 message_type_id,
710 Arc::new(move |subscriber, envelope, client, cx| {
711 let subscriber = subscriber.downcast::<E>().unwrap();
712 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
713 handler(subscriber, *envelope, client.clone(), cx).boxed_local()
714 }),
715 );
716 if prev_handler.is_some() {
717 let location = std::panic::Location::caller();
718 panic!(
719 "{}:{} registered handler for the same message {} twice",
720 location.file(),
721 location.line(),
722 std::any::type_name::<M>()
723 );
724 }
725
726 Subscription::Message {
727 client: Arc::downgrade(self),
728 id: message_type_id,
729 }
730 }
731
732 pub fn add_request_handler<M, E, H, F>(
733 self: &Arc<Self>,
734 model: WeakModel<E>,
735 handler: H,
736 ) -> Subscription
737 where
738 M: RequestMessage,
739 E: 'static,
740 H: 'static
741 + Sync
742 + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
743 + Send
744 + Sync,
745 F: 'static + Future<Output = Result<M::Response>>,
746 {
747 self.add_message_handler(model, move |handle, envelope, this, cx| {
748 Self::respond_to_request(
749 envelope.receipt(),
750 handler(handle, envelope, this.clone(), cx),
751 this,
752 )
753 })
754 }
755
756 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
757 where
758 M: EntityMessage,
759 E: 'static,
760 H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
761 F: 'static + Future<Output = Result<()>>,
762 {
763 self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
764 handler(subscriber.downcast::<E>().unwrap(), message, client, cx)
765 })
766 }
767
768 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
769 where
770 M: EntityMessage,
771 E: 'static,
772 H: 'static + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
773 F: 'static + Future<Output = Result<()>>,
774 {
775 let model_type_id = TypeId::of::<E>();
776 let message_type_id = TypeId::of::<M>();
777
778 let mut state = self.state.write();
779 state
780 .entity_types_by_message_type
781 .insert(message_type_id, model_type_id);
782 state
783 .entity_id_extractors
784 .entry(message_type_id)
785 .or_insert_with(|| {
786 |envelope| {
787 envelope
788 .as_any()
789 .downcast_ref::<TypedEnvelope<M>>()
790 .unwrap()
791 .payload
792 .remote_entity_id()
793 }
794 });
795 let prev_handler = state.message_handlers.insert(
796 message_type_id,
797 Arc::new(move |handle, envelope, client, cx| {
798 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
799 handler(handle, *envelope, client.clone(), cx).boxed_local()
800 }),
801 );
802 if prev_handler.is_some() {
803 panic!("registered handler for the same message twice");
804 }
805 }
806
807 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
808 where
809 M: EntityMessage + RequestMessage,
810 E: 'static,
811 H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
812 F: 'static + Future<Output = Result<M::Response>>,
813 {
814 self.add_model_message_handler(move |entity, envelope, client, cx| {
815 Self::respond_to_request::<M, _>(
816 envelope.receipt(),
817 handler(entity, envelope, client.clone(), cx),
818 client,
819 )
820 })
821 }
822
823 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
824 receipt: Receipt<T>,
825 response: F,
826 client: Arc<Self>,
827 ) -> Result<()> {
828 match response.await {
829 Ok(response) => {
830 client.respond(receipt, response)?;
831 Ok(())
832 }
833 Err(error) => {
834 client.respond_with_error(receipt, error.to_proto())?;
835 Err(error)
836 }
837 }
838 }
839
840 pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool {
841 self.credentials_provider
842 .read_credentials(cx)
843 .await
844 .is_some()
845 }
846
847 pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self {
848 self.state.write().credentials = Some(Credentials::DevServer { token });
849 self
850 }
851
852 #[async_recursion(?Send)]
853 pub async fn authenticate_and_connect(
854 self: &Arc<Self>,
855 try_provider: bool,
856 cx: &AsyncAppContext,
857 ) -> anyhow::Result<()> {
858 let was_disconnected = match *self.status().borrow() {
859 Status::SignedOut => true,
860 Status::ConnectionError
861 | Status::ConnectionLost
862 | Status::Authenticating { .. }
863 | Status::Reauthenticating { .. }
864 | Status::ReconnectionError { .. } => false,
865 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
866 return Ok(())
867 }
868 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
869 };
870 if was_disconnected {
871 self.set_status(Status::Authenticating, cx);
872 } else {
873 self.set_status(Status::Reauthenticating, cx)
874 }
875
876 let mut read_from_provider = false;
877 let mut credentials = self.state.read().credentials.clone();
878 if credentials.is_none() && try_provider {
879 credentials = self.credentials_provider.read_credentials(cx).await;
880 read_from_provider = credentials.is_some();
881 }
882
883 if credentials.is_none() {
884 let mut status_rx = self.status();
885 let _ = status_rx.next().await;
886 futures::select_biased! {
887 authenticate = self.authenticate(cx).fuse() => {
888 match authenticate {
889 Ok(creds) => credentials = Some(creds),
890 Err(err) => {
891 self.set_status(Status::ConnectionError, cx);
892 return Err(err);
893 }
894 }
895 }
896 _ = status_rx.next().fuse() => {
897 return Err(anyhow!("authentication canceled"));
898 }
899 }
900 }
901 let credentials = credentials.unwrap();
902 if let Credentials::User { user_id, .. } = &credentials {
903 self.set_id(*user_id);
904 }
905
906 if was_disconnected {
907 self.set_status(Status::Connecting, cx);
908 } else {
909 self.set_status(Status::Reconnecting, cx);
910 }
911
912 let mut timeout =
913 futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
914 futures::select_biased! {
915 connection = self.establish_connection(&credentials, cx).fuse() => {
916 match connection {
917 Ok(conn) => {
918 self.state.write().credentials = Some(credentials.clone());
919 if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
920 if let Credentials::User{user_id, access_token} = credentials {
921 self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err();
922 }
923 }
924
925 futures::select_biased! {
926 result = self.set_connection(conn, cx).fuse() => result,
927 _ = timeout => {
928 self.set_status(Status::ConnectionError, cx);
929 Err(anyhow!("timed out waiting on hello message from server"))
930 }
931 }
932 }
933 Err(EstablishConnectionError::Unauthorized) => {
934 self.state.write().credentials.take();
935 if read_from_provider {
936 self.credentials_provider.delete_credentials(cx).await.log_err();
937 self.set_status(Status::SignedOut, cx);
938 self.authenticate_and_connect(false, cx).await
939 } else {
940 self.set_status(Status::ConnectionError, cx);
941 Err(EstablishConnectionError::Unauthorized)?
942 }
943 }
944 Err(EstablishConnectionError::UpgradeRequired) => {
945 self.set_status(Status::UpgradeRequired, cx);
946 Err(EstablishConnectionError::UpgradeRequired)?
947 }
948 Err(error) => {
949 self.set_status(Status::ConnectionError, cx);
950 Err(error)?
951 }
952 }
953 }
954 _ = &mut timeout => {
955 self.set_status(Status::ConnectionError, cx);
956 Err(anyhow!("timed out trying to establish connection"))
957 }
958 }
959 }
960
961 async fn set_connection(
962 self: &Arc<Self>,
963 conn: Connection,
964 cx: &AsyncAppContext,
965 ) -> Result<()> {
966 let executor = cx.background_executor();
967 log::info!("add connection to peer");
968 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
969 let executor = executor.clone();
970 move |duration| executor.timer(duration)
971 });
972 let handle_io = executor.spawn(handle_io);
973
974 let peer_id = async {
975 log::info!("waiting for server hello");
976 let message = incoming
977 .next()
978 .await
979 .ok_or_else(|| anyhow!("no hello message received"))?;
980 log::info!("got server hello");
981 let hello_message_type_name = message.payload_type_name().to_string();
982 let hello = message
983 .into_any()
984 .downcast::<TypedEnvelope<proto::Hello>>()
985 .map_err(|_| {
986 anyhow!(
987 "invalid hello message received: {:?}",
988 hello_message_type_name
989 )
990 })?;
991 let peer_id = hello
992 .payload
993 .peer_id
994 .ok_or_else(|| anyhow!("invalid peer id"))?;
995 Ok(peer_id)
996 };
997
998 let peer_id = match peer_id.await {
999 Ok(peer_id) => peer_id,
1000 Err(error) => {
1001 self.peer.disconnect(connection_id);
1002 return Err(error);
1003 }
1004 };
1005
1006 log::info!(
1007 "set status to connected (connection id: {:?}, peer id: {:?})",
1008 connection_id,
1009 peer_id
1010 );
1011 self.set_status(
1012 Status::Connected {
1013 peer_id,
1014 connection_id,
1015 },
1016 cx,
1017 );
1018
1019 cx.spawn({
1020 let this = self.clone();
1021 |cx| {
1022 async move {
1023 while let Some(message) = incoming.next().await {
1024 this.handle_message(message, &cx);
1025 // Don't starve the main thread when receiving lots of messages at once.
1026 smol::future::yield_now().await;
1027 }
1028 }
1029 }
1030 })
1031 .detach();
1032
1033 cx.spawn({
1034 let this = self.clone();
1035 move |cx| async move {
1036 match handle_io.await {
1037 Ok(()) => {
1038 if *this.status().borrow()
1039 == (Status::Connected {
1040 connection_id,
1041 peer_id,
1042 })
1043 {
1044 this.set_status(Status::SignedOut, &cx);
1045 }
1046 }
1047 Err(err) => {
1048 log::error!("connection error: {:?}", err);
1049 this.set_status(Status::ConnectionLost, &cx);
1050 }
1051 }
1052 }
1053 })
1054 .detach();
1055
1056 Ok(())
1057 }
1058
1059 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
1060 #[cfg(any(test, feature = "test-support"))]
1061 if let Some(callback) = self.authenticate.read().as_ref() {
1062 return callback(cx);
1063 }
1064
1065 self.authenticate_with_browser(cx)
1066 }
1067
1068 fn establish_connection(
1069 self: &Arc<Self>,
1070 credentials: &Credentials,
1071 cx: &AsyncAppContext,
1072 ) -> Task<Result<Connection, EstablishConnectionError>> {
1073 #[cfg(any(test, feature = "test-support"))]
1074 if let Some(callback) = self.establish_connection.read().as_ref() {
1075 return callback(credentials, cx);
1076 }
1077
1078 self.establish_websocket_connection(credentials, cx)
1079 }
1080
1081 async fn get_rpc_url(
1082 http: Arc<HttpClientWithUrl>,
1083 release_channel: Option<ReleaseChannel>,
1084 ) -> Result<Url> {
1085 if let Some(url) = &*ZED_RPC_URL {
1086 return Url::parse(url).context("invalid rpc url");
1087 }
1088
1089 let mut url = http.build_url("/rpc");
1090 if let Some(preview_param) =
1091 release_channel.and_then(|channel| channel.release_query_param())
1092 {
1093 url += "?";
1094 url += preview_param;
1095 }
1096 let response = http.get(&url, Default::default(), false).await?;
1097 let collab_url = if response.status().is_redirection() {
1098 response
1099 .headers()
1100 .get("Location")
1101 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1102 .to_str()
1103 .map_err(EstablishConnectionError::other)?
1104 .to_string()
1105 } else {
1106 Err(anyhow!(
1107 "unexpected /rpc response status {}",
1108 response.status()
1109 ))?
1110 };
1111
1112 Url::parse(&collab_url).context("invalid rpc url")
1113 }
1114
1115 fn establish_websocket_connection(
1116 self: &Arc<Self>,
1117 credentials: &Credentials,
1118 cx: &AsyncAppContext,
1119 ) -> Task<Result<Connection, EstablishConnectionError>> {
1120 let release_channel = cx
1121 .update(|cx| ReleaseChannel::try_global(cx))
1122 .ok()
1123 .flatten();
1124 let app_version = cx
1125 .update(|cx| AppVersion::global(cx).to_string())
1126 .ok()
1127 .unwrap_or_default();
1128
1129 let request = Request::builder()
1130 .header("Authorization", credentials.authorization_header())
1131 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION)
1132 .header("x-zed-app-version", app_version)
1133 .header(
1134 "x-zed-release-channel",
1135 release_channel.map(|r| r.dev_name()).unwrap_or("unknown"),
1136 );
1137
1138 let http = self.http.clone();
1139 cx.background_executor().spawn(async move {
1140 let mut rpc_url = Self::get_rpc_url(http, release_channel).await?;
1141 let rpc_host = rpc_url
1142 .host_str()
1143 .zip(rpc_url.port_or_known_default())
1144 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1145 let stream = smol::net::TcpStream::connect(rpc_host).await?;
1146
1147 log::info!("connected to rpc endpoint {}", rpc_url);
1148
1149 match rpc_url.scheme() {
1150 "https" => {
1151 rpc_url.set_scheme("wss").unwrap();
1152 let request = request.uri(rpc_url.as_str()).body(())?;
1153 let (stream, _) =
1154 async_tungstenite::async_std::client_async_tls(request, stream).await?;
1155 Ok(Connection::new(
1156 stream
1157 .map_err(|error| anyhow!(error))
1158 .sink_map_err(|error| anyhow!(error)),
1159 ))
1160 }
1161 "http" => {
1162 rpc_url.set_scheme("ws").unwrap();
1163 let request = request.uri(rpc_url.as_str()).body(())?;
1164 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1165 Ok(Connection::new(
1166 stream
1167 .map_err(|error| anyhow!(error))
1168 .sink_map_err(|error| anyhow!(error)),
1169 ))
1170 }
1171 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1172 }
1173 })
1174 }
1175
1176 pub fn authenticate_with_browser(
1177 self: &Arc<Self>,
1178 cx: &AsyncAppContext,
1179 ) -> Task<Result<Credentials>> {
1180 let http = self.http.clone();
1181 cx.spawn(|cx| async move {
1182 let background = cx.background_executor().clone();
1183
1184 let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1185 cx.update(|cx| {
1186 cx.spawn(move |cx| async move {
1187 let url = open_url_rx.await?;
1188 cx.update(|cx| cx.open_url(&url))
1189 })
1190 .detach_and_log_err(cx);
1191 })
1192 .log_err();
1193
1194 let credentials = background
1195 .clone()
1196 .spawn(async move {
1197 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1198 // zed server to encrypt the user's access token, so that it can'be intercepted by
1199 // any other app running on the user's device.
1200 let (public_key, private_key) =
1201 rpc::auth::keypair().expect("failed to generate keypair for auth");
1202 let public_key_string = String::try_from(public_key)
1203 .expect("failed to serialize public key for auth");
1204
1205 if let Some((login, token)) =
1206 IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1207 {
1208 eprintln!("authenticate as admin {login}, {token}");
1209
1210 return Self::authenticate_as_admin(http, login.clone(), token.clone())
1211 .await;
1212 }
1213
1214 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1215 let server =
1216 tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1217 let port = server.server_addr().port();
1218
1219 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1220 // that the user is signing in from a Zed app running on the same device.
1221 let mut url = http.build_url(&format!(
1222 "/native_app_signin?native_app_port={}&native_app_public_key={}",
1223 port, public_key_string
1224 ));
1225
1226 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1227 log::info!("impersonating user @{}", impersonate_login);
1228 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1229 }
1230
1231 open_url_tx.send(url).log_err();
1232
1233 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1234 // access token from the query params.
1235 //
1236 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1237 // custom URL scheme instead of this local HTTP server.
1238 let (user_id, access_token) = background
1239 .spawn(async move {
1240 for _ in 0..100 {
1241 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1242 let path = req.url();
1243 let mut user_id = None;
1244 let mut access_token = None;
1245 let url = Url::parse(&format!("http://example.com{}", path))
1246 .context("failed to parse login notification url")?;
1247 for (key, value) in url.query_pairs() {
1248 if key == "access_token" {
1249 access_token = Some(value.to_string());
1250 } else if key == "user_id" {
1251 user_id = Some(value.to_string());
1252 }
1253 }
1254
1255 let post_auth_url =
1256 http.build_url("/native_app_signin_succeeded");
1257 req.respond(
1258 tiny_http::Response::empty(302).with_header(
1259 tiny_http::Header::from_bytes(
1260 &b"Location"[..],
1261 post_auth_url.as_bytes(),
1262 )
1263 .unwrap(),
1264 ),
1265 )
1266 .context("failed to respond to login http request")?;
1267 return Ok((
1268 user_id
1269 .ok_or_else(|| anyhow!("missing user_id parameter"))?,
1270 access_token.ok_or_else(|| {
1271 anyhow!("missing access_token parameter")
1272 })?,
1273 ));
1274 }
1275 }
1276
1277 Err(anyhow!("didn't receive login redirect"))
1278 })
1279 .await?;
1280
1281 let access_token = private_key
1282 .decrypt_string(&access_token)
1283 .context("failed to decrypt access token")?;
1284
1285 Ok(Credentials::User {
1286 user_id: user_id.parse()?,
1287 access_token,
1288 })
1289 })
1290 .await?;
1291
1292 cx.update(|cx| cx.activate(true))?;
1293 Ok(credentials)
1294 })
1295 }
1296
1297 async fn authenticate_as_admin(
1298 http: Arc<HttpClientWithUrl>,
1299 login: String,
1300 mut api_token: String,
1301 ) -> Result<Credentials> {
1302 #[derive(Deserialize)]
1303 struct AuthenticatedUserResponse {
1304 user: User,
1305 }
1306
1307 #[derive(Deserialize)]
1308 struct User {
1309 id: u64,
1310 }
1311
1312 // Use the collab server's admin API to retrieve the id
1313 // of the impersonated user.
1314 let mut url = Self::get_rpc_url(http.clone(), None).await?;
1315 url.set_path("/user");
1316 url.set_query(Some(&format!("github_login={login}")));
1317 let request = Request::get(url.as_str())
1318 .header("Authorization", format!("token {api_token}"))
1319 .body("".into())?;
1320
1321 let mut response = http.send(request).await?;
1322 let mut body = String::new();
1323 response.body_mut().read_to_string(&mut body).await?;
1324 if !response.status().is_success() {
1325 Err(anyhow!(
1326 "admin user request failed {} - {}",
1327 response.status().as_u16(),
1328 body,
1329 ))?;
1330 }
1331 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1332
1333 // Use the admin API token to authenticate as the impersonated user.
1334 api_token.insert_str(0, "ADMIN_TOKEN:");
1335 Ok(Credentials::User {
1336 user_id: response.user.id,
1337 access_token: api_token,
1338 })
1339 }
1340
1341 pub async fn sign_out(self: &Arc<Self>, cx: &AsyncAppContext) {
1342 self.state.write().credentials = None;
1343 self.disconnect(&cx);
1344
1345 if self.has_credentials(cx).await {
1346 self.credentials_provider
1347 .delete_credentials(cx)
1348 .await
1349 .log_err();
1350 }
1351 }
1352
1353 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1354 self.peer.teardown();
1355 self.set_status(Status::SignedOut, cx);
1356 }
1357
1358 pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1359 self.peer.teardown();
1360 self.set_status(Status::ConnectionLost, cx);
1361 }
1362
1363 fn connection_id(&self) -> Result<ConnectionId> {
1364 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1365 Ok(connection_id)
1366 } else {
1367 Err(anyhow!("not connected"))
1368 }
1369 }
1370
1371 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1372 log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1373 self.peer.send(self.connection_id()?, message)
1374 }
1375
1376 pub fn request<T: RequestMessage>(
1377 &self,
1378 request: T,
1379 ) -> impl Future<Output = Result<T::Response>> {
1380 self.request_envelope(request)
1381 .map_ok(|envelope| envelope.payload)
1382 }
1383
1384 pub fn request_stream<T: RequestMessage>(
1385 &self,
1386 request: T,
1387 ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1388 let client_id = self.id.load(Ordering::SeqCst);
1389 log::debug!(
1390 "rpc request start. client_id:{}. name:{}",
1391 client_id,
1392 T::NAME
1393 );
1394 let response = self
1395 .connection_id()
1396 .map(|conn_id| self.peer.request_stream(conn_id, request));
1397 async move {
1398 let response = response?.await;
1399 log::debug!(
1400 "rpc request finish. client_id:{}. name:{}",
1401 client_id,
1402 T::NAME
1403 );
1404 response
1405 }
1406 }
1407
1408 pub fn request_envelope<T: RequestMessage>(
1409 &self,
1410 request: T,
1411 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1412 let client_id = self.id();
1413 log::debug!(
1414 "rpc request start. client_id:{}. name:{}",
1415 client_id,
1416 T::NAME
1417 );
1418 let response = self
1419 .connection_id()
1420 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1421 async move {
1422 let response = response?.await;
1423 log::debug!(
1424 "rpc request finish. client_id:{}. name:{}",
1425 client_id,
1426 T::NAME
1427 );
1428 response
1429 }
1430 }
1431
1432 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1433 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1434 self.peer.respond(receipt, response)
1435 }
1436
1437 fn respond_with_error<T: RequestMessage>(
1438 &self,
1439 receipt: Receipt<T>,
1440 error: proto::Error,
1441 ) -> Result<()> {
1442 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1443 self.peer.respond_with_error(receipt, error)
1444 }
1445
1446 fn handle_message(
1447 self: &Arc<Client>,
1448 message: Box<dyn AnyTypedEnvelope>,
1449 cx: &AsyncAppContext,
1450 ) {
1451 let mut state = self.state.write();
1452 let type_name = message.payload_type_name();
1453 let payload_type_id = message.payload_type_id();
1454 let sender_id = message.original_sender_id();
1455
1456 let mut subscriber = None;
1457
1458 if let Some(handle) = state
1459 .models_by_message_type
1460 .get(&payload_type_id)
1461 .and_then(|handle| handle.upgrade())
1462 {
1463 subscriber = Some(handle);
1464 } else if let Some((extract_entity_id, entity_type_id)) =
1465 state.entity_id_extractors.get(&payload_type_id).zip(
1466 state
1467 .entity_types_by_message_type
1468 .get(&payload_type_id)
1469 .copied(),
1470 )
1471 {
1472 let entity_id = (extract_entity_id)(message.as_ref());
1473
1474 match state
1475 .entities_by_type_and_remote_id
1476 .get_mut(&(entity_type_id, entity_id))
1477 {
1478 Some(WeakSubscriber::Pending(pending)) => {
1479 pending.push(message);
1480 return;
1481 }
1482 Some(weak_subscriber) => match weak_subscriber {
1483 WeakSubscriber::Entity { handle } => {
1484 subscriber = handle.upgrade();
1485 }
1486
1487 WeakSubscriber::Pending(_) => {}
1488 },
1489 _ => {}
1490 }
1491 }
1492
1493 let subscriber = if let Some(subscriber) = subscriber {
1494 subscriber
1495 } else {
1496 log::info!("unhandled message {}", type_name);
1497 self.peer.respond_with_unhandled_message(message).log_err();
1498 return;
1499 };
1500
1501 let handler = state.message_handlers.get(&payload_type_id).cloned();
1502 // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1503 // It also ensures we don't hold the lock while yielding back to the executor, as
1504 // that might cause the executor thread driving this future to block indefinitely.
1505 drop(state);
1506
1507 if let Some(handler) = handler {
1508 let future = handler(subscriber, message, self, cx.clone());
1509 let client_id = self.id();
1510 log::debug!(
1511 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1512 client_id,
1513 sender_id,
1514 type_name
1515 );
1516 cx.spawn(move |_| async move {
1517 match future.await {
1518 Ok(()) => {
1519 log::debug!(
1520 "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1521 client_id,
1522 sender_id,
1523 type_name
1524 );
1525 }
1526 Err(error) => {
1527 log::error!(
1528 "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1529 client_id,
1530 sender_id,
1531 type_name,
1532 error
1533 );
1534 }
1535 }
1536 })
1537 .detach();
1538 } else {
1539 log::info!("unhandled message {}", type_name);
1540 self.peer.respond_with_unhandled_message(message).log_err();
1541 }
1542 }
1543
1544 pub fn telemetry(&self) -> &Arc<Telemetry> {
1545 &self.telemetry
1546 }
1547}
1548
1549#[derive(Serialize, Deserialize)]
1550struct DevelopmentCredentials {
1551 user_id: u64,
1552 access_token: String,
1553}
1554
1555/// A credentials provider that stores credentials in a local file.
1556///
1557/// This MUST only be used in development, as this is not a secure way of storing
1558/// credentials on user machines.
1559///
1560/// Its existence is purely to work around the annoyance of having to constantly
1561/// re-allow access to the system keychain when developing Zed.
1562struct DevelopmentCredentialsProvider {
1563 path: PathBuf,
1564}
1565
1566impl CredentialsProvider for DevelopmentCredentialsProvider {
1567 fn read_credentials<'a>(
1568 &'a self,
1569 _cx: &'a AsyncAppContext,
1570 ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1571 async move {
1572 if IMPERSONATE_LOGIN.is_some() {
1573 return None;
1574 }
1575
1576 let json = std::fs::read(&self.path).log_err()?;
1577
1578 let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?;
1579
1580 Some(Credentials::User {
1581 user_id: credentials.user_id,
1582 access_token: credentials.access_token,
1583 })
1584 }
1585 .boxed_local()
1586 }
1587
1588 fn write_credentials<'a>(
1589 &'a self,
1590 user_id: u64,
1591 access_token: String,
1592 _cx: &'a AsyncAppContext,
1593 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1594 async move {
1595 let json = serde_json::to_string(&DevelopmentCredentials {
1596 user_id,
1597 access_token,
1598 })?;
1599
1600 std::fs::write(&self.path, json)?;
1601
1602 Ok(())
1603 }
1604 .boxed_local()
1605 }
1606
1607 fn delete_credentials<'a>(
1608 &'a self,
1609 _cx: &'a AsyncAppContext,
1610 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1611 async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local()
1612 }
1613}
1614
1615/// A credentials provider that stores credentials in the system keychain.
1616struct KeychainCredentialsProvider;
1617
1618impl CredentialsProvider for KeychainCredentialsProvider {
1619 fn read_credentials<'a>(
1620 &'a self,
1621 cx: &'a AsyncAppContext,
1622 ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1623 async move {
1624 if IMPERSONATE_LOGIN.is_some() {
1625 return None;
1626 }
1627
1628 let (user_id, access_token) = cx
1629 .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
1630 .log_err()?
1631 .await
1632 .log_err()??;
1633
1634 Some(Credentials::User {
1635 user_id: user_id.parse().ok()?,
1636 access_token: String::from_utf8(access_token).ok()?,
1637 })
1638 }
1639 .boxed_local()
1640 }
1641
1642 fn write_credentials<'a>(
1643 &'a self,
1644 user_id: u64,
1645 access_token: String,
1646 cx: &'a AsyncAppContext,
1647 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1648 async move {
1649 cx.update(move |cx| {
1650 cx.write_credentials(
1651 &ClientSettings::get_global(cx).server_url,
1652 &user_id.to_string(),
1653 access_token.as_bytes(),
1654 )
1655 })?
1656 .await
1657 }
1658 .boxed_local()
1659 }
1660
1661 fn delete_credentials<'a>(
1662 &'a self,
1663 cx: &'a AsyncAppContext,
1664 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1665 async move {
1666 cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
1667 .await
1668 }
1669 .boxed_local()
1670 }
1671}
1672
1673/// prefix for the zed:// url scheme
1674pub static ZED_URL_SCHEME: &str = "zed";
1675
1676/// Parses the given link into a Zed link.
1677///
1678/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1679/// Returns [`None`] otherwise.
1680pub fn parse_zed_link<'a>(link: &'a str, cx: &AppContext) -> Option<&'a str> {
1681 let server_url = &ClientSettings::get_global(cx).server_url;
1682 if let Some(stripped) = link
1683 .strip_prefix(server_url)
1684 .and_then(|result| result.strip_prefix('/'))
1685 {
1686 return Some(stripped);
1687 }
1688 if let Some(stripped) = link
1689 .strip_prefix(ZED_URL_SCHEME)
1690 .and_then(|result| result.strip_prefix("://"))
1691 {
1692 return Some(stripped);
1693 }
1694
1695 None
1696}
1697
1698#[cfg(test)]
1699mod tests {
1700 use super::*;
1701 use crate::test::FakeServer;
1702
1703 use clock::FakeSystemClock;
1704 use gpui::{BackgroundExecutor, Context, TestAppContext};
1705 use http::FakeHttpClient;
1706 use parking_lot::Mutex;
1707 use settings::SettingsStore;
1708 use std::future;
1709
1710 #[gpui::test(iterations = 10)]
1711 async fn test_reconnection(cx: &mut TestAppContext) {
1712 init_test(cx);
1713 let user_id = 5;
1714 let client = cx.update(|cx| {
1715 Client::new(
1716 Arc::new(FakeSystemClock::default()),
1717 FakeHttpClient::with_404_response(),
1718 cx,
1719 )
1720 });
1721 let server = FakeServer::for_client(user_id, &client, cx).await;
1722 let mut status = client.status();
1723 assert!(matches!(
1724 status.next().await,
1725 Some(Status::Connected { .. })
1726 ));
1727 assert_eq!(server.auth_count(), 1);
1728
1729 server.forbid_connections();
1730 server.disconnect();
1731 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1732
1733 server.allow_connections();
1734 cx.executor().advance_clock(Duration::from_secs(10));
1735 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1736 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1737
1738 server.forbid_connections();
1739 server.disconnect();
1740 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1741
1742 // Clear cached credentials after authentication fails
1743 server.roll_access_token();
1744 server.allow_connections();
1745 cx.executor().run_until_parked();
1746 cx.executor().advance_clock(Duration::from_secs(10));
1747 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1748 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1749 }
1750
1751 #[gpui::test(iterations = 10)]
1752 async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1753 init_test(cx);
1754 let user_id = 5;
1755 let client = cx.update(|cx| {
1756 Client::new(
1757 Arc::new(FakeSystemClock::default()),
1758 FakeHttpClient::with_404_response(),
1759 cx,
1760 )
1761 });
1762 let mut status = client.status();
1763
1764 // Time out when client tries to connect.
1765 client.override_authenticate(move |cx| {
1766 cx.background_executor().spawn(async move {
1767 Ok(Credentials::User {
1768 user_id,
1769 access_token: "token".into(),
1770 })
1771 })
1772 });
1773 client.override_establish_connection(|_, cx| {
1774 cx.background_executor().spawn(async move {
1775 future::pending::<()>().await;
1776 unreachable!()
1777 })
1778 });
1779 let auth_and_connect = cx.spawn({
1780 let client = client.clone();
1781 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1782 });
1783 executor.run_until_parked();
1784 assert!(matches!(status.next().await, Some(Status::Connecting)));
1785
1786 executor.advance_clock(CONNECTION_TIMEOUT);
1787 assert!(matches!(
1788 status.next().await,
1789 Some(Status::ConnectionError { .. })
1790 ));
1791 auth_and_connect.await.unwrap_err();
1792
1793 // Allow the connection to be established.
1794 let server = FakeServer::for_client(user_id, &client, cx).await;
1795 assert!(matches!(
1796 status.next().await,
1797 Some(Status::Connected { .. })
1798 ));
1799
1800 // Disconnect client.
1801 server.forbid_connections();
1802 server.disconnect();
1803 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1804
1805 // Time out when re-establishing the connection.
1806 server.allow_connections();
1807 client.override_establish_connection(|_, cx| {
1808 cx.background_executor().spawn(async move {
1809 future::pending::<()>().await;
1810 unreachable!()
1811 })
1812 });
1813 executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1814 assert!(matches!(
1815 status.next().await,
1816 Some(Status::Reconnecting { .. })
1817 ));
1818
1819 executor.advance_clock(CONNECTION_TIMEOUT);
1820 assert!(matches!(
1821 status.next().await,
1822 Some(Status::ReconnectionError { .. })
1823 ));
1824 }
1825
1826 #[gpui::test(iterations = 10)]
1827 async fn test_authenticating_more_than_once(
1828 cx: &mut TestAppContext,
1829 executor: BackgroundExecutor,
1830 ) {
1831 init_test(cx);
1832 let auth_count = Arc::new(Mutex::new(0));
1833 let dropped_auth_count = Arc::new(Mutex::new(0));
1834 let client = cx.update(|cx| {
1835 Client::new(
1836 Arc::new(FakeSystemClock::default()),
1837 FakeHttpClient::with_404_response(),
1838 cx,
1839 )
1840 });
1841 client.override_authenticate({
1842 let auth_count = auth_count.clone();
1843 let dropped_auth_count = dropped_auth_count.clone();
1844 move |cx| {
1845 let auth_count = auth_count.clone();
1846 let dropped_auth_count = dropped_auth_count.clone();
1847 cx.background_executor().spawn(async move {
1848 *auth_count.lock() += 1;
1849 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1850 future::pending::<()>().await;
1851 unreachable!()
1852 })
1853 }
1854 });
1855
1856 let _authenticate = cx.spawn({
1857 let client = client.clone();
1858 move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1859 });
1860 executor.run_until_parked();
1861 assert_eq!(*auth_count.lock(), 1);
1862 assert_eq!(*dropped_auth_count.lock(), 0);
1863
1864 let _authenticate = cx.spawn({
1865 let client = client.clone();
1866 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1867 });
1868 executor.run_until_parked();
1869 assert_eq!(*auth_count.lock(), 2);
1870 assert_eq!(*dropped_auth_count.lock(), 1);
1871 }
1872
1873 #[gpui::test]
1874 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1875 init_test(cx);
1876 let user_id = 5;
1877 let client = cx.update(|cx| {
1878 Client::new(
1879 Arc::new(FakeSystemClock::default()),
1880 FakeHttpClient::with_404_response(),
1881 cx,
1882 )
1883 });
1884 let server = FakeServer::for_client(user_id, &client, cx).await;
1885
1886 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1887 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1888 client.add_model_message_handler(
1889 move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
1890 match model.update(&mut cx, |model, _| model.id).unwrap() {
1891 1 => done_tx1.try_send(()).unwrap(),
1892 2 => done_tx2.try_send(()).unwrap(),
1893 _ => unreachable!(),
1894 }
1895 async { Ok(()) }
1896 },
1897 );
1898 let model1 = cx.new_model(|_| TestModel {
1899 id: 1,
1900 subscription: None,
1901 });
1902 let model2 = cx.new_model(|_| TestModel {
1903 id: 2,
1904 subscription: None,
1905 });
1906 let model3 = cx.new_model(|_| TestModel {
1907 id: 3,
1908 subscription: None,
1909 });
1910
1911 let _subscription1 = client
1912 .subscribe_to_entity(1)
1913 .unwrap()
1914 .set_model(&model1, &mut cx.to_async());
1915 let _subscription2 = client
1916 .subscribe_to_entity(2)
1917 .unwrap()
1918 .set_model(&model2, &mut cx.to_async());
1919 // Ensure dropping a subscription for the same entity type still allows receiving of
1920 // messages for other entity IDs of the same type.
1921 let subscription3 = client
1922 .subscribe_to_entity(3)
1923 .unwrap()
1924 .set_model(&model3, &mut cx.to_async());
1925 drop(subscription3);
1926
1927 server.send(proto::JoinProject { project_id: 1 });
1928 server.send(proto::JoinProject { project_id: 2 });
1929 done_rx1.next().await.unwrap();
1930 done_rx2.next().await.unwrap();
1931 }
1932
1933 #[gpui::test]
1934 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1935 init_test(cx);
1936 let user_id = 5;
1937 let client = cx.update(|cx| {
1938 Client::new(
1939 Arc::new(FakeSystemClock::default()),
1940 FakeHttpClient::with_404_response(),
1941 cx,
1942 )
1943 });
1944 let server = FakeServer::for_client(user_id, &client, cx).await;
1945
1946 let model = cx.new_model(|_| TestModel::default());
1947 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1948 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1949 let subscription1 = client.add_message_handler(
1950 model.downgrade(),
1951 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1952 done_tx1.try_send(()).unwrap();
1953 async { Ok(()) }
1954 },
1955 );
1956 drop(subscription1);
1957 let _subscription2 = client.add_message_handler(
1958 model.downgrade(),
1959 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1960 done_tx2.try_send(()).unwrap();
1961 async { Ok(()) }
1962 },
1963 );
1964 server.send(proto::Ping {});
1965 done_rx2.next().await.unwrap();
1966 }
1967
1968 #[gpui::test]
1969 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1970 init_test(cx);
1971 let user_id = 5;
1972 let client = cx.update(|cx| {
1973 Client::new(
1974 Arc::new(FakeSystemClock::default()),
1975 FakeHttpClient::with_404_response(),
1976 cx,
1977 )
1978 });
1979 let server = FakeServer::for_client(user_id, &client, cx).await;
1980
1981 let model = cx.new_model(|_| TestModel::default());
1982 let (done_tx, mut done_rx) = smol::channel::unbounded();
1983 let subscription = client.add_message_handler(
1984 model.clone().downgrade(),
1985 move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1986 model
1987 .update(&mut cx, |model, _| model.subscription.take())
1988 .unwrap();
1989 done_tx.try_send(()).unwrap();
1990 async { Ok(()) }
1991 },
1992 );
1993 model.update(cx, |model, _| {
1994 model.subscription = Some(subscription);
1995 });
1996 server.send(proto::Ping {});
1997 done_rx.next().await.unwrap();
1998 }
1999
2000 #[derive(Default)]
2001 struct TestModel {
2002 id: usize,
2003 subscription: Option<Subscription>,
2004 }
2005
2006 fn init_test(cx: &mut TestAppContext) {
2007 cx.update(|cx| {
2008 let settings_store = SettingsStore::test(cx);
2009 cx.set_global(settings_store);
2010 init_settings(cx);
2011 });
2012 }
2013}