1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4mod proxy;
5pub mod telemetry;
6pub mod user;
7pub mod zed_urls;
8
9use anyhow::{Context as _, Result, anyhow};
10use async_tungstenite::tungstenite::{
11 client::IntoClientRequest,
12 error::Error as WebsocketError,
13 http::{HeaderValue, Request, StatusCode},
14};
15use clock::SystemClock;
16use cloud_api_client::CloudApiClient;
17use cloud_api_client::websocket_protocol::MessageToClient;
18use credentials_provider::CredentialsProvider;
19use feature_flags::FeatureFlagAppExt as _;
20use futures::{
21 AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
22 channel::oneshot, future::BoxFuture,
23};
24use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
25use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
26use parking_lot::RwLock;
27use postage::watch;
28use proxy::connect_proxy_stream;
29use rand::prelude::*;
30use release_channel::{AppVersion, ReleaseChannel};
31use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
32use serde::{Deserialize, Serialize};
33use settings::{RegisterSetting, Settings, SettingsContent};
34use std::{
35 any::TypeId,
36 convert::TryFrom,
37 future::Future,
38 marker::PhantomData,
39 path::PathBuf,
40 sync::{
41 Arc, LazyLock, Weak,
42 atomic::{AtomicU64, Ordering},
43 },
44 time::{Duration, Instant},
45};
46use std::{cmp, pin::Pin};
47use telemetry::Telemetry;
48use thiserror::Error;
49use tokio::net::TcpStream;
50use url::Url;
51use util::{ConnectionResult, ResultExt};
52
53pub use rpc::*;
54pub use telemetry_events::Event;
55pub use user::*;
56
57static ZED_SERVER_URL: LazyLock<Option<String>> =
58 LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
59static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
60
61pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
62 std::env::var("ZED_IMPERSONATE")
63 .ok()
64 .and_then(|s| if s.is_empty() { None } else { Some(s) })
65});
66
67pub static USE_WEB_LOGIN: LazyLock<bool> = LazyLock::new(|| std::env::var("ZED_WEB_LOGIN").is_ok());
68
69pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
70 std::env::var("ZED_ADMIN_API_TOKEN")
71 .ok()
72 .and_then(|s| if s.is_empty() { None } else { Some(s) })
73});
74
75pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
76 LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
77
78pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
79 LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").is_ok_and(|e| !e.is_empty()));
80
81pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
82pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
83pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
84
85actions!(
86 client,
87 [
88 /// Signs in to Zed account.
89 SignIn,
90 /// Signs out of Zed account.
91 SignOut,
92 /// Reconnects to the collaboration server.
93 Reconnect
94 ]
95);
96
97#[derive(Deserialize, RegisterSetting)]
98pub struct ClientSettings {
99 pub server_url: String,
100}
101
102impl Settings for ClientSettings {
103 fn from_settings(content: &settings::SettingsContent) -> Self {
104 if let Some(server_url) = &*ZED_SERVER_URL {
105 return Self {
106 server_url: server_url.clone(),
107 };
108 }
109 Self {
110 server_url: content.server_url.clone().unwrap(),
111 }
112 }
113}
114
115#[derive(Deserialize, Default, RegisterSetting)]
116pub struct ProxySettings {
117 pub proxy: Option<String>,
118}
119
120impl ProxySettings {
121 pub fn proxy_url(&self) -> Option<Url> {
122 self.proxy
123 .as_deref()
124 .map(str::trim)
125 .filter(|input| !input.is_empty())
126 .and_then(|input| {
127 input
128 .parse::<Url>()
129 .inspect_err(|e| log::error!("Error parsing proxy settings: {}", e))
130 .ok()
131 })
132 .or_else(read_proxy_from_env)
133 }
134}
135
136impl Settings for ProxySettings {
137 fn from_settings(content: &settings::SettingsContent) -> Self {
138 Self {
139 proxy: content
140 .proxy
141 .as_deref()
142 .map(str::trim)
143 .filter(|proxy| !proxy.is_empty())
144 .map(ToOwned::to_owned),
145 }
146 }
147}
148
149pub fn init(client: &Arc<Client>, cx: &mut App) {
150 let client = Arc::downgrade(client);
151 cx.on_action({
152 let client = client.clone();
153 move |_: &SignIn, cx| {
154 if let Some(client) = client.upgrade() {
155 cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, cx).await)
156 .detach_and_log_err(cx);
157 }
158 }
159 })
160 .on_action({
161 let client = client.clone();
162 move |_: &SignOut, cx| {
163 if let Some(client) = client.upgrade() {
164 cx.spawn(async move |cx| {
165 client.sign_out(cx).await;
166 })
167 .detach();
168 }
169 }
170 })
171 .on_action({
172 let client = client;
173 move |_: &Reconnect, cx| {
174 if let Some(client) = client.upgrade() {
175 cx.spawn(async move |cx| {
176 client.reconnect(cx);
177 })
178 .detach();
179 }
180 }
181 });
182}
183
184pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
185
186struct GlobalClient(Arc<Client>);
187
188impl Global for GlobalClient {}
189
190pub struct Client {
191 id: AtomicU64,
192 peer: Arc<Peer>,
193 http: Arc<HttpClientWithUrl>,
194 cloud_client: Arc<CloudApiClient>,
195 telemetry: Arc<Telemetry>,
196 credentials_provider: ClientCredentialsProvider,
197 state: RwLock<ClientState>,
198 handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
199 message_to_client_handlers: parking_lot::Mutex<Vec<MessageToClientHandler>>,
200
201 #[allow(clippy::type_complexity)]
202 #[cfg(any(test, feature = "test-support"))]
203 authenticate:
204 RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
205
206 #[allow(clippy::type_complexity)]
207 #[cfg(any(test, feature = "test-support"))]
208 establish_connection: RwLock<
209 Option<
210 Box<
211 dyn 'static
212 + Send
213 + Sync
214 + Fn(
215 &Credentials,
216 &AsyncApp,
217 ) -> Task<Result<Connection, EstablishConnectionError>>,
218 >,
219 >,
220 >,
221
222 #[cfg(any(test, feature = "test-support"))]
223 rpc_url: RwLock<Option<Url>>,
224}
225
226#[derive(Error, Debug)]
227pub enum EstablishConnectionError {
228 #[error("upgrade required")]
229 UpgradeRequired,
230 #[error("unauthorized")]
231 Unauthorized,
232 #[error("{0}")]
233 Other(#[from] anyhow::Error),
234 #[error("{0}")]
235 InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
236 #[error("{0}")]
237 Io(#[from] std::io::Error),
238 #[error("{0}")]
239 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
240}
241
242impl From<WebsocketError> for EstablishConnectionError {
243 fn from(error: WebsocketError) -> Self {
244 if let WebsocketError::Http(response) = &error {
245 match response.status() {
246 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
247 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
248 _ => {}
249 }
250 }
251 EstablishConnectionError::Other(error.into())
252 }
253}
254
255impl EstablishConnectionError {
256 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
257 Self::Other(error.into())
258 }
259}
260
261#[derive(Copy, Clone, Debug, PartialEq)]
262pub enum Status {
263 SignedOut,
264 UpgradeRequired,
265 Authenticating,
266 Authenticated,
267 AuthenticationError,
268 Connecting,
269 ConnectionError,
270 Connected {
271 peer_id: PeerId,
272 connection_id: ConnectionId,
273 },
274 ConnectionLost,
275 Reauthenticating,
276 Reauthenticated,
277 Reconnecting,
278 ReconnectionError {
279 next_reconnection: Instant,
280 },
281}
282
283impl Status {
284 pub fn is_connected(&self) -> bool {
285 matches!(self, Self::Connected { .. })
286 }
287
288 pub fn was_connected(&self) -> bool {
289 matches!(
290 self,
291 Self::ConnectionLost
292 | Self::Reauthenticating
293 | Self::Reauthenticated
294 | Self::Reconnecting
295 )
296 }
297
298 /// Returns whether the client is currently connected or was connected at some point.
299 pub fn is_or_was_connected(&self) -> bool {
300 self.is_connected() || self.was_connected()
301 }
302
303 pub fn is_signing_in(&self) -> bool {
304 matches!(
305 self,
306 Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
307 )
308 }
309
310 pub fn is_signed_out(&self) -> bool {
311 matches!(self, Self::SignedOut | Self::UpgradeRequired)
312 }
313}
314
315struct ClientState {
316 credentials: Option<Credentials>,
317 status: (watch::Sender<Status>, watch::Receiver<Status>),
318 _reconnect_task: Option<Task<()>>,
319}
320
321#[derive(Clone, Debug, Eq, PartialEq)]
322pub struct Credentials {
323 pub user_id: u64,
324 pub access_token: String,
325}
326
327impl Credentials {
328 pub fn authorization_header(&self) -> String {
329 format!("{} {}", self.user_id, self.access_token)
330 }
331}
332
333pub struct ClientCredentialsProvider {
334 provider: Arc<dyn CredentialsProvider>,
335}
336
337impl ClientCredentialsProvider {
338 pub fn new(cx: &App) -> Self {
339 Self {
340 provider: <dyn CredentialsProvider>::global(cx),
341 }
342 }
343
344 fn server_url(&self, cx: &AsyncApp) -> Result<String> {
345 Ok(cx.update(|cx| ClientSettings::get_global(cx).server_url.clone()))
346 }
347
348 /// Reads the credentials from the provider.
349 fn read_credentials<'a>(
350 &'a self,
351 cx: &'a AsyncApp,
352 ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
353 async move {
354 if IMPERSONATE_LOGIN.is_some() {
355 return None;
356 }
357
358 let server_url = self.server_url(cx).ok()?;
359 let (user_id, access_token) = self
360 .provider
361 .read_credentials(&server_url, cx)
362 .await
363 .log_err()
364 .flatten()?;
365
366 Some(Credentials {
367 user_id: user_id.parse().ok()?,
368 access_token: String::from_utf8(access_token).ok()?,
369 })
370 }
371 .boxed_local()
372 }
373
374 /// Writes the credentials to the provider.
375 fn write_credentials<'a>(
376 &'a self,
377 user_id: u64,
378 access_token: String,
379 cx: &'a AsyncApp,
380 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
381 async move {
382 let server_url = self.server_url(cx)?;
383 self.provider
384 .write_credentials(
385 &server_url,
386 &user_id.to_string(),
387 access_token.as_bytes(),
388 cx,
389 )
390 .await
391 }
392 .boxed_local()
393 }
394
395 /// Deletes the credentials from the provider.
396 fn delete_credentials<'a>(
397 &'a self,
398 cx: &'a AsyncApp,
399 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
400 async move {
401 let server_url = self.server_url(cx)?;
402 self.provider.delete_credentials(&server_url, cx).await
403 }
404 .boxed_local()
405 }
406}
407
408impl Default for ClientState {
409 fn default() -> Self {
410 Self {
411 credentials: None,
412 status: watch::channel_with(Status::SignedOut),
413 _reconnect_task: None,
414 }
415 }
416}
417
418pub enum Subscription {
419 Entity {
420 client: Weak<Client>,
421 id: (TypeId, u64),
422 },
423 Message {
424 client: Weak<Client>,
425 id: TypeId,
426 },
427}
428
429impl Drop for Subscription {
430 fn drop(&mut self) {
431 match self {
432 Subscription::Entity { client, id } => {
433 if let Some(client) = client.upgrade() {
434 let mut state = client.handler_set.lock();
435 let _ = state.entities_by_type_and_remote_id.remove(id);
436 }
437 }
438 Subscription::Message { client, id } => {
439 if let Some(client) = client.upgrade() {
440 let mut state = client.handler_set.lock();
441 let _ = state.entity_types_by_message_type.remove(id);
442 let _ = state.message_handlers.remove(id);
443 }
444 }
445 }
446 }
447}
448
449pub struct PendingEntitySubscription<T: 'static> {
450 client: Arc<Client>,
451 remote_id: u64,
452 _entity_type: PhantomData<T>,
453 consumed: bool,
454}
455
456impl<T: 'static> PendingEntitySubscription<T> {
457 pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
458 self.consumed = true;
459 let mut handlers = self.client.handler_set.lock();
460 let id = (TypeId::of::<T>(), self.remote_id);
461 let Some(EntityMessageSubscriber::Pending(messages)) =
462 handlers.entities_by_type_and_remote_id.remove(&id)
463 else {
464 unreachable!()
465 };
466
467 handlers.entities_by_type_and_remote_id.insert(
468 id,
469 EntityMessageSubscriber::Entity {
470 handle: entity.downgrade().into(),
471 },
472 );
473 drop(handlers);
474 for message in messages {
475 let client_id = self.client.id();
476 let type_name = message.payload_type_name();
477 let sender_id = message.original_sender_id();
478 log::debug!(
479 "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
480 client_id,
481 sender_id,
482 type_name
483 );
484 self.client.handle_message(message, cx);
485 }
486 Subscription::Entity {
487 client: Arc::downgrade(&self.client),
488 id,
489 }
490 }
491}
492
493impl<T: 'static> Drop for PendingEntitySubscription<T> {
494 fn drop(&mut self) {
495 if !self.consumed {
496 let mut state = self.client.handler_set.lock();
497 if let Some(EntityMessageSubscriber::Pending(messages)) = state
498 .entities_by_type_and_remote_id
499 .remove(&(TypeId::of::<T>(), self.remote_id))
500 {
501 for message in messages {
502 log::info!("unhandled message {}", message.payload_type_name());
503 }
504 }
505 }
506 }
507}
508
509#[derive(Copy, Clone, Deserialize, Debug, RegisterSetting)]
510pub struct TelemetrySettings {
511 pub diagnostics: bool,
512 pub metrics: bool,
513}
514
515impl settings::Settings for TelemetrySettings {
516 fn from_settings(content: &SettingsContent) -> Self {
517 Self {
518 diagnostics: content.telemetry.as_ref().unwrap().diagnostics.unwrap(),
519 metrics: content.telemetry.as_ref().unwrap().metrics.unwrap(),
520 }
521 }
522}
523
524impl Client {
525 pub fn new(
526 clock: Arc<dyn SystemClock>,
527 http: Arc<HttpClientWithUrl>,
528 cx: &mut App,
529 ) -> Arc<Self> {
530 Arc::new(Self {
531 id: AtomicU64::new(0),
532 peer: Peer::new(0),
533 telemetry: Telemetry::new(clock, http.clone(), cx),
534 cloud_client: Arc::new(CloudApiClient::new(http.clone())),
535 http,
536 credentials_provider: ClientCredentialsProvider::new(cx),
537 state: Default::default(),
538 handler_set: Default::default(),
539 message_to_client_handlers: parking_lot::Mutex::new(Vec::new()),
540
541 #[cfg(any(test, feature = "test-support"))]
542 authenticate: Default::default(),
543 #[cfg(any(test, feature = "test-support"))]
544 establish_connection: Default::default(),
545 #[cfg(any(test, feature = "test-support"))]
546 rpc_url: RwLock::default(),
547 })
548 }
549
550 pub fn production(cx: &mut App) -> Arc<Self> {
551 let clock = Arc::new(clock::RealSystemClock);
552 let http = Arc::new(HttpClientWithUrl::new_url(
553 cx.http_client(),
554 &ClientSettings::get_global(cx).server_url,
555 cx.http_client().proxy().cloned(),
556 ));
557 Self::new(clock, http, cx)
558 }
559
560 pub fn id(&self) -> u64 {
561 self.id.load(Ordering::SeqCst)
562 }
563
564 pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
565 self.http.clone()
566 }
567
568 pub fn cloud_client(&self) -> Arc<CloudApiClient> {
569 self.cloud_client.clone()
570 }
571
572 pub fn set_id(&self, id: u64) -> &Self {
573 self.id.store(id, Ordering::SeqCst);
574 self
575 }
576
577 #[cfg(any(test, feature = "test-support"))]
578 pub fn teardown(&self) {
579 let mut state = self.state.write();
580 state._reconnect_task.take();
581 self.handler_set.lock().clear();
582 self.peer.teardown();
583 }
584
585 #[cfg(any(test, feature = "test-support"))]
586 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
587 where
588 F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
589 {
590 *self.authenticate.write() = Some(Box::new(authenticate));
591 self
592 }
593
594 #[cfg(any(test, feature = "test-support"))]
595 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
596 where
597 F: 'static
598 + Send
599 + Sync
600 + Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
601 {
602 *self.establish_connection.write() = Some(Box::new(connect));
603 self
604 }
605
606 #[cfg(any(test, feature = "test-support"))]
607 pub fn override_rpc_url(&self, url: Url) -> &Self {
608 *self.rpc_url.write() = Some(url);
609 self
610 }
611
612 pub fn global(cx: &App) -> Arc<Self> {
613 cx.global::<GlobalClient>().0.clone()
614 }
615 pub fn set_global(client: Arc<Client>, cx: &mut App) {
616 cx.set_global(GlobalClient(client))
617 }
618
619 pub fn user_id(&self) -> Option<u64> {
620 self.state
621 .read()
622 .credentials
623 .as_ref()
624 .map(|credentials| credentials.user_id)
625 }
626
627 pub fn peer_id(&self) -> Option<PeerId> {
628 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
629 Some(*peer_id)
630 } else {
631 None
632 }
633 }
634
635 pub fn status(&self) -> watch::Receiver<Status> {
636 self.state.read().status.1.clone()
637 }
638
639 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
640 log::info!("set status on client {}: {:?}", self.id(), status);
641 let mut state = self.state.write();
642 *state.status.0.borrow_mut() = status;
643
644 match status {
645 Status::Connected { .. } => {
646 state._reconnect_task = None;
647 }
648 Status::ConnectionLost => {
649 let client = self.clone();
650 state._reconnect_task = Some(cx.spawn(async move |cx| {
651 #[cfg(any(test, feature = "test-support"))]
652 let mut rng = StdRng::seed_from_u64(0);
653 #[cfg(not(any(test, feature = "test-support")))]
654 let mut rng = StdRng::from_os_rng();
655
656 let mut delay = INITIAL_RECONNECTION_DELAY;
657 loop {
658 match client.connect(true, cx).await {
659 ConnectionResult::Timeout => {
660 log::error!("client connect attempt timed out")
661 }
662 ConnectionResult::ConnectionReset => {
663 log::error!("client connect attempt reset")
664 }
665 ConnectionResult::Result(r) => {
666 if let Err(error) = r {
667 log::error!("failed to connect: {error}");
668 } else {
669 break;
670 }
671 }
672 }
673
674 if matches!(
675 *client.status().borrow(),
676 Status::AuthenticationError | Status::ConnectionError
677 ) {
678 client.set_status(
679 Status::ReconnectionError {
680 next_reconnection: Instant::now() + delay,
681 },
682 cx,
683 );
684 let jitter = Duration::from_millis(
685 rng.random_range(0..delay.as_millis() as u64),
686 );
687 cx.background_executor().timer(delay + jitter).await;
688 delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
689 } else {
690 break;
691 }
692 }
693 }));
694 }
695 Status::SignedOut | Status::UpgradeRequired => {
696 self.telemetry.set_authenticated_user_info(None, false);
697 state._reconnect_task.take();
698 }
699 _ => {}
700 }
701 }
702
703 pub fn subscribe_to_entity<T>(
704 self: &Arc<Self>,
705 remote_id: u64,
706 ) -> Result<PendingEntitySubscription<T>>
707 where
708 T: 'static,
709 {
710 let id = (TypeId::of::<T>(), remote_id);
711
712 let mut state = self.handler_set.lock();
713 anyhow::ensure!(
714 !state.entities_by_type_and_remote_id.contains_key(&id),
715 "already subscribed to entity"
716 );
717
718 state
719 .entities_by_type_and_remote_id
720 .insert(id, EntityMessageSubscriber::Pending(Default::default()));
721
722 Ok(PendingEntitySubscription {
723 client: self.clone(),
724 remote_id,
725 consumed: false,
726 _entity_type: PhantomData,
727 })
728 }
729
730 #[track_caller]
731 pub fn add_message_handler<M, E, H, F>(
732 self: &Arc<Self>,
733 entity: WeakEntity<E>,
734 handler: H,
735 ) -> Subscription
736 where
737 M: EnvelopedMessage,
738 E: 'static,
739 H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
740 F: 'static + Future<Output = Result<()>>,
741 {
742 self.add_message_handler_impl(entity, move |entity, message, _, cx| {
743 handler(entity, message, cx)
744 })
745 }
746
747 fn add_message_handler_impl<M, E, H, F>(
748 self: &Arc<Self>,
749 entity: WeakEntity<E>,
750 handler: H,
751 ) -> Subscription
752 where
753 M: EnvelopedMessage,
754 E: 'static,
755 H: 'static
756 + Sync
757 + Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
758 + Send
759 + Sync,
760 F: 'static + Future<Output = Result<()>>,
761 {
762 let message_type_id = TypeId::of::<M>();
763 let mut state = self.handler_set.lock();
764 state
765 .entities_by_message_type
766 .insert(message_type_id, entity.into());
767
768 let prev_handler = state.message_handlers.insert(
769 message_type_id,
770 Arc::new(move |subscriber, envelope, client, cx| {
771 let subscriber = subscriber.downcast::<E>().unwrap();
772 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
773 handler(subscriber, *envelope, client, cx).boxed_local()
774 }),
775 );
776 if prev_handler.is_some() {
777 let location = std::panic::Location::caller();
778 panic!(
779 "{}:{} registered handler for the same message {} twice",
780 location.file(),
781 location.line(),
782 std::any::type_name::<M>()
783 );
784 }
785
786 Subscription::Message {
787 client: Arc::downgrade(self),
788 id: message_type_id,
789 }
790 }
791
792 pub fn add_request_handler<M, E, H, F>(
793 self: &Arc<Self>,
794 entity: WeakEntity<E>,
795 handler: H,
796 ) -> Subscription
797 where
798 M: RequestMessage,
799 E: 'static,
800 H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
801 F: 'static + Future<Output = Result<M::Response>>,
802 {
803 self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
804 Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
805 })
806 }
807
808 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
809 receipt: Receipt<T>,
810 response: F,
811 client: AnyProtoClient,
812 ) -> Result<()> {
813 match response.await {
814 Ok(response) => {
815 client.send_response(receipt.message_id, response)?;
816 Ok(())
817 }
818 Err(error) => {
819 client.send_response(receipt.message_id, error.to_proto())?;
820 Err(error)
821 }
822 }
823 }
824
825 pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
826 self.credentials_provider
827 .read_credentials(cx)
828 .await
829 .is_some()
830 }
831
832 pub async fn sign_in(
833 self: &Arc<Self>,
834 try_provider: bool,
835 cx: &AsyncApp,
836 ) -> Result<Credentials> {
837 let is_reauthenticating = if self.status().borrow().is_signed_out() {
838 self.set_status(Status::Authenticating, cx);
839 false
840 } else {
841 self.set_status(Status::Reauthenticating, cx);
842 true
843 };
844
845 let mut credentials = None;
846
847 let old_credentials = self.state.read().credentials.clone();
848 if let Some(old_credentials) = old_credentials
849 && self.validate_credentials(&old_credentials, cx).await?
850 {
851 credentials = Some(old_credentials);
852 }
853
854 if credentials.is_none()
855 && try_provider
856 && let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await
857 {
858 if self.validate_credentials(&stored_credentials, cx).await? {
859 credentials = Some(stored_credentials);
860 } else {
861 self.credentials_provider
862 .delete_credentials(cx)
863 .await
864 .log_err();
865 }
866 }
867
868 if credentials.is_none() {
869 let mut status_rx = self.status();
870 let _ = status_rx.next().await;
871 futures::select_biased! {
872 authenticate = self.authenticate(cx).fuse() => {
873 match authenticate {
874 Ok(creds) => {
875 if IMPERSONATE_LOGIN.is_none() {
876 self.credentials_provider
877 .write_credentials(creds.user_id, creds.access_token.clone(), cx)
878 .await
879 .log_err();
880 }
881
882 credentials = Some(creds);
883 },
884 Err(err) => {
885 self.set_status(Status::AuthenticationError, cx);
886 return Err(err);
887 }
888 }
889 }
890 _ = status_rx.next().fuse() => {
891 return Err(anyhow!("authentication canceled"));
892 }
893 }
894 }
895
896 let credentials = credentials.unwrap();
897 self.set_id(credentials.user_id);
898 self.cloud_client
899 .set_credentials(credentials.user_id as u32, credentials.access_token.clone());
900 self.state.write().credentials = Some(credentials.clone());
901 self.set_status(
902 if is_reauthenticating {
903 Status::Reauthenticated
904 } else {
905 Status::Authenticated
906 },
907 cx,
908 );
909
910 Ok(credentials)
911 }
912
913 async fn validate_credentials(
914 self: &Arc<Self>,
915 credentials: &Credentials,
916 cx: &AsyncApp,
917 ) -> Result<bool> {
918 match self
919 .cloud_client
920 .validate_credentials(credentials.user_id as u32, &credentials.access_token)
921 .await
922 {
923 Ok(valid) => Ok(valid),
924 Err(err) => {
925 self.set_status(Status::AuthenticationError, cx);
926 Err(anyhow!("failed to validate credentials: {}", err))
927 }
928 }
929 }
930
931 /// Establishes a WebSocket connection with Cloud for receiving updates from the server.
932 async fn connect_to_cloud(self: &Arc<Self>, cx: &AsyncApp) -> Result<()> {
933 let connect_task = cx.update({
934 let cloud_client = self.cloud_client.clone();
935 move |cx| cloud_client.connect(cx)
936 })?;
937 let connection = connect_task.await?;
938
939 let (mut messages, task) = cx.update(|cx| connection.spawn(cx));
940 task.detach();
941
942 cx.spawn({
943 let this = self.clone();
944 async move |cx| {
945 while let Some(message) = messages.next().await {
946 if let Some(message) = message.log_err() {
947 this.handle_message_to_client(message, cx);
948 }
949 }
950 }
951 })
952 .detach();
953
954 Ok(())
955 }
956
957 /// Performs a sign-in and also (optionally) connects to Collab.
958 ///
959 /// Only Zed staff automatically connect to Collab.
960 pub async fn sign_in_with_optional_connect(
961 self: &Arc<Self>,
962 try_provider: bool,
963 cx: &AsyncApp,
964 ) -> Result<()> {
965 // Don't try to sign in again if we're already connected to Collab, as it will temporarily disconnect us.
966 if self.status().borrow().is_connected() {
967 return Ok(());
968 }
969
970 let (is_staff_tx, is_staff_rx) = oneshot::channel::<bool>();
971 let mut is_staff_tx = Some(is_staff_tx);
972 cx.update(|cx| {
973 cx.on_flags_ready(move |state, _cx| {
974 if let Some(is_staff_tx) = is_staff_tx.take() {
975 is_staff_tx.send(state.is_staff).log_err();
976 }
977 })
978 .detach();
979 });
980
981 let credentials = self.sign_in(try_provider, cx).await?;
982
983 self.connect_to_cloud(cx).await.log_err();
984
985 cx.update(move |cx| {
986 cx.spawn({
987 let client = self.clone();
988 async move |cx| {
989 let is_staff = is_staff_rx.await?;
990 if is_staff {
991 match client.connect_with_credentials(credentials, cx).await {
992 ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
993 ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
994 ConnectionResult::Result(result) => {
995 result.context("client auth and connect")
996 }
997 }
998 } else {
999 Ok(())
1000 }
1001 }
1002 })
1003 .detach_and_log_err(cx);
1004 });
1005
1006 Ok(())
1007 }
1008
1009 pub async fn connect(
1010 self: &Arc<Self>,
1011 try_provider: bool,
1012 cx: &AsyncApp,
1013 ) -> ConnectionResult<()> {
1014 let was_disconnected = match *self.status().borrow() {
1015 Status::SignedOut | Status::Authenticated => true,
1016 Status::ConnectionError
1017 | Status::ConnectionLost
1018 | Status::Authenticating
1019 | Status::AuthenticationError
1020 | Status::Reauthenticating
1021 | Status::Reauthenticated
1022 | Status::ReconnectionError { .. } => false,
1023 Status::Connected { .. } | Status::Connecting | Status::Reconnecting => {
1024 return ConnectionResult::Result(Ok(()));
1025 }
1026 Status::UpgradeRequired => {
1027 return ConnectionResult::Result(
1028 Err(EstablishConnectionError::UpgradeRequired)
1029 .context("client auth and connect"),
1030 );
1031 }
1032 };
1033 let credentials = match self.sign_in(try_provider, cx).await {
1034 Ok(credentials) => credentials,
1035 Err(err) => return ConnectionResult::Result(Err(err)),
1036 };
1037
1038 if was_disconnected {
1039 self.set_status(Status::Connecting, cx);
1040 } else {
1041 self.set_status(Status::Reconnecting, cx);
1042 }
1043
1044 self.connect_with_credentials(credentials, cx).await
1045 }
1046
1047 async fn connect_with_credentials(
1048 self: &Arc<Self>,
1049 credentials: Credentials,
1050 cx: &AsyncApp,
1051 ) -> ConnectionResult<()> {
1052 let mut timeout =
1053 futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
1054 futures::select_biased! {
1055 connection = self.establish_connection(&credentials, cx).fuse() => {
1056 match connection {
1057 Ok(conn) => {
1058 futures::select_biased! {
1059 result = self.set_connection(conn, cx).fuse() => {
1060 match result.context("client auth and connect") {
1061 Ok(()) => ConnectionResult::Result(Ok(())),
1062 Err(err) => {
1063 self.set_status(Status::ConnectionError, cx);
1064 ConnectionResult::Result(Err(err))
1065 },
1066 }
1067 },
1068 _ = timeout => {
1069 self.set_status(Status::ConnectionError, cx);
1070 ConnectionResult::Timeout
1071 }
1072 }
1073 }
1074 Err(EstablishConnectionError::Unauthorized) => {
1075 self.set_status(Status::ConnectionError, cx);
1076 ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
1077 }
1078 Err(EstablishConnectionError::UpgradeRequired) => {
1079 self.set_status(Status::UpgradeRequired, cx);
1080 ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
1081 }
1082 Err(error) => {
1083 self.set_status(Status::ConnectionError, cx);
1084 ConnectionResult::Result(Err(error).context("client auth and connect"))
1085 }
1086 }
1087 }
1088 _ = &mut timeout => {
1089 self.set_status(Status::ConnectionError, cx);
1090 ConnectionResult::Timeout
1091 }
1092 }
1093 }
1094
1095 async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
1096 let executor = cx.background_executor();
1097 log::debug!("add connection to peer");
1098 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
1099 let executor = executor.clone();
1100 move |duration| executor.timer(duration)
1101 });
1102 let handle_io = executor.spawn(handle_io);
1103
1104 let peer_id = async {
1105 log::debug!("waiting for server hello");
1106 let message = incoming.next().await.context("no hello message received")?;
1107 log::debug!("got server hello");
1108 let hello_message_type_name = message.payload_type_name().to_string();
1109 let hello = message
1110 .into_any()
1111 .downcast::<TypedEnvelope<proto::Hello>>()
1112 .map_err(|_| {
1113 anyhow!(
1114 "invalid hello message received: {:?}",
1115 hello_message_type_name
1116 )
1117 })?;
1118 let peer_id = hello.payload.peer_id.context("invalid peer id")?;
1119 Ok(peer_id)
1120 };
1121
1122 let peer_id = match peer_id.await {
1123 Ok(peer_id) => peer_id,
1124 Err(error) => {
1125 self.peer.disconnect(connection_id);
1126 return Err(error);
1127 }
1128 };
1129
1130 log::debug!(
1131 "set status to connected (connection id: {:?}, peer id: {:?})",
1132 connection_id,
1133 peer_id
1134 );
1135 self.set_status(
1136 Status::Connected {
1137 peer_id,
1138 connection_id,
1139 },
1140 cx,
1141 );
1142
1143 cx.spawn({
1144 let this = self.clone();
1145 async move |cx| {
1146 while let Some(message) = incoming.next().await {
1147 this.handle_message(message, cx);
1148 // Don't starve the main thread when receiving lots of messages at once.
1149 smol::future::yield_now().await;
1150 }
1151 }
1152 })
1153 .detach();
1154
1155 cx.spawn({
1156 let this = self.clone();
1157 async move |cx| match handle_io.await {
1158 Ok(()) => {
1159 if *this.status().borrow()
1160 == (Status::Connected {
1161 connection_id,
1162 peer_id,
1163 })
1164 {
1165 this.set_status(Status::SignedOut, cx);
1166 }
1167 }
1168 Err(err) => {
1169 log::error!("connection error: {:?}", err);
1170 this.set_status(Status::ConnectionLost, cx);
1171 }
1172 }
1173 })
1174 .detach();
1175
1176 Ok(())
1177 }
1178
1179 fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1180 #[cfg(any(test, feature = "test-support"))]
1181 if let Some(callback) = self.authenticate.read().as_ref() {
1182 return callback(cx);
1183 }
1184
1185 self.authenticate_with_browser(cx)
1186 }
1187
1188 fn establish_connection(
1189 self: &Arc<Self>,
1190 credentials: &Credentials,
1191 cx: &AsyncApp,
1192 ) -> Task<Result<Connection, EstablishConnectionError>> {
1193 #[cfg(any(test, feature = "test-support"))]
1194 if let Some(callback) = self.establish_connection.read().as_ref() {
1195 return callback(credentials, cx);
1196 }
1197
1198 self.establish_websocket_connection(credentials, cx)
1199 }
1200
1201 fn rpc_url(
1202 &self,
1203 http: Arc<HttpClientWithUrl>,
1204 release_channel: Option<ReleaseChannel>,
1205 ) -> impl Future<Output = Result<url::Url>> + use<> {
1206 #[cfg(any(test, feature = "test-support"))]
1207 let url_override = self.rpc_url.read().clone();
1208
1209 async move {
1210 #[cfg(any(test, feature = "test-support"))]
1211 if let Some(url) = url_override {
1212 return Ok(url);
1213 }
1214
1215 if let Some(url) = &*ZED_RPC_URL {
1216 return Url::parse(url).context("invalid rpc url");
1217 }
1218
1219 let mut url = http.build_url("/rpc");
1220 if let Some(preview_param) =
1221 release_channel.and_then(|channel| channel.release_query_param())
1222 {
1223 url += "?";
1224 url += preview_param;
1225 }
1226
1227 let response = http.get(&url, Default::default(), false).await?;
1228 anyhow::ensure!(
1229 response.status().is_redirection(),
1230 "unexpected /rpc response status {}",
1231 response.status()
1232 );
1233 let collab_url = response
1234 .headers()
1235 .get("Location")
1236 .context("missing location header in /rpc response")?
1237 .to_str()
1238 .map_err(EstablishConnectionError::other)?
1239 .to_string();
1240 Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
1241 }
1242 }
1243
1244 fn establish_websocket_connection(
1245 self: &Arc<Self>,
1246 credentials: &Credentials,
1247 cx: &AsyncApp,
1248 ) -> Task<Result<Connection, EstablishConnectionError>> {
1249 let release_channel = cx.update(|cx| ReleaseChannel::try_global(cx));
1250 let app_version = cx.update(|cx| AppVersion::global(cx).to_string());
1251
1252 let http = self.http.clone();
1253 let proxy = http.proxy().cloned();
1254 let user_agent = http.user_agent().cloned();
1255 let credentials = credentials.clone();
1256 let rpc_url = self.rpc_url(http, release_channel);
1257 let system_id = self.telemetry.system_id();
1258 let metrics_id = self.telemetry.metrics_id();
1259 cx.spawn(async move |cx| {
1260 use HttpOrHttps::*;
1261
1262 #[derive(Debug)]
1263 enum HttpOrHttps {
1264 Http,
1265 Https,
1266 }
1267
1268 let mut rpc_url = rpc_url.await?;
1269 let url_scheme = match rpc_url.scheme() {
1270 "https" => Https,
1271 "http" => Http,
1272 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1273 };
1274
1275 let stream = gpui_tokio::Tokio::spawn_result(cx, {
1276 let rpc_url = rpc_url.clone();
1277 async move {
1278 let rpc_host = rpc_url
1279 .host_str()
1280 .zip(rpc_url.port_or_known_default())
1281 .context("missing host in rpc url")?;
1282 Ok(match proxy {
1283 Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
1284 None => Box::new(TcpStream::connect(rpc_host).await?),
1285 })
1286 }
1287 })
1288 .await?;
1289
1290 log::info!("connected to rpc endpoint {}", rpc_url);
1291
1292 rpc_url
1293 .set_scheme(match url_scheme {
1294 Https => "wss",
1295 Http => "ws",
1296 })
1297 .unwrap();
1298
1299 // We call `into_client_request` to let `tungstenite` construct the WebSocket request
1300 // for us from the RPC URL.
1301 //
1302 // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
1303 let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
1304
1305 // We then modify the request to add our desired headers.
1306 let request_headers = request.headers_mut();
1307 request_headers.insert(
1308 http::header::AUTHORIZATION,
1309 HeaderValue::from_str(&credentials.authorization_header())?,
1310 );
1311 request_headers.insert(
1312 "x-zed-protocol-version",
1313 HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
1314 );
1315 request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
1316 request_headers.insert(
1317 "x-zed-release-channel",
1318 HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
1319 );
1320 if let Some(user_agent) = user_agent {
1321 request_headers.insert(http::header::USER_AGENT, user_agent);
1322 }
1323 if let Some(system_id) = system_id {
1324 request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
1325 }
1326 if let Some(metrics_id) = metrics_id {
1327 request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
1328 }
1329
1330 let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
1331 request,
1332 stream,
1333 Some(Arc::new(http_client_tls::tls_config()).into()),
1334 None,
1335 )
1336 .await?;
1337
1338 Ok(Connection::new(
1339 stream
1340 .map_err(|error| anyhow!(error))
1341 .sink_map_err(|error| anyhow!(error)),
1342 ))
1343 })
1344 }
1345
1346 pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
1347 let http = self.http.clone();
1348 let this = self.clone();
1349 cx.spawn(async move |cx| {
1350 let background = cx.background_executor().clone();
1351
1352 let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1353 cx.update(|cx| {
1354 cx.spawn(async move |cx| {
1355 if let Ok(url) = open_url_rx.await {
1356 cx.update(|cx| cx.open_url(&url));
1357 }
1358 })
1359 .detach();
1360 });
1361
1362 let credentials = background
1363 .clone()
1364 .spawn(async move {
1365 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1366 // zed server to encrypt the user's access token, so that it can'be intercepted by
1367 // any other app running on the user's device.
1368 let (public_key, private_key) =
1369 rpc::auth::keypair().expect("failed to generate keypair for auth");
1370 let public_key_string = String::try_from(public_key)
1371 .expect("failed to serialize public key for auth");
1372
1373 if let Some((login, token)) =
1374 IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1375 {
1376 if !*USE_WEB_LOGIN {
1377 eprintln!("authenticate as admin {login}, {token}");
1378
1379 return this
1380 .authenticate_as_admin(http, login.clone(), token.clone())
1381 .await;
1382 }
1383 }
1384
1385 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1386 let server =
1387 tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1388 let port = server.server_addr().port();
1389
1390 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1391 // that the user is signing in from a Zed app running on the same device.
1392 let url = http.build_url(&format!(
1393 "/native_app_signin?native_app_port={}&native_app_public_key={}",
1394 port, public_key_string
1395 ));
1396
1397 open_url_tx.send(url).log_err();
1398
1399 #[derive(Deserialize)]
1400 struct CallbackParams {
1401 pub user_id: String,
1402 pub access_token: String,
1403 }
1404
1405 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1406 // access token from the query params.
1407 //
1408 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1409 // custom URL scheme instead of this local HTTP server.
1410 let (user_id, access_token) = background
1411 .spawn(async move {
1412 for _ in 0..100 {
1413 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1414 let path = req.url();
1415 let url = Url::parse(&format!("http://example.com{}", path))
1416 .context("failed to parse login notification url")?;
1417 let callback_params: CallbackParams =
1418 serde_urlencoded::from_str(url.query().unwrap_or_default())
1419 .context(
1420 "failed to parse sign-in callback query parameters",
1421 )?;
1422
1423 let post_auth_url =
1424 http.build_url("/native_app_signin_succeeded");
1425 req.respond(
1426 tiny_http::Response::empty(302).with_header(
1427 tiny_http::Header::from_bytes(
1428 &b"Location"[..],
1429 post_auth_url.as_bytes(),
1430 )
1431 .unwrap(),
1432 ),
1433 )
1434 .context("failed to respond to login http request")?;
1435 return Ok((
1436 callback_params.user_id,
1437 callback_params.access_token,
1438 ));
1439 }
1440 }
1441
1442 anyhow::bail!("didn't receive login redirect");
1443 })
1444 .await?;
1445
1446 let access_token = private_key
1447 .decrypt_string(&access_token)
1448 .context("failed to decrypt access token")?;
1449
1450 Ok(Credentials {
1451 user_id: user_id.parse()?,
1452 access_token,
1453 })
1454 })
1455 .await?;
1456
1457 cx.update(|cx| cx.activate(true));
1458 Ok(credentials)
1459 })
1460 }
1461
1462 async fn authenticate_as_admin(
1463 self: &Arc<Self>,
1464 http: Arc<HttpClientWithUrl>,
1465 login: String,
1466 api_token: String,
1467 ) -> Result<Credentials> {
1468 #[derive(Serialize)]
1469 struct ImpersonateUserBody {
1470 github_login: String,
1471 }
1472
1473 #[derive(Deserialize)]
1474 struct ImpersonateUserResponse {
1475 user_id: u64,
1476 access_token: String,
1477 }
1478
1479 let url = self
1480 .http
1481 .build_zed_cloud_url("/internal/users/impersonate")?;
1482 let request = Request::post(url.as_str())
1483 .header("Content-Type", "application/json")
1484 .header("Authorization", format!("Bearer {api_token}"))
1485 .body(
1486 serde_json::to_string(&ImpersonateUserBody {
1487 github_login: login,
1488 })?
1489 .into(),
1490 )?;
1491
1492 let mut response = http.send(request).await?;
1493 let mut body = String::new();
1494 response.body_mut().read_to_string(&mut body).await?;
1495 anyhow::ensure!(
1496 response.status().is_success(),
1497 "admin user request failed {} - {}",
1498 response.status().as_u16(),
1499 body,
1500 );
1501 let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
1502
1503 Ok(Credentials {
1504 user_id: response.user_id,
1505 access_token: response.access_token,
1506 })
1507 }
1508
1509 pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
1510 self.state.write().credentials = None;
1511 self.cloud_client.clear_credentials();
1512 self.disconnect(cx);
1513
1514 if self.has_credentials(cx).await {
1515 self.credentials_provider
1516 .delete_credentials(cx)
1517 .await
1518 .log_err();
1519 }
1520 }
1521
1522 pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
1523 self.peer.teardown();
1524 self.set_status(Status::SignedOut, cx);
1525 }
1526
1527 pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
1528 self.peer.teardown();
1529 self.set_status(Status::ConnectionLost, cx);
1530 }
1531
1532 fn connection_id(&self) -> Result<ConnectionId> {
1533 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1534 Ok(connection_id)
1535 } else {
1536 anyhow::bail!("not connected");
1537 }
1538 }
1539
1540 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1541 log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1542 self.peer.send(self.connection_id()?, message)
1543 }
1544
1545 pub fn request<T: RequestMessage>(
1546 &self,
1547 request: T,
1548 ) -> impl Future<Output = Result<T::Response>> + use<T> {
1549 self.request_envelope(request)
1550 .map_ok(|envelope| envelope.payload)
1551 }
1552
1553 pub fn request_stream<T: RequestMessage>(
1554 &self,
1555 request: T,
1556 ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
1557 let client_id = self.id.load(Ordering::SeqCst);
1558 log::debug!(
1559 "rpc request start. client_id:{}. name:{}",
1560 client_id,
1561 T::NAME
1562 );
1563 let response = self
1564 .connection_id()
1565 .map(|conn_id| self.peer.request_stream(conn_id, request));
1566 async move {
1567 let response = response?.await;
1568 log::debug!(
1569 "rpc request finish. client_id:{}. name:{}",
1570 client_id,
1571 T::NAME
1572 );
1573 response
1574 }
1575 }
1576
1577 pub fn request_envelope<T: RequestMessage>(
1578 &self,
1579 request: T,
1580 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
1581 let client_id = self.id();
1582 log::debug!(
1583 "rpc request start. client_id:{}. name:{}",
1584 client_id,
1585 T::NAME
1586 );
1587 let response = self
1588 .connection_id()
1589 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1590 async move {
1591 let response = response?.await;
1592 log::debug!(
1593 "rpc request finish. client_id:{}. name:{}",
1594 client_id,
1595 T::NAME
1596 );
1597 response
1598 }
1599 }
1600
1601 pub fn request_dynamic(
1602 &self,
1603 envelope: proto::Envelope,
1604 request_type: &'static str,
1605 ) -> impl Future<Output = Result<proto::Envelope>> + use<> {
1606 let client_id = self.id();
1607 log::debug!(
1608 "rpc request start. client_id:{}. name:{}",
1609 client_id,
1610 request_type
1611 );
1612 let response = self
1613 .connection_id()
1614 .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
1615 async move {
1616 let response = response?.await;
1617 log::debug!(
1618 "rpc request finish. client_id:{}. name:{}",
1619 client_id,
1620 request_type
1621 );
1622 Ok(response?.0)
1623 }
1624 }
1625
1626 fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
1627 let sender_id = message.sender_id();
1628 let request_id = message.message_id();
1629 let type_name = message.payload_type_name();
1630 let original_sender_id = message.original_sender_id();
1631
1632 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1633 &self.handler_set,
1634 message,
1635 self.clone().into(),
1636 cx.clone(),
1637 ) {
1638 let client_id = self.id();
1639 log::debug!(
1640 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1641 client_id,
1642 original_sender_id,
1643 type_name
1644 );
1645 cx.spawn(async move |_| match future.await {
1646 Ok(()) => {
1647 log::debug!("rpc message handled. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}");
1648 }
1649 Err(error) => {
1650 log::error!("error handling message. client_id:{client_id}, sender_id:{original_sender_id:?}, type:{type_name}, error:{error:#}");
1651 }
1652 })
1653 .detach();
1654 } else {
1655 log::info!("unhandled message {}", type_name);
1656 self.peer
1657 .respond_with_unhandled_message(sender_id.into(), request_id, type_name)
1658 .log_err();
1659 }
1660 }
1661
1662 pub fn add_message_to_client_handler(
1663 self: &Arc<Client>,
1664 handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
1665 ) {
1666 self.message_to_client_handlers
1667 .lock()
1668 .push(Box::new(handler));
1669 }
1670
1671 fn handle_message_to_client(self: &Arc<Client>, message: MessageToClient, cx: &AsyncApp) {
1672 cx.update(|cx| {
1673 for handler in self.message_to_client_handlers.lock().iter() {
1674 handler(&message, cx);
1675 }
1676 });
1677 }
1678
1679 pub fn telemetry(&self) -> &Arc<Telemetry> {
1680 &self.telemetry
1681 }
1682}
1683
1684impl ProtoClient for Client {
1685 fn request(
1686 &self,
1687 envelope: proto::Envelope,
1688 request_type: &'static str,
1689 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1690 self.request_dynamic(envelope, request_type).boxed()
1691 }
1692
1693 fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1694 log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
1695 let connection_id = self.connection_id()?;
1696 self.peer.send_dynamic(connection_id, envelope)
1697 }
1698
1699 fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
1700 log::debug!(
1701 "rpc respond. client_id:{}, name:{}",
1702 self.id(),
1703 message_type
1704 );
1705 let connection_id = self.connection_id()?;
1706 self.peer.send_dynamic(connection_id, envelope)
1707 }
1708
1709 fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
1710 &self.handler_set
1711 }
1712
1713 fn is_via_collab(&self) -> bool {
1714 true
1715 }
1716
1717 fn has_wsl_interop(&self) -> bool {
1718 false
1719 }
1720}
1721
1722/// prefix for the zed:// url scheme
1723pub const ZED_URL_SCHEME: &str = "zed";
1724
1725/// A parsed Zed link that can be handled internally by the application.
1726#[derive(Debug, Clone, PartialEq, Eq)]
1727pub enum ZedLink {
1728 /// Join a channel: `zed.dev/channel/channel-name-123` or `zed://channel/channel-name-123`
1729 Channel { channel_id: u64 },
1730 /// Open channel notes: `zed.dev/channel/channel-name-123/notes` or with heading `notes#heading`
1731 ChannelNotes {
1732 channel_id: u64,
1733 heading: Option<String>,
1734 },
1735}
1736
1737/// Parses the given link into a Zed link.
1738///
1739/// Returns a [`Some`] containing the parsed link if the link is a recognized Zed link
1740/// that should be handled internally by the application.
1741/// Returns [`None`] for links that should be opened in the browser.
1742pub fn parse_zed_link(link: &str, cx: &App) -> Option<ZedLink> {
1743 let server_url = &ClientSettings::get_global(cx).server_url;
1744 let path = link
1745 .strip_prefix(server_url)
1746 .and_then(|result| result.strip_prefix('/'))
1747 .or_else(|| {
1748 link.strip_prefix(ZED_URL_SCHEME)
1749 .and_then(|result| result.strip_prefix("://"))
1750 })?;
1751
1752 let mut parts = path.split('/');
1753
1754 if parts.next() != Some("channel") {
1755 return None;
1756 }
1757
1758 let slug = parts.next()?;
1759 let id_str = slug.split('-').next_back()?;
1760 let channel_id = id_str.parse::<u64>().ok()?;
1761
1762 let Some(next) = parts.next() else {
1763 return Some(ZedLink::Channel { channel_id });
1764 };
1765
1766 if let Some(heading) = next.strip_prefix("notes#") {
1767 return Some(ZedLink::ChannelNotes {
1768 channel_id,
1769 heading: Some(heading.to_string()),
1770 });
1771 }
1772
1773 if next == "notes" {
1774 return Some(ZedLink::ChannelNotes {
1775 channel_id,
1776 heading: None,
1777 });
1778 }
1779
1780 None
1781}
1782
1783#[cfg(test)]
1784mod tests {
1785 use super::*;
1786 use crate::test::{FakeServer, parse_authorization_header};
1787
1788 use clock::FakeSystemClock;
1789 use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
1790 use http_client::FakeHttpClient;
1791 use parking_lot::Mutex;
1792 use proto::TypedEnvelope;
1793 use settings::SettingsStore;
1794 use std::future;
1795
1796 #[test]
1797 fn test_proxy_settings_trims_and_ignores_empty_proxy() {
1798 let mut content = SettingsContent::default();
1799 content.proxy = Some(" ".to_owned());
1800 assert_eq!(ProxySettings::from_settings(&content).proxy, None);
1801
1802 content.proxy = Some("http://127.0.0.1:10809".to_owned());
1803 assert_eq!(
1804 ProxySettings::from_settings(&content).proxy.as_deref(),
1805 Some("http://127.0.0.1:10809")
1806 );
1807 }
1808
1809 #[gpui::test(iterations = 10)]
1810 async fn test_reconnection(cx: &mut TestAppContext) {
1811 init_test(cx);
1812 let user_id = 5;
1813 let client = cx.update(|cx| {
1814 Client::new(
1815 Arc::new(FakeSystemClock::new()),
1816 FakeHttpClient::with_404_response(),
1817 cx,
1818 )
1819 });
1820 let server = FakeServer::for_client(user_id, &client, cx).await;
1821 let mut status = client.status();
1822 assert!(matches!(
1823 status.next().await,
1824 Some(Status::Connected { .. })
1825 ));
1826 assert_eq!(server.auth_count(), 1);
1827
1828 server.forbid_connections();
1829 server.disconnect();
1830 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1831
1832 server.allow_connections();
1833 cx.executor().advance_clock(Duration::from_secs(10));
1834 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1835 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1836
1837 server.forbid_connections();
1838 server.disconnect();
1839 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1840
1841 // Clear cached credentials after authentication fails
1842 server.roll_access_token();
1843 server.allow_connections();
1844 cx.executor().run_until_parked();
1845 cx.executor().advance_clock(Duration::from_secs(10));
1846 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1847 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1848 }
1849
1850 #[gpui::test(iterations = 10)]
1851 async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) {
1852 init_test(cx);
1853 let http_client = FakeHttpClient::with_200_response();
1854 let client =
1855 cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1856 let server = FakeServer::for_client(42, &client, cx).await;
1857 let mut status = client.status();
1858 assert!(matches!(
1859 status.next().await,
1860 Some(Status::Connected { .. })
1861 ));
1862 assert_eq!(server.auth_count(), 1);
1863
1864 // Simulate an auth failure during reconnection.
1865 http_client
1866 .as_fake()
1867 .replace_handler(|_, _request| async move {
1868 Ok(http_client::Response::builder()
1869 .status(503)
1870 .body("".into())
1871 .unwrap())
1872 });
1873 server.disconnect();
1874 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1875
1876 // Restore the ability to authenticate.
1877 http_client
1878 .as_fake()
1879 .replace_handler(|_, _request| async move {
1880 Ok(http_client::Response::builder()
1881 .status(200)
1882 .body("".into())
1883 .unwrap())
1884 });
1885 cx.executor().advance_clock(Duration::from_secs(10));
1886 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1887 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1888 }
1889
1890 #[gpui::test(iterations = 10)]
1891 async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1892 init_test(cx);
1893 let user_id = 5;
1894 let client = cx.update(|cx| {
1895 Client::new(
1896 Arc::new(FakeSystemClock::new()),
1897 FakeHttpClient::with_404_response(),
1898 cx,
1899 )
1900 });
1901 let mut status = client.status();
1902
1903 // Time out when client tries to connect.
1904 client.override_authenticate(move |cx| {
1905 cx.background_spawn(async move {
1906 Ok(Credentials {
1907 user_id,
1908 access_token: "token".into(),
1909 })
1910 })
1911 });
1912 client.override_establish_connection(|_, cx| {
1913 cx.background_spawn(async move {
1914 future::pending::<()>().await;
1915 unreachable!()
1916 })
1917 });
1918 let auth_and_connect = cx.spawn({
1919 let client = client.clone();
1920 |cx| async move { client.connect(false, &cx).await }
1921 });
1922 executor.run_until_parked();
1923 assert!(matches!(status.next().await, Some(Status::Connecting)));
1924
1925 executor.advance_clock(CONNECTION_TIMEOUT);
1926 assert!(matches!(status.next().await, Some(Status::ConnectionError)));
1927 auth_and_connect.await.into_response().unwrap_err();
1928
1929 // Allow the connection to be established.
1930 let server = FakeServer::for_client(user_id, &client, cx).await;
1931 assert!(matches!(
1932 status.next().await,
1933 Some(Status::Connected { .. })
1934 ));
1935
1936 // Disconnect client.
1937 server.forbid_connections();
1938 server.disconnect();
1939 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1940
1941 // Time out when re-establishing the connection.
1942 server.allow_connections();
1943 client.override_establish_connection(|_, cx| {
1944 cx.background_spawn(async move {
1945 future::pending::<()>().await;
1946 unreachable!()
1947 })
1948 });
1949 executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1950 assert!(matches!(status.next().await, Some(Status::Reconnecting)));
1951
1952 executor.advance_clock(CONNECTION_TIMEOUT);
1953 assert!(matches!(
1954 status.next().await,
1955 Some(Status::ReconnectionError { .. })
1956 ));
1957 }
1958
1959 #[gpui::test(iterations = 10)]
1960 async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
1961 init_test(cx);
1962 let auth_count = Arc::new(Mutex::new(0));
1963 let http_client = FakeHttpClient::create(|_request| async move {
1964 Ok(http_client::Response::builder()
1965 .status(200)
1966 .body("".into())
1967 .unwrap())
1968 });
1969 let client =
1970 cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
1971 client.override_authenticate({
1972 let auth_count = auth_count.clone();
1973 move |cx| {
1974 let auth_count = auth_count.clone();
1975 cx.background_spawn(async move {
1976 *auth_count.lock() += 1;
1977 Ok(Credentials {
1978 user_id: 1,
1979 access_token: auth_count.lock().to_string(),
1980 })
1981 })
1982 }
1983 });
1984
1985 let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1986 assert_eq!(*auth_count.lock(), 1);
1987 assert_eq!(credentials.access_token, "1");
1988
1989 // If credentials are still valid, signing in doesn't trigger authentication.
1990 let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
1991 assert_eq!(*auth_count.lock(), 1);
1992 assert_eq!(credentials.access_token, "1");
1993
1994 // If the server is unavailable, signing in doesn't trigger authentication.
1995 http_client
1996 .as_fake()
1997 .replace_handler(|_, _request| async move {
1998 Ok(http_client::Response::builder()
1999 .status(503)
2000 .body("".into())
2001 .unwrap())
2002 });
2003 client.sign_in(false, &cx.to_async()).await.unwrap_err();
2004 assert_eq!(*auth_count.lock(), 1);
2005
2006 // If credentials became invalid, signing in triggers authentication.
2007 http_client
2008 .as_fake()
2009 .replace_handler(|_, request| async move {
2010 let credentials = parse_authorization_header(&request).unwrap();
2011 if credentials.access_token == "2" {
2012 Ok(http_client::Response::builder()
2013 .status(200)
2014 .body("".into())
2015 .unwrap())
2016 } else {
2017 Ok(http_client::Response::builder()
2018 .status(401)
2019 .body("".into())
2020 .unwrap())
2021 }
2022 });
2023 let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
2024 assert_eq!(*auth_count.lock(), 2);
2025 assert_eq!(credentials.access_token, "2");
2026 }
2027
2028 #[gpui::test(iterations = 10)]
2029 async fn test_authenticating_more_than_once(
2030 cx: &mut TestAppContext,
2031 executor: BackgroundExecutor,
2032 ) {
2033 init_test(cx);
2034 let auth_count = Arc::new(Mutex::new(0));
2035 let dropped_auth_count = Arc::new(Mutex::new(0));
2036 let client = cx.update(|cx| {
2037 Client::new(
2038 Arc::new(FakeSystemClock::new()),
2039 FakeHttpClient::with_404_response(),
2040 cx,
2041 )
2042 });
2043 client.override_authenticate({
2044 let auth_count = auth_count.clone();
2045 let dropped_auth_count = dropped_auth_count.clone();
2046 move |cx| {
2047 let auth_count = auth_count.clone();
2048 let dropped_auth_count = dropped_auth_count.clone();
2049 cx.background_spawn(async move {
2050 *auth_count.lock() += 1;
2051 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
2052 future::pending::<()>().await;
2053 unreachable!()
2054 })
2055 }
2056 });
2057
2058 let _authenticate = cx.spawn({
2059 let client = client.clone();
2060 move |cx| async move { client.connect(false, &cx).await }
2061 });
2062 executor.run_until_parked();
2063 assert_eq!(*auth_count.lock(), 1);
2064 assert_eq!(*dropped_auth_count.lock(), 0);
2065
2066 let _authenticate = cx.spawn(|cx| async move { client.connect(false, &cx).await });
2067 executor.run_until_parked();
2068 assert_eq!(*auth_count.lock(), 2);
2069 assert_eq!(*dropped_auth_count.lock(), 1);
2070 }
2071
2072 #[gpui::test]
2073 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
2074 init_test(cx);
2075 let user_id = 5;
2076 let client = cx.update(|cx| {
2077 Client::new(
2078 Arc::new(FakeSystemClock::new()),
2079 FakeHttpClient::with_404_response(),
2080 cx,
2081 )
2082 });
2083 let server = FakeServer::for_client(user_id, &client, cx).await;
2084
2085 let (done_tx1, done_rx1) = smol::channel::unbounded();
2086 let (done_tx2, done_rx2) = smol::channel::unbounded();
2087 AnyProtoClient::from(client.clone()).add_entity_message_handler(
2088 move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, cx| {
2089 match entity.read_with(&cx, |entity, _| entity.id) {
2090 1 => done_tx1.try_send(()).unwrap(),
2091 2 => done_tx2.try_send(()).unwrap(),
2092 _ => unreachable!(),
2093 }
2094 async { Ok(()) }
2095 },
2096 );
2097 let entity1 = cx.new(|_| TestEntity {
2098 id: 1,
2099 subscription: None,
2100 });
2101 let entity2 = cx.new(|_| TestEntity {
2102 id: 2,
2103 subscription: None,
2104 });
2105 let entity3 = cx.new(|_| TestEntity {
2106 id: 3,
2107 subscription: None,
2108 });
2109
2110 let _subscription1 = client
2111 .subscribe_to_entity(1)
2112 .unwrap()
2113 .set_entity(&entity1, &cx.to_async());
2114 let _subscription2 = client
2115 .subscribe_to_entity(2)
2116 .unwrap()
2117 .set_entity(&entity2, &cx.to_async());
2118 // Ensure dropping a subscription for the same entity type still allows receiving of
2119 // messages for other entity IDs of the same type.
2120 let subscription3 = client
2121 .subscribe_to_entity(3)
2122 .unwrap()
2123 .set_entity(&entity3, &cx.to_async());
2124 drop(subscription3);
2125
2126 server.send(proto::JoinProject {
2127 project_id: 1,
2128 committer_name: None,
2129 committer_email: None,
2130 });
2131 server.send(proto::JoinProject {
2132 project_id: 2,
2133 committer_name: None,
2134 committer_email: None,
2135 });
2136 done_rx1.recv().await.unwrap();
2137 done_rx2.recv().await.unwrap();
2138 }
2139
2140 #[gpui::test]
2141 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
2142 init_test(cx);
2143 let user_id = 5;
2144 let client = cx.update(|cx| {
2145 Client::new(
2146 Arc::new(FakeSystemClock::new()),
2147 FakeHttpClient::with_404_response(),
2148 cx,
2149 )
2150 });
2151 let server = FakeServer::for_client(user_id, &client, cx).await;
2152
2153 let entity = cx.new(|_| TestEntity::default());
2154 let (done_tx1, _done_rx1) = smol::channel::unbounded();
2155 let (done_tx2, done_rx2) = smol::channel::unbounded();
2156 let subscription1 = client.add_message_handler(
2157 entity.downgrade(),
2158 move |_, _: TypedEnvelope<proto::Ping>, _| {
2159 done_tx1.try_send(()).unwrap();
2160 async { Ok(()) }
2161 },
2162 );
2163 drop(subscription1);
2164 let _subscription2 = client.add_message_handler(
2165 entity.downgrade(),
2166 move |_, _: TypedEnvelope<proto::Ping>, _| {
2167 done_tx2.try_send(()).unwrap();
2168 async { Ok(()) }
2169 },
2170 );
2171 server.send(proto::Ping {});
2172 done_rx2.recv().await.unwrap();
2173 }
2174
2175 #[gpui::test]
2176 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
2177 init_test(cx);
2178 let user_id = 5;
2179 let client = cx.update(|cx| {
2180 Client::new(
2181 Arc::new(FakeSystemClock::new()),
2182 FakeHttpClient::with_404_response(),
2183 cx,
2184 )
2185 });
2186 let server = FakeServer::for_client(user_id, &client, cx).await;
2187
2188 let entity = cx.new(|_| TestEntity::default());
2189 let (done_tx, done_rx) = smol::channel::unbounded();
2190 let subscription = client.add_message_handler(
2191 entity.clone().downgrade(),
2192 move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
2193 entity
2194 .update(&mut cx, |entity, _| entity.subscription.take())
2195 .unwrap();
2196 done_tx.try_send(()).unwrap();
2197 async { Ok(()) }
2198 },
2199 );
2200 entity.update(cx, |entity, _| {
2201 entity.subscription = Some(subscription);
2202 });
2203 server.send(proto::Ping {});
2204 done_rx.recv().await.unwrap();
2205 }
2206
2207 #[derive(Default)]
2208 struct TestEntity {
2209 id: usize,
2210 subscription: Option<Subscription>,
2211 }
2212
2213 fn init_test(cx: &mut TestAppContext) {
2214 cx.update(|cx| {
2215 let settings_store = SettingsStore::test(cx);
2216 cx.set_global(settings_store);
2217 });
2218 }
2219}