1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod telemetry;
5pub mod user;
6
7use anyhow::{anyhow, Context as _, Result};
8use async_recursion::async_recursion;
9use async_tungstenite::tungstenite::{
10 error::Error as WebsocketError,
11 http::{Request, StatusCode},
12};
13use clock::SystemClock;
14use collections::HashMap;
15use futures::{
16 channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
17 TryFutureExt as _, TryStreamExt,
18};
19use gpui::{
20 actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, BorrowAppContext, Global, Model,
21 Task, WeakModel,
22};
23use lazy_static::lazy_static;
24use parking_lot::RwLock;
25use postage::watch;
26use rand::prelude::*;
27use release_channel::{AppVersion, ReleaseChannel};
28use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::{Settings, SettingsSources, SettingsStore};
32use std::fmt;
33use std::pin::Pin;
34use std::{
35 any::TypeId,
36 convert::TryFrom,
37 fmt::Write as _,
38 future::Future,
39 marker::PhantomData,
40 path::PathBuf,
41 sync::{
42 atomic::{AtomicU64, Ordering},
43 Arc, Weak,
44 },
45 time::{Duration, Instant},
46};
47use telemetry::Telemetry;
48use thiserror::Error;
49use url::Url;
50use util::http::{HttpClient, HttpClientWithUrl};
51use util::{ResultExt, TryFutureExt};
52
53pub use rpc::*;
54pub use telemetry_events::Event;
55pub use user::*;
56
57#[derive(Debug, Clone, Eq, PartialEq)]
58pub struct DevServerToken(pub String);
59
60impl fmt::Display for DevServerToken {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 write!(f, "{}", self.0)
63 }
64}
65
66lazy_static! {
67 static ref ZED_SERVER_URL: Option<String> = std::env::var("ZED_SERVER_URL").ok();
68 static ref ZED_RPC_URL: Option<String> = std::env::var("ZED_RPC_URL").ok();
69 /// An environment variable whose presence indicates that the development auth
70 /// provider should be used.
71 ///
72 /// Only works in development. Setting this environment variable in other release
73 /// channels is a no-op.
74 pub static ref ZED_DEVELOPMENT_AUTH: bool =
75 std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty());
76 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
77 .ok()
78 .and_then(|s| if s.is_empty() { None } else { Some(s) });
79 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
80 .ok()
81 .and_then(|s| if s.is_empty() { None } else { Some(s) });
82 pub static ref ZED_APP_PATH: Option<PathBuf> =
83 std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
84 pub static ref ZED_ALWAYS_ACTIVE: bool =
85 std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty());
86}
87
88pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
89pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
90
91actions!(client, [SignIn, SignOut, Reconnect]);
92
93#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
94pub struct ClientSettingsContent {
95 server_url: Option<String>,
96}
97
98#[derive(Deserialize)]
99pub struct ClientSettings {
100 pub server_url: String,
101}
102
103impl Settings for ClientSettings {
104 const KEY: Option<&'static str> = None;
105
106 type FileContent = ClientSettingsContent;
107
108 fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
109 let mut result = sources.json_merge::<Self>()?;
110 if let Some(server_url) = &*ZED_SERVER_URL {
111 result.server_url.clone_from(&server_url)
112 }
113 Ok(result)
114 }
115}
116
117pub fn init_settings(cx: &mut AppContext) {
118 TelemetrySettings::register(cx);
119 cx.update_global(|store: &mut SettingsStore, cx| {
120 store.register_setting::<ClientSettings>(cx);
121 });
122}
123
124pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
125 let client = Arc::downgrade(client);
126 cx.on_action({
127 let client = client.clone();
128 move |_: &SignIn, cx| {
129 if let Some(client) = client.upgrade() {
130 cx.spawn(
131 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
132 )
133 .detach();
134 }
135 }
136 });
137
138 cx.on_action({
139 let client = client.clone();
140 move |_: &SignOut, cx| {
141 if let Some(client) = client.upgrade() {
142 cx.spawn(|cx| async move {
143 client.sign_out(&cx).await;
144 })
145 .detach();
146 }
147 }
148 });
149
150 cx.on_action({
151 let client = client.clone();
152 move |_: &Reconnect, cx| {
153 if let Some(client) = client.upgrade() {
154 cx.spawn(|cx| async move {
155 client.reconnect(&cx);
156 })
157 .detach();
158 }
159 }
160 });
161}
162
163struct GlobalClient(Arc<Client>);
164
165impl Global for GlobalClient {}
166
167pub struct Client {
168 id: AtomicU64,
169 peer: Arc<Peer>,
170 http: Arc<HttpClientWithUrl>,
171 telemetry: Arc<Telemetry>,
172 credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
173 state: RwLock<ClientState>,
174
175 #[allow(clippy::type_complexity)]
176 #[cfg(any(test, feature = "test-support"))]
177 authenticate: RwLock<
178 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
179 >,
180
181 #[allow(clippy::type_complexity)]
182 #[cfg(any(test, feature = "test-support"))]
183 establish_connection: RwLock<
184 Option<
185 Box<
186 dyn 'static
187 + Send
188 + Sync
189 + Fn(
190 &Credentials,
191 &AsyncAppContext,
192 ) -> Task<Result<Connection, EstablishConnectionError>>,
193 >,
194 >,
195 >,
196}
197
198#[derive(Error, Debug)]
199pub enum EstablishConnectionError {
200 #[error("upgrade required")]
201 UpgradeRequired,
202 #[error("unauthorized")]
203 Unauthorized,
204 #[error("{0}")]
205 Other(#[from] anyhow::Error),
206 #[error("{0}")]
207 Http(#[from] util::http::Error),
208 #[error("{0}")]
209 Io(#[from] std::io::Error),
210 #[error("{0}")]
211 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
212}
213
214impl From<WebsocketError> for EstablishConnectionError {
215 fn from(error: WebsocketError) -> Self {
216 if let WebsocketError::Http(response) = &error {
217 match response.status() {
218 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
219 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
220 _ => {}
221 }
222 }
223 EstablishConnectionError::Other(error.into())
224 }
225}
226
227impl EstablishConnectionError {
228 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
229 Self::Other(error.into())
230 }
231}
232
233#[derive(Copy, Clone, Debug, PartialEq)]
234pub enum Status {
235 SignedOut,
236 UpgradeRequired,
237 Authenticating,
238 Connecting,
239 ConnectionError,
240 Connected {
241 peer_id: PeerId,
242 connection_id: ConnectionId,
243 },
244 ConnectionLost,
245 Reauthenticating,
246 Reconnecting,
247 ReconnectionError {
248 next_reconnection: Instant,
249 },
250}
251
252impl Status {
253 pub fn is_connected(&self) -> bool {
254 matches!(self, Self::Connected { .. })
255 }
256
257 pub fn is_signed_out(&self) -> bool {
258 matches!(self, Self::SignedOut | Self::UpgradeRequired)
259 }
260}
261
262struct ClientState {
263 credentials: Option<Credentials>,
264 status: (watch::Sender<Status>, watch::Receiver<Status>),
265 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
266 _reconnect_task: Option<Task<()>>,
267 reconnect_interval: Duration,
268 entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
269 models_by_message_type: HashMap<TypeId, AnyWeakModel>,
270 entity_types_by_message_type: HashMap<TypeId, TypeId>,
271 #[allow(clippy::type_complexity)]
272 message_handlers: HashMap<
273 TypeId,
274 Arc<
275 dyn Send
276 + Sync
277 + Fn(
278 AnyModel,
279 Box<dyn AnyTypedEnvelope>,
280 &Arc<Client>,
281 AsyncAppContext,
282 ) -> LocalBoxFuture<'static, Result<()>>,
283 >,
284 >,
285}
286
287enum WeakSubscriber {
288 Entity { handle: AnyWeakModel },
289 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
290}
291
292#[derive(Clone, Debug, Eq, PartialEq)]
293pub enum Credentials {
294 DevServer { token: DevServerToken },
295 User { user_id: u64, access_token: String },
296}
297
298impl Credentials {
299 pub fn authorization_header(&self) -> String {
300 match self {
301 Credentials::DevServer { token } => format!("dev-server-token {}", token),
302 Credentials::User {
303 user_id,
304 access_token,
305 } => format!("{} {}", user_id, access_token),
306 }
307 }
308}
309
310/// A provider for [`Credentials`].
311///
312/// Used to abstract over reading and writing credentials to some form of
313/// persistence (like the system keychain).
314trait CredentialsProvider {
315 /// Reads the credentials from the provider.
316 fn read_credentials<'a>(
317 &'a self,
318 cx: &'a AsyncAppContext,
319 ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>>;
320
321 /// Writes the credentials to the provider.
322 fn write_credentials<'a>(
323 &'a self,
324 user_id: u64,
325 access_token: String,
326 cx: &'a AsyncAppContext,
327 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
328
329 /// Deletes the credentials from the provider.
330 fn delete_credentials<'a>(
331 &'a self,
332 cx: &'a AsyncAppContext,
333 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
334}
335
336impl Default for ClientState {
337 fn default() -> Self {
338 Self {
339 credentials: None,
340 status: watch::channel_with(Status::SignedOut),
341 entity_id_extractors: Default::default(),
342 _reconnect_task: None,
343 reconnect_interval: Duration::from_secs(5),
344 models_by_message_type: Default::default(),
345 entities_by_type_and_remote_id: Default::default(),
346 entity_types_by_message_type: Default::default(),
347 message_handlers: Default::default(),
348 }
349 }
350}
351
352pub enum Subscription {
353 Entity {
354 client: Weak<Client>,
355 id: (TypeId, u64),
356 },
357 Message {
358 client: Weak<Client>,
359 id: TypeId,
360 },
361}
362
363impl Drop for Subscription {
364 fn drop(&mut self) {
365 match self {
366 Subscription::Entity { client, id } => {
367 if let Some(client) = client.upgrade() {
368 let mut state = client.state.write();
369 let _ = state.entities_by_type_and_remote_id.remove(id);
370 }
371 }
372 Subscription::Message { client, id } => {
373 if let Some(client) = client.upgrade() {
374 let mut state = client.state.write();
375 let _ = state.entity_types_by_message_type.remove(id);
376 let _ = state.message_handlers.remove(id);
377 }
378 }
379 }
380 }
381}
382
383pub struct PendingEntitySubscription<T: 'static> {
384 client: Arc<Client>,
385 remote_id: u64,
386 _entity_type: PhantomData<T>,
387 consumed: bool,
388}
389
390impl<T: 'static> PendingEntitySubscription<T> {
391 pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
392 self.consumed = true;
393 let mut state = self.client.state.write();
394 let id = (TypeId::of::<T>(), self.remote_id);
395 let Some(WeakSubscriber::Pending(messages)) =
396 state.entities_by_type_and_remote_id.remove(&id)
397 else {
398 unreachable!()
399 };
400
401 state.entities_by_type_and_remote_id.insert(
402 id,
403 WeakSubscriber::Entity {
404 handle: model.downgrade().into(),
405 },
406 );
407 drop(state);
408 for message in messages {
409 self.client.handle_message(message, cx);
410 }
411 Subscription::Entity {
412 client: Arc::downgrade(&self.client),
413 id,
414 }
415 }
416}
417
418impl<T: 'static> Drop for PendingEntitySubscription<T> {
419 fn drop(&mut self) {
420 if !self.consumed {
421 let mut state = self.client.state.write();
422 if let Some(WeakSubscriber::Pending(messages)) = state
423 .entities_by_type_and_remote_id
424 .remove(&(TypeId::of::<T>(), self.remote_id))
425 {
426 for message in messages {
427 log::info!("unhandled message {}", message.payload_type_name());
428 }
429 }
430 }
431 }
432}
433
434#[derive(Copy, Clone)]
435pub struct TelemetrySettings {
436 pub diagnostics: bool,
437 pub metrics: bool,
438}
439
440/// Control what info is collected by Zed.
441#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
442pub struct TelemetrySettingsContent {
443 /// Send debug info like crash reports.
444 ///
445 /// Default: true
446 pub diagnostics: Option<bool>,
447 /// Send anonymized usage data like what languages you're using Zed with.
448 ///
449 /// Default: true
450 pub metrics: Option<bool>,
451}
452
453impl settings::Settings for TelemetrySettings {
454 const KEY: Option<&'static str> = Some("telemetry");
455
456 type FileContent = TelemetrySettingsContent;
457
458 fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
459 Ok(Self {
460 diagnostics: sources.user.as_ref().and_then(|v| v.diagnostics).unwrap_or(
461 sources
462 .default
463 .diagnostics
464 .ok_or_else(Self::missing_default)?,
465 ),
466 metrics: sources
467 .user
468 .as_ref()
469 .and_then(|v| v.metrics)
470 .unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
471 })
472 }
473}
474
475impl Client {
476 pub fn new(
477 clock: Arc<dyn SystemClock>,
478 http: Arc<HttpClientWithUrl>,
479 cx: &mut AppContext,
480 ) -> Arc<Self> {
481 let use_zed_development_auth = match ReleaseChannel::try_global(cx) {
482 Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH,
483 Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
484 | None => false,
485 };
486
487 let credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static> =
488 if use_zed_development_auth {
489 Arc::new(DevelopmentCredentialsProvider {
490 path: util::paths::CONFIG_DIR.join("development_auth"),
491 })
492 } else {
493 Arc::new(KeychainCredentialsProvider)
494 };
495
496 Arc::new(Self {
497 id: AtomicU64::new(0),
498 peer: Peer::new(0),
499 telemetry: Telemetry::new(clock, http.clone(), cx),
500 http,
501 credentials_provider,
502 state: Default::default(),
503
504 #[cfg(any(test, feature = "test-support"))]
505 authenticate: Default::default(),
506 #[cfg(any(test, feature = "test-support"))]
507 establish_connection: Default::default(),
508 })
509 }
510
511 pub fn production(cx: &mut AppContext) -> Arc<Self> {
512 let clock = Arc::new(clock::RealSystemClock);
513 let http = Arc::new(HttpClientWithUrl::new(
514 &ClientSettings::get_global(cx).server_url,
515 ));
516 Self::new(clock, http.clone(), cx)
517 }
518
519 pub fn id(&self) -> u64 {
520 self.id.load(Ordering::SeqCst)
521 }
522
523 pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
524 self.http.clone()
525 }
526
527 pub fn set_id(&self, id: u64) -> &Self {
528 self.id.store(id, Ordering::SeqCst);
529 self
530 }
531
532 #[cfg(any(test, feature = "test-support"))]
533 pub fn teardown(&self) {
534 let mut state = self.state.write();
535 state._reconnect_task.take();
536 state.message_handlers.clear();
537 state.models_by_message_type.clear();
538 state.entities_by_type_and_remote_id.clear();
539 state.entity_id_extractors.clear();
540 self.peer.teardown();
541 }
542
543 #[cfg(any(test, feature = "test-support"))]
544 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
545 where
546 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
547 {
548 *self.authenticate.write() = Some(Box::new(authenticate));
549 self
550 }
551
552 #[cfg(any(test, feature = "test-support"))]
553 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
554 where
555 F: 'static
556 + Send
557 + Sync
558 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
559 {
560 *self.establish_connection.write() = Some(Box::new(connect));
561 self
562 }
563
564 pub fn global(cx: &AppContext) -> Arc<Self> {
565 cx.global::<GlobalClient>().0.clone()
566 }
567 pub fn set_global(client: Arc<Client>, cx: &mut AppContext) {
568 cx.set_global(GlobalClient(client))
569 }
570
571 pub fn user_id(&self) -> Option<u64> {
572 if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() {
573 Some(*user_id)
574 } else {
575 None
576 }
577 }
578
579 pub fn peer_id(&self) -> Option<PeerId> {
580 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
581 Some(*peer_id)
582 } else {
583 None
584 }
585 }
586
587 pub fn status(&self) -> watch::Receiver<Status> {
588 self.state.read().status.1.clone()
589 }
590
591 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
592 log::info!("set status on client {}: {:?}", self.id(), status);
593 let mut state = self.state.write();
594 *state.status.0.borrow_mut() = status;
595
596 match status {
597 Status::Connected { .. } => {
598 state._reconnect_task = None;
599 }
600 Status::ConnectionLost => {
601 let this = self.clone();
602 let reconnect_interval = state.reconnect_interval;
603 state._reconnect_task = Some(cx.spawn(move |cx| async move {
604 #[cfg(any(test, feature = "test-support"))]
605 let mut rng = StdRng::seed_from_u64(0);
606 #[cfg(not(any(test, feature = "test-support")))]
607 let mut rng = StdRng::from_entropy();
608
609 let mut delay = INITIAL_RECONNECTION_DELAY;
610 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
611 log::error!("failed to connect {}", error);
612 if matches!(*this.status().borrow(), Status::ConnectionError) {
613 this.set_status(
614 Status::ReconnectionError {
615 next_reconnection: Instant::now() + delay,
616 },
617 &cx,
618 );
619 cx.background_executor().timer(delay).await;
620 delay = delay
621 .mul_f32(rng.gen_range(1.0..=2.0))
622 .min(reconnect_interval);
623 } else {
624 break;
625 }
626 }
627 }));
628 }
629 Status::SignedOut | Status::UpgradeRequired => {
630 self.telemetry.set_authenticated_user_info(None, false);
631 state._reconnect_task.take();
632 }
633 _ => {}
634 }
635 }
636
637 pub fn subscribe_to_entity<T>(
638 self: &Arc<Self>,
639 remote_id: u64,
640 ) -> Result<PendingEntitySubscription<T>>
641 where
642 T: 'static,
643 {
644 let id = (TypeId::of::<T>(), remote_id);
645
646 let mut state = self.state.write();
647 if state.entities_by_type_and_remote_id.contains_key(&id) {
648 return Err(anyhow!("already subscribed to entity"));
649 }
650
651 state
652 .entities_by_type_and_remote_id
653 .insert(id, WeakSubscriber::Pending(Default::default()));
654
655 Ok(PendingEntitySubscription {
656 client: self.clone(),
657 remote_id,
658 consumed: false,
659 _entity_type: PhantomData,
660 })
661 }
662
663 #[track_caller]
664 pub fn add_message_handler<M, E, H, F>(
665 self: &Arc<Self>,
666 entity: WeakModel<E>,
667 handler: H,
668 ) -> Subscription
669 where
670 M: EnvelopedMessage,
671 E: 'static,
672 H: 'static
673 + Sync
674 + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
675 + Send
676 + Sync,
677 F: 'static + Future<Output = Result<()>>,
678 {
679 let message_type_id = TypeId::of::<M>();
680 let mut state = self.state.write();
681 state
682 .models_by_message_type
683 .insert(message_type_id, entity.into());
684
685 let prev_handler = state.message_handlers.insert(
686 message_type_id,
687 Arc::new(move |subscriber, envelope, client, cx| {
688 let subscriber = subscriber.downcast::<E>().unwrap();
689 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
690 handler(subscriber, *envelope, client.clone(), cx).boxed_local()
691 }),
692 );
693 if prev_handler.is_some() {
694 let location = std::panic::Location::caller();
695 panic!(
696 "{}:{} registered handler for the same message {} twice",
697 location.file(),
698 location.line(),
699 std::any::type_name::<M>()
700 );
701 }
702
703 Subscription::Message {
704 client: Arc::downgrade(self),
705 id: message_type_id,
706 }
707 }
708
709 pub fn add_request_handler<M, E, H, F>(
710 self: &Arc<Self>,
711 model: WeakModel<E>,
712 handler: H,
713 ) -> Subscription
714 where
715 M: RequestMessage,
716 E: 'static,
717 H: 'static
718 + Sync
719 + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
720 + Send
721 + Sync,
722 F: 'static + Future<Output = Result<M::Response>>,
723 {
724 self.add_message_handler(model, move |handle, envelope, this, cx| {
725 Self::respond_to_request(
726 envelope.receipt(),
727 handler(handle, envelope, this.clone(), cx),
728 this,
729 )
730 })
731 }
732
733 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
734 where
735 M: EntityMessage,
736 E: 'static,
737 H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
738 F: 'static + Future<Output = Result<()>>,
739 {
740 self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
741 handler(subscriber.downcast::<E>().unwrap(), message, client, cx)
742 })
743 }
744
745 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
746 where
747 M: EntityMessage,
748 E: 'static,
749 H: 'static + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
750 F: 'static + Future<Output = Result<()>>,
751 {
752 let model_type_id = TypeId::of::<E>();
753 let message_type_id = TypeId::of::<M>();
754
755 let mut state = self.state.write();
756 state
757 .entity_types_by_message_type
758 .insert(message_type_id, model_type_id);
759 state
760 .entity_id_extractors
761 .entry(message_type_id)
762 .or_insert_with(|| {
763 |envelope| {
764 envelope
765 .as_any()
766 .downcast_ref::<TypedEnvelope<M>>()
767 .unwrap()
768 .payload
769 .remote_entity_id()
770 }
771 });
772 let prev_handler = state.message_handlers.insert(
773 message_type_id,
774 Arc::new(move |handle, envelope, client, cx| {
775 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
776 handler(handle, *envelope, client.clone(), cx).boxed_local()
777 }),
778 );
779 if prev_handler.is_some() {
780 panic!("registered handler for the same message twice");
781 }
782 }
783
784 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
785 where
786 M: EntityMessage + RequestMessage,
787 E: 'static,
788 H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
789 F: 'static + Future<Output = Result<M::Response>>,
790 {
791 self.add_model_message_handler(move |entity, envelope, client, cx| {
792 Self::respond_to_request::<M, _>(
793 envelope.receipt(),
794 handler(entity, envelope, client.clone(), cx),
795 client,
796 )
797 })
798 }
799
800 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
801 receipt: Receipt<T>,
802 response: F,
803 client: Arc<Self>,
804 ) -> Result<()> {
805 match response.await {
806 Ok(response) => {
807 client.respond(receipt, response)?;
808 Ok(())
809 }
810 Err(error) => {
811 client.respond_with_error(receipt, error.to_proto())?;
812 Err(error)
813 }
814 }
815 }
816
817 pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool {
818 self.credentials_provider
819 .read_credentials(cx)
820 .await
821 .is_some()
822 }
823
824 pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self {
825 self.state.write().credentials = Some(Credentials::DevServer { token });
826 self
827 }
828
829 #[async_recursion(?Send)]
830 pub async fn authenticate_and_connect(
831 self: &Arc<Self>,
832 try_provider: bool,
833 cx: &AsyncAppContext,
834 ) -> anyhow::Result<()> {
835 let was_disconnected = match *self.status().borrow() {
836 Status::SignedOut => true,
837 Status::ConnectionError
838 | Status::ConnectionLost
839 | Status::Authenticating { .. }
840 | Status::Reauthenticating { .. }
841 | Status::ReconnectionError { .. } => false,
842 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
843 return Ok(())
844 }
845 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
846 };
847 if was_disconnected {
848 self.set_status(Status::Authenticating, cx);
849 } else {
850 self.set_status(Status::Reauthenticating, cx)
851 }
852
853 let mut read_from_provider = false;
854 let mut credentials = self.state.read().credentials.clone();
855 if credentials.is_none() && try_provider {
856 credentials = self.credentials_provider.read_credentials(cx).await;
857 read_from_provider = credentials.is_some();
858 }
859
860 if credentials.is_none() {
861 let mut status_rx = self.status();
862 let _ = status_rx.next().await;
863 futures::select_biased! {
864 authenticate = self.authenticate(cx).fuse() => {
865 match authenticate {
866 Ok(creds) => credentials = Some(creds),
867 Err(err) => {
868 self.set_status(Status::ConnectionError, cx);
869 return Err(err);
870 }
871 }
872 }
873 _ = status_rx.next().fuse() => {
874 return Err(anyhow!("authentication canceled"));
875 }
876 }
877 }
878 let credentials = credentials.unwrap();
879 if let Credentials::User { user_id, .. } = &credentials {
880 self.set_id(*user_id);
881 }
882
883 if was_disconnected {
884 self.set_status(Status::Connecting, cx);
885 } else {
886 self.set_status(Status::Reconnecting, cx);
887 }
888
889 let mut timeout =
890 futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
891 futures::select_biased! {
892 connection = self.establish_connection(&credentials, cx).fuse() => {
893 match connection {
894 Ok(conn) => {
895 self.state.write().credentials = Some(credentials.clone());
896 if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
897 if let Credentials::User{user_id, access_token} = credentials {
898 self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err();
899 }
900 }
901
902 futures::select_biased! {
903 result = self.set_connection(conn, cx).fuse() => result,
904 _ = timeout => {
905 self.set_status(Status::ConnectionError, cx);
906 Err(anyhow!("timed out waiting on hello message from server"))
907 }
908 }
909 }
910 Err(EstablishConnectionError::Unauthorized) => {
911 self.state.write().credentials.take();
912 if read_from_provider {
913 self.credentials_provider.delete_credentials(cx).await.log_err();
914 self.set_status(Status::SignedOut, cx);
915 self.authenticate_and_connect(false, cx).await
916 } else {
917 self.set_status(Status::ConnectionError, cx);
918 Err(EstablishConnectionError::Unauthorized)?
919 }
920 }
921 Err(EstablishConnectionError::UpgradeRequired) => {
922 self.set_status(Status::UpgradeRequired, cx);
923 Err(EstablishConnectionError::UpgradeRequired)?
924 }
925 Err(error) => {
926 self.set_status(Status::ConnectionError, cx);
927 Err(error)?
928 }
929 }
930 }
931 _ = &mut timeout => {
932 self.set_status(Status::ConnectionError, cx);
933 Err(anyhow!("timed out trying to establish connection"))
934 }
935 }
936 }
937
938 async fn set_connection(
939 self: &Arc<Self>,
940 conn: Connection,
941 cx: &AsyncAppContext,
942 ) -> Result<()> {
943 let executor = cx.background_executor();
944 log::info!("add connection to peer");
945 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
946 let executor = executor.clone();
947 move |duration| executor.timer(duration)
948 });
949 let handle_io = executor.spawn(handle_io);
950
951 let peer_id = async {
952 log::info!("waiting for server hello");
953 let message = incoming
954 .next()
955 .await
956 .ok_or_else(|| anyhow!("no hello message received"))?;
957 log::info!("got server hello");
958 let hello_message_type_name = message.payload_type_name().to_string();
959 let hello = message
960 .into_any()
961 .downcast::<TypedEnvelope<proto::Hello>>()
962 .map_err(|_| {
963 anyhow!(
964 "invalid hello message received: {:?}",
965 hello_message_type_name
966 )
967 })?;
968 let peer_id = hello
969 .payload
970 .peer_id
971 .ok_or_else(|| anyhow!("invalid peer id"))?;
972 Ok(peer_id)
973 };
974
975 let peer_id = match peer_id.await {
976 Ok(peer_id) => peer_id,
977 Err(error) => {
978 self.peer.disconnect(connection_id);
979 return Err(error);
980 }
981 };
982
983 log::info!(
984 "set status to connected (connection id: {:?}, peer id: {:?})",
985 connection_id,
986 peer_id
987 );
988 self.set_status(
989 Status::Connected {
990 peer_id,
991 connection_id,
992 },
993 cx,
994 );
995
996 cx.spawn({
997 let this = self.clone();
998 |cx| {
999 async move {
1000 while let Some(message) = incoming.next().await {
1001 this.handle_message(message, &cx);
1002 // Don't starve the main thread when receiving lots of messages at once.
1003 smol::future::yield_now().await;
1004 }
1005 }
1006 }
1007 })
1008 .detach();
1009
1010 cx.spawn({
1011 let this = self.clone();
1012 move |cx| async move {
1013 match handle_io.await {
1014 Ok(()) => {
1015 if *this.status().borrow()
1016 == (Status::Connected {
1017 connection_id,
1018 peer_id,
1019 })
1020 {
1021 this.set_status(Status::SignedOut, &cx);
1022 }
1023 }
1024 Err(err) => {
1025 log::error!("connection error: {:?}", err);
1026 this.set_status(Status::ConnectionLost, &cx);
1027 }
1028 }
1029 }
1030 })
1031 .detach();
1032
1033 Ok(())
1034 }
1035
1036 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
1037 #[cfg(any(test, feature = "test-support"))]
1038 if let Some(callback) = self.authenticate.read().as_ref() {
1039 return callback(cx);
1040 }
1041
1042 self.authenticate_with_browser(cx)
1043 }
1044
1045 fn establish_connection(
1046 self: &Arc<Self>,
1047 credentials: &Credentials,
1048 cx: &AsyncAppContext,
1049 ) -> Task<Result<Connection, EstablishConnectionError>> {
1050 #[cfg(any(test, feature = "test-support"))]
1051 if let Some(callback) = self.establish_connection.read().as_ref() {
1052 return callback(credentials, cx);
1053 }
1054
1055 self.establish_websocket_connection(credentials, cx)
1056 }
1057
1058 async fn get_rpc_url(
1059 http: Arc<HttpClientWithUrl>,
1060 release_channel: Option<ReleaseChannel>,
1061 ) -> Result<Url> {
1062 if let Some(url) = &*ZED_RPC_URL {
1063 return Url::parse(url).context("invalid rpc url");
1064 }
1065
1066 let mut url = http.build_url("/rpc");
1067 if let Some(preview_param) =
1068 release_channel.and_then(|channel| channel.release_query_param())
1069 {
1070 url += "?";
1071 url += preview_param;
1072 }
1073 let response = http.get(&url, Default::default(), false).await?;
1074 let collab_url = if response.status().is_redirection() {
1075 response
1076 .headers()
1077 .get("Location")
1078 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1079 .to_str()
1080 .map_err(EstablishConnectionError::other)?
1081 .to_string()
1082 } else {
1083 Err(anyhow!(
1084 "unexpected /rpc response status {}",
1085 response.status()
1086 ))?
1087 };
1088
1089 Url::parse(&collab_url).context("invalid rpc url")
1090 }
1091
1092 fn establish_websocket_connection(
1093 self: &Arc<Self>,
1094 credentials: &Credentials,
1095 cx: &AsyncAppContext,
1096 ) -> Task<Result<Connection, EstablishConnectionError>> {
1097 let release_channel = cx
1098 .update(|cx| ReleaseChannel::try_global(cx))
1099 .ok()
1100 .flatten();
1101 let app_version = cx
1102 .update(|cx| AppVersion::global(cx).to_string())
1103 .ok()
1104 .unwrap_or_default();
1105
1106 let request = Request::builder()
1107 .header("Authorization", credentials.authorization_header())
1108 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION)
1109 .header("x-zed-app-version", app_version)
1110 .header(
1111 "x-zed-release-channel",
1112 release_channel.map(|r| r.dev_name()).unwrap_or("unknown"),
1113 );
1114
1115 let http = self.http.clone();
1116 cx.background_executor().spawn(async move {
1117 let mut rpc_url = Self::get_rpc_url(http, release_channel).await?;
1118 let rpc_host = rpc_url
1119 .host_str()
1120 .zip(rpc_url.port_or_known_default())
1121 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1122 let stream = smol::net::TcpStream::connect(rpc_host).await?;
1123
1124 log::info!("connected to rpc endpoint {}", rpc_url);
1125
1126 match rpc_url.scheme() {
1127 "https" => {
1128 rpc_url.set_scheme("wss").unwrap();
1129 let request = request.uri(rpc_url.as_str()).body(())?;
1130 let (stream, _) =
1131 async_tungstenite::async_std::client_async_tls(request, stream).await?;
1132 Ok(Connection::new(
1133 stream
1134 .map_err(|error| anyhow!(error))
1135 .sink_map_err(|error| anyhow!(error)),
1136 ))
1137 }
1138 "http" => {
1139 rpc_url.set_scheme("ws").unwrap();
1140 let request = request.uri(rpc_url.as_str()).body(())?;
1141 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1142 Ok(Connection::new(
1143 stream
1144 .map_err(|error| anyhow!(error))
1145 .sink_map_err(|error| anyhow!(error)),
1146 ))
1147 }
1148 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1149 }
1150 })
1151 }
1152
1153 pub fn authenticate_with_browser(
1154 self: &Arc<Self>,
1155 cx: &AsyncAppContext,
1156 ) -> Task<Result<Credentials>> {
1157 let http = self.http.clone();
1158 cx.spawn(|cx| async move {
1159 let background = cx.background_executor().clone();
1160
1161 let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1162 cx.update(|cx| {
1163 cx.spawn(move |cx| async move {
1164 let url = open_url_rx.await?;
1165 cx.update(|cx| cx.open_url(&url))
1166 })
1167 .detach_and_log_err(cx);
1168 })
1169 .log_err();
1170
1171 let credentials = background
1172 .clone()
1173 .spawn(async move {
1174 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1175 // zed server to encrypt the user's access token, so that it can'be intercepted by
1176 // any other app running on the user's device.
1177 let (public_key, private_key) =
1178 rpc::auth::keypair().expect("failed to generate keypair for auth");
1179 let public_key_string = String::try_from(public_key)
1180 .expect("failed to serialize public key for auth");
1181
1182 if let Some((login, token)) =
1183 IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1184 {
1185 eprintln!("authenticate as admin {login}, {token}");
1186
1187 return Self::authenticate_as_admin(http, login.clone(), token.clone())
1188 .await;
1189 }
1190
1191 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1192 let server =
1193 tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1194 let port = server.server_addr().port();
1195
1196 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1197 // that the user is signing in from a Zed app running on the same device.
1198 let mut url = http.build_url(&format!(
1199 "/native_app_signin?native_app_port={}&native_app_public_key={}",
1200 port, public_key_string
1201 ));
1202
1203 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1204 log::info!("impersonating user @{}", impersonate_login);
1205 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1206 }
1207
1208 open_url_tx.send(url).log_err();
1209
1210 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1211 // access token from the query params.
1212 //
1213 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1214 // custom URL scheme instead of this local HTTP server.
1215 let (user_id, access_token) = background
1216 .spawn(async move {
1217 for _ in 0..100 {
1218 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1219 let path = req.url();
1220 let mut user_id = None;
1221 let mut access_token = None;
1222 let url = Url::parse(&format!("http://example.com{}", path))
1223 .context("failed to parse login notification url")?;
1224 for (key, value) in url.query_pairs() {
1225 if key == "access_token" {
1226 access_token = Some(value.to_string());
1227 } else if key == "user_id" {
1228 user_id = Some(value.to_string());
1229 }
1230 }
1231
1232 let post_auth_url =
1233 http.build_url("/native_app_signin_succeeded");
1234 req.respond(
1235 tiny_http::Response::empty(302).with_header(
1236 tiny_http::Header::from_bytes(
1237 &b"Location"[..],
1238 post_auth_url.as_bytes(),
1239 )
1240 .unwrap(),
1241 ),
1242 )
1243 .context("failed to respond to login http request")?;
1244 return Ok((
1245 user_id
1246 .ok_or_else(|| anyhow!("missing user_id parameter"))?,
1247 access_token.ok_or_else(|| {
1248 anyhow!("missing access_token parameter")
1249 })?,
1250 ));
1251 }
1252 }
1253
1254 Err(anyhow!("didn't receive login redirect"))
1255 })
1256 .await?;
1257
1258 let access_token = private_key
1259 .decrypt_string(&access_token)
1260 .context("failed to decrypt access token")?;
1261
1262 Ok(Credentials::User {
1263 user_id: user_id.parse()?,
1264 access_token,
1265 })
1266 })
1267 .await?;
1268
1269 cx.update(|cx| cx.activate(true))?;
1270 Ok(credentials)
1271 })
1272 }
1273
1274 async fn authenticate_as_admin(
1275 http: Arc<HttpClientWithUrl>,
1276 login: String,
1277 mut api_token: String,
1278 ) -> Result<Credentials> {
1279 #[derive(Deserialize)]
1280 struct AuthenticatedUserResponse {
1281 user: User,
1282 }
1283
1284 #[derive(Deserialize)]
1285 struct User {
1286 id: u64,
1287 }
1288
1289 // Use the collab server's admin API to retrieve the id
1290 // of the impersonated user.
1291 let mut url = Self::get_rpc_url(http.clone(), None).await?;
1292 url.set_path("/user");
1293 url.set_query(Some(&format!("github_login={login}")));
1294 let request = Request::get(url.as_str())
1295 .header("Authorization", format!("token {api_token}"))
1296 .body("".into())?;
1297
1298 let mut response = http.send(request).await?;
1299 let mut body = String::new();
1300 response.body_mut().read_to_string(&mut body).await?;
1301 if !response.status().is_success() {
1302 Err(anyhow!(
1303 "admin user request failed {} - {}",
1304 response.status().as_u16(),
1305 body,
1306 ))?;
1307 }
1308 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1309
1310 // Use the admin API token to authenticate as the impersonated user.
1311 api_token.insert_str(0, "ADMIN_TOKEN:");
1312 Ok(Credentials::User {
1313 user_id: response.user.id,
1314 access_token: api_token,
1315 })
1316 }
1317
1318 pub async fn sign_out(self: &Arc<Self>, cx: &AsyncAppContext) {
1319 self.state.write().credentials = None;
1320 self.disconnect(&cx);
1321
1322 if self.has_credentials(cx).await {
1323 self.credentials_provider
1324 .delete_credentials(cx)
1325 .await
1326 .log_err();
1327 }
1328 }
1329
1330 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1331 self.peer.teardown();
1332 self.set_status(Status::SignedOut, cx);
1333 }
1334
1335 pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1336 self.peer.teardown();
1337 self.set_status(Status::ConnectionLost, cx);
1338 }
1339
1340 fn connection_id(&self) -> Result<ConnectionId> {
1341 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1342 Ok(connection_id)
1343 } else {
1344 Err(anyhow!("not connected"))
1345 }
1346 }
1347
1348 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1349 log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1350 self.peer.send(self.connection_id()?, message)
1351 }
1352
1353 pub fn request<T: RequestMessage>(
1354 &self,
1355 request: T,
1356 ) -> impl Future<Output = Result<T::Response>> {
1357 self.request_envelope(request)
1358 .map_ok(|envelope| envelope.payload)
1359 }
1360
1361 pub fn request_stream<T: RequestMessage>(
1362 &self,
1363 request: T,
1364 ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1365 let client_id = self.id.load(Ordering::SeqCst);
1366 log::debug!(
1367 "rpc request start. client_id:{}. name:{}",
1368 client_id,
1369 T::NAME
1370 );
1371 let response = self
1372 .connection_id()
1373 .map(|conn_id| self.peer.request_stream(conn_id, request));
1374 async move {
1375 let response = response?.await;
1376 log::debug!(
1377 "rpc request finish. client_id:{}. name:{}",
1378 client_id,
1379 T::NAME
1380 );
1381 response
1382 }
1383 }
1384
1385 pub fn request_envelope<T: RequestMessage>(
1386 &self,
1387 request: T,
1388 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1389 let client_id = self.id();
1390 log::debug!(
1391 "rpc request start. client_id:{}. name:{}",
1392 client_id,
1393 T::NAME
1394 );
1395 let response = self
1396 .connection_id()
1397 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1398 async move {
1399 let response = response?.await;
1400 log::debug!(
1401 "rpc request finish. client_id:{}. name:{}",
1402 client_id,
1403 T::NAME
1404 );
1405 response
1406 }
1407 }
1408
1409 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1410 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1411 self.peer.respond(receipt, response)
1412 }
1413
1414 fn respond_with_error<T: RequestMessage>(
1415 &self,
1416 receipt: Receipt<T>,
1417 error: proto::Error,
1418 ) -> Result<()> {
1419 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1420 self.peer.respond_with_error(receipt, error)
1421 }
1422
1423 fn handle_message(
1424 self: &Arc<Client>,
1425 message: Box<dyn AnyTypedEnvelope>,
1426 cx: &AsyncAppContext,
1427 ) {
1428 let mut state = self.state.write();
1429 let type_name = message.payload_type_name();
1430 let payload_type_id = message.payload_type_id();
1431 let sender_id = message.original_sender_id();
1432
1433 let mut subscriber = None;
1434
1435 if let Some(handle) = state
1436 .models_by_message_type
1437 .get(&payload_type_id)
1438 .and_then(|handle| handle.upgrade())
1439 {
1440 subscriber = Some(handle);
1441 } else if let Some((extract_entity_id, entity_type_id)) =
1442 state.entity_id_extractors.get(&payload_type_id).zip(
1443 state
1444 .entity_types_by_message_type
1445 .get(&payload_type_id)
1446 .copied(),
1447 )
1448 {
1449 let entity_id = (extract_entity_id)(message.as_ref());
1450
1451 match state
1452 .entities_by_type_and_remote_id
1453 .get_mut(&(entity_type_id, entity_id))
1454 {
1455 Some(WeakSubscriber::Pending(pending)) => {
1456 pending.push(message);
1457 return;
1458 }
1459 Some(weak_subscriber) => match weak_subscriber {
1460 WeakSubscriber::Entity { handle } => {
1461 subscriber = handle.upgrade();
1462 }
1463
1464 WeakSubscriber::Pending(_) => {}
1465 },
1466 _ => {}
1467 }
1468 }
1469
1470 let subscriber = if let Some(subscriber) = subscriber {
1471 subscriber
1472 } else {
1473 log::info!("unhandled message {}", type_name);
1474 self.peer.respond_with_unhandled_message(message).log_err();
1475 return;
1476 };
1477
1478 let handler = state.message_handlers.get(&payload_type_id).cloned();
1479 // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1480 // It also ensures we don't hold the lock while yielding back to the executor, as
1481 // that might cause the executor thread driving this future to block indefinitely.
1482 drop(state);
1483
1484 if let Some(handler) = handler {
1485 let future = handler(subscriber, message, self, cx.clone());
1486 let client_id = self.id();
1487 log::debug!(
1488 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1489 client_id,
1490 sender_id,
1491 type_name
1492 );
1493 cx.spawn(move |_| async move {
1494 match future.await {
1495 Ok(()) => {
1496 log::debug!(
1497 "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1498 client_id,
1499 sender_id,
1500 type_name
1501 );
1502 }
1503 Err(error) => {
1504 log::error!(
1505 "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1506 client_id,
1507 sender_id,
1508 type_name,
1509 error
1510 );
1511 }
1512 }
1513 })
1514 .detach();
1515 } else {
1516 log::info!("unhandled message {}", type_name);
1517 self.peer.respond_with_unhandled_message(message).log_err();
1518 }
1519 }
1520
1521 pub fn telemetry(&self) -> &Arc<Telemetry> {
1522 &self.telemetry
1523 }
1524}
1525
1526#[derive(Serialize, Deserialize)]
1527struct DevelopmentCredentials {
1528 user_id: u64,
1529 access_token: String,
1530}
1531
1532/// A credentials provider that stores credentials in a local file.
1533///
1534/// This MUST only be used in development, as this is not a secure way of storing
1535/// credentials on user machines.
1536///
1537/// Its existence is purely to work around the annoyance of having to constantly
1538/// re-allow access to the system keychain when developing Zed.
1539struct DevelopmentCredentialsProvider {
1540 path: PathBuf,
1541}
1542
1543impl CredentialsProvider for DevelopmentCredentialsProvider {
1544 fn read_credentials<'a>(
1545 &'a self,
1546 _cx: &'a AsyncAppContext,
1547 ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1548 async move {
1549 if IMPERSONATE_LOGIN.is_some() {
1550 return None;
1551 }
1552
1553 let json = std::fs::read(&self.path).log_err()?;
1554
1555 let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?;
1556
1557 Some(Credentials::User {
1558 user_id: credentials.user_id,
1559 access_token: credentials.access_token,
1560 })
1561 }
1562 .boxed_local()
1563 }
1564
1565 fn write_credentials<'a>(
1566 &'a self,
1567 user_id: u64,
1568 access_token: String,
1569 _cx: &'a AsyncAppContext,
1570 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1571 async move {
1572 let json = serde_json::to_string(&DevelopmentCredentials {
1573 user_id,
1574 access_token,
1575 })?;
1576
1577 std::fs::write(&self.path, json)?;
1578
1579 Ok(())
1580 }
1581 .boxed_local()
1582 }
1583
1584 fn delete_credentials<'a>(
1585 &'a self,
1586 _cx: &'a AsyncAppContext,
1587 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1588 async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local()
1589 }
1590}
1591
1592/// A credentials provider that stores credentials in the system keychain.
1593struct KeychainCredentialsProvider;
1594
1595impl CredentialsProvider for KeychainCredentialsProvider {
1596 fn read_credentials<'a>(
1597 &'a self,
1598 cx: &'a AsyncAppContext,
1599 ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1600 async move {
1601 if IMPERSONATE_LOGIN.is_some() {
1602 return None;
1603 }
1604
1605 let (user_id, access_token) = cx
1606 .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
1607 .log_err()?
1608 .await
1609 .log_err()??;
1610
1611 Some(Credentials::User {
1612 user_id: user_id.parse().ok()?,
1613 access_token: String::from_utf8(access_token).ok()?,
1614 })
1615 }
1616 .boxed_local()
1617 }
1618
1619 fn write_credentials<'a>(
1620 &'a self,
1621 user_id: u64,
1622 access_token: String,
1623 cx: &'a AsyncAppContext,
1624 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1625 async move {
1626 cx.update(move |cx| {
1627 cx.write_credentials(
1628 &ClientSettings::get_global(cx).server_url,
1629 &user_id.to_string(),
1630 access_token.as_bytes(),
1631 )
1632 })?
1633 .await
1634 }
1635 .boxed_local()
1636 }
1637
1638 fn delete_credentials<'a>(
1639 &'a self,
1640 cx: &'a AsyncAppContext,
1641 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1642 async move {
1643 cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
1644 .await
1645 }
1646 .boxed_local()
1647 }
1648}
1649
1650/// prefix for the zed:// url scheme
1651pub static ZED_URL_SCHEME: &str = "zed";
1652
1653/// Parses the given link into a Zed link.
1654///
1655/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1656/// Returns [`None`] otherwise.
1657pub fn parse_zed_link<'a>(link: &'a str, cx: &AppContext) -> Option<&'a str> {
1658 let server_url = &ClientSettings::get_global(cx).server_url;
1659 if let Some(stripped) = link
1660 .strip_prefix(server_url)
1661 .and_then(|result| result.strip_prefix('/'))
1662 {
1663 return Some(stripped);
1664 }
1665 if let Some(stripped) = link
1666 .strip_prefix(ZED_URL_SCHEME)
1667 .and_then(|result| result.strip_prefix("://"))
1668 {
1669 return Some(stripped);
1670 }
1671
1672 None
1673}
1674
1675#[cfg(test)]
1676mod tests {
1677 use super::*;
1678 use crate::test::FakeServer;
1679
1680 use clock::FakeSystemClock;
1681 use gpui::{BackgroundExecutor, Context, TestAppContext};
1682 use parking_lot::Mutex;
1683 use settings::SettingsStore;
1684 use std::future;
1685 use util::http::FakeHttpClient;
1686
1687 #[gpui::test(iterations = 10)]
1688 async fn test_reconnection(cx: &mut TestAppContext) {
1689 init_test(cx);
1690 let user_id = 5;
1691 let client = cx.update(|cx| {
1692 Client::new(
1693 Arc::new(FakeSystemClock::default()),
1694 FakeHttpClient::with_404_response(),
1695 cx,
1696 )
1697 });
1698 let server = FakeServer::for_client(user_id, &client, cx).await;
1699 let mut status = client.status();
1700 assert!(matches!(
1701 status.next().await,
1702 Some(Status::Connected { .. })
1703 ));
1704 assert_eq!(server.auth_count(), 1);
1705
1706 server.forbid_connections();
1707 server.disconnect();
1708 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1709
1710 server.allow_connections();
1711 cx.executor().advance_clock(Duration::from_secs(10));
1712 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1713 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1714
1715 server.forbid_connections();
1716 server.disconnect();
1717 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1718
1719 // Clear cached credentials after authentication fails
1720 server.roll_access_token();
1721 server.allow_connections();
1722 cx.executor().run_until_parked();
1723 cx.executor().advance_clock(Duration::from_secs(10));
1724 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1725 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1726 }
1727
1728 #[gpui::test(iterations = 10)]
1729 async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1730 init_test(cx);
1731 let user_id = 5;
1732 let client = cx.update(|cx| {
1733 Client::new(
1734 Arc::new(FakeSystemClock::default()),
1735 FakeHttpClient::with_404_response(),
1736 cx,
1737 )
1738 });
1739 let mut status = client.status();
1740
1741 // Time out when client tries to connect.
1742 client.override_authenticate(move |cx| {
1743 cx.background_executor().spawn(async move {
1744 Ok(Credentials::User {
1745 user_id,
1746 access_token: "token".into(),
1747 })
1748 })
1749 });
1750 client.override_establish_connection(|_, cx| {
1751 cx.background_executor().spawn(async move {
1752 future::pending::<()>().await;
1753 unreachable!()
1754 })
1755 });
1756 let auth_and_connect = cx.spawn({
1757 let client = client.clone();
1758 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1759 });
1760 executor.run_until_parked();
1761 assert!(matches!(status.next().await, Some(Status::Connecting)));
1762
1763 executor.advance_clock(CONNECTION_TIMEOUT);
1764 assert!(matches!(
1765 status.next().await,
1766 Some(Status::ConnectionError { .. })
1767 ));
1768 auth_and_connect.await.unwrap_err();
1769
1770 // Allow the connection to be established.
1771 let server = FakeServer::for_client(user_id, &client, cx).await;
1772 assert!(matches!(
1773 status.next().await,
1774 Some(Status::Connected { .. })
1775 ));
1776
1777 // Disconnect client.
1778 server.forbid_connections();
1779 server.disconnect();
1780 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1781
1782 // Time out when re-establishing the connection.
1783 server.allow_connections();
1784 client.override_establish_connection(|_, cx| {
1785 cx.background_executor().spawn(async move {
1786 future::pending::<()>().await;
1787 unreachable!()
1788 })
1789 });
1790 executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1791 assert!(matches!(
1792 status.next().await,
1793 Some(Status::Reconnecting { .. })
1794 ));
1795
1796 executor.advance_clock(CONNECTION_TIMEOUT);
1797 assert!(matches!(
1798 status.next().await,
1799 Some(Status::ReconnectionError { .. })
1800 ));
1801 }
1802
1803 #[gpui::test(iterations = 10)]
1804 async fn test_authenticating_more_than_once(
1805 cx: &mut TestAppContext,
1806 executor: BackgroundExecutor,
1807 ) {
1808 init_test(cx);
1809 let auth_count = Arc::new(Mutex::new(0));
1810 let dropped_auth_count = Arc::new(Mutex::new(0));
1811 let client = cx.update(|cx| {
1812 Client::new(
1813 Arc::new(FakeSystemClock::default()),
1814 FakeHttpClient::with_404_response(),
1815 cx,
1816 )
1817 });
1818 client.override_authenticate({
1819 let auth_count = auth_count.clone();
1820 let dropped_auth_count = dropped_auth_count.clone();
1821 move |cx| {
1822 let auth_count = auth_count.clone();
1823 let dropped_auth_count = dropped_auth_count.clone();
1824 cx.background_executor().spawn(async move {
1825 *auth_count.lock() += 1;
1826 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1827 future::pending::<()>().await;
1828 unreachable!()
1829 })
1830 }
1831 });
1832
1833 let _authenticate = cx.spawn({
1834 let client = client.clone();
1835 move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1836 });
1837 executor.run_until_parked();
1838 assert_eq!(*auth_count.lock(), 1);
1839 assert_eq!(*dropped_auth_count.lock(), 0);
1840
1841 let _authenticate = cx.spawn({
1842 let client = client.clone();
1843 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1844 });
1845 executor.run_until_parked();
1846 assert_eq!(*auth_count.lock(), 2);
1847 assert_eq!(*dropped_auth_count.lock(), 1);
1848 }
1849
1850 #[gpui::test]
1851 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1852 init_test(cx);
1853 let user_id = 5;
1854 let client = cx.update(|cx| {
1855 Client::new(
1856 Arc::new(FakeSystemClock::default()),
1857 FakeHttpClient::with_404_response(),
1858 cx,
1859 )
1860 });
1861 let server = FakeServer::for_client(user_id, &client, cx).await;
1862
1863 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1864 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1865 client.add_model_message_handler(
1866 move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
1867 match model.update(&mut cx, |model, _| model.id).unwrap() {
1868 1 => done_tx1.try_send(()).unwrap(),
1869 2 => done_tx2.try_send(()).unwrap(),
1870 _ => unreachable!(),
1871 }
1872 async { Ok(()) }
1873 },
1874 );
1875 let model1 = cx.new_model(|_| TestModel {
1876 id: 1,
1877 subscription: None,
1878 });
1879 let model2 = cx.new_model(|_| TestModel {
1880 id: 2,
1881 subscription: None,
1882 });
1883 let model3 = cx.new_model(|_| TestModel {
1884 id: 3,
1885 subscription: None,
1886 });
1887
1888 let _subscription1 = client
1889 .subscribe_to_entity(1)
1890 .unwrap()
1891 .set_model(&model1, &mut cx.to_async());
1892 let _subscription2 = client
1893 .subscribe_to_entity(2)
1894 .unwrap()
1895 .set_model(&model2, &mut cx.to_async());
1896 // Ensure dropping a subscription for the same entity type still allows receiving of
1897 // messages for other entity IDs of the same type.
1898 let subscription3 = client
1899 .subscribe_to_entity(3)
1900 .unwrap()
1901 .set_model(&model3, &mut cx.to_async());
1902 drop(subscription3);
1903
1904 server.send(proto::JoinProject { project_id: 1 });
1905 server.send(proto::JoinProject { project_id: 2 });
1906 done_rx1.next().await.unwrap();
1907 done_rx2.next().await.unwrap();
1908 }
1909
1910 #[gpui::test]
1911 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1912 init_test(cx);
1913 let user_id = 5;
1914 let client = cx.update(|cx| {
1915 Client::new(
1916 Arc::new(FakeSystemClock::default()),
1917 FakeHttpClient::with_404_response(),
1918 cx,
1919 )
1920 });
1921 let server = FakeServer::for_client(user_id, &client, cx).await;
1922
1923 let model = cx.new_model(|_| TestModel::default());
1924 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1925 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1926 let subscription1 = client.add_message_handler(
1927 model.downgrade(),
1928 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1929 done_tx1.try_send(()).unwrap();
1930 async { Ok(()) }
1931 },
1932 );
1933 drop(subscription1);
1934 let _subscription2 = client.add_message_handler(
1935 model.downgrade(),
1936 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1937 done_tx2.try_send(()).unwrap();
1938 async { Ok(()) }
1939 },
1940 );
1941 server.send(proto::Ping {});
1942 done_rx2.next().await.unwrap();
1943 }
1944
1945 #[gpui::test]
1946 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1947 init_test(cx);
1948 let user_id = 5;
1949 let client = cx.update(|cx| {
1950 Client::new(
1951 Arc::new(FakeSystemClock::default()),
1952 FakeHttpClient::with_404_response(),
1953 cx,
1954 )
1955 });
1956 let server = FakeServer::for_client(user_id, &client, cx).await;
1957
1958 let model = cx.new_model(|_| TestModel::default());
1959 let (done_tx, mut done_rx) = smol::channel::unbounded();
1960 let subscription = client.add_message_handler(
1961 model.clone().downgrade(),
1962 move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1963 model
1964 .update(&mut cx, |model, _| model.subscription.take())
1965 .unwrap();
1966 done_tx.try_send(()).unwrap();
1967 async { Ok(()) }
1968 },
1969 );
1970 model.update(cx, |model, _| {
1971 model.subscription = Some(subscription);
1972 });
1973 server.send(proto::Ping {});
1974 done_rx.next().await.unwrap();
1975 }
1976
1977 #[derive(Default)]
1978 struct TestModel {
1979 id: usize,
1980 subscription: Option<Subscription>,
1981 }
1982
1983 fn init_test(cx: &mut TestAppContext) {
1984 cx.update(|cx| {
1985 let settings_store = SettingsStore::test(cx);
1986 cx.set_global(settings_store);
1987 init_settings(cx);
1988 });
1989 }
1990}