1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod telemetry;
5pub mod user;
6
7use anyhow::{anyhow, Context as _, Result};
8use async_recursion::async_recursion;
9use async_tungstenite::tungstenite::{
10 error::Error as WebsocketError,
11 http::{Request, StatusCode},
12};
13use futures::{
14 future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _, TryStreamExt,
15};
16use gpui2::{
17 serde_json, AnyHandle, AnyWeakHandle, AnyWindowHandle, AppContext, AsyncAppContext, Handle,
18 SemanticVersion, Task, ViewContext,
19};
20use lazy_static::lazy_static;
21use parking_lot::RwLock;
22use postage::watch;
23use rand::prelude::*;
24use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use std::{
28 any::TypeId,
29 collections::HashMap,
30 convert::TryFrom,
31 fmt::Write as _,
32 future::Future,
33 marker::PhantomData,
34 path::PathBuf,
35 sync::{atomic::AtomicU64, Arc, Weak},
36 time::{Duration, Instant},
37};
38use telemetry::Telemetry;
39use thiserror::Error;
40use url::Url;
41use util::channel::ReleaseChannel;
42use util::http::HttpClient;
43use util::{ResultExt, TryFutureExt};
44
45pub use rpc::*;
46pub use telemetry::ClickhouseEvent;
47pub use user::*;
48
49lazy_static! {
50 pub static ref ZED_SERVER_URL: String =
51 std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
52 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
53 .ok()
54 .and_then(|s| if s.is_empty() { None } else { Some(s) });
55 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
56 .ok()
57 .and_then(|s| if s.is_empty() { None } else { Some(s) });
58 pub static ref ZED_APP_VERSION: Option<SemanticVersion> = std::env::var("ZED_APP_VERSION")
59 .ok()
60 .and_then(|v| v.parse().ok());
61 pub static ref ZED_APP_PATH: Option<PathBuf> =
62 std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
63 pub static ref ZED_ALWAYS_ACTIVE: bool =
64 std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| e.len() > 0);
65}
66
67pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
68pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
69pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
70
71#[derive(Clone, Default, PartialEq, Deserialize)]
72pub struct SignIn;
73
74#[derive(Clone, Default, PartialEq, Deserialize)]
75pub struct SignOut;
76
77#[derive(Clone, Default, PartialEq, Deserialize)]
78pub struct Reconnect;
79
80pub fn init_settings(cx: &mut AppContext) {
81 settings2::register::<TelemetrySettings>(cx);
82}
83
84pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
85 init_settings(cx);
86
87 let client = Arc::downgrade(client);
88 cx.register_action_type::<SignIn>();
89 cx.on_action({
90 let client = client.clone();
91 move |_: &SignIn, cx| {
92 if let Some(client) = client.upgrade() {
93 cx.spawn(
94 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
95 )
96 .detach();
97 }
98 }
99 });
100
101 cx.register_action_type::<SignOut>();
102 cx.on_action({
103 let client = client.clone();
104 move |_: &SignOut, cx| {
105 if let Some(client) = client.upgrade() {
106 cx.spawn(|cx| async move {
107 client.disconnect(&cx);
108 })
109 .detach();
110 }
111 }
112 });
113
114 cx.register_action_type::<Reconnect>();
115 cx.on_action({
116 let client = client.clone();
117 move |_: &Reconnect, cx| {
118 if let Some(client) = client.upgrade() {
119 cx.spawn(|cx| async move {
120 client.reconnect(&cx);
121 })
122 .detach();
123 }
124 }
125 });
126}
127
128pub struct Client {
129 id: AtomicU64,
130 peer: Arc<Peer>,
131 http: Arc<dyn HttpClient>,
132 telemetry: Arc<Telemetry>,
133 state: RwLock<ClientState>,
134
135 #[allow(clippy::type_complexity)]
136 #[cfg(any(test, feature = "test-support"))]
137 authenticate: RwLock<
138 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
139 >,
140
141 #[allow(clippy::type_complexity)]
142 #[cfg(any(test, feature = "test-support"))]
143 establish_connection: RwLock<
144 Option<
145 Box<
146 dyn 'static
147 + Send
148 + Sync
149 + Fn(
150 &Credentials,
151 &AsyncAppContext,
152 ) -> Task<Result<Connection, EstablishConnectionError>>,
153 >,
154 >,
155 >,
156}
157
158#[derive(Error, Debug)]
159pub enum EstablishConnectionError {
160 #[error("upgrade required")]
161 UpgradeRequired,
162 #[error("unauthorized")]
163 Unauthorized,
164 #[error("{0}")]
165 Other(#[from] anyhow::Error),
166 #[error("{0}")]
167 Http(#[from] util::http::Error),
168 #[error("{0}")]
169 Io(#[from] std::io::Error),
170 #[error("{0}")]
171 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
172}
173
174impl From<WebsocketError> for EstablishConnectionError {
175 fn from(error: WebsocketError) -> Self {
176 if let WebsocketError::Http(response) = &error {
177 match response.status() {
178 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
179 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
180 _ => {}
181 }
182 }
183 EstablishConnectionError::Other(error.into())
184 }
185}
186
187impl EstablishConnectionError {
188 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
189 Self::Other(error.into())
190 }
191}
192
193#[derive(Copy, Clone, Debug, PartialEq)]
194pub enum Status {
195 SignedOut,
196 UpgradeRequired,
197 Authenticating,
198 Connecting,
199 ConnectionError,
200 Connected {
201 peer_id: PeerId,
202 connection_id: ConnectionId,
203 },
204 ConnectionLost,
205 Reauthenticating,
206 Reconnecting,
207 ReconnectionError {
208 next_reconnection: Instant,
209 },
210}
211
212impl Status {
213 pub fn is_connected(&self) -> bool {
214 matches!(self, Self::Connected { .. })
215 }
216
217 pub fn is_signed_out(&self) -> bool {
218 matches!(self, Self::SignedOut | Self::UpgradeRequired)
219 }
220}
221
222struct ClientState {
223 credentials: Option<Credentials>,
224 status: (watch::Sender<Status>, watch::Receiver<Status>),
225 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
226 _reconnect_task: Option<Task<()>>,
227 reconnect_interval: Duration,
228 entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
229 models_by_message_type: HashMap<TypeId, AnyWeakHandle>,
230 entity_types_by_message_type: HashMap<TypeId, TypeId>,
231 #[allow(clippy::type_complexity)]
232 message_handlers: HashMap<
233 TypeId,
234 Arc<
235 dyn Send
236 + Sync
237 + Fn(
238 Subscriber,
239 Box<dyn AnyTypedEnvelope>,
240 &Arc<Client>,
241 AsyncAppContext,
242 ) -> BoxFuture<'static, Result<()>>,
243 >,
244 >,
245}
246
247enum WeakSubscriber {
248 Entity {
249 handle: AnyWeakHandle,
250 window_handle: Option<AnyWindowHandle>,
251 },
252 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
253}
254
255struct Subscriber {
256 handle: AnyHandle,
257 window_handle: Option<AnyWindowHandle>,
258}
259
260#[derive(Clone, Debug)]
261pub struct Credentials {
262 pub user_id: u64,
263 pub access_token: String,
264}
265
266impl Default for ClientState {
267 fn default() -> Self {
268 Self {
269 credentials: None,
270 status: watch::channel_with(Status::SignedOut),
271 entity_id_extractors: Default::default(),
272 _reconnect_task: None,
273 reconnect_interval: Duration::from_secs(5),
274 models_by_message_type: Default::default(),
275 entities_by_type_and_remote_id: Default::default(),
276 entity_types_by_message_type: Default::default(),
277 message_handlers: Default::default(),
278 }
279 }
280}
281
282pub enum Subscription {
283 Entity {
284 client: Weak<Client>,
285 id: (TypeId, u64),
286 },
287 Message {
288 client: Weak<Client>,
289 id: TypeId,
290 },
291}
292
293impl Drop for Subscription {
294 fn drop(&mut self) {
295 match self {
296 Subscription::Entity { client, id } => {
297 if let Some(client) = client.upgrade() {
298 let mut state = client.state.write();
299 let _ = state.entities_by_type_and_remote_id.remove(id);
300 }
301 }
302 Subscription::Message { client, id } => {
303 if let Some(client) = client.upgrade() {
304 let mut state = client.state.write();
305 let _ = state.entity_types_by_message_type.remove(id);
306 let _ = state.message_handlers.remove(id);
307 }
308 }
309 }
310 }
311}
312
313pub struct PendingEntitySubscription<T> {
314 client: Arc<Client>,
315 remote_id: u64,
316 _entity_type: PhantomData<T>,
317 consumed: bool,
318}
319
320impl<T> PendingEntitySubscription<T>
321where
322 T: 'static + Send + Sync,
323{
324 pub fn set_model(mut self, model: &Handle<T>, cx: &mut AsyncAppContext) -> Subscription {
325 self.consumed = true;
326 let mut state = self.client.state.write();
327 let id = (TypeId::of::<T>(), self.remote_id);
328 let Some(WeakSubscriber::Pending(messages)) =
329 state.entities_by_type_and_remote_id.remove(&id)
330 else {
331 unreachable!()
332 };
333
334 state.entities_by_type_and_remote_id.insert(
335 id,
336 WeakSubscriber::Entity {
337 handle: model.downgrade().into(),
338 window_handle: None,
339 },
340 );
341 drop(state);
342 for message in messages {
343 self.client.handle_message(message, cx);
344 }
345 Subscription::Entity {
346 client: Arc::downgrade(&self.client),
347 id,
348 }
349 }
350}
351
352impl<T> Drop for PendingEntitySubscription<T> {
353 fn drop(&mut self) {
354 if !self.consumed {
355 let mut state = self.client.state.write();
356 if let Some(WeakSubscriber::Pending(messages)) = state
357 .entities_by_type_and_remote_id
358 .remove(&(TypeId::of::<T>(), self.remote_id))
359 {
360 for message in messages {
361 log::info!("unhandled message {}", message.payload_type_name());
362 }
363 }
364 }
365 }
366}
367
368#[derive(Copy, Clone)]
369pub struct TelemetrySettings {
370 pub diagnostics: bool,
371 pub metrics: bool,
372}
373
374#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
375pub struct TelemetrySettingsContent {
376 pub diagnostics: Option<bool>,
377 pub metrics: Option<bool>,
378}
379
380impl settings2::Setting for TelemetrySettings {
381 const KEY: Option<&'static str> = Some("telemetry");
382
383 type FileContent = TelemetrySettingsContent;
384
385 fn load(
386 default_value: &Self::FileContent,
387 user_values: &[&Self::FileContent],
388 _: &AppContext,
389 ) -> Result<Self> {
390 Ok(Self {
391 diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
392 default_value
393 .diagnostics
394 .ok_or_else(Self::missing_default)?,
395 ),
396 metrics: user_values
397 .first()
398 .and_then(|v| v.metrics)
399 .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
400 })
401 }
402}
403
404impl Client {
405 pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
406 Arc::new(Self {
407 id: AtomicU64::new(0),
408 peer: Peer::new(0),
409 telemetry: Telemetry::new(http.clone(), cx),
410 http,
411 state: Default::default(),
412
413 #[cfg(any(test, feature = "test-support"))]
414 authenticate: Default::default(),
415 #[cfg(any(test, feature = "test-support"))]
416 establish_connection: Default::default(),
417 })
418 }
419
420 pub fn id(&self) -> u64 {
421 self.id.load(std::sync::atomic::Ordering::SeqCst)
422 }
423
424 pub fn http_client(&self) -> Arc<dyn HttpClient> {
425 self.http.clone()
426 }
427
428 pub fn set_id(&self, id: u64) -> &Self {
429 self.id.store(id, std::sync::atomic::Ordering::SeqCst);
430 self
431 }
432
433 #[cfg(any(test, feature = "test-support"))]
434 pub fn teardown(&self) {
435 let mut state = self.state.write();
436 state._reconnect_task.take();
437 state.message_handlers.clear();
438 state.models_by_message_type.clear();
439 state.entities_by_type_and_remote_id.clear();
440 state.entity_id_extractors.clear();
441 self.peer.teardown();
442 }
443
444 #[cfg(any(test, feature = "test-support"))]
445 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
446 where
447 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
448 {
449 *self.authenticate.write() = Some(Box::new(authenticate));
450 self
451 }
452
453 #[cfg(any(test, feature = "test-support"))]
454 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
455 where
456 F: 'static
457 + Send
458 + Sync
459 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
460 {
461 *self.establish_connection.write() = Some(Box::new(connect));
462 self
463 }
464
465 pub fn user_id(&self) -> Option<u64> {
466 self.state
467 .read()
468 .credentials
469 .as_ref()
470 .map(|credentials| credentials.user_id)
471 }
472
473 pub fn peer_id(&self) -> Option<PeerId> {
474 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
475 Some(*peer_id)
476 } else {
477 None
478 }
479 }
480
481 pub fn status(&self) -> watch::Receiver<Status> {
482 self.state.read().status.1.clone()
483 }
484
485 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
486 log::info!("set status on client {}: {:?}", self.id(), status);
487 let mut state = self.state.write();
488 *state.status.0.borrow_mut() = status;
489
490 match status {
491 Status::Connected { .. } => {
492 state._reconnect_task = None;
493 }
494 Status::ConnectionLost => {
495 let this = self.clone();
496 let reconnect_interval = state.reconnect_interval;
497 state._reconnect_task = Some(cx.spawn(|cx| async move {
498 #[cfg(any(test, feature = "test-support"))]
499 let mut rng = StdRng::seed_from_u64(0);
500 #[cfg(not(any(test, feature = "test-support")))]
501 let mut rng = StdRng::from_entropy();
502
503 let mut delay = INITIAL_RECONNECTION_DELAY;
504 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
505 log::error!("failed to connect {}", error);
506 if matches!(*this.status().borrow(), Status::ConnectionError) {
507 this.set_status(
508 Status::ReconnectionError {
509 next_reconnection: Instant::now() + delay,
510 },
511 &cx,
512 );
513 cx.executor().timer(delay).await;
514 delay = delay
515 .mul_f32(rng.gen_range(1.0..=2.0))
516 .min(reconnect_interval);
517 } else {
518 break;
519 }
520 }
521 }));
522 }
523 Status::SignedOut | Status::UpgradeRequired => {
524 cx.update(|cx| self.telemetry.set_authenticated_user_info(None, false, cx));
525 state._reconnect_task.take();
526 }
527 _ => {}
528 }
529 }
530
531 pub fn add_view_for_remote_entity<T>(
532 self: &Arc<Self>,
533 remote_id: u64,
534 cx: &mut ViewContext<T>,
535 ) -> Subscription
536 where
537 T: 'static + Send + Sync,
538 {
539 let id = (TypeId::of::<T>(), remote_id);
540 self.state.write().entities_by_type_and_remote_id.insert(
541 id,
542 WeakSubscriber::Entity {
543 handle: cx.handle().into(),
544 window_handle: Some(cx.window_handle()),
545 },
546 );
547 Subscription::Entity {
548 client: Arc::downgrade(self),
549 id,
550 }
551 }
552
553 pub fn subscribe_to_entity<T>(
554 self: &Arc<Self>,
555 remote_id: u64,
556 ) -> Result<PendingEntitySubscription<T>> {
557 let id = (TypeId::of::<T>(), remote_id);
558
559 let mut state = self.state.write();
560 if state.entities_by_type_and_remote_id.contains_key(&id) {
561 return Err(anyhow!("already subscribed to entity"));
562 } else {
563 state
564 .entities_by_type_and_remote_id
565 .insert(id, WeakSubscriber::Pending(Default::default()));
566 Ok(PendingEntitySubscription {
567 client: self.clone(),
568 remote_id,
569 consumed: false,
570 _entity_type: PhantomData,
571 })
572 }
573 }
574
575 #[track_caller]
576 pub fn add_message_handler<M, E, H, F>(
577 self: &Arc<Self>,
578 entity: Handle<E>,
579 handler: H,
580 ) -> Subscription
581 where
582 M: EnvelopedMessage,
583 E: 'static + Send + Sync,
584 H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
585 F: 'static + Future<Output = Result<()>> + Send,
586 {
587 let message_type_id = TypeId::of::<M>();
588
589 let mut state = self.state.write();
590 state
591 .models_by_message_type
592 .insert(message_type_id, entity.downgrade().into());
593
594 let prev_handler = state.message_handlers.insert(
595 message_type_id,
596 Arc::new(move |subscriber, envelope, client, cx| {
597 let subscriber = subscriber.handle.downcast::<E>().unwrap();
598 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
599 handler(subscriber, *envelope, client.clone(), cx).boxed()
600 }),
601 );
602 if prev_handler.is_some() {
603 let location = std::panic::Location::caller();
604 panic!(
605 "{}:{} registered handler for the same message {} twice",
606 location.file(),
607 location.line(),
608 std::any::type_name::<M>()
609 );
610 }
611
612 Subscription::Message {
613 client: Arc::downgrade(self),
614 id: message_type_id,
615 }
616 }
617
618 pub fn add_request_handler<M, E, H, F>(
619 self: &Arc<Self>,
620 model: Handle<E>,
621 handler: H,
622 ) -> Subscription
623 where
624 M: RequestMessage,
625 E: 'static + Send + Sync,
626 H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
627 F: 'static + Future<Output = Result<M::Response>> + Send,
628 {
629 self.add_message_handler(model, move |handle, envelope, this, cx| {
630 Self::respond_to_request(
631 envelope.receipt(),
632 handler(handle, envelope, this.clone(), cx),
633 this,
634 )
635 })
636 }
637
638 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
639 where
640 M: EntityMessage,
641 E: 'static + Send + Sync,
642 H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
643 F: 'static + Future<Output = Result<()>> + Send,
644 {
645 self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
646 handler(
647 subscriber.handle.downcast::<E>().unwrap(),
648 message,
649 client,
650 cx,
651 )
652 })
653 }
654
655 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
656 where
657 M: EntityMessage,
658 H: 'static
659 + Send
660 + Sync
661 + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
662 F: 'static + Future<Output = Result<()>> + Send,
663 {
664 let model_type_id = TypeId::of::<E>();
665 let message_type_id = TypeId::of::<M>();
666
667 let mut state = self.state.write();
668 state
669 .entity_types_by_message_type
670 .insert(message_type_id, model_type_id);
671 state
672 .entity_id_extractors
673 .entry(message_type_id)
674 .or_insert_with(|| {
675 |envelope| {
676 envelope
677 .as_any()
678 .downcast_ref::<TypedEnvelope<M>>()
679 .unwrap()
680 .payload
681 .remote_entity_id()
682 }
683 });
684 let prev_handler = state.message_handlers.insert(
685 message_type_id,
686 Arc::new(move |handle, envelope, client, cx| {
687 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
688 handler(handle, *envelope, client.clone(), cx).boxed()
689 }),
690 );
691 if prev_handler.is_some() {
692 panic!("registered handler for the same message twice");
693 }
694 }
695
696 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
697 where
698 M: EntityMessage + RequestMessage,
699 E: 'static + Send + Sync,
700 H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
701 F: 'static + Future<Output = Result<M::Response>> + Send,
702 {
703 self.add_model_message_handler(move |entity, envelope, client, cx| {
704 Self::respond_to_request::<M, _>(
705 envelope.receipt(),
706 handler(entity, envelope, client.clone(), cx),
707 client,
708 )
709 })
710 }
711
712 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
713 receipt: Receipt<T>,
714 response: F,
715 client: Arc<Self>,
716 ) -> Result<()> {
717 match response.await {
718 Ok(response) => {
719 client.respond(receipt, response)?;
720 Ok(())
721 }
722 Err(error) => {
723 client.respond_with_error(
724 receipt,
725 proto::Error {
726 message: format!("{:?}", error),
727 },
728 )?;
729 Err(error)
730 }
731 }
732 }
733
734 pub async fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
735 read_credentials_from_keychain(cx).await.is_some()
736 }
737
738 #[async_recursion]
739 pub async fn authenticate_and_connect(
740 self: &Arc<Self>,
741 try_keychain: bool,
742 cx: &AsyncAppContext,
743 ) -> anyhow::Result<()> {
744 let was_disconnected = match *self.status().borrow() {
745 Status::SignedOut => true,
746 Status::ConnectionError
747 | Status::ConnectionLost
748 | Status::Authenticating { .. }
749 | Status::Reauthenticating { .. }
750 | Status::ReconnectionError { .. } => false,
751 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
752 return Ok(())
753 }
754 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
755 };
756
757 if was_disconnected {
758 self.set_status(Status::Authenticating, cx);
759 } else {
760 self.set_status(Status::Reauthenticating, cx)
761 }
762
763 let mut read_from_keychain = false;
764 let mut credentials = self.state.read().credentials.clone();
765 if credentials.is_none() && try_keychain {
766 credentials = read_credentials_from_keychain(cx).await;
767 read_from_keychain = credentials.is_some();
768 }
769 if credentials.is_none() {
770 let mut status_rx = self.status();
771 let _ = status_rx.next().await;
772 futures::select_biased! {
773 authenticate = self.authenticate(cx).fuse() => {
774 match authenticate {
775 Ok(creds) => credentials = Some(creds),
776 Err(err) => {
777 self.set_status(Status::ConnectionError, cx);
778 return Err(err);
779 }
780 }
781 }
782 _ = status_rx.next().fuse() => {
783 return Err(anyhow!("authentication canceled"));
784 }
785 }
786 }
787 let credentials = credentials.unwrap();
788 self.set_id(credentials.user_id);
789
790 if was_disconnected {
791 self.set_status(Status::Connecting, cx);
792 } else {
793 self.set_status(Status::Reconnecting, cx);
794 }
795
796 let mut timeout = futures::FutureExt::fuse(cx.executor().timer(CONNECTION_TIMEOUT));
797 futures::select_biased! {
798 connection = self.establish_connection(&credentials, cx).fuse() => {
799 match connection {
800 Ok(conn) => {
801 self.state.write().credentials = Some(credentials.clone());
802 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
803 write_credentials_to_keychain(credentials, cx).log_err();
804 }
805
806 futures::select_biased! {
807 result = self.set_connection(conn, cx).fuse() => result,
808 _ = timeout => {
809 self.set_status(Status::ConnectionError, cx);
810 Err(anyhow!("timed out waiting on hello message from server"))
811 }
812 }
813 }
814 Err(EstablishConnectionError::Unauthorized) => {
815 self.state.write().credentials.take();
816 if read_from_keychain {
817 delete_credentials_from_keychain(cx).log_err();
818 self.set_status(Status::SignedOut, cx);
819 self.authenticate_and_connect(false, cx).await
820 } else {
821 self.set_status(Status::ConnectionError, cx);
822 Err(EstablishConnectionError::Unauthorized)?
823 }
824 }
825 Err(EstablishConnectionError::UpgradeRequired) => {
826 self.set_status(Status::UpgradeRequired, cx);
827 Err(EstablishConnectionError::UpgradeRequired)?
828 }
829 Err(error) => {
830 self.set_status(Status::ConnectionError, cx);
831 Err(error)?
832 }
833 }
834 }
835 _ = &mut timeout => {
836 self.set_status(Status::ConnectionError, cx);
837 Err(anyhow!("timed out trying to establish connection"))
838 }
839 }
840 }
841
842 async fn set_connection(
843 self: &Arc<Self>,
844 conn: Connection,
845 cx: &AsyncAppContext,
846 ) -> Result<()> {
847 let executor = cx.executor();
848 log::info!("add connection to peer");
849 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
850 let executor = executor.clone();
851 move |duration| executor.timer(duration)
852 });
853 let handle_io = executor.spawn(handle_io);
854
855 let peer_id = async {
856 log::info!("waiting for server hello");
857 let message = incoming
858 .next()
859 .await
860 .ok_or_else(|| anyhow!("no hello message received"))?;
861 log::info!("got server hello");
862 let hello_message_type_name = message.payload_type_name().to_string();
863 let hello = message
864 .into_any()
865 .downcast::<TypedEnvelope<proto::Hello>>()
866 .map_err(|_| {
867 anyhow!(
868 "invalid hello message received: {:?}",
869 hello_message_type_name
870 )
871 })?;
872 let peer_id = hello
873 .payload
874 .peer_id
875 .ok_or_else(|| anyhow!("invalid peer id"))?;
876 Ok(peer_id)
877 };
878
879 let peer_id = match peer_id.await {
880 Ok(peer_id) => peer_id,
881 Err(error) => {
882 self.peer.disconnect(connection_id);
883 return Err(error);
884 }
885 };
886
887 log::info!(
888 "set status to connected (connection id: {:?}, peer id: {:?})",
889 connection_id,
890 peer_id
891 );
892 self.set_status(
893 Status::Connected {
894 peer_id,
895 connection_id,
896 },
897 cx,
898 );
899
900 cx.spawn({
901 let this = self.clone();
902 |cx| {
903 async move {
904 while let Some(message) = incoming.next().await {
905 this.handle_message(message, &cx);
906 // Don't starve the main thread when receiving lots of messages at once.
907 smol::future::yield_now().await;
908 }
909 }
910 }
911 })
912 .detach();
913
914 cx.spawn({
915 let this = self.clone();
916 move |cx| async move {
917 match handle_io.await {
918 Ok(()) => {
919 if this.status().borrow().clone()
920 == (Status::Connected {
921 connection_id,
922 peer_id,
923 })
924 {
925 this.set_status(Status::SignedOut, &cx);
926 }
927 }
928 Err(err) => {
929 log::error!("connection error: {:?}", err);
930 this.set_status(Status::ConnectionLost, &cx);
931 }
932 }
933 }
934 })
935 .detach();
936
937 Ok(())
938 }
939
940 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
941 #[cfg(any(test, feature = "test-support"))]
942 if let Some(callback) = self.authenticate.read().as_ref() {
943 return callback(cx);
944 }
945
946 self.authenticate_with_browser(cx)
947 }
948
949 fn establish_connection(
950 self: &Arc<Self>,
951 credentials: &Credentials,
952 cx: &AsyncAppContext,
953 ) -> Task<Result<Connection, EstablishConnectionError>> {
954 #[cfg(any(test, feature = "test-support"))]
955 if let Some(callback) = self.establish_connection.read().as_ref() {
956 return callback(credentials, cx);
957 }
958
959 self.establish_websocket_connection(credentials, cx)
960 }
961
962 async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
963 let preview_param = if is_preview { "?preview=1" } else { "" };
964 let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
965 let response = http.get(&url, Default::default(), false).await?;
966
967 // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
968 // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
969 // which requires authorization via an HTTP header.
970 //
971 // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
972 // of a collab server. In that case, a request to the /rpc endpoint will
973 // return an 'unauthorized' response.
974 let collab_url = if response.status().is_redirection() {
975 response
976 .headers()
977 .get("Location")
978 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
979 .to_str()
980 .map_err(EstablishConnectionError::other)?
981 .to_string()
982 } else if response.status() == StatusCode::UNAUTHORIZED {
983 url
984 } else {
985 Err(anyhow!(
986 "unexpected /rpc response status {}",
987 response.status()
988 ))?
989 };
990
991 Url::parse(&collab_url).context("invalid rpc url")
992 }
993
994 fn establish_websocket_connection(
995 self: &Arc<Self>,
996 credentials: &Credentials,
997 cx: &AsyncAppContext,
998 ) -> Task<Result<Connection, EstablishConnectionError>> {
999 let use_preview_server = cx
1000 .try_read_global(|channel: &ReleaseChannel, _| *channel != ReleaseChannel::Stable)
1001 .unwrap_or(false);
1002
1003 let request = Request::builder()
1004 .header(
1005 "Authorization",
1006 format!("{} {}", credentials.user_id, credentials.access_token),
1007 )
1008 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
1009
1010 let http = self.http.clone();
1011 cx.executor().spawn(async move {
1012 let mut rpc_url = Self::get_rpc_url(http, use_preview_server).await?;
1013 let rpc_host = rpc_url
1014 .host_str()
1015 .zip(rpc_url.port_or_known_default())
1016 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1017 let stream = smol::net::TcpStream::connect(rpc_host).await?;
1018
1019 log::info!("connected to rpc endpoint {}", rpc_url);
1020
1021 match rpc_url.scheme() {
1022 "https" => {
1023 rpc_url.set_scheme("wss").unwrap();
1024 let request = request.uri(rpc_url.as_str()).body(())?;
1025 let (stream, _) =
1026 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1027 Ok(Connection::new(
1028 stream
1029 .map_err(|error| anyhow!(error))
1030 .sink_map_err(|error| anyhow!(error)),
1031 ))
1032 }
1033 "http" => {
1034 rpc_url.set_scheme("ws").unwrap();
1035 let request = request.uri(rpc_url.as_str()).body(())?;
1036 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1037 Ok(Connection::new(
1038 stream
1039 .map_err(|error| anyhow!(error))
1040 .sink_map_err(|error| anyhow!(error)),
1041 ))
1042 }
1043 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1044 }
1045 })
1046 }
1047
1048 pub fn authenticate_with_browser(
1049 self: &Arc<Self>,
1050 cx: &AsyncAppContext,
1051 ) -> Task<Result<Credentials>> {
1052 let http = self.http.clone();
1053 cx.spawn(|cx| async move {
1054 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1055 // zed server to encrypt the user's access token, so that it can'be intercepted by
1056 // any other app running on the user's device.
1057 let (public_key, private_key) =
1058 rpc::auth::keypair().expect("failed to generate keypair for auth");
1059 let public_key_string =
1060 String::try_from(public_key).expect("failed to serialize public key for auth");
1061
1062 if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1063 return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1064 }
1065
1066 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1067 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1068 let port = server.server_addr().port();
1069
1070 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1071 // that the user is signing in from a Zed app running on the same device.
1072 let mut url = format!(
1073 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1074 *ZED_SERVER_URL, port, public_key_string
1075 );
1076
1077 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1078 log::info!("impersonating user @{}", impersonate_login);
1079 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1080 }
1081
1082 cx.run_on_main(|cx| cx.open_url(&url))?.await;
1083
1084 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1085 // access token from the query params.
1086 //
1087 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1088 // custom URL scheme instead of this local HTTP server.
1089 let (user_id, access_token) = cx
1090 .spawn(|_| async move {
1091 for _ in 0..100 {
1092 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1093 let path = req.url();
1094 let mut user_id = None;
1095 let mut access_token = None;
1096 let url = Url::parse(&format!("http://example.com{}", path))
1097 .context("failed to parse login notification url")?;
1098 for (key, value) in url.query_pairs() {
1099 if key == "access_token" {
1100 access_token = Some(value.to_string());
1101 } else if key == "user_id" {
1102 user_id = Some(value.to_string());
1103 }
1104 }
1105
1106 let post_auth_url =
1107 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1108 req.respond(
1109 tiny_http::Response::empty(302).with_header(
1110 tiny_http::Header::from_bytes(
1111 &b"Location"[..],
1112 post_auth_url.as_bytes(),
1113 )
1114 .unwrap(),
1115 ),
1116 )
1117 .context("failed to respond to login http request")?;
1118 return Ok((
1119 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1120 access_token
1121 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1122 ));
1123 }
1124 }
1125
1126 Err(anyhow!("didn't receive login redirect"))
1127 })
1128 .await?;
1129
1130 let access_token = private_key
1131 .decrypt_string(&access_token)
1132 .context("failed to decrypt access token")?;
1133 cx.run_on_main(|cx| cx.activate(true))?.await;
1134
1135 Ok(Credentials {
1136 user_id: user_id.parse()?,
1137 access_token,
1138 })
1139 })
1140 }
1141
1142 async fn authenticate_as_admin(
1143 http: Arc<dyn HttpClient>,
1144 login: String,
1145 mut api_token: String,
1146 ) -> Result<Credentials> {
1147 #[derive(Deserialize)]
1148 struct AuthenticatedUserResponse {
1149 user: User,
1150 }
1151
1152 #[derive(Deserialize)]
1153 struct User {
1154 id: u64,
1155 }
1156
1157 // Use the collab server's admin API to retrieve the id
1158 // of the impersonated user.
1159 let mut url = Self::get_rpc_url(http.clone(), false).await?;
1160 url.set_path("/user");
1161 url.set_query(Some(&format!("github_login={login}")));
1162 let request = Request::get(url.as_str())
1163 .header("Authorization", format!("token {api_token}"))
1164 .body("".into())?;
1165
1166 let mut response = http.send(request).await?;
1167 let mut body = String::new();
1168 response.body_mut().read_to_string(&mut body).await?;
1169 if !response.status().is_success() {
1170 Err(anyhow!(
1171 "admin user request failed {} - {}",
1172 response.status().as_u16(),
1173 body,
1174 ))?;
1175 }
1176 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1177
1178 // Use the admin API token to authenticate as the impersonated user.
1179 api_token.insert_str(0, "ADMIN_TOKEN:");
1180 Ok(Credentials {
1181 user_id: response.user.id,
1182 access_token: api_token,
1183 })
1184 }
1185
1186 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1187 self.peer.teardown();
1188 self.set_status(Status::SignedOut, cx);
1189 }
1190
1191 pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1192 self.peer.teardown();
1193 self.set_status(Status::ConnectionLost, cx);
1194 }
1195
1196 fn connection_id(&self) -> Result<ConnectionId> {
1197 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1198 Ok(connection_id)
1199 } else {
1200 Err(anyhow!("not connected"))
1201 }
1202 }
1203
1204 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1205 log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1206 self.peer.send(self.connection_id()?, message)
1207 }
1208
1209 pub fn request<T: RequestMessage>(
1210 &self,
1211 request: T,
1212 ) -> impl Future<Output = Result<T::Response>> {
1213 self.request_envelope(request)
1214 .map_ok(|envelope| envelope.payload)
1215 }
1216
1217 pub fn request_envelope<T: RequestMessage>(
1218 &self,
1219 request: T,
1220 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1221 let client_id = self.id();
1222 log::debug!(
1223 "rpc request start. client_id:{}. name:{}",
1224 client_id,
1225 T::NAME
1226 );
1227 let response = self
1228 .connection_id()
1229 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1230 async move {
1231 let response = response?.await;
1232 log::debug!(
1233 "rpc request finish. client_id:{}. name:{}",
1234 client_id,
1235 T::NAME
1236 );
1237 response
1238 }
1239 }
1240
1241 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1242 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1243 self.peer.respond(receipt, response)
1244 }
1245
1246 fn respond_with_error<T: RequestMessage>(
1247 &self,
1248 receipt: Receipt<T>,
1249 error: proto::Error,
1250 ) -> Result<()> {
1251 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1252 self.peer.respond_with_error(receipt, error)
1253 }
1254
1255 fn handle_message(
1256 self: &Arc<Client>,
1257 message: Box<dyn AnyTypedEnvelope>,
1258 cx: &AsyncAppContext,
1259 ) {
1260 let mut state = self.state.write();
1261 let type_name = message.payload_type_name();
1262 let payload_type_id = message.payload_type_id();
1263 let sender_id = message.original_sender_id();
1264
1265 let mut subscriber = None;
1266
1267 if let Some(handle) = state
1268 .models_by_message_type
1269 .get(&payload_type_id)
1270 .and_then(|handle| handle.upgrade())
1271 {
1272 subscriber = Some(Subscriber {
1273 handle,
1274 window_handle: None,
1275 });
1276 } else if let Some((extract_entity_id, entity_type_id)) =
1277 state.entity_id_extractors.get(&payload_type_id).zip(
1278 state
1279 .entity_types_by_message_type
1280 .get(&payload_type_id)
1281 .copied(),
1282 )
1283 {
1284 let entity_id = (extract_entity_id)(message.as_ref());
1285
1286 match state
1287 .entities_by_type_and_remote_id
1288 .get_mut(&(entity_type_id, entity_id))
1289 {
1290 Some(WeakSubscriber::Pending(pending)) => {
1291 pending.push(message);
1292 return;
1293 }
1294 Some(weak_subscriber @ _) => match weak_subscriber {
1295 WeakSubscriber::Entity {
1296 handle,
1297 window_handle,
1298 } => {
1299 subscriber = handle.upgrade().map(|handle| Subscriber {
1300 handle,
1301 window_handle: window_handle.clone(),
1302 });
1303 }
1304
1305 WeakSubscriber::Pending(_) => {}
1306 },
1307 _ => {}
1308 }
1309 }
1310
1311 let subscriber = if let Some(subscriber) = subscriber {
1312 subscriber
1313 } else {
1314 log::info!("unhandled message {}", type_name);
1315 self.peer.respond_with_unhandled_message(message).log_err();
1316 return;
1317 };
1318
1319 let handler = state.message_handlers.get(&payload_type_id).cloned();
1320 // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1321 // It also ensures we don't hold the lock while yielding back to the executor, as
1322 // that might cause the executor thread driving this future to block indefinitely.
1323 drop(state);
1324
1325 if let Some(handler) = handler {
1326 let future = handler(subscriber, message, &self, cx.clone());
1327 let client_id = self.id();
1328 log::debug!(
1329 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1330 client_id,
1331 sender_id,
1332 type_name
1333 );
1334 cx.spawn_on_main(|_| async move {
1335 match future.await {
1336 Ok(()) => {
1337 log::debug!(
1338 "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1339 client_id,
1340 sender_id,
1341 type_name
1342 );
1343 }
1344 Err(error) => {
1345 log::error!(
1346 "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1347 client_id,
1348 sender_id,
1349 type_name,
1350 error
1351 );
1352 }
1353 }
1354 })
1355 .detach();
1356 } else {
1357 log::info!("unhandled message {}", type_name);
1358 self.peer.respond_with_unhandled_message(message).log_err();
1359 }
1360 }
1361
1362 pub fn telemetry(&self) -> &Arc<Telemetry> {
1363 &self.telemetry
1364 }
1365}
1366
1367async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1368 if IMPERSONATE_LOGIN.is_some() {
1369 return None;
1370 }
1371
1372 let (user_id, access_token) = cx
1373 .run_on_main(|cx| cx.read_credentials(&ZED_SERVER_URL).log_err().flatten())
1374 .ok()?
1375 .await?;
1376
1377 Some(Credentials {
1378 user_id: user_id.parse().ok()?,
1379 access_token: String::from_utf8(access_token).ok()?,
1380 })
1381}
1382
1383async fn write_credentials_to_keychain(
1384 credentials: Credentials,
1385 cx: &AsyncAppContext,
1386) -> Result<()> {
1387 cx.run_on_main(move |cx| {
1388 cx.write_credentials(
1389 &ZED_SERVER_URL,
1390 &credentials.user_id.to_string(),
1391 credentials.access_token.as_bytes(),
1392 )
1393 })?
1394 .await
1395}
1396
1397async fn delete_credentials_from_keychain(cx: &AsyncAppContext) -> Result<()> {
1398 cx.run_on_main(move |cx| cx.delete_credentials(&ZED_SERVER_URL))?
1399 .await
1400}
1401
1402const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1403
1404pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1405 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1406}
1407
1408pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1409 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1410 let mut parts = path.split('/');
1411 let id = parts.next()?.parse::<u64>().ok()?;
1412 let access_token = parts.next()?;
1413 if access_token.is_empty() {
1414 return None;
1415 }
1416 Some((id, access_token.to_string()))
1417}
1418
1419// #[cfg(test)]
1420// mod tests {
1421// use super::*;
1422// use crate::test::FakeServer;
1423// use gpui::{executor::Deterministic, TestAppContext};
1424// use parking_lot::Mutex;
1425// use std::future;
1426// use util::http::FakeHttpClient;
1427
1428// #[gpui::test(iterations = 10)]
1429// async fn test_reconnection(cx: &mut TestAppContext) {
1430// cx.foreground().forbid_parking();
1431
1432// let user_id = 5;
1433// let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1434// let server = FakeServer::for_client(user_id, &client, cx).await;
1435// let mut status = client.status();
1436// assert!(matches!(
1437// status.next().await,
1438// Some(Status::Connected { .. })
1439// ));
1440// assert_eq!(server.auth_count(), 1);
1441
1442// server.forbid_connections();
1443// server.disconnect();
1444// while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1445
1446// server.allow_connections();
1447// cx.foreground().advance_clock(Duration::from_secs(10));
1448// while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1449// assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1450
1451// server.forbid_connections();
1452// server.disconnect();
1453// while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1454
1455// // Clear cached credentials after authentication fails
1456// server.roll_access_token();
1457// server.allow_connections();
1458// cx.foreground().advance_clock(Duration::from_secs(10));
1459// while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1460// assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1461// }
1462
1463// #[gpui::test(iterations = 10)]
1464// async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1465// deterministic.forbid_parking();
1466
1467// let user_id = 5;
1468// let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1469// let mut status = client.status();
1470
1471// // Time out when client tries to connect.
1472// client.override_authenticate(move |cx| {
1473// cx.foreground().spawn(async move {
1474// Ok(Credentials {
1475// user_id,
1476// access_token: "token".into(),
1477// })
1478// })
1479// });
1480// client.override_establish_connection(|_, cx| {
1481// cx.foreground().spawn(async move {
1482// future::pending::<()>().await;
1483// unreachable!()
1484// })
1485// });
1486// let auth_and_connect = cx.spawn({
1487// let client = client.clone();
1488// |cx| async move { client.authenticate_and_connect(false, &cx).await }
1489// });
1490// deterministic.run_until_parked();
1491// assert!(matches!(status.next().await, Some(Status::Connecting)));
1492
1493// deterministic.advance_clock(CONNECTION_TIMEOUT);
1494// assert!(matches!(
1495// status.next().await,
1496// Some(Status::ConnectionError { .. })
1497// ));
1498// auth_and_connect.await.unwrap_err();
1499
1500// // Allow the connection to be established.
1501// let server = FakeServer::for_client(user_id, &client, cx).await;
1502// assert!(matches!(
1503// status.next().await,
1504// Some(Status::Connected { .. })
1505// ));
1506
1507// // Disconnect client.
1508// server.forbid_connections();
1509// server.disconnect();
1510// while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1511
1512// // Time out when re-establishing the connection.
1513// server.allow_connections();
1514// client.override_establish_connection(|_, cx| {
1515// cx.foreground().spawn(async move {
1516// future::pending::<()>().await;
1517// unreachable!()
1518// })
1519// });
1520// deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1521// assert!(matches!(
1522// status.next().await,
1523// Some(Status::Reconnecting { .. })
1524// ));
1525
1526// deterministic.advance_clock(CONNECTION_TIMEOUT);
1527// assert!(matches!(
1528// status.next().await,
1529// Some(Status::ReconnectionError { .. })
1530// ));
1531// }
1532
1533// #[gpui::test(iterations = 10)]
1534// async fn test_authenticating_more_than_once(
1535// cx: &mut TestAppContext,
1536// deterministic: Arc<Deterministic>,
1537// ) {
1538// cx.foreground().forbid_parking();
1539
1540// let auth_count = Arc::new(Mutex::new(0));
1541// let dropped_auth_count = Arc::new(Mutex::new(0));
1542// let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1543// client.override_authenticate({
1544// let auth_count = auth_count.clone();
1545// let dropped_auth_count = dropped_auth_count.clone();
1546// move |cx| {
1547// let auth_count = auth_count.clone();
1548// let dropped_auth_count = dropped_auth_count.clone();
1549// cx.foreground().spawn(async move {
1550// *auth_count.lock() += 1;
1551// let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1552// future::pending::<()>().await;
1553// unreachable!()
1554// })
1555// }
1556// });
1557
1558// let _authenticate = cx.spawn(|cx| {
1559// let client = client.clone();
1560// async move { client.authenticate_and_connect(false, &cx).await }
1561// });
1562// deterministic.run_until_parked();
1563// assert_eq!(*auth_count.lock(), 1);
1564// assert_eq!(*dropped_auth_count.lock(), 0);
1565
1566// let _authenticate = cx.spawn(|cx| {
1567// let client = client.clone();
1568// async move { client.authenticate_and_connect(false, &cx).await }
1569// });
1570// deterministic.run_until_parked();
1571// assert_eq!(*auth_count.lock(), 2);
1572// assert_eq!(*dropped_auth_count.lock(), 1);
1573// }
1574
1575// #[test]
1576// fn test_encode_and_decode_worktree_url() {
1577// let url = encode_worktree_url(5, "deadbeef");
1578// assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1579// assert_eq!(
1580// decode_worktree_url(&format!("\n {}\t", url)),
1581// Some((5, "deadbeef".to_string()))
1582// );
1583// assert_eq!(decode_worktree_url("not://the-right-format"), None);
1584// }
1585
1586// #[gpui::test]
1587// async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1588// cx.foreground().forbid_parking();
1589
1590// let user_id = 5;
1591// let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1592// let server = FakeServer::for_client(user_id, &client, cx).await;
1593
1594// let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1595// let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1596// client.add_model_message_handler(
1597// move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1598// match model.read_with(&cx, |model, _| model.id) {
1599// 1 => done_tx1.try_send(()).unwrap(),
1600// 2 => done_tx2.try_send(()).unwrap(),
1601// _ => unreachable!(),
1602// }
1603// async { Ok(()) }
1604// },
1605// );
1606// let model1 = cx.add_model(|_| Model {
1607// id: 1,
1608// subscription: None,
1609// });
1610// let model2 = cx.add_model(|_| Model {
1611// id: 2,
1612// subscription: None,
1613// });
1614// let model3 = cx.add_model(|_| Model {
1615// id: 3,
1616// subscription: None,
1617// });
1618
1619// let _subscription1 = client
1620// .subscribe_to_entity(1)
1621// .unwrap()
1622// .set_model(&model1, &mut cx.to_async());
1623// let _subscription2 = client
1624// .subscribe_to_entity(2)
1625// .unwrap()
1626// .set_model(&model2, &mut cx.to_async());
1627// // Ensure dropping a subscription for the same entity type still allows receiving of
1628// // messages for other entity IDs of the same type.
1629// let subscription3 = client
1630// .subscribe_to_entity(3)
1631// .unwrap()
1632// .set_model(&model3, &mut cx.to_async());
1633// drop(subscription3);
1634
1635// server.send(proto::JoinProject { project_id: 1 });
1636// server.send(proto::JoinProject { project_id: 2 });
1637// done_rx1.next().await.unwrap();
1638// done_rx2.next().await.unwrap();
1639// }
1640
1641// #[gpui::test]
1642// async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1643// cx.foreground().forbid_parking();
1644
1645// let user_id = 5;
1646// let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1647// let server = FakeServer::for_client(user_id, &client, cx).await;
1648
1649// let model = cx.add_model(|_| Model::default());
1650// let (done_tx1, _done_rx1) = smol::channel::unbounded();
1651// let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1652// let subscription1 = client.add_message_handler(
1653// model.clone(),
1654// move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1655// done_tx1.try_send(()).unwrap();
1656// async { Ok(()) }
1657// },
1658// );
1659// drop(subscription1);
1660// let _subscription2 = client.add_message_handler(
1661// model.clone(),
1662// move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1663// done_tx2.try_send(()).unwrap();
1664// async { Ok(()) }
1665// },
1666// );
1667// server.send(proto::Ping {});
1668// done_rx2.next().await.unwrap();
1669// }
1670
1671// #[gpui::test]
1672// async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1673// cx.foreground().forbid_parking();
1674
1675// let user_id = 5;
1676// let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1677// let server = FakeServer::for_client(user_id, &client, cx).await;
1678
1679// let model = cx.add_model(|_| Model::default());
1680// let (done_tx, mut done_rx) = smol::channel::unbounded();
1681// let subscription = client.add_message_handler(
1682// model.clone(),
1683// move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1684// model.update(&mut cx, |model, _| model.subscription.take());
1685// done_tx.try_send(()).unwrap();
1686// async { Ok(()) }
1687// },
1688// );
1689// model.update(cx, |model, _| {
1690// model.subscription = Some(subscription);
1691// });
1692// server.send(proto::Ping {});
1693// done_rx.next().await.unwrap();
1694// }
1695
1696// #[derive(Default)]
1697// struct Model {
1698// id: usize,
1699// subscription: Option<Subscription>,
1700// }
1701
1702// impl Entity for Model {
1703// type Event = ();
1704// }
1705// }