1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod telemetry;
5pub mod user;
6
7use anyhow::{anyhow, Context, Result};
8use async_recursion::async_recursion;
9use async_tungstenite::tungstenite::{
10 error::Error as WebsocketError,
11 http::{Request, StatusCode},
12};
13use futures::{
14 future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
15 TryStreamExt,
16};
17use gpui::{
18 actions, platform::AppVersion, serde_json, AnyModelHandle, AnyWeakModelHandle,
19 AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Task, View, ViewContext,
20 WeakViewHandle,
21};
22use lazy_static::lazy_static;
23use parking_lot::RwLock;
24use postage::watch;
25use rand::prelude::*;
26use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
27use schemars::JsonSchema;
28use serde::{Deserialize, Serialize};
29use std::{
30 any::TypeId,
31 collections::HashMap,
32 convert::TryFrom,
33 fmt::Write as _,
34 future::Future,
35 marker::PhantomData,
36 path::PathBuf,
37 sync::{atomic::AtomicU64, Arc, Weak},
38 time::{Duration, Instant},
39};
40use telemetry::Telemetry;
41use thiserror::Error;
42use url::Url;
43use util::channel::ReleaseChannel;
44use util::http::HttpClient;
45use util::{ResultExt, TryFutureExt};
46
47pub use rpc::*;
48pub use telemetry::ClickhouseEvent;
49pub use user::*;
50
51lazy_static! {
52 pub static ref ZED_SERVER_URL: String =
53 std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
54 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
55 .ok()
56 .and_then(|s| if s.is_empty() { None } else { Some(s) });
57 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
58 .ok()
59 .and_then(|s| if s.is_empty() { None } else { Some(s) });
60 pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
61 .ok()
62 .and_then(|v| v.parse().ok());
63 pub static ref ZED_APP_PATH: Option<PathBuf> =
64 std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
65 pub static ref ZED_ALWAYS_ACTIVE: bool =
66 std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| e.len() > 0);
67}
68
69pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
70pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
71pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
72
73actions!(client, [SignIn, SignOut, Reconnect]);
74
75pub fn init_settings(cx: &mut AppContext) {
76 settings::register::<TelemetrySettings>(cx);
77}
78
79pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
80 init_settings(cx);
81
82 let client = Arc::downgrade(client);
83 cx.add_global_action({
84 let client = client.clone();
85 move |_: &SignIn, cx| {
86 if let Some(client) = client.upgrade() {
87 cx.spawn(
88 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
89 )
90 .detach();
91 }
92 }
93 });
94 cx.add_global_action({
95 let client = client.clone();
96 move |_: &SignOut, cx| {
97 if let Some(client) = client.upgrade() {
98 cx.spawn(|cx| async move {
99 client.disconnect(&cx);
100 })
101 .detach();
102 }
103 }
104 });
105 cx.add_global_action({
106 let client = client.clone();
107 move |_: &Reconnect, cx| {
108 if let Some(client) = client.upgrade() {
109 cx.spawn(|cx| async move {
110 client.reconnect(&cx);
111 })
112 .detach();
113 }
114 }
115 });
116}
117
118pub struct Client {
119 id: AtomicU64,
120 peer: Arc<Peer>,
121 http: Arc<dyn HttpClient>,
122 telemetry: Arc<Telemetry>,
123 state: RwLock<ClientState>,
124
125 #[allow(clippy::type_complexity)]
126 #[cfg(any(test, feature = "test-support"))]
127 authenticate: RwLock<
128 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
129 >,
130
131 #[allow(clippy::type_complexity)]
132 #[cfg(any(test, feature = "test-support"))]
133 establish_connection: RwLock<
134 Option<
135 Box<
136 dyn 'static
137 + Send
138 + Sync
139 + Fn(
140 &Credentials,
141 &AsyncAppContext,
142 ) -> Task<Result<Connection, EstablishConnectionError>>,
143 >,
144 >,
145 >,
146}
147
148#[derive(Error, Debug)]
149pub enum EstablishConnectionError {
150 #[error("upgrade required")]
151 UpgradeRequired,
152 #[error("unauthorized")]
153 Unauthorized,
154 #[error("{0}")]
155 Other(#[from] anyhow::Error),
156 #[error("{0}")]
157 Http(#[from] util::http::Error),
158 #[error("{0}")]
159 Io(#[from] std::io::Error),
160 #[error("{0}")]
161 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
162}
163
164impl From<WebsocketError> for EstablishConnectionError {
165 fn from(error: WebsocketError) -> Self {
166 if let WebsocketError::Http(response) = &error {
167 match response.status() {
168 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
169 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
170 _ => {}
171 }
172 }
173 EstablishConnectionError::Other(error.into())
174 }
175}
176
177impl EstablishConnectionError {
178 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
179 Self::Other(error.into())
180 }
181}
182
183#[derive(Copy, Clone, Debug, PartialEq)]
184pub enum Status {
185 SignedOut,
186 UpgradeRequired,
187 Authenticating,
188 Connecting,
189 ConnectionError,
190 Connected {
191 peer_id: PeerId,
192 connection_id: ConnectionId,
193 },
194 ConnectionLost,
195 Reauthenticating,
196 Reconnecting,
197 ReconnectionError {
198 next_reconnection: Instant,
199 },
200}
201
202impl Status {
203 pub fn is_connected(&self) -> bool {
204 matches!(self, Self::Connected { .. })
205 }
206
207 pub fn is_signed_out(&self) -> bool {
208 matches!(self, Self::SignedOut | Self::UpgradeRequired)
209 }
210}
211
212struct ClientState {
213 credentials: Option<Credentials>,
214 status: (watch::Sender<Status>, watch::Receiver<Status>),
215 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
216 _reconnect_task: Option<Task<()>>,
217 reconnect_interval: Duration,
218 entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
219 models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
220 entity_types_by_message_type: HashMap<TypeId, TypeId>,
221 #[allow(clippy::type_complexity)]
222 message_handlers: HashMap<
223 TypeId,
224 Arc<
225 dyn Send
226 + Sync
227 + Fn(
228 Subscriber,
229 Box<dyn AnyTypedEnvelope>,
230 &Arc<Client>,
231 AsyncAppContext,
232 ) -> LocalBoxFuture<'static, Result<()>>,
233 >,
234 >,
235}
236
237enum WeakSubscriber {
238 Model(AnyWeakModelHandle),
239 View(AnyWeakViewHandle),
240 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
241}
242
243enum Subscriber {
244 Model(AnyModelHandle),
245 View(AnyWeakViewHandle),
246}
247
248#[derive(Clone, Debug)]
249pub struct Credentials {
250 pub user_id: u64,
251 pub access_token: String,
252}
253
254impl Default for ClientState {
255 fn default() -> Self {
256 Self {
257 credentials: None,
258 status: watch::channel_with(Status::SignedOut),
259 entity_id_extractors: Default::default(),
260 _reconnect_task: None,
261 reconnect_interval: Duration::from_secs(5),
262 models_by_message_type: Default::default(),
263 entities_by_type_and_remote_id: Default::default(),
264 entity_types_by_message_type: Default::default(),
265 message_handlers: Default::default(),
266 }
267 }
268}
269
270pub enum Subscription {
271 Entity {
272 client: Weak<Client>,
273 id: (TypeId, u64),
274 },
275 Message {
276 client: Weak<Client>,
277 id: TypeId,
278 },
279}
280
281impl Drop for Subscription {
282 fn drop(&mut self) {
283 match self {
284 Subscription::Entity { client, id } => {
285 if let Some(client) = client.upgrade() {
286 let mut state = client.state.write();
287 let _ = state.entities_by_type_and_remote_id.remove(id);
288 }
289 }
290 Subscription::Message { client, id } => {
291 if let Some(client) = client.upgrade() {
292 let mut state = client.state.write();
293 let _ = state.entity_types_by_message_type.remove(id);
294 let _ = state.message_handlers.remove(id);
295 }
296 }
297 }
298 }
299}
300
301pub struct PendingEntitySubscription<T: Entity> {
302 client: Arc<Client>,
303 remote_id: u64,
304 _entity_type: PhantomData<T>,
305 consumed: bool,
306}
307
308impl<T: Entity> PendingEntitySubscription<T> {
309 pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
310 self.consumed = true;
311 let mut state = self.client.state.write();
312 let id = (TypeId::of::<T>(), self.remote_id);
313 let Some(WeakSubscriber::Pending(messages)) =
314 state.entities_by_type_and_remote_id.remove(&id)
315 else {
316 unreachable!()
317 };
318
319 state
320 .entities_by_type_and_remote_id
321 .insert(id, WeakSubscriber::Model(model.downgrade().into_any()));
322 drop(state);
323 for message in messages {
324 self.client.handle_message(message, cx);
325 }
326 Subscription::Entity {
327 client: Arc::downgrade(&self.client),
328 id,
329 }
330 }
331}
332
333impl<T: Entity> Drop for PendingEntitySubscription<T> {
334 fn drop(&mut self) {
335 if !self.consumed {
336 let mut state = self.client.state.write();
337 if let Some(WeakSubscriber::Pending(messages)) = state
338 .entities_by_type_and_remote_id
339 .remove(&(TypeId::of::<T>(), self.remote_id))
340 {
341 for message in messages {
342 log::info!("unhandled message {}", message.payload_type_name());
343 }
344 }
345 }
346 }
347}
348
349#[derive(Debug, Copy, Clone)]
350pub struct TelemetrySettings {
351 pub diagnostics: bool,
352 pub metrics: bool,
353}
354
355#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
356pub struct TelemetrySettingsContent {
357 pub diagnostics: Option<bool>,
358 pub metrics: Option<bool>,
359}
360
361impl settings::Setting for TelemetrySettings {
362 const KEY: Option<&'static str> = Some("telemetry");
363
364 type FileContent = TelemetrySettingsContent;
365
366 fn load(
367 default_value: &Self::FileContent,
368 user_values: &[&Self::FileContent],
369 _: &AppContext,
370 ) -> Result<Self> {
371 Ok(Self {
372 diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
373 default_value
374 .diagnostics
375 .ok_or_else(Self::missing_default)?,
376 ),
377 metrics: user_values
378 .first()
379 .and_then(|v| v.metrics)
380 .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
381 })
382 }
383}
384
385impl Client {
386 pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
387 Arc::new(Self {
388 id: AtomicU64::new(0),
389 peer: Peer::new(0),
390 telemetry: Telemetry::new(http.clone(), cx),
391 http,
392 state: Default::default(),
393
394 #[cfg(any(test, feature = "test-support"))]
395 authenticate: Default::default(),
396 #[cfg(any(test, feature = "test-support"))]
397 establish_connection: Default::default(),
398 })
399 }
400
401 pub fn id(&self) -> u64 {
402 self.id.load(std::sync::atomic::Ordering::SeqCst)
403 }
404
405 pub fn http_client(&self) -> Arc<dyn HttpClient> {
406 self.http.clone()
407 }
408
409 pub fn set_id(&self, id: u64) -> &Self {
410 self.id.store(id, std::sync::atomic::Ordering::SeqCst);
411 self
412 }
413
414 #[cfg(any(test, feature = "test-support"))]
415 pub fn teardown(&self) {
416 let mut state = self.state.write();
417 state._reconnect_task.take();
418 state.message_handlers.clear();
419 state.models_by_message_type.clear();
420 state.entities_by_type_and_remote_id.clear();
421 state.entity_id_extractors.clear();
422 self.peer.teardown();
423 }
424
425 #[cfg(any(test, feature = "test-support"))]
426 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
427 where
428 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
429 {
430 *self.authenticate.write() = Some(Box::new(authenticate));
431 self
432 }
433
434 #[cfg(any(test, feature = "test-support"))]
435 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
436 where
437 F: 'static
438 + Send
439 + Sync
440 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
441 {
442 *self.establish_connection.write() = Some(Box::new(connect));
443 self
444 }
445
446 pub fn user_id(&self) -> Option<u64> {
447 self.state
448 .read()
449 .credentials
450 .as_ref()
451 .map(|credentials| credentials.user_id)
452 }
453
454 pub fn peer_id(&self) -> Option<PeerId> {
455 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
456 Some(*peer_id)
457 } else {
458 None
459 }
460 }
461
462 pub fn status(&self) -> watch::Receiver<Status> {
463 self.state.read().status.1.clone()
464 }
465
466 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
467 log::info!("set status on client {}: {:?}", self.id(), status);
468 let mut state = self.state.write();
469 *state.status.0.borrow_mut() = status;
470
471 match status {
472 Status::Connected { .. } => {
473 state._reconnect_task = None;
474 }
475 Status::ConnectionLost => {
476 let this = self.clone();
477 let reconnect_interval = state.reconnect_interval;
478 state._reconnect_task = Some(cx.spawn(|cx| async move {
479 #[cfg(any(test, feature = "test-support"))]
480 let mut rng = StdRng::seed_from_u64(0);
481 #[cfg(not(any(test, feature = "test-support")))]
482 let mut rng = StdRng::from_entropy();
483
484 let mut delay = INITIAL_RECONNECTION_DELAY;
485 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
486 log::error!("failed to connect {}", error);
487 if matches!(*this.status().borrow(), Status::ConnectionError) {
488 this.set_status(
489 Status::ReconnectionError {
490 next_reconnection: Instant::now() + delay,
491 },
492 &cx,
493 );
494 cx.background().timer(delay).await;
495 delay = delay
496 .mul_f32(rng.gen_range(1.0..=2.0))
497 .min(reconnect_interval);
498 } else {
499 break;
500 }
501 }
502 }));
503 }
504 Status::SignedOut | Status::UpgradeRequired => {
505 cx.read(|cx| self.telemetry.set_authenticated_user_info(None, false, cx));
506 state._reconnect_task.take();
507 }
508 _ => {}
509 }
510 }
511
512 pub fn add_view_for_remote_entity<T: View>(
513 self: &Arc<Self>,
514 remote_id: u64,
515 cx: &mut ViewContext<T>,
516 ) -> Subscription {
517 let id = (TypeId::of::<T>(), remote_id);
518 self.state
519 .write()
520 .entities_by_type_and_remote_id
521 .insert(id, WeakSubscriber::View(cx.weak_handle().into_any()));
522 Subscription::Entity {
523 client: Arc::downgrade(self),
524 id,
525 }
526 }
527
528 pub fn subscribe_to_entity<T: Entity>(
529 self: &Arc<Self>,
530 remote_id: u64,
531 ) -> Result<PendingEntitySubscription<T>> {
532 let id = (TypeId::of::<T>(), remote_id);
533
534 let mut state = self.state.write();
535 if state.entities_by_type_and_remote_id.contains_key(&id) {
536 return Err(anyhow!("already subscribed to entity"));
537 } else {
538 state
539 .entities_by_type_and_remote_id
540 .insert(id, WeakSubscriber::Pending(Default::default()));
541 Ok(PendingEntitySubscription {
542 client: self.clone(),
543 remote_id,
544 consumed: false,
545 _entity_type: PhantomData,
546 })
547 }
548 }
549
550 #[track_caller]
551 pub fn add_message_handler<M, E, H, F>(
552 self: &Arc<Self>,
553 model: ModelHandle<E>,
554 handler: H,
555 ) -> Subscription
556 where
557 M: EnvelopedMessage,
558 E: Entity,
559 H: 'static
560 + Send
561 + Sync
562 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
563 F: 'static + Future<Output = Result<()>>,
564 {
565 let message_type_id = TypeId::of::<M>();
566
567 let mut state = self.state.write();
568 state
569 .models_by_message_type
570 .insert(message_type_id, model.downgrade().into_any());
571
572 let prev_handler = state.message_handlers.insert(
573 message_type_id,
574 Arc::new(move |handle, envelope, client, cx| {
575 let handle = if let Subscriber::Model(handle) = handle {
576 handle
577 } else {
578 unreachable!();
579 };
580 let model = handle.downcast::<E>().unwrap();
581 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
582 handler(model, *envelope, client.clone(), cx).boxed_local()
583 }),
584 );
585 if prev_handler.is_some() {
586 let location = std::panic::Location::caller();
587 panic!(
588 "{}:{} registered handler for the same message {} twice",
589 location.file(),
590 location.line(),
591 std::any::type_name::<M>()
592 );
593 }
594
595 Subscription::Message {
596 client: Arc::downgrade(self),
597 id: message_type_id,
598 }
599 }
600
601 pub fn add_request_handler<M, E, H, F>(
602 self: &Arc<Self>,
603 model: ModelHandle<E>,
604 handler: H,
605 ) -> Subscription
606 where
607 M: RequestMessage,
608 E: Entity,
609 H: 'static
610 + Send
611 + Sync
612 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
613 F: 'static + Future<Output = Result<M::Response>>,
614 {
615 self.add_message_handler(model, move |handle, envelope, this, cx| {
616 Self::respond_to_request(
617 envelope.receipt(),
618 handler(handle, envelope, this.clone(), cx),
619 this,
620 )
621 })
622 }
623
624 pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
625 where
626 M: EntityMessage,
627 E: View,
628 H: 'static
629 + Send
630 + Sync
631 + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
632 F: 'static + Future<Output = Result<()>>,
633 {
634 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
635 if let Subscriber::View(handle) = handle {
636 handler(handle.downcast::<E>().unwrap(), message, client, cx)
637 } else {
638 unreachable!();
639 }
640 })
641 }
642
643 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
644 where
645 M: EntityMessage,
646 E: Entity,
647 H: 'static
648 + Send
649 + Sync
650 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
651 F: 'static + Future<Output = Result<()>>,
652 {
653 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
654 if let Subscriber::Model(handle) = handle {
655 handler(handle.downcast::<E>().unwrap(), message, client, cx)
656 } else {
657 unreachable!();
658 }
659 })
660 }
661
662 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
663 where
664 M: EntityMessage,
665 E: Entity,
666 H: 'static
667 + Send
668 + Sync
669 + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
670 F: 'static + Future<Output = Result<()>>,
671 {
672 let model_type_id = TypeId::of::<E>();
673 let message_type_id = TypeId::of::<M>();
674
675 let mut state = self.state.write();
676 state
677 .entity_types_by_message_type
678 .insert(message_type_id, model_type_id);
679 state
680 .entity_id_extractors
681 .entry(message_type_id)
682 .or_insert_with(|| {
683 |envelope| {
684 envelope
685 .as_any()
686 .downcast_ref::<TypedEnvelope<M>>()
687 .unwrap()
688 .payload
689 .remote_entity_id()
690 }
691 });
692 let prev_handler = state.message_handlers.insert(
693 message_type_id,
694 Arc::new(move |handle, envelope, client, cx| {
695 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
696 handler(handle, *envelope, client.clone(), cx).boxed_local()
697 }),
698 );
699 if prev_handler.is_some() {
700 panic!("registered handler for the same message twice");
701 }
702 }
703
704 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
705 where
706 M: EntityMessage + RequestMessage,
707 E: Entity,
708 H: 'static
709 + Send
710 + Sync
711 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
712 F: 'static + Future<Output = Result<M::Response>>,
713 {
714 self.add_model_message_handler(move |entity, envelope, client, cx| {
715 Self::respond_to_request::<M, _>(
716 envelope.receipt(),
717 handler(entity, envelope, client.clone(), cx),
718 client,
719 )
720 })
721 }
722
723 pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
724 where
725 M: EntityMessage + RequestMessage,
726 E: View,
727 H: 'static
728 + Send
729 + Sync
730 + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
731 F: 'static + Future<Output = Result<M::Response>>,
732 {
733 self.add_view_message_handler(move |entity, envelope, client, cx| {
734 Self::respond_to_request::<M, _>(
735 envelope.receipt(),
736 handler(entity, envelope, client.clone(), cx),
737 client,
738 )
739 })
740 }
741
742 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
743 receipt: Receipt<T>,
744 response: F,
745 client: Arc<Self>,
746 ) -> Result<()> {
747 match response.await {
748 Ok(response) => {
749 client.respond(receipt, response)?;
750 Ok(())
751 }
752 Err(error) => {
753 client.respond_with_error(
754 receipt,
755 proto::Error {
756 message: format!("{:?}", error),
757 },
758 )?;
759 Err(error)
760 }
761 }
762 }
763
764 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
765 read_credentials_from_keychain(cx).is_some()
766 }
767
768 #[async_recursion(?Send)]
769 pub async fn authenticate_and_connect(
770 self: &Arc<Self>,
771 try_keychain: bool,
772 cx: &AsyncAppContext,
773 ) -> anyhow::Result<()> {
774 let was_disconnected = match *self.status().borrow() {
775 Status::SignedOut => true,
776 Status::ConnectionError
777 | Status::ConnectionLost
778 | Status::Authenticating { .. }
779 | Status::Reauthenticating { .. }
780 | Status::ReconnectionError { .. } => false,
781 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
782 return Ok(())
783 }
784 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
785 };
786
787 if was_disconnected {
788 self.set_status(Status::Authenticating, cx);
789 } else {
790 self.set_status(Status::Reauthenticating, cx)
791 }
792
793 let mut read_from_keychain = false;
794 let mut credentials = self.state.read().credentials.clone();
795 if credentials.is_none() && try_keychain {
796 credentials = read_credentials_from_keychain(cx);
797 read_from_keychain = credentials.is_some();
798 }
799 if credentials.is_none() {
800 let mut status_rx = self.status();
801 let _ = status_rx.next().await;
802 futures::select_biased! {
803 authenticate = self.authenticate(cx).fuse() => {
804 match authenticate {
805 Ok(creds) => credentials = Some(creds),
806 Err(err) => {
807 self.set_status(Status::ConnectionError, cx);
808 return Err(err);
809 }
810 }
811 }
812 _ = status_rx.next().fuse() => {
813 return Err(anyhow!("authentication canceled"));
814 }
815 }
816 }
817 let credentials = credentials.unwrap();
818 self.set_id(credentials.user_id);
819
820 if was_disconnected {
821 self.set_status(Status::Connecting, cx);
822 } else {
823 self.set_status(Status::Reconnecting, cx);
824 }
825
826 let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
827 futures::select_biased! {
828 connection = self.establish_connection(&credentials, cx).fuse() => {
829 match connection {
830 Ok(conn) => {
831 self.state.write().credentials = Some(credentials.clone());
832 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
833 write_credentials_to_keychain(&credentials, cx).log_err();
834 }
835
836 futures::select_biased! {
837 result = self.set_connection(conn, cx).fuse() => result,
838 _ = timeout => {
839 self.set_status(Status::ConnectionError, cx);
840 Err(anyhow!("timed out waiting on hello message from server"))
841 }
842 }
843 }
844 Err(EstablishConnectionError::Unauthorized) => {
845 self.state.write().credentials.take();
846 if read_from_keychain {
847 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
848 self.set_status(Status::SignedOut, cx);
849 self.authenticate_and_connect(false, cx).await
850 } else {
851 self.set_status(Status::ConnectionError, cx);
852 Err(EstablishConnectionError::Unauthorized)?
853 }
854 }
855 Err(EstablishConnectionError::UpgradeRequired) => {
856 self.set_status(Status::UpgradeRequired, cx);
857 Err(EstablishConnectionError::UpgradeRequired)?
858 }
859 Err(error) => {
860 self.set_status(Status::ConnectionError, cx);
861 Err(error)?
862 }
863 }
864 }
865 _ = &mut timeout => {
866 self.set_status(Status::ConnectionError, cx);
867 Err(anyhow!("timed out trying to establish connection"))
868 }
869 }
870 }
871
872 async fn set_connection(
873 self: &Arc<Self>,
874 conn: Connection,
875 cx: &AsyncAppContext,
876 ) -> Result<()> {
877 let executor = cx.background();
878 log::info!("add connection to peer");
879 let (connection_id, handle_io, mut incoming) = self
880 .peer
881 .add_connection(conn, move |duration| executor.timer(duration));
882 let handle_io = cx.background().spawn(handle_io);
883
884 let peer_id = async {
885 log::info!("waiting for server hello");
886 let message = incoming
887 .next()
888 .await
889 .ok_or_else(|| anyhow!("no hello message received"))?;
890 log::info!("got server hello");
891 let hello_message_type_name = message.payload_type_name().to_string();
892 let hello = message
893 .into_any()
894 .downcast::<TypedEnvelope<proto::Hello>>()
895 .map_err(|_| {
896 anyhow!(
897 "invalid hello message received: {:?}",
898 hello_message_type_name
899 )
900 })?;
901 let peer_id = hello
902 .payload
903 .peer_id
904 .ok_or_else(|| anyhow!("invalid peer id"))?;
905 Ok(peer_id)
906 };
907
908 let peer_id = match peer_id.await {
909 Ok(peer_id) => peer_id,
910 Err(error) => {
911 self.peer.disconnect(connection_id);
912 return Err(error);
913 }
914 };
915
916 log::info!(
917 "set status to connected (connection id: {:?}, peer id: {:?})",
918 connection_id,
919 peer_id
920 );
921 self.set_status(
922 Status::Connected {
923 peer_id,
924 connection_id,
925 },
926 cx,
927 );
928 cx.foreground()
929 .spawn({
930 let cx = cx.clone();
931 let this = self.clone();
932 async move {
933 while let Some(message) = incoming.next().await {
934 this.handle_message(message, &cx);
935 // Don't starve the main thread when receiving lots of messages at once.
936 smol::future::yield_now().await;
937 }
938 }
939 })
940 .detach();
941
942 let this = self.clone();
943 let cx = cx.clone();
944 cx.foreground()
945 .spawn(async move {
946 match handle_io.await {
947 Ok(()) => {
948 if this.status().borrow().clone()
949 == (Status::Connected {
950 connection_id,
951 peer_id,
952 })
953 {
954 this.set_status(Status::SignedOut, &cx);
955 }
956 }
957 Err(err) => {
958 log::error!("connection error: {:?}", err);
959 this.set_status(Status::ConnectionLost, &cx);
960 }
961 }
962 })
963 .detach();
964
965 Ok(())
966 }
967
968 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
969 #[cfg(any(test, feature = "test-support"))]
970 if let Some(callback) = self.authenticate.read().as_ref() {
971 return callback(cx);
972 }
973
974 self.authenticate_with_browser(cx)
975 }
976
977 fn establish_connection(
978 self: &Arc<Self>,
979 credentials: &Credentials,
980 cx: &AsyncAppContext,
981 ) -> Task<Result<Connection, EstablishConnectionError>> {
982 #[cfg(any(test, feature = "test-support"))]
983 if let Some(callback) = self.establish_connection.read().as_ref() {
984 return callback(credentials, cx);
985 }
986
987 self.establish_websocket_connection(credentials, cx)
988 }
989
990 async fn get_rpc_url(
991 http: Arc<dyn HttpClient>,
992 release_channel: Option<ReleaseChannel>,
993 ) -> Result<Url> {
994 let mut url = format!("{}/rpc", *ZED_SERVER_URL);
995 if let Some(preview_param) =
996 release_channel.and_then(|channel| channel.release_query_param())
997 {
998 url += "?";
999 url += preview_param;
1000 }
1001 let response = http.get(&url, Default::default(), false).await?;
1002
1003 // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
1004 // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
1005 // which requires authorization via an HTTP header.
1006 //
1007 // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
1008 // of a collab server. In that case, a request to the /rpc endpoint will
1009 // return an 'unauthorized' response.
1010 let collab_url = if response.status().is_redirection() {
1011 response
1012 .headers()
1013 .get("Location")
1014 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1015 .to_str()
1016 .map_err(EstablishConnectionError::other)?
1017 .to_string()
1018 } else if response.status() == StatusCode::UNAUTHORIZED {
1019 url
1020 } else {
1021 Err(anyhow!(
1022 "unexpected /rpc response status {}",
1023 response.status()
1024 ))?
1025 };
1026
1027 Url::parse(&collab_url).context("invalid rpc url")
1028 }
1029
1030 fn establish_websocket_connection(
1031 self: &Arc<Self>,
1032 credentials: &Credentials,
1033 cx: &AsyncAppContext,
1034 ) -> Task<Result<Connection, EstablishConnectionError>> {
1035 let release_channel = cx.read(|cx| {
1036 if cx.has_global::<ReleaseChannel>() {
1037 Some(*cx.global::<ReleaseChannel>())
1038 } else {
1039 None
1040 }
1041 });
1042
1043 let request = Request::builder()
1044 .header(
1045 "Authorization",
1046 format!("{} {}", credentials.user_id, credentials.access_token),
1047 )
1048 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
1049
1050 let http = self.http.clone();
1051 cx.background().spawn(async move {
1052 let mut rpc_url = Self::get_rpc_url(http, release_channel).await?;
1053 let rpc_host = rpc_url
1054 .host_str()
1055 .zip(rpc_url.port_or_known_default())
1056 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1057 let stream = smol::net::TcpStream::connect(rpc_host).await?;
1058
1059 log::info!("connected to rpc endpoint {}", rpc_url);
1060
1061 match rpc_url.scheme() {
1062 "https" => {
1063 rpc_url.set_scheme("wss").unwrap();
1064 let request = request.uri(rpc_url.as_str()).body(())?;
1065 let (stream, _) =
1066 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1067 Ok(Connection::new(
1068 stream
1069 .map_err(|error| anyhow!(error))
1070 .sink_map_err(|error| anyhow!(error)),
1071 ))
1072 }
1073 "http" => {
1074 rpc_url.set_scheme("ws").unwrap();
1075 let request = request.uri(rpc_url.as_str()).body(())?;
1076 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1077 Ok(Connection::new(
1078 stream
1079 .map_err(|error| anyhow!(error))
1080 .sink_map_err(|error| anyhow!(error)),
1081 ))
1082 }
1083 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1084 }
1085 })
1086 }
1087
1088 pub fn authenticate_with_browser(
1089 self: &Arc<Self>,
1090 cx: &AsyncAppContext,
1091 ) -> Task<Result<Credentials>> {
1092 let platform = cx.platform();
1093 let executor = cx.background();
1094 let http = self.http.clone();
1095
1096 executor.clone().spawn(async move {
1097 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1098 // zed server to encrypt the user's access token, so that it can'be intercepted by
1099 // any other app running on the user's device.
1100 let (public_key, private_key) =
1101 rpc::auth::keypair().expect("failed to generate keypair for auth");
1102 let public_key_string =
1103 String::try_from(public_key).expect("failed to serialize public key for auth");
1104
1105 if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1106 return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1107 }
1108
1109 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1110 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1111 let port = server.server_addr().port();
1112
1113 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1114 // that the user is signing in from a Zed app running on the same device.
1115 let mut url = format!(
1116 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1117 *ZED_SERVER_URL, port, public_key_string
1118 );
1119
1120 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1121 log::info!("impersonating user @{}", impersonate_login);
1122 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1123 }
1124
1125 platform.open_url(&url);
1126
1127 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1128 // access token from the query params.
1129 //
1130 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1131 // custom URL scheme instead of this local HTTP server.
1132 let (user_id, access_token) = executor
1133 .spawn(async move {
1134 for _ in 0..100 {
1135 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1136 let path = req.url();
1137 let mut user_id = None;
1138 let mut access_token = None;
1139 let url = Url::parse(&format!("http://example.com{}", path))
1140 .context("failed to parse login notification url")?;
1141 for (key, value) in url.query_pairs() {
1142 if key == "access_token" {
1143 access_token = Some(value.to_string());
1144 } else if key == "user_id" {
1145 user_id = Some(value.to_string());
1146 }
1147 }
1148
1149 let post_auth_url =
1150 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1151 req.respond(
1152 tiny_http::Response::empty(302).with_header(
1153 tiny_http::Header::from_bytes(
1154 &b"Location"[..],
1155 post_auth_url.as_bytes(),
1156 )
1157 .unwrap(),
1158 ),
1159 )
1160 .context("failed to respond to login http request")?;
1161 return Ok((
1162 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1163 access_token
1164 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1165 ));
1166 }
1167 }
1168
1169 Err(anyhow!("didn't receive login redirect"))
1170 })
1171 .await?;
1172
1173 let access_token = private_key
1174 .decrypt_string(&access_token)
1175 .context("failed to decrypt access token")?;
1176 platform.activate(true);
1177
1178 Ok(Credentials {
1179 user_id: user_id.parse()?,
1180 access_token,
1181 })
1182 })
1183 }
1184
1185 async fn authenticate_as_admin(
1186 http: Arc<dyn HttpClient>,
1187 login: String,
1188 mut api_token: String,
1189 ) -> Result<Credentials> {
1190 #[derive(Deserialize)]
1191 struct AuthenticatedUserResponse {
1192 user: User,
1193 }
1194
1195 #[derive(Deserialize)]
1196 struct User {
1197 id: u64,
1198 }
1199
1200 // Use the collab server's admin API to retrieve the id
1201 // of the impersonated user.
1202 let mut url = Self::get_rpc_url(http.clone(), None).await?;
1203 url.set_path("/user");
1204 url.set_query(Some(&format!("github_login={login}")));
1205 let request = Request::get(url.as_str())
1206 .header("Authorization", format!("token {api_token}"))
1207 .body("".into())?;
1208
1209 let mut response = http.send(request).await?;
1210 let mut body = String::new();
1211 response.body_mut().read_to_string(&mut body).await?;
1212 if !response.status().is_success() {
1213 Err(anyhow!(
1214 "admin user request failed {} - {}",
1215 response.status().as_u16(),
1216 body,
1217 ))?;
1218 }
1219 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1220
1221 // Use the admin API token to authenticate as the impersonated user.
1222 api_token.insert_str(0, "ADMIN_TOKEN:");
1223 Ok(Credentials {
1224 user_id: response.user.id,
1225 access_token: api_token,
1226 })
1227 }
1228
1229 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1230 self.peer.teardown();
1231 self.set_status(Status::SignedOut, cx);
1232 }
1233
1234 pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1235 self.peer.teardown();
1236 self.set_status(Status::ConnectionLost, cx);
1237 }
1238
1239 fn connection_id(&self) -> Result<ConnectionId> {
1240 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1241 Ok(connection_id)
1242 } else {
1243 Err(anyhow!("not connected"))
1244 }
1245 }
1246
1247 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1248 log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1249 self.peer.send(self.connection_id()?, message)
1250 }
1251
1252 pub fn request<T: RequestMessage>(
1253 &self,
1254 request: T,
1255 ) -> impl Future<Output = Result<T::Response>> {
1256 self.request_envelope(request)
1257 .map_ok(|envelope| envelope.payload)
1258 }
1259
1260 pub fn request_envelope<T: RequestMessage>(
1261 &self,
1262 request: T,
1263 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1264 let client_id = self.id();
1265 log::debug!(
1266 "rpc request start. client_id:{}. name:{}",
1267 client_id,
1268 T::NAME
1269 );
1270 let response = self
1271 .connection_id()
1272 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1273 async move {
1274 let response = response?.await;
1275 log::debug!(
1276 "rpc request finish. client_id:{}. name:{}",
1277 client_id,
1278 T::NAME
1279 );
1280 response
1281 }
1282 }
1283
1284 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1285 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1286 self.peer.respond(receipt, response)
1287 }
1288
1289 fn respond_with_error<T: RequestMessage>(
1290 &self,
1291 receipt: Receipt<T>,
1292 error: proto::Error,
1293 ) -> Result<()> {
1294 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1295 self.peer.respond_with_error(receipt, error)
1296 }
1297
1298 fn handle_message(
1299 self: &Arc<Client>,
1300 message: Box<dyn AnyTypedEnvelope>,
1301 cx: &AsyncAppContext,
1302 ) {
1303 let mut state = self.state.write();
1304 let type_name = message.payload_type_name();
1305 let payload_type_id = message.payload_type_id();
1306 let sender_id = message.original_sender_id();
1307
1308 let mut subscriber = None;
1309
1310 if let Some(message_model) = state
1311 .models_by_message_type
1312 .get(&payload_type_id)
1313 .and_then(|model| model.upgrade(cx))
1314 {
1315 subscriber = Some(Subscriber::Model(message_model));
1316 } else if let Some((extract_entity_id, entity_type_id)) =
1317 state.entity_id_extractors.get(&payload_type_id).zip(
1318 state
1319 .entity_types_by_message_type
1320 .get(&payload_type_id)
1321 .copied(),
1322 )
1323 {
1324 let entity_id = (extract_entity_id)(message.as_ref());
1325
1326 match state
1327 .entities_by_type_and_remote_id
1328 .get_mut(&(entity_type_id, entity_id))
1329 {
1330 Some(WeakSubscriber::Pending(pending)) => {
1331 pending.push(message);
1332 return;
1333 }
1334 Some(weak_subscriber @ _) => match weak_subscriber {
1335 WeakSubscriber::Model(handle) => {
1336 subscriber = handle.upgrade(cx).map(Subscriber::Model);
1337 }
1338 WeakSubscriber::View(handle) => {
1339 subscriber = Some(Subscriber::View(handle.clone()));
1340 }
1341 WeakSubscriber::Pending(_) => {}
1342 },
1343 _ => {}
1344 }
1345 }
1346
1347 let subscriber = if let Some(subscriber) = subscriber {
1348 subscriber
1349 } else {
1350 log::info!("unhandled message {}", type_name);
1351 self.peer.respond_with_unhandled_message(message).log_err();
1352 return;
1353 };
1354
1355 let handler = state.message_handlers.get(&payload_type_id).cloned();
1356 // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1357 // It also ensures we don't hold the lock while yielding back to the executor, as
1358 // that might cause the executor thread driving this future to block indefinitely.
1359 drop(state);
1360
1361 if let Some(handler) = handler {
1362 let future = handler(subscriber, message, &self, cx.clone());
1363 let client_id = self.id();
1364 log::debug!(
1365 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1366 client_id,
1367 sender_id,
1368 type_name
1369 );
1370 cx.foreground()
1371 .spawn(async move {
1372 match future.await {
1373 Ok(()) => {
1374 log::debug!(
1375 "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1376 client_id,
1377 sender_id,
1378 type_name
1379 );
1380 }
1381 Err(error) => {
1382 log::error!(
1383 "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1384 client_id,
1385 sender_id,
1386 type_name,
1387 error
1388 );
1389 }
1390 }
1391 })
1392 .detach();
1393 } else {
1394 log::info!("unhandled message {}", type_name);
1395 self.peer.respond_with_unhandled_message(message).log_err();
1396 }
1397 }
1398
1399 pub fn telemetry(&self) -> &Arc<Telemetry> {
1400 &self.telemetry
1401 }
1402}
1403
1404fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1405 if IMPERSONATE_LOGIN.is_some() {
1406 return None;
1407 }
1408
1409 let (user_id, access_token) = cx
1410 .platform()
1411 .read_credentials(&ZED_SERVER_URL)
1412 .log_err()
1413 .flatten()?;
1414 Some(Credentials {
1415 user_id: user_id.parse().ok()?,
1416 access_token: String::from_utf8(access_token).ok()?,
1417 })
1418}
1419
1420fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1421 cx.platform().write_credentials(
1422 &ZED_SERVER_URL,
1423 &credentials.user_id.to_string(),
1424 credentials.access_token.as_bytes(),
1425 )
1426}
1427
1428const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1429
1430pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1431 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1432}
1433
1434pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1435 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1436 let mut parts = path.split('/');
1437 let id = parts.next()?.parse::<u64>().ok()?;
1438 let access_token = parts.next()?;
1439 if access_token.is_empty() {
1440 return None;
1441 }
1442 Some((id, access_token.to_string()))
1443}
1444
1445#[cfg(test)]
1446mod tests {
1447 use super::*;
1448 use crate::test::FakeServer;
1449 use gpui::{executor::Deterministic, TestAppContext};
1450 use parking_lot::Mutex;
1451 use std::future;
1452 use util::http::FakeHttpClient;
1453
1454 #[gpui::test(iterations = 10)]
1455 async fn test_reconnection(cx: &mut TestAppContext) {
1456 cx.foreground().forbid_parking();
1457
1458 let user_id = 5;
1459 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1460 let server = FakeServer::for_client(user_id, &client, cx).await;
1461 let mut status = client.status();
1462 assert!(matches!(
1463 status.next().await,
1464 Some(Status::Connected { .. })
1465 ));
1466 assert_eq!(server.auth_count(), 1);
1467
1468 server.forbid_connections();
1469 server.disconnect();
1470 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1471
1472 server.allow_connections();
1473 cx.foreground().advance_clock(Duration::from_secs(10));
1474 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1475 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1476
1477 server.forbid_connections();
1478 server.disconnect();
1479 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1480
1481 // Clear cached credentials after authentication fails
1482 server.roll_access_token();
1483 server.allow_connections();
1484 cx.foreground().advance_clock(Duration::from_secs(10));
1485 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1486 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1487 }
1488
1489 #[gpui::test(iterations = 10)]
1490 async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1491 deterministic.forbid_parking();
1492
1493 let user_id = 5;
1494 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1495 let mut status = client.status();
1496
1497 // Time out when client tries to connect.
1498 client.override_authenticate(move |cx| {
1499 cx.foreground().spawn(async move {
1500 Ok(Credentials {
1501 user_id,
1502 access_token: "token".into(),
1503 })
1504 })
1505 });
1506 client.override_establish_connection(|_, cx| {
1507 cx.foreground().spawn(async move {
1508 future::pending::<()>().await;
1509 unreachable!()
1510 })
1511 });
1512 let auth_and_connect = cx.spawn({
1513 let client = client.clone();
1514 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1515 });
1516 deterministic.run_until_parked();
1517 assert!(matches!(status.next().await, Some(Status::Connecting)));
1518
1519 deterministic.advance_clock(CONNECTION_TIMEOUT);
1520 assert!(matches!(
1521 status.next().await,
1522 Some(Status::ConnectionError { .. })
1523 ));
1524 auth_and_connect.await.unwrap_err();
1525
1526 // Allow the connection to be established.
1527 let server = FakeServer::for_client(user_id, &client, cx).await;
1528 assert!(matches!(
1529 status.next().await,
1530 Some(Status::Connected { .. })
1531 ));
1532
1533 // Disconnect client.
1534 server.forbid_connections();
1535 server.disconnect();
1536 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1537
1538 // Time out when re-establishing the connection.
1539 server.allow_connections();
1540 client.override_establish_connection(|_, cx| {
1541 cx.foreground().spawn(async move {
1542 future::pending::<()>().await;
1543 unreachable!()
1544 })
1545 });
1546 deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1547 assert!(matches!(
1548 status.next().await,
1549 Some(Status::Reconnecting { .. })
1550 ));
1551
1552 deterministic.advance_clock(CONNECTION_TIMEOUT);
1553 assert!(matches!(
1554 status.next().await,
1555 Some(Status::ReconnectionError { .. })
1556 ));
1557 }
1558
1559 #[gpui::test(iterations = 10)]
1560 async fn test_authenticating_more_than_once(
1561 cx: &mut TestAppContext,
1562 deterministic: Arc<Deterministic>,
1563 ) {
1564 cx.foreground().forbid_parking();
1565
1566 let auth_count = Arc::new(Mutex::new(0));
1567 let dropped_auth_count = Arc::new(Mutex::new(0));
1568 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1569 client.override_authenticate({
1570 let auth_count = auth_count.clone();
1571 let dropped_auth_count = dropped_auth_count.clone();
1572 move |cx| {
1573 let auth_count = auth_count.clone();
1574 let dropped_auth_count = dropped_auth_count.clone();
1575 cx.foreground().spawn(async move {
1576 *auth_count.lock() += 1;
1577 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1578 future::pending::<()>().await;
1579 unreachable!()
1580 })
1581 }
1582 });
1583
1584 let _authenticate = cx.spawn(|cx| {
1585 let client = client.clone();
1586 async move { client.authenticate_and_connect(false, &cx).await }
1587 });
1588 deterministic.run_until_parked();
1589 assert_eq!(*auth_count.lock(), 1);
1590 assert_eq!(*dropped_auth_count.lock(), 0);
1591
1592 let _authenticate = cx.spawn(|cx| {
1593 let client = client.clone();
1594 async move { client.authenticate_and_connect(false, &cx).await }
1595 });
1596 deterministic.run_until_parked();
1597 assert_eq!(*auth_count.lock(), 2);
1598 assert_eq!(*dropped_auth_count.lock(), 1);
1599 }
1600
1601 #[test]
1602 fn test_encode_and_decode_worktree_url() {
1603 let url = encode_worktree_url(5, "deadbeef");
1604 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1605 assert_eq!(
1606 decode_worktree_url(&format!("\n {}\t", url)),
1607 Some((5, "deadbeef".to_string()))
1608 );
1609 assert_eq!(decode_worktree_url("not://the-right-format"), None);
1610 }
1611
1612 #[gpui::test]
1613 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1614 cx.foreground().forbid_parking();
1615
1616 let user_id = 5;
1617 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1618 let server = FakeServer::for_client(user_id, &client, cx).await;
1619
1620 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1621 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1622 client.add_model_message_handler(
1623 move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1624 match model.read_with(&cx, |model, _| model.id) {
1625 1 => done_tx1.try_send(()).unwrap(),
1626 2 => done_tx2.try_send(()).unwrap(),
1627 _ => unreachable!(),
1628 }
1629 async { Ok(()) }
1630 },
1631 );
1632 let model1 = cx.add_model(|_| Model {
1633 id: 1,
1634 subscription: None,
1635 });
1636 let model2 = cx.add_model(|_| Model {
1637 id: 2,
1638 subscription: None,
1639 });
1640 let model3 = cx.add_model(|_| Model {
1641 id: 3,
1642 subscription: None,
1643 });
1644
1645 let _subscription1 = client
1646 .subscribe_to_entity(1)
1647 .unwrap()
1648 .set_model(&model1, &mut cx.to_async());
1649 let _subscription2 = client
1650 .subscribe_to_entity(2)
1651 .unwrap()
1652 .set_model(&model2, &mut cx.to_async());
1653 // Ensure dropping a subscription for the same entity type still allows receiving of
1654 // messages for other entity IDs of the same type.
1655 let subscription3 = client
1656 .subscribe_to_entity(3)
1657 .unwrap()
1658 .set_model(&model3, &mut cx.to_async());
1659 drop(subscription3);
1660
1661 server.send(proto::JoinProject { project_id: 1 });
1662 server.send(proto::JoinProject { project_id: 2 });
1663 done_rx1.next().await.unwrap();
1664 done_rx2.next().await.unwrap();
1665 }
1666
1667 #[gpui::test]
1668 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1669 cx.foreground().forbid_parking();
1670
1671 let user_id = 5;
1672 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1673 let server = FakeServer::for_client(user_id, &client, cx).await;
1674
1675 let model = cx.add_model(|_| Model::default());
1676 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1677 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1678 let subscription1 = client.add_message_handler(
1679 model.clone(),
1680 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1681 done_tx1.try_send(()).unwrap();
1682 async { Ok(()) }
1683 },
1684 );
1685 drop(subscription1);
1686 let _subscription2 = client.add_message_handler(
1687 model.clone(),
1688 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1689 done_tx2.try_send(()).unwrap();
1690 async { Ok(()) }
1691 },
1692 );
1693 server.send(proto::Ping {});
1694 done_rx2.next().await.unwrap();
1695 }
1696
1697 #[gpui::test]
1698 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1699 cx.foreground().forbid_parking();
1700
1701 let user_id = 5;
1702 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1703 let server = FakeServer::for_client(user_id, &client, cx).await;
1704
1705 let model = cx.add_model(|_| Model::default());
1706 let (done_tx, mut done_rx) = smol::channel::unbounded();
1707 let subscription = client.add_message_handler(
1708 model.clone(),
1709 move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1710 model.update(&mut cx, |model, _| model.subscription.take());
1711 done_tx.try_send(()).unwrap();
1712 async { Ok(()) }
1713 },
1714 );
1715 model.update(cx, |model, _| {
1716 model.subscription = Some(subscription);
1717 });
1718 server.send(proto::Ping {});
1719 done_rx.next().await.unwrap();
1720 }
1721
1722 #[derive(Default)]
1723 struct Model {
1724 id: usize,
1725 subscription: Option<Subscription>,
1726 }
1727
1728 impl Entity for Model {
1729 type Event = ();
1730 }
1731}