1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod telemetry;
5pub mod user;
6
7use anyhow::{anyhow, Context, Result};
8use async_recursion::async_recursion;
9use async_tungstenite::tungstenite::{
10 error::Error as WebsocketError,
11 http::{Request, StatusCode},
12};
13use futures::{
14 future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
15 TryStreamExt,
16};
17use gpui::{
18 actions, platform::AppVersion, serde_json, AnyModelHandle, AnyWeakModelHandle,
19 AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Task, View, ViewContext,
20 WeakViewHandle,
21};
22use lazy_static::lazy_static;
23use parking_lot::RwLock;
24use postage::watch;
25use rand::prelude::*;
26use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
27use schemars::JsonSchema;
28use serde::{Deserialize, Serialize};
29use std::{
30 any::TypeId,
31 collections::HashMap,
32 convert::TryFrom,
33 fmt::Write as _,
34 future::Future,
35 marker::PhantomData,
36 path::PathBuf,
37 sync::{Arc, Weak},
38 time::{Duration, Instant},
39};
40use telemetry::Telemetry;
41use thiserror::Error;
42use url::Url;
43use util::channel::ReleaseChannel;
44use util::http::HttpClient;
45use util::{ResultExt, TryFutureExt};
46
47pub use rpc::*;
48pub use telemetry::ClickhouseEvent;
49pub use user::*;
50
51lazy_static! {
52 pub static ref ZED_SERVER_URL: String =
53 std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
54 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
55 .ok()
56 .and_then(|s| if s.is_empty() { None } else { Some(s) });
57 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
58 .ok()
59 .and_then(|s| if s.is_empty() { None } else { Some(s) });
60 pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
61 .ok()
62 .and_then(|v| v.parse().ok());
63 pub static ref ZED_APP_PATH: Option<PathBuf> =
64 std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
65}
66
67pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
68pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
69pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
70
71actions!(client, [SignIn, SignOut]);
72
73pub fn init(client: Arc<Client>, cx: &mut AppContext) {
74 settings::register_setting::<TelemetrySettings>(cx);
75
76 cx.add_global_action({
77 let client = client.clone();
78 move |_: &SignIn, cx| {
79 let client = client.clone();
80 cx.spawn(
81 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
82 )
83 .detach();
84 }
85 });
86 cx.add_global_action({
87 let client = client.clone();
88 move |_: &SignOut, cx| {
89 let client = client.clone();
90 cx.spawn(|cx| async move {
91 client.disconnect(&cx);
92 })
93 .detach();
94 }
95 });
96}
97
98pub struct Client {
99 id: usize,
100 peer: Arc<Peer>,
101 http: Arc<dyn HttpClient>,
102 telemetry: Arc<Telemetry>,
103 state: RwLock<ClientState>,
104
105 #[allow(clippy::type_complexity)]
106 #[cfg(any(test, feature = "test-support"))]
107 authenticate: RwLock<
108 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
109 >,
110
111 #[allow(clippy::type_complexity)]
112 #[cfg(any(test, feature = "test-support"))]
113 establish_connection: RwLock<
114 Option<
115 Box<
116 dyn 'static
117 + Send
118 + Sync
119 + Fn(
120 &Credentials,
121 &AsyncAppContext,
122 ) -> Task<Result<Connection, EstablishConnectionError>>,
123 >,
124 >,
125 >,
126}
127
128#[derive(Error, Debug)]
129pub enum EstablishConnectionError {
130 #[error("upgrade required")]
131 UpgradeRequired,
132 #[error("unauthorized")]
133 Unauthorized,
134 #[error("{0}")]
135 Other(#[from] anyhow::Error),
136 #[error("{0}")]
137 Http(#[from] util::http::Error),
138 #[error("{0}")]
139 Io(#[from] std::io::Error),
140 #[error("{0}")]
141 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
142}
143
144impl From<WebsocketError> for EstablishConnectionError {
145 fn from(error: WebsocketError) -> Self {
146 if let WebsocketError::Http(response) = &error {
147 match response.status() {
148 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
149 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
150 _ => {}
151 }
152 }
153 EstablishConnectionError::Other(error.into())
154 }
155}
156
157impl EstablishConnectionError {
158 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
159 Self::Other(error.into())
160 }
161}
162
163#[derive(Copy, Clone, Debug, PartialEq)]
164pub enum Status {
165 SignedOut,
166 UpgradeRequired,
167 Authenticating,
168 Connecting,
169 ConnectionError,
170 Connected {
171 peer_id: PeerId,
172 connection_id: ConnectionId,
173 },
174 ConnectionLost,
175 Reauthenticating,
176 Reconnecting,
177 ReconnectionError {
178 next_reconnection: Instant,
179 },
180}
181
182impl Status {
183 pub fn is_connected(&self) -> bool {
184 matches!(self, Self::Connected { .. })
185 }
186
187 pub fn is_signed_out(&self) -> bool {
188 matches!(self, Self::SignedOut | Self::UpgradeRequired)
189 }
190}
191
192struct ClientState {
193 credentials: Option<Credentials>,
194 status: (watch::Sender<Status>, watch::Receiver<Status>),
195 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
196 _reconnect_task: Option<Task<()>>,
197 reconnect_interval: Duration,
198 entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
199 models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
200 entity_types_by_message_type: HashMap<TypeId, TypeId>,
201 #[allow(clippy::type_complexity)]
202 message_handlers: HashMap<
203 TypeId,
204 Arc<
205 dyn Send
206 + Sync
207 + Fn(
208 Subscriber,
209 Box<dyn AnyTypedEnvelope>,
210 &Arc<Client>,
211 AsyncAppContext,
212 ) -> LocalBoxFuture<'static, Result<()>>,
213 >,
214 >,
215}
216
217enum WeakSubscriber {
218 Model(AnyWeakModelHandle),
219 View(AnyWeakViewHandle),
220 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
221}
222
223enum Subscriber {
224 Model(AnyModelHandle),
225 View(AnyWeakViewHandle),
226}
227
228#[derive(Clone, Debug)]
229pub struct Credentials {
230 pub user_id: u64,
231 pub access_token: String,
232}
233
234impl Default for ClientState {
235 fn default() -> Self {
236 Self {
237 credentials: None,
238 status: watch::channel_with(Status::SignedOut),
239 entity_id_extractors: Default::default(),
240 _reconnect_task: None,
241 reconnect_interval: Duration::from_secs(5),
242 models_by_message_type: Default::default(),
243 entities_by_type_and_remote_id: Default::default(),
244 entity_types_by_message_type: Default::default(),
245 message_handlers: Default::default(),
246 }
247 }
248}
249
250pub enum Subscription {
251 Entity {
252 client: Weak<Client>,
253 id: (TypeId, u64),
254 },
255 Message {
256 client: Weak<Client>,
257 id: TypeId,
258 },
259}
260
261impl Drop for Subscription {
262 fn drop(&mut self) {
263 match self {
264 Subscription::Entity { client, id } => {
265 if let Some(client) = client.upgrade() {
266 let mut state = client.state.write();
267 let _ = state.entities_by_type_and_remote_id.remove(id);
268 }
269 }
270 Subscription::Message { client, id } => {
271 if let Some(client) = client.upgrade() {
272 let mut state = client.state.write();
273 let _ = state.entity_types_by_message_type.remove(id);
274 let _ = state.message_handlers.remove(id);
275 }
276 }
277 }
278 }
279}
280
281pub struct PendingEntitySubscription<T: Entity> {
282 client: Arc<Client>,
283 remote_id: u64,
284 _entity_type: PhantomData<T>,
285 consumed: bool,
286}
287
288impl<T: Entity> PendingEntitySubscription<T> {
289 pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
290 self.consumed = true;
291 let mut state = self.client.state.write();
292 let id = (TypeId::of::<T>(), self.remote_id);
293 let Some(WeakSubscriber::Pending(messages)) =
294 state.entities_by_type_and_remote_id.remove(&id)
295 else {
296 unreachable!()
297 };
298
299 state
300 .entities_by_type_and_remote_id
301 .insert(id, WeakSubscriber::Model(model.downgrade().into_any()));
302 drop(state);
303 for message in messages {
304 self.client.handle_message(message, cx);
305 }
306 Subscription::Entity {
307 client: Arc::downgrade(&self.client),
308 id,
309 }
310 }
311}
312
313impl<T: Entity> Drop for PendingEntitySubscription<T> {
314 fn drop(&mut self) {
315 if !self.consumed {
316 let mut state = self.client.state.write();
317 if let Some(WeakSubscriber::Pending(messages)) = state
318 .entities_by_type_and_remote_id
319 .remove(&(TypeId::of::<T>(), self.remote_id))
320 {
321 for message in messages {
322 log::info!("unhandled message {}", message.payload_type_name());
323 }
324 }
325 }
326 }
327}
328
329#[derive(Copy, Clone)]
330pub struct TelemetrySettings {
331 pub diagnostics: bool,
332 pub metrics: bool,
333}
334
335#[derive(Clone, Serialize, Deserialize, JsonSchema)]
336pub struct TelemetrySettingsContent {
337 pub diagnostics: Option<bool>,
338 pub metrics: Option<bool>,
339}
340
341impl settings::Setting for TelemetrySettings {
342 const KEY: Option<&'static str> = Some("telemetry");
343
344 type FileContent = TelemetrySettingsContent;
345
346 fn load(
347 default_value: &Self::FileContent,
348 user_values: &[&Self::FileContent],
349 _: &AppContext,
350 ) -> Self {
351 Self {
352 diagnostics: user_values
353 .first()
354 .and_then(|v| v.diagnostics)
355 .unwrap_or(default_value.diagnostics.unwrap()),
356 metrics: user_values
357 .first()
358 .and_then(|v| v.metrics)
359 .unwrap_or(default_value.metrics.unwrap()),
360 }
361 }
362}
363
364impl Client {
365 pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
366 Arc::new(Self {
367 id: 0,
368 peer: Peer::new(0),
369 telemetry: Telemetry::new(http.clone(), cx),
370 http,
371 state: Default::default(),
372
373 #[cfg(any(test, feature = "test-support"))]
374 authenticate: Default::default(),
375 #[cfg(any(test, feature = "test-support"))]
376 establish_connection: Default::default(),
377 })
378 }
379
380 pub fn id(&self) -> usize {
381 self.id
382 }
383
384 pub fn http_client(&self) -> Arc<dyn HttpClient> {
385 self.http.clone()
386 }
387
388 #[cfg(any(test, feature = "test-support"))]
389 pub fn set_id(&mut self, id: usize) -> &Self {
390 self.id = id;
391 self
392 }
393
394 #[cfg(any(test, feature = "test-support"))]
395 pub fn teardown(&self) {
396 let mut state = self.state.write();
397 state._reconnect_task.take();
398 state.message_handlers.clear();
399 state.models_by_message_type.clear();
400 state.entities_by_type_and_remote_id.clear();
401 state.entity_id_extractors.clear();
402 self.peer.teardown();
403 }
404
405 #[cfg(any(test, feature = "test-support"))]
406 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
407 where
408 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
409 {
410 *self.authenticate.write() = Some(Box::new(authenticate));
411 self
412 }
413
414 #[cfg(any(test, feature = "test-support"))]
415 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
416 where
417 F: 'static
418 + Send
419 + Sync
420 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
421 {
422 *self.establish_connection.write() = Some(Box::new(connect));
423 self
424 }
425
426 pub fn user_id(&self) -> Option<u64> {
427 self.state
428 .read()
429 .credentials
430 .as_ref()
431 .map(|credentials| credentials.user_id)
432 }
433
434 pub fn peer_id(&self) -> Option<PeerId> {
435 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
436 Some(*peer_id)
437 } else {
438 None
439 }
440 }
441
442 pub fn status(&self) -> watch::Receiver<Status> {
443 self.state.read().status.1.clone()
444 }
445
446 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
447 log::info!("set status on client {}: {:?}", self.id, status);
448 let mut state = self.state.write();
449 *state.status.0.borrow_mut() = status;
450
451 match status {
452 Status::Connected { .. } => {
453 state._reconnect_task = None;
454 }
455 Status::ConnectionLost => {
456 let this = self.clone();
457 let reconnect_interval = state.reconnect_interval;
458 state._reconnect_task = Some(cx.spawn(|cx| async move {
459 #[cfg(any(test, feature = "test-support"))]
460 let mut rng = StdRng::seed_from_u64(0);
461 #[cfg(not(any(test, feature = "test-support")))]
462 let mut rng = StdRng::from_entropy();
463
464 let mut delay = INITIAL_RECONNECTION_DELAY;
465 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
466 log::error!("failed to connect {}", error);
467 if matches!(*this.status().borrow(), Status::ConnectionError) {
468 this.set_status(
469 Status::ReconnectionError {
470 next_reconnection: Instant::now() + delay,
471 },
472 &cx,
473 );
474 cx.background().timer(delay).await;
475 delay = delay
476 .mul_f32(rng.gen_range(1.0..=2.0))
477 .min(reconnect_interval);
478 } else {
479 break;
480 }
481 }
482 }));
483 }
484 Status::SignedOut | Status::UpgradeRequired => {
485 cx.read(|cx| self.telemetry.set_authenticated_user_info(None, false, cx));
486 state._reconnect_task.take();
487 }
488 _ => {}
489 }
490 }
491
492 pub fn add_view_for_remote_entity<T: View>(
493 self: &Arc<Self>,
494 remote_id: u64,
495 cx: &mut ViewContext<T>,
496 ) -> Subscription {
497 let id = (TypeId::of::<T>(), remote_id);
498 self.state
499 .write()
500 .entities_by_type_and_remote_id
501 .insert(id, WeakSubscriber::View(cx.weak_handle().into_any()));
502 Subscription::Entity {
503 client: Arc::downgrade(self),
504 id,
505 }
506 }
507
508 pub fn subscribe_to_entity<T: Entity>(
509 self: &Arc<Self>,
510 remote_id: u64,
511 ) -> Result<PendingEntitySubscription<T>> {
512 let id = (TypeId::of::<T>(), remote_id);
513
514 let mut state = self.state.write();
515 if state.entities_by_type_and_remote_id.contains_key(&id) {
516 return Err(anyhow!("already subscribed to entity"));
517 } else {
518 state
519 .entities_by_type_and_remote_id
520 .insert(id, WeakSubscriber::Pending(Default::default()));
521 Ok(PendingEntitySubscription {
522 client: self.clone(),
523 remote_id,
524 consumed: false,
525 _entity_type: PhantomData,
526 })
527 }
528 }
529
530 pub fn add_message_handler<M, E, H, F>(
531 self: &Arc<Self>,
532 model: ModelHandle<E>,
533 handler: H,
534 ) -> Subscription
535 where
536 M: EnvelopedMessage,
537 E: Entity,
538 H: 'static
539 + Send
540 + Sync
541 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
542 F: 'static + Future<Output = Result<()>>,
543 {
544 let message_type_id = TypeId::of::<M>();
545
546 let mut state = self.state.write();
547 state
548 .models_by_message_type
549 .insert(message_type_id, model.downgrade().into_any());
550
551 let prev_handler = state.message_handlers.insert(
552 message_type_id,
553 Arc::new(move |handle, envelope, client, cx| {
554 let handle = if let Subscriber::Model(handle) = handle {
555 handle
556 } else {
557 unreachable!();
558 };
559 let model = handle.downcast::<E>().unwrap();
560 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
561 handler(model, *envelope, client.clone(), cx).boxed_local()
562 }),
563 );
564 if prev_handler.is_some() {
565 panic!("registered handler for the same message twice");
566 }
567
568 Subscription::Message {
569 client: Arc::downgrade(self),
570 id: message_type_id,
571 }
572 }
573
574 pub fn add_request_handler<M, E, H, F>(
575 self: &Arc<Self>,
576 model: ModelHandle<E>,
577 handler: H,
578 ) -> Subscription
579 where
580 M: RequestMessage,
581 E: Entity,
582 H: 'static
583 + Send
584 + Sync
585 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
586 F: 'static + Future<Output = Result<M::Response>>,
587 {
588 self.add_message_handler(model, move |handle, envelope, this, cx| {
589 Self::respond_to_request(
590 envelope.receipt(),
591 handler(handle, envelope, this.clone(), cx),
592 this,
593 )
594 })
595 }
596
597 pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
598 where
599 M: EntityMessage,
600 E: View,
601 H: 'static
602 + Send
603 + Sync
604 + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
605 F: 'static + Future<Output = Result<()>>,
606 {
607 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
608 if let Subscriber::View(handle) = handle {
609 handler(handle.downcast::<E>().unwrap(), message, client, cx)
610 } else {
611 unreachable!();
612 }
613 })
614 }
615
616 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
617 where
618 M: EntityMessage,
619 E: Entity,
620 H: 'static
621 + Send
622 + Sync
623 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
624 F: 'static + Future<Output = Result<()>>,
625 {
626 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
627 if let Subscriber::Model(handle) = handle {
628 handler(handle.downcast::<E>().unwrap(), message, client, cx)
629 } else {
630 unreachable!();
631 }
632 })
633 }
634
635 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
636 where
637 M: EntityMessage,
638 E: Entity,
639 H: 'static
640 + Send
641 + Sync
642 + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
643 F: 'static + Future<Output = Result<()>>,
644 {
645 let model_type_id = TypeId::of::<E>();
646 let message_type_id = TypeId::of::<M>();
647
648 let mut state = self.state.write();
649 state
650 .entity_types_by_message_type
651 .insert(message_type_id, model_type_id);
652 state
653 .entity_id_extractors
654 .entry(message_type_id)
655 .or_insert_with(|| {
656 |envelope| {
657 envelope
658 .as_any()
659 .downcast_ref::<TypedEnvelope<M>>()
660 .unwrap()
661 .payload
662 .remote_entity_id()
663 }
664 });
665 let prev_handler = state.message_handlers.insert(
666 message_type_id,
667 Arc::new(move |handle, envelope, client, cx| {
668 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
669 handler(handle, *envelope, client.clone(), cx).boxed_local()
670 }),
671 );
672 if prev_handler.is_some() {
673 panic!("registered handler for the same message twice");
674 }
675 }
676
677 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
678 where
679 M: EntityMessage + RequestMessage,
680 E: Entity,
681 H: 'static
682 + Send
683 + Sync
684 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
685 F: 'static + Future<Output = Result<M::Response>>,
686 {
687 self.add_model_message_handler(move |entity, envelope, client, cx| {
688 Self::respond_to_request::<M, _>(
689 envelope.receipt(),
690 handler(entity, envelope, client.clone(), cx),
691 client,
692 )
693 })
694 }
695
696 pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
697 where
698 M: EntityMessage + RequestMessage,
699 E: View,
700 H: 'static
701 + Send
702 + Sync
703 + Fn(WeakViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
704 F: 'static + Future<Output = Result<M::Response>>,
705 {
706 self.add_view_message_handler(move |entity, envelope, client, cx| {
707 Self::respond_to_request::<M, _>(
708 envelope.receipt(),
709 handler(entity, envelope, client.clone(), cx),
710 client,
711 )
712 })
713 }
714
715 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
716 receipt: Receipt<T>,
717 response: F,
718 client: Arc<Self>,
719 ) -> Result<()> {
720 match response.await {
721 Ok(response) => {
722 client.respond(receipt, response)?;
723 Ok(())
724 }
725 Err(error) => {
726 client.respond_with_error(
727 receipt,
728 proto::Error {
729 message: format!("{:?}", error),
730 },
731 )?;
732 Err(error)
733 }
734 }
735 }
736
737 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
738 read_credentials_from_keychain(cx).is_some()
739 }
740
741 #[async_recursion(?Send)]
742 pub async fn authenticate_and_connect(
743 self: &Arc<Self>,
744 try_keychain: bool,
745 cx: &AsyncAppContext,
746 ) -> anyhow::Result<()> {
747 let was_disconnected = match *self.status().borrow() {
748 Status::SignedOut => true,
749 Status::ConnectionError
750 | Status::ConnectionLost
751 | Status::Authenticating { .. }
752 | Status::Reauthenticating { .. }
753 | Status::ReconnectionError { .. } => false,
754 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
755 return Ok(())
756 }
757 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
758 };
759
760 if was_disconnected {
761 self.set_status(Status::Authenticating, cx);
762 } else {
763 self.set_status(Status::Reauthenticating, cx)
764 }
765
766 let mut read_from_keychain = false;
767 let mut credentials = self.state.read().credentials.clone();
768 if credentials.is_none() && try_keychain {
769 credentials = read_credentials_from_keychain(cx);
770 read_from_keychain = credentials.is_some();
771 if read_from_keychain {
772 cx.read(|cx| {
773 self.telemetry().report_mixpanel_event(
774 "read credentials from keychain",
775 Default::default(),
776 *settings::get_setting::<TelemetrySettings>(None, cx),
777 );
778 });
779 }
780 }
781 if credentials.is_none() {
782 let mut status_rx = self.status();
783 let _ = status_rx.next().await;
784 futures::select_biased! {
785 authenticate = self.authenticate(cx).fuse() => {
786 match authenticate {
787 Ok(creds) => credentials = Some(creds),
788 Err(err) => {
789 self.set_status(Status::ConnectionError, cx);
790 return Err(err);
791 }
792 }
793 }
794 _ = status_rx.next().fuse() => {
795 return Err(anyhow!("authentication canceled"));
796 }
797 }
798 }
799 let credentials = credentials.unwrap();
800
801 if was_disconnected {
802 self.set_status(Status::Connecting, cx);
803 } else {
804 self.set_status(Status::Reconnecting, cx);
805 }
806
807 let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
808 futures::select_biased! {
809 connection = self.establish_connection(&credentials, cx).fuse() => {
810 match connection {
811 Ok(conn) => {
812 self.state.write().credentials = Some(credentials.clone());
813 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
814 write_credentials_to_keychain(&credentials, cx).log_err();
815 }
816
817 futures::select_biased! {
818 result = self.set_connection(conn, cx).fuse() => result,
819 _ = timeout => {
820 self.set_status(Status::ConnectionError, cx);
821 Err(anyhow!("timed out waiting on hello message from server"))
822 }
823 }
824 }
825 Err(EstablishConnectionError::Unauthorized) => {
826 self.state.write().credentials.take();
827 if read_from_keychain {
828 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
829 self.set_status(Status::SignedOut, cx);
830 self.authenticate_and_connect(false, cx).await
831 } else {
832 self.set_status(Status::ConnectionError, cx);
833 Err(EstablishConnectionError::Unauthorized)?
834 }
835 }
836 Err(EstablishConnectionError::UpgradeRequired) => {
837 self.set_status(Status::UpgradeRequired, cx);
838 Err(EstablishConnectionError::UpgradeRequired)?
839 }
840 Err(error) => {
841 self.set_status(Status::ConnectionError, cx);
842 Err(error)?
843 }
844 }
845 }
846 _ = &mut timeout => {
847 self.set_status(Status::ConnectionError, cx);
848 Err(anyhow!("timed out trying to establish connection"))
849 }
850 }
851 }
852
853 async fn set_connection(
854 self: &Arc<Self>,
855 conn: Connection,
856 cx: &AsyncAppContext,
857 ) -> Result<()> {
858 let executor = cx.background();
859 log::info!("add connection to peer");
860 let (connection_id, handle_io, mut incoming) = self
861 .peer
862 .add_connection(conn, move |duration| executor.timer(duration));
863 let handle_io = cx.background().spawn(handle_io);
864
865 let peer_id = async {
866 log::info!("waiting for server hello");
867 let message = incoming
868 .next()
869 .await
870 .ok_or_else(|| anyhow!("no hello message received"))?;
871 log::info!("got server hello");
872 let hello_message_type_name = message.payload_type_name().to_string();
873 let hello = message
874 .into_any()
875 .downcast::<TypedEnvelope<proto::Hello>>()
876 .map_err(|_| {
877 anyhow!(
878 "invalid hello message received: {:?}",
879 hello_message_type_name
880 )
881 })?;
882 let peer_id = hello
883 .payload
884 .peer_id
885 .ok_or_else(|| anyhow!("invalid peer id"))?;
886 Ok(peer_id)
887 };
888
889 let peer_id = match peer_id.await {
890 Ok(peer_id) => peer_id,
891 Err(error) => {
892 self.peer.disconnect(connection_id);
893 return Err(error);
894 }
895 };
896
897 log::info!(
898 "set status to connected (connection id: {:?}, peer id: {:?})",
899 connection_id,
900 peer_id
901 );
902 self.set_status(
903 Status::Connected {
904 peer_id,
905 connection_id,
906 },
907 cx,
908 );
909 cx.foreground()
910 .spawn({
911 let cx = cx.clone();
912 let this = self.clone();
913 async move {
914 while let Some(message) = incoming.next().await {
915 this.handle_message(message, &cx);
916 // Don't starve the main thread when receiving lots of messages at once.
917 smol::future::yield_now().await;
918 }
919 }
920 })
921 .detach();
922
923 let this = self.clone();
924 let cx = cx.clone();
925 cx.foreground()
926 .spawn(async move {
927 match handle_io.await {
928 Ok(()) => {
929 if this.status().borrow().clone()
930 == (Status::Connected {
931 connection_id,
932 peer_id,
933 })
934 {
935 this.set_status(Status::SignedOut, &cx);
936 }
937 }
938 Err(err) => {
939 log::error!("connection error: {:?}", err);
940 this.set_status(Status::ConnectionLost, &cx);
941 }
942 }
943 })
944 .detach();
945
946 Ok(())
947 }
948
949 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
950 #[cfg(any(test, feature = "test-support"))]
951 if let Some(callback) = self.authenticate.read().as_ref() {
952 return callback(cx);
953 }
954
955 self.authenticate_with_browser(cx)
956 }
957
958 fn establish_connection(
959 self: &Arc<Self>,
960 credentials: &Credentials,
961 cx: &AsyncAppContext,
962 ) -> Task<Result<Connection, EstablishConnectionError>> {
963 #[cfg(any(test, feature = "test-support"))]
964 if let Some(callback) = self.establish_connection.read().as_ref() {
965 return callback(credentials, cx);
966 }
967
968 self.establish_websocket_connection(credentials, cx)
969 }
970
971 async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
972 let preview_param = if is_preview { "?preview=1" } else { "" };
973 let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
974 let response = http.get(&url, Default::default(), false).await?;
975
976 // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
977 // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
978 // which requires authorization via an HTTP header.
979 //
980 // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
981 // of a collab server. In that case, a request to the /rpc endpoint will
982 // return an 'unauthorized' response.
983 let collab_url = if response.status().is_redirection() {
984 response
985 .headers()
986 .get("Location")
987 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
988 .to_str()
989 .map_err(EstablishConnectionError::other)?
990 .to_string()
991 } else if response.status() == StatusCode::UNAUTHORIZED {
992 url
993 } else {
994 Err(anyhow!(
995 "unexpected /rpc response status {}",
996 response.status()
997 ))?
998 };
999
1000 Url::parse(&collab_url).context("invalid rpc url")
1001 }
1002
1003 fn establish_websocket_connection(
1004 self: &Arc<Self>,
1005 credentials: &Credentials,
1006 cx: &AsyncAppContext,
1007 ) -> Task<Result<Connection, EstablishConnectionError>> {
1008 let is_preview = cx.read(|cx| {
1009 if cx.has_global::<ReleaseChannel>() {
1010 *cx.global::<ReleaseChannel>() == ReleaseChannel::Preview
1011 } else {
1012 false
1013 }
1014 });
1015
1016 let request = Request::builder()
1017 .header(
1018 "Authorization",
1019 format!("{} {}", credentials.user_id, credentials.access_token),
1020 )
1021 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
1022
1023 let http = self.http.clone();
1024 cx.background().spawn(async move {
1025 let mut rpc_url = Self::get_rpc_url(http, is_preview).await?;
1026 let rpc_host = rpc_url
1027 .host_str()
1028 .zip(rpc_url.port_or_known_default())
1029 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
1030 let stream = smol::net::TcpStream::connect(rpc_host).await?;
1031
1032 log::info!("connected to rpc endpoint {}", rpc_url);
1033
1034 match rpc_url.scheme() {
1035 "https" => {
1036 rpc_url.set_scheme("wss").unwrap();
1037 let request = request.uri(rpc_url.as_str()).body(())?;
1038 let (stream, _) =
1039 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
1040 Ok(Connection::new(
1041 stream
1042 .map_err(|error| anyhow!(error))
1043 .sink_map_err(|error| anyhow!(error)),
1044 ))
1045 }
1046 "http" => {
1047 rpc_url.set_scheme("ws").unwrap();
1048 let request = request.uri(rpc_url.as_str()).body(())?;
1049 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1050 Ok(Connection::new(
1051 stream
1052 .map_err(|error| anyhow!(error))
1053 .sink_map_err(|error| anyhow!(error)),
1054 ))
1055 }
1056 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1057 }
1058 })
1059 }
1060
1061 pub fn authenticate_with_browser(
1062 self: &Arc<Self>,
1063 cx: &AsyncAppContext,
1064 ) -> Task<Result<Credentials>> {
1065 let platform = cx.platform();
1066 let executor = cx.background();
1067 let telemetry = self.telemetry.clone();
1068 let http = self.http.clone();
1069
1070 let telemetry_settings =
1071 cx.read(|cx| *settings::get_setting::<TelemetrySettings>(None, cx));
1072
1073 executor.clone().spawn(async move {
1074 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1075 // zed server to encrypt the user's access token, so that it can'be intercepted by
1076 // any other app running on the user's device.
1077 let (public_key, private_key) =
1078 rpc::auth::keypair().expect("failed to generate keypair for auth");
1079 let public_key_string =
1080 String::try_from(public_key).expect("failed to serialize public key for auth");
1081
1082 if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1083 return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1084 }
1085
1086 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1087 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1088 let port = server.server_addr().port();
1089
1090 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1091 // that the user is signing in from a Zed app running on the same device.
1092 let mut url = format!(
1093 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1094 *ZED_SERVER_URL, port, public_key_string
1095 );
1096
1097 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1098 log::info!("impersonating user @{}", impersonate_login);
1099 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1100 }
1101
1102 platform.open_url(&url);
1103
1104 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1105 // access token from the query params.
1106 //
1107 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1108 // custom URL scheme instead of this local HTTP server.
1109 let (user_id, access_token) = executor
1110 .spawn(async move {
1111 for _ in 0..100 {
1112 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1113 let path = req.url();
1114 let mut user_id = None;
1115 let mut access_token = None;
1116 let url = Url::parse(&format!("http://example.com{}", path))
1117 .context("failed to parse login notification url")?;
1118 for (key, value) in url.query_pairs() {
1119 if key == "access_token" {
1120 access_token = Some(value.to_string());
1121 } else if key == "user_id" {
1122 user_id = Some(value.to_string());
1123 }
1124 }
1125
1126 let post_auth_url =
1127 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1128 req.respond(
1129 tiny_http::Response::empty(302).with_header(
1130 tiny_http::Header::from_bytes(
1131 &b"Location"[..],
1132 post_auth_url.as_bytes(),
1133 )
1134 .unwrap(),
1135 ),
1136 )
1137 .context("failed to respond to login http request")?;
1138 return Ok((
1139 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1140 access_token
1141 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1142 ));
1143 }
1144 }
1145
1146 Err(anyhow!("didn't receive login redirect"))
1147 })
1148 .await?;
1149
1150 let access_token = private_key
1151 .decrypt_string(&access_token)
1152 .context("failed to decrypt access token")?;
1153 platform.activate(true);
1154
1155 telemetry.report_mixpanel_event(
1156 "authenticate with browser",
1157 Default::default(),
1158 telemetry_settings,
1159 );
1160
1161 Ok(Credentials {
1162 user_id: user_id.parse()?,
1163 access_token,
1164 })
1165 })
1166 }
1167
1168 async fn authenticate_as_admin(
1169 http: Arc<dyn HttpClient>,
1170 login: String,
1171 mut api_token: String,
1172 ) -> Result<Credentials> {
1173 #[derive(Deserialize)]
1174 struct AuthenticatedUserResponse {
1175 user: User,
1176 }
1177
1178 #[derive(Deserialize)]
1179 struct User {
1180 id: u64,
1181 }
1182
1183 // Use the collab server's admin API to retrieve the id
1184 // of the impersonated user.
1185 let mut url = Self::get_rpc_url(http.clone(), false).await?;
1186 url.set_path("/user");
1187 url.set_query(Some(&format!("github_login={login}")));
1188 let request = Request::get(url.as_str())
1189 .header("Authorization", format!("token {api_token}"))
1190 .body("".into())?;
1191
1192 let mut response = http.send(request).await?;
1193 let mut body = String::new();
1194 response.body_mut().read_to_string(&mut body).await?;
1195 if !response.status().is_success() {
1196 Err(anyhow!(
1197 "admin user request failed {} - {}",
1198 response.status().as_u16(),
1199 body,
1200 ))?;
1201 }
1202 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1203
1204 // Use the admin API token to authenticate as the impersonated user.
1205 api_token.insert_str(0, "ADMIN_TOKEN:");
1206 Ok(Credentials {
1207 user_id: response.user.id,
1208 access_token: api_token,
1209 })
1210 }
1211
1212 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1213 self.peer.teardown();
1214 self.set_status(Status::SignedOut, cx);
1215 }
1216
1217 fn connection_id(&self) -> Result<ConnectionId> {
1218 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1219 Ok(connection_id)
1220 } else {
1221 Err(anyhow!("not connected"))
1222 }
1223 }
1224
1225 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1226 log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1227 self.peer.send(self.connection_id()?, message)
1228 }
1229
1230 pub fn request<T: RequestMessage>(
1231 &self,
1232 request: T,
1233 ) -> impl Future<Output = Result<T::Response>> {
1234 self.request_envelope(request)
1235 .map_ok(|envelope| envelope.payload)
1236 }
1237
1238 pub fn request_envelope<T: RequestMessage>(
1239 &self,
1240 request: T,
1241 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1242 let client_id = self.id;
1243 log::debug!(
1244 "rpc request start. client_id:{}. name:{}",
1245 client_id,
1246 T::NAME
1247 );
1248 let response = self
1249 .connection_id()
1250 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1251 async move {
1252 let response = response?.await;
1253 log::debug!(
1254 "rpc request finish. client_id:{}. name:{}",
1255 client_id,
1256 T::NAME
1257 );
1258 response
1259 }
1260 }
1261
1262 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1263 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1264 self.peer.respond(receipt, response)
1265 }
1266
1267 fn respond_with_error<T: RequestMessage>(
1268 &self,
1269 receipt: Receipt<T>,
1270 error: proto::Error,
1271 ) -> Result<()> {
1272 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1273 self.peer.respond_with_error(receipt, error)
1274 }
1275
1276 fn handle_message(
1277 self: &Arc<Client>,
1278 message: Box<dyn AnyTypedEnvelope>,
1279 cx: &AsyncAppContext,
1280 ) {
1281 let mut state = self.state.write();
1282 let type_name = message.payload_type_name();
1283 let payload_type_id = message.payload_type_id();
1284 let sender_id = message.original_sender_id();
1285
1286 let mut subscriber = None;
1287
1288 if let Some(message_model) = state
1289 .models_by_message_type
1290 .get(&payload_type_id)
1291 .and_then(|model| model.upgrade(cx))
1292 {
1293 subscriber = Some(Subscriber::Model(message_model));
1294 } else if let Some((extract_entity_id, entity_type_id)) =
1295 state.entity_id_extractors.get(&payload_type_id).zip(
1296 state
1297 .entity_types_by_message_type
1298 .get(&payload_type_id)
1299 .copied(),
1300 )
1301 {
1302 let entity_id = (extract_entity_id)(message.as_ref());
1303
1304 match state
1305 .entities_by_type_and_remote_id
1306 .get_mut(&(entity_type_id, entity_id))
1307 {
1308 Some(WeakSubscriber::Pending(pending)) => {
1309 pending.push(message);
1310 return;
1311 }
1312 Some(weak_subscriber @ _) => match weak_subscriber {
1313 WeakSubscriber::Model(handle) => {
1314 subscriber = handle.upgrade(cx).map(Subscriber::Model);
1315 }
1316 WeakSubscriber::View(handle) => {
1317 subscriber = Some(Subscriber::View(handle.clone()));
1318 }
1319 WeakSubscriber::Pending(_) => {}
1320 },
1321 _ => {}
1322 }
1323 }
1324
1325 let subscriber = if let Some(subscriber) = subscriber {
1326 subscriber
1327 } else {
1328 log::info!("unhandled message {}", type_name);
1329 self.peer.respond_with_unhandled_message(message).log_err();
1330 return;
1331 };
1332
1333 let handler = state.message_handlers.get(&payload_type_id).cloned();
1334 // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1335 // It also ensures we don't hold the lock while yielding back to the executor, as
1336 // that might cause the executor thread driving this future to block indefinitely.
1337 drop(state);
1338
1339 if let Some(handler) = handler {
1340 let future = handler(subscriber, message, &self, cx.clone());
1341 let client_id = self.id;
1342 log::debug!(
1343 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1344 client_id,
1345 sender_id,
1346 type_name
1347 );
1348 cx.foreground()
1349 .spawn(async move {
1350 match future.await {
1351 Ok(()) => {
1352 log::debug!(
1353 "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1354 client_id,
1355 sender_id,
1356 type_name
1357 );
1358 }
1359 Err(error) => {
1360 log::error!(
1361 "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1362 client_id,
1363 sender_id,
1364 type_name,
1365 error
1366 );
1367 }
1368 }
1369 })
1370 .detach();
1371 } else {
1372 log::info!("unhandled message {}", type_name);
1373 self.peer.respond_with_unhandled_message(message).log_err();
1374 }
1375 }
1376
1377 pub fn telemetry(&self) -> &Arc<Telemetry> {
1378 &self.telemetry
1379 }
1380}
1381
1382fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1383 if IMPERSONATE_LOGIN.is_some() {
1384 return None;
1385 }
1386
1387 let (user_id, access_token) = cx
1388 .platform()
1389 .read_credentials(&ZED_SERVER_URL)
1390 .log_err()
1391 .flatten()?;
1392 Some(Credentials {
1393 user_id: user_id.parse().ok()?,
1394 access_token: String::from_utf8(access_token).ok()?,
1395 })
1396}
1397
1398fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1399 cx.platform().write_credentials(
1400 &ZED_SERVER_URL,
1401 &credentials.user_id.to_string(),
1402 credentials.access_token.as_bytes(),
1403 )
1404}
1405
1406const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1407
1408pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1409 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1410}
1411
1412pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1413 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1414 let mut parts = path.split('/');
1415 let id = parts.next()?.parse::<u64>().ok()?;
1416 let access_token = parts.next()?;
1417 if access_token.is_empty() {
1418 return None;
1419 }
1420 Some((id, access_token.to_string()))
1421}
1422
1423#[cfg(test)]
1424mod tests {
1425 use super::*;
1426 use crate::test::FakeServer;
1427 use gpui::{executor::Deterministic, TestAppContext};
1428 use parking_lot::Mutex;
1429 use std::future;
1430 use util::http::FakeHttpClient;
1431
1432 #[gpui::test(iterations = 10)]
1433 async fn test_reconnection(cx: &mut TestAppContext) {
1434 cx.foreground().forbid_parking();
1435
1436 let user_id = 5;
1437 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1438 let server = FakeServer::for_client(user_id, &client, cx).await;
1439 let mut status = client.status();
1440 assert!(matches!(
1441 status.next().await,
1442 Some(Status::Connected { .. })
1443 ));
1444 assert_eq!(server.auth_count(), 1);
1445
1446 server.forbid_connections();
1447 server.disconnect();
1448 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1449
1450 server.allow_connections();
1451 cx.foreground().advance_clock(Duration::from_secs(10));
1452 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1453 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1454
1455 server.forbid_connections();
1456 server.disconnect();
1457 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1458
1459 // Clear cached credentials after authentication fails
1460 server.roll_access_token();
1461 server.allow_connections();
1462 cx.foreground().advance_clock(Duration::from_secs(10));
1463 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1464 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1465 }
1466
1467 #[gpui::test(iterations = 10)]
1468 async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1469 deterministic.forbid_parking();
1470
1471 let user_id = 5;
1472 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1473 let mut status = client.status();
1474
1475 // Time out when client tries to connect.
1476 client.override_authenticate(move |cx| {
1477 cx.foreground().spawn(async move {
1478 Ok(Credentials {
1479 user_id,
1480 access_token: "token".into(),
1481 })
1482 })
1483 });
1484 client.override_establish_connection(|_, cx| {
1485 cx.foreground().spawn(async move {
1486 future::pending::<()>().await;
1487 unreachable!()
1488 })
1489 });
1490 let auth_and_connect = cx.spawn({
1491 let client = client.clone();
1492 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1493 });
1494 deterministic.run_until_parked();
1495 assert!(matches!(status.next().await, Some(Status::Connecting)));
1496
1497 deterministic.advance_clock(CONNECTION_TIMEOUT);
1498 assert!(matches!(
1499 status.next().await,
1500 Some(Status::ConnectionError { .. })
1501 ));
1502 auth_and_connect.await.unwrap_err();
1503
1504 // Allow the connection to be established.
1505 let server = FakeServer::for_client(user_id, &client, cx).await;
1506 assert!(matches!(
1507 status.next().await,
1508 Some(Status::Connected { .. })
1509 ));
1510
1511 // Disconnect client.
1512 server.forbid_connections();
1513 server.disconnect();
1514 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1515
1516 // Time out when re-establishing the connection.
1517 server.allow_connections();
1518 client.override_establish_connection(|_, cx| {
1519 cx.foreground().spawn(async move {
1520 future::pending::<()>().await;
1521 unreachable!()
1522 })
1523 });
1524 deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1525 assert!(matches!(
1526 status.next().await,
1527 Some(Status::Reconnecting { .. })
1528 ));
1529
1530 deterministic.advance_clock(CONNECTION_TIMEOUT);
1531 assert!(matches!(
1532 status.next().await,
1533 Some(Status::ReconnectionError { .. })
1534 ));
1535 }
1536
1537 #[gpui::test(iterations = 10)]
1538 async fn test_authenticating_more_than_once(
1539 cx: &mut TestAppContext,
1540 deterministic: Arc<Deterministic>,
1541 ) {
1542 cx.foreground().forbid_parking();
1543
1544 let auth_count = Arc::new(Mutex::new(0));
1545 let dropped_auth_count = Arc::new(Mutex::new(0));
1546 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1547 client.override_authenticate({
1548 let auth_count = auth_count.clone();
1549 let dropped_auth_count = dropped_auth_count.clone();
1550 move |cx| {
1551 let auth_count = auth_count.clone();
1552 let dropped_auth_count = dropped_auth_count.clone();
1553 cx.foreground().spawn(async move {
1554 *auth_count.lock() += 1;
1555 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1556 future::pending::<()>().await;
1557 unreachable!()
1558 })
1559 }
1560 });
1561
1562 let _authenticate = cx.spawn(|cx| {
1563 let client = client.clone();
1564 async move { client.authenticate_and_connect(false, &cx).await }
1565 });
1566 deterministic.run_until_parked();
1567 assert_eq!(*auth_count.lock(), 1);
1568 assert_eq!(*dropped_auth_count.lock(), 0);
1569
1570 let _authenticate = cx.spawn(|cx| {
1571 let client = client.clone();
1572 async move { client.authenticate_and_connect(false, &cx).await }
1573 });
1574 deterministic.run_until_parked();
1575 assert_eq!(*auth_count.lock(), 2);
1576 assert_eq!(*dropped_auth_count.lock(), 1);
1577 }
1578
1579 #[test]
1580 fn test_encode_and_decode_worktree_url() {
1581 let url = encode_worktree_url(5, "deadbeef");
1582 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1583 assert_eq!(
1584 decode_worktree_url(&format!("\n {}\t", url)),
1585 Some((5, "deadbeef".to_string()))
1586 );
1587 assert_eq!(decode_worktree_url("not://the-right-format"), None);
1588 }
1589
1590 #[gpui::test]
1591 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1592 cx.foreground().forbid_parking();
1593
1594 let user_id = 5;
1595 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1596 let server = FakeServer::for_client(user_id, &client, cx).await;
1597
1598 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1599 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1600 client.add_model_message_handler(
1601 move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1602 match model.read_with(&cx, |model, _| model.id) {
1603 1 => done_tx1.try_send(()).unwrap(),
1604 2 => done_tx2.try_send(()).unwrap(),
1605 _ => unreachable!(),
1606 }
1607 async { Ok(()) }
1608 },
1609 );
1610 let model1 = cx.add_model(|_| Model {
1611 id: 1,
1612 subscription: None,
1613 });
1614 let model2 = cx.add_model(|_| Model {
1615 id: 2,
1616 subscription: None,
1617 });
1618 let model3 = cx.add_model(|_| Model {
1619 id: 3,
1620 subscription: None,
1621 });
1622
1623 let _subscription1 = client
1624 .subscribe_to_entity(1)
1625 .unwrap()
1626 .set_model(&model1, &mut cx.to_async());
1627 let _subscription2 = client
1628 .subscribe_to_entity(2)
1629 .unwrap()
1630 .set_model(&model2, &mut cx.to_async());
1631 // Ensure dropping a subscription for the same entity type still allows receiving of
1632 // messages for other entity IDs of the same type.
1633 let subscription3 = client
1634 .subscribe_to_entity(3)
1635 .unwrap()
1636 .set_model(&model3, &mut cx.to_async());
1637 drop(subscription3);
1638
1639 server.send(proto::JoinProject { project_id: 1 });
1640 server.send(proto::JoinProject { project_id: 2 });
1641 done_rx1.next().await.unwrap();
1642 done_rx2.next().await.unwrap();
1643 }
1644
1645 #[gpui::test]
1646 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1647 cx.foreground().forbid_parking();
1648
1649 let user_id = 5;
1650 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1651 let server = FakeServer::for_client(user_id, &client, cx).await;
1652
1653 let model = cx.add_model(|_| Model::default());
1654 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1655 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1656 let subscription1 = client.add_message_handler(
1657 model.clone(),
1658 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1659 done_tx1.try_send(()).unwrap();
1660 async { Ok(()) }
1661 },
1662 );
1663 drop(subscription1);
1664 let _subscription2 = client.add_message_handler(
1665 model.clone(),
1666 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1667 done_tx2.try_send(()).unwrap();
1668 async { Ok(()) }
1669 },
1670 );
1671 server.send(proto::Ping {});
1672 done_rx2.next().await.unwrap();
1673 }
1674
1675 #[gpui::test]
1676 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1677 cx.foreground().forbid_parking();
1678
1679 let user_id = 5;
1680 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1681 let server = FakeServer::for_client(user_id, &client, cx).await;
1682
1683 let model = cx.add_model(|_| Model::default());
1684 let (done_tx, mut done_rx) = smol::channel::unbounded();
1685 let subscription = client.add_message_handler(
1686 model.clone(),
1687 move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1688 model.update(&mut cx, |model, _| model.subscription.take());
1689 done_tx.try_send(()).unwrap();
1690 async { Ok(()) }
1691 },
1692 );
1693 model.update(cx, |model, _| {
1694 model.subscription = Some(subscription);
1695 });
1696 server.send(proto::Ping {});
1697 done_rx.next().await.unwrap();
1698 }
1699
1700 #[derive(Default)]
1701 struct Model {
1702 id: usize,
1703 subscription: Option<Subscription>,
1704 }
1705
1706 impl Entity for Model {
1707 type Event = ();
1708 }
1709}