1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod telemetry;
5pub mod user;
6
7use anyhow::{anyhow, Context as _, Result};
8use async_recursion::async_recursion;
9use async_tungstenite::tungstenite::{
10 error::Error as WebsocketError,
11 http::{Request, StatusCode},
12};
13use futures::{
14 channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt,
15 TryFutureExt as _, TryStreamExt,
16};
17use gpui::{
18 actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Model, SemanticVersion, Task,
19 WeakModel,
20};
21use lazy_static::lazy_static;
22use parking_lot::RwLock;
23use postage::watch;
24use rand::prelude::*;
25use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use serde_json;
29use settings::{Settings, SettingsStore};
30use std::{
31 any::TypeId,
32 collections::HashMap,
33 convert::TryFrom,
34 fmt::Write as _,
35 future::Future,
36 marker::PhantomData,
37 path::PathBuf,
38 sync::{atomic::AtomicU64, Arc, Weak},
39 time::{Duration, Instant},
40};
41use telemetry::Telemetry;
42use thiserror::Error;
43use url::Url;
44use util::http::HttpClient;
45use util::{channel::ReleaseChannel, http::ZedHttpClient};
46use util::{ResultExt, TryFutureExt};
47
48pub use rpc::*;
49pub use telemetry::Event;
50pub use user::*;
51
52lazy_static! {
53 static ref ZED_SERVER_URL: Option<String> = std::env::var("ZED_SERVER_URL").ok();
54 static ref ZED_RPC_URL: Option<String> = std::env::var("ZED_RPC_URL").ok();
55 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
56 .ok()
57 .and_then(|s| if s.is_empty() { None } else { Some(s) });
58 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
59 .ok()
60 .and_then(|s| if s.is_empty() { None } else { Some(s) });
61 pub static ref ZED_APP_VERSION: Option<SemanticVersion> = std::env::var("ZED_APP_VERSION")
62 .ok()
63 .and_then(|v| v.parse().ok());
64 pub static ref ZED_APP_PATH: Option<PathBuf> =
65 std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
66 pub static ref ZED_ALWAYS_ACTIVE: bool =
67 std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| e.len() > 0);
68}
69
70pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
71pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
72
73actions!(client, [SignIn, SignOut, Reconnect]);
74
75#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
76pub struct ClientSettingsContent {
77 server_url: Option<String>,
78}
79
80#[derive(Deserialize)]
81pub struct ClientSettings {
82 pub server_url: String,
83}
84
85impl Settings for ClientSettings {
86 const KEY: Option<&'static str> = None;
87
88 type FileContent = ClientSettingsContent;
89
90 fn load(
91 default_value: &Self::FileContent,
92 user_values: &[&Self::FileContent],
93 _: &mut AppContext,
94 ) -> Result<Self>
95 where
96 Self: Sized,
97 {
98 let mut result = Self::load_via_json_merge(default_value, user_values)?;
99 if let Some(server_url) = &*ZED_SERVER_URL {
100 result.server_url = server_url.clone()
101 }
102 Ok(result)
103 }
104}
105
106pub fn init_settings(cx: &mut AppContext) {
107 TelemetrySettings::register(cx);
108 cx.update_global(|store: &mut SettingsStore, cx| {
109 store.register_setting::<ClientSettings>(cx);
110 });
111}
112
113pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
114 let client = Arc::downgrade(client);
115 cx.on_action({
116 let client = client.clone();
117 move |_: &SignIn, cx| {
118 if let Some(client) = client.upgrade() {
119 cx.spawn(
120 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
121 )
122 .detach();
123 }
124 }
125 });
126
127 cx.on_action({
128 let client = client.clone();
129 move |_: &SignOut, cx| {
130 if let Some(client) = client.upgrade() {
131 cx.spawn(|cx| async move {
132 client.disconnect(&cx);
133 })
134 .detach();
135 }
136 }
137 });
138
139 cx.on_action({
140 let client = client.clone();
141 move |_: &Reconnect, cx| {
142 if let Some(client) = client.upgrade() {
143 cx.spawn(|cx| async move {
144 client.reconnect(&cx);
145 })
146 .detach();
147 }
148 }
149 });
150}
151
152pub struct Client {
153 id: AtomicU64,
154 peer: Arc<Peer>,
155 http: Arc<ZedHttpClient>,
156 telemetry: Arc<Telemetry>,
157 state: RwLock<ClientState>,
158
159 #[allow(clippy::type_complexity)]
160 #[cfg(any(test, feature = "test-support"))]
161 authenticate: RwLock<
162 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
163 >,
164
165 #[allow(clippy::type_complexity)]
166 #[cfg(any(test, feature = "test-support"))]
167 establish_connection: RwLock<
168 Option<
169 Box<
170 dyn 'static
171 + Send
172 + Sync
173 + Fn(
174 &Credentials,
175 &AsyncAppContext,
176 ) -> Task<Result<Connection, EstablishConnectionError>>,
177 >,
178 >,
179 >,
180}
181
182#[derive(Error, Debug)]
183pub enum EstablishConnectionError {
184 #[error("upgrade required")]
185 UpgradeRequired,
186 #[error("unauthorized")]
187 Unauthorized,
188 #[error("{0}")]
189 Other(#[from] anyhow::Error),
190 #[error("{0}")]
191 Http(#[from] util::http::Error),
192 #[error("{0}")]
193 Io(#[from] std::io::Error),
194 #[error("{0}")]
195 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
196}
197
198impl From<WebsocketError> for EstablishConnectionError {
199 fn from(error: WebsocketError) -> Self {
200 if let WebsocketError::Http(response) = &error {
201 match response.status() {
202 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
203 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
204 _ => {}
205 }
206 }
207 EstablishConnectionError::Other(error.into())
208 }
209}
210
211impl EstablishConnectionError {
212 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
213 Self::Other(error.into())
214 }
215}
216
217#[derive(Copy, Clone, Debug, PartialEq)]
218pub enum Status {
219 SignedOut,
220 UpgradeRequired,
221 Authenticating,
222 Connecting,
223 ConnectionError,
224 Connected {
225 peer_id: PeerId,
226 connection_id: ConnectionId,
227 },
228 ConnectionLost,
229 Reauthenticating,
230 Reconnecting,
231 ReconnectionError {
232 next_reconnection: Instant,
233 },
234}
235
236impl Status {
237 pub fn is_connected(&self) -> bool {
238 matches!(self, Self::Connected { .. })
239 }
240
241 pub fn is_signed_out(&self) -> bool {
242 matches!(self, Self::SignedOut | Self::UpgradeRequired)
243 }
244}
245
246struct ClientState {
247 credentials: Option<Credentials>,
248 status: (watch::Sender<Status>, watch::Receiver<Status>),
249 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
250 _reconnect_task: Option<Task<()>>,
251 reconnect_interval: Duration,
252 entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
253 models_by_message_type: HashMap<TypeId, AnyWeakModel>,
254 entity_types_by_message_type: HashMap<TypeId, TypeId>,
255 #[allow(clippy::type_complexity)]
256 message_handlers: HashMap<
257 TypeId,
258 Arc<
259 dyn Send
260 + Sync
261 + Fn(
262 AnyModel,
263 Box<dyn AnyTypedEnvelope>,
264 &Arc<Client>,
265 AsyncAppContext,
266 ) -> LocalBoxFuture<'static, Result<()>>,
267 >,
268 >,
269}
270
271enum WeakSubscriber {
272 Entity { handle: AnyWeakModel },
273 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
274}
275
276#[derive(Clone, Debug)]
277pub struct Credentials {
278 pub user_id: u64,
279 pub access_token: String,
280}
281
282impl Default for ClientState {
283 fn default() -> Self {
284 Self {
285 credentials: None,
286 status: watch::channel_with(Status::SignedOut),
287 entity_id_extractors: Default::default(),
288 _reconnect_task: None,
289 reconnect_interval: Duration::from_secs(5),
290 models_by_message_type: Default::default(),
291 entities_by_type_and_remote_id: Default::default(),
292 entity_types_by_message_type: Default::default(),
293 message_handlers: Default::default(),
294 }
295 }
296}
297
298pub enum Subscription {
299 Entity {
300 client: Weak<Client>,
301 id: (TypeId, u64),
302 },
303 Message {
304 client: Weak<Client>,
305 id: TypeId,
306 },
307}
308
309impl Drop for Subscription {
310 fn drop(&mut self) {
311 match self {
312 Subscription::Entity { client, id } => {
313 if let Some(client) = client.upgrade() {
314 let mut state = client.state.write();
315 let _ = state.entities_by_type_and_remote_id.remove(id);
316 }
317 }
318 Subscription::Message { client, id } => {
319 if let Some(client) = client.upgrade() {
320 let mut state = client.state.write();
321 let _ = state.entity_types_by_message_type.remove(id);
322 let _ = state.message_handlers.remove(id);
323 }
324 }
325 }
326 }
327}
328
329pub struct PendingEntitySubscription<T: 'static> {
330 client: Arc<Client>,
331 remote_id: u64,
332 _entity_type: PhantomData<T>,
333 consumed: bool,
334}
335
336impl<T: 'static> PendingEntitySubscription<T> {
337 pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
338 self.consumed = true;
339 let mut state = self.client.state.write();
340 let id = (TypeId::of::<T>(), self.remote_id);
341 let Some(WeakSubscriber::Pending(messages)) =
342 state.entities_by_type_and_remote_id.remove(&id)
343 else {
344 unreachable!()
345 };
346
347 state.entities_by_type_and_remote_id.insert(
348 id,
349 WeakSubscriber::Entity {
350 handle: model.downgrade().into(),
351 },
352 );
353 drop(state);
354 for message in messages {
355 self.client.handle_message(message, cx);
356 }
357 Subscription::Entity {
358 client: Arc::downgrade(&self.client),
359 id,
360 }
361 }
362}
363
364impl<T: 'static> Drop for PendingEntitySubscription<T> {
365 fn drop(&mut self) {
366 if !self.consumed {
367 let mut state = self.client.state.write();
368 if let Some(WeakSubscriber::Pending(messages)) = state
369 .entities_by_type_and_remote_id
370 .remove(&(TypeId::of::<T>(), self.remote_id))
371 {
372 for message in messages {
373 log::info!("unhandled message {}", message.payload_type_name());
374 }
375 }
376 }
377 }
378}
379
380#[derive(Copy, Clone)]
381pub struct TelemetrySettings {
382 pub diagnostics: bool,
383 pub metrics: bool,
384}
385
386/// Control what info is collected by Zed.
387#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
388pub struct TelemetrySettingsContent {
389 /// Send debug info like crash reports.
390 ///
391 /// Default: true
392 pub diagnostics: Option<bool>,
393 /// Send anonymized usage data like what languages you're using Zed with.
394 ///
395 /// Default: true
396 pub metrics: Option<bool>,
397}
398
399impl settings::Settings for TelemetrySettings {
400 const KEY: Option<&'static str> = Some("telemetry");
401
402 type FileContent = TelemetrySettingsContent;
403
404 fn load(
405 default_value: &Self::FileContent,
406 user_values: &[&Self::FileContent],
407 _: &mut AppContext,
408 ) -> Result<Self> {
409 Ok(Self {
410 diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
411 default_value
412 .diagnostics
413 .ok_or_else(Self::missing_default)?,
414 ),
415 metrics: user_values
416 .first()
417 .and_then(|v| v.metrics)
418 .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
419 })
420 }
421}
422
423impl Client {
424 pub fn new(http: Arc<ZedHttpClient>, cx: &mut AppContext) -> Arc<Self> {
425 let client = Arc::new(Self {
426 id: AtomicU64::new(0),
427 peer: Peer::new(0),
428 telemetry: Telemetry::new(http.clone(), cx),
429 http,
430 state: Default::default(),
431
432 #[cfg(any(test, feature = "test-support"))]
433 authenticate: Default::default(),
434 #[cfg(any(test, feature = "test-support"))]
435 establish_connection: Default::default(),
436 });
437
438 client
439 }
440
441 pub fn id(&self) -> u64 {
442 self.id.load(std::sync::atomic::Ordering::SeqCst)
443 }
444
445 pub fn http_client(&self) -> Arc<ZedHttpClient> {
446 self.http.clone()
447 }
448
449 pub fn set_id(&self, id: u64) -> &Self {
450 self.id.store(id, std::sync::atomic::Ordering::SeqCst);
451 self
452 }
453
454 #[cfg(any(test, feature = "test-support"))]
455 pub fn teardown(&self) {
456 let mut state = self.state.write();
457 state._reconnect_task.take();
458 state.message_handlers.clear();
459 state.models_by_message_type.clear();
460 state.entities_by_type_and_remote_id.clear();
461 state.entity_id_extractors.clear();
462 self.peer.teardown();
463 }
464
465 #[cfg(any(test, feature = "test-support"))]
466 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
467 where
468 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
469 {
470 *self.authenticate.write() = Some(Box::new(authenticate));
471 self
472 }
473
474 #[cfg(any(test, feature = "test-support"))]
475 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
476 where
477 F: 'static
478 + Send
479 + Sync
480 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
481 {
482 *self.establish_connection.write() = Some(Box::new(connect));
483 self
484 }
485
486 pub fn user_id(&self) -> Option<u64> {
487 self.state
488 .read()
489 .credentials
490 .as_ref()
491 .map(|credentials| credentials.user_id)
492 }
493
494 pub fn peer_id(&self) -> Option<PeerId> {
495 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
496 Some(*peer_id)
497 } else {
498 None
499 }
500 }
501
502 pub fn status(&self) -> watch::Receiver<Status> {
503 self.state.read().status.1.clone()
504 }
505
506 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
507 log::info!("set status on client {}: {:?}", self.id(), status);
508 let mut state = self.state.write();
509 *state.status.0.borrow_mut() = status;
510
511 match status {
512 Status::Connected { .. } => {
513 state._reconnect_task = None;
514 }
515 Status::ConnectionLost => {
516 let this = self.clone();
517 let reconnect_interval = state.reconnect_interval;
518 state._reconnect_task = Some(cx.spawn(move |cx| async move {
519 #[cfg(any(test, feature = "test-support"))]
520 let mut rng = StdRng::seed_from_u64(0);
521 #[cfg(not(any(test, feature = "test-support")))]
522 let mut rng = StdRng::from_entropy();
523
524 let mut delay = INITIAL_RECONNECTION_DELAY;
525 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
526 log::error!("failed to connect {}", error);
527 if matches!(*this.status().borrow(), Status::ConnectionError) {
528 this.set_status(
529 Status::ReconnectionError {
530 next_reconnection: Instant::now() + delay,
531 },
532 &cx,
533 );
534 cx.background_executor().timer(delay).await;
535 delay = delay
536 .mul_f32(rng.gen_range(1.0..=2.0))
537 .min(reconnect_interval);
538 } else {
539 break;
540 }
541 }
542 }));
543 }
544 Status::SignedOut | Status::UpgradeRequired => {
545 self.telemetry.set_authenticated_user_info(None, false);
546 state._reconnect_task.take();
547 }
548 _ => {}
549 }
550 }
551
552 pub fn subscribe_to_entity<T>(
553 self: &Arc<Self>,
554 remote_id: u64,
555 ) -> Result<PendingEntitySubscription<T>>
556 where
557 T: 'static,
558 {
559 let id = (TypeId::of::<T>(), remote_id);
560
561 let mut state = self.state.write();
562 if state.entities_by_type_and_remote_id.contains_key(&id) {
563 return Err(anyhow!("already subscribed to entity"));
564 } else {
565 state
566 .entities_by_type_and_remote_id
567 .insert(id, WeakSubscriber::Pending(Default::default()));
568 Ok(PendingEntitySubscription {
569 client: self.clone(),
570 remote_id,
571 consumed: false,
572 _entity_type: PhantomData,
573 })
574 }
575 }
576
577 #[track_caller]
578 pub fn add_message_handler<M, E, H, F>(
579 self: &Arc<Self>,
580 entity: WeakModel<E>,
581 handler: H,
582 ) -> Subscription
583 where
584 M: EnvelopedMessage,
585 E: 'static,
586 H: 'static
587 + Sync
588 + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
589 + Send
590 + Sync,
591 F: 'static + Future<Output = Result<()>>,
592 {
593 let message_type_id = TypeId::of::<M>();
594 let mut state = self.state.write();
595 state
596 .models_by_message_type
597 .insert(message_type_id, entity.into());
598
599 let prev_handler = state.message_handlers.insert(
600 message_type_id,
601 Arc::new(move |subscriber, envelope, client, cx| {
602 let subscriber = subscriber.downcast::<E>().unwrap();
603 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
604 handler(subscriber, *envelope, client.clone(), cx).boxed_local()
605 }),
606 );
607 if prev_handler.is_some() {
608 let location = std::panic::Location::caller();
609 panic!(
610 "{}:{} registered handler for the same message {} twice",
611 location.file(),
612 location.line(),
613 std::any::type_name::<M>()
614 );
615 }
616
617 Subscription::Message {
618 client: Arc::downgrade(self),
619 id: message_type_id,
620 }
621 }
622
623 pub fn add_request_handler<M, E, H, F>(
624 self: &Arc<Self>,
625 model: WeakModel<E>,
626 handler: H,
627 ) -> Subscription
628 where
629 M: RequestMessage,
630 E: 'static,
631 H: 'static
632 + Sync
633 + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
634 + Send
635 + Sync,
636 F: 'static + Future<Output = Result<M::Response>>,
637 {
638 self.add_message_handler(model, move |handle, envelope, this, cx| {
639 Self::respond_to_request(
640 envelope.receipt(),
641 handler(handle, envelope, this.clone(), cx),
642 this,
643 )
644 })
645 }
646
647 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
648 where
649 M: EntityMessage,
650 E: 'static,
651 H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
652 F: 'static + Future<Output = Result<()>>,
653 {
654 self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
655 handler(subscriber.downcast::<E>().unwrap(), message, client, cx)
656 })
657 }
658
659 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
660 where
661 M: EntityMessage,
662 E: 'static,
663 H: 'static + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
664 F: 'static + Future<Output = Result<()>>,
665 {
666 let model_type_id = TypeId::of::<E>();
667 let message_type_id = TypeId::of::<M>();
668
669 let mut state = self.state.write();
670 state
671 .entity_types_by_message_type
672 .insert(message_type_id, model_type_id);
673 state
674 .entity_id_extractors
675 .entry(message_type_id)
676 .or_insert_with(|| {
677 |envelope| {
678 envelope
679 .as_any()
680 .downcast_ref::<TypedEnvelope<M>>()
681 .unwrap()
682 .payload
683 .remote_entity_id()
684 }
685 });
686 let prev_handler = state.message_handlers.insert(
687 message_type_id,
688 Arc::new(move |handle, envelope, client, cx| {
689 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
690 handler(handle, *envelope, client.clone(), cx).boxed_local()
691 }),
692 );
693 if prev_handler.is_some() {
694 panic!("registered handler for the same message twice");
695 }
696 }
697
698 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
699 where
700 M: EntityMessage + RequestMessage,
701 E: 'static,
702 H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
703 F: 'static + Future<Output = Result<M::Response>>,
704 {
705 self.add_model_message_handler(move |entity, envelope, client, cx| {
706 Self::respond_to_request::<M, _>(
707 envelope.receipt(),
708 handler(entity, envelope, client.clone(), cx),
709 client,
710 )
711 })
712 }
713
714 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
715 receipt: Receipt<T>,
716 response: F,
717 client: Arc<Self>,
718 ) -> Result<()> {
719 match response.await {
720 Ok(response) => {
721 client.respond(receipt, response)?;
722 Ok(())
723 }
724 Err(error) => {
725 client.respond_with_error(receipt, error.to_proto())?;
726 Err(error)
727 }
728 }
729 }
730
731 pub async fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
732 read_credentials_from_keychain(cx).await.is_some()
733 }
734
735 #[async_recursion(?Send)]
736 pub async fn authenticate_and_connect(
737 self: &Arc<Self>,
738 try_keychain: bool,
739 cx: &AsyncAppContext,
740 ) -> anyhow::Result<()> {
741 let was_disconnected = match *self.status().borrow() {
742 Status::SignedOut => true,
743 Status::ConnectionError
744 | Status::ConnectionLost
745 | Status::Authenticating { .. }
746 | Status::Reauthenticating { .. }
747 | Status::ReconnectionError { .. } => false,
748 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
749 return Ok(())
750 }
751 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
752 };
753
754 if was_disconnected {
755 self.set_status(Status::Authenticating, cx);
756 } else {
757 self.set_status(Status::Reauthenticating, cx)
758 }
759
760 let mut read_from_keychain = false;
761 let mut credentials = self.state.read().credentials.clone();
762 if credentials.is_none() && try_keychain {
763 credentials = read_credentials_from_keychain(cx).await;
764 read_from_keychain = credentials.is_some();
765 }
766 if credentials.is_none() {
767 let mut status_rx = self.status();
768 let _ = status_rx.next().await;
769 futures::select_biased! {
770 authenticate = self.authenticate(cx).fuse() => {
771 match authenticate {
772 Ok(creds) => credentials = Some(creds),
773 Err(err) => {
774 self.set_status(Status::ConnectionError, cx);
775 return Err(err);
776 }
777 }
778 }
779 _ = status_rx.next().fuse() => {
780 return Err(anyhow!("authentication canceled"));
781 }
782 }
783 }
784 let credentials = credentials.unwrap();
785 self.set_id(credentials.user_id);
786
787 if was_disconnected {
788 self.set_status(Status::Connecting, cx);
789 } else {
790 self.set_status(Status::Reconnecting, cx);
791 }
792
793 let mut timeout =
794 futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
795 futures::select_biased! {
796 connection = self.establish_connection(&credentials, cx).fuse() => {
797 match connection {
798 Ok(conn) => {
799 self.state.write().credentials = Some(credentials.clone());
800 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
801 write_credentials_to_keychain(credentials, cx).await.log_err();
802 }
803
804 futures::select_biased! {
805 result = self.set_connection(conn, cx).fuse() => result,
806 _ = timeout => {
807 self.set_status(Status::ConnectionError, cx);
808 Err(anyhow!("timed out waiting on hello message from server"))
809 }
810 }
811 }
812 Err(EstablishConnectionError::Unauthorized) => {
813 self.state.write().credentials.take();
814 if read_from_keychain {
815 delete_credentials_from_keychain(cx).await.log_err();
816 self.set_status(Status::SignedOut, cx);
817 self.authenticate_and_connect(false, cx).await
818 } else {
819 self.set_status(Status::ConnectionError, cx);
820 Err(EstablishConnectionError::Unauthorized)?
821 }
822 }
823 Err(EstablishConnectionError::UpgradeRequired) => {
824 self.set_status(Status::UpgradeRequired, cx);
825 Err(EstablishConnectionError::UpgradeRequired)?
826 }
827 Err(error) => {
828 self.set_status(Status::ConnectionError, cx);
829 Err(error)?
830 }
831 }
832 }
833 _ = &mut timeout => {
834 self.set_status(Status::ConnectionError, cx);
835 Err(anyhow!("timed out trying to establish connection"))
836 }
837 }
838 }
839
840 async fn set_connection(
841 self: &Arc<Self>,
842 conn: Connection,
843 cx: &AsyncAppContext,
844 ) -> Result<()> {
845 let executor = cx.background_executor();
846 log::info!("add connection to peer");
847 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
848 let executor = executor.clone();
849 move |duration| executor.timer(duration)
850 });
851 let handle_io = executor.spawn(handle_io);
852
853 let peer_id = async {
854 log::info!("waiting for server hello");
855 let message = incoming
856 .next()
857 .await
858 .ok_or_else(|| anyhow!("no hello message received"))?;
859 log::info!("got server hello");
860 let hello_message_type_name = message.payload_type_name().to_string();
861 let hello = message
862 .into_any()
863 .downcast::<TypedEnvelope<proto::Hello>>()
864 .map_err(|_| {
865 anyhow!(
866 "invalid hello message received: {:?}",
867 hello_message_type_name
868 )
869 })?;
870 let peer_id = hello
871 .payload
872 .peer_id
873 .ok_or_else(|| anyhow!("invalid peer id"))?;
874 Ok(peer_id)
875 };
876
877 let peer_id = match peer_id.await {
878 Ok(peer_id) => peer_id,
879 Err(error) => {
880 self.peer.disconnect(connection_id);
881 return Err(error);
882 }
883 };
884
885 log::info!(
886 "set status to connected (connection id: {:?}, peer id: {:?})",
887 connection_id,
888 peer_id
889 );
890 self.set_status(
891 Status::Connected {
892 peer_id,
893 connection_id,
894 },
895 cx,
896 );
897
898 cx.spawn({
899 let this = self.clone();
900 |cx| {
901 async move {
902 while let Some(message) = incoming.next().await {
903 this.handle_message(message, &cx);
904 // Don't starve the main thread when receiving lots of messages at once.
905 smol::future::yield_now().await;
906 }
907 }
908 }
909 })
910 .detach();
911
912 cx.spawn({
913 let this = self.clone();
914 move |cx| async move {
915 match handle_io.await {
916 Ok(()) => {
917 if this.status().borrow().clone()
918 == (Status::Connected {
919 connection_id,
920 peer_id,
921 })
922 {
923 this.set_status(Status::SignedOut, &cx);
924 }
925 }
926 Err(err) => {
927 log::error!("connection error: {:?}", err);
928 this.set_status(Status::ConnectionLost, &cx);
929 }
930 }
931 }
932 })
933 .detach();
934
935 Ok(())
936 }
937
938 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
939 #[cfg(any(test, feature = "test-support"))]
940 if let Some(callback) = self.authenticate.read().as_ref() {
941 return callback(cx);
942 }
943
944 self.authenticate_with_browser(cx)
945 }
946
947 fn establish_connection(
948 self: &Arc<Self>,
949 credentials: &Credentials,
950 cx: &AsyncAppContext,
951 ) -> Task<Result<Connection, EstablishConnectionError>> {
952 #[cfg(any(test, feature = "test-support"))]
953 if let Some(callback) = self.establish_connection.read().as_ref() {
954 return callback(credentials, cx);
955 }
956
957 self.establish_websocket_connection(credentials, cx)
958 }
959
960 async fn get_rpc_url(
961 http: Arc<ZedHttpClient>,
962 release_channel: Option<ReleaseChannel>,
963 ) -> Result<Url> {
964 if let Some(url) = &*ZED_RPC_URL {
965 return Url::parse(url).context("invalid rpc url");
966 }
967
968 let mut url = http.zed_url("/rpc");
969 if let Some(preview_param) =
970 release_channel.and_then(|channel| channel.release_query_param())
971 {
972 url += "?";
973 url += preview_param;
974 }
975 let response = http.get(&url, Default::default(), false).await?;
976 let collab_url = if response.status().is_redirection() {
977 response
978 .headers()
979 .get("Location")
980 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
981 .to_str()
982 .map_err(EstablishConnectionError::other)?
983 .to_string()
984 } else {
985 Err(anyhow!(
986 "unexpected /rpc response status {}",
987 response.status()
988 ))?
989 };
990
991 Url::parse(&collab_url).context("invalid rpc url")
992 }
993
994 fn establish_websocket_connection(
995 self: &Arc<Self>,
996 credentials: &Credentials,
997 cx: &AsyncAppContext,
998 ) -> Task<Result<Connection, EstablishConnectionError>> {
999 let release_channel = cx.try_read_global(|channel: &ReleaseChannel, _| *channel);
1000
1001 let request = Request::builder()
1002 .header(
1003 "Authorization",
1004 format!("{} {}", credentials.user_id, credentials.access_token),
1005 )
1006 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
1007
1008 let http = self.http.clone();
1009 cx.background_executor().spawn(async move {
1010 let mut rpc_url = Self::get_rpc_url(http, release_channel).await?;
1011 let rpc_host = rpc_url
1012 .host_str()
1013 .zip(rpc_url.port_or_known_default())
1014 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1015 let stream = smol::net::TcpStream::connect(rpc_host).await?;
1016
1017 log::info!("connected to rpc endpoint {}", rpc_url);
1018
1019 match rpc_url.scheme() {
1020 "https" => {
1021 rpc_url.set_scheme("wss").unwrap();
1022 let request = request.uri(rpc_url.as_str()).body(())?;
1023 let (stream, _) =
1024 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1025 Ok(Connection::new(
1026 stream
1027 .map_err(|error| anyhow!(error))
1028 .sink_map_err(|error| anyhow!(error)),
1029 ))
1030 }
1031 "http" => {
1032 rpc_url.set_scheme("ws").unwrap();
1033 let request = request.uri(rpc_url.as_str()).body(())?;
1034 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1035 Ok(Connection::new(
1036 stream
1037 .map_err(|error| anyhow!(error))
1038 .sink_map_err(|error| anyhow!(error)),
1039 ))
1040 }
1041 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1042 }
1043 })
1044 }
1045
1046 pub fn authenticate_with_browser(
1047 self: &Arc<Self>,
1048 cx: &AsyncAppContext,
1049 ) -> Task<Result<Credentials>> {
1050 let http = self.http.clone();
1051 cx.spawn(|cx| async move {
1052 let background = cx.background_executor().clone();
1053
1054 let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
1055 cx.update(|cx| {
1056 cx.spawn(move |cx| async move {
1057 let url = open_url_rx.await?;
1058 cx.update(|cx| cx.open_url(&url))
1059 })
1060 .detach_and_log_err(cx);
1061 })
1062 .log_err();
1063
1064 let credentials = background
1065 .clone()
1066 .spawn(async move {
1067 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1068 // zed server to encrypt the user's access token, so that it can'be intercepted by
1069 // any other app running on the user's device.
1070 let (public_key, private_key) =
1071 rpc::auth::keypair().expect("failed to generate keypair for auth");
1072 let public_key_string = String::try_from(public_key)
1073 .expect("failed to serialize public key for auth");
1074
1075 if let Some((login, token)) =
1076 IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
1077 {
1078 return Self::authenticate_as_admin(http, login.clone(), token.clone())
1079 .await;
1080 }
1081
1082 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1083 let server =
1084 tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1085 let port = server.server_addr().port();
1086
1087 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1088 // that the user is signing in from a Zed app running on the same device.
1089 let mut url = http.zed_url(&format!(
1090 "/native_app_signin?native_app_port={}&native_app_public_key={}",
1091 port, public_key_string
1092 ));
1093
1094 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1095 log::info!("impersonating user @{}", impersonate_login);
1096 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1097 }
1098
1099 open_url_tx.send(url).log_err();
1100
1101 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1102 // access token from the query params.
1103 //
1104 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1105 // custom URL scheme instead of this local HTTP server.
1106 let (user_id, access_token) = background
1107 .spawn(async move {
1108 for _ in 0..100 {
1109 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1110 let path = req.url();
1111 let mut user_id = None;
1112 let mut access_token = None;
1113 let url = Url::parse(&format!("http://example.com{}", path))
1114 .context("failed to parse login notification url")?;
1115 for (key, value) in url.query_pairs() {
1116 if key == "access_token" {
1117 access_token = Some(value.to_string());
1118 } else if key == "user_id" {
1119 user_id = Some(value.to_string());
1120 }
1121 }
1122
1123 let post_auth_url =
1124 http.zed_url("/native_app_signin_succeeded");
1125 req.respond(
1126 tiny_http::Response::empty(302).with_header(
1127 tiny_http::Header::from_bytes(
1128 &b"Location"[..],
1129 post_auth_url.as_bytes(),
1130 )
1131 .unwrap(),
1132 ),
1133 )
1134 .context("failed to respond to login http request")?;
1135 return Ok((
1136 user_id
1137 .ok_or_else(|| anyhow!("missing user_id parameter"))?,
1138 access_token.ok_or_else(|| {
1139 anyhow!("missing access_token parameter")
1140 })?,
1141 ));
1142 }
1143 }
1144
1145 Err(anyhow!("didn't receive login redirect"))
1146 })
1147 .await?;
1148
1149 let access_token = private_key
1150 .decrypt_string(&access_token)
1151 .context("failed to decrypt access token")?;
1152
1153 Ok(Credentials {
1154 user_id: user_id.parse()?,
1155 access_token,
1156 })
1157 })
1158 .await?;
1159
1160 cx.update(|cx| cx.activate(true))?;
1161 Ok(credentials)
1162 })
1163 }
1164
1165 async fn authenticate_as_admin(
1166 http: Arc<ZedHttpClient>,
1167 login: String,
1168 mut api_token: String,
1169 ) -> Result<Credentials> {
1170 #[derive(Deserialize)]
1171 struct AuthenticatedUserResponse {
1172 user: User,
1173 }
1174
1175 #[derive(Deserialize)]
1176 struct User {
1177 id: u64,
1178 }
1179
1180 // Use the collab server's admin API to retrieve the id
1181 // of the impersonated user.
1182 let mut url = Self::get_rpc_url(http.clone(), None).await?;
1183 url.set_path("/user");
1184 url.set_query(Some(&format!("github_login={login}")));
1185 let request = Request::get(url.as_str())
1186 .header("Authorization", format!("token {api_token}"))
1187 .body("".into())?;
1188
1189 let mut response = http.send(request).await?;
1190 let mut body = String::new();
1191 response.body_mut().read_to_string(&mut body).await?;
1192 if !response.status().is_success() {
1193 Err(anyhow!(
1194 "admin user request failed {} - {}",
1195 response.status().as_u16(),
1196 body,
1197 ))?;
1198 }
1199 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1200
1201 // Use the admin API token to authenticate as the impersonated user.
1202 api_token.insert_str(0, "ADMIN_TOKEN:");
1203 Ok(Credentials {
1204 user_id: response.user.id,
1205 access_token: api_token,
1206 })
1207 }
1208
1209 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1210 self.peer.teardown();
1211 self.set_status(Status::SignedOut, cx);
1212 }
1213
1214 pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1215 self.peer.teardown();
1216 self.set_status(Status::ConnectionLost, cx);
1217 }
1218
1219 fn connection_id(&self) -> Result<ConnectionId> {
1220 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1221 Ok(connection_id)
1222 } else {
1223 Err(anyhow!("not connected"))
1224 }
1225 }
1226
1227 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1228 log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1229 self.peer.send(self.connection_id()?, message)
1230 }
1231
1232 pub fn request<T: RequestMessage>(
1233 &self,
1234 request: T,
1235 ) -> impl Future<Output = Result<T::Response>> {
1236 self.request_envelope(request)
1237 .map_ok(|envelope| envelope.payload)
1238 }
1239
1240 pub fn request_envelope<T: RequestMessage>(
1241 &self,
1242 request: T,
1243 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1244 let client_id = self.id();
1245 log::debug!(
1246 "rpc request start. client_id:{}. name:{}",
1247 client_id,
1248 T::NAME
1249 );
1250 let response = self
1251 .connection_id()
1252 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1253 async move {
1254 let response = response?.await;
1255 log::debug!(
1256 "rpc request finish. client_id:{}. name:{}",
1257 client_id,
1258 T::NAME
1259 );
1260 response
1261 }
1262 }
1263
1264 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1265 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1266 self.peer.respond(receipt, response)
1267 }
1268
1269 fn respond_with_error<T: RequestMessage>(
1270 &self,
1271 receipt: Receipt<T>,
1272 error: proto::Error,
1273 ) -> Result<()> {
1274 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1275 self.peer.respond_with_error(receipt, error)
1276 }
1277
1278 fn handle_message(
1279 self: &Arc<Client>,
1280 message: Box<dyn AnyTypedEnvelope>,
1281 cx: &AsyncAppContext,
1282 ) {
1283 let mut state = self.state.write();
1284 let type_name = message.payload_type_name();
1285 let payload_type_id = message.payload_type_id();
1286 let sender_id = message.original_sender_id();
1287
1288 let mut subscriber = None;
1289
1290 if let Some(handle) = state
1291 .models_by_message_type
1292 .get(&payload_type_id)
1293 .and_then(|handle| handle.upgrade())
1294 {
1295 subscriber = Some(handle);
1296 } else if let Some((extract_entity_id, entity_type_id)) =
1297 state.entity_id_extractors.get(&payload_type_id).zip(
1298 state
1299 .entity_types_by_message_type
1300 .get(&payload_type_id)
1301 .copied(),
1302 )
1303 {
1304 let entity_id = (extract_entity_id)(message.as_ref());
1305
1306 match state
1307 .entities_by_type_and_remote_id
1308 .get_mut(&(entity_type_id, entity_id))
1309 {
1310 Some(WeakSubscriber::Pending(pending)) => {
1311 pending.push(message);
1312 return;
1313 }
1314 Some(weak_subscriber @ _) => match weak_subscriber {
1315 WeakSubscriber::Entity { handle } => {
1316 subscriber = handle.upgrade();
1317 }
1318
1319 WeakSubscriber::Pending(_) => {}
1320 },
1321 _ => {}
1322 }
1323 }
1324
1325 let subscriber = if let Some(subscriber) = subscriber {
1326 subscriber
1327 } else {
1328 log::info!("unhandled message {}", type_name);
1329 self.peer.respond_with_unhandled_message(message).log_err();
1330 return;
1331 };
1332
1333 let handler = state.message_handlers.get(&payload_type_id).cloned();
1334 // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1335 // It also ensures we don't hold the lock while yielding back to the executor, as
1336 // that might cause the executor thread driving this future to block indefinitely.
1337 drop(state);
1338
1339 if let Some(handler) = handler {
1340 let future = handler(subscriber, message, self, cx.clone());
1341 let client_id = self.id();
1342 log::debug!(
1343 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1344 client_id,
1345 sender_id,
1346 type_name
1347 );
1348 cx.spawn(move |_| async move {
1349 match future.await {
1350 Ok(()) => {
1351 log::debug!(
1352 "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1353 client_id,
1354 sender_id,
1355 type_name
1356 );
1357 }
1358 Err(error) => {
1359 log::error!(
1360 "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1361 client_id,
1362 sender_id,
1363 type_name,
1364 error
1365 );
1366 }
1367 }
1368 })
1369 .detach();
1370 } else {
1371 log::info!("unhandled message {}", type_name);
1372 self.peer.respond_with_unhandled_message(message).log_err();
1373 }
1374 }
1375
1376 pub fn telemetry(&self) -> &Arc<Telemetry> {
1377 &self.telemetry
1378 }
1379}
1380
1381async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1382 if IMPERSONATE_LOGIN.is_some() {
1383 return None;
1384 }
1385
1386 let (user_id, access_token) = cx
1387 .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
1388 .log_err()?
1389 .await
1390 .log_err()??;
1391
1392 Some(Credentials {
1393 user_id: user_id.parse().ok()?,
1394 access_token: String::from_utf8(access_token).ok()?,
1395 })
1396}
1397
1398async fn write_credentials_to_keychain(
1399 credentials: Credentials,
1400 cx: &AsyncAppContext,
1401) -> Result<()> {
1402 cx.update(move |cx| {
1403 cx.write_credentials(
1404 &ClientSettings::get_global(cx).server_url,
1405 &credentials.user_id.to_string(),
1406 credentials.access_token.as_bytes(),
1407 )
1408 })?
1409 .await
1410}
1411
1412async fn delete_credentials_from_keychain(cx: &AsyncAppContext) -> Result<()> {
1413 cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
1414 .await
1415}
1416
1417const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1418
1419pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1420 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1421}
1422
1423pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1424 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1425 let mut parts = path.split('/');
1426 let id = parts.next()?.parse::<u64>().ok()?;
1427 let access_token = parts.next()?;
1428 if access_token.is_empty() {
1429 return None;
1430 }
1431 Some((id, access_token.to_string()))
1432}
1433
1434#[cfg(test)]
1435mod tests {
1436 use super::*;
1437 use crate::test::FakeServer;
1438
1439 use gpui::{BackgroundExecutor, Context, TestAppContext};
1440 use parking_lot::Mutex;
1441 use settings::SettingsStore;
1442 use std::future;
1443 use util::http::FakeHttpClient;
1444
1445 #[gpui::test(iterations = 10)]
1446 async fn test_reconnection(cx: &mut TestAppContext) {
1447 init_test(cx);
1448 let user_id = 5;
1449 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1450 let server = FakeServer::for_client(user_id, &client, cx).await;
1451 let mut status = client.status();
1452 assert!(matches!(
1453 status.next().await,
1454 Some(Status::Connected { .. })
1455 ));
1456 assert_eq!(server.auth_count(), 1);
1457
1458 server.forbid_connections();
1459 server.disconnect();
1460 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1461
1462 server.allow_connections();
1463 cx.executor().advance_clock(Duration::from_secs(10));
1464 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1465 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1466
1467 server.forbid_connections();
1468 server.disconnect();
1469 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1470
1471 // Clear cached credentials after authentication fails
1472 server.roll_access_token();
1473 server.allow_connections();
1474 cx.executor().run_until_parked();
1475 cx.executor().advance_clock(Duration::from_secs(10));
1476 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1477 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1478 }
1479
1480 #[gpui::test(iterations = 10)]
1481 async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1482 init_test(cx);
1483 let user_id = 5;
1484 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1485 let mut status = client.status();
1486
1487 // Time out when client tries to connect.
1488 client.override_authenticate(move |cx| {
1489 cx.background_executor().spawn(async move {
1490 Ok(Credentials {
1491 user_id,
1492 access_token: "token".into(),
1493 })
1494 })
1495 });
1496 client.override_establish_connection(|_, cx| {
1497 cx.background_executor().spawn(async move {
1498 future::pending::<()>().await;
1499 unreachable!()
1500 })
1501 });
1502 let auth_and_connect = cx.spawn({
1503 let client = client.clone();
1504 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1505 });
1506 executor.run_until_parked();
1507 assert!(matches!(status.next().await, Some(Status::Connecting)));
1508
1509 executor.advance_clock(CONNECTION_TIMEOUT);
1510 assert!(matches!(
1511 status.next().await,
1512 Some(Status::ConnectionError { .. })
1513 ));
1514 auth_and_connect.await.unwrap_err();
1515
1516 // Allow the connection to be established.
1517 let server = FakeServer::for_client(user_id, &client, cx).await;
1518 assert!(matches!(
1519 status.next().await,
1520 Some(Status::Connected { .. })
1521 ));
1522
1523 // Disconnect client.
1524 server.forbid_connections();
1525 server.disconnect();
1526 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1527
1528 // Time out when re-establishing the connection.
1529 server.allow_connections();
1530 client.override_establish_connection(|_, cx| {
1531 cx.background_executor().spawn(async move {
1532 future::pending::<()>().await;
1533 unreachable!()
1534 })
1535 });
1536 executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1537 assert!(matches!(
1538 status.next().await,
1539 Some(Status::Reconnecting { .. })
1540 ));
1541
1542 executor.advance_clock(CONNECTION_TIMEOUT);
1543 assert!(matches!(
1544 status.next().await,
1545 Some(Status::ReconnectionError { .. })
1546 ));
1547 }
1548
1549 #[gpui::test(iterations = 10)]
1550 async fn test_authenticating_more_than_once(
1551 cx: &mut TestAppContext,
1552 executor: BackgroundExecutor,
1553 ) {
1554 init_test(cx);
1555 let auth_count = Arc::new(Mutex::new(0));
1556 let dropped_auth_count = Arc::new(Mutex::new(0));
1557 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1558 client.override_authenticate({
1559 let auth_count = auth_count.clone();
1560 let dropped_auth_count = dropped_auth_count.clone();
1561 move |cx| {
1562 let auth_count = auth_count.clone();
1563 let dropped_auth_count = dropped_auth_count.clone();
1564 cx.background_executor().spawn(async move {
1565 *auth_count.lock() += 1;
1566 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1567 future::pending::<()>().await;
1568 unreachable!()
1569 })
1570 }
1571 });
1572
1573 let _authenticate = cx.spawn({
1574 let client = client.clone();
1575 move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1576 });
1577 executor.run_until_parked();
1578 assert_eq!(*auth_count.lock(), 1);
1579 assert_eq!(*dropped_auth_count.lock(), 0);
1580
1581 let _authenticate = cx.spawn({
1582 let client = client.clone();
1583 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1584 });
1585 executor.run_until_parked();
1586 assert_eq!(*auth_count.lock(), 2);
1587 assert_eq!(*dropped_auth_count.lock(), 1);
1588 }
1589
1590 #[test]
1591 fn test_encode_and_decode_worktree_url() {
1592 let url = encode_worktree_url(5, "deadbeef");
1593 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1594 assert_eq!(
1595 decode_worktree_url(&format!("\n {}\t", url)),
1596 Some((5, "deadbeef".to_string()))
1597 );
1598 assert_eq!(decode_worktree_url("not://the-right-format"), None);
1599 }
1600
1601 #[gpui::test]
1602 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1603 init_test(cx);
1604 let user_id = 5;
1605 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1606 let server = FakeServer::for_client(user_id, &client, cx).await;
1607
1608 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1609 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1610 client.add_model_message_handler(
1611 move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
1612 match model.update(&mut cx, |model, _| model.id).unwrap() {
1613 1 => done_tx1.try_send(()).unwrap(),
1614 2 => done_tx2.try_send(()).unwrap(),
1615 _ => unreachable!(),
1616 }
1617 async { Ok(()) }
1618 },
1619 );
1620 let model1 = cx.new_model(|_| TestModel {
1621 id: 1,
1622 subscription: None,
1623 });
1624 let model2 = cx.new_model(|_| TestModel {
1625 id: 2,
1626 subscription: None,
1627 });
1628 let model3 = cx.new_model(|_| TestModel {
1629 id: 3,
1630 subscription: None,
1631 });
1632
1633 let _subscription1 = client
1634 .subscribe_to_entity(1)
1635 .unwrap()
1636 .set_model(&model1, &mut cx.to_async());
1637 let _subscription2 = client
1638 .subscribe_to_entity(2)
1639 .unwrap()
1640 .set_model(&model2, &mut cx.to_async());
1641 // Ensure dropping a subscription for the same entity type still allows receiving of
1642 // messages for other entity IDs of the same type.
1643 let subscription3 = client
1644 .subscribe_to_entity(3)
1645 .unwrap()
1646 .set_model(&model3, &mut cx.to_async());
1647 drop(subscription3);
1648
1649 server.send(proto::JoinProject { project_id: 1 });
1650 server.send(proto::JoinProject { project_id: 2 });
1651 done_rx1.next().await.unwrap();
1652 done_rx2.next().await.unwrap();
1653 }
1654
1655 #[gpui::test]
1656 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1657 init_test(cx);
1658 let user_id = 5;
1659 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1660 let server = FakeServer::for_client(user_id, &client, cx).await;
1661
1662 let model = cx.new_model(|_| TestModel::default());
1663 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1664 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1665 let subscription1 = client.add_message_handler(
1666 model.downgrade(),
1667 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1668 done_tx1.try_send(()).unwrap();
1669 async { Ok(()) }
1670 },
1671 );
1672 drop(subscription1);
1673 let _subscription2 = client.add_message_handler(
1674 model.downgrade(),
1675 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1676 done_tx2.try_send(()).unwrap();
1677 async { Ok(()) }
1678 },
1679 );
1680 server.send(proto::Ping {});
1681 done_rx2.next().await.unwrap();
1682 }
1683
1684 #[gpui::test]
1685 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1686 init_test(cx);
1687 let user_id = 5;
1688 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1689 let server = FakeServer::for_client(user_id, &client, cx).await;
1690
1691 let model = cx.new_model(|_| TestModel::default());
1692 let (done_tx, mut done_rx) = smol::channel::unbounded();
1693 let subscription = client.add_message_handler(
1694 model.clone().downgrade(),
1695 move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1696 model
1697 .update(&mut cx, |model, _| model.subscription.take())
1698 .unwrap();
1699 done_tx.try_send(()).unwrap();
1700 async { Ok(()) }
1701 },
1702 );
1703 model.update(cx, |model, _| {
1704 model.subscription = Some(subscription);
1705 });
1706 server.send(proto::Ping {});
1707 done_rx.next().await.unwrap();
1708 }
1709
1710 #[derive(Default)]
1711 struct TestModel {
1712 id: usize,
1713 subscription: Option<Subscription>,
1714 }
1715
1716 fn init_test(cx: &mut TestAppContext) {
1717 cx.update(|cx| {
1718 let settings_store = SettingsStore::test(cx);
1719 cx.set_global(settings_store);
1720 init_settings(cx);
1721 });
1722 }
1723}