1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod telemetry;
5pub mod user;
6
7use anyhow::{anyhow, Context as _, Result};
8use async_recursion::async_recursion;
9use async_tungstenite::tungstenite::{
10 error::Error as WebsocketError,
11 http::{Request, StatusCode},
12};
13use clock::SystemClock;
14use collections::HashMap;
15use futures::{
16 channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
17 TryFutureExt as _, TryStreamExt,
18};
19use gpui::{
20 actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Global, Model, Task, WeakModel,
21};
22use http::{HttpClient, HttpClientWithUrl};
23use lazy_static::lazy_static;
24use parking_lot::RwLock;
25use postage::watch;
26use rand::prelude::*;
27use release_channel::{AppVersion, ReleaseChannel};
28use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::{Settings, SettingsSources};
32use std::fmt;
33use std::pin::Pin;
34use std::{
35 any::TypeId,
36 convert::TryFrom,
37 fmt::Write as _,
38 future::Future,
39 marker::PhantomData,
40 path::PathBuf,
41 sync::{
42 atomic::{AtomicU64, Ordering},
43 Arc, Weak,
44 },
45 time::{Duration, Instant},
46};
47use telemetry::Telemetry;
48use thiserror::Error;
49use url::Url;
50use util::{ResultExt, TryFutureExt};
51
52pub use rpc::*;
53pub use telemetry_events::Event;
54pub use user::*;
55
56#[derive(Debug, Clone, Eq, PartialEq)]
57pub struct DevServerToken(pub String);
58
59impl fmt::Display for DevServerToken {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 write!(f, "{}", self.0)
62 }
63}
64
65lazy_static! {
66 static ref ZED_SERVER_URL: Option<String> = std::env::var("ZED_SERVER_URL").ok();
67 static ref ZED_RPC_URL: Option<String> = std::env::var("ZED_RPC_URL").ok();
68 /// An environment variable whose presence indicates that the development auth
69 /// provider should be used.
70 ///
71 /// Only works in development. Setting this environment variable in other release
72 /// channels is a no-op.
73 pub static ref ZED_DEVELOPMENT_AUTH: bool =
74 std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty());
75 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
76 .ok()
77 .and_then(|s| if s.is_empty() { None } else { Some(s) });
78 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
79 .ok()
80 .and_then(|s| if s.is_empty() { None } else { Some(s) });
81 pub static ref ZED_APP_PATH: Option<PathBuf> =
82 std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
83 pub static ref ZED_ALWAYS_ACTIVE: bool =
84 std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty());
85}
86
87pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
88pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
89pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
90
91actions!(client, [SignIn, SignOut, Reconnect]);
92
93#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
94pub struct ClientSettingsContent {
95 server_url: Option<String>,
96}
97
98#[derive(Deserialize)]
99pub struct ClientSettings {
100 pub server_url: String,
101}
102
103impl Settings for ClientSettings {
104 const KEY: Option<&'static str> = None;
105
106 type FileContent = ClientSettingsContent;
107
108 fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
109 let mut result = sources.json_merge::<Self>()?;
110 if let Some(server_url) = &*ZED_SERVER_URL {
111 result.server_url.clone_from(&server_url)
112 }
113 Ok(result)
114 }
115}
116
117#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
118pub struct ProxySettingsContent {
119 proxy: Option<String>,
120}
121
122#[derive(Deserialize, Default)]
123pub struct ProxySettings {
124 pub proxy: Option<String>,
125}
126
127impl Settings for ProxySettings {
128 const KEY: Option<&'static str> = None;
129
130 type FileContent = ProxySettingsContent;
131
132 fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
133 Ok(Self {
134 proxy: sources
135 .user
136 .and_then(|value| value.proxy.clone())
137 .or(sources.default.proxy.clone()),
138 })
139 }
140}
141
142pub fn init_settings(cx: &mut AppContext) {
143 TelemetrySettings::register(cx);
144 ClientSettings::register(cx);
145 ProxySettings::register(cx);
146}
147
148pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
149 let client = Arc::downgrade(client);
150 cx.on_action({
151 let client = client.clone();
152 move |_: &SignIn, cx| {
153 if let Some(client) = client.upgrade() {
154 cx.spawn(
155 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
156 )
157 .detach();
158 }
159 }
160 });
161
162 cx.on_action({
163 let client = client.clone();
164 move |_: &SignOut, cx| {
165 if let Some(client) = client.upgrade() {
166 cx.spawn(|cx| async move {
167 client.sign_out(&cx).await;
168 })
169 .detach();
170 }
171 }
172 });
173
174 cx.on_action({
175 let client = client.clone();
176 move |_: &Reconnect, cx| {
177 if let Some(client) = client.upgrade() {
178 cx.spawn(|cx| async move {
179 client.reconnect(&cx);
180 })
181 .detach();
182 }
183 }
184 });
185}
186
187struct GlobalClient(Arc<Client>);
188
189impl Global for GlobalClient {}
190
191pub struct Client {
192 id: AtomicU64,
193 peer: Arc<Peer>,
194 http: Arc<HttpClientWithUrl>,
195 telemetry: Arc<Telemetry>,
196 credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
197 state: RwLock<ClientState>,
198
199 #[allow(clippy::type_complexity)]
200 #[cfg(any(test, feature = "test-support"))]
201 authenticate: RwLock<
202 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
203 >,
204
205 #[allow(clippy::type_complexity)]
206 #[cfg(any(test, feature = "test-support"))]
207 establish_connection: RwLock<
208 Option<
209 Box<
210 dyn 'static
211 + Send
212 + Sync
213 + Fn(
214 &Credentials,
215 &AsyncAppContext,
216 ) -> Task<Result<Connection, EstablishConnectionError>>,
217 >,
218 >,
219 >,
220}
221
222#[derive(Error, Debug)]
223pub enum EstablishConnectionError {
224 #[error("upgrade required")]
225 UpgradeRequired,
226 #[error("unauthorized")]
227 Unauthorized,
228 #[error("{0}")]
229 Other(#[from] anyhow::Error),
230 #[error("{0}")]
231 Http(#[from] http::Error),
232 #[error("{0}")]
233 Io(#[from] std::io::Error),
234 #[error("{0}")]
235 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
236}
237
238impl From<WebsocketError> for EstablishConnectionError {
239 fn from(error: WebsocketError) -> Self {
240 if let WebsocketError::Http(response) = &error {
241 match response.status() {
242 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
243 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
244 _ => {}
245 }
246 }
247 EstablishConnectionError::Other(error.into())
248 }
249}
250
251impl EstablishConnectionError {
252 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
253 Self::Other(error.into())
254 }
255}
256
257#[derive(Copy, Clone, Debug, PartialEq)]
258pub enum Status {
259 SignedOut,
260 UpgradeRequired,
261 Authenticating,
262 Connecting,
263 ConnectionError,
264 Connected {
265 peer_id: PeerId,
266 connection_id: ConnectionId,
267 },
268 ConnectionLost,
269 Reauthenticating,
270 Reconnecting,
271 ReconnectionError {
272 next_reconnection: Instant,
273 },
274}
275
276impl Status {
277 pub fn is_connected(&self) -> bool {
278 matches!(self, Self::Connected { .. })
279 }
280
281 pub fn is_signed_out(&self) -> bool {
282 matches!(self, Self::SignedOut | Self::UpgradeRequired)
283 }
284}
285
286struct ClientState {
287 credentials: Option<Credentials>,
288 status: (watch::Sender<Status>, watch::Receiver<Status>),
289 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
290 _reconnect_task: Option<Task<()>>,
291 entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
292 models_by_message_type: HashMap<TypeId, AnyWeakModel>,
293 entity_types_by_message_type: HashMap<TypeId, TypeId>,
294 #[allow(clippy::type_complexity)]
295 message_handlers: HashMap<
296 TypeId,
297 Arc<
298 dyn Send
299 + Sync
300 + Fn(
301 AnyModel,
302 Box<dyn AnyTypedEnvelope>,
303 &Arc<Client>,
304 AsyncAppContext,
305 ) -> LocalBoxFuture<'static, Result<()>>,
306 >,
307 >,
308}
309
310enum WeakSubscriber {
311 Entity { handle: AnyWeakModel },
312 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
313}
314
315#[derive(Clone, Debug, Eq, PartialEq)]
316pub enum Credentials {
317 DevServer { token: DevServerToken },
318 User { user_id: u64, access_token: String },
319}
320
321impl Credentials {
322 pub fn authorization_header(&self) -> String {
323 match self {
324 Credentials::DevServer { token } => format!("dev-server-token {}", token),
325 Credentials::User {
326 user_id,
327 access_token,
328 } => format!("{} {}", user_id, access_token),
329 }
330 }
331}
332
333/// A provider for [`Credentials`].
334///
335/// Used to abstract over reading and writing credentials to some form of
336/// persistence (like the system keychain).
337trait CredentialsProvider {
338 /// Reads the credentials from the provider.
339 fn read_credentials<'a>(
340 &'a self,
341 cx: &'a AsyncAppContext,
342 ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>>;
343
344 /// Writes the credentials to the provider.
345 fn write_credentials<'a>(
346 &'a self,
347 user_id: u64,
348 access_token: String,
349 cx: &'a AsyncAppContext,
350 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
351
352 /// Deletes the credentials from the provider.
353 fn delete_credentials<'a>(
354 &'a self,
355 cx: &'a AsyncAppContext,
356 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
357}
358
359impl Default for ClientState {
360 fn default() -> Self {
361 Self {
362 credentials: None,
363 status: watch::channel_with(Status::SignedOut),
364 entity_id_extractors: Default::default(),
365 _reconnect_task: None,
366 models_by_message_type: Default::default(),
367 entities_by_type_and_remote_id: Default::default(),
368 entity_types_by_message_type: Default::default(),
369 message_handlers: Default::default(),
370 }
371 }
372}
373
374pub enum Subscription {
375 Entity {
376 client: Weak<Client>,
377 id: (TypeId, u64),
378 },
379 Message {
380 client: Weak<Client>,
381 id: TypeId,
382 },
383}
384
385impl Drop for Subscription {
386 fn drop(&mut self) {
387 match self {
388 Subscription::Entity { client, id } => {
389 if let Some(client) = client.upgrade() {
390 let mut state = client.state.write();
391 let _ = state.entities_by_type_and_remote_id.remove(id);
392 }
393 }
394 Subscription::Message { client, id } => {
395 if let Some(client) = client.upgrade() {
396 let mut state = client.state.write();
397 let _ = state.entity_types_by_message_type.remove(id);
398 let _ = state.message_handlers.remove(id);
399 }
400 }
401 }
402 }
403}
404
405pub struct PendingEntitySubscription<T: 'static> {
406 client: Arc<Client>,
407 remote_id: u64,
408 _entity_type: PhantomData<T>,
409 consumed: bool,
410}
411
412impl<T: 'static> PendingEntitySubscription<T> {
413 pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
414 self.consumed = true;
415 let mut state = self.client.state.write();
416 let id = (TypeId::of::<T>(), self.remote_id);
417 let Some(WeakSubscriber::Pending(messages)) =
418 state.entities_by_type_and_remote_id.remove(&id)
419 else {
420 unreachable!()
421 };
422
423 state.entities_by_type_and_remote_id.insert(
424 id,
425 WeakSubscriber::Entity {
426 handle: model.downgrade().into(),
427 },
428 );
429 drop(state);
430 for message in messages {
431 self.client.handle_message(message, cx);
432 }
433 Subscription::Entity {
434 client: Arc::downgrade(&self.client),
435 id,
436 }
437 }
438}
439
440impl<T: 'static> Drop for PendingEntitySubscription<T> {
441 fn drop(&mut self) {
442 if !self.consumed {
443 let mut state = self.client.state.write();
444 if let Some(WeakSubscriber::Pending(messages)) = state
445 .entities_by_type_and_remote_id
446 .remove(&(TypeId::of::<T>(), self.remote_id))
447 {
448 for message in messages {
449 log::info!("unhandled message {}", message.payload_type_name());
450 }
451 }
452 }
453 }
454}
455
456#[derive(Copy, Clone)]
457pub struct TelemetrySettings {
458 pub diagnostics: bool,
459 pub metrics: bool,
460}
461
462/// Control what info is collected by Zed.
463#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
464pub struct TelemetrySettingsContent {
465 /// Send debug info like crash reports.
466 ///
467 /// Default: true
468 pub diagnostics: Option<bool>,
469 /// Send anonymized usage data like what languages you're using Zed with.
470 ///
471 /// Default: true
472 pub metrics: Option<bool>,
473}
474
475impl settings::Settings for TelemetrySettings {
476 const KEY: Option<&'static str> = Some("telemetry");
477
478 type FileContent = TelemetrySettingsContent;
479
480 fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
481 Ok(Self {
482 diagnostics: sources.user.as_ref().and_then(|v| v.diagnostics).unwrap_or(
483 sources
484 .default
485 .diagnostics
486 .ok_or_else(Self::missing_default)?,
487 ),
488 metrics: sources
489 .user
490 .as_ref()
491 .and_then(|v| v.metrics)
492 .unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
493 })
494 }
495}
496
497impl Client {
498 pub fn new(
499 clock: Arc<dyn SystemClock>,
500 http: Arc<HttpClientWithUrl>,
501 cx: &mut AppContext,
502 ) -> Arc<Self> {
503 let use_zed_development_auth = match ReleaseChannel::try_global(cx) {
504 Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH,
505 Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
506 | None => false,
507 };
508
509 let credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static> =
510 if use_zed_development_auth {
511 Arc::new(DevelopmentCredentialsProvider {
512 path: paths::config_dir().join("development_auth"),
513 })
514 } else {
515 Arc::new(KeychainCredentialsProvider)
516 };
517
518 Arc::new(Self {
519 id: AtomicU64::new(0),
520 peer: Peer::new(0),
521 telemetry: Telemetry::new(clock, http.clone(), cx),
522 http,
523 credentials_provider,
524 state: Default::default(),
525
526 #[cfg(any(test, feature = "test-support"))]
527 authenticate: Default::default(),
528 #[cfg(any(test, feature = "test-support"))]
529 establish_connection: Default::default(),
530 })
531 }
532
533 pub fn production(cx: &mut AppContext) -> Arc<Self> {
534 let clock = Arc::new(clock::RealSystemClock);
535 let http = Arc::new(HttpClientWithUrl::new(
536 &ClientSettings::get_global(cx).server_url,
537 ProxySettings::get_global(cx).proxy.clone(),
538 ));
539 Self::new(clock, http.clone(), cx)
540 }
541
542 pub fn id(&self) -> u64 {
543 self.id.load(Ordering::SeqCst)
544 }
545
546 pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
547 self.http.clone()
548 }
549
550 pub fn set_id(&self, id: u64) -> &Self {
551 self.id.store(id, Ordering::SeqCst);
552 self
553 }
554
555 #[cfg(any(test, feature = "test-support"))]
556 pub fn teardown(&self) {
557 let mut state = self.state.write();
558 state._reconnect_task.take();
559 state.message_handlers.clear();
560 state.models_by_message_type.clear();
561 state.entities_by_type_and_remote_id.clear();
562 state.entity_id_extractors.clear();
563 self.peer.teardown();
564 }
565
566 #[cfg(any(test, feature = "test-support"))]
567 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
568 where
569 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
570 {
571 *self.authenticate.write() = Some(Box::new(authenticate));
572 self
573 }
574
575 #[cfg(any(test, feature = "test-support"))]
576 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
577 where
578 F: 'static
579 + Send
580 + Sync
581 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
582 {
583 *self.establish_connection.write() = Some(Box::new(connect));
584 self
585 }
586
587 pub fn global(cx: &AppContext) -> Arc<Self> {
588 cx.global::<GlobalClient>().0.clone()
589 }
590 pub fn set_global(client: Arc<Client>, cx: &mut AppContext) {
591 cx.set_global(GlobalClient(client))
592 }
593
594 pub fn user_id(&self) -> Option<u64> {
595 if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() {
596 Some(*user_id)
597 } else {
598 None
599 }
600 }
601
602 pub fn peer_id(&self) -> Option<PeerId> {
603 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
604 Some(*peer_id)
605 } else {
606 None
607 }
608 }
609
610 pub fn status(&self) -> watch::Receiver<Status> {
611 self.state.read().status.1.clone()
612 }
613
614 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
615 log::info!("set status on client {}: {:?}", self.id(), status);
616 let mut state = self.state.write();
617 *state.status.0.borrow_mut() = status;
618
619 match status {
620 Status::Connected { .. } => {
621 state._reconnect_task = None;
622 }
623 Status::ConnectionLost => {
624 let this = self.clone();
625 state._reconnect_task = Some(cx.spawn(move |cx| async move {
626 #[cfg(any(test, feature = "test-support"))]
627 let mut rng = StdRng::seed_from_u64(0);
628 #[cfg(not(any(test, feature = "test-support")))]
629 let mut rng = StdRng::from_entropy();
630
631 let mut delay = INITIAL_RECONNECTION_DELAY;
632 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
633 log::error!("failed to connect {}", error);
634 if matches!(*this.status().borrow(), Status::ConnectionError) {
635 this.set_status(
636 Status::ReconnectionError {
637 next_reconnection: Instant::now() + delay,
638 },
639 &cx,
640 );
641 cx.background_executor().timer(delay).await;
642 delay = delay
643 .mul_f32(rng.gen_range(0.5..=2.5))
644 .max(INITIAL_RECONNECTION_DELAY)
645 .min(MAX_RECONNECTION_DELAY);
646 } else {
647 break;
648 }
649 }
650 }));
651 }
652 Status::SignedOut | Status::UpgradeRequired => {
653 self.telemetry.set_authenticated_user_info(None, false);
654 state._reconnect_task.take();
655 }
656 _ => {}
657 }
658 }
659
660 pub fn subscribe_to_entity<T>(
661 self: &Arc<Self>,
662 remote_id: u64,
663 ) -> Result<PendingEntitySubscription<T>>
664 where
665 T: 'static,
666 {
667 let id = (TypeId::of::<T>(), remote_id);
668
669 let mut state = self.state.write();
670 if state.entities_by_type_and_remote_id.contains_key(&id) {
671 return Err(anyhow!("already subscribed to entity"));
672 }
673
674 state
675 .entities_by_type_and_remote_id
676 .insert(id, WeakSubscriber::Pending(Default::default()));
677
678 Ok(PendingEntitySubscription {
679 client: self.clone(),
680 remote_id,
681 consumed: false,
682 _entity_type: PhantomData,
683 })
684 }
685
686 #[track_caller]
687 pub fn add_message_handler<M, E, H, F>(
688 self: &Arc<Self>,
689 entity: WeakModel<E>,
690 handler: H,
691 ) -> Subscription
692 where
693 M: EnvelopedMessage,
694 E: 'static,
695 H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
696 F: 'static + Future<Output = Result<()>>,
697 {
698 self.add_message_handler_impl(entity, move |model, message, _, cx| {
699 handler(model, message, cx)
700 })
701 }
702
703 fn add_message_handler_impl<M, E, H, F>(
704 self: &Arc<Self>,
705 entity: WeakModel<E>,
706 handler: H,
707 ) -> Subscription
708 where
709 M: EnvelopedMessage,
710 E: 'static,
711 H: 'static
712 + Sync
713 + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
714 + Send
715 + Sync,
716 F: 'static + Future<Output = Result<()>>,
717 {
718 let message_type_id = TypeId::of::<M>();
719 let mut state = self.state.write();
720 state
721 .models_by_message_type
722 .insert(message_type_id, entity.into());
723
724 let prev_handler = state.message_handlers.insert(
725 message_type_id,
726 Arc::new(move |subscriber, envelope, client, cx| {
727 let subscriber = subscriber.downcast::<E>().unwrap();
728 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
729 handler(subscriber, *envelope, client.clone(), cx).boxed_local()
730 }),
731 );
732 if prev_handler.is_some() {
733 let location = std::panic::Location::caller();
734 panic!(
735 "{}:{} registered handler for the same message {} twice",
736 location.file(),
737 location.line(),
738 std::any::type_name::<M>()
739 );
740 }
741
742 Subscription::Message {
743 client: Arc::downgrade(self),
744 id: message_type_id,
745 }
746 }
747
748 pub fn add_request_handler<M, E, H, F>(
749 self: &Arc<Self>,
750 model: WeakModel<E>,
751 handler: H,
752 ) -> Subscription
753 where
754 M: RequestMessage,
755 E: 'static,
756 H: 'static + Sync + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
757 F: 'static + Future<Output = Result<M::Response>>,
758 {
759 self.add_message_handler_impl(model, move |handle, envelope, this, cx| {
760 Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
761 })
762 }
763
764 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
765 where
766 M: EntityMessage,
767 E: 'static,
768 H: 'static + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
769 F: 'static + Future<Output = Result<()>>,
770 {
771 self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, _, cx| {
772 handler(subscriber.downcast::<E>().unwrap(), message, cx)
773 })
774 }
775
776 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
777 where
778 M: EntityMessage,
779 E: 'static,
780 H: 'static + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
781 F: 'static + Future<Output = Result<()>>,
782 {
783 let model_type_id = TypeId::of::<E>();
784 let message_type_id = TypeId::of::<M>();
785
786 let mut state = self.state.write();
787 state
788 .entity_types_by_message_type
789 .insert(message_type_id, model_type_id);
790 state
791 .entity_id_extractors
792 .entry(message_type_id)
793 .or_insert_with(|| {
794 |envelope| {
795 envelope
796 .as_any()
797 .downcast_ref::<TypedEnvelope<M>>()
798 .unwrap()
799 .payload
800 .remote_entity_id()
801 }
802 });
803 let prev_handler = state.message_handlers.insert(
804 message_type_id,
805 Arc::new(move |handle, envelope, client, cx| {
806 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
807 handler(handle, *envelope, client.clone(), cx).boxed_local()
808 }),
809 );
810 if prev_handler.is_some() {
811 panic!("registered handler for the same message twice");
812 }
813 }
814
815 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
816 where
817 M: EntityMessage + RequestMessage,
818 E: 'static,
819 H: 'static + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
820 F: 'static + Future<Output = Result<M::Response>>,
821 {
822 self.add_entity_message_handler::<M, E, _, _>(move |entity, envelope, client, cx| {
823 Self::respond_to_request::<M, _>(
824 envelope.receipt(),
825 handler(entity.downcast::<E>().unwrap(), envelope, cx),
826 client,
827 )
828 })
829 }
830
831 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
832 receipt: Receipt<T>,
833 response: F,
834 client: Arc<Self>,
835 ) -> Result<()> {
836 match response.await {
837 Ok(response) => {
838 client.respond(receipt, response)?;
839 Ok(())
840 }
841 Err(error) => {
842 client.respond_with_error(receipt, error.to_proto())?;
843 Err(error)
844 }
845 }
846 }
847
848 pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool {
849 self.credentials_provider
850 .read_credentials(cx)
851 .await
852 .is_some()
853 }
854
855 pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self {
856 self.state.write().credentials = Some(Credentials::DevServer { token });
857 self
858 }
859
860 #[async_recursion(?Send)]
861 pub async fn authenticate_and_connect(
862 self: &Arc<Self>,
863 try_provider: bool,
864 cx: &AsyncAppContext,
865 ) -> anyhow::Result<()> {
866 let was_disconnected = match *self.status().borrow() {
867 Status::SignedOut => true,
868 Status::ConnectionError
869 | Status::ConnectionLost
870 | Status::Authenticating { .. }
871 | Status::Reauthenticating { .. }
872 | Status::ReconnectionError { .. } => false,
873 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
874 return Ok(())
875 }
876 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
877 };
878 if was_disconnected {
879 self.set_status(Status::Authenticating, cx);
880 } else {
881 self.set_status(Status::Reauthenticating, cx)
882 }
883
884 let mut read_from_provider = false;
885 let mut credentials = self.state.read().credentials.clone();
886 if credentials.is_none() && try_provider {
887 credentials = self.credentials_provider.read_credentials(cx).await;
888 read_from_provider = credentials.is_some();
889 }
890
891 if credentials.is_none() {
892 let mut status_rx = self.status();
893 let _ = status_rx.next().await;
894 futures::select_biased! {
895 authenticate = self.authenticate(cx).fuse() => {
896 match authenticate {
897 Ok(creds) => credentials = Some(creds),
898 Err(err) => {
899 self.set_status(Status::ConnectionError, cx);
900 return Err(err);
901 }
902 }
903 }
904 _ = status_rx.next().fuse() => {
905 return Err(anyhow!("authentication canceled"));
906 }
907 }
908 }
909 let credentials = credentials.unwrap();
910 if let Credentials::User { user_id, .. } = &credentials {
911 self.set_id(*user_id);
912 }
913
914 if was_disconnected {
915 self.set_status(Status::Connecting, cx);
916 } else {
917 self.set_status(Status::Reconnecting, cx);
918 }
919
920 let mut timeout =
921 futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
922 futures::select_biased! {
923 connection = self.establish_connection(&credentials, cx).fuse() => {
924 match connection {
925 Ok(conn) => {
926 self.state.write().credentials = Some(credentials.clone());
927 if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
928 if let Credentials::User{user_id, access_token} = credentials {
929 self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err();
930 }
931 }
932
933 futures::select_biased! {
934 result = self.set_connection(conn, cx).fuse() => result,
935 _ = timeout => {
936 self.set_status(Status::ConnectionError, cx);
937 Err(anyhow!("timed out waiting on hello message from server"))
938 }
939 }
940 }
941 Err(EstablishConnectionError::Unauthorized) => {
942 self.state.write().credentials.take();
943 if read_from_provider {
944 self.credentials_provider.delete_credentials(cx).await.log_err();
945 self.set_status(Status::SignedOut, cx);
946 self.authenticate_and_connect(false, cx).await
947 } else {
948 self.set_status(Status::ConnectionError, cx);
949 Err(EstablishConnectionError::Unauthorized)?
950 }
951 }
952 Err(EstablishConnectionError::UpgradeRequired) => {
953 self.set_status(Status::UpgradeRequired, cx);
954 Err(EstablishConnectionError::UpgradeRequired)?
955 }
956 Err(error) => {
957 self.set_status(Status::ConnectionError, cx);
958 Err(error)?
959 }
960 }
961 }
962 _ = &mut timeout => {
963 self.set_status(Status::ConnectionError, cx);
964 Err(anyhow!("timed out trying to establish connection"))
965 }
966 }
967 }
968
969 async fn set_connection(
970 self: &Arc<Self>,
971 conn: Connection,
972 cx: &AsyncAppContext,
973 ) -> Result<()> {
974 let executor = cx.background_executor();
975 log::info!("add connection to peer");
976 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
977 let executor = executor.clone();
978 move |duration| executor.timer(duration)
979 });
980 let handle_io = executor.spawn(handle_io);
981
982 let peer_id = async {
983 log::info!("waiting for server hello");
984 let message = incoming
985 .next()
986 .await
987 .ok_or_else(|| anyhow!("no hello message received"))?;
988 log::info!("got server hello");
989 let hello_message_type_name = message.payload_type_name().to_string();
990 let hello = message
991 .into_any()
992 .downcast::<TypedEnvelope<proto::Hello>>()
993 .map_err(|_| {
994 anyhow!(
995 "invalid hello message received: {:?}",
996 hello_message_type_name
997 )
998 })?;
999 let peer_id = hello
1000 .payload
1001 .peer_id
1002 .ok_or_else(|| anyhow!("invalid peer id"))?;
1003 Ok(peer_id)
1004 };
1005
1006 let peer_id = match peer_id.await {
1007 Ok(peer_id) => peer_id,
1008 Err(error) => {
1009 self.peer.disconnect(connection_id);
1010 return Err(error);
1011 }
1012 };
1013
1014 log::info!(
1015 "set status to connected (connection id: {:?}, peer id: {:?})",
1016 connection_id,
1017 peer_id
1018 );
1019 self.set_status(
1020 Status::Connected {
1021 peer_id,
1022 connection_id,
1023 },
1024 cx,
1025 );
1026
1027 cx.spawn({
1028 let this = self.clone();
1029 |cx| {
1030 async move {
1031 while let Some(message) = incoming.next().await {
1032 this.handle_message(message, &cx);
1033 // Don't starve the main thread when receiving lots of messages at once.
1034 smol::future::yield_now().await;
1035 }
1036 }
1037 }
1038 })
1039 .detach();
1040
1041 cx.spawn({
1042 let this = self.clone();
1043 move |cx| async move {
1044 match handle_io.await {
1045 Ok(()) => {
1046 if *this.status().borrow()
1047 == (Status::Connected {
1048 connection_id,
1049 peer_id,
1050 })
1051 {
1052 this.set_status(Status::SignedOut, &cx);
1053 }
1054 }
1055 Err(err) => {
1056 log::error!("connection error: {:?}", err);
1057 this.set_status(Status::ConnectionLost, &cx);
1058 }
1059 }
1060 }
1061 })
1062 .detach();
1063
1064 Ok(())
1065 }
1066
1067 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
1068 #[cfg(any(test, feature = "test-support"))]
1069 if let Some(callback) = self.authenticate.read().as_ref() {
1070 return callback(cx);
1071 }
1072
1073 self.authenticate_with_browser(cx)
1074 }
1075
1076 fn establish_connection(
1077 self: &Arc<Self>,
1078 credentials: &Credentials,
1079 cx: &AsyncAppContext,
1080 ) -> Task<Result<Connection, EstablishConnectionError>> {
1081 #[cfg(any(test, feature = "test-support"))]
1082 if let Some(callback) = self.establish_connection.read().as_ref() {
1083 return callback(credentials, cx);
1084 }
1085
1086 self.establish_websocket_connection(credentials, cx)
1087 }
1088
1089 async fn get_rpc_url(
1090 http: Arc<HttpClientWithUrl>,
1091 release_channel: Option<ReleaseChannel>,
1092 ) -> Result<Url> {
1093 if let Some(url) = &*ZED_RPC_URL {
1094 return Url::parse(url).context("invalid rpc url");
1095 }
1096
1097 let mut url = http.build_url("/rpc");
1098 if let Some(preview_param) =
1099 release_channel.and_then(|channel| channel.release_query_param())
1100 {
1101 url += "?";
1102 url += preview_param;
1103 }
1104 let response = http.get(&url, Default::default(), false).await?;
1105 let collab_url = if response.status().is_redirection() {
1106 response
1107 .headers()
1108 .get("Location")
1109 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1110 .to_str()
1111 .map_err(EstablishConnectionError::other)?
1112 .to_string()
1113 } else {
1114 Err(anyhow!(
1115 "unexpected /rpc response status {}",
1116 response.status()
1117 ))?
1118 };
1119
1120 Url::parse(&collab_url).context("invalid rpc url")
1121 }
1122
1123 fn establish_websocket_connection(
1124 self: &Arc<Self>,
1125 credentials: &Credentials,
1126 cx: &AsyncAppContext,
1127 ) -> Task<Result<Connection, EstablishConnectionError>> {
1128 let release_channel = cx
1129 .update(|cx| ReleaseChannel::try_global(cx))
1130 .ok()
1131 .flatten();
1132 let app_version = cx
1133 .update(|cx| AppVersion::global(cx).to_string())
1134 .ok()
1135 .unwrap_or_default();
1136
1137 let request = Request::builder()
1138 .header("Authorization", credentials.authorization_header())
1139 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION)
1140 .header("x-zed-app-version", app_version)
1141 .header(
1142 "x-zed-release-channel",
1143 release_channel.map(|r| r.dev_name()).unwrap_or("unknown"),
1144 );
1145
1146 let http = self.http.clone();
1147 cx.background_executor().spawn(async move {
1148 let mut rpc_url = Self::get_rpc_url(http, release_channel).await?;
1149 let rpc_host = rpc_url
1150 .host_str()
1151 .zip(rpc_url.port_or_known_default())
1152 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1153 let stream = smol::net::TcpStream::connect(rpc_host).await?;
1154
1155 log::info!("connected to rpc endpoint {}", rpc_url);
1156
1157 match rpc_url.scheme() {
1158 "https" => {
1159 rpc_url.set_scheme("wss").unwrap();
1160 let request = request.uri(rpc_url.as_str()).body(())?;
1161 let (stream, _) =
1162 async_tungstenite::async_std::client_async_tls(request, stream).await?;
1163 Ok(Connection::new(
1164 stream
1165 .map_err(|error| anyhow!(error))
1166 .sink_map_err(|error| anyhow!(error)),
1167 ))
1168 }
1169 "http" => {
1170 rpc_url.set_scheme("ws").unwrap();
1171 let request = request.uri(rpc_url.as_str()).body(())?;
1172 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1173 Ok(Connection::new(
1174 stream
1175 .map_err(|error| anyhow!(error))
1176 .sink_map_err(|error| anyhow!(error)),
1177 ))
1178 }
1179 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1180 }
1181 })
1182 }
1183
1184 pub fn authenticate_with_browser(
1185 self: &Arc<Self>,
1186 cx: &AsyncAppContext,
1187 ) -> Task<Result<Credentials>> {
1188 let http = self.http.clone();
1189 cx.spawn(|cx| async move {
1190 let background = cx.background_executor().clone();
1191
1192 let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1193 cx.update(|cx| {
1194 cx.spawn(move |cx| async move {
1195 let url = open_url_rx.await?;
1196 cx.update(|cx| cx.open_url(&url))
1197 })
1198 .detach_and_log_err(cx);
1199 })
1200 .log_err();
1201
1202 let credentials = background
1203 .clone()
1204 .spawn(async move {
1205 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1206 // zed server to encrypt the user's access token, so that it can'be intercepted by
1207 // any other app running on the user's device.
1208 let (public_key, private_key) =
1209 rpc::auth::keypair().expect("failed to generate keypair for auth");
1210 let public_key_string = String::try_from(public_key)
1211 .expect("failed to serialize public key for auth");
1212
1213 if let Some((login, token)) =
1214 IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1215 {
1216 eprintln!("authenticate as admin {login}, {token}");
1217
1218 return Self::authenticate_as_admin(http, login.clone(), token.clone())
1219 .await;
1220 }
1221
1222 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1223 let server =
1224 tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1225 let port = server.server_addr().port();
1226
1227 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1228 // that the user is signing in from a Zed app running on the same device.
1229 let mut url = http.build_url(&format!(
1230 "/native_app_signin?native_app_port={}&native_app_public_key={}",
1231 port, public_key_string
1232 ));
1233
1234 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1235 log::info!("impersonating user @{}", impersonate_login);
1236 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1237 }
1238
1239 open_url_tx.send(url).log_err();
1240
1241 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1242 // access token from the query params.
1243 //
1244 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1245 // custom URL scheme instead of this local HTTP server.
1246 let (user_id, access_token) = background
1247 .spawn(async move {
1248 for _ in 0..100 {
1249 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1250 let path = req.url();
1251 let mut user_id = None;
1252 let mut access_token = None;
1253 let url = Url::parse(&format!("http://example.com{}", path))
1254 .context("failed to parse login notification url")?;
1255 for (key, value) in url.query_pairs() {
1256 if key == "access_token" {
1257 access_token = Some(value.to_string());
1258 } else if key == "user_id" {
1259 user_id = Some(value.to_string());
1260 }
1261 }
1262
1263 let post_auth_url =
1264 http.build_url("/native_app_signin_succeeded");
1265 req.respond(
1266 tiny_http::Response::empty(302).with_header(
1267 tiny_http::Header::from_bytes(
1268 &b"Location"[..],
1269 post_auth_url.as_bytes(),
1270 )
1271 .unwrap(),
1272 ),
1273 )
1274 .context("failed to respond to login http request")?;
1275 return Ok((
1276 user_id
1277 .ok_or_else(|| anyhow!("missing user_id parameter"))?,
1278 access_token.ok_or_else(|| {
1279 anyhow!("missing access_token parameter")
1280 })?,
1281 ));
1282 }
1283 }
1284
1285 Err(anyhow!("didn't receive login redirect"))
1286 })
1287 .await?;
1288
1289 let access_token = private_key
1290 .decrypt_string(&access_token)
1291 .context("failed to decrypt access token")?;
1292
1293 Ok(Credentials::User {
1294 user_id: user_id.parse()?,
1295 access_token,
1296 })
1297 })
1298 .await?;
1299
1300 cx.update(|cx| cx.activate(true))?;
1301 Ok(credentials)
1302 })
1303 }
1304
1305 async fn authenticate_as_admin(
1306 http: Arc<HttpClientWithUrl>,
1307 login: String,
1308 mut api_token: String,
1309 ) -> Result<Credentials> {
1310 #[derive(Deserialize)]
1311 struct AuthenticatedUserResponse {
1312 user: User,
1313 }
1314
1315 #[derive(Deserialize)]
1316 struct User {
1317 id: u64,
1318 }
1319
1320 // Use the collab server's admin API to retrieve the id
1321 // of the impersonated user.
1322 let mut url = Self::get_rpc_url(http.clone(), None).await?;
1323 url.set_path("/user");
1324 url.set_query(Some(&format!("github_login={login}")));
1325 let request = Request::get(url.as_str())
1326 .header("Authorization", format!("token {api_token}"))
1327 .body("".into())?;
1328
1329 let mut response = http.send(request).await?;
1330 let mut body = String::new();
1331 response.body_mut().read_to_string(&mut body).await?;
1332 if !response.status().is_success() {
1333 Err(anyhow!(
1334 "admin user request failed {} - {}",
1335 response.status().as_u16(),
1336 body,
1337 ))?;
1338 }
1339 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1340
1341 // Use the admin API token to authenticate as the impersonated user.
1342 api_token.insert_str(0, "ADMIN_TOKEN:");
1343 Ok(Credentials::User {
1344 user_id: response.user.id,
1345 access_token: api_token,
1346 })
1347 }
1348
1349 pub async fn sign_out(self: &Arc<Self>, cx: &AsyncAppContext) {
1350 self.state.write().credentials = None;
1351 self.disconnect(&cx);
1352
1353 if self.has_credentials(cx).await {
1354 self.credentials_provider
1355 .delete_credentials(cx)
1356 .await
1357 .log_err();
1358 }
1359 }
1360
1361 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1362 self.peer.teardown();
1363 self.set_status(Status::SignedOut, cx);
1364 }
1365
1366 pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1367 self.peer.teardown();
1368 self.set_status(Status::ConnectionLost, cx);
1369 }
1370
1371 fn connection_id(&self) -> Result<ConnectionId> {
1372 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1373 Ok(connection_id)
1374 } else {
1375 Err(anyhow!("not connected"))
1376 }
1377 }
1378
1379 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1380 log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1381 self.peer.send(self.connection_id()?, message)
1382 }
1383
1384 pub fn request<T: RequestMessage>(
1385 &self,
1386 request: T,
1387 ) -> impl Future<Output = Result<T::Response>> {
1388 self.request_envelope(request)
1389 .map_ok(|envelope| envelope.payload)
1390 }
1391
1392 pub fn request_stream<T: RequestMessage>(
1393 &self,
1394 request: T,
1395 ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1396 let client_id = self.id.load(Ordering::SeqCst);
1397 log::debug!(
1398 "rpc request start. client_id:{}. name:{}",
1399 client_id,
1400 T::NAME
1401 );
1402 let response = self
1403 .connection_id()
1404 .map(|conn_id| self.peer.request_stream(conn_id, request));
1405 async move {
1406 let response = response?.await;
1407 log::debug!(
1408 "rpc request finish. client_id:{}. name:{}",
1409 client_id,
1410 T::NAME
1411 );
1412 response
1413 }
1414 }
1415
1416 pub fn request_envelope<T: RequestMessage>(
1417 &self,
1418 request: T,
1419 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1420 let client_id = self.id();
1421 log::debug!(
1422 "rpc request start. client_id:{}. name:{}",
1423 client_id,
1424 T::NAME
1425 );
1426 let response = self
1427 .connection_id()
1428 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1429 async move {
1430 let response = response?.await;
1431 log::debug!(
1432 "rpc request finish. client_id:{}. name:{}",
1433 client_id,
1434 T::NAME
1435 );
1436 response
1437 }
1438 }
1439
1440 pub fn request_dynamic(
1441 &self,
1442 envelope: proto::Envelope,
1443 request_type: &'static str,
1444 ) -> impl Future<Output = Result<proto::Envelope>> {
1445 let client_id = self.id();
1446 log::debug!(
1447 "rpc request start. client_id:{}. name:{}",
1448 client_id,
1449 request_type
1450 );
1451 let response = self
1452 .connection_id()
1453 .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1454 async move {
1455 let response = response?.await;
1456 log::debug!(
1457 "rpc request finish. client_id:{}. name:{}",
1458 client_id,
1459 request_type
1460 );
1461 Ok(response?.0)
1462 }
1463 }
1464
1465 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1466 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1467 self.peer.respond(receipt, response)
1468 }
1469
1470 fn respond_with_error<T: RequestMessage>(
1471 &self,
1472 receipt: Receipt<T>,
1473 error: proto::Error,
1474 ) -> Result<()> {
1475 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1476 self.peer.respond_with_error(receipt, error)
1477 }
1478
1479 fn handle_message(
1480 self: &Arc<Client>,
1481 message: Box<dyn AnyTypedEnvelope>,
1482 cx: &AsyncAppContext,
1483 ) {
1484 let mut state = self.state.write();
1485 let type_name = message.payload_type_name();
1486 let payload_type_id = message.payload_type_id();
1487 let sender_id = message.original_sender_id();
1488
1489 let mut subscriber = None;
1490
1491 if let Some(handle) = state
1492 .models_by_message_type
1493 .get(&payload_type_id)
1494 .and_then(|handle| handle.upgrade())
1495 {
1496 subscriber = Some(handle);
1497 } else if let Some((extract_entity_id, entity_type_id)) =
1498 state.entity_id_extractors.get(&payload_type_id).zip(
1499 state
1500 .entity_types_by_message_type
1501 .get(&payload_type_id)
1502 .copied(),
1503 )
1504 {
1505 let entity_id = (extract_entity_id)(message.as_ref());
1506
1507 match state
1508 .entities_by_type_and_remote_id
1509 .get_mut(&(entity_type_id, entity_id))
1510 {
1511 Some(WeakSubscriber::Pending(pending)) => {
1512 pending.push(message);
1513 return;
1514 }
1515 Some(weak_subscriber) => match weak_subscriber {
1516 WeakSubscriber::Entity { handle } => {
1517 subscriber = handle.upgrade();
1518 }
1519
1520 WeakSubscriber::Pending(_) => {}
1521 },
1522 _ => {}
1523 }
1524 }
1525
1526 let subscriber = if let Some(subscriber) = subscriber {
1527 subscriber
1528 } else {
1529 log::info!("unhandled message {}", type_name);
1530 self.peer.respond_with_unhandled_message(message).log_err();
1531 return;
1532 };
1533
1534 let handler = state.message_handlers.get(&payload_type_id).cloned();
1535 // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1536 // It also ensures we don't hold the lock while yielding back to the executor, as
1537 // that might cause the executor thread driving this future to block indefinitely.
1538 drop(state);
1539
1540 if let Some(handler) = handler {
1541 let future = handler(subscriber, message, self, cx.clone());
1542 let client_id = self.id();
1543 log::debug!(
1544 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1545 client_id,
1546 sender_id,
1547 type_name
1548 );
1549 cx.spawn(move |_| async move {
1550 match future.await {
1551 Ok(()) => {
1552 log::debug!(
1553 "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1554 client_id,
1555 sender_id,
1556 type_name
1557 );
1558 }
1559 Err(error) => {
1560 log::error!(
1561 "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1562 client_id,
1563 sender_id,
1564 type_name,
1565 error
1566 );
1567 }
1568 }
1569 })
1570 .detach();
1571 } else {
1572 log::info!("unhandled message {}", type_name);
1573 self.peer.respond_with_unhandled_message(message).log_err();
1574 }
1575 }
1576
1577 pub fn telemetry(&self) -> &Arc<Telemetry> {
1578 &self.telemetry
1579 }
1580}
1581
1582#[derive(Serialize, Deserialize)]
1583struct DevelopmentCredentials {
1584 user_id: u64,
1585 access_token: String,
1586}
1587
1588/// A credentials provider that stores credentials in a local file.
1589///
1590/// This MUST only be used in development, as this is not a secure way of storing
1591/// credentials on user machines.
1592///
1593/// Its existence is purely to work around the annoyance of having to constantly
1594/// re-allow access to the system keychain when developing Zed.
1595struct DevelopmentCredentialsProvider {
1596 path: PathBuf,
1597}
1598
1599impl CredentialsProvider for DevelopmentCredentialsProvider {
1600 fn read_credentials<'a>(
1601 &'a self,
1602 _cx: &'a AsyncAppContext,
1603 ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1604 async move {
1605 if IMPERSONATE_LOGIN.is_some() {
1606 return None;
1607 }
1608
1609 let json = std::fs::read(&self.path).log_err()?;
1610
1611 let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?;
1612
1613 Some(Credentials::User {
1614 user_id: credentials.user_id,
1615 access_token: credentials.access_token,
1616 })
1617 }
1618 .boxed_local()
1619 }
1620
1621 fn write_credentials<'a>(
1622 &'a self,
1623 user_id: u64,
1624 access_token: String,
1625 _cx: &'a AsyncAppContext,
1626 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1627 async move {
1628 let json = serde_json::to_string(&DevelopmentCredentials {
1629 user_id,
1630 access_token,
1631 })?;
1632
1633 std::fs::write(&self.path, json)?;
1634
1635 Ok(())
1636 }
1637 .boxed_local()
1638 }
1639
1640 fn delete_credentials<'a>(
1641 &'a self,
1642 _cx: &'a AsyncAppContext,
1643 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1644 async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local()
1645 }
1646}
1647
1648/// A credentials provider that stores credentials in the system keychain.
1649struct KeychainCredentialsProvider;
1650
1651impl CredentialsProvider for KeychainCredentialsProvider {
1652 fn read_credentials<'a>(
1653 &'a self,
1654 cx: &'a AsyncAppContext,
1655 ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
1656 async move {
1657 if IMPERSONATE_LOGIN.is_some() {
1658 return None;
1659 }
1660
1661 let (user_id, access_token) = cx
1662 .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
1663 .log_err()?
1664 .await
1665 .log_err()??;
1666
1667 Some(Credentials::User {
1668 user_id: user_id.parse().ok()?,
1669 access_token: String::from_utf8(access_token).ok()?,
1670 })
1671 }
1672 .boxed_local()
1673 }
1674
1675 fn write_credentials<'a>(
1676 &'a self,
1677 user_id: u64,
1678 access_token: String,
1679 cx: &'a AsyncAppContext,
1680 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1681 async move {
1682 cx.update(move |cx| {
1683 cx.write_credentials(
1684 &ClientSettings::get_global(cx).server_url,
1685 &user_id.to_string(),
1686 access_token.as_bytes(),
1687 )
1688 })?
1689 .await
1690 }
1691 .boxed_local()
1692 }
1693
1694 fn delete_credentials<'a>(
1695 &'a self,
1696 cx: &'a AsyncAppContext,
1697 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1698 async move {
1699 cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
1700 .await
1701 }
1702 .boxed_local()
1703 }
1704}
1705
1706/// prefix for the zed:// url scheme
1707pub static ZED_URL_SCHEME: &str = "zed";
1708
1709/// Parses the given link into a Zed link.
1710///
1711/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1712/// Returns [`None`] otherwise.
1713pub fn parse_zed_link<'a>(link: &'a str, cx: &AppContext) -> Option<&'a str> {
1714 let server_url = &ClientSettings::get_global(cx).server_url;
1715 if let Some(stripped) = link
1716 .strip_prefix(server_url)
1717 .and_then(|result| result.strip_prefix('/'))
1718 {
1719 return Some(stripped);
1720 }
1721 if let Some(stripped) = link
1722 .strip_prefix(ZED_URL_SCHEME)
1723 .and_then(|result| result.strip_prefix("://"))
1724 {
1725 return Some(stripped);
1726 }
1727
1728 None
1729}
1730
1731#[cfg(test)]
1732mod tests {
1733 use super::*;
1734 use crate::test::FakeServer;
1735
1736 use clock::FakeSystemClock;
1737 use gpui::{BackgroundExecutor, Context, TestAppContext};
1738 use http::FakeHttpClient;
1739 use parking_lot::Mutex;
1740 use proto::TypedEnvelope;
1741 use settings::SettingsStore;
1742 use std::future;
1743
1744 #[gpui::test(iterations = 10)]
1745 async fn test_reconnection(cx: &mut TestAppContext) {
1746 init_test(cx);
1747 let user_id = 5;
1748 let client = cx.update(|cx| {
1749 Client::new(
1750 Arc::new(FakeSystemClock::default()),
1751 FakeHttpClient::with_404_response(),
1752 cx,
1753 )
1754 });
1755 let server = FakeServer::for_client(user_id, &client, cx).await;
1756 let mut status = client.status();
1757 assert!(matches!(
1758 status.next().await,
1759 Some(Status::Connected { .. })
1760 ));
1761 assert_eq!(server.auth_count(), 1);
1762
1763 server.forbid_connections();
1764 server.disconnect();
1765 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1766
1767 server.allow_connections();
1768 cx.executor().advance_clock(Duration::from_secs(10));
1769 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1770 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1771
1772 server.forbid_connections();
1773 server.disconnect();
1774 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1775
1776 // Clear cached credentials after authentication fails
1777 server.roll_access_token();
1778 server.allow_connections();
1779 cx.executor().run_until_parked();
1780 cx.executor().advance_clock(Duration::from_secs(10));
1781 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1782 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1783 }
1784
1785 #[gpui::test(iterations = 10)]
1786 async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1787 init_test(cx);
1788 let user_id = 5;
1789 let client = cx.update(|cx| {
1790 Client::new(
1791 Arc::new(FakeSystemClock::default()),
1792 FakeHttpClient::with_404_response(),
1793 cx,
1794 )
1795 });
1796 let mut status = client.status();
1797
1798 // Time out when client tries to connect.
1799 client.override_authenticate(move |cx| {
1800 cx.background_executor().spawn(async move {
1801 Ok(Credentials::User {
1802 user_id,
1803 access_token: "token".into(),
1804 })
1805 })
1806 });
1807 client.override_establish_connection(|_, cx| {
1808 cx.background_executor().spawn(async move {
1809 future::pending::<()>().await;
1810 unreachable!()
1811 })
1812 });
1813 let auth_and_connect = cx.spawn({
1814 let client = client.clone();
1815 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1816 });
1817 executor.run_until_parked();
1818 assert!(matches!(status.next().await, Some(Status::Connecting)));
1819
1820 executor.advance_clock(CONNECTION_TIMEOUT);
1821 assert!(matches!(
1822 status.next().await,
1823 Some(Status::ConnectionError { .. })
1824 ));
1825 auth_and_connect.await.unwrap_err();
1826
1827 // Allow the connection to be established.
1828 let server = FakeServer::for_client(user_id, &client, cx).await;
1829 assert!(matches!(
1830 status.next().await,
1831 Some(Status::Connected { .. })
1832 ));
1833
1834 // Disconnect client.
1835 server.forbid_connections();
1836 server.disconnect();
1837 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1838
1839 // Time out when re-establishing the connection.
1840 server.allow_connections();
1841 client.override_establish_connection(|_, cx| {
1842 cx.background_executor().spawn(async move {
1843 future::pending::<()>().await;
1844 unreachable!()
1845 })
1846 });
1847 executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1848 assert!(matches!(
1849 status.next().await,
1850 Some(Status::Reconnecting { .. })
1851 ));
1852
1853 executor.advance_clock(CONNECTION_TIMEOUT);
1854 assert!(matches!(
1855 status.next().await,
1856 Some(Status::ReconnectionError { .. })
1857 ));
1858 }
1859
1860 #[gpui::test(iterations = 10)]
1861 async fn test_authenticating_more_than_once(
1862 cx: &mut TestAppContext,
1863 executor: BackgroundExecutor,
1864 ) {
1865 init_test(cx);
1866 let auth_count = Arc::new(Mutex::new(0));
1867 let dropped_auth_count = Arc::new(Mutex::new(0));
1868 let client = cx.update(|cx| {
1869 Client::new(
1870 Arc::new(FakeSystemClock::default()),
1871 FakeHttpClient::with_404_response(),
1872 cx,
1873 )
1874 });
1875 client.override_authenticate({
1876 let auth_count = auth_count.clone();
1877 let dropped_auth_count = dropped_auth_count.clone();
1878 move |cx| {
1879 let auth_count = auth_count.clone();
1880 let dropped_auth_count = dropped_auth_count.clone();
1881 cx.background_executor().spawn(async move {
1882 *auth_count.lock() += 1;
1883 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1884 future::pending::<()>().await;
1885 unreachable!()
1886 })
1887 }
1888 });
1889
1890 let _authenticate = cx.spawn({
1891 let client = client.clone();
1892 move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1893 });
1894 executor.run_until_parked();
1895 assert_eq!(*auth_count.lock(), 1);
1896 assert_eq!(*dropped_auth_count.lock(), 0);
1897
1898 let _authenticate = cx.spawn({
1899 let client = client.clone();
1900 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1901 });
1902 executor.run_until_parked();
1903 assert_eq!(*auth_count.lock(), 2);
1904 assert_eq!(*dropped_auth_count.lock(), 1);
1905 }
1906
1907 #[gpui::test]
1908 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1909 init_test(cx);
1910 let user_id = 5;
1911 let client = cx.update(|cx| {
1912 Client::new(
1913 Arc::new(FakeSystemClock::default()),
1914 FakeHttpClient::with_404_response(),
1915 cx,
1916 )
1917 });
1918 let server = FakeServer::for_client(user_id, &client, cx).await;
1919
1920 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1921 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1922 client.add_model_message_handler(
1923 move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, mut cx| {
1924 match model.update(&mut cx, |model, _| model.id).unwrap() {
1925 1 => done_tx1.try_send(()).unwrap(),
1926 2 => done_tx2.try_send(()).unwrap(),
1927 _ => unreachable!(),
1928 }
1929 async { Ok(()) }
1930 },
1931 );
1932 let model1 = cx.new_model(|_| TestModel {
1933 id: 1,
1934 subscription: None,
1935 });
1936 let model2 = cx.new_model(|_| TestModel {
1937 id: 2,
1938 subscription: None,
1939 });
1940 let model3 = cx.new_model(|_| TestModel {
1941 id: 3,
1942 subscription: None,
1943 });
1944
1945 let _subscription1 = client
1946 .subscribe_to_entity(1)
1947 .unwrap()
1948 .set_model(&model1, &mut cx.to_async());
1949 let _subscription2 = client
1950 .subscribe_to_entity(2)
1951 .unwrap()
1952 .set_model(&model2, &mut cx.to_async());
1953 // Ensure dropping a subscription for the same entity type still allows receiving of
1954 // messages for other entity IDs of the same type.
1955 let subscription3 = client
1956 .subscribe_to_entity(3)
1957 .unwrap()
1958 .set_model(&model3, &mut cx.to_async());
1959 drop(subscription3);
1960
1961 server.send(proto::JoinProject { project_id: 1 });
1962 server.send(proto::JoinProject { project_id: 2 });
1963 done_rx1.next().await.unwrap();
1964 done_rx2.next().await.unwrap();
1965 }
1966
1967 #[gpui::test]
1968 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1969 init_test(cx);
1970 let user_id = 5;
1971 let client = cx.update(|cx| {
1972 Client::new(
1973 Arc::new(FakeSystemClock::default()),
1974 FakeHttpClient::with_404_response(),
1975 cx,
1976 )
1977 });
1978 let server = FakeServer::for_client(user_id, &client, cx).await;
1979
1980 let model = cx.new_model(|_| TestModel::default());
1981 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1982 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1983 let subscription1 = client.add_message_handler(
1984 model.downgrade(),
1985 move |_, _: TypedEnvelope<proto::Ping>, _| {
1986 done_tx1.try_send(()).unwrap();
1987 async { Ok(()) }
1988 },
1989 );
1990 drop(subscription1);
1991 let _subscription2 = client.add_message_handler(
1992 model.downgrade(),
1993 move |_, _: TypedEnvelope<proto::Ping>, _| {
1994 done_tx2.try_send(()).unwrap();
1995 async { Ok(()) }
1996 },
1997 );
1998 server.send(proto::Ping {});
1999 done_rx2.next().await.unwrap();
2000 }
2001
2002 #[gpui::test]
2003 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2004 init_test(cx);
2005 let user_id = 5;
2006 let client = cx.update(|cx| {
2007 Client::new(
2008 Arc::new(FakeSystemClock::default()),
2009 FakeHttpClient::with_404_response(),
2010 cx,
2011 )
2012 });
2013 let server = FakeServer::for_client(user_id, &client, cx).await;
2014
2015 let model = cx.new_model(|_| TestModel::default());
2016 let (done_tx, mut done_rx) = smol::channel::unbounded();
2017 let subscription = client.add_message_handler(
2018 model.clone().downgrade(),
2019 move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, mut cx| {
2020 model
2021 .update(&mut cx, |model, _| model.subscription.take())
2022 .unwrap();
2023 done_tx.try_send(()).unwrap();
2024 async { Ok(()) }
2025 },
2026 );
2027 model.update(cx, |model, _| {
2028 model.subscription = Some(subscription);
2029 });
2030 server.send(proto::Ping {});
2031 done_rx.next().await.unwrap();
2032 }
2033
2034 #[derive(Default)]
2035 struct TestModel {
2036 id: usize,
2037 subscription: Option<Subscription>,
2038 }
2039
2040 fn init_test(cx: &mut TestAppContext) {
2041 cx.update(|cx| {
2042 let settings_store = SettingsStore::test(cx);
2043 cx.set_global(settings_store);
2044 init_settings(cx);
2045 });
2046 }
2047}