1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod telemetry;
5pub mod user;
6
7use anyhow::{anyhow, Context, Result};
8use async_recursion::async_recursion;
9use async_tungstenite::tungstenite::{
10 error::Error as WebsocketError,
11 http::{Request, StatusCode},
12};
13use futures::{
14 future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
15 TryStreamExt,
16};
17use gpui::{
18 actions, platform::AppVersion, serde_json, AnyModelHandle, AnyWeakModelHandle,
19 AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Task, View, ViewContext,
20 WeakViewHandle,
21};
22use lazy_static::lazy_static;
23use parking_lot::RwLock;
24use postage::watch;
25use rand::prelude::*;
26use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
27use schemars::JsonSchema;
28use serde::{Deserialize, Serialize};
29use std::{
30 any::TypeId,
31 collections::HashMap,
32 convert::TryFrom,
33 fmt::Write as _,
34 future::Future,
35 marker::PhantomData,
36 path::PathBuf,
37 sync::{Arc, Weak},
38 time::{Duration, Instant},
39};
40use telemetry::Telemetry;
41use thiserror::Error;
42use url::Url;
43use util::channel::ReleaseChannel;
44use util::http::HttpClient;
45use util::{ResultExt, TryFutureExt};
46
47pub use rpc::*;
48pub use telemetry::ClickhouseEvent;
49pub use user::*;
50
51lazy_static! {
52 pub static ref ZED_SERVER_URL: String =
53 std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
54 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
55 .ok()
56 .and_then(|s| if s.is_empty() { None } else { Some(s) });
57 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
58 .ok()
59 .and_then(|s| if s.is_empty() { None } else { Some(s) });
60 pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
61 .ok()
62 .and_then(|v| v.parse().ok());
63 pub static ref ZED_APP_PATH: Option<PathBuf> =
64 std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
65}
66
67pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
68pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
69pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
70
71actions!(client, [SignIn, SignOut]);
72
73pub fn init_settings(cx: &mut AppContext) {
74 settings::register::<TelemetrySettings>(cx);
75}
76
77pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
78 init_settings(cx);
79
80 let client = Arc::downgrade(client);
81 cx.add_global_action({
82 let client = client.clone();
83 move |_: &SignIn, cx| {
84 if let Some(client) = client.upgrade() {
85 cx.spawn(
86 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
87 )
88 .detach();
89 }
90 }
91 });
92 cx.add_global_action({
93 let client = client.clone();
94 move |_: &SignOut, cx| {
95 if let Some(client) = client.upgrade() {
96 cx.spawn(|cx| async move {
97 client.disconnect(&cx);
98 })
99 .detach();
100 }
101 }
102 });
103}
104
105pub struct Client {
106 id: usize,
107 peer: Arc<Peer>,
108 http: Arc<dyn HttpClient>,
109 telemetry: Arc<Telemetry>,
110 state: RwLock<ClientState>,
111
112 #[allow(clippy::type_complexity)]
113 #[cfg(any(test, feature = "test-support"))]
114 authenticate: RwLock<
115 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
116 >,
117
118 #[allow(clippy::type_complexity)]
119 #[cfg(any(test, feature = "test-support"))]
120 establish_connection: RwLock<
121 Option<
122 Box<
123 dyn 'static
124 + Send
125 + Sync
126 + Fn(
127 &Credentials,
128 &AsyncAppContext,
129 ) -> Task<Result<Connection, EstablishConnectionError>>,
130 >,
131 >,
132 >,
133}
134
135#[derive(Error, Debug)]
136pub enum EstablishConnectionError {
137 #[error("upgrade required")]
138 UpgradeRequired,
139 #[error("unauthorized")]
140 Unauthorized,
141 #[error("{0}")]
142 Other(#[from] anyhow::Error),
143 #[error("{0}")]
144 Http(#[from] util::http::Error),
145 #[error("{0}")]
146 Io(#[from] std::io::Error),
147 #[error("{0}")]
148 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
149}
150
151impl From<WebsocketError> for EstablishConnectionError {
152 fn from(error: WebsocketError) -> Self {
153 if let WebsocketError::Http(response) = &error {
154 match response.status() {
155 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
156 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
157 _ => {}
158 }
159 }
160 EstablishConnectionError::Other(error.into())
161 }
162}
163
164impl EstablishConnectionError {
165 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
166 Self::Other(error.into())
167 }
168}
169
170#[derive(Copy, Clone, Debug, PartialEq)]
171pub enum Status {
172 SignedOut,
173 UpgradeRequired,
174 Authenticating,
175 Connecting,
176 ConnectionError,
177 Connected {
178 peer_id: PeerId,
179 connection_id: ConnectionId,
180 },
181 ConnectionLost,
182 Reauthenticating,
183 Reconnecting,
184 ReconnectionError {
185 next_reconnection: Instant,
186 },
187}
188
189impl Status {
190 pub fn is_connected(&self) -> bool {
191 matches!(self, Self::Connected { .. })
192 }
193
194 pub fn is_signed_out(&self) -> bool {
195 matches!(self, Self::SignedOut | Self::UpgradeRequired)
196 }
197}
198
199struct ClientState {
200 credentials: Option<Credentials>,
201 status: (watch::Sender<Status>, watch::Receiver<Status>),
202 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
203 _reconnect_task: Option<Task<()>>,
204 reconnect_interval: Duration,
205 entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
206 models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
207 entity_types_by_message_type: HashMap<TypeId, TypeId>,
208 #[allow(clippy::type_complexity)]
209 message_handlers: HashMap<
210 TypeId,
211 Arc<
212 dyn Send
213 + Sync
214 + Fn(
215 Subscriber,
216 Box<dyn AnyTypedEnvelope>,
217 &Arc<Client>,
218 AsyncAppContext,
219 ) -> LocalBoxFuture<'static, Result<()>>,
220 >,
221 >,
222}
223
224enum WeakSubscriber {
225 Model(AnyWeakModelHandle),
226 View(AnyWeakViewHandle),
227 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
228}
229
230enum Subscriber {
231 Model(AnyModelHandle),
232 View(AnyWeakViewHandle),
233}
234
235#[derive(Clone, Debug)]
236pub struct Credentials {
237 pub user_id: u64,
238 pub access_token: String,
239}
240
241impl Default for ClientState {
242 fn default() -> Self {
243 Self {
244 credentials: None,
245 status: watch::channel_with(Status::SignedOut),
246 entity_id_extractors: Default::default(),
247 _reconnect_task: None,
248 reconnect_interval: Duration::from_secs(5),
249 models_by_message_type: Default::default(),
250 entities_by_type_and_remote_id: Default::default(),
251 entity_types_by_message_type: Default::default(),
252 message_handlers: Default::default(),
253 }
254 }
255}
256
257pub enum Subscription {
258 Entity {
259 client: Weak<Client>,
260 id: (TypeId, u64),
261 },
262 Message {
263 client: Weak<Client>,
264 id: TypeId,
265 },
266}
267
268impl Drop for Subscription {
269 fn drop(&mut self) {
270 match self {
271 Subscription::Entity { client, id } => {
272 if let Some(client) = client.upgrade() {
273 let mut state = client.state.write();
274 let _ = state.entities_by_type_and_remote_id.remove(id);
275 }
276 }
277 Subscription::Message { client, id } => {
278 if let Some(client) = client.upgrade() {
279 let mut state = client.state.write();
280 let _ = state.entity_types_by_message_type.remove(id);
281 let _ = state.message_handlers.remove(id);
282 }
283 }
284 }
285 }
286}
287
288pub struct PendingEntitySubscription<T: Entity> {
289 client: Arc<Client>,
290 remote_id: u64,
291 _entity_type: PhantomData<T>,
292 consumed: bool,
293}
294
295impl<T: Entity> PendingEntitySubscription<T> {
296 pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
297 self.consumed = true;
298 let mut state = self.client.state.write();
299 let id = (TypeId::of::<T>(), self.remote_id);
300 let Some(WeakSubscriber::Pending(messages)) =
301 state.entities_by_type_and_remote_id.remove(&id)
302 else {
303 unreachable!()
304 };
305
306 state
307 .entities_by_type_and_remote_id
308 .insert(id, WeakSubscriber::Model(model.downgrade().into_any()));
309 drop(state);
310 for message in messages {
311 self.client.handle_message(message, cx);
312 }
313 Subscription::Entity {
314 client: Arc::downgrade(&self.client),
315 id,
316 }
317 }
318}
319
320impl<T: Entity> Drop for PendingEntitySubscription<T> {
321 fn drop(&mut self) {
322 if !self.consumed {
323 let mut state = self.client.state.write();
324 if let Some(WeakSubscriber::Pending(messages)) = state
325 .entities_by_type_and_remote_id
326 .remove(&(TypeId::of::<T>(), self.remote_id))
327 {
328 for message in messages {
329 log::info!("unhandled message {}", message.payload_type_name());
330 }
331 }
332 }
333 }
334}
335
336#[derive(Copy, Clone)]
337pub struct TelemetrySettings {
338 pub diagnostics: bool,
339 pub metrics: bool,
340}
341
342#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
343pub struct TelemetrySettingsContent {
344 pub diagnostics: Option<bool>,
345 pub metrics: Option<bool>,
346}
347
348impl settings::Setting for TelemetrySettings {
349 const KEY: Option<&'static str> = Some("telemetry");
350
351 type FileContent = TelemetrySettingsContent;
352
353 fn load(
354 default_value: &Self::FileContent,
355 user_values: &[&Self::FileContent],
356 _: &AppContext,
357 ) -> Result<Self> {
358 Ok(Self {
359 diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
360 default_value
361 .diagnostics
362 .ok_or_else(Self::missing_default)?,
363 ),
364 metrics: user_values
365 .first()
366 .and_then(|v| v.metrics)
367 .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
368 })
369 }
370}
371
372impl Client {
373 pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
374 Arc::new(Self {
375 id: 0,
376 peer: Peer::new(0),
377 telemetry: Telemetry::new(http.clone(), cx),
378 http,
379 state: Default::default(),
380
381 #[cfg(any(test, feature = "test-support"))]
382 authenticate: Default::default(),
383 #[cfg(any(test, feature = "test-support"))]
384 establish_connection: Default::default(),
385 })
386 }
387
388 pub fn id(&self) -> usize {
389 self.id
390 }
391
392 pub fn http_client(&self) -> Arc<dyn HttpClient> {
393 self.http.clone()
394 }
395
396 #[cfg(any(test, feature = "test-support"))]
397 pub fn set_id(&mut self, id: usize) -> &Self {
398 self.id = id;
399 self
400 }
401
402 #[cfg(any(test, feature = "test-support"))]
403 pub fn teardown(&self) {
404 let mut state = self.state.write();
405 state._reconnect_task.take();
406 state.message_handlers.clear();
407 state.models_by_message_type.clear();
408 state.entities_by_type_and_remote_id.clear();
409 state.entity_id_extractors.clear();
410 self.peer.teardown();
411 }
412
413 #[cfg(any(test, feature = "test-support"))]
414 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
415 where
416 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
417 {
418 *self.authenticate.write() = Some(Box::new(authenticate));
419 self
420 }
421
422 #[cfg(any(test, feature = "test-support"))]
423 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
424 where
425 F: 'static
426 + Send
427 + Sync
428 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
429 {
430 *self.establish_connection.write() = Some(Box::new(connect));
431 self
432 }
433
434 pub fn user_id(&self) -> Option<u64> {
435 self.state
436 .read()
437 .credentials
438 .as_ref()
439 .map(|credentials| credentials.user_id)
440 }
441
442 pub fn peer_id(&self) -> Option<PeerId> {
443 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
444 Some(*peer_id)
445 } else {
446 None
447 }
448 }
449
450 pub fn status(&self) -> watch::Receiver<Status> {
451 self.state.read().status.1.clone()
452 }
453
454 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
455 log::info!("set status on client {}: {:?}", self.id, status);
456 let mut state = self.state.write();
457 *state.status.0.borrow_mut() = status;
458
459 match status {
460 Status::Connected { .. } => {
461 state._reconnect_task = None;
462 }
463 Status::ConnectionLost => {
464 let this = self.clone();
465 let reconnect_interval = state.reconnect_interval;
466 state._reconnect_task = Some(cx.spawn(|cx| async move {
467 #[cfg(any(test, feature = "test-support"))]
468 let mut rng = StdRng::seed_from_u64(0);
469 #[cfg(not(any(test, feature = "test-support")))]
470 let mut rng = StdRng::from_entropy();
471
472 let mut delay = INITIAL_RECONNECTION_DELAY;
473 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
474 log::error!("failed to connect {}", error);
475 if matches!(*this.status().borrow(), Status::ConnectionError) {
476 this.set_status(
477 Status::ReconnectionError {
478 next_reconnection: Instant::now() + delay,
479 },
480 &cx,
481 );
482 cx.background().timer(delay).await;
483 delay = delay
484 .mul_f32(rng.gen_range(1.0..=2.0))
485 .min(reconnect_interval);
486 } else {
487 break;
488 }
489 }
490 }));
491 }
492 Status::SignedOut | Status::UpgradeRequired => {
493 cx.read(|cx| self.telemetry.set_authenticated_user_info(None, false, cx));
494 state._reconnect_task.take();
495 }
496 _ => {}
497 }
498 }
499
500 pub fn add_view_for_remote_entity<T: View>(
501 self: &Arc<Self>,
502 remote_id: u64,
503 cx: &mut ViewContext<T>,
504 ) -> Subscription {
505 let id = (TypeId::of::<T>(), remote_id);
506 self.state
507 .write()
508 .entities_by_type_and_remote_id
509 .insert(id, WeakSubscriber::View(cx.weak_handle().into_any()));
510 Subscription::Entity {
511 client: Arc::downgrade(self),
512 id,
513 }
514 }
515
516 pub fn subscribe_to_entity<T: Entity>(
517 self: &Arc<Self>,
518 remote_id: u64,
519 ) -> Result<PendingEntitySubscription<T>> {
520 let id = (TypeId::of::<T>(), remote_id);
521
522 let mut state = self.state.write();
523 if state.entities_by_type_and_remote_id.contains_key(&id) {
524 return Err(anyhow!("already subscribed to entity"));
525 } else {
526 state
527 .entities_by_type_and_remote_id
528 .insert(id, WeakSubscriber::Pending(Default::default()));
529 Ok(PendingEntitySubscription {
530 client: self.clone(),
531 remote_id,
532 consumed: false,
533 _entity_type: PhantomData,
534 })
535 }
536 }
537
538 pub fn add_message_handler<M, E, H, F>(
539 self: &Arc<Self>,
540 model: ModelHandle<E>,
541 handler: H,
542 ) -> Subscription
543 where
544 M: EnvelopedMessage,
545 E: Entity,
546 H: 'static
547 + Send
548 + Sync
549 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
550 F: 'static + Future<Output = Result<()>>,
551 {
552 let message_type_id = TypeId::of::<M>();
553
554 let mut state = self.state.write();
555 state
556 .models_by_message_type
557 .insert(message_type_id, model.downgrade().into_any());
558
559 let prev_handler = state.message_handlers.insert(
560 message_type_id,
561 Arc::new(move |handle, envelope, client, cx| {
562 let handle = if let Subscriber::Model(handle) = handle {
563 handle
564 } else {
565 unreachable!();
566 };
567 let model = handle.downcast::<E>().unwrap();
568 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
569 handler(model, *envelope, client.clone(), cx).boxed_local()
570 }),
571 );
572 if prev_handler.is_some() {
573 panic!("registered handler for the same message twice");
574 }
575
576 Subscription::Message {
577 client: Arc::downgrade(self),
578 id: message_type_id,
579 }
580 }
581
582 pub fn add_request_handler<M, E, H, F>(
583 self: &Arc<Self>,
584 model: ModelHandle<E>,
585 handler: H,
586 ) -> Subscription
587 where
588 M: RequestMessage,
589 E: Entity,
590 H: 'static
591 + Send
592 + Sync
593 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
594 F: 'static + Future<Output = Result<M::Response>>,
595 {
596 self.add_message_handler(model, move |handle, envelope, this, cx| {
597 Self::respond_to_request(
598 envelope.receipt(),
599 handler(handle, envelope, this.clone(), cx),
600 this,
601 )
602 })
603 }
604
605 pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
606 where
607 M: EntityMessage,
608 E: View,
609 H: 'static
610 + Send
611 + Sync
612 + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
613 F: 'static + Future<Output = Result<()>>,
614 {
615 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
616 if let Subscriber::View(handle) = handle {
617 handler(handle.downcast::<E>().unwrap(), message, client, cx)
618 } else {
619 unreachable!();
620 }
621 })
622 }
623
624 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
625 where
626 M: EntityMessage,
627 E: Entity,
628 H: 'static
629 + Send
630 + Sync
631 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
632 F: 'static + Future<Output = Result<()>>,
633 {
634 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
635 if let Subscriber::Model(handle) = handle {
636 handler(handle.downcast::<E>().unwrap(), message, client, cx)
637 } else {
638 unreachable!();
639 }
640 })
641 }
642
643 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
644 where
645 M: EntityMessage,
646 E: Entity,
647 H: 'static
648 + Send
649 + Sync
650 + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
651 F: 'static + Future<Output = Result<()>>,
652 {
653 let model_type_id = TypeId::of::<E>();
654 let message_type_id = TypeId::of::<M>();
655
656 let mut state = self.state.write();
657 state
658 .entity_types_by_message_type
659 .insert(message_type_id, model_type_id);
660 state
661 .entity_id_extractors
662 .entry(message_type_id)
663 .or_insert_with(|| {
664 |envelope| {
665 envelope
666 .as_any()
667 .downcast_ref::<TypedEnvelope<M>>()
668 .unwrap()
669 .payload
670 .remote_entity_id()
671 }
672 });
673 let prev_handler = state.message_handlers.insert(
674 message_type_id,
675 Arc::new(move |handle, envelope, client, cx| {
676 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
677 handler(handle, *envelope, client.clone(), cx).boxed_local()
678 }),
679 );
680 if prev_handler.is_some() {
681 panic!("registered handler for the same message twice");
682 }
683 }
684
685 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
686 where
687 M: EntityMessage + RequestMessage,
688 E: Entity,
689 H: 'static
690 + Send
691 + Sync
692 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
693 F: 'static + Future<Output = Result<M::Response>>,
694 {
695 self.add_model_message_handler(move |entity, envelope, client, cx| {
696 Self::respond_to_request::<M, _>(
697 envelope.receipt(),
698 handler(entity, envelope, client.clone(), cx),
699 client,
700 )
701 })
702 }
703
704 pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
705 where
706 M: EntityMessage + RequestMessage,
707 E: View,
708 H: 'static
709 + Send
710 + Sync
711 + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
712 F: 'static + Future<Output = Result<M::Response>>,
713 {
714 self.add_view_message_handler(move |entity, envelope, client, cx| {
715 Self::respond_to_request::<M, _>(
716 envelope.receipt(),
717 handler(entity, envelope, client.clone(), cx),
718 client,
719 )
720 })
721 }
722
723 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
724 receipt: Receipt<T>,
725 response: F,
726 client: Arc<Self>,
727 ) -> Result<()> {
728 match response.await {
729 Ok(response) => {
730 client.respond(receipt, response)?;
731 Ok(())
732 }
733 Err(error) => {
734 client.respond_with_error(
735 receipt,
736 proto::Error {
737 message: format!("{:?}", error),
738 },
739 )?;
740 Err(error)
741 }
742 }
743 }
744
745 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
746 read_credentials_from_keychain(cx).is_some()
747 }
748
749 #[async_recursion(?Send)]
750 pub async fn authenticate_and_connect(
751 self: &Arc<Self>,
752 try_keychain: bool,
753 cx: &AsyncAppContext,
754 ) -> anyhow::Result<()> {
755 let was_disconnected = match *self.status().borrow() {
756 Status::SignedOut => true,
757 Status::ConnectionError
758 | Status::ConnectionLost
759 | Status::Authenticating { .. }
760 | Status::Reauthenticating { .. }
761 | Status::ReconnectionError { .. } => false,
762 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
763 return Ok(())
764 }
765 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
766 };
767
768 if was_disconnected {
769 self.set_status(Status::Authenticating, cx);
770 } else {
771 self.set_status(Status::Reauthenticating, cx)
772 }
773
774 let mut read_from_keychain = false;
775 let mut credentials = self.state.read().credentials.clone();
776 if credentials.is_none() && try_keychain {
777 credentials = read_credentials_from_keychain(cx);
778 read_from_keychain = credentials.is_some();
779 }
780 if credentials.is_none() {
781 let mut status_rx = self.status();
782 let _ = status_rx.next().await;
783 futures::select_biased! {
784 authenticate = self.authenticate(cx).fuse() => {
785 match authenticate {
786 Ok(creds) => credentials = Some(creds),
787 Err(err) => {
788 self.set_status(Status::ConnectionError, cx);
789 return Err(err);
790 }
791 }
792 }
793 _ = status_rx.next().fuse() => {
794 return Err(anyhow!("authentication canceled"));
795 }
796 }
797 }
798 let credentials = credentials.unwrap();
799
800 if was_disconnected {
801 self.set_status(Status::Connecting, cx);
802 } else {
803 self.set_status(Status::Reconnecting, cx);
804 }
805
806 let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
807 futures::select_biased! {
808 connection = self.establish_connection(&credentials, cx).fuse() => {
809 match connection {
810 Ok(conn) => {
811 self.state.write().credentials = Some(credentials.clone());
812 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
813 write_credentials_to_keychain(&credentials, cx).log_err();
814 }
815
816 futures::select_biased! {
817 result = self.set_connection(conn, cx).fuse() => result,
818 _ = timeout => {
819 self.set_status(Status::ConnectionError, cx);
820 Err(anyhow!("timed out waiting on hello message from server"))
821 }
822 }
823 }
824 Err(EstablishConnectionError::Unauthorized) => {
825 self.state.write().credentials.take();
826 if read_from_keychain {
827 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
828 self.set_status(Status::SignedOut, cx);
829 self.authenticate_and_connect(false, cx).await
830 } else {
831 self.set_status(Status::ConnectionError, cx);
832 Err(EstablishConnectionError::Unauthorized)?
833 }
834 }
835 Err(EstablishConnectionError::UpgradeRequired) => {
836 self.set_status(Status::UpgradeRequired, cx);
837 Err(EstablishConnectionError::UpgradeRequired)?
838 }
839 Err(error) => {
840 self.set_status(Status::ConnectionError, cx);
841 Err(error)?
842 }
843 }
844 }
845 _ = &mut timeout => {
846 self.set_status(Status::ConnectionError, cx);
847 Err(anyhow!("timed out trying to establish connection"))
848 }
849 }
850 }
851
852 async fn set_connection(
853 self: &Arc<Self>,
854 conn: Connection,
855 cx: &AsyncAppContext,
856 ) -> Result<()> {
857 let executor = cx.background();
858 log::info!("add connection to peer");
859 let (connection_id, handle_io, mut incoming) = self
860 .peer
861 .add_connection(conn, move |duration| executor.timer(duration));
862 let handle_io = cx.background().spawn(handle_io);
863
864 let peer_id = async {
865 log::info!("waiting for server hello");
866 let message = incoming
867 .next()
868 .await
869 .ok_or_else(|| anyhow!("no hello message received"))?;
870 log::info!("got server hello");
871 let hello_message_type_name = message.payload_type_name().to_string();
872 let hello = message
873 .into_any()
874 .downcast::<TypedEnvelope<proto::Hello>>()
875 .map_err(|_| {
876 anyhow!(
877 "invalid hello message received: {:?}",
878 hello_message_type_name
879 )
880 })?;
881 let peer_id = hello
882 .payload
883 .peer_id
884 .ok_or_else(|| anyhow!("invalid peer id"))?;
885 Ok(peer_id)
886 };
887
888 let peer_id = match peer_id.await {
889 Ok(peer_id) => peer_id,
890 Err(error) => {
891 self.peer.disconnect(connection_id);
892 return Err(error);
893 }
894 };
895
896 log::info!(
897 "set status to connected (connection id: {:?}, peer id: {:?})",
898 connection_id,
899 peer_id
900 );
901 self.set_status(
902 Status::Connected {
903 peer_id,
904 connection_id,
905 },
906 cx,
907 );
908 cx.foreground()
909 .spawn({
910 let cx = cx.clone();
911 let this = self.clone();
912 async move {
913 while let Some(message) = incoming.next().await {
914 this.handle_message(message, &cx);
915 // Don't starve the main thread when receiving lots of messages at once.
916 smol::future::yield_now().await;
917 }
918 }
919 })
920 .detach();
921
922 let this = self.clone();
923 let cx = cx.clone();
924 cx.foreground()
925 .spawn(async move {
926 match handle_io.await {
927 Ok(()) => {
928 if this.status().borrow().clone()
929 == (Status::Connected {
930 connection_id,
931 peer_id,
932 })
933 {
934 this.set_status(Status::SignedOut, &cx);
935 }
936 }
937 Err(err) => {
938 log::error!("connection error: {:?}", err);
939 this.set_status(Status::ConnectionLost, &cx);
940 }
941 }
942 })
943 .detach();
944
945 Ok(())
946 }
947
948 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
949 #[cfg(any(test, feature = "test-support"))]
950 if let Some(callback) = self.authenticate.read().as_ref() {
951 return callback(cx);
952 }
953
954 self.authenticate_with_browser(cx)
955 }
956
957 fn establish_connection(
958 self: &Arc<Self>,
959 credentials: &Credentials,
960 cx: &AsyncAppContext,
961 ) -> Task<Result<Connection, EstablishConnectionError>> {
962 #[cfg(any(test, feature = "test-support"))]
963 if let Some(callback) = self.establish_connection.read().as_ref() {
964 return callback(credentials, cx);
965 }
966
967 self.establish_websocket_connection(credentials, cx)
968 }
969
970 async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
971 let preview_param = if is_preview { "?preview=1" } else { "" };
972 let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
973 let response = http.get(&url, Default::default(), false).await?;
974
975 // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
976 // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
977 // which requires authorization via an HTTP header.
978 //
979 // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
980 // of a collab server. In that case, a request to the /rpc endpoint will
981 // return an 'unauthorized' response.
982 let collab_url = if response.status().is_redirection() {
983 response
984 .headers()
985 .get("Location")
986 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
987 .to_str()
988 .map_err(EstablishConnectionError::other)?
989 .to_string()
990 } else if response.status() == StatusCode::UNAUTHORIZED {
991 url
992 } else {
993 Err(anyhow!(
994 "unexpected /rpc response status {}",
995 response.status()
996 ))?
997 };
998
999 Url::parse(&collab_url).context("invalid rpc url")
1000 }
1001
1002 fn establish_websocket_connection(
1003 self: &Arc<Self>,
1004 credentials: &Credentials,
1005 cx: &AsyncAppContext,
1006 ) -> Task<Result<Connection, EstablishConnectionError>> {
1007 let is_preview = cx.read(|cx| {
1008 if cx.has_global::<ReleaseChannel>() {
1009 *cx.global::<ReleaseChannel>() == ReleaseChannel::Preview
1010 } else {
1011 false
1012 }
1013 });
1014
1015 let request = Request::builder()
1016 .header(
1017 "Authorization",
1018 format!("{} {}", credentials.user_id, credentials.access_token),
1019 )
1020 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
1021
1022 let http = self.http.clone();
1023 cx.background().spawn(async move {
1024 let mut rpc_url = Self::get_rpc_url(http, is_preview).await?;
1025 let rpc_host = rpc_url
1026 .host_str()
1027 .zip(rpc_url.port_or_known_default())
1028 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1029 let stream = smol::net::TcpStream::connect(rpc_host).await?;
1030
1031 log::info!("connected to rpc endpoint {}", rpc_url);
1032
1033 match rpc_url.scheme() {
1034 "https" => {
1035 rpc_url.set_scheme("wss").unwrap();
1036 let request = request.uri(rpc_url.as_str()).body(())?;
1037 let (stream, _) =
1038 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1039 Ok(Connection::new(
1040 stream
1041 .map_err(|error| anyhow!(error))
1042 .sink_map_err(|error| anyhow!(error)),
1043 ))
1044 }
1045 "http" => {
1046 rpc_url.set_scheme("ws").unwrap();
1047 let request = request.uri(rpc_url.as_str()).body(())?;
1048 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1049 Ok(Connection::new(
1050 stream
1051 .map_err(|error| anyhow!(error))
1052 .sink_map_err(|error| anyhow!(error)),
1053 ))
1054 }
1055 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1056 }
1057 })
1058 }
1059
1060 pub fn authenticate_with_browser(
1061 self: &Arc<Self>,
1062 cx: &AsyncAppContext,
1063 ) -> Task<Result<Credentials>> {
1064 let platform = cx.platform();
1065 let executor = cx.background();
1066 let http = self.http.clone();
1067
1068 executor.clone().spawn(async move {
1069 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1070 // zed server to encrypt the user's access token, so that it can'be intercepted by
1071 // any other app running on the user's device.
1072 let (public_key, private_key) =
1073 rpc::auth::keypair().expect("failed to generate keypair for auth");
1074 let public_key_string =
1075 String::try_from(public_key).expect("failed to serialize public key for auth");
1076
1077 if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1078 return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1079 }
1080
1081 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1082 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1083 let port = server.server_addr().port();
1084
1085 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1086 // that the user is signing in from a Zed app running on the same device.
1087 let mut url = format!(
1088 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1089 *ZED_SERVER_URL, port, public_key_string
1090 );
1091
1092 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1093 log::info!("impersonating user @{}", impersonate_login);
1094 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1095 }
1096
1097 platform.open_url(&url);
1098
1099 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1100 // access token from the query params.
1101 //
1102 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1103 // custom URL scheme instead of this local HTTP server.
1104 let (user_id, access_token) = executor
1105 .spawn(async move {
1106 for _ in 0..100 {
1107 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1108 let path = req.url();
1109 let mut user_id = None;
1110 let mut access_token = None;
1111 let url = Url::parse(&format!("http://example.com{}", path))
1112 .context("failed to parse login notification url")?;
1113 for (key, value) in url.query_pairs() {
1114 if key == "access_token" {
1115 access_token = Some(value.to_string());
1116 } else if key == "user_id" {
1117 user_id = Some(value.to_string());
1118 }
1119 }
1120
1121 let post_auth_url =
1122 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1123 req.respond(
1124 tiny_http::Response::empty(302).with_header(
1125 tiny_http::Header::from_bytes(
1126 &b"Location"[..],
1127 post_auth_url.as_bytes(),
1128 )
1129 .unwrap(),
1130 ),
1131 )
1132 .context("failed to respond to login http request")?;
1133 return Ok((
1134 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1135 access_token
1136 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1137 ));
1138 }
1139 }
1140
1141 Err(anyhow!("didn't receive login redirect"))
1142 })
1143 .await?;
1144
1145 let access_token = private_key
1146 .decrypt_string(&access_token)
1147 .context("failed to decrypt access token")?;
1148 platform.activate(true);
1149
1150 Ok(Credentials {
1151 user_id: user_id.parse()?,
1152 access_token,
1153 })
1154 })
1155 }
1156
1157 async fn authenticate_as_admin(
1158 http: Arc<dyn HttpClient>,
1159 login: String,
1160 mut api_token: String,
1161 ) -> Result<Credentials> {
1162 #[derive(Deserialize)]
1163 struct AuthenticatedUserResponse {
1164 user: User,
1165 }
1166
1167 #[derive(Deserialize)]
1168 struct User {
1169 id: u64,
1170 }
1171
1172 // Use the collab server's admin API to retrieve the id
1173 // of the impersonated user.
1174 let mut url = Self::get_rpc_url(http.clone(), false).await?;
1175 url.set_path("/user");
1176 url.set_query(Some(&format!("github_login={login}")));
1177 let request = Request::get(url.as_str())
1178 .header("Authorization", format!("token {api_token}"))
1179 .body("".into())?;
1180
1181 let mut response = http.send(request).await?;
1182 let mut body = String::new();
1183 response.body_mut().read_to_string(&mut body).await?;
1184 if !response.status().is_success() {
1185 Err(anyhow!(
1186 "admin user request failed {} - {}",
1187 response.status().as_u16(),
1188 body,
1189 ))?;
1190 }
1191 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1192
1193 // Use the admin API token to authenticate as the impersonated user.
1194 api_token.insert_str(0, "ADMIN_TOKEN:");
1195 Ok(Credentials {
1196 user_id: response.user.id,
1197 access_token: api_token,
1198 })
1199 }
1200
1201 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1202 self.peer.teardown();
1203 self.set_status(Status::SignedOut, cx);
1204 }
1205
1206 fn connection_id(&self) -> Result<ConnectionId> {
1207 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1208 Ok(connection_id)
1209 } else {
1210 Err(anyhow!("not connected"))
1211 }
1212 }
1213
1214 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1215 log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1216 self.peer.send(self.connection_id()?, message)
1217 }
1218
1219 pub fn request<T: RequestMessage>(
1220 &self,
1221 request: T,
1222 ) -> impl Future<Output = Result<T::Response>> {
1223 self.request_envelope(request)
1224 .map_ok(|envelope| envelope.payload)
1225 }
1226
1227 pub fn request_envelope<T: RequestMessage>(
1228 &self,
1229 request: T,
1230 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1231 let client_id = self.id;
1232 log::debug!(
1233 "rpc request start. client_id:{}. name:{}",
1234 client_id,
1235 T::NAME
1236 );
1237 let response = self
1238 .connection_id()
1239 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1240 async move {
1241 let response = response?.await;
1242 log::debug!(
1243 "rpc request finish. client_id:{}. name:{}",
1244 client_id,
1245 T::NAME
1246 );
1247 response
1248 }
1249 }
1250
1251 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1252 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1253 self.peer.respond(receipt, response)
1254 }
1255
1256 fn respond_with_error<T: RequestMessage>(
1257 &self,
1258 receipt: Receipt<T>,
1259 error: proto::Error,
1260 ) -> Result<()> {
1261 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1262 self.peer.respond_with_error(receipt, error)
1263 }
1264
1265 fn handle_message(
1266 self: &Arc<Client>,
1267 message: Box<dyn AnyTypedEnvelope>,
1268 cx: &AsyncAppContext,
1269 ) {
1270 let mut state = self.state.write();
1271 let type_name = message.payload_type_name();
1272 let payload_type_id = message.payload_type_id();
1273 let sender_id = message.original_sender_id();
1274
1275 let mut subscriber = None;
1276
1277 if let Some(message_model) = state
1278 .models_by_message_type
1279 .get(&payload_type_id)
1280 .and_then(|model| model.upgrade(cx))
1281 {
1282 subscriber = Some(Subscriber::Model(message_model));
1283 } else if let Some((extract_entity_id, entity_type_id)) =
1284 state.entity_id_extractors.get(&payload_type_id).zip(
1285 state
1286 .entity_types_by_message_type
1287 .get(&payload_type_id)
1288 .copied(),
1289 )
1290 {
1291 let entity_id = (extract_entity_id)(message.as_ref());
1292
1293 match state
1294 .entities_by_type_and_remote_id
1295 .get_mut(&(entity_type_id, entity_id))
1296 {
1297 Some(WeakSubscriber::Pending(pending)) => {
1298 pending.push(message);
1299 return;
1300 }
1301 Some(weak_subscriber @ _) => match weak_subscriber {
1302 WeakSubscriber::Model(handle) => {
1303 subscriber = handle.upgrade(cx).map(Subscriber::Model);
1304 }
1305 WeakSubscriber::View(handle) => {
1306 subscriber = Some(Subscriber::View(handle.clone()));
1307 }
1308 WeakSubscriber::Pending(_) => {}
1309 },
1310 _ => {}
1311 }
1312 }
1313
1314 let subscriber = if let Some(subscriber) = subscriber {
1315 subscriber
1316 } else {
1317 log::info!("unhandled message {}", type_name);
1318 self.peer.respond_with_unhandled_message(message).log_err();
1319 return;
1320 };
1321
1322 let handler = state.message_handlers.get(&payload_type_id).cloned();
1323 // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1324 // It also ensures we don't hold the lock while yielding back to the executor, as
1325 // that might cause the executor thread driving this future to block indefinitely.
1326 drop(state);
1327
1328 if let Some(handler) = handler {
1329 let future = handler(subscriber, message, &self, cx.clone());
1330 let client_id = self.id;
1331 log::debug!(
1332 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1333 client_id,
1334 sender_id,
1335 type_name
1336 );
1337 cx.foreground()
1338 .spawn(async move {
1339 match future.await {
1340 Ok(()) => {
1341 log::debug!(
1342 "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1343 client_id,
1344 sender_id,
1345 type_name
1346 );
1347 }
1348 Err(error) => {
1349 log::error!(
1350 "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1351 client_id,
1352 sender_id,
1353 type_name,
1354 error
1355 );
1356 }
1357 }
1358 })
1359 .detach();
1360 } else {
1361 log::info!("unhandled message {}", type_name);
1362 self.peer.respond_with_unhandled_message(message).log_err();
1363 }
1364 }
1365
1366 pub fn telemetry(&self) -> &Arc<Telemetry> {
1367 &self.telemetry
1368 }
1369}
1370
1371fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1372 if IMPERSONATE_LOGIN.is_some() {
1373 return None;
1374 }
1375
1376 let (user_id, access_token) = cx
1377 .platform()
1378 .read_credentials(&ZED_SERVER_URL)
1379 .log_err()
1380 .flatten()?;
1381 Some(Credentials {
1382 user_id: user_id.parse().ok()?,
1383 access_token: String::from_utf8(access_token).ok()?,
1384 })
1385}
1386
1387fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1388 cx.platform().write_credentials(
1389 &ZED_SERVER_URL,
1390 &credentials.user_id.to_string(),
1391 credentials.access_token.as_bytes(),
1392 )
1393}
1394
1395const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1396
1397pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1398 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1399}
1400
1401pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1402 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1403 let mut parts = path.split('/');
1404 let id = parts.next()?.parse::<u64>().ok()?;
1405 let access_token = parts.next()?;
1406 if access_token.is_empty() {
1407 return None;
1408 }
1409 Some((id, access_token.to_string()))
1410}
1411
1412#[cfg(test)]
1413mod tests {
1414 use super::*;
1415 use crate::test::FakeServer;
1416 use gpui::{executor::Deterministic, TestAppContext};
1417 use parking_lot::Mutex;
1418 use std::future;
1419 use util::http::FakeHttpClient;
1420
1421 #[gpui::test(iterations = 10)]
1422 async fn test_reconnection(cx: &mut TestAppContext) {
1423 cx.foreground().forbid_parking();
1424
1425 let user_id = 5;
1426 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1427 let server = FakeServer::for_client(user_id, &client, cx).await;
1428 let mut status = client.status();
1429 assert!(matches!(
1430 status.next().await,
1431 Some(Status::Connected { .. })
1432 ));
1433 assert_eq!(server.auth_count(), 1);
1434
1435 server.forbid_connections();
1436 server.disconnect();
1437 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1438
1439 server.allow_connections();
1440 cx.foreground().advance_clock(Duration::from_secs(10));
1441 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1442 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1443
1444 server.forbid_connections();
1445 server.disconnect();
1446 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1447
1448 // Clear cached credentials after authentication fails
1449 server.roll_access_token();
1450 server.allow_connections();
1451 cx.foreground().advance_clock(Duration::from_secs(10));
1452 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1453 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1454 }
1455
1456 #[gpui::test(iterations = 10)]
1457 async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1458 deterministic.forbid_parking();
1459
1460 let user_id = 5;
1461 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1462 let mut status = client.status();
1463
1464 // Time out when client tries to connect.
1465 client.override_authenticate(move |cx| {
1466 cx.foreground().spawn(async move {
1467 Ok(Credentials {
1468 user_id,
1469 access_token: "token".into(),
1470 })
1471 })
1472 });
1473 client.override_establish_connection(|_, cx| {
1474 cx.foreground().spawn(async move {
1475 future::pending::<()>().await;
1476 unreachable!()
1477 })
1478 });
1479 let auth_and_connect = cx.spawn({
1480 let client = client.clone();
1481 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1482 });
1483 deterministic.run_until_parked();
1484 assert!(matches!(status.next().await, Some(Status::Connecting)));
1485
1486 deterministic.advance_clock(CONNECTION_TIMEOUT);
1487 assert!(matches!(
1488 status.next().await,
1489 Some(Status::ConnectionError { .. })
1490 ));
1491 auth_and_connect.await.unwrap_err();
1492
1493 // Allow the connection to be established.
1494 let server = FakeServer::for_client(user_id, &client, cx).await;
1495 assert!(matches!(
1496 status.next().await,
1497 Some(Status::Connected { .. })
1498 ));
1499
1500 // Disconnect client.
1501 server.forbid_connections();
1502 server.disconnect();
1503 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1504
1505 // Time out when re-establishing the connection.
1506 server.allow_connections();
1507 client.override_establish_connection(|_, cx| {
1508 cx.foreground().spawn(async move {
1509 future::pending::<()>().await;
1510 unreachable!()
1511 })
1512 });
1513 deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1514 assert!(matches!(
1515 status.next().await,
1516 Some(Status::Reconnecting { .. })
1517 ));
1518
1519 deterministic.advance_clock(CONNECTION_TIMEOUT);
1520 assert!(matches!(
1521 status.next().await,
1522 Some(Status::ReconnectionError { .. })
1523 ));
1524 }
1525
1526 #[gpui::test(iterations = 10)]
1527 async fn test_authenticating_more_than_once(
1528 cx: &mut TestAppContext,
1529 deterministic: Arc<Deterministic>,
1530 ) {
1531 cx.foreground().forbid_parking();
1532
1533 let auth_count = Arc::new(Mutex::new(0));
1534 let dropped_auth_count = Arc::new(Mutex::new(0));
1535 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1536 client.override_authenticate({
1537 let auth_count = auth_count.clone();
1538 let dropped_auth_count = dropped_auth_count.clone();
1539 move |cx| {
1540 let auth_count = auth_count.clone();
1541 let dropped_auth_count = dropped_auth_count.clone();
1542 cx.foreground().spawn(async move {
1543 *auth_count.lock() += 1;
1544 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1545 future::pending::<()>().await;
1546 unreachable!()
1547 })
1548 }
1549 });
1550
1551 let _authenticate = cx.spawn(|cx| {
1552 let client = client.clone();
1553 async move { client.authenticate_and_connect(false, &cx).await }
1554 });
1555 deterministic.run_until_parked();
1556 assert_eq!(*auth_count.lock(), 1);
1557 assert_eq!(*dropped_auth_count.lock(), 0);
1558
1559 let _authenticate = cx.spawn(|cx| {
1560 let client = client.clone();
1561 async move { client.authenticate_and_connect(false, &cx).await }
1562 });
1563 deterministic.run_until_parked();
1564 assert_eq!(*auth_count.lock(), 2);
1565 assert_eq!(*dropped_auth_count.lock(), 1);
1566 }
1567
1568 #[test]
1569 fn test_encode_and_decode_worktree_url() {
1570 let url = encode_worktree_url(5, "deadbeef");
1571 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1572 assert_eq!(
1573 decode_worktree_url(&format!("\n {}\t", url)),
1574 Some((5, "deadbeef".to_string()))
1575 );
1576 assert_eq!(decode_worktree_url("not://the-right-format"), None);
1577 }
1578
1579 #[gpui::test]
1580 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1581 cx.foreground().forbid_parking();
1582
1583 let user_id = 5;
1584 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1585 let server = FakeServer::for_client(user_id, &client, cx).await;
1586
1587 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1588 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1589 client.add_model_message_handler(
1590 move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1591 match model.read_with(&cx, |model, _| model.id) {
1592 1 => done_tx1.try_send(()).unwrap(),
1593 2 => done_tx2.try_send(()).unwrap(),
1594 _ => unreachable!(),
1595 }
1596 async { Ok(()) }
1597 },
1598 );
1599 let model1 = cx.add_model(|_| Model {
1600 id: 1,
1601 subscription: None,
1602 });
1603 let model2 = cx.add_model(|_| Model {
1604 id: 2,
1605 subscription: None,
1606 });
1607 let model3 = cx.add_model(|_| Model {
1608 id: 3,
1609 subscription: None,
1610 });
1611
1612 let _subscription1 = client
1613 .subscribe_to_entity(1)
1614 .unwrap()
1615 .set_model(&model1, &mut cx.to_async());
1616 let _subscription2 = client
1617 .subscribe_to_entity(2)
1618 .unwrap()
1619 .set_model(&model2, &mut cx.to_async());
1620 // Ensure dropping a subscription for the same entity type still allows receiving of
1621 // messages for other entity IDs of the same type.
1622 let subscription3 = client
1623 .subscribe_to_entity(3)
1624 .unwrap()
1625 .set_model(&model3, &mut cx.to_async());
1626 drop(subscription3);
1627
1628 server.send(proto::JoinProject { project_id: 1 });
1629 server.send(proto::JoinProject { project_id: 2 });
1630 done_rx1.next().await.unwrap();
1631 done_rx2.next().await.unwrap();
1632 }
1633
1634 #[gpui::test]
1635 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1636 cx.foreground().forbid_parking();
1637
1638 let user_id = 5;
1639 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1640 let server = FakeServer::for_client(user_id, &client, cx).await;
1641
1642 let model = cx.add_model(|_| Model::default());
1643 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1644 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1645 let subscription1 = client.add_message_handler(
1646 model.clone(),
1647 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1648 done_tx1.try_send(()).unwrap();
1649 async { Ok(()) }
1650 },
1651 );
1652 drop(subscription1);
1653 let _subscription2 = client.add_message_handler(
1654 model.clone(),
1655 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1656 done_tx2.try_send(()).unwrap();
1657 async { Ok(()) }
1658 },
1659 );
1660 server.send(proto::Ping {});
1661 done_rx2.next().await.unwrap();
1662 }
1663
1664 #[gpui::test]
1665 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1666 cx.foreground().forbid_parking();
1667
1668 let user_id = 5;
1669 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1670 let server = FakeServer::for_client(user_id, &client, cx).await;
1671
1672 let model = cx.add_model(|_| Model::default());
1673 let (done_tx, mut done_rx) = smol::channel::unbounded();
1674 let subscription = client.add_message_handler(
1675 model.clone(),
1676 move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1677 model.update(&mut cx, |model, _| model.subscription.take());
1678 done_tx.try_send(()).unwrap();
1679 async { Ok(()) }
1680 },
1681 );
1682 model.update(cx, |model, _| {
1683 model.subscription = Some(subscription);
1684 });
1685 server.send(proto::Ping {});
1686 done_rx.next().await.unwrap();
1687 }
1688
1689 #[derive(Default)]
1690 struct Model {
1691 id: usize,
1692 subscription: Option<Subscription>,
1693 }
1694
1695 impl Entity for Model {
1696 type Event = ();
1697 }
1698}