1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod telemetry;
5pub mod user;
6
7use anyhow::{anyhow, Context, Result};
8use async_recursion::async_recursion;
9use async_tungstenite::tungstenite::{
10 error::Error as WebsocketError,
11 http::{Request, StatusCode},
12};
13use futures::{
14 future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
15 TryStreamExt,
16};
17use gpui::{
18 actions, platform::AppVersion, serde_json, AnyModelHandle, AnyWeakModelHandle,
19 AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Task, View, ViewContext,
20 WeakViewHandle,
21};
22use lazy_static::lazy_static;
23use parking_lot::RwLock;
24use postage::watch;
25use rand::prelude::*;
26use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
27use schemars::JsonSchema;
28use serde::{Deserialize, Serialize};
29use std::{
30 any::TypeId,
31 collections::HashMap,
32 convert::TryFrom,
33 fmt::Write as _,
34 future::Future,
35 marker::PhantomData,
36 path::PathBuf,
37 sync::{atomic::AtomicU64, Arc, Weak},
38 time::{Duration, Instant},
39};
40use telemetry::Telemetry;
41use thiserror::Error;
42use url::Url;
43use util::channel::ReleaseChannel;
44use util::http::HttpClient;
45use util::{ResultExt, TryFutureExt};
46
47pub use rpc::*;
48pub use telemetry::ClickhouseEvent;
49pub use user::*;
50
51lazy_static! {
52 pub static ref ZED_SERVER_URL: String =
53 std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
54 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
55 .ok()
56 .and_then(|s| if s.is_empty() { None } else { Some(s) });
57 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
58 .ok()
59 .and_then(|s| if s.is_empty() { None } else { Some(s) });
60 pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
61 .ok()
62 .and_then(|v| v.parse().ok());
63 pub static ref ZED_APP_PATH: Option<PathBuf> =
64 std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
65 pub static ref ZED_ALWAYS_ACTIVE: bool =
66 std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| e.len() > 0);
67}
68
69pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
70pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
71pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
72
73actions!(client, [SignIn, SignOut, Reconnect]);
74
75pub fn init_settings(cx: &mut AppContext) {
76 settings::register::<TelemetrySettings>(cx);
77}
78
79pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
80 init_settings(cx);
81
82 let client = Arc::downgrade(client);
83 cx.add_global_action({
84 let client = client.clone();
85 move |_: &SignIn, cx| {
86 if let Some(client) = client.upgrade() {
87 cx.spawn(
88 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
89 )
90 .detach();
91 }
92 }
93 });
94 cx.add_global_action({
95 let client = client.clone();
96 move |_: &SignOut, cx| {
97 if let Some(client) = client.upgrade() {
98 cx.spawn(|cx| async move {
99 client.disconnect(&cx);
100 })
101 .detach();
102 }
103 }
104 });
105 cx.add_global_action({
106 let client = client.clone();
107 move |_: &Reconnect, cx| {
108 if let Some(client) = client.upgrade() {
109 cx.spawn(|cx| async move {
110 client.reconnect(&cx);
111 })
112 .detach();
113 }
114 }
115 });
116}
117
118pub struct Client {
119 id: AtomicU64,
120 peer: Arc<Peer>,
121 http: Arc<dyn HttpClient>,
122 telemetry: Arc<Telemetry>,
123 state: RwLock<ClientState>,
124
125 #[allow(clippy::type_complexity)]
126 #[cfg(any(test, feature = "test-support"))]
127 authenticate: RwLock<
128 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
129 >,
130
131 #[allow(clippy::type_complexity)]
132 #[cfg(any(test, feature = "test-support"))]
133 establish_connection: RwLock<
134 Option<
135 Box<
136 dyn 'static
137 + Send
138 + Sync
139 + Fn(
140 &Credentials,
141 &AsyncAppContext,
142 ) -> Task<Result<Connection, EstablishConnectionError>>,
143 >,
144 >,
145 >,
146}
147
148#[derive(Error, Debug)]
149pub enum EstablishConnectionError {
150 #[error("upgrade required")]
151 UpgradeRequired,
152 #[error("unauthorized")]
153 Unauthorized,
154 #[error("{0}")]
155 Other(#[from] anyhow::Error),
156 #[error("{0}")]
157 Http(#[from] util::http::Error),
158 #[error("{0}")]
159 Io(#[from] std::io::Error),
160 #[error("{0}")]
161 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
162}
163
164impl From<WebsocketError> for EstablishConnectionError {
165 fn from(error: WebsocketError) -> Self {
166 if let WebsocketError::Http(response) = &error {
167 match response.status() {
168 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
169 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
170 _ => {}
171 }
172 }
173 EstablishConnectionError::Other(error.into())
174 }
175}
176
177impl EstablishConnectionError {
178 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
179 Self::Other(error.into())
180 }
181}
182
183#[derive(Copy, Clone, Debug, PartialEq)]
184pub enum Status {
185 SignedOut,
186 UpgradeRequired,
187 Authenticating,
188 Connecting,
189 ConnectionError,
190 Connected {
191 peer_id: PeerId,
192 connection_id: ConnectionId,
193 },
194 ConnectionLost,
195 Reauthenticating,
196 Reconnecting,
197 ReconnectionError {
198 next_reconnection: Instant,
199 },
200}
201
202impl Status {
203 pub fn is_connected(&self) -> bool {
204 matches!(self, Self::Connected { .. })
205 }
206
207 pub fn is_signed_out(&self) -> bool {
208 matches!(self, Self::SignedOut | Self::UpgradeRequired)
209 }
210}
211
212struct ClientState {
213 credentials: Option<Credentials>,
214 status: (watch::Sender<Status>, watch::Receiver<Status>),
215 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
216 _reconnect_task: Option<Task<()>>,
217 reconnect_interval: Duration,
218 entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
219 models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
220 entity_types_by_message_type: HashMap<TypeId, TypeId>,
221 #[allow(clippy::type_complexity)]
222 message_handlers: HashMap<
223 TypeId,
224 Arc<
225 dyn Send
226 + Sync
227 + Fn(
228 Subscriber,
229 Box<dyn AnyTypedEnvelope>,
230 &Arc<Client>,
231 AsyncAppContext,
232 ) -> LocalBoxFuture<'static, Result<()>>,
233 >,
234 >,
235}
236
237enum WeakSubscriber {
238 Model(AnyWeakModelHandle),
239 View(AnyWeakViewHandle),
240 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
241}
242
243enum Subscriber {
244 Model(AnyModelHandle),
245 View(AnyWeakViewHandle),
246}
247
248#[derive(Clone, Debug)]
249pub struct Credentials {
250 pub user_id: u64,
251 pub access_token: String,
252}
253
254impl Default for ClientState {
255 fn default() -> Self {
256 Self {
257 credentials: None,
258 status: watch::channel_with(Status::SignedOut),
259 entity_id_extractors: Default::default(),
260 _reconnect_task: None,
261 reconnect_interval: Duration::from_secs(5),
262 models_by_message_type: Default::default(),
263 entities_by_type_and_remote_id: Default::default(),
264 entity_types_by_message_type: Default::default(),
265 message_handlers: Default::default(),
266 }
267 }
268}
269
270pub enum Subscription {
271 Entity {
272 client: Weak<Client>,
273 id: (TypeId, u64),
274 },
275 Message {
276 client: Weak<Client>,
277 id: TypeId,
278 },
279}
280
281impl Drop for Subscription {
282 fn drop(&mut self) {
283 match self {
284 Subscription::Entity { client, id } => {
285 if let Some(client) = client.upgrade() {
286 let mut state = client.state.write();
287 let _ = state.entities_by_type_and_remote_id.remove(id);
288 }
289 }
290 Subscription::Message { client, id } => {
291 if let Some(client) = client.upgrade() {
292 let mut state = client.state.write();
293 let _ = state.entity_types_by_message_type.remove(id);
294 let _ = state.message_handlers.remove(id);
295 }
296 }
297 }
298 }
299}
300
301pub struct PendingEntitySubscription<T: Entity> {
302 client: Arc<Client>,
303 remote_id: u64,
304 _entity_type: PhantomData<T>,
305 consumed: bool,
306}
307
308impl<T: Entity> PendingEntitySubscription<T> {
309 pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
310 self.consumed = true;
311 let mut state = self.client.state.write();
312 let id = (TypeId::of::<T>(), self.remote_id);
313 let Some(WeakSubscriber::Pending(messages)) =
314 state.entities_by_type_and_remote_id.remove(&id)
315 else {
316 unreachable!()
317 };
318
319 state
320 .entities_by_type_and_remote_id
321 .insert(id, WeakSubscriber::Model(model.downgrade().into_any()));
322 drop(state);
323 for message in messages {
324 self.client.handle_message(message, cx);
325 }
326 Subscription::Entity {
327 client: Arc::downgrade(&self.client),
328 id,
329 }
330 }
331}
332
333impl<T: Entity> Drop for PendingEntitySubscription<T> {
334 fn drop(&mut self) {
335 if !self.consumed {
336 let mut state = self.client.state.write();
337 if let Some(WeakSubscriber::Pending(messages)) = state
338 .entities_by_type_and_remote_id
339 .remove(&(TypeId::of::<T>(), self.remote_id))
340 {
341 for message in messages {
342 log::info!("unhandled message {}", message.payload_type_name());
343 }
344 }
345 }
346 }
347}
348
349#[derive(Copy, Clone)]
350pub struct TelemetrySettings {
351 pub diagnostics: bool,
352 pub metrics: bool,
353}
354
355#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
356pub struct TelemetrySettingsContent {
357 pub diagnostics: Option<bool>,
358 pub metrics: Option<bool>,
359}
360
361impl settings::Setting for TelemetrySettings {
362 const KEY: Option<&'static str> = Some("telemetry");
363
364 type FileContent = TelemetrySettingsContent;
365
366 fn load(
367 default_value: &Self::FileContent,
368 user_values: &[&Self::FileContent],
369 _: &AppContext,
370 ) -> Result<Self> {
371 Ok(Self {
372 diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
373 default_value
374 .diagnostics
375 .ok_or_else(Self::missing_default)?,
376 ),
377 metrics: user_values
378 .first()
379 .and_then(|v| v.metrics)
380 .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
381 })
382 }
383}
384
385impl Client {
386 pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
387 Arc::new(Self {
388 id: AtomicU64::new(0),
389 peer: Peer::new(0),
390 telemetry: Telemetry::new(http.clone(), cx),
391 http,
392 state: Default::default(),
393
394 #[cfg(any(test, feature = "test-support"))]
395 authenticate: Default::default(),
396 #[cfg(any(test, feature = "test-support"))]
397 establish_connection: Default::default(),
398 })
399 }
400
401 pub fn id(&self) -> u64 {
402 self.id.load(std::sync::atomic::Ordering::SeqCst)
403 }
404
405 pub fn http_client(&self) -> Arc<dyn HttpClient> {
406 self.http.clone()
407 }
408
409 pub fn set_id(&self, id: u64) -> &Self {
410 self.id.store(id, std::sync::atomic::Ordering::SeqCst);
411 self
412 }
413
414 #[cfg(any(test, feature = "test-support"))]
415 pub fn teardown(&self) {
416 let mut state = self.state.write();
417 state._reconnect_task.take();
418 state.message_handlers.clear();
419 state.models_by_message_type.clear();
420 state.entities_by_type_and_remote_id.clear();
421 state.entity_id_extractors.clear();
422 self.peer.teardown();
423 }
424
425 #[cfg(any(test, feature = "test-support"))]
426 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
427 where
428 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
429 {
430 *self.authenticate.write() = Some(Box::new(authenticate));
431 self
432 }
433
434 #[cfg(any(test, feature = "test-support"))]
435 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
436 where
437 F: 'static
438 + Send
439 + Sync
440 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
441 {
442 *self.establish_connection.write() = Some(Box::new(connect));
443 self
444 }
445
446 pub fn user_id(&self) -> Option<u64> {
447 self.state
448 .read()
449 .credentials
450 .as_ref()
451 .map(|credentials| credentials.user_id)
452 }
453
454 pub fn peer_id(&self) -> Option<PeerId> {
455 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
456 Some(*peer_id)
457 } else {
458 None
459 }
460 }
461
462 pub fn status(&self) -> watch::Receiver<Status> {
463 self.state.read().status.1.clone()
464 }
465
466 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
467 log::info!("set status on client {}: {:?}", self.id(), status);
468 let mut state = self.state.write();
469 *state.status.0.borrow_mut() = status;
470
471 match status {
472 Status::Connected { .. } => {
473 state._reconnect_task = None;
474 }
475 Status::ConnectionLost => {
476 let this = self.clone();
477 let reconnect_interval = state.reconnect_interval;
478 state._reconnect_task = Some(cx.spawn(|cx| async move {
479 #[cfg(any(test, feature = "test-support"))]
480 let mut rng = StdRng::seed_from_u64(0);
481 #[cfg(not(any(test, feature = "test-support")))]
482 let mut rng = StdRng::from_entropy();
483
484 let mut delay = INITIAL_RECONNECTION_DELAY;
485 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
486 log::error!("failed to connect {}", error);
487 if matches!(*this.status().borrow(), Status::ConnectionError) {
488 this.set_status(
489 Status::ReconnectionError {
490 next_reconnection: Instant::now() + delay,
491 },
492 &cx,
493 );
494 cx.background().timer(delay).await;
495 delay = delay
496 .mul_f32(rng.gen_range(1.0..=2.0))
497 .min(reconnect_interval);
498 } else {
499 break;
500 }
501 }
502 }));
503 }
504 Status::SignedOut | Status::UpgradeRequired => {
505 cx.read(|cx| self.telemetry.set_authenticated_user_info(None, false, cx));
506 state._reconnect_task.take();
507 }
508 _ => {}
509 }
510 }
511
512 pub fn add_view_for_remote_entity<T: View>(
513 self: &Arc<Self>,
514 remote_id: u64,
515 cx: &mut ViewContext<T>,
516 ) -> Subscription {
517 let id = (TypeId::of::<T>(), remote_id);
518 self.state
519 .write()
520 .entities_by_type_and_remote_id
521 .insert(id, WeakSubscriber::View(cx.weak_handle().into_any()));
522 Subscription::Entity {
523 client: Arc::downgrade(self),
524 id,
525 }
526 }
527
528 pub fn subscribe_to_entity<T: Entity>(
529 self: &Arc<Self>,
530 remote_id: u64,
531 ) -> Result<PendingEntitySubscription<T>> {
532 let id = (TypeId::of::<T>(), remote_id);
533
534 let mut state = self.state.write();
535 if state.entities_by_type_and_remote_id.contains_key(&id) {
536 return Err(anyhow!("already subscribed to entity"));
537 } else {
538 state
539 .entities_by_type_and_remote_id
540 .insert(id, WeakSubscriber::Pending(Default::default()));
541 Ok(PendingEntitySubscription {
542 client: self.clone(),
543 remote_id,
544 consumed: false,
545 _entity_type: PhantomData,
546 })
547 }
548 }
549
550 #[track_caller]
551 pub fn add_message_handler<M, E, H, F>(
552 self: &Arc<Self>,
553 model: ModelHandle<E>,
554 handler: H,
555 ) -> Subscription
556 where
557 M: EnvelopedMessage,
558 E: Entity,
559 H: 'static
560 + Send
561 + Sync
562 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
563 F: 'static + Future<Output = Result<()>>,
564 {
565 let message_type_id = TypeId::of::<M>();
566
567 let mut state = self.state.write();
568 state
569 .models_by_message_type
570 .insert(message_type_id, model.downgrade().into_any());
571
572 let prev_handler = state.message_handlers.insert(
573 message_type_id,
574 Arc::new(move |handle, envelope, client, cx| {
575 let handle = if let Subscriber::Model(handle) = handle {
576 handle
577 } else {
578 unreachable!();
579 };
580 let model = handle.downcast::<E>().unwrap();
581 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
582 handler(model, *envelope, client.clone(), cx).boxed_local()
583 }),
584 );
585 if prev_handler.is_some() {
586 let location = std::panic::Location::caller();
587 panic!(
588 "{}:{} registered handler for the same message {} twice",
589 location.file(),
590 location.line(),
591 std::any::type_name::<M>()
592 );
593 }
594
595 Subscription::Message {
596 client: Arc::downgrade(self),
597 id: message_type_id,
598 }
599 }
600
601 pub fn add_request_handler<M, E, H, F>(
602 self: &Arc<Self>,
603 model: ModelHandle<E>,
604 handler: H,
605 ) -> Subscription
606 where
607 M: RequestMessage,
608 E: Entity,
609 H: 'static
610 + Send
611 + Sync
612 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
613 F: 'static + Future<Output = Result<M::Response>>,
614 {
615 self.add_message_handler(model, move |handle, envelope, this, cx| {
616 Self::respond_to_request(
617 envelope.receipt(),
618 handler(handle, envelope, this.clone(), cx),
619 this,
620 )
621 })
622 }
623
624 pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
625 where
626 M: EntityMessage,
627 E: View,
628 H: 'static
629 + Send
630 + Sync
631 + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
632 F: 'static + Future<Output = Result<()>>,
633 {
634 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
635 if let Subscriber::View(handle) = handle {
636 handler(handle.downcast::<E>().unwrap(), message, client, cx)
637 } else {
638 unreachable!();
639 }
640 })
641 }
642
643 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
644 where
645 M: EntityMessage,
646 E: Entity,
647 H: 'static
648 + Send
649 + Sync
650 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
651 F: 'static + Future<Output = Result<()>>,
652 {
653 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
654 if let Subscriber::Model(handle) = handle {
655 handler(handle.downcast::<E>().unwrap(), message, client, cx)
656 } else {
657 unreachable!();
658 }
659 })
660 }
661
662 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
663 where
664 M: EntityMessage,
665 E: Entity,
666 H: 'static
667 + Send
668 + Sync
669 + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
670 F: 'static + Future<Output = Result<()>>,
671 {
672 let model_type_id = TypeId::of::<E>();
673 let message_type_id = TypeId::of::<M>();
674
675 let mut state = self.state.write();
676 state
677 .entity_types_by_message_type
678 .insert(message_type_id, model_type_id);
679 state
680 .entity_id_extractors
681 .entry(message_type_id)
682 .or_insert_with(|| {
683 |envelope| {
684 envelope
685 .as_any()
686 .downcast_ref::<TypedEnvelope<M>>()
687 .unwrap()
688 .payload
689 .remote_entity_id()
690 }
691 });
692 let prev_handler = state.message_handlers.insert(
693 message_type_id,
694 Arc::new(move |handle, envelope, client, cx| {
695 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
696 handler(handle, *envelope, client.clone(), cx).boxed_local()
697 }),
698 );
699 if prev_handler.is_some() {
700 panic!("registered handler for the same message twice");
701 }
702 }
703
704 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
705 where
706 M: EntityMessage + RequestMessage,
707 E: Entity,
708 H: 'static
709 + Send
710 + Sync
711 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
712 F: 'static + Future<Output = Result<M::Response>>,
713 {
714 self.add_model_message_handler(move |entity, envelope, client, cx| {
715 Self::respond_to_request::<M, _>(
716 envelope.receipt(),
717 handler(entity, envelope, client.clone(), cx),
718 client,
719 )
720 })
721 }
722
723 pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
724 where
725 M: EntityMessage + RequestMessage,
726 E: View,
727 H: 'static
728 + Send
729 + Sync
730 + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
731 F: 'static + Future<Output = Result<M::Response>>,
732 {
733 self.add_view_message_handler(move |entity, envelope, client, cx| {
734 Self::respond_to_request::<M, _>(
735 envelope.receipt(),
736 handler(entity, envelope, client.clone(), cx),
737 client,
738 )
739 })
740 }
741
742 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
743 receipt: Receipt<T>,
744 response: F,
745 client: Arc<Self>,
746 ) -> Result<()> {
747 match response.await {
748 Ok(response) => {
749 client.respond(receipt, response)?;
750 Ok(())
751 }
752 Err(error) => {
753 client.respond_with_error(
754 receipt,
755 proto::Error {
756 message: format!("{:?}", error),
757 },
758 )?;
759 Err(error)
760 }
761 }
762 }
763
764 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
765 read_credentials_from_keychain(cx).is_some()
766 }
767
768 #[async_recursion(?Send)]
769 pub async fn authenticate_and_connect(
770 self: &Arc<Self>,
771 try_keychain: bool,
772 cx: &AsyncAppContext,
773 ) -> anyhow::Result<()> {
774 let was_disconnected = match *self.status().borrow() {
775 Status::SignedOut => true,
776 Status::ConnectionError
777 | Status::ConnectionLost
778 | Status::Authenticating { .. }
779 | Status::Reauthenticating { .. }
780 | Status::ReconnectionError { .. } => false,
781 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
782 return Ok(())
783 }
784 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
785 };
786
787 if was_disconnected {
788 self.set_status(Status::Authenticating, cx);
789 } else {
790 self.set_status(Status::Reauthenticating, cx)
791 }
792
793 let mut read_from_keychain = false;
794 let mut credentials = self.state.read().credentials.clone();
795 if credentials.is_none() && try_keychain {
796 credentials = read_credentials_from_keychain(cx);
797 read_from_keychain = credentials.is_some();
798 }
799 if credentials.is_none() {
800 let mut status_rx = self.status();
801 let _ = status_rx.next().await;
802 futures::select_biased! {
803 authenticate = self.authenticate(cx).fuse() => {
804 match authenticate {
805 Ok(creds) => credentials = Some(creds),
806 Err(err) => {
807 self.set_status(Status::ConnectionError, cx);
808 return Err(err);
809 }
810 }
811 }
812 _ = status_rx.next().fuse() => {
813 return Err(anyhow!("authentication canceled"));
814 }
815 }
816 }
817 let credentials = credentials.unwrap();
818 self.set_id(credentials.user_id);
819
820 if was_disconnected {
821 self.set_status(Status::Connecting, cx);
822 } else {
823 self.set_status(Status::Reconnecting, cx);
824 }
825
826 let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
827 futures::select_biased! {
828 connection = self.establish_connection(&credentials, cx).fuse() => {
829 match connection {
830 Ok(conn) => {
831 self.state.write().credentials = Some(credentials.clone());
832 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
833 write_credentials_to_keychain(&credentials, cx).log_err();
834 }
835
836 futures::select_biased! {
837 result = self.set_connection(conn, cx).fuse() => result,
838 _ = timeout => {
839 self.set_status(Status::ConnectionError, cx);
840 Err(anyhow!("timed out waiting on hello message from server"))
841 }
842 }
843 }
844 Err(EstablishConnectionError::Unauthorized) => {
845 self.state.write().credentials.take();
846 if read_from_keychain {
847 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
848 self.set_status(Status::SignedOut, cx);
849 self.authenticate_and_connect(false, cx).await
850 } else {
851 self.set_status(Status::ConnectionError, cx);
852 Err(EstablishConnectionError::Unauthorized)?
853 }
854 }
855 Err(EstablishConnectionError::UpgradeRequired) => {
856 self.set_status(Status::UpgradeRequired, cx);
857 Err(EstablishConnectionError::UpgradeRequired)?
858 }
859 Err(error) => {
860 self.set_status(Status::ConnectionError, cx);
861 Err(error)?
862 }
863 }
864 }
865 _ = &mut timeout => {
866 self.set_status(Status::ConnectionError, cx);
867 Err(anyhow!("timed out trying to establish connection"))
868 }
869 }
870 }
871
872 async fn set_connection(
873 self: &Arc<Self>,
874 conn: Connection,
875 cx: &AsyncAppContext,
876 ) -> Result<()> {
877 let executor = cx.background();
878 log::info!("add connection to peer");
879 let (connection_id, handle_io, mut incoming) = self
880 .peer
881 .add_connection(conn, move |duration| executor.timer(duration));
882 let handle_io = cx.background().spawn(handle_io);
883
884 let peer_id = async {
885 log::info!("waiting for server hello");
886 let message = incoming
887 .next()
888 .await
889 .ok_or_else(|| anyhow!("no hello message received"))?;
890 log::info!("got server hello");
891 let hello_message_type_name = message.payload_type_name().to_string();
892 let hello = message
893 .into_any()
894 .downcast::<TypedEnvelope<proto::Hello>>()
895 .map_err(|_| {
896 anyhow!(
897 "invalid hello message received: {:?}",
898 hello_message_type_name
899 )
900 })?;
901 let peer_id = hello
902 .payload
903 .peer_id
904 .ok_or_else(|| anyhow!("invalid peer id"))?;
905 Ok(peer_id)
906 };
907
908 let peer_id = match peer_id.await {
909 Ok(peer_id) => peer_id,
910 Err(error) => {
911 self.peer.disconnect(connection_id);
912 return Err(error);
913 }
914 };
915
916 log::info!(
917 "set status to connected (connection id: {:?}, peer id: {:?})",
918 connection_id,
919 peer_id
920 );
921 self.set_status(
922 Status::Connected {
923 peer_id,
924 connection_id,
925 },
926 cx,
927 );
928 cx.foreground()
929 .spawn({
930 let cx = cx.clone();
931 let this = self.clone();
932 async move {
933 while let Some(message) = incoming.next().await {
934 this.handle_message(message, &cx);
935 // Don't starve the main thread when receiving lots of messages at once.
936 smol::future::yield_now().await;
937 }
938 }
939 })
940 .detach();
941
942 let this = self.clone();
943 let cx = cx.clone();
944 cx.foreground()
945 .spawn(async move {
946 match handle_io.await {
947 Ok(()) => {
948 if this.status().borrow().clone()
949 == (Status::Connected {
950 connection_id,
951 peer_id,
952 })
953 {
954 this.set_status(Status::SignedOut, &cx);
955 }
956 }
957 Err(err) => {
958 log::error!("connection error: {:?}", err);
959 this.set_status(Status::ConnectionLost, &cx);
960 }
961 }
962 })
963 .detach();
964
965 Ok(())
966 }
967
968 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
969 #[cfg(any(test, feature = "test-support"))]
970 if let Some(callback) = self.authenticate.read().as_ref() {
971 return callback(cx);
972 }
973
974 self.authenticate_with_browser(cx)
975 }
976
977 fn establish_connection(
978 self: &Arc<Self>,
979 credentials: &Credentials,
980 cx: &AsyncAppContext,
981 ) -> Task<Result<Connection, EstablishConnectionError>> {
982 #[cfg(any(test, feature = "test-support"))]
983 if let Some(callback) = self.establish_connection.read().as_ref() {
984 return callback(credentials, cx);
985 }
986
987 self.establish_websocket_connection(credentials, cx)
988 }
989
990 async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
991 let preview_param = if is_preview { "?preview=1" } else { "" };
992 let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
993 let response = http.get(&url, Default::default(), false).await?;
994
995 // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
996 // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
997 // which requires authorization via an HTTP header.
998 //
999 // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
1000 // of a collab server. In that case, a request to the /rpc endpoint will
1001 // return an 'unauthorized' response.
1002 let collab_url = if response.status().is_redirection() {
1003 response
1004 .headers()
1005 .get("Location")
1006 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
1007 .to_str()
1008 .map_err(EstablishConnectionError::other)?
1009 .to_string()
1010 } else if response.status() == StatusCode::UNAUTHORIZED {
1011 url
1012 } else {
1013 Err(anyhow!(
1014 "unexpected /rpc response status {}",
1015 response.status()
1016 ))?
1017 };
1018
1019 Url::parse(&collab_url).context("invalid rpc url")
1020 }
1021
1022 fn establish_websocket_connection(
1023 self: &Arc<Self>,
1024 credentials: &Credentials,
1025 cx: &AsyncAppContext,
1026 ) -> Task<Result<Connection, EstablishConnectionError>> {
1027 let use_preview_server = cx.read(|cx| {
1028 if cx.has_global::<ReleaseChannel>() {
1029 *cx.global::<ReleaseChannel>() != ReleaseChannel::Stable
1030 } else {
1031 false
1032 }
1033 });
1034
1035 let request = Request::builder()
1036 .header(
1037 "Authorization",
1038 format!("{} {}", credentials.user_id, credentials.access_token),
1039 )
1040 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
1041
1042 let http = self.http.clone();
1043 cx.background().spawn(async move {
1044 let mut rpc_url = Self::get_rpc_url(http, use_preview_server).await?;
1045 let rpc_host = rpc_url
1046 .host_str()
1047 .zip(rpc_url.port_or_known_default())
1048 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1049 let stream = smol::net::TcpStream::connect(rpc_host).await?;
1050
1051 log::info!("connected to rpc endpoint {}", rpc_url);
1052
1053 match rpc_url.scheme() {
1054 "https" => {
1055 rpc_url.set_scheme("wss").unwrap();
1056 let request = request.uri(rpc_url.as_str()).body(())?;
1057 let (stream, _) =
1058 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1059 Ok(Connection::new(
1060 stream
1061 .map_err(|error| anyhow!(error))
1062 .sink_map_err(|error| anyhow!(error)),
1063 ))
1064 }
1065 "http" => {
1066 rpc_url.set_scheme("ws").unwrap();
1067 let request = request.uri(rpc_url.as_str()).body(())?;
1068 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1069 Ok(Connection::new(
1070 stream
1071 .map_err(|error| anyhow!(error))
1072 .sink_map_err(|error| anyhow!(error)),
1073 ))
1074 }
1075 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1076 }
1077 })
1078 }
1079
1080 pub fn authenticate_with_browser(
1081 self: &Arc<Self>,
1082 cx: &AsyncAppContext,
1083 ) -> Task<Result<Credentials>> {
1084 let platform = cx.platform();
1085 let executor = cx.background();
1086 let http = self.http.clone();
1087
1088 executor.clone().spawn(async move {
1089 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1090 // zed server to encrypt the user's access token, so that it can'be intercepted by
1091 // any other app running on the user's device.
1092 let (public_key, private_key) =
1093 rpc::auth::keypair().expect("failed to generate keypair for auth");
1094 let public_key_string =
1095 String::try_from(public_key).expect("failed to serialize public key for auth");
1096
1097 if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1098 return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1099 }
1100
1101 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1102 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1103 let port = server.server_addr().port();
1104
1105 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1106 // that the user is signing in from a Zed app running on the same device.
1107 let mut url = format!(
1108 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1109 *ZED_SERVER_URL, port, public_key_string
1110 );
1111
1112 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1113 log::info!("impersonating user @{}", impersonate_login);
1114 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1115 }
1116
1117 platform.open_url(&url);
1118
1119 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1120 // access token from the query params.
1121 //
1122 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1123 // custom URL scheme instead of this local HTTP server.
1124 let (user_id, access_token) = executor
1125 .spawn(async move {
1126 for _ in 0..100 {
1127 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1128 let path = req.url();
1129 let mut user_id = None;
1130 let mut access_token = None;
1131 let url = Url::parse(&format!("http://example.com{}", path))
1132 .context("failed to parse login notification url")?;
1133 for (key, value) in url.query_pairs() {
1134 if key == "access_token" {
1135 access_token = Some(value.to_string());
1136 } else if key == "user_id" {
1137 user_id = Some(value.to_string());
1138 }
1139 }
1140
1141 let post_auth_url =
1142 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1143 req.respond(
1144 tiny_http::Response::empty(302).with_header(
1145 tiny_http::Header::from_bytes(
1146 &b"Location"[..],
1147 post_auth_url.as_bytes(),
1148 )
1149 .unwrap(),
1150 ),
1151 )
1152 .context("failed to respond to login http request")?;
1153 return Ok((
1154 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1155 access_token
1156 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1157 ));
1158 }
1159 }
1160
1161 Err(anyhow!("didn't receive login redirect"))
1162 })
1163 .await?;
1164
1165 let access_token = private_key
1166 .decrypt_string(&access_token)
1167 .context("failed to decrypt access token")?;
1168 platform.activate(true);
1169
1170 Ok(Credentials {
1171 user_id: user_id.parse()?,
1172 access_token,
1173 })
1174 })
1175 }
1176
1177 async fn authenticate_as_admin(
1178 http: Arc<dyn HttpClient>,
1179 login: String,
1180 mut api_token: String,
1181 ) -> Result<Credentials> {
1182 #[derive(Deserialize)]
1183 struct AuthenticatedUserResponse {
1184 user: User,
1185 }
1186
1187 #[derive(Deserialize)]
1188 struct User {
1189 id: u64,
1190 }
1191
1192 // Use the collab server's admin API to retrieve the id
1193 // of the impersonated user.
1194 let mut url = Self::get_rpc_url(http.clone(), false).await?;
1195 url.set_path("/user");
1196 url.set_query(Some(&format!("github_login={login}")));
1197 let request = Request::get(url.as_str())
1198 .header("Authorization", format!("token {api_token}"))
1199 .body("".into())?;
1200
1201 let mut response = http.send(request).await?;
1202 let mut body = String::new();
1203 response.body_mut().read_to_string(&mut body).await?;
1204 if !response.status().is_success() {
1205 Err(anyhow!(
1206 "admin user request failed {} - {}",
1207 response.status().as_u16(),
1208 body,
1209 ))?;
1210 }
1211 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1212
1213 // Use the admin API token to authenticate as the impersonated user.
1214 api_token.insert_str(0, "ADMIN_TOKEN:");
1215 Ok(Credentials {
1216 user_id: response.user.id,
1217 access_token: api_token,
1218 })
1219 }
1220
1221 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1222 self.peer.teardown();
1223 self.set_status(Status::SignedOut, cx);
1224 }
1225
1226 pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1227 self.peer.teardown();
1228 self.set_status(Status::ConnectionLost, cx);
1229 }
1230
1231 fn connection_id(&self) -> Result<ConnectionId> {
1232 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1233 Ok(connection_id)
1234 } else {
1235 Err(anyhow!("not connected"))
1236 }
1237 }
1238
1239 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1240 log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1241 self.peer.send(self.connection_id()?, message)
1242 }
1243
1244 pub fn request<T: RequestMessage>(
1245 &self,
1246 request: T,
1247 ) -> impl Future<Output = Result<T::Response>> {
1248 self.request_envelope(request)
1249 .map_ok(|envelope| envelope.payload)
1250 }
1251
1252 pub fn request_envelope<T: RequestMessage>(
1253 &self,
1254 request: T,
1255 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1256 let client_id = self.id();
1257 log::debug!(
1258 "rpc request start. client_id:{}. name:{}",
1259 client_id,
1260 T::NAME
1261 );
1262 let response = self
1263 .connection_id()
1264 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1265 async move {
1266 let response = response?.await;
1267 log::debug!(
1268 "rpc request finish. client_id:{}. name:{}",
1269 client_id,
1270 T::NAME
1271 );
1272 response
1273 }
1274 }
1275
1276 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1277 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1278 self.peer.respond(receipt, response)
1279 }
1280
1281 fn respond_with_error<T: RequestMessage>(
1282 &self,
1283 receipt: Receipt<T>,
1284 error: proto::Error,
1285 ) -> Result<()> {
1286 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1287 self.peer.respond_with_error(receipt, error)
1288 }
1289
1290 fn handle_message(
1291 self: &Arc<Client>,
1292 message: Box<dyn AnyTypedEnvelope>,
1293 cx: &AsyncAppContext,
1294 ) {
1295 let mut state = self.state.write();
1296 let type_name = message.payload_type_name();
1297 let payload_type_id = message.payload_type_id();
1298 let sender_id = message.original_sender_id();
1299
1300 let mut subscriber = None;
1301
1302 if let Some(message_model) = state
1303 .models_by_message_type
1304 .get(&payload_type_id)
1305 .and_then(|model| model.upgrade(cx))
1306 {
1307 subscriber = Some(Subscriber::Model(message_model));
1308 } else if let Some((extract_entity_id, entity_type_id)) =
1309 state.entity_id_extractors.get(&payload_type_id).zip(
1310 state
1311 .entity_types_by_message_type
1312 .get(&payload_type_id)
1313 .copied(),
1314 )
1315 {
1316 let entity_id = (extract_entity_id)(message.as_ref());
1317
1318 match state
1319 .entities_by_type_and_remote_id
1320 .get_mut(&(entity_type_id, entity_id))
1321 {
1322 Some(WeakSubscriber::Pending(pending)) => {
1323 pending.push(message);
1324 return;
1325 }
1326 Some(weak_subscriber @ _) => match weak_subscriber {
1327 WeakSubscriber::Model(handle) => {
1328 subscriber = handle.upgrade(cx).map(Subscriber::Model);
1329 }
1330 WeakSubscriber::View(handle) => {
1331 subscriber = Some(Subscriber::View(handle.clone()));
1332 }
1333 WeakSubscriber::Pending(_) => {}
1334 },
1335 _ => {}
1336 }
1337 }
1338
1339 let subscriber = if let Some(subscriber) = subscriber {
1340 subscriber
1341 } else {
1342 log::info!("unhandled message {}", type_name);
1343 self.peer.respond_with_unhandled_message(message).log_err();
1344 return;
1345 };
1346
1347 let handler = state.message_handlers.get(&payload_type_id).cloned();
1348 // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1349 // It also ensures we don't hold the lock while yielding back to the executor, as
1350 // that might cause the executor thread driving this future to block indefinitely.
1351 drop(state);
1352
1353 if let Some(handler) = handler {
1354 let future = handler(subscriber, message, &self, cx.clone());
1355 let client_id = self.id();
1356 log::debug!(
1357 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1358 client_id,
1359 sender_id,
1360 type_name
1361 );
1362 cx.foreground()
1363 .spawn(async move {
1364 match future.await {
1365 Ok(()) => {
1366 log::debug!(
1367 "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1368 client_id,
1369 sender_id,
1370 type_name
1371 );
1372 }
1373 Err(error) => {
1374 log::error!(
1375 "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1376 client_id,
1377 sender_id,
1378 type_name,
1379 error
1380 );
1381 }
1382 }
1383 })
1384 .detach();
1385 } else {
1386 log::info!("unhandled message {}", type_name);
1387 self.peer.respond_with_unhandled_message(message).log_err();
1388 }
1389 }
1390
1391 pub fn telemetry(&self) -> &Arc<Telemetry> {
1392 &self.telemetry
1393 }
1394}
1395
1396fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1397 if IMPERSONATE_LOGIN.is_some() {
1398 return None;
1399 }
1400
1401 let (user_id, access_token) = cx
1402 .platform()
1403 .read_credentials(&ZED_SERVER_URL)
1404 .log_err()
1405 .flatten()?;
1406 Some(Credentials {
1407 user_id: user_id.parse().ok()?,
1408 access_token: String::from_utf8(access_token).ok()?,
1409 })
1410}
1411
1412fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1413 cx.platform().write_credentials(
1414 &ZED_SERVER_URL,
1415 &credentials.user_id.to_string(),
1416 credentials.access_token.as_bytes(),
1417 )
1418}
1419
1420const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1421
1422pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1423 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1424}
1425
1426pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1427 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1428 let mut parts = path.split('/');
1429 let id = parts.next()?.parse::<u64>().ok()?;
1430 let access_token = parts.next()?;
1431 if access_token.is_empty() {
1432 return None;
1433 }
1434 Some((id, access_token.to_string()))
1435}
1436
1437#[cfg(test)]
1438mod tests {
1439 use super::*;
1440 use crate::test::FakeServer;
1441 use gpui::{executor::Deterministic, TestAppContext};
1442 use parking_lot::Mutex;
1443 use std::future;
1444 use util::http::FakeHttpClient;
1445
1446 #[gpui::test(iterations = 10)]
1447 async fn test_reconnection(cx: &mut TestAppContext) {
1448 cx.foreground().forbid_parking();
1449
1450 let user_id = 5;
1451 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1452 let server = FakeServer::for_client(user_id, &client, cx).await;
1453 let mut status = client.status();
1454 assert!(matches!(
1455 status.next().await,
1456 Some(Status::Connected { .. })
1457 ));
1458 assert_eq!(server.auth_count(), 1);
1459
1460 server.forbid_connections();
1461 server.disconnect();
1462 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1463
1464 server.allow_connections();
1465 cx.foreground().advance_clock(Duration::from_secs(10));
1466 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1467 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1468
1469 server.forbid_connections();
1470 server.disconnect();
1471 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1472
1473 // Clear cached credentials after authentication fails
1474 server.roll_access_token();
1475 server.allow_connections();
1476 cx.foreground().advance_clock(Duration::from_secs(10));
1477 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1478 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1479 }
1480
1481 #[gpui::test(iterations = 10)]
1482 async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1483 deterministic.forbid_parking();
1484
1485 let user_id = 5;
1486 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1487 let mut status = client.status();
1488
1489 // Time out when client tries to connect.
1490 client.override_authenticate(move |cx| {
1491 cx.foreground().spawn(async move {
1492 Ok(Credentials {
1493 user_id,
1494 access_token: "token".into(),
1495 })
1496 })
1497 });
1498 client.override_establish_connection(|_, cx| {
1499 cx.foreground().spawn(async move {
1500 future::pending::<()>().await;
1501 unreachable!()
1502 })
1503 });
1504 let auth_and_connect = cx.spawn({
1505 let client = client.clone();
1506 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1507 });
1508 deterministic.run_until_parked();
1509 assert!(matches!(status.next().await, Some(Status::Connecting)));
1510
1511 deterministic.advance_clock(CONNECTION_TIMEOUT);
1512 assert!(matches!(
1513 status.next().await,
1514 Some(Status::ConnectionError { .. })
1515 ));
1516 auth_and_connect.await.unwrap_err();
1517
1518 // Allow the connection to be established.
1519 let server = FakeServer::for_client(user_id, &client, cx).await;
1520 assert!(matches!(
1521 status.next().await,
1522 Some(Status::Connected { .. })
1523 ));
1524
1525 // Disconnect client.
1526 server.forbid_connections();
1527 server.disconnect();
1528 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1529
1530 // Time out when re-establishing the connection.
1531 server.allow_connections();
1532 client.override_establish_connection(|_, cx| {
1533 cx.foreground().spawn(async move {
1534 future::pending::<()>().await;
1535 unreachable!()
1536 })
1537 });
1538 deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1539 assert!(matches!(
1540 status.next().await,
1541 Some(Status::Reconnecting { .. })
1542 ));
1543
1544 deterministic.advance_clock(CONNECTION_TIMEOUT);
1545 assert!(matches!(
1546 status.next().await,
1547 Some(Status::ReconnectionError { .. })
1548 ));
1549 }
1550
1551 #[gpui::test(iterations = 10)]
1552 async fn test_authenticating_more_than_once(
1553 cx: &mut TestAppContext,
1554 deterministic: Arc<Deterministic>,
1555 ) {
1556 cx.foreground().forbid_parking();
1557
1558 let auth_count = Arc::new(Mutex::new(0));
1559 let dropped_auth_count = Arc::new(Mutex::new(0));
1560 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1561 client.override_authenticate({
1562 let auth_count = auth_count.clone();
1563 let dropped_auth_count = dropped_auth_count.clone();
1564 move |cx| {
1565 let auth_count = auth_count.clone();
1566 let dropped_auth_count = dropped_auth_count.clone();
1567 cx.foreground().spawn(async move {
1568 *auth_count.lock() += 1;
1569 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1570 future::pending::<()>().await;
1571 unreachable!()
1572 })
1573 }
1574 });
1575
1576 let _authenticate = cx.spawn(|cx| {
1577 let client = client.clone();
1578 async move { client.authenticate_and_connect(false, &cx).await }
1579 });
1580 deterministic.run_until_parked();
1581 assert_eq!(*auth_count.lock(), 1);
1582 assert_eq!(*dropped_auth_count.lock(), 0);
1583
1584 let _authenticate = cx.spawn(|cx| {
1585 let client = client.clone();
1586 async move { client.authenticate_and_connect(false, &cx).await }
1587 });
1588 deterministic.run_until_parked();
1589 assert_eq!(*auth_count.lock(), 2);
1590 assert_eq!(*dropped_auth_count.lock(), 1);
1591 }
1592
1593 #[test]
1594 fn test_encode_and_decode_worktree_url() {
1595 let url = encode_worktree_url(5, "deadbeef");
1596 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1597 assert_eq!(
1598 decode_worktree_url(&format!("\n {}\t", url)),
1599 Some((5, "deadbeef".to_string()))
1600 );
1601 assert_eq!(decode_worktree_url("not://the-right-format"), None);
1602 }
1603
1604 #[gpui::test]
1605 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1606 cx.foreground().forbid_parking();
1607
1608 let user_id = 5;
1609 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1610 let server = FakeServer::for_client(user_id, &client, cx).await;
1611
1612 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1613 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1614 client.add_model_message_handler(
1615 move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1616 match model.read_with(&cx, |model, _| model.id) {
1617 1 => done_tx1.try_send(()).unwrap(),
1618 2 => done_tx2.try_send(()).unwrap(),
1619 _ => unreachable!(),
1620 }
1621 async { Ok(()) }
1622 },
1623 );
1624 let model1 = cx.add_model(|_| Model {
1625 id: 1,
1626 subscription: None,
1627 });
1628 let model2 = cx.add_model(|_| Model {
1629 id: 2,
1630 subscription: None,
1631 });
1632 let model3 = cx.add_model(|_| Model {
1633 id: 3,
1634 subscription: None,
1635 });
1636
1637 let _subscription1 = client
1638 .subscribe_to_entity(1)
1639 .unwrap()
1640 .set_model(&model1, &mut cx.to_async());
1641 let _subscription2 = client
1642 .subscribe_to_entity(2)
1643 .unwrap()
1644 .set_model(&model2, &mut cx.to_async());
1645 // Ensure dropping a subscription for the same entity type still allows receiving of
1646 // messages for other entity IDs of the same type.
1647 let subscription3 = client
1648 .subscribe_to_entity(3)
1649 .unwrap()
1650 .set_model(&model3, &mut cx.to_async());
1651 drop(subscription3);
1652
1653 server.send(proto::JoinProject { project_id: 1 });
1654 server.send(proto::JoinProject { project_id: 2 });
1655 done_rx1.next().await.unwrap();
1656 done_rx2.next().await.unwrap();
1657 }
1658
1659 #[gpui::test]
1660 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1661 cx.foreground().forbid_parking();
1662
1663 let user_id = 5;
1664 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1665 let server = FakeServer::for_client(user_id, &client, cx).await;
1666
1667 let model = cx.add_model(|_| Model::default());
1668 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1669 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1670 let subscription1 = client.add_message_handler(
1671 model.clone(),
1672 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1673 done_tx1.try_send(()).unwrap();
1674 async { Ok(()) }
1675 },
1676 );
1677 drop(subscription1);
1678 let _subscription2 = client.add_message_handler(
1679 model.clone(),
1680 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1681 done_tx2.try_send(()).unwrap();
1682 async { Ok(()) }
1683 },
1684 );
1685 server.send(proto::Ping {});
1686 done_rx2.next().await.unwrap();
1687 }
1688
1689 #[gpui::test]
1690 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1691 cx.foreground().forbid_parking();
1692
1693 let user_id = 5;
1694 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1695 let server = FakeServer::for_client(user_id, &client, cx).await;
1696
1697 let model = cx.add_model(|_| Model::default());
1698 let (done_tx, mut done_rx) = smol::channel::unbounded();
1699 let subscription = client.add_message_handler(
1700 model.clone(),
1701 move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1702 model.update(&mut cx, |model, _| model.subscription.take());
1703 done_tx.try_send(()).unwrap();
1704 async { Ok(()) }
1705 },
1706 );
1707 model.update(cx, |model, _| {
1708 model.subscription = Some(subscription);
1709 });
1710 server.send(proto::Ping {});
1711 done_rx.next().await.unwrap();
1712 }
1713
1714 #[derive(Default)]
1715 struct Model {
1716 id: usize,
1717 subscription: Option<Subscription>,
1718 }
1719
1720 impl Entity for Model {
1721 type Event = ();
1722 }
1723}