1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4mod proxy;
5pub mod telemetry;
6pub mod user;
7pub mod zed_urls;
8
9use anyhow::{Context as _, Result, anyhow};
10use async_tungstenite::tungstenite::{
11 client::IntoClientRequest,
12 error::Error as WebsocketError,
13 http::{HeaderValue, Request, StatusCode},
14};
15use clock::SystemClock;
16use cloud_api_client::CloudApiClient;
17use cloud_api_client::websocket_protocol::MessageToClient;
18use credentials_provider::CredentialsProvider;
19use feature_flags::FeatureFlagAppExt as _;
20use futures::{
21 AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
22 channel::oneshot, future::BoxFuture,
23};
24use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
25use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
26use parking_lot::RwLock;
27use postage::watch;
28use proxy::connect_proxy_stream;
29use rand::prelude::*;
30use release_channel::{AppVersion, ReleaseChannel};
31use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
32use serde::{Deserialize, Serialize};
33use settings::{RegisterSetting, Settings, SettingsContent};
34use std::{
35 any::TypeId,
36 convert::TryFrom,
37 fmt::Write as _,
38 future::Future,
39 marker::PhantomData,
40 path::PathBuf,
41 sync::{
42 Arc, LazyLock, Weak,
43 atomic::{AtomicU64, Ordering},
44 },
45 time::{Duration, Instant},
46};
47use std::{cmp, pin::Pin};
48use telemetry::Telemetry;
49use thiserror::Error;
50use tokio::net::TcpStream;
51use url::Url;
52use util::{ConnectionResult, ResultExt};
53
54pub use rpc::*;
55pub use telemetry_events::Event;
56pub use user::*;
57
58static ZED_SERVER_URL: LazyLock<Option<String>> =
59 LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
60static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
61
62pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
63 std::env::var("ZED_IMPERSONATE")
64 .ok()
65 .and_then(|s| if s.is_empty() { None } else { Some(s) })
66});
67
68pub static USE_WEB_LOGIN: LazyLock<bool> = LazyLock::new(|| std::env::var("ZED_WEB_LOGIN").is_ok());
69
70pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
71 std::env::var("ZED_ADMIN_API_TOKEN")
72 .ok()
73 .and_then(|s| if s.is_empty() { None } else { Some(s) })
74});
75
76pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
77 LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
78
79pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
80 LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").is_ok_and(|e| !e.is_empty()));
81
82pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
83pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
84pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
85
86actions!(
87 client,
88 [
89 /// Signs in to Zed account.
90 SignIn,
91 /// Signs out of Zed account.
92 SignOut,
93 /// Reconnects to the collaboration server.
94 Reconnect
95 ]
96);
97
98#[derive(Deserialize, RegisterSetting)]
99pub struct ClientSettings {
100 pub server_url: String,
101}
102
103impl Settings for ClientSettings {
104 fn from_settings(content: &settings::SettingsContent) -> Self {
105 if let Some(server_url) = &*ZED_SERVER_URL {
106 return Self {
107 server_url: server_url.clone(),
108 };
109 }
110 Self {
111 server_url: content.server_url.clone().unwrap(),
112 }
113 }
114}
115
116#[derive(Deserialize, Default, RegisterSetting)]
117pub struct ProxySettings {
118 pub proxy: Option<String>,
119}
120
121impl ProxySettings {
122 pub fn proxy_url(&self) -> Option<Url> {
123 self.proxy
124 .as_ref()
125 .and_then(|input| {
126 input
127 .parse::<Url>()
128 .inspect_err(|e| log::error!("Error parsing proxy settings: {}", e))
129 .ok()
130 })
131 .or_else(read_proxy_from_env)
132 }
133}
134
135impl Settings for ProxySettings {
136 fn from_settings(content: &settings::SettingsContent) -> Self {
137 Self {
138 proxy: content.proxy.clone(),
139 }
140 }
141}
142
143pub fn init(client: &Arc<Client>, cx: &mut App) {
144 let client = Arc::downgrade(client);
145 cx.on_action({
146 let client = client.clone();
147 move |_: &SignIn, cx| {
148 if let Some(client) = client.upgrade() {
149 cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, cx).await)
150 .detach_and_log_err(cx);
151 }
152 }
153 })
154 .on_action({
155 let client = client.clone();
156 move |_: &SignOut, cx| {
157 if let Some(client) = client.upgrade() {
158 cx.spawn(async move |cx| {
159 client.sign_out(cx).await;
160 })
161 .detach();
162 }
163 }
164 })
165 .on_action({
166 let client = client;
167 move |_: &Reconnect, cx| {
168 if let Some(client) = client.upgrade() {
169 cx.spawn(async move |cx| {
170 client.reconnect(cx);
171 })
172 .detach();
173 }
174 }
175 });
176}
177
178pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
179
180struct GlobalClient(Arc<Client>);
181
182impl Global for GlobalClient {}
183
184pub struct Client {
185 id: AtomicU64,
186 peer: Arc<Peer>,
187 http: Arc<HttpClientWithUrl>,
188 cloud_client: Arc<CloudApiClient>,
189 telemetry: Arc<Telemetry>,
190 credentials_provider: ClientCredentialsProvider,
191 state: RwLock<ClientState>,
192 handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
193 message_to_client_handlers: parking_lot::Mutex<Vec<MessageToClientHandler>>,
194
195 #[allow(clippy::type_complexity)]
196 #[cfg(any(test, feature = "test-support"))]
197 authenticate:
198 RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
199
200 #[allow(clippy::type_complexity)]
201 #[cfg(any(test, feature = "test-support"))]
202 establish_connection: RwLock<
203 Option<
204 Box<
205 dyn 'static
206 + Send
207 + Sync
208 + Fn(
209 &Credentials,
210 &AsyncApp,
211 ) -> Task<Result<Connection, EstablishConnectionError>>,
212 >,
213 >,
214 >,
215
216 #[cfg(any(test, feature = "test-support"))]
217 rpc_url: RwLock<Option<Url>>,
218}
219
220#[derive(Error, Debug)]
221pub enum EstablishConnectionError {
222 #[error("upgrade required")]
223 UpgradeRequired,
224 #[error("unauthorized")]
225 Unauthorized,
226 #[error("{0}")]
227 Other(#[from] anyhow::Error),
228 #[error("{0}")]
229 InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
230 #[error("{0}")]
231 Io(#[from] std::io::Error),
232 #[error("{0}")]
233 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
234}
235
236impl From<WebsocketError> for EstablishConnectionError {
237 fn from(error: WebsocketError) -> Self {
238 if let WebsocketError::Http(response) = &error {
239 match response.status() {
240 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
241 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
242 _ => {}
243 }
244 }
245 EstablishConnectionError::Other(error.into())
246 }
247}
248
249impl EstablishConnectionError {
250 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
251 Self::Other(error.into())
252 }
253}
254
255#[derive(Copy, Clone, Debug, PartialEq)]
256pub enum Status {
257 SignedOut,
258 UpgradeRequired,
259 Authenticating,
260 Authenticated,
261 AuthenticationError,
262 Connecting,
263 ConnectionError,
264 Connected {
265 peer_id: PeerId,
266 connection_id: ConnectionId,
267 },
268 ConnectionLost,
269 Reauthenticating,
270 Reauthenticated,
271 Reconnecting,
272 ReconnectionError {
273 next_reconnection: Instant,
274 },
275}
276
277impl Status {
278 pub fn is_connected(&self) -> bool {
279 matches!(self, Self::Connected { .. })
280 }
281
282 pub fn was_connected(&self) -> bool {
283 matches!(
284 self,
285 Self::ConnectionLost
286 | Self::Reauthenticating
287 | Self::Reauthenticated
288 | Self::Reconnecting
289 )
290 }
291
292 /// Returns whether the client is currently connected or was connected at some point.
293 pub fn is_or_was_connected(&self) -> bool {
294 self.is_connected() || self.was_connected()
295 }
296
297 pub fn is_signing_in(&self) -> bool {
298 matches!(
299 self,
300 Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
301 )
302 }
303
304 pub fn is_signed_out(&self) -> bool {
305 matches!(self, Self::SignedOut | Self::UpgradeRequired)
306 }
307}
308
309struct ClientState {
310 credentials: Option<Credentials>,
311 status: (watch::Sender<Status>, watch::Receiver<Status>),
312 _reconnect_task: Option<Task<()>>,
313}
314
315#[derive(Clone, Debug, Eq, PartialEq)]
316pub struct Credentials {
317 pub user_id: u64,
318 pub access_token: String,
319}
320
321impl Credentials {
322 pub fn authorization_header(&self) -> String {
323 format!("{} {}", self.user_id, self.access_token)
324 }
325}
326
327pub struct ClientCredentialsProvider {
328 provider: Arc<dyn CredentialsProvider>,
329}
330
331impl ClientCredentialsProvider {
332 pub fn new(cx: &App) -> Self {
333 Self {
334 provider: <dyn CredentialsProvider>::global(cx),
335 }
336 }
337
338 fn server_url(&self, cx: &AsyncApp) -> Result<String> {
339 cx.update(|cx| ClientSettings::get_global(cx).server_url.clone())
340 }
341
342 /// Reads the credentials from the provider.
343 fn read_credentials<'a>(
344 &'a self,
345 cx: &'a AsyncApp,
346 ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
347 async move {
348 if IMPERSONATE_LOGIN.is_some() {
349 return None;
350 }
351
352 let server_url = self.server_url(cx).ok()?;
353 let (user_id, access_token) = self
354 .provider
355 .read_credentials(&server_url, cx)
356 .await
357 .log_err()
358 .flatten()?;
359
360 Some(Credentials {
361 user_id: user_id.parse().ok()?,
362 access_token: String::from_utf8(access_token).ok()?,
363 })
364 }
365 .boxed_local()
366 }
367
368 /// Writes the credentials to the provider.
369 fn write_credentials<'a>(
370 &'a self,
371 user_id: u64,
372 access_token: String,
373 cx: &'a AsyncApp,
374 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
375 async move {
376 let server_url = self.server_url(cx)?;
377 self.provider
378 .write_credentials(
379 &server_url,
380 &user_id.to_string(),
381 access_token.as_bytes(),
382 cx,
383 )
384 .await
385 }
386 .boxed_local()
387 }
388
389 /// Deletes the credentials from the provider.
390 fn delete_credentials<'a>(
391 &'a self,
392 cx: &'a AsyncApp,
393 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
394 async move {
395 let server_url = self.server_url(cx)?;
396 self.provider.delete_credentials(&server_url, cx).await
397 }
398 .boxed_local()
399 }
400}
401
402impl Default for ClientState {
403 fn default() -> Self {
404 Self {
405 credentials: None,
406 status: watch::channel_with(Status::SignedOut),
407 _reconnect_task: None,
408 }
409 }
410}
411
412pub enum Subscription {
413 Entity {
414 client: Weak<Client>,
415 id: (TypeId, u64),
416 },
417 Message {
418 client: Weak<Client>,
419 id: TypeId,
420 },
421}
422
423impl Drop for Subscription {
424 fn drop(&mut self) {
425 match self {
426 Subscription::Entity { client, id } => {
427 if let Some(client) = client.upgrade() {
428 let mut state = client.handler_set.lock();
429 let _ = state.entities_by_type_and_remote_id.remove(id);
430 }
431 }
432 Subscription::Message { client, id } => {
433 if let Some(client) = client.upgrade() {
434 let mut state = client.handler_set.lock();
435 let _ = state.entity_types_by_message_type.remove(id);
436 let _ = state.message_handlers.remove(id);
437 }
438 }
439 }
440 }
441}
442
443pub struct PendingEntitySubscription<T: 'static> {
444 client: Arc<Client>,
445 remote_id: u64,
446 _entity_type: PhantomData<T>,
447 consumed: bool,
448}
449
450impl<T: 'static> PendingEntitySubscription<T> {
451 pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
452 self.consumed = true;
453 let mut handlers = self.client.handler_set.lock();
454 let id = (TypeId::of::<T>(), self.remote_id);
455 let Some(EntityMessageSubscriber::Pending(messages)) =
456 handlers.entities_by_type_and_remote_id.remove(&id)
457 else {
458 unreachable!()
459 };
460
461 handlers.entities_by_type_and_remote_id.insert(
462 id,
463 EntityMessageSubscriber::Entity {
464 handle: entity.downgrade().into(),
465 },
466 );
467 drop(handlers);
468 for message in messages {
469 let client_id = self.client.id();
470 let type_name = message.payload_type_name();
471 let sender_id = message.original_sender_id();
472 log::debug!(
473 "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
474 client_id,
475 sender_id,
476 type_name
477 );
478 self.client.handle_message(message, cx);
479 }
480 Subscription::Entity {
481 client: Arc::downgrade(&self.client),
482 id,
483 }
484 }
485}
486
487impl<T: 'static> Drop for PendingEntitySubscription<T> {
488 fn drop(&mut self) {
489 if !self.consumed {
490 let mut state = self.client.handler_set.lock();
491 if let Some(EntityMessageSubscriber::Pending(messages)) = state
492 .entities_by_type_and_remote_id
493 .remove(&(TypeId::of::<T>(), self.remote_id))
494 {
495 for message in messages {
496 log::info!("unhandled message {}", message.payload_type_name());
497 }
498 }
499 }
500 }
501}
502
503#[derive(Copy, Clone, Deserialize, Debug, RegisterSetting)]
504pub struct TelemetrySettings {
505 pub diagnostics: bool,
506 pub metrics: bool,
507}
508
509impl settings::Settings for TelemetrySettings {
510 fn from_settings(content: &SettingsContent) -> Self {
511 Self {
512 diagnostics: content.telemetry.as_ref().unwrap().diagnostics.unwrap(),
513 metrics: content.telemetry.as_ref().unwrap().metrics.unwrap(),
514 }
515 }
516}
517
518impl Client {
519 pub fn new(
520 clock: Arc<dyn SystemClock>,
521 http: Arc<HttpClientWithUrl>,
522 cx: &mut App,
523 ) -> Arc<Self> {
524 Arc::new(Self {
525 id: AtomicU64::new(0),
526 peer: Peer::new(0),
527 telemetry: Telemetry::new(clock, http.clone(), cx),
528 cloud_client: Arc::new(CloudApiClient::new(http.clone())),
529 http,
530 credentials_provider: ClientCredentialsProvider::new(cx),
531 state: Default::default(),
532 handler_set: Default::default(),
533 message_to_client_handlers: parking_lot::Mutex::new(Vec::new()),
534
535 #[cfg(any(test, feature = "test-support"))]
536 authenticate: Default::default(),
537 #[cfg(any(test, feature = "test-support"))]
538 establish_connection: Default::default(),
539 #[cfg(any(test, feature = "test-support"))]
540 rpc_url: RwLock::default(),
541 })
542 }
543
544 pub fn production(cx: &mut App) -> Arc<Self> {
545 let clock = Arc::new(clock::RealSystemClock);
546 let http = Arc::new(HttpClientWithUrl::new_url(
547 cx.http_client(),
548 &ClientSettings::get_global(cx).server_url,
549 cx.http_client().proxy().cloned(),
550 ));
551 Self::new(clock, http, cx)
552 }
553
554 pub fn id(&self) -> u64 {
555 self.id.load(Ordering::SeqCst)
556 }
557
558 pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
559 self.http.clone()
560 }
561
562 pub fn cloud_client(&self) -> Arc<CloudApiClient> {
563 self.cloud_client.clone()
564 }
565
566 pub fn set_id(&self, id: u64) -> &Self {
567 self.id.store(id, Ordering::SeqCst);
568 self
569 }
570
571 #[cfg(any(test, feature = "test-support"))]
572 pub fn teardown(&self) {
573 let mut state = self.state.write();
574 state._reconnect_task.take();
575 self.handler_set.lock().clear();
576 self.peer.teardown();
577 }
578
579 #[cfg(any(test, feature = "test-support"))]
580 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
581 where
582 F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
583 {
584 *self.authenticate.write() = Some(Box::new(authenticate));
585 self
586 }
587
588 #[cfg(any(test, feature = "test-support"))]
589 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
590 where
591 F: 'static
592 + Send
593 + Sync
594 + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
595 {
596 *self.establish_connection.write() = Some(Box::new(connect));
597 self
598 }
599
600 #[cfg(any(test, feature = "test-support"))]
601 pub fn override_rpc_url(&self, url: Url) -> &Self {
602 *self.rpc_url.write() = Some(url);
603 self
604 }
605
606 pub fn global(cx: &App) -> Arc<Self> {
607 cx.global::<GlobalClient>().0.clone()
608 }
609 pub fn set_global(client: Arc<Client>, cx: &mut App) {
610 cx.set_global(GlobalClient(client))
611 }
612
613 pub fn user_id(&self) -> Option<u64> {
614 self.state
615 .read()
616 .credentials
617 .as_ref()
618 .map(|credentials| credentials.user_id)
619 }
620
621 pub fn peer_id(&self) -> Option<PeerId> {
622 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
623 Some(*peer_id)
624 } else {
625 None
626 }
627 }
628
629 pub fn status(&self) -> watch::Receiver<Status> {
630 self.state.read().status.1.clone()
631 }
632
633 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
634 log::info!("set status on client {}: {:?}", self.id(), status);
635 let mut state = self.state.write();
636 *state.status.0.borrow_mut() = status;
637
638 match status {
639 Status::Connected { .. } => {
640 state._reconnect_task = None;
641 }
642 Status::ConnectionLost => {
643 let client = self.clone();
644 state._reconnect_task = Some(cx.spawn(async move |cx| {
645 #[cfg(any(test, feature = "test-support"))]
646 let mut rng = StdRng::seed_from_u64(0);
647 #[cfg(not(any(test, feature = "test-support")))]
648 let mut rng = StdRng::from_os_rng();
649
650 let mut delay = INITIAL_RECONNECTION_DELAY;
651 loop {
652 match client.connect(true, cx).await {
653 ConnectionResult::Timeout => {
654 log::error!("client connect attempt timed out")
655 }
656 ConnectionResult::ConnectionReset => {
657 log::error!("client connect attempt reset")
658 }
659 ConnectionResult::Result(r) => {
660 if let Err(error) = r {
661 log::error!("failed to connect: {error}");
662 } else {
663 break;
664 }
665 }
666 }
667
668 if matches!(
669 *client.status().borrow(),
670 Status::AuthenticationError | Status::ConnectionError
671 ) {
672 client.set_status(
673 Status::ReconnectionError {
674 next_reconnection: Instant::now() + delay,
675 },
676 cx,
677 );
678 let jitter = Duration::from_millis(
679 rng.random_range(0..delay.as_millis() as u64),
680 );
681 cx.background_executor().timer(delay + jitter).await;
682 delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
683 } else {
684 break;
685 }
686 }
687 }));
688 }
689 Status::SignedOut | Status::UpgradeRequired => {
690 self.telemetry.set_authenticated_user_info(None, false);
691 state._reconnect_task.take();
692 }
693 _ => {}
694 }
695 }
696
697 pub fn subscribe_to_entity<T>(
698 self: &Arc<Self>,
699 remote_id: u64,
700 ) -> Result<PendingEntitySubscription<T>>
701 where
702 T: 'static,
703 {
704 let id = (TypeId::of::<T>(), remote_id);
705
706 let mut state = self.handler_set.lock();
707 anyhow::ensure!(
708 !state.entities_by_type_and_remote_id.contains_key(&id),
709 "already subscribed to entity"
710 );
711
712 state
713 .entities_by_type_and_remote_id
714 .insert(id, EntityMessageSubscriber::Pending(Default::default()));
715
716 Ok(PendingEntitySubscription {
717 client: self.clone(),
718 remote_id,
719 consumed: false,
720 _entity_type: PhantomData,
721 })
722 }
723
724 #[track_caller]
725 pub fn add_message_handler<M, E, H, F>(
726 self: &Arc<Self>,
727 entity: WeakEntity<E>,
728 handler: H,
729 ) -> Subscription
730 where
731 M: EnvelopedMessage,
732 E: 'static,
733 H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
734 F: 'static + Future<Output = Result<()>>,
735 {
736 self.add_message_handler_impl(entity, move |entity, message, _, cx| {
737 handler(entity, message, cx)
738 })
739 }
740
741 fn add_message_handler_impl<M, E, H, F>(
742 self: &Arc<Self>,
743 entity: WeakEntity<E>,
744 handler: H,
745 ) -> Subscription
746 where
747 M: EnvelopedMessage,
748 E: 'static,
749 H: 'static
750 + Sync
751 + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
752 + Send
753 + Sync,
754 F: 'static + Future<Output = Result<()>>,
755 {
756 let message_type_id = TypeId::of::<M>();
757 let mut state = self.handler_set.lock();
758 state
759 .entities_by_message_type
760 .insert(message_type_id, entity.into());
761
762 let prev_handler = state.message_handlers.insert(
763 message_type_id,
764 Arc::new(move |subscriber, envelope, client, cx| {
765 let subscriber = subscriber.downcast::<E>().unwrap();
766 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
767 handler(subscriber, *envelope, client, cx).boxed_local()
768 }),
769 );
770 if prev_handler.is_some() {
771 let location = std::panic::Location::caller();
772 panic!(
773 "{}:{} registered handler for the same message {} twice",
774 location.file(),
775 location.line(),
776 std::any::type_name::<M>()
777 );
778 }
779
780 Subscription::Message {
781 client: Arc::downgrade(self),
782 id: message_type_id,
783 }
784 }
785
786 pub fn add_request_handler<M, E, H, F>(
787 self: &Arc<Self>,
788 entity: WeakEntity<E>,
789 handler: H,
790 ) -> Subscription
791 where
792 M: RequestMessage,
793 E: 'static,
794 H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
795 F: 'static + Future<Output = Result<M::Response>>,
796 {
797 self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
798 Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
799 })
800 }
801
802 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
803 receipt: Receipt<T>,
804 response: F,
805 client: AnyProtoClient,
806 ) -> Result<()> {
807 match response.await {
808 Ok(response) => {
809 client.send_response(receipt.message_id, response)?;
810 Ok(())
811 }
812 Err(error) => {
813 client.send_response(receipt.message_id, error.to_proto())?;
814 Err(error)
815 }
816 }
817 }
818
819 pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
820 self.credentials_provider
821 .read_credentials(cx)
822 .await
823 .is_some()
824 }
825
826 pub async fn sign_in(
827 self: &Arc<Self>,
828 try_provider: bool,
829 cx: &AsyncApp,
830 ) -> Result<Credentials> {
831 let is_reauthenticating = if self.status().borrow().is_signed_out() {
832 self.set_status(Status::Authenticating, cx);
833 false
834 } else {
835 self.set_status(Status::Reauthenticating, cx);
836 true
837 };
838
839 let mut credentials = None;
840
841 let old_credentials = self.state.read().credentials.clone();
842 if let Some(old_credentials) = old_credentials
843 && self.validate_credentials(&old_credentials, cx).await?
844 {
845 credentials = Some(old_credentials);
846 }
847
848 if credentials.is_none()
849 && try_provider
850 && let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await
851 {
852 if self.validate_credentials(&stored_credentials, cx).await? {
853 credentials = Some(stored_credentials);
854 } else {
855 self.credentials_provider
856 .delete_credentials(cx)
857 .await
858 .log_err();
859 }
860 }
861
862 if credentials.is_none() {
863 let mut status_rx = self.status();
864 let _ = status_rx.next().await;
865 futures::select_biased! {
866 authenticate = self.authenticate(cx).fuse() => {
867 match authenticate {
868 Ok(creds) => {
869 if IMPERSONATE_LOGIN.is_none() {
870 self.credentials_provider
871 .write_credentials(creds.user_id, creds.access_token.clone(), cx)
872 .await
873 .log_err();
874 }
875
876 credentials = Some(creds);
877 },
878 Err(err) => {
879 self.set_status(Status::AuthenticationError, cx);
880 return Err(err);
881 }
882 }
883 }
884 _ = status_rx.next().fuse() => {
885 return Err(anyhow!("authentication canceled"));
886 }
887 }
888 }
889
890 let credentials = credentials.unwrap();
891 self.set_id(credentials.user_id);
892 self.cloud_client
893 .set_credentials(credentials.user_id as u32, credentials.access_token.clone());
894 self.state.write().credentials = Some(credentials.clone());
895 self.set_status(
896 if is_reauthenticating {
897 Status::Reauthenticated
898 } else {
899 Status::Authenticated
900 },
901 cx,
902 );
903
904 Ok(credentials)
905 }
906
907 async fn validate_credentials(
908 self: &Arc<Self>,
909 credentials: &Credentials,
910 cx: &AsyncApp,
911 ) -> Result<bool> {
912 match self
913 .cloud_client
914 .validate_credentials(credentials.user_id as u32, &credentials.access_token)
915 .await
916 {
917 Ok(valid) => Ok(valid),
918 Err(err) => {
919 self.set_status(Status::AuthenticationError, cx);
920 Err(anyhow!("failed to validate credentials: {}", err))
921 }
922 }
923 }
924
925 /// Establishes a WebSocket connection with Cloud for receiving updates from the server.
926 async fn connect_to_cloud(self: &Arc<Self>, cx: &AsyncApp) -> Result<()> {
927 let connect_task = cx.update({
928 let cloud_client = self.cloud_client.clone();
929 move |cx| cloud_client.connect(cx)
930 })??;
931 let connection = connect_task.await?;
932
933 let (mut messages, task) = cx.update(|cx| connection.spawn(cx))?;
934 task.detach();
935
936 cx.spawn({
937 let this = self.clone();
938 async move |cx| {
939 while let Some(message) = messages.next().await {
940 if let Some(message) = message.log_err() {
941 this.handle_message_to_client(message, cx);
942 }
943 }
944 }
945 })
946 .detach();
947
948 Ok(())
949 }
950
951 /// Performs a sign-in and also (optionally) connects to Collab.
952 ///
953 /// Only Zed staff automatically connect to Collab.
954 pub async fn sign_in_with_optional_connect(
955 self: &Arc<Self>,
956 try_provider: bool,
957 cx: &AsyncApp,
958 ) -> Result<()> {
959 // Don't try to sign in again if we're already connected to Collab, as it will temporarily disconnect us.
960 if self.status().borrow().is_connected() {
961 return Ok(());
962 }
963
964 let (is_staff_tx, is_staff_rx) = oneshot::channel::<bool>();
965 let mut is_staff_tx = Some(is_staff_tx);
966 cx.update(|cx| {
967 cx.on_flags_ready(move |state, _cx| {
968 if let Some(is_staff_tx) = is_staff_tx.take() {
969 is_staff_tx.send(state.is_staff).log_err();
970 }
971 })
972 .detach();
973 })
974 .log_err();
975
976 let credentials = self.sign_in(try_provider, cx).await?;
977
978 self.connect_to_cloud(cx).await.log_err();
979
980 cx.update(move |cx| {
981 cx.spawn({
982 let client = self.clone();
983 async move |cx| {
984 let is_staff = is_staff_rx.await?;
985 if is_staff {
986 match client.connect_with_credentials(credentials, cx).await {
987 ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
988 ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
989 ConnectionResult::Result(result) => {
990 result.context("client auth and connect")
991 }
992 }
993 } else {
994 Ok(())
995 }
996 }
997 })
998 .detach_and_log_err(cx);
999 })
1000 .log_err();
1001
1002 Ok(())
1003 }
1004
1005 pub async fn connect(
1006 self: &Arc<Self>,
1007 try_provider: bool,
1008 cx: &AsyncApp,
1009 ) -> ConnectionResult<()> {
1010 let was_disconnected = match *self.status().borrow() {
1011 Status::SignedOut | Status::Authenticated => true,
1012 Status::ConnectionError
1013 | Status::ConnectionLost
1014 | Status::Authenticating
1015 | Status::AuthenticationError
1016 | Status::Reauthenticating
1017 | Status::Reauthenticated
1018 | Status::ReconnectionError { .. } => false,
1019 Status::Connected { .. } | Status::Connecting | Status::Reconnecting => {
1020 return ConnectionResult::Result(Ok(()));
1021 }
1022 Status::UpgradeRequired => {
1023 return ConnectionResult::Result(
1024 Err(EstablishConnectionError::UpgradeRequired)
1025 .context("client auth and connect"),
1026 );
1027 }
1028 };
1029 let credentials = match self.sign_in(try_provider, cx).await {
1030 Ok(credentials) => credentials,
1031 Err(err) => return ConnectionResult::Result(Err(err)),
1032 };
1033
1034 if was_disconnected {
1035 self.set_status(Status::Connecting, cx);
1036 } else {
1037 self.set_status(Status::Reconnecting, cx);
1038 }
1039
1040 self.connect_with_credentials(credentials, cx).await
1041 }
1042
1043 async fn connect_with_credentials(
1044 self: &Arc<Self>,
1045 credentials: Credentials,
1046 cx: &AsyncApp,
1047 ) -> ConnectionResult<()> {
1048 let mut timeout =
1049 futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
1050 futures::select_biased! {
1051 connection = self.establish_connection(&credentials, cx).fuse() => {
1052 match connection {
1053 Ok(conn) => {
1054 futures::select_biased! {
1055 result = self.set_connection(conn, cx).fuse() => {
1056 match result.context("client auth and connect") {
1057 Ok(()) => ConnectionResult::Result(Ok(())),
1058 Err(err) => {
1059 self.set_status(Status::ConnectionError, cx);
1060 ConnectionResult::Result(Err(err))
1061 },
1062 }
1063 },
1064 _ = timeout => {
1065 self.set_status(Status::ConnectionError, cx);
1066 ConnectionResult::Timeout
1067 }
1068 }
1069 }
1070 Err(EstablishConnectionError::Unauthorized) => {
1071 self.set_status(Status::ConnectionError, cx);
1072 ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
1073 }
1074 Err(EstablishConnectionError::UpgradeRequired) => {
1075 self.set_status(Status::UpgradeRequired, cx);
1076 ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
1077 }
1078 Err(error) => {
1079 self.set_status(Status::ConnectionError, cx);
1080 ConnectionResult::Result(Err(error).context("client auth and connect"))
1081 }
1082 }
1083 }
1084 _ = &mut timeout => {
1085 self.set_status(Status::ConnectionError, cx);
1086 ConnectionResult::Timeout
1087 }
1088 }
1089 }
1090
1091 async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
1092 let executor = cx.background_executor();
1093 log::debug!("add connection to peer");
1094 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
1095 let executor = executor.clone();
1096 move |duration| executor.timer(duration)
1097 });
1098 let handle_io = executor.spawn(handle_io);
1099
1100 let peer_id = async {
1101 log::debug!("waiting for server hello");
1102 let message = incoming.next().await.context("no hello message received")?;
1103 log::debug!("got server hello");
1104 let hello_message_type_name = message.payload_type_name().to_string();
1105 let hello = message
1106 .into_any()
1107 .downcast::<TypedEnvelope<proto::Hello>>()
1108 .map_err(|_| {
1109 anyhow!(
1110 "invalid hello message received: {:?}",
1111 hello_message_type_name
1112 )
1113 })?;
1114 let peer_id = hello.payload.peer_id.context("invalid peer id")?;
1115 Ok(peer_id)
1116 };
1117
1118 let peer_id = match peer_id.await {
1119 Ok(peer_id) => peer_id,
1120 Err(error) => {
1121 self.peer.disconnect(connection_id);
1122 return Err(error);
1123 }
1124 };
1125
1126 log::debug!(
1127 "set status to connected (connection id: {:?}, peer id: {:?})",
1128 connection_id,
1129 peer_id
1130 );
1131 self.set_status(
1132 Status::Connected {
1133 peer_id,
1134 connection_id,
1135 },
1136 cx,
1137 );
1138
1139 cx.spawn({
1140 let this = self.clone();
1141 async move |cx| {
1142 while let Some(message) = incoming.next().await {
1143 this.handle_message(message, cx);
1144 // Don't starve the main thread when receiving lots of messages at once.
1145 smol::future::yield_now().await;
1146 }
1147 }
1148 })
1149 .detach();
1150
1151 cx.spawn({
1152 let this = self.clone();
1153 async move |cx| match handle_io.await {
1154 Ok(()) => {
1155 if *this.status().borrow()
1156 == (Status::Connected {
1157 connection_id,
1158 peer_id,
1159 })
1160 {
1161 this.set_status(Status::SignedOut, cx);
1162 }
1163 }
1164 Err(err) => {
1165 log::error!("connection error: {:?}", err);
1166 this.set_status(Status::ConnectionLost, cx);
1167 }
1168 }
1169 })
1170 .detach();
1171
1172 Ok(())
1173 }
1174
1175 fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1176 #[cfg(any(test, feature = "test-support"))]
1177 if let Some(callback) = self.authenticate.read().as_ref() {
1178 return callback(cx);
1179 }
1180
1181 self.authenticate_with_browser(cx)
1182 }
1183
1184 fn establish_connection(
1185 self: &Arc<Self>,
1186 credentials: &Credentials,
1187 cx: &AsyncApp,
1188 ) -> Task<Result<Connection, EstablishConnectionError>> {
1189 #[cfg(any(test, feature = "test-support"))]
1190 if let Some(callback) = self.establish_connection.read().as_ref() {
1191 return callback(credentials, cx);
1192 }
1193
1194 self.establish_websocket_connection(credentials, cx)
1195 }
1196
1197 fn rpc_url(
1198 &self,
1199 http: Arc<HttpClientWithUrl>,
1200 release_channel: Option<ReleaseChannel>,
1201 ) -> impl Future<Output = Result<url::Url>> + use<> {
1202 #[cfg(any(test, feature = "test-support"))]
1203 let url_override = self.rpc_url.read().clone();
1204
1205 async move {
1206 #[cfg(any(test, feature = "test-support"))]
1207 if let Some(url) = url_override {
1208 return Ok(url);
1209 }
1210
1211 if let Some(url) = &*ZED_RPC_URL {
1212 return Url::parse(url).context("invalid rpc url");
1213 }
1214
1215 let mut url = http.build_url("/rpc");
1216 if let Some(preview_param) =
1217 release_channel.and_then(|channel| channel.release_query_param())
1218 {
1219 url += "?";
1220 url += preview_param;
1221 }
1222
1223 let response = http.get(&url, Default::default(), false).await?;
1224 anyhow::ensure!(
1225 response.status().is_redirection(),
1226 "unexpected /rpc response status {}",
1227 response.status()
1228 );
1229 let collab_url = response
1230 .headers()
1231 .get("Location")
1232 .context("missing location header in /rpc response")?
1233 .to_str()
1234 .map_err(EstablishConnectionError::other)?
1235 .to_string();
1236 Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
1237 }
1238 }
1239
1240 fn establish_websocket_connection(
1241 self: &Arc<Self>,
1242 credentials: &Credentials,
1243 cx: &AsyncApp,
1244 ) -> Task<Result<Connection, EstablishConnectionError>> {
1245 let release_channel = cx
1246 .update(|cx| ReleaseChannel::try_global(cx))
1247 .ok()
1248 .flatten();
1249 let app_version = cx
1250 .update(|cx| AppVersion::global(cx).to_string())
1251 .ok()
1252 .unwrap_or_default();
1253
1254 let http = self.http.clone();
1255 let proxy = http.proxy().cloned();
1256 let user_agent = http.user_agent().cloned();
1257 let credentials = credentials.clone();
1258 let rpc_url = self.rpc_url(http, release_channel);
1259 let system_id = self.telemetry.system_id();
1260 let metrics_id = self.telemetry.metrics_id();
1261 cx.spawn(async move |cx| {
1262 use HttpOrHttps::*;
1263
1264 #[derive(Debug)]
1265 enum HttpOrHttps {
1266 Http,
1267 Https,
1268 }
1269
1270 let mut rpc_url = rpc_url.await?;
1271 let url_scheme = match rpc_url.scheme() {
1272 "https" => Https,
1273 "http" => Http,
1274 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1275 };
1276
1277 let stream = gpui_tokio::Tokio::spawn_result(cx, {
1278 let rpc_url = rpc_url.clone();
1279 async move {
1280 let rpc_host = rpc_url
1281 .host_str()
1282 .zip(rpc_url.port_or_known_default())
1283 .context("missing host in rpc url")?;
1284 Ok(match proxy {
1285 Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
1286 None => Box::new(TcpStream::connect(rpc_host).await?),
1287 })
1288 }
1289 })?
1290 .await?;
1291
1292 log::info!("connected to rpc endpoint {}", rpc_url);
1293
1294 rpc_url
1295 .set_scheme(match url_scheme {
1296 Https => "wss",
1297 Http => "ws",
1298 })
1299 .unwrap();
1300
1301 // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1302 // for us from the RPC URL.
1303 //
1304 // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1305 let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1306
1307 // We then modify the request to add our desired headers.
1308 let request_headers = request.headers_mut();
1309 request_headers.insert(
1310 http::header::AUTHORIZATION,
1311 HeaderValue::from_str(&credentials.authorization_header())?,
1312 );
1313 request_headers.insert(
1314 "x-zed-protocol-version",
1315 HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1316 );
1317 request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1318 request_headers.insert(
1319 "x-zed-release-channel",
1320 HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1321 );
1322 if let Some(user_agent) = user_agent {
1323 request_headers.insert(http::header::USER_AGENT, user_agent);
1324 }
1325 if let Some(system_id) = system_id {
1326 request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1327 }
1328 if let Some(metrics_id) = metrics_id {
1329 request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1330 }
1331
1332 let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1333 request,
1334 stream,
1335 Some(Arc::new(http_client_tls::tls_config()).into()),
1336 None,
1337 )
1338 .await?;
1339
1340 Ok(Connection::new(
1341 stream
1342 .map_err(|error| anyhow!(error))
1343 .sink_map_err(|error| anyhow!(error)),
1344 ))
1345 })
1346 }
1347
1348 pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1349 let http = self.http.clone();
1350 let this = self.clone();
1351 cx.spawn(async move |cx| {
1352 let background = cx.background_executor().clone();
1353
1354 let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1355 cx.update(|cx| {
1356 cx.spawn(async move |cx| {
1357 let url = open_url_rx.await?;
1358 cx.update(|cx| cx.open_url(&url))
1359 })
1360 .detach_and_log_err(cx);
1361 })
1362 .log_err();
1363
1364 let credentials = background
1365 .clone()
1366 .spawn(async move {
1367 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1368 // zed server to encrypt the user's access token, so that it can'be intercepted by
1369 // any other app running on the user's device.
1370 let (public_key, private_key) =
1371 rpc::auth::keypair().expect("failed to generate keypair for auth");
1372 let public_key_string = String::try_from(public_key)
1373 .expect("failed to serialize public key for auth");
1374
1375 if let Some((login, token)) =
1376 IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1377 {
1378 if !*USE_WEB_LOGIN {
1379 eprintln!("authenticate as admin {login}, {token}");
1380
1381 return this
1382 .authenticate_as_admin(http, login.clone(), token.clone())
1383 .await;
1384 }
1385 }
1386
1387 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1388 let server =
1389 tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1390 let port = server.server_addr().port();
1391
1392 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1393 // that the user is signing in from a Zed app running on the same device.
1394 let mut url = http.build_url(&format!(
1395 "/native_app_signin?native_app_port={}&native_app_public_key={}",
1396 port, public_key_string
1397 ));
1398
1399 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1400 log::info!("impersonating user @{}", impersonate_login);
1401 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1402 }
1403
1404 open_url_tx.send(url).log_err();
1405
1406 #[derive(Deserialize)]
1407 struct CallbackParams {
1408 pub user_id: String,
1409 pub access_token: String,
1410 }
1411
1412 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1413 // access token from the query params.
1414 //
1415 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1416 // custom URL scheme instead of this local HTTP server.
1417 let (user_id, access_token) = background
1418 .spawn(async move {
1419 for _ in 0..100 {
1420 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1421 let path = req.url();
1422 let url = Url::parse(&format!("http://example.com{}", path))
1423 .context("failed to parse login notification url")?;
1424 let callback_params: CallbackParams =
1425 serde_urlencoded::from_str(url.query().unwrap_or_default())
1426 .context(
1427 "failed to parse sign-in callback query parameters",
1428 )?;
1429
1430 let post_auth_url =
1431 http.build_url("/native_app_signin_succeeded");
1432 req.respond(
1433 tiny_http::Response::empty(302).with_header(
1434 tiny_http::Header::from_bytes(
1435 &b"Location"[..],
1436 post_auth_url.as_bytes(),
1437 )
1438 .unwrap(),
1439 ),
1440 )
1441 .context("failed to respond to login http request")?;
1442 return Ok((
1443 callback_params.user_id,
1444 callback_params.access_token,
1445 ));
1446 }
1447 }
1448
1449 anyhow::bail!("didn't receive login redirect");
1450 })
1451 .await?;
1452
1453 let access_token = private_key
1454 .decrypt_string(&access_token)
1455 .context("failed to decrypt access token")?;
1456
1457 Ok(Credentials {
1458 user_id: user_id.parse()?,
1459 access_token,
1460 })
1461 })
1462 .await?;
1463
1464 cx.update(|cx| cx.activate(true))?;
1465 Ok(credentials)
1466 })
1467 }
1468
1469 async fn authenticate_as_admin(
1470 self: &Arc<Self>,
1471 http: Arc<HttpClientWithUrl>,
1472 login: String,
1473 api_token: String,
1474 ) -> Result<Credentials> {
1475 #[derive(Serialize)]
1476 struct ImpersonateUserBody {
1477 github_login: String,
1478 }
1479
1480 #[derive(Deserialize)]
1481 struct ImpersonateUserResponse {
1482 user_id: u64,
1483 access_token: String,
1484 }
1485
1486 let url = self
1487 .http
1488 .build_zed_cloud_url("/internal/users/impersonate")?;
1489 let request = Request::post(url.as_str())
1490 .header("Content-Type", "application/json")
1491 .header("Authorization", format!("Bearer {api_token}"))
1492 .body(
1493 serde_json::to_string(&ImpersonateUserBody {
1494 github_login: login,
1495 })?
1496 .into(),
1497 )?;
1498
1499 let mut response = http.send(request).await?;
1500 let mut body = String::new();
1501 response.body_mut().read_to_string(&mut body).await?;
1502 anyhow::ensure!(
1503 response.status().is_success(),
1504 "admin user request failed {} - {}",
1505 response.status().as_u16(),
1506 body,
1507 );
1508 let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
1509
1510 Ok(Credentials {
1511 user_id: response.user_id,
1512 access_token: response.access_token,
1513 })
1514 }
1515
1516 pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1517 self.state.write().credentials = None;
1518 self.cloud_client.clear_credentials();
1519 self.disconnect(cx);
1520
1521 if self.has_credentials(cx).await {
1522 self.credentials_provider
1523 .delete_credentials(cx)
1524 .await
1525 .log_err();
1526 }
1527 }
1528
1529 pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1530 self.peer.teardown();
1531 self.set_status(Status::SignedOut, cx);
1532 }
1533
1534 pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1535 self.peer.teardown();
1536 self.set_status(Status::ConnectionLost, cx);
1537 }
1538
1539 fn connection_id(&self) -> Result<ConnectionId> {
1540 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1541 Ok(connection_id)
1542 } else {
1543 anyhow::bail!("not connected");
1544 }
1545 }
1546
1547 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1548 log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1549 self.peer.send(self.connection_id()?, message)
1550 }
1551
1552 pub fn request<T: RequestMessage>(
1553 &self,
1554 request: T,
1555 ) -> impl Future<Output = Result<T::Response>> + use<T> {
1556 self.request_envelope(request)
1557 .map_ok(|envelope| envelope.payload)
1558 }
1559
1560 pub fn request_stream<T: RequestMessage>(
1561 &self,
1562 request: T,
1563 ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1564 let client_id = self.id.load(Ordering::SeqCst);
1565 log::debug!(
1566 "rpc request start. client_id:{}. name:{}",
1567 client_id,
1568 T::NAME
1569 );
1570 let response = self
1571 .connection_id()
1572 .map(|conn_id| self.peer.request_stream(conn_id, request));
1573 async move {
1574 let response = response?.await;
1575 log::debug!(
1576 "rpc request finish. client_id:{}. name:{}",
1577 client_id,
1578 T::NAME
1579 );
1580 response
1581 }
1582 }
1583
1584 pub fn request_envelope<T: RequestMessage>(
1585 &self,
1586 request: T,
1587 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1588 let client_id = self.id();
1589 log::debug!(
1590 "rpc request start. client_id:{}. name:{}",
1591 client_id,
1592 T::NAME
1593 );
1594 let response = self
1595 .connection_id()
1596 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1597 async move {
1598 let response = response?.await;
1599 log::debug!(
1600 "rpc request finish. client_id:{}. name:{}",
1601 client_id,
1602 T::NAME
1603 );
1604 response
1605 }
1606 }
1607
1608 pub fn request_dynamic(
1609 &self,
1610 envelope: proto::Envelope,
1611 request_type: &'static str,
1612 ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1613 let client_id = self.id();
1614 log::debug!(
1615 "rpc request start. client_id:{}. name:{}",
1616 client_id,
1617 request_type
1618 );
1619 let response = self
1620 .connection_id()
1621 .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1622 async move {
1623 let response = response?.await;
1624 log::debug!(
1625 "rpc request finish. client_id:{}. name:{}",
1626 client_id,
1627 request_type
1628 );
1629 Ok(response?.0)
1630 }
1631 }
1632
1633 fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1634 let sender_id = message.sender_id();
1635 let request_id = message.message_id();
1636 let type_name = message.payload_type_name();
1637 let original_sender_id = message.original_sender_id();
1638
1639 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1640 &self.handler_set,
1641 message,
1642 self.clone().into(),
1643 cx.clone(),
1644 ) {
1645 let client_id = self.id();
1646 log::debug!(
1647 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1648 client_id,
1649 original_sender_id,
1650 type_name
1651 );
1652 cx.spawn(async move |_| match future.await {
1653 Ok(()) => {
1654 log::debug!("rpc message handled. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}");
1655 }
1656 Err(error) => {
1657 log::error!("error handling message. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}, error:{error:#}");
1658 }
1659 })
1660 .detach();
1661 } else {
1662 log::info!("unhandled message {}", type_name);
1663 self.peer
1664 .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1665 .log_err();
1666 }
1667 }
1668
1669 pub fn add_message_to_client_handler(
1670 self: &Arc<Client>,
1671 handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
1672 ) {
1673 self.message_to_client_handlers
1674 .lock()
1675 .push(Box::new(handler));
1676 }
1677
1678 fn handle_message_to_client(self: &Arc<Client>, message: MessageToClient, cx: &AsyncApp) {
1679 cx.update(|cx| {
1680 for handler in self.message_to_client_handlers.lock().iter() {
1681 handler(&message, cx);
1682 }
1683 })
1684 .ok();
1685 }
1686
1687 pub fn telemetry(&self) -> &Arc<Telemetry> {
1688 &self.telemetry
1689 }
1690}
1691
1692impl ProtoClient for Client {
1693 fn request(
1694 &self,
1695 envelope: proto::Envelope,
1696 request_type: &'static str,
1697 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1698 self.request_dynamic(envelope, request_type).boxed()
1699 }
1700
1701 fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1702 log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1703 let connection_id = self.connection_id()?;
1704 self.peer.send_dynamic(connection_id, envelope)
1705 }
1706
1707 fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1708 log::debug!(
1709 "rpc respond. client_id:{}, name:{}",
1710 self.id(),
1711 message_type
1712 );
1713 let connection_id = self.connection_id()?;
1714 self.peer.send_dynamic(connection_id, envelope)
1715 }
1716
1717 fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1718 &self.handler_set
1719 }
1720
1721 fn is_via_collab(&self) -> bool {
1722 true
1723 }
1724
1725 fn has_wsl_interop(&self) -> bool {
1726 false
1727 }
1728}
1729
1730/// prefix for the zed:// url scheme
1731pub const ZED_URL_SCHEME: &str = "zed";
1732
1733/// Parses the given link into a Zed link.
1734///
1735/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
1736/// Returns [`None`] otherwise.
1737pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
1738 let server_url = &ClientSettings::get_global(cx).server_url;
1739 if let Some(stripped) = link
1740 .strip_prefix(server_url)
1741 .and_then(|result| result.strip_prefix('/'))
1742 {
1743 return Some(stripped);
1744 }
1745 if let Some(stripped) = link
1746 .strip_prefix(ZED_URL_SCHEME)
1747 .and_then(|result| result.strip_prefix("://"))
1748 {
1749 return Some(stripped);
1750 }
1751
1752 None
1753}
1754
1755#[cfg(test)]
1756mod tests {
1757 use super::*;
1758 use crate::test::{FakeServer, parse_authorization_header};
1759
1760 use clock::FakeSystemClock;
1761 use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1762 use http_client::FakeHttpClient;
1763 use parking_lot::Mutex;
1764 use proto::TypedEnvelope;
1765 use settings::SettingsStore;
1766 use std::future;
1767
1768 #[gpui::test(iterations = 10)]
1769 async fn test_reconnection(cx: &mut TestAppContext) {
1770 init_test(cx);
1771 let user_id = 5;
1772 let client = cx.update(|cx| {
1773 Client::new(
1774 Arc::new(FakeSystemClock::new()),
1775 FakeHttpClient::with_404_response(),
1776 cx,
1777 )
1778 });
1779 let server = FakeServer::for_client(user_id, &client, cx).await;
1780 let mut status = client.status();
1781 assert!(matches!(
1782 status.next().await,
1783 Some(Status::Connected { .. })
1784 ));
1785 assert_eq!(server.auth_count(), 1);
1786
1787 server.forbid_connections();
1788 server.disconnect();
1789 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1790
1791 server.allow_connections();
1792 cx.executor().advance_clock(Duration::from_secs(10));
1793 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1794 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1795
1796 server.forbid_connections();
1797 server.disconnect();
1798 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1799
1800 // Clear cached credentials after authentication fails
1801 server.roll_access_token();
1802 server.allow_connections();
1803 cx.executor().run_until_parked();
1804 cx.executor().advance_clock(Duration::from_secs(10));
1805 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1806 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1807 }
1808
1809 #[gpui::test(iterations = 10)]
1810 async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) {
1811 init_test(cx);
1812 let http_client = FakeHttpClient::with_200_response();
1813 let client =
1814 cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1815 let server = FakeServer::for_client(42, &client, cx).await;
1816 let mut status = client.status();
1817 assert!(matches!(
1818 status.next().await,
1819 Some(Status::Connected { .. })
1820 ));
1821 assert_eq!(server.auth_count(), 1);
1822
1823 // Simulate an auth failure during reconnection.
1824 http_client
1825 .as_fake()
1826 .replace_handler(|_, _request| async move {
1827 Ok(http_client::Response::builder()
1828 .status(503)
1829 .body("".into())
1830 .unwrap())
1831 });
1832 server.disconnect();
1833 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1834
1835 // Restore the ability to authenticate.
1836 http_client
1837 .as_fake()
1838 .replace_handler(|_, _request| async move {
1839 Ok(http_client::Response::builder()
1840 .status(200)
1841 .body("".into())
1842 .unwrap())
1843 });
1844 cx.executor().advance_clock(Duration::from_secs(10));
1845 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1846 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1847 }
1848
1849 #[gpui::test(iterations = 10)]
1850 async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1851 init_test(cx);
1852 let user_id = 5;
1853 let client = cx.update(|cx| {
1854 Client::new(
1855 Arc::new(FakeSystemClock::new()),
1856 FakeHttpClient::with_404_response(),
1857 cx,
1858 )
1859 });
1860 let mut status = client.status();
1861
1862 // Time out when client tries to connect.
1863 client.override_authenticate(move |cx| {
1864 cx.background_spawn(async move {
1865 Ok(Credentials {
1866 user_id,
1867 access_token: "token".into(),
1868 })
1869 })
1870 });
1871 client.override_establish_connection(|_, cx| {
1872 cx.background_spawn(async move {
1873 future::pending::<()>().await;
1874 unreachable!()
1875 })
1876 });
1877 let auth_and_connect = cx.spawn({
1878 let client = client.clone();
1879 |cx| async move { client.connect(false, &cx).await }
1880 });
1881 executor.run_until_parked();
1882 assert!(matches!(status.next().await, Some(Status::Connecting)));
1883
1884 executor.advance_clock(CONNECTION_TIMEOUT);
1885 assert!(matches!(status.next().await, Some(Status::ConnectionError)));
1886 auth_and_connect.await.into_response().unwrap_err();
1887
1888 // Allow the connection to be established.
1889 let server = FakeServer::for_client(user_id, &client, cx).await;
1890 assert!(matches!(
1891 status.next().await,
1892 Some(Status::Connected { .. })
1893 ));
1894
1895 // Disconnect client.
1896 server.forbid_connections();
1897 server.disconnect();
1898 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1899
1900 // Time out when re-establishing the connection.
1901 server.allow_connections();
1902 client.override_establish_connection(|_, cx| {
1903 cx.background_spawn(async move {
1904 future::pending::<()>().await;
1905 unreachable!()
1906 })
1907 });
1908 executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1909 assert!(matches!(status.next().await, Some(Status::Reconnecting)));
1910
1911 executor.advance_clock(CONNECTION_TIMEOUT);
1912 assert!(matches!(
1913 status.next().await,
1914 Some(Status::ReconnectionError { .. })
1915 ));
1916 }
1917
1918 #[gpui::test(iterations = 10)]
1919 async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
1920 init_test(cx);
1921 let auth_count = Arc::new(Mutex::new(0));
1922 let http_client = FakeHttpClient::create(|_request| async move {
1923 Ok(http_client::Response::builder()
1924 .status(200)
1925 .body("".into())
1926 .unwrap())
1927 });
1928 let client =
1929 cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1930 client.override_authenticate({
1931 let auth_count = auth_count.clone();
1932 move |cx| {
1933 let auth_count = auth_count.clone();
1934 cx.background_spawn(async move {
1935 *auth_count.lock() += 1;
1936 Ok(Credentials {
1937 user_id: 1,
1938 access_token: auth_count.lock().to_string(),
1939 })
1940 })
1941 }
1942 });
1943
1944 let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1945 assert_eq!(*auth_count.lock(), 1);
1946 assert_eq!(credentials.access_token, "1");
1947
1948 // If credentials are still valid, signing in doesn't trigger authentication.
1949 let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1950 assert_eq!(*auth_count.lock(), 1);
1951 assert_eq!(credentials.access_token, "1");
1952
1953 // If the server is unavailable, signing in doesn't trigger authentication.
1954 http_client
1955 .as_fake()
1956 .replace_handler(|_, _request| async move {
1957 Ok(http_client::Response::builder()
1958 .status(503)
1959 .body("".into())
1960 .unwrap())
1961 });
1962 client.sign_in(false, &cx.to_async()).await.unwrap_err();
1963 assert_eq!(*auth_count.lock(), 1);
1964
1965 // If credentials became invalid, signing in triggers authentication.
1966 http_client
1967 .as_fake()
1968 .replace_handler(|_, request| async move {
1969 let credentials = parse_authorization_header(&request).unwrap();
1970 if credentials.access_token == "2" {
1971 Ok(http_client::Response::builder()
1972 .status(200)
1973 .body("".into())
1974 .unwrap())
1975 } else {
1976 Ok(http_client::Response::builder()
1977 .status(401)
1978 .body("".into())
1979 .unwrap())
1980 }
1981 });
1982 let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1983 assert_eq!(*auth_count.lock(), 2);
1984 assert_eq!(credentials.access_token, "2");
1985 }
1986
1987 #[gpui::test(iterations = 10)]
1988 async fn test_authenticating_more_than_once(
1989 cx: &mut TestAppContext,
1990 executor: BackgroundExecutor,
1991 ) {
1992 init_test(cx);
1993 let auth_count = Arc::new(Mutex::new(0));
1994 let dropped_auth_count = Arc::new(Mutex::new(0));
1995 let client = cx.update(|cx| {
1996 Client::new(
1997 Arc::new(FakeSystemClock::new()),
1998 FakeHttpClient::with_404_response(),
1999 cx,
2000 )
2001 });
2002 client.override_authenticate({
2003 let auth_count = auth_count.clone();
2004 let dropped_auth_count = dropped_auth_count.clone();
2005 move |cx| {
2006 let auth_count = auth_count.clone();
2007 let dropped_auth_count = dropped_auth_count.clone();
2008 cx.background_spawn(async move {
2009 *auth_count.lock() += 1;
2010 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
2011 future::pending::<()>().await;
2012 unreachable!()
2013 })
2014 }
2015 });
2016
2017 let _authenticate = cx.spawn({
2018 let client = client.clone();
2019 move |cx| async move { client.connect(false, &cx).await }
2020 });
2021 executor.run_until_parked();
2022 assert_eq!(*auth_count.lock(), 1);
2023 assert_eq!(*dropped_auth_count.lock(), 0);
2024
2025 let _authenticate = cx.spawn(|cx| async move { client.connect(false, &cx).await });
2026 executor.run_until_parked();
2027 assert_eq!(*auth_count.lock(), 2);
2028 assert_eq!(*dropped_auth_count.lock(), 1);
2029 }
2030
2031 #[gpui::test]
2032 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
2033 init_test(cx);
2034 let user_id = 5;
2035 let client = cx.update(|cx| {
2036 Client::new(
2037 Arc::new(FakeSystemClock::new()),
2038 FakeHttpClient::with_404_response(),
2039 cx,
2040 )
2041 });
2042 let server = FakeServer::for_client(user_id, &client, cx).await;
2043
2044 let (done_tx1, done_rx1) = smol::channel::unbounded();
2045 let (done_tx2, done_rx2) = smol::channel::unbounded();
2046 AnyProtoClient::from(client.clone()).add_entity_message_handler(
2047 move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, cx| {
2048 match entity.read_with(&cx, |entity, _| entity.id).unwrap() {
2049 1 => done_tx1.try_send(()).unwrap(),
2050 2 => done_tx2.try_send(()).unwrap(),
2051 _ => unreachable!(),
2052 }
2053 async { Ok(()) }
2054 },
2055 );
2056 let entity1 = cx.new(|_| TestEntity {
2057 id: 1,
2058 subscription: None,
2059 });
2060 let entity2 = cx.new(|_| TestEntity {
2061 id: 2,
2062 subscription: None,
2063 });
2064 let entity3 = cx.new(|_| TestEntity {
2065 id: 3,
2066 subscription: None,
2067 });
2068
2069 let _subscription1 = client
2070 .subscribe_to_entity(1)
2071 .unwrap()
2072 .set_entity(&entity1, &cx.to_async());
2073 let _subscription2 = client
2074 .subscribe_to_entity(2)
2075 .unwrap()
2076 .set_entity(&entity2, &cx.to_async());
2077 // Ensure dropping a subscription for the same entity type still allows receiving of
2078 // messages for other entity IDs of the same type.
2079 let subscription3 = client
2080 .subscribe_to_entity(3)
2081 .unwrap()
2082 .set_entity(&entity3, &cx.to_async());
2083 drop(subscription3);
2084
2085 server.send(proto::JoinProject {
2086 project_id: 1,
2087 committer_name: None,
2088 committer_email: None,
2089 });
2090 server.send(proto::JoinProject {
2091 project_id: 2,
2092 committer_name: None,
2093 committer_email: None,
2094 });
2095 done_rx1.recv().await.unwrap();
2096 done_rx2.recv().await.unwrap();
2097 }
2098
2099 #[gpui::test]
2100 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2101 init_test(cx);
2102 let user_id = 5;
2103 let client = cx.update(|cx| {
2104 Client::new(
2105 Arc::new(FakeSystemClock::new()),
2106 FakeHttpClient::with_404_response(),
2107 cx,
2108 )
2109 });
2110 let server = FakeServer::for_client(user_id, &client, cx).await;
2111
2112 let entity = cx.new(|_| TestEntity::default());
2113 let (done_tx1, _done_rx1) = smol::channel::unbounded();
2114 let (done_tx2, done_rx2) = smol::channel::unbounded();
2115 let subscription1 = client.add_message_handler(
2116 entity.downgrade(),
2117 move |_, _: TypedEnvelope<proto::Ping>, _| {
2118 done_tx1.try_send(()).unwrap();
2119 async { Ok(()) }
2120 },
2121 );
2122 drop(subscription1);
2123 let _subscription2 = client.add_message_handler(
2124 entity.downgrade(),
2125 move |_, _: TypedEnvelope<proto::Ping>, _| {
2126 done_tx2.try_send(()).unwrap();
2127 async { Ok(()) }
2128 },
2129 );
2130 server.send(proto::Ping {});
2131 done_rx2.recv().await.unwrap();
2132 }
2133
2134 #[gpui::test]
2135 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2136 init_test(cx);
2137 let user_id = 5;
2138 let client = cx.update(|cx| {
2139 Client::new(
2140 Arc::new(FakeSystemClock::new()),
2141 FakeHttpClient::with_404_response(),
2142 cx,
2143 )
2144 });
2145 let server = FakeServer::for_client(user_id, &client, cx).await;
2146
2147 let entity = cx.new(|_| TestEntity::default());
2148 let (done_tx, done_rx) = smol::channel::unbounded();
2149 let subscription = client.add_message_handler(
2150 entity.clone().downgrade(),
2151 move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
2152 entity
2153 .update(&mut cx, |entity, _| entity.subscription.take())
2154 .unwrap();
2155 done_tx.try_send(()).unwrap();
2156 async { Ok(()) }
2157 },
2158 );
2159 entity.update(cx, |entity, _| {
2160 entity.subscription = Some(subscription);
2161 });
2162 server.send(proto::Ping {});
2163 done_rx.recv().await.unwrap();
2164 }
2165
2166 #[derive(Default)]
2167 struct TestEntity {
2168 id: usize,
2169 subscription: Option<Subscription>,
2170 }
2171
2172 fn init_test(cx: &mut TestAppContext) {
2173 cx.update(|cx| {
2174 let settings_store = SettingsStore::test(cx);
2175 cx.set_global(settings_store);
2176 });
2177 }
2178}