1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod telemetry;
5pub mod user;
6
7use anyhow::{anyhow, Context as _, Result};
8use async_recursion::async_recursion;
9use async_tungstenite::tungstenite::{
10 error::Error as WebsocketError,
11 http::{Request, StatusCode},
12};
13use futures::{
14 future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
15 TryStreamExt,
16};
17use gpui2::{
18 serde_json, AnyHandle, AnyWeakHandle, AnyWindowHandle, AppContext, AsyncAppContext,
19 AsyncWindowContext, Handle, SemanticVersion, Task, ViewContext, WeakHandle, WindowId,
20};
21use lazy_static::lazy_static;
22use parking_lot::RwLock;
23use postage::watch;
24use rand::prelude::*;
25use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use std::{
29 any::TypeId,
30 collections::HashMap,
31 convert::TryFrom,
32 fmt::Write as _,
33 future::Future,
34 marker::PhantomData,
35 path::PathBuf,
36 sync::{atomic::AtomicU64, Arc, Weak},
37 time::{Duration, Instant},
38};
39use telemetry::Telemetry;
40use thiserror::Error;
41use url::Url;
42use util::channel::ReleaseChannel;
43use util::http::HttpClient;
44use util::{ResultExt, TryFutureExt};
45
46pub use rpc::*;
47pub use telemetry::ClickhouseEvent;
48pub use user::*;
49
50lazy_static! {
51 pub static ref ZED_SERVER_URL: String =
52 std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
53 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
54 .ok()
55 .and_then(|s| if s.is_empty() { None } else { Some(s) });
56 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
57 .ok()
58 .and_then(|s| if s.is_empty() { None } else { Some(s) });
59 pub static ref ZED_APP_VERSION: Option<SemanticVersion> = std::env::var("ZED_APP_VERSION")
60 .ok()
61 .and_then(|v| v.parse().ok());
62 pub static ref ZED_APP_PATH: Option<PathBuf> =
63 std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
64 pub static ref ZED_ALWAYS_ACTIVE: bool =
65 std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| e.len() > 0);
66}
67
68pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
69pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
70pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
71
72#[derive(Clone, Default, PartialEq, Deserialize)]
73pub struct SignIn;
74
75#[derive(Clone, Default, PartialEq, Deserialize)]
76pub struct SignOut;
77
78#[derive(Clone, Default, PartialEq, Deserialize)]
79pub struct Reconnect;
80
81pub fn init_settings(cx: &mut AppContext) {
82 settings2::register::<TelemetrySettings>(cx);
83}
84
85pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
86 init_settings(cx);
87
88 let client = Arc::downgrade(client);
89 cx.register_action_type::<SignIn>();
90 cx.on_action({
91 let client = client.clone();
92 move |_: &SignIn, cx| {
93 if let Some(client) = client.upgrade() {
94 cx.spawn(
95 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
96 )
97 .detach();
98 }
99 }
100 });
101
102 cx.register_action_type::<SignOut>();
103 cx.on_action({
104 let client = client.clone();
105 move |_: &SignOut, cx| {
106 if let Some(client) = client.upgrade() {
107 cx.spawn(|cx| async move {
108 client.disconnect(&cx);
109 })
110 .detach();
111 }
112 }
113 });
114
115 cx.register_action_type::<Reconnect>();
116 cx.on_action({
117 let client = client.clone();
118 move |_: &Reconnect, cx| {
119 if let Some(client) = client.upgrade() {
120 cx.spawn(|cx| async move {
121 client.reconnect(&cx);
122 })
123 .detach();
124 }
125 }
126 });
127}
128
129pub struct Client {
130 id: AtomicU64,
131 peer: Arc<Peer>,
132 http: Arc<dyn HttpClient>,
133 telemetry: Arc<Telemetry>,
134 state: RwLock<ClientState>,
135
136 #[allow(clippy::type_complexity)]
137 #[cfg(any(test, feature = "test-support"))]
138 authenticate: RwLock<
139 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
140 >,
141
142 #[allow(clippy::type_complexity)]
143 #[cfg(any(test, feature = "test-support"))]
144 establish_connection: RwLock<
145 Option<
146 Box<
147 dyn 'static
148 + Send
149 + Sync
150 + Fn(
151 &Credentials,
152 &AsyncAppContext,
153 ) -> Task<Result<Connection, EstablishConnectionError>>,
154 >,
155 >,
156 >,
157}
158
159#[derive(Error, Debug)]
160pub enum EstablishConnectionError {
161 #[error("upgrade required")]
162 UpgradeRequired,
163 #[error("unauthorized")]
164 Unauthorized,
165 #[error("{0}")]
166 Other(#[from] anyhow::Error),
167 #[error("{0}")]
168 Http(#[from] util::http::Error),
169 #[error("{0}")]
170 Io(#[from] std::io::Error),
171 #[error("{0}")]
172 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
173}
174
175impl From<WebsocketError> for EstablishConnectionError {
176 fn from(error: WebsocketError) -> Self {
177 if let WebsocketError::Http(response) = &error {
178 match response.status() {
179 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
180 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
181 _ => {}
182 }
183 }
184 EstablishConnectionError::Other(error.into())
185 }
186}
187
188impl EstablishConnectionError {
189 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
190 Self::Other(error.into())
191 }
192}
193
194#[derive(Copy, Clone, Debug, PartialEq)]
195pub enum Status {
196 SignedOut,
197 UpgradeRequired,
198 Authenticating,
199 Connecting,
200 ConnectionError,
201 Connected {
202 peer_id: PeerId,
203 connection_id: ConnectionId,
204 },
205 ConnectionLost,
206 Reauthenticating,
207 Reconnecting,
208 ReconnectionError {
209 next_reconnection: Instant,
210 },
211}
212
213impl Status {
214 pub fn is_connected(&self) -> bool {
215 matches!(self, Self::Connected { .. })
216 }
217
218 pub fn is_signed_out(&self) -> bool {
219 matches!(self, Self::SignedOut | Self::UpgradeRequired)
220 }
221}
222
223struct ClientState {
224 credentials: Option<Credentials>,
225 status: (watch::Sender<Status>, watch::Receiver<Status>),
226 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
227 _reconnect_task: Option<Task<()>>,
228 reconnect_interval: Duration,
229 entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
230 models_by_message_type: HashMap<TypeId, AnyWeakHandle>,
231 entity_types_by_message_type: HashMap<TypeId, TypeId>,
232 #[allow(clippy::type_complexity)]
233 message_handlers: HashMap<
234 TypeId,
235 Arc<
236 dyn Send
237 + Sync
238 + Fn(
239 Subscriber,
240 Box<dyn AnyTypedEnvelope>,
241 &Arc<Client>,
242 AsyncAppContext,
243 ) -> LocalBoxFuture<'static, Result<()>>,
244 >,
245 >,
246}
247
248enum WeakSubscriber {
249 Entity {
250 handle: AnyWeakHandle,
251 window_id: Option<WindowId>,
252 },
253 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
254}
255
256struct Subscriber {
257 handle: AnyHandle,
258 window_handle: Option<AnyWindowHandle>,
259}
260
261#[derive(Clone, Debug)]
262pub struct Credentials {
263 pub user_id: u64,
264 pub access_token: String,
265}
266
267impl Default for ClientState {
268 fn default() -> Self {
269 Self {
270 credentials: None,
271 status: watch::channel_with(Status::SignedOut),
272 entity_id_extractors: Default::default(),
273 _reconnect_task: None,
274 reconnect_interval: Duration::from_secs(5),
275 models_by_message_type: Default::default(),
276 entities_by_type_and_remote_id: Default::default(),
277 entity_types_by_message_type: Default::default(),
278 message_handlers: Default::default(),
279 }
280 }
281}
282
283pub enum Subscription {
284 Entity {
285 client: Weak<Client>,
286 id: (TypeId, u64),
287 },
288 Message {
289 client: Weak<Client>,
290 id: TypeId,
291 },
292}
293
294impl Drop for Subscription {
295 fn drop(&mut self) {
296 match self {
297 Subscription::Entity { client, id } => {
298 if let Some(client) = client.upgrade() {
299 let mut state = client.state.write();
300 let _ = state.entities_by_type_and_remote_id.remove(id);
301 }
302 }
303 Subscription::Message { client, id } => {
304 if let Some(client) = client.upgrade() {
305 let mut state = client.state.write();
306 let _ = state.entity_types_by_message_type.remove(id);
307 let _ = state.message_handlers.remove(id);
308 }
309 }
310 }
311 }
312}
313
314pub struct PendingEntitySubscription<T> {
315 client: Arc<Client>,
316 remote_id: u64,
317 _entity_type: PhantomData<T>,
318 consumed: bool,
319}
320
321impl<T> PendingEntitySubscription<T>
322where
323 T: 'static + Send + Sync,
324{
325 pub fn set_model(mut self, model: &Handle<T>, cx: &mut AsyncAppContext) -> Subscription {
326 self.consumed = true;
327 let mut state = self.client.state.write();
328 let id = (TypeId::of::<T>(), self.remote_id);
329 let Some(WeakSubscriber::Pending(messages)) =
330 state.entities_by_type_and_remote_id.remove(&id)
331 else {
332 unreachable!()
333 };
334
335 state.entities_by_type_and_remote_id.insert(
336 id,
337 WeakSubscriber::Entity {
338 handle: model.downgrade().into(),
339 window_id: None,
340 },
341 );
342 drop(state);
343 for message in messages {
344 self.client.handle_message(message, cx);
345 }
346 Subscription::Entity {
347 client: Arc::downgrade(&self.client),
348 id,
349 }
350 }
351}
352
353impl<T> Drop for PendingEntitySubscription<T> {
354 fn drop(&mut self) {
355 if !self.consumed {
356 let mut state = self.client.state.write();
357 if let Some(WeakSubscriber::Pending(messages)) = state
358 .entities_by_type_and_remote_id
359 .remove(&(TypeId::of::<T>(), self.remote_id))
360 {
361 for message in messages {
362 log::info!("unhandled message {}", message.payload_type_name());
363 }
364 }
365 }
366 }
367}
368
369#[derive(Copy, Clone)]
370pub struct TelemetrySettings {
371 pub diagnostics: bool,
372 pub metrics: bool,
373}
374
375#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
376pub struct TelemetrySettingsContent {
377 pub diagnostics: Option<bool>,
378 pub metrics: Option<bool>,
379}
380
381impl settings2::Setting for TelemetrySettings {
382 const KEY: Option<&'static str> = Some("telemetry");
383
384 type FileContent = TelemetrySettingsContent;
385
386 fn load(
387 default_value: &Self::FileContent,
388 user_values: &[&Self::FileContent],
389 _: &AppContext,
390 ) -> Result<Self> {
391 Ok(Self {
392 diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
393 default_value
394 .diagnostics
395 .ok_or_else(Self::missing_default)?,
396 ),
397 metrics: user_values
398 .first()
399 .and_then(|v| v.metrics)
400 .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
401 })
402 }
403}
404
405impl Client {
406 pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
407 Arc::new(Self {
408 id: AtomicU64::new(0),
409 peer: Peer::new(0),
410 telemetry: Telemetry::new(http.clone(), cx),
411 http,
412 state: Default::default(),
413
414 #[cfg(any(test, feature = "test-support"))]
415 authenticate: Default::default(),
416 #[cfg(any(test, feature = "test-support"))]
417 establish_connection: Default::default(),
418 })
419 }
420
421 pub fn id(&self) -> u64 {
422 self.id.load(std::sync::atomic::Ordering::SeqCst)
423 }
424
425 pub fn http_client(&self) -> Arc<dyn HttpClient> {
426 self.http.clone()
427 }
428
429 pub fn set_id(&self, id: u64) -> &Self {
430 self.id.store(id, std::sync::atomic::Ordering::SeqCst);
431 self
432 }
433
434 #[cfg(any(test, feature = "test-support"))]
435 pub fn teardown(&self) {
436 let mut state = self.state.write();
437 state._reconnect_task.take();
438 state.message_handlers.clear();
439 state.models_by_message_type.clear();
440 state.entities_by_type_and_remote_id.clear();
441 state.entity_id_extractors.clear();
442 self.peer.teardown();
443 }
444
445 #[cfg(any(test, feature = "test-support"))]
446 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
447 where
448 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
449 {
450 *self.authenticate.write() = Some(Box::new(authenticate));
451 self
452 }
453
454 #[cfg(any(test, feature = "test-support"))]
455 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
456 where
457 F: 'static
458 + Send
459 + Sync
460 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
461 {
462 *self.establish_connection.write() = Some(Box::new(connect));
463 self
464 }
465
466 pub fn user_id(&self) -> Option<u64> {
467 self.state
468 .read()
469 .credentials
470 .as_ref()
471 .map(|credentials| credentials.user_id)
472 }
473
474 pub fn peer_id(&self) -> Option<PeerId> {
475 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
476 Some(*peer_id)
477 } else {
478 None
479 }
480 }
481
482 pub fn status(&self) -> watch::Receiver<Status> {
483 self.state.read().status.1.clone()
484 }
485
486 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
487 log::info!("set status on client {}: {:?}", self.id(), status);
488 let mut state = self.state.write();
489 *state.status.0.borrow_mut() = status;
490
491 match status {
492 Status::Connected { .. } => {
493 state._reconnect_task = None;
494 }
495 Status::ConnectionLost => {
496 let this = self.clone();
497 let reconnect_interval = state.reconnect_interval;
498 state._reconnect_task = Some(cx.spawn(|cx| async move {
499 #[cfg(any(test, feature = "test-support"))]
500 let mut rng = StdRng::seed_from_u64(0);
501 #[cfg(not(any(test, feature = "test-support")))]
502 let mut rng = StdRng::from_entropy();
503
504 let mut delay = INITIAL_RECONNECTION_DELAY;
505 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
506 log::error!("failed to connect {}", error);
507 if matches!(*this.status().borrow(), Status::ConnectionError) {
508 this.set_status(
509 Status::ReconnectionError {
510 next_reconnection: Instant::now() + delay,
511 },
512 &cx,
513 );
514 cx.background().timer(delay).await;
515 delay = delay
516 .mul_f32(rng.gen_range(1.0..=2.0))
517 .min(reconnect_interval);
518 } else {
519 break;
520 }
521 }
522 }));
523 }
524 Status::SignedOut | Status::UpgradeRequired => {
525 cx.read(|cx| self.telemetry.set_authenticated_user_info(None, false, cx));
526 state._reconnect_task.take();
527 }
528 _ => {}
529 }
530 }
531
532 pub fn add_view_for_remote_entity<T>(
533 self: &Arc<Self>,
534 remote_id: u64,
535 cx: &mut ViewContext<T>,
536 ) -> Subscription {
537 let id = (TypeId::of::<T>(), remote_id);
538 self.state.write().entities_by_type_and_remote_id.insert(
539 id,
540 WeakSubscriber::Entity {
541 handle: cx.handle().into_any(),
542 window_id: Some(cx.window_id()),
543 },
544 );
545 Subscription::Entity {
546 client: Arc::downgrade(self),
547 id,
548 }
549 }
550
551 pub fn subscribe_to_entity<T>(
552 self: &Arc<Self>,
553 remote_id: u64,
554 ) -> Result<PendingEntitySubscription<T>> {
555 let id = (TypeId::of::<T>(), remote_id);
556
557 let mut state = self.state.write();
558 if state.entities_by_type_and_remote_id.contains_key(&id) {
559 return Err(anyhow!("already subscribed to entity"));
560 } else {
561 state
562 .entities_by_type_and_remote_id
563 .insert(id, WeakSubscriber::Pending(Default::default()));
564 Ok(PendingEntitySubscription {
565 client: self.clone(),
566 remote_id,
567 consumed: false,
568 _entity_type: PhantomData,
569 })
570 }
571 }
572
573 #[track_caller]
574 pub fn add_message_handler<M, E, H, F>(
575 self: &Arc<Self>,
576 model: Handle<E>,
577 handler: H,
578 ) -> Subscription
579 where
580 M: EnvelopedMessage,
581 E: 'static + Send + Sync,
582 H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
583 F: 'static + Future<Output = Result<()>>,
584 {
585 let message_type_id = TypeId::of::<M>();
586
587 let mut state = self.state.write();
588 state
589 .models_by_message_type
590 .insert(message_type_id, model.downgrade().into_any());
591
592 let prev_handler = state.message_handlers.insert(
593 message_type_id,
594 Arc::new(move |handle, envelope, client, cx| {
595 let handle = if let Subscriber::Model(handle) = handle {
596 handle
597 } else {
598 unreachable!();
599 };
600 let model = handle.downcast::<E>().unwrap();
601 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
602 handler(model, *envelope, client.clone(), cx).boxed_local()
603 }),
604 );
605 if prev_handler.is_some() {
606 let location = std::panic::Location::caller();
607 panic!(
608 "{}:{} registered handler for the same message {} twice",
609 location.file(),
610 location.line(),
611 std::any::type_name::<M>()
612 );
613 }
614
615 Subscription::Message {
616 client: Arc::downgrade(self),
617 id: message_type_id,
618 }
619 }
620
621 pub fn add_request_handler<M, E, H, F>(
622 self: &Arc<Self>,
623 model: Handle<E>,
624 handler: H,
625 ) -> Subscription
626 where
627 M: RequestMessage,
628 E: 'static + Send + Sync,
629 H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
630 F: 'static + Future<Output = Result<M::Response>>,
631 {
632 self.add_message_handler(model, move |handle, envelope, this, cx| {
633 Self::respond_to_request(
634 envelope.receipt(),
635 handler(handle, envelope, this.clone(), cx),
636 this,
637 )
638 })
639 }
640
641 pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
642 where
643 M: EntityMessage,
644 H: 'static
645 + Send
646 + Sync
647 + Fn(WeakHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncWindowContext) -> F,
648 F: 'static + Future<Output = Result<()>>,
649 {
650 self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
651 if let Some(window_handle) = subscriber.window_handle {
652 cx.update_window(subscriber, |cx| {
653 handler(
654 subscriber.handle.downcast::<E>().unwrap(),
655 message,
656 client,
657 cx,
658 )
659 })
660 } else {
661 panic!()
662 }
663 })
664 }
665
666 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
667 where
668 M: EntityMessage,
669 E: 'static + Send + Sync,
670 H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
671 F: 'static + Future<Output = Result<()>>,
672 {
673 self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
674 handler(
675 subscriber.handle.downcast::<E>().unwrap(),
676 message,
677 client,
678 cx,
679 )
680 })
681 }
682
683 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
684 where
685 M: EntityMessage,
686 H: 'static
687 + Send
688 + Sync
689 + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
690 F: 'static + Future<Output = Result<()>>,
691 {
692 let model_type_id = TypeId::of::<E>();
693 let message_type_id = TypeId::of::<M>();
694
695 let mut state = self.state.write();
696 state
697 .entity_types_by_message_type
698 .insert(message_type_id, model_type_id);
699 state
700 .entity_id_extractors
701 .entry(message_type_id)
702 .or_insert_with(|| {
703 |envelope| {
704 envelope
705 .as_any()
706 .downcast_ref::<TypedEnvelope<M>>()
707 .unwrap()
708 .payload
709 .remote_entity_id()
710 }
711 });
712 let prev_handler = state.message_handlers.insert(
713 message_type_id,
714 Arc::new(move |handle, envelope, client, cx| {
715 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
716 handler(handle, *envelope, client.clone(), cx).boxed_local()
717 }),
718 );
719 if prev_handler.is_some() {
720 panic!("registered handler for the same message twice");
721 }
722 }
723
724 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
725 where
726 M: EntityMessage + RequestMessage,
727 E: 'static + Send + Sync,
728 H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
729 F: 'static + Future<Output = Result<M::Response>>,
730 {
731 self.add_model_message_handler(move |entity, envelope, client, cx| {
732 Self::respond_to_request::<M, _>(
733 envelope.receipt(),
734 handler(entity, envelope, client.clone(), cx),
735 client,
736 )
737 })
738 }
739
740 pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
741 where
742 M: EntityMessage + RequestMessage,
743 E: 'static + Send + Sync,
744 H: 'static
745 + Send
746 + Sync
747 + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncWindowContext) -> F,
748 F: 'static + Future<Output = Result<M::Response>>,
749 {
750 self.add_view_message_handler(move |entity, envelope, client, cx| {
751 Self::respond_to_request::<M, _>(
752 envelope.receipt(),
753 handler(entity, envelope, client.clone(), cx),
754 client,
755 )
756 })
757 }
758
759 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
760 receipt: Receipt<T>,
761 response: F,
762 client: Arc<Self>,
763 ) -> Result<()> {
764 match response.await {
765 Ok(response) => {
766 client.respond(receipt, response)?;
767 Ok(())
768 }
769 Err(error) => {
770 client.respond_with_error(
771 receipt,
772 proto::Error {
773 message: format!("{:?}", error),
774 },
775 )?;
776 Err(error)
777 }
778 }
779 }
780
781 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
782 read_credentials_from_keychain(cx).is_some()
783 }
784
785 #[async_recursion(?Send)]
786 pub async fn authenticate_and_connect(
787 self: &Arc<Self>,
788 try_keychain: bool,
789 cx: &AsyncAppContext,
790 ) -> anyhow::Result<()> {
791 let was_disconnected = match *self.status().borrow() {
792 Status::SignedOut => true,
793 Status::ConnectionError
794 | Status::ConnectionLost
795 | Status::Authenticating { .. }
796 | Status::Reauthenticating { .. }
797 | Status::ReconnectionError { .. } => false,
798 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
799 return Ok(())
800 }
801 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
802 };
803
804 if was_disconnected {
805 self.set_status(Status::Authenticating, cx);
806 } else {
807 self.set_status(Status::Reauthenticating, cx)
808 }
809
810 let mut read_from_keychain = false;
811 let mut credentials = self.state.read().credentials.clone();
812 if credentials.is_none() && try_keychain {
813 credentials = read_credentials_from_keychain(cx).await;
814 read_from_keychain = credentials.is_some();
815 }
816 if credentials.is_none() {
817 let mut status_rx = self.status();
818 let _ = status_rx.next().await;
819 futures::select_biased! {
820 authenticate = self.authenticate(cx).fuse() => {
821 match authenticate {
822 Ok(creds) => credentials = Some(creds),
823 Err(err) => {
824 self.set_status(Status::ConnectionError, cx);
825 return Err(err);
826 }
827 }
828 }
829 _ = status_rx.next().fuse() => {
830 return Err(anyhow!("authentication canceled"));
831 }
832 }
833 }
834 let credentials = credentials.unwrap();
835 self.set_id(credentials.user_id);
836
837 if was_disconnected {
838 self.set_status(Status::Connecting, cx);
839 } else {
840 self.set_status(Status::Reconnecting, cx);
841 }
842
843 let mut timeout = futures::FutureExt::fuse(cx.executor()?.timer(CONNECTION_TIMEOUT));
844 futures::select_biased! {
845 connection = self.establish_connection(&credentials, cx).fuse() => {
846 match connection {
847 Ok(conn) => {
848 self.state.write().credentials = Some(credentials.clone());
849 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
850 write_credentials_to_keychain(credentials, cx).log_err();
851 }
852
853 futures::select_biased! {
854 result = self.set_connection(conn, cx).fuse() => result,
855 _ = timeout => {
856 self.set_status(Status::ConnectionError, cx);
857 Err(anyhow!("timed out waiting on hello message from server"))
858 }
859 }
860 }
861 Err(EstablishConnectionError::Unauthorized) => {
862 self.state.write().credentials.take();
863 if read_from_keychain {
864 delete_credentials_from_keychain(cx).log_err();
865 self.set_status(Status::SignedOut, cx);
866 self.authenticate_and_connect(false, cx).await
867 } else {
868 self.set_status(Status::ConnectionError, cx);
869 Err(EstablishConnectionError::Unauthorized)?
870 }
871 }
872 Err(EstablishConnectionError::UpgradeRequired) => {
873 self.set_status(Status::UpgradeRequired, cx);
874 Err(EstablishConnectionError::UpgradeRequired)?
875 }
876 Err(error) => {
877 self.set_status(Status::ConnectionError, cx);
878 Err(error)?
879 }
880 }
881 }
882 _ = &mut timeout => {
883 self.set_status(Status::ConnectionError, cx);
884 Err(anyhow!("timed out trying to establish connection"))
885 }
886 }
887 }
888
889 async fn set_connection(
890 self: &Arc<Self>,
891 conn: Connection,
892 cx: &AsyncAppContext,
893 ) -> Result<()> {
894 let executor = cx.executor()?;
895 log::info!("add connection to peer");
896 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
897 let executor = executor.clone();
898 move |duration| executor.timer(duration)
899 });
900 let handle_io = executor.spawn(handle_io);
901
902 let peer_id = async {
903 log::info!("waiting for server hello");
904 let message = incoming
905 .next()
906 .await
907 .ok_or_else(|| anyhow!("no hello message received"))?;
908 log::info!("got server hello");
909 let hello_message_type_name = message.payload_type_name().to_string();
910 let hello = message
911 .into_any()
912 .downcast::<TypedEnvelope<proto::Hello>>()
913 .map_err(|_| {
914 anyhow!(
915 "invalid hello message received: {:?}",
916 hello_message_type_name
917 )
918 })?;
919 let peer_id = hello
920 .payload
921 .peer_id
922 .ok_or_else(|| anyhow!("invalid peer id"))?;
923 Ok(peer_id)
924 };
925
926 let peer_id = match peer_id.await {
927 Ok(peer_id) => peer_id,
928 Err(error) => {
929 self.peer.disconnect(connection_id);
930 return Err(error);
931 }
932 };
933
934 log::info!(
935 "set status to connected (connection id: {:?}, peer id: {:?})",
936 connection_id,
937 peer_id
938 );
939 self.set_status(
940 Status::Connected {
941 peer_id,
942 connection_id,
943 },
944 cx,
945 );
946
947 cx.spawn({
948 let this = self.clone();
949 |cx| {
950 async move {
951 while let Some(message) = incoming.next().await {
952 this.handle_message(message, &cx);
953 // Don't starve the main thread when receiving lots of messages at once.
954 smol::future::yield_now().await;
955 }
956 }
957 }
958 })?
959 .detach();
960
961 cx.spawn({
962 let this = self.clone();
963 move |cx| async move {
964 match handle_io.await {
965 Ok(()) => {
966 if this.status().borrow().clone()
967 == (Status::Connected {
968 connection_id,
969 peer_id,
970 })
971 {
972 this.set_status(Status::SignedOut, &cx);
973 }
974 }
975 Err(err) => {
976 log::error!("connection error: {:?}", err);
977 this.set_status(Status::ConnectionLost, &cx);
978 }
979 }
980 }
981 })?;
982
983 Ok(())
984 }
985
986 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
987 #[cfg(any(test, feature = "test-support"))]
988 if let Some(callback) = self.authenticate.read().as_ref() {
989 return callback(cx);
990 }
991
992 self.authenticate_with_browser(cx)
993 }
994
995 fn establish_connection(
996 self: &Arc<Self>,
997 credentials: &Credentials,
998 cx: &AsyncAppContext,
999 ) -> Task<Result<Connection, EstablishConnectionError>> {
1000 #[cfg(any(test, feature = "test-support"))]
1001 if let Some(callback) = self.establish_connection.read().as_ref() {
1002 return callback(credentials, cx);
1003 }
1004
1005 self.establish_websocket_connection(credentials, cx)
1006 }
1007
1008 async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
1009 let preview_param = if is_preview { "?preview=1" } else { "" };
1010 let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
1011 let response = http.get(&url, Default::default(), false).await?;
1012
1013 // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
1014 // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
1015 // which requires authorization via an HTTP header.
1016 //
1017 // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
1018 // of a collab server. In that case, a request to the /rpc endpoint will
1019 // return an 'unauthorized' response.
1020 let collab_url = if response.status().is_redirection() {
1021 response
1022 .headers()
1023 .get("Location")
1024 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1025 .to_str()
1026 .map_err(EstablishConnectionError::other)?
1027 .to_string()
1028 } else if response.status() == StatusCode::UNAUTHORIZED {
1029 url
1030 } else {
1031 Err(anyhow!(
1032 "unexpected /rpc response status {}",
1033 response.status()
1034 ))?
1035 };
1036
1037 Url::parse(&collab_url).context("invalid rpc url")
1038 }
1039
1040 fn establish_websocket_connection(
1041 self: &Arc<Self>,
1042 credentials: &Credentials,
1043 cx: &AsyncAppContext,
1044 ) -> Task<Result<Connection, EstablishConnectionError>> {
1045 let executor = match cx.executor() {
1046 Ok(executor) => executor,
1047 Err(error) => return Task::ready(Err(error)),
1048 };
1049
1050 let use_preview_server = cx
1051 .try_read_global(|channel: &ReleaseChannel, _| channel != ReleaseChannel::Stable)
1052 .unwrap_or(false);
1053
1054 let request = Request::builder()
1055 .header(
1056 "Authorization",
1057 format!("{} {}", credentials.user_id, credentials.access_token),
1058 )
1059 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
1060
1061 let http = self.http.clone();
1062 executor.spawn(async move {
1063 let mut rpc_url = Self::get_rpc_url(http, use_preview_server).await?;
1064 let rpc_host = rpc_url
1065 .host_str()
1066 .zip(rpc_url.port_or_known_default())
1067 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1068 let stream = smol::net::TcpStream::connect(rpc_host).await?;
1069
1070 log::info!("connected to rpc endpoint {}", rpc_url);
1071
1072 match rpc_url.scheme() {
1073 "https" => {
1074 rpc_url.set_scheme("wss").unwrap();
1075 let request = request.uri(rpc_url.as_str()).body(())?;
1076 let (stream, _) =
1077 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1078 Ok(Connection::new(
1079 stream
1080 .map_err(|error| anyhow!(error))
1081 .sink_map_err(|error| anyhow!(error)),
1082 ))
1083 }
1084 "http" => {
1085 rpc_url.set_scheme("ws").unwrap();
1086 let request = request.uri(rpc_url.as_str()).body(())?;
1087 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1088 Ok(Connection::new(
1089 stream
1090 .map_err(|error| anyhow!(error))
1091 .sink_map_err(|error| anyhow!(error)),
1092 ))
1093 }
1094 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1095 }
1096 })
1097 }
1098
1099 pub fn authenticate_with_browser(
1100 self: &Arc<Self>,
1101 cx: &AsyncAppContext,
1102 ) -> Task<Result<Credentials>> {
1103 let http = self.http.clone();
1104 cx.spawn(|cx| async move {
1105 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1106 // zed server to encrypt the user's access token, so that it can'be intercepted by
1107 // any other app running on the user's device.
1108 let (public_key, private_key) =
1109 rpc::auth::keypair().expect("failed to generate keypair for auth");
1110 let public_key_string =
1111 String::try_from(public_key).expect("failed to serialize public key for auth");
1112
1113 if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1114 return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1115 }
1116
1117 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1118 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1119 let port = server.server_addr().port();
1120
1121 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1122 // that the user is signing in from a Zed app running on the same device.
1123 let mut url = format!(
1124 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1125 *ZED_SERVER_URL, port, public_key_string
1126 );
1127
1128 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1129 log::info!("impersonating user @{}", impersonate_login);
1130 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1131 }
1132
1133 platform.open_url(&url);
1134
1135 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1136 // access token from the query params.
1137 //
1138 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1139 // custom URL scheme instead of this local HTTP server.
1140 let (user_id, access_token) = executor
1141 .spawn(async move {
1142 for _ in 0..100 {
1143 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1144 let path = req.url();
1145 let mut user_id = None;
1146 let mut access_token = None;
1147 let url = Url::parse(&format!("http://example.com{}", path))
1148 .context("failed to parse login notification url")?;
1149 for (key, value) in url.query_pairs() {
1150 if key == "access_token" {
1151 access_token = Some(value.to_string());
1152 } else if key == "user_id" {
1153 user_id = Some(value.to_string());
1154 }
1155 }
1156
1157 let post_auth_url =
1158 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1159 req.respond(
1160 tiny_http::Response::empty(302).with_header(
1161 tiny_http::Header::from_bytes(
1162 &b"Location"[..],
1163 post_auth_url.as_bytes(),
1164 )
1165 .unwrap(),
1166 ),
1167 )
1168 .context("failed to respond to login http request")?;
1169 return Ok((
1170 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1171 access_token
1172 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1173 ));
1174 }
1175 }
1176
1177 Err(anyhow!("didn't receive login redirect"))
1178 })
1179 .await?;
1180
1181 let access_token = private_key
1182 .decrypt_string(&access_token)
1183 .context("failed to decrypt access token")?;
1184 platform.activate(true);
1185
1186 Ok(Credentials {
1187 user_id: user_id.parse()?,
1188 access_token,
1189 })
1190 })
1191 .unwrap_or_else(|error| Task::ready(Err(error)))
1192 }
1193
1194 async fn authenticate_as_admin(
1195 http: Arc<dyn HttpClient>,
1196 login: String,
1197 mut api_token: String,
1198 ) -> Result<Credentials> {
1199 #[derive(Deserialize)]
1200 struct AuthenticatedUserResponse {
1201 user: User,
1202 }
1203
1204 #[derive(Deserialize)]
1205 struct User {
1206 id: u64,
1207 }
1208
1209 // Use the collab server's admin API to retrieve the id
1210 // of the impersonated user.
1211 let mut url = Self::get_rpc_url(http.clone(), false).await?;
1212 url.set_path("/user");
1213 url.set_query(Some(&format!("github_login={login}")));
1214 let request = Request::get(url.as_str())
1215 .header("Authorization", format!("token {api_token}"))
1216 .body("".into())?;
1217
1218 let mut response = http.send(request).await?;
1219 let mut body = String::new();
1220 response.body_mut().read_to_string(&mut body).await?;
1221 if !response.status().is_success() {
1222 Err(anyhow!(
1223 "admin user request failed {} - {}",
1224 response.status().as_u16(),
1225 body,
1226 ))?;
1227 }
1228 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1229
1230 // Use the admin API token to authenticate as the impersonated user.
1231 api_token.insert_str(0, "ADMIN_TOKEN:");
1232 Ok(Credentials {
1233 user_id: response.user.id,
1234 access_token: api_token,
1235 })
1236 }
1237
1238 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1239 self.peer.teardown();
1240 self.set_status(Status::SignedOut, cx);
1241 }
1242
1243 pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1244 self.peer.teardown();
1245 self.set_status(Status::ConnectionLost, cx);
1246 }
1247
1248 fn connection_id(&self) -> Result<ConnectionId> {
1249 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1250 Ok(connection_id)
1251 } else {
1252 Err(anyhow!("not connected"))
1253 }
1254 }
1255
1256 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1257 log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1258 self.peer.send(self.connection_id()?, message)
1259 }
1260
1261 pub fn request<T: RequestMessage>(
1262 &self,
1263 request: T,
1264 ) -> impl Future<Output = Result<T::Response>> {
1265 self.request_envelope(request)
1266 .map_ok(|envelope| envelope.payload)
1267 }
1268
1269 pub fn request_envelope<T: RequestMessage>(
1270 &self,
1271 request: T,
1272 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1273 let client_id = self.id();
1274 log::debug!(
1275 "rpc request start. client_id:{}. name:{}",
1276 client_id,
1277 T::NAME
1278 );
1279 let response = self
1280 .connection_id()
1281 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1282 async move {
1283 let response = response?.await;
1284 log::debug!(
1285 "rpc request finish. client_id:{}. name:{}",
1286 client_id,
1287 T::NAME
1288 );
1289 response
1290 }
1291 }
1292
1293 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1294 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1295 self.peer.respond(receipt, response)
1296 }
1297
1298 fn respond_with_error<T: RequestMessage>(
1299 &self,
1300 receipt: Receipt<T>,
1301 error: proto::Error,
1302 ) -> Result<()> {
1303 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1304 self.peer.respond_with_error(receipt, error)
1305 }
1306
1307 fn handle_message(
1308 self: &Arc<Client>,
1309 message: Box<dyn AnyTypedEnvelope>,
1310 cx: &AsyncAppContext,
1311 ) {
1312 let mut state = self.state.write();
1313 let type_name = message.payload_type_name();
1314 let payload_type_id = message.payload_type_id();
1315 let sender_id = message.original_sender_id();
1316
1317 let mut subscriber = None;
1318
1319 if let Some(handle) = state
1320 .models_by_message_type
1321 .get(&payload_type_id)
1322 .and_then(|handle| handle.upgrade())
1323 {
1324 subscriber = Some(Subscriber {
1325 handle,
1326 window_handle: None,
1327 });
1328 } else if let Some((extract_entity_id, entity_type_id)) =
1329 state.entity_id_extractors.get(&payload_type_id).zip(
1330 state
1331 .entity_types_by_message_type
1332 .get(&payload_type_id)
1333 .copied(),
1334 )
1335 {
1336 let entity_id = (extract_entity_id)(message.as_ref());
1337
1338 match state
1339 .entities_by_type_and_remote_id
1340 .get_mut(&(entity_type_id, entity_id))
1341 {
1342 Some(WeakSubscriber::Pending(pending)) => {
1343 pending.push(message);
1344 return;
1345 }
1346 Some(weak_subscriber @ _) => match weak_subscriber {
1347 WeakSubscriber::Model(handle) => {
1348 subscriber = handle.upgrade(cx).map(Subscriber::Model);
1349 }
1350 WeakSubscriber::View(handle) => {
1351 subscriber = Some(Subscriber::View(handle.clone()));
1352 }
1353 WeakSubscriber::Pending(_) => {}
1354 },
1355 _ => {}
1356 }
1357 }
1358
1359 let subscriber = if let Some(subscriber) = subscriber {
1360 subscriber
1361 } else {
1362 log::info!("unhandled message {}", type_name);
1363 self.peer.respond_with_unhandled_message(message).log_err();
1364 return;
1365 };
1366
1367 let handler = state.message_handlers.get(&payload_type_id).cloned();
1368 // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1369 // It also ensures we don't hold the lock while yielding back to the executor, as
1370 // that might cause the executor thread driving this future to block indefinitely.
1371 drop(state);
1372
1373 if let Some(handler) = handler {
1374 let future = handler(subscriber, message, &self, cx.clone());
1375 let client_id = self.id();
1376 log::debug!(
1377 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1378 client_id,
1379 sender_id,
1380 type_name
1381 );
1382 cx.foreground()
1383 .spawn(async move {
1384 match future.await {
1385 Ok(()) => {
1386 log::debug!(
1387 "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1388 client_id,
1389 sender_id,
1390 type_name
1391 );
1392 }
1393 Err(error) => {
1394 log::error!(
1395 "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1396 client_id,
1397 sender_id,
1398 type_name,
1399 error
1400 );
1401 }
1402 }
1403 })
1404 .detach();
1405 } else {
1406 log::info!("unhandled message {}", type_name);
1407 self.peer.respond_with_unhandled_message(message).log_err();
1408 }
1409 }
1410
1411 pub fn telemetry(&self) -> &Arc<Telemetry> {
1412 &self.telemetry
1413 }
1414}
1415
1416async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1417 if IMPERSONATE_LOGIN.is_some() {
1418 return None;
1419 }
1420
1421 let (user_id, access_token) = cx
1422 .run_on_main(|cx| cx.read_credentials(&ZED_SERVER_URL).log_err().flatten())
1423 .ok()?
1424 .await?;
1425
1426 Some(Credentials {
1427 user_id: user_id.parse().ok()?,
1428 access_token: String::from_utf8(access_token).ok()?,
1429 })
1430}
1431
1432async fn write_credentials_to_keychain(
1433 credentials: Credentials,
1434 cx: &AsyncAppContext,
1435) -> Result<()> {
1436 cx.run_on_main(move |cx| {
1437 cx.write_credentials(
1438 &ZED_SERVER_URL,
1439 &credentials.user_id.to_string(),
1440 credentials.access_token.as_bytes(),
1441 )
1442 })?
1443 .await
1444}
1445
1446async fn delete_credentials_from_keychain(cx: &AsyncAppContext) -> Result<()> {
1447 cx.run_on_main(move |cx| cx.delete_credentials(&ZED_SERVER_URL))?
1448 .await
1449}
1450
1451const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1452
1453pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1454 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1455}
1456
1457pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1458 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1459 let mut parts = path.split('/');
1460 let id = parts.next()?.parse::<u64>().ok()?;
1461 let access_token = parts.next()?;
1462 if access_token.is_empty() {
1463 return None;
1464 }
1465 Some((id, access_token.to_string()))
1466}
1467
1468// #[cfg(test)]
1469// mod tests {
1470// use super::*;
1471// use crate::test::FakeServer;
1472// use gpui::{executor::Deterministic, TestAppContext};
1473// use parking_lot::Mutex;
1474// use std::future;
1475// use util::http::FakeHttpClient;
1476
1477// #[gpui::test(iterations = 10)]
1478// async fn test_reconnection(cx: &mut TestAppContext) {
1479// cx.foreground().forbid_parking();
1480
1481// let user_id = 5;
1482// let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1483// let server = FakeServer::for_client(user_id, &client, cx).await;
1484// let mut status = client.status();
1485// assert!(matches!(
1486// status.next().await,
1487// Some(Status::Connected { .. })
1488// ));
1489// assert_eq!(server.auth_count(), 1);
1490
1491// server.forbid_connections();
1492// server.disconnect();
1493// while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1494
1495// server.allow_connections();
1496// cx.foreground().advance_clock(Duration::from_secs(10));
1497// while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1498// assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1499
1500// server.forbid_connections();
1501// server.disconnect();
1502// while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1503
1504// // Clear cached credentials after authentication fails
1505// server.roll_access_token();
1506// server.allow_connections();
1507// cx.foreground().advance_clock(Duration::from_secs(10));
1508// while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1509// assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1510// }
1511
1512// #[gpui::test(iterations = 10)]
1513// async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1514// deterministic.forbid_parking();
1515
1516// let user_id = 5;
1517// let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1518// let mut status = client.status();
1519
1520// // Time out when client tries to connect.
1521// client.override_authenticate(move |cx| {
1522// cx.foreground().spawn(async move {
1523// Ok(Credentials {
1524// user_id,
1525// access_token: "token".into(),
1526// })
1527// })
1528// });
1529// client.override_establish_connection(|_, cx| {
1530// cx.foreground().spawn(async move {
1531// future::pending::<()>().await;
1532// unreachable!()
1533// })
1534// });
1535// let auth_and_connect = cx.spawn({
1536// let client = client.clone();
1537// |cx| async move { client.authenticate_and_connect(false, &cx).await }
1538// });
1539// deterministic.run_until_parked();
1540// assert!(matches!(status.next().await, Some(Status::Connecting)));
1541
1542// deterministic.advance_clock(CONNECTION_TIMEOUT);
1543// assert!(matches!(
1544// status.next().await,
1545// Some(Status::ConnectionError { .. })
1546// ));
1547// auth_and_connect.await.unwrap_err();
1548
1549// // Allow the connection to be established.
1550// let server = FakeServer::for_client(user_id, &client, cx).await;
1551// assert!(matches!(
1552// status.next().await,
1553// Some(Status::Connected { .. })
1554// ));
1555
1556// // Disconnect client.
1557// server.forbid_connections();
1558// server.disconnect();
1559// while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1560
1561// // Time out when re-establishing the connection.
1562// server.allow_connections();
1563// client.override_establish_connection(|_, cx| {
1564// cx.foreground().spawn(async move {
1565// future::pending::<()>().await;
1566// unreachable!()
1567// })
1568// });
1569// deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1570// assert!(matches!(
1571// status.next().await,
1572// Some(Status::Reconnecting { .. })
1573// ));
1574
1575// deterministic.advance_clock(CONNECTION_TIMEOUT);
1576// assert!(matches!(
1577// status.next().await,
1578// Some(Status::ReconnectionError { .. })
1579// ));
1580// }
1581
1582// #[gpui::test(iterations = 10)]
1583// async fn test_authenticating_more_than_once(
1584// cx: &mut TestAppContext,
1585// deterministic: Arc<Deterministic>,
1586// ) {
1587// cx.foreground().forbid_parking();
1588
1589// let auth_count = Arc::new(Mutex::new(0));
1590// let dropped_auth_count = Arc::new(Mutex::new(0));
1591// let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1592// client.override_authenticate({
1593// let auth_count = auth_count.clone();
1594// let dropped_auth_count = dropped_auth_count.clone();
1595// move |cx| {
1596// let auth_count = auth_count.clone();
1597// let dropped_auth_count = dropped_auth_count.clone();
1598// cx.foreground().spawn(async move {
1599// *auth_count.lock() += 1;
1600// let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1601// future::pending::<()>().await;
1602// unreachable!()
1603// })
1604// }
1605// });
1606
1607// let _authenticate = cx.spawn(|cx| {
1608// let client = client.clone();
1609// async move { client.authenticate_and_connect(false, &cx).await }
1610// });
1611// deterministic.run_until_parked();
1612// assert_eq!(*auth_count.lock(), 1);
1613// assert_eq!(*dropped_auth_count.lock(), 0);
1614
1615// let _authenticate = cx.spawn(|cx| {
1616// let client = client.clone();
1617// async move { client.authenticate_and_connect(false, &cx).await }
1618// });
1619// deterministic.run_until_parked();
1620// assert_eq!(*auth_count.lock(), 2);
1621// assert_eq!(*dropped_auth_count.lock(), 1);
1622// }
1623
1624// #[test]
1625// fn test_encode_and_decode_worktree_url() {
1626// let url = encode_worktree_url(5, "deadbeef");
1627// assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1628// assert_eq!(
1629// decode_worktree_url(&format!("\n {}\t", url)),
1630// Some((5, "deadbeef".to_string()))
1631// );
1632// assert_eq!(decode_worktree_url("not://the-right-format"), None);
1633// }
1634
1635// #[gpui::test]
1636// async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1637// cx.foreground().forbid_parking();
1638
1639// let user_id = 5;
1640// let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1641// let server = FakeServer::for_client(user_id, &client, cx).await;
1642
1643// let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1644// let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1645// client.add_model_message_handler(
1646// move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1647// match model.read_with(&cx, |model, _| model.id) {
1648// 1 => done_tx1.try_send(()).unwrap(),
1649// 2 => done_tx2.try_send(()).unwrap(),
1650// _ => unreachable!(),
1651// }
1652// async { Ok(()) }
1653// },
1654// );
1655// let model1 = cx.add_model(|_| Model {
1656// id: 1,
1657// subscription: None,
1658// });
1659// let model2 = cx.add_model(|_| Model {
1660// id: 2,
1661// subscription: None,
1662// });
1663// let model3 = cx.add_model(|_| Model {
1664// id: 3,
1665// subscription: None,
1666// });
1667
1668// let _subscription1 = client
1669// .subscribe_to_entity(1)
1670// .unwrap()
1671// .set_model(&model1, &mut cx.to_async());
1672// let _subscription2 = client
1673// .subscribe_to_entity(2)
1674// .unwrap()
1675// .set_model(&model2, &mut cx.to_async());
1676// // Ensure dropping a subscription for the same entity type still allows receiving of
1677// // messages for other entity IDs of the same type.
1678// let subscription3 = client
1679// .subscribe_to_entity(3)
1680// .unwrap()
1681// .set_model(&model3, &mut cx.to_async());
1682// drop(subscription3);
1683
1684// server.send(proto::JoinProject { project_id: 1 });
1685// server.send(proto::JoinProject { project_id: 2 });
1686// done_rx1.next().await.unwrap();
1687// done_rx2.next().await.unwrap();
1688// }
1689
1690// #[gpui::test]
1691// async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1692// cx.foreground().forbid_parking();
1693
1694// let user_id = 5;
1695// let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1696// let server = FakeServer::for_client(user_id, &client, cx).await;
1697
1698// let model = cx.add_model(|_| Model::default());
1699// let (done_tx1, _done_rx1) = smol::channel::unbounded();
1700// let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1701// let subscription1 = client.add_message_handler(
1702// model.clone(),
1703// move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1704// done_tx1.try_send(()).unwrap();
1705// async { Ok(()) }
1706// },
1707// );
1708// drop(subscription1);
1709// let _subscription2 = client.add_message_handler(
1710// model.clone(),
1711// move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1712// done_tx2.try_send(()).unwrap();
1713// async { Ok(()) }
1714// },
1715// );
1716// server.send(proto::Ping {});
1717// done_rx2.next().await.unwrap();
1718// }
1719
1720// #[gpui::test]
1721// async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1722// cx.foreground().forbid_parking();
1723
1724// let user_id = 5;
1725// let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1726// let server = FakeServer::for_client(user_id, &client, cx).await;
1727
1728// let model = cx.add_model(|_| Model::default());
1729// let (done_tx, mut done_rx) = smol::channel::unbounded();
1730// let subscription = client.add_message_handler(
1731// model.clone(),
1732// move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1733// model.update(&mut cx, |model, _| model.subscription.take());
1734// done_tx.try_send(()).unwrap();
1735// async { Ok(()) }
1736// },
1737// );
1738// model.update(cx, |model, _| {
1739// model.subscription = Some(subscription);
1740// });
1741// server.send(proto::Ping {});
1742// done_rx.next().await.unwrap();
1743// }
1744
1745// #[derive(Default)]
1746// struct Model {
1747// id: usize,
1748// subscription: Option<Subscription>,
1749// }
1750
1751// impl Entity for Model {
1752// type Event = ();
1753// }
1754// }