1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod telemetry;
5pub mod user;
6
7use anyhow::{anyhow, Context as _, Result};
8use async_recursion::async_recursion;
9use async_tungstenite::tungstenite::{
10 error::Error as WebsocketError,
11 http::{Request, StatusCode},
12};
13use futures::{
14 future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _,
15 TryStreamExt,
16};
17use gpui::{
18 actions, serde_json, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Model,
19 SemanticVersion, Task, WeakModel,
20};
21use lazy_static::lazy_static;
22use parking_lot::RwLock;
23use postage::watch;
24use rand::prelude::*;
25use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use settings::Settings;
29use std::{
30 any::TypeId,
31 collections::HashMap,
32 convert::TryFrom,
33 fmt::Write as _,
34 future::Future,
35 marker::PhantomData,
36 path::PathBuf,
37 sync::{atomic::AtomicU64, Arc, Weak},
38 time::{Duration, Instant},
39};
40use telemetry::Telemetry;
41use thiserror::Error;
42use url::Url;
43use util::channel::ReleaseChannel;
44use util::http::HttpClient;
45use util::{ResultExt, TryFutureExt};
46
47pub use rpc::*;
48pub use telemetry::ClickhouseEvent;
49pub use user::*;
50
51lazy_static! {
52 pub static ref ZED_SERVER_URL: String =
53 std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
54 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
55 .ok()
56 .and_then(|s| if s.is_empty() { None } else { Some(s) });
57 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
58 .ok()
59 .and_then(|s| if s.is_empty() { None } else { Some(s) });
60 pub static ref ZED_APP_VERSION: Option<SemanticVersion> = std::env::var("ZED_APP_VERSION")
61 .ok()
62 .and_then(|v| v.parse().ok());
63 pub static ref ZED_APP_PATH: Option<PathBuf> =
64 std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
65 pub static ref ZED_ALWAYS_ACTIVE: bool =
66 std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| e.len() > 0);
67}
68
69pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
70pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
71pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
72
73actions!(SignIn, SignOut, Reconnect);
74
75pub fn init_settings(cx: &mut AppContext) {
76 TelemetrySettings::register(cx);
77}
78
79pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
80 init_settings(cx);
81
82 let client = Arc::downgrade(client);
83 cx.register_action_type::<SignIn>();
84 cx.on_action({
85 let client = client.clone();
86 move |_: &SignIn, cx| {
87 if let Some(client) = client.upgrade() {
88 cx.spawn(
89 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
90 )
91 .detach();
92 }
93 }
94 });
95
96 cx.register_action_type::<SignOut>();
97 cx.on_action({
98 let client = client.clone();
99 move |_: &SignOut, cx| {
100 if let Some(client) = client.upgrade() {
101 cx.spawn(|cx| async move {
102 client.disconnect(&cx);
103 })
104 .detach();
105 }
106 }
107 });
108
109 cx.register_action_type::<Reconnect>();
110 cx.on_action({
111 let client = client.clone();
112 move |_: &Reconnect, cx| {
113 if let Some(client) = client.upgrade() {
114 cx.spawn(|cx| async move {
115 client.reconnect(&cx);
116 })
117 .detach();
118 }
119 }
120 });
121}
122
123pub struct Client {
124 id: AtomicU64,
125 peer: Arc<Peer>,
126 http: Arc<dyn HttpClient>,
127 telemetry: Arc<Telemetry>,
128 state: RwLock<ClientState>,
129
130 #[allow(clippy::type_complexity)]
131 #[cfg(any(test, feature = "test-support"))]
132 authenticate: RwLock<
133 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
134 >,
135
136 #[allow(clippy::type_complexity)]
137 #[cfg(any(test, feature = "test-support"))]
138 establish_connection: RwLock<
139 Option<
140 Box<
141 dyn 'static
142 + Send
143 + Sync
144 + Fn(
145 &Credentials,
146 &AsyncAppContext,
147 ) -> Task<Result<Connection, EstablishConnectionError>>,
148 >,
149 >,
150 >,
151}
152
153#[derive(Error, Debug)]
154pub enum EstablishConnectionError {
155 #[error("upgrade required")]
156 UpgradeRequired,
157 #[error("unauthorized")]
158 Unauthorized,
159 #[error("{0}")]
160 Other(#[from] anyhow::Error),
161 #[error("{0}")]
162 Http(#[from] util::http::Error),
163 #[error("{0}")]
164 Io(#[from] std::io::Error),
165 #[error("{0}")]
166 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
167}
168
169impl From<WebsocketError> for EstablishConnectionError {
170 fn from(error: WebsocketError) -> Self {
171 if let WebsocketError::Http(response) = &error {
172 match response.status() {
173 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
174 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
175 _ => {}
176 }
177 }
178 EstablishConnectionError::Other(error.into())
179 }
180}
181
182impl EstablishConnectionError {
183 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
184 Self::Other(error.into())
185 }
186}
187
188#[derive(Copy, Clone, Debug, PartialEq)]
189pub enum Status {
190 SignedOut,
191 UpgradeRequired,
192 Authenticating,
193 Connecting,
194 ConnectionError,
195 Connected {
196 peer_id: PeerId,
197 connection_id: ConnectionId,
198 },
199 ConnectionLost,
200 Reauthenticating,
201 Reconnecting,
202 ReconnectionError {
203 next_reconnection: Instant,
204 },
205}
206
207impl Status {
208 pub fn is_connected(&self) -> bool {
209 matches!(self, Self::Connected { .. })
210 }
211
212 pub fn is_signed_out(&self) -> bool {
213 matches!(self, Self::SignedOut | Self::UpgradeRequired)
214 }
215}
216
217struct ClientState {
218 credentials: Option<Credentials>,
219 status: (watch::Sender<Status>, watch::Receiver<Status>),
220 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
221 _reconnect_task: Option<Task<()>>,
222 reconnect_interval: Duration,
223 entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
224 models_by_message_type: HashMap<TypeId, AnyWeakModel>,
225 entity_types_by_message_type: HashMap<TypeId, TypeId>,
226 #[allow(clippy::type_complexity)]
227 message_handlers: HashMap<
228 TypeId,
229 Arc<
230 dyn Send
231 + Sync
232 + Fn(
233 AnyModel,
234 Box<dyn AnyTypedEnvelope>,
235 &Arc<Client>,
236 AsyncAppContext,
237 ) -> LocalBoxFuture<'static, Result<()>>,
238 >,
239 >,
240}
241
242enum WeakSubscriber {
243 Entity { handle: AnyWeakModel },
244 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
245}
246
247#[derive(Clone, Debug)]
248pub struct Credentials {
249 pub user_id: u64,
250 pub access_token: String,
251}
252
253impl Default for ClientState {
254 fn default() -> Self {
255 Self {
256 credentials: None,
257 status: watch::channel_with(Status::SignedOut),
258 entity_id_extractors: Default::default(),
259 _reconnect_task: None,
260 reconnect_interval: Duration::from_secs(5),
261 models_by_message_type: Default::default(),
262 entities_by_type_and_remote_id: Default::default(),
263 entity_types_by_message_type: Default::default(),
264 message_handlers: Default::default(),
265 }
266 }
267}
268
269pub enum Subscription {
270 Entity {
271 client: Weak<Client>,
272 id: (TypeId, u64),
273 },
274 Message {
275 client: Weak<Client>,
276 id: TypeId,
277 },
278}
279
280impl Drop for Subscription {
281 fn drop(&mut self) {
282 match self {
283 Subscription::Entity { client, id } => {
284 if let Some(client) = client.upgrade() {
285 let mut state = client.state.write();
286 let _ = state.entities_by_type_and_remote_id.remove(id);
287 }
288 }
289 Subscription::Message { client, id } => {
290 if let Some(client) = client.upgrade() {
291 let mut state = client.state.write();
292 let _ = state.entity_types_by_message_type.remove(id);
293 let _ = state.message_handlers.remove(id);
294 }
295 }
296 }
297 }
298}
299
300pub struct PendingEntitySubscription<T: 'static> {
301 client: Arc<Client>,
302 remote_id: u64,
303 _entity_type: PhantomData<T>,
304 consumed: bool,
305}
306
307impl<T: 'static> PendingEntitySubscription<T> {
308 pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
309 self.consumed = true;
310 let mut state = self.client.state.write();
311 let id = (TypeId::of::<T>(), self.remote_id);
312 let Some(WeakSubscriber::Pending(messages)) =
313 state.entities_by_type_and_remote_id.remove(&id)
314 else {
315 unreachable!()
316 };
317
318 state.entities_by_type_and_remote_id.insert(
319 id,
320 WeakSubscriber::Entity {
321 handle: model.downgrade().into(),
322 },
323 );
324 drop(state);
325 for message in messages {
326 self.client.handle_message(message, cx);
327 }
328 Subscription::Entity {
329 client: Arc::downgrade(&self.client),
330 id,
331 }
332 }
333}
334
335impl<T: 'static> Drop for PendingEntitySubscription<T> {
336 fn drop(&mut self) {
337 if !self.consumed {
338 let mut state = self.client.state.write();
339 if let Some(WeakSubscriber::Pending(messages)) = state
340 .entities_by_type_and_remote_id
341 .remove(&(TypeId::of::<T>(), self.remote_id))
342 {
343 for message in messages {
344 log::info!("unhandled message {}", message.payload_type_name());
345 }
346 }
347 }
348 }
349}
350
351#[derive(Copy, Clone)]
352pub struct TelemetrySettings {
353 pub diagnostics: bool,
354 pub metrics: bool,
355}
356
357#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
358pub struct TelemetrySettingsContent {
359 pub diagnostics: Option<bool>,
360 pub metrics: Option<bool>,
361}
362
363impl settings::Settings for TelemetrySettings {
364 const KEY: Option<&'static str> = Some("telemetry");
365
366 type FileContent = TelemetrySettingsContent;
367
368 fn load(
369 default_value: &Self::FileContent,
370 user_values: &[&Self::FileContent],
371 _: &mut AppContext,
372 ) -> Result<Self> {
373 Ok(Self {
374 diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or(
375 default_value
376 .diagnostics
377 .ok_or_else(Self::missing_default)?,
378 ),
379 metrics: user_values
380 .first()
381 .and_then(|v| v.metrics)
382 .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?),
383 })
384 }
385}
386
387impl Client {
388 pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
389 Arc::new(Self {
390 id: AtomicU64::new(0),
391 peer: Peer::new(0),
392 telemetry: Telemetry::new(http.clone(), cx),
393 http,
394 state: Default::default(),
395
396 #[cfg(any(test, feature = "test-support"))]
397 authenticate: Default::default(),
398 #[cfg(any(test, feature = "test-support"))]
399 establish_connection: Default::default(),
400 })
401 }
402
403 pub fn id(&self) -> u64 {
404 self.id.load(std::sync::atomic::Ordering::SeqCst)
405 }
406
407 pub fn http_client(&self) -> Arc<dyn HttpClient> {
408 self.http.clone()
409 }
410
411 pub fn set_id(&self, id: u64) -> &Self {
412 self.id.store(id, std::sync::atomic::Ordering::SeqCst);
413 self
414 }
415
416 #[cfg(any(test, feature = "test-support"))]
417 pub fn teardown(&self) {
418 let mut state = self.state.write();
419 state._reconnect_task.take();
420 state.message_handlers.clear();
421 state.models_by_message_type.clear();
422 state.entities_by_type_and_remote_id.clear();
423 state.entity_id_extractors.clear();
424 self.peer.teardown();
425 }
426
427 #[cfg(any(test, feature = "test-support"))]
428 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
429 where
430 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
431 {
432 *self.authenticate.write() = Some(Box::new(authenticate));
433 self
434 }
435
436 #[cfg(any(test, feature = "test-support"))]
437 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
438 where
439 F: 'static
440 + Send
441 + Sync
442 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
443 {
444 *self.establish_connection.write() = Some(Box::new(connect));
445 self
446 }
447
448 pub fn user_id(&self) -> Option<u64> {
449 self.state
450 .read()
451 .credentials
452 .as_ref()
453 .map(|credentials| credentials.user_id)
454 }
455
456 pub fn peer_id(&self) -> Option<PeerId> {
457 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
458 Some(*peer_id)
459 } else {
460 None
461 }
462 }
463
464 pub fn status(&self) -> watch::Receiver<Status> {
465 self.state.read().status.1.clone()
466 }
467
468 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
469 log::info!("set status on client {}: {:?}", self.id(), status);
470 let mut state = self.state.write();
471 *state.status.0.borrow_mut() = status;
472
473 match status {
474 Status::Connected { .. } => {
475 state._reconnect_task = None;
476 }
477 Status::ConnectionLost => {
478 let this = self.clone();
479 let reconnect_interval = state.reconnect_interval;
480 state._reconnect_task = Some(cx.spawn(move |cx| async move {
481 #[cfg(any(test, feature = "test-support"))]
482 let mut rng = StdRng::seed_from_u64(0);
483 #[cfg(not(any(test, feature = "test-support")))]
484 let mut rng = StdRng::from_entropy();
485
486 let mut delay = INITIAL_RECONNECTION_DELAY;
487 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
488 log::error!("failed to connect {}", error);
489 if matches!(*this.status().borrow(), Status::ConnectionError) {
490 this.set_status(
491 Status::ReconnectionError {
492 next_reconnection: Instant::now() + delay,
493 },
494 &cx,
495 );
496 cx.background_executor().timer(delay).await;
497 delay = delay
498 .mul_f32(rng.gen_range(1.0..=2.0))
499 .min(reconnect_interval);
500 } else {
501 break;
502 }
503 }
504 }));
505 }
506 Status::SignedOut | Status::UpgradeRequired => {
507 cx.update(|cx| self.telemetry.set_authenticated_user_info(None, false, cx))
508 .log_err();
509 state._reconnect_task.take();
510 }
511 _ => {}
512 }
513 }
514
515 pub fn subscribe_to_entity<T>(
516 self: &Arc<Self>,
517 remote_id: u64,
518 ) -> Result<PendingEntitySubscription<T>>
519 where
520 T: 'static,
521 {
522 let id = (TypeId::of::<T>(), remote_id);
523
524 let mut state = self.state.write();
525 if state.entities_by_type_and_remote_id.contains_key(&id) {
526 return Err(anyhow!("already subscribed to entity"));
527 } else {
528 state
529 .entities_by_type_and_remote_id
530 .insert(id, WeakSubscriber::Pending(Default::default()));
531 Ok(PendingEntitySubscription {
532 client: self.clone(),
533 remote_id,
534 consumed: false,
535 _entity_type: PhantomData,
536 })
537 }
538 }
539
540 #[track_caller]
541 pub fn add_message_handler<M, E, H, F>(
542 self: &Arc<Self>,
543 entity: WeakModel<E>,
544 handler: H,
545 ) -> Subscription
546 where
547 M: EnvelopedMessage,
548 E: 'static,
549 H: 'static
550 + Sync
551 + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
552 + Send
553 + Sync,
554 F: 'static + Future<Output = Result<()>>,
555 {
556 let message_type_id = TypeId::of::<M>();
557
558 let mut state = self.state.write();
559 state
560 .models_by_message_type
561 .insert(message_type_id, entity.into());
562
563 let prev_handler = state.message_handlers.insert(
564 message_type_id,
565 Arc::new(move |subscriber, envelope, client, cx| {
566 let subscriber = subscriber.downcast::<E>().unwrap();
567 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
568 handler(subscriber, *envelope, client.clone(), cx).boxed_local()
569 }),
570 );
571 if prev_handler.is_some() {
572 let location = std::panic::Location::caller();
573 panic!(
574 "{}:{} registered handler for the same message {} twice",
575 location.file(),
576 location.line(),
577 std::any::type_name::<M>()
578 );
579 }
580
581 Subscription::Message {
582 client: Arc::downgrade(self),
583 id: message_type_id,
584 }
585 }
586
587 pub fn add_request_handler<M, E, H, F>(
588 self: &Arc<Self>,
589 model: WeakModel<E>,
590 handler: H,
591 ) -> Subscription
592 where
593 M: RequestMessage,
594 E: 'static,
595 H: 'static
596 + Sync
597 + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
598 + Send
599 + Sync,
600 F: 'static + Future<Output = Result<M::Response>>,
601 {
602 self.add_message_handler(model, move |handle, envelope, this, cx| {
603 Self::respond_to_request(
604 envelope.receipt(),
605 handler(handle, envelope, this.clone(), cx),
606 this,
607 )
608 })
609 }
610
611 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
612 where
613 M: EntityMessage,
614 E: 'static,
615 H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
616 F: 'static + Future<Output = Result<()>>,
617 {
618 self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
619 handler(subscriber.downcast::<E>().unwrap(), message, client, cx)
620 })
621 }
622
623 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
624 where
625 M: EntityMessage,
626 E: 'static,
627 H: 'static + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
628 F: 'static + Future<Output = Result<()>>,
629 {
630 let model_type_id = TypeId::of::<E>();
631 let message_type_id = TypeId::of::<M>();
632
633 let mut state = self.state.write();
634 state
635 .entity_types_by_message_type
636 .insert(message_type_id, model_type_id);
637 state
638 .entity_id_extractors
639 .entry(message_type_id)
640 .or_insert_with(|| {
641 |envelope| {
642 envelope
643 .as_any()
644 .downcast_ref::<TypedEnvelope<M>>()
645 .unwrap()
646 .payload
647 .remote_entity_id()
648 }
649 });
650 let prev_handler = state.message_handlers.insert(
651 message_type_id,
652 Arc::new(move |handle, envelope, client, cx| {
653 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
654 handler(handle, *envelope, client.clone(), cx).boxed_local()
655 }),
656 );
657 if prev_handler.is_some() {
658 panic!("registered handler for the same message twice");
659 }
660 }
661
662 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
663 where
664 M: EntityMessage + RequestMessage,
665 E: 'static,
666 H: 'static + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
667 F: 'static + Future<Output = Result<M::Response>>,
668 {
669 self.add_model_message_handler(move |entity, envelope, client, cx| {
670 Self::respond_to_request::<M, _>(
671 envelope.receipt(),
672 handler(entity, envelope, client.clone(), cx),
673 client,
674 )
675 })
676 }
677
678 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
679 receipt: Receipt<T>,
680 response: F,
681 client: Arc<Self>,
682 ) -> Result<()> {
683 match response.await {
684 Ok(response) => {
685 client.respond(receipt, response)?;
686 Ok(())
687 }
688 Err(error) => {
689 client.respond_with_error(
690 receipt,
691 proto::Error {
692 message: format!("{:?}", error),
693 },
694 )?;
695 Err(error)
696 }
697 }
698 }
699
700 pub async fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
701 read_credentials_from_keychain(cx).await.is_some()
702 }
703
704 #[async_recursion(?Send)]
705 pub async fn authenticate_and_connect(
706 self: &Arc<Self>,
707 try_keychain: bool,
708 cx: &AsyncAppContext,
709 ) -> anyhow::Result<()> {
710 let was_disconnected = match *self.status().borrow() {
711 Status::SignedOut => true,
712 Status::ConnectionError
713 | Status::ConnectionLost
714 | Status::Authenticating { .. }
715 | Status::Reauthenticating { .. }
716 | Status::ReconnectionError { .. } => false,
717 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
718 return Ok(())
719 }
720 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
721 };
722
723 if was_disconnected {
724 self.set_status(Status::Authenticating, cx);
725 } else {
726 self.set_status(Status::Reauthenticating, cx)
727 }
728
729 let mut read_from_keychain = false;
730 let mut credentials = self.state.read().credentials.clone();
731 if credentials.is_none() && try_keychain {
732 credentials = read_credentials_from_keychain(cx).await;
733 read_from_keychain = credentials.is_some();
734 }
735 if credentials.is_none() {
736 let mut status_rx = self.status();
737 let _ = status_rx.next().await;
738 futures::select_biased! {
739 authenticate = self.authenticate(cx).fuse() => {
740 match authenticate {
741 Ok(creds) => credentials = Some(creds),
742 Err(err) => {
743 self.set_status(Status::ConnectionError, cx);
744 return Err(err);
745 }
746 }
747 }
748 _ = status_rx.next().fuse() => {
749 return Err(anyhow!("authentication canceled"));
750 }
751 }
752 }
753 let credentials = credentials.unwrap();
754 self.set_id(credentials.user_id);
755
756 if was_disconnected {
757 self.set_status(Status::Connecting, cx);
758 } else {
759 self.set_status(Status::Reconnecting, cx);
760 }
761
762 let mut timeout =
763 futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
764 futures::select_biased! {
765 connection = self.establish_connection(&credentials, cx).fuse() => {
766 match connection {
767 Ok(conn) => {
768 self.state.write().credentials = Some(credentials.clone());
769 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
770 write_credentials_to_keychain(credentials, cx).log_err();
771 }
772
773 futures::select_biased! {
774 result = self.set_connection(conn, cx).fuse() => result,
775 _ = timeout => {
776 self.set_status(Status::ConnectionError, cx);
777 Err(anyhow!("timed out waiting on hello message from server"))
778 }
779 }
780 }
781 Err(EstablishConnectionError::Unauthorized) => {
782 self.state.write().credentials.take();
783 if read_from_keychain {
784 delete_credentials_from_keychain(cx).log_err();
785 self.set_status(Status::SignedOut, cx);
786 self.authenticate_and_connect(false, cx).await
787 } else {
788 self.set_status(Status::ConnectionError, cx);
789 Err(EstablishConnectionError::Unauthorized)?
790 }
791 }
792 Err(EstablishConnectionError::UpgradeRequired) => {
793 self.set_status(Status::UpgradeRequired, cx);
794 Err(EstablishConnectionError::UpgradeRequired)?
795 }
796 Err(error) => {
797 self.set_status(Status::ConnectionError, cx);
798 Err(error)?
799 }
800 }
801 }
802 _ = &mut timeout => {
803 self.set_status(Status::ConnectionError, cx);
804 Err(anyhow!("timed out trying to establish connection"))
805 }
806 }
807 }
808
809 async fn set_connection(
810 self: &Arc<Self>,
811 conn: Connection,
812 cx: &AsyncAppContext,
813 ) -> Result<()> {
814 let executor = cx.background_executor();
815 log::info!("add connection to peer");
816 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
817 let executor = executor.clone();
818 move |duration| executor.timer(duration)
819 });
820 let handle_io = executor.spawn(handle_io);
821
822 let peer_id = async {
823 log::info!("waiting for server hello");
824 let message = incoming
825 .next()
826 .await
827 .ok_or_else(|| anyhow!("no hello message received"))?;
828 log::info!("got server hello");
829 let hello_message_type_name = message.payload_type_name().to_string();
830 let hello = message
831 .into_any()
832 .downcast::<TypedEnvelope<proto::Hello>>()
833 .map_err(|_| {
834 anyhow!(
835 "invalid hello message received: {:?}",
836 hello_message_type_name
837 )
838 })?;
839 let peer_id = hello
840 .payload
841 .peer_id
842 .ok_or_else(|| anyhow!("invalid peer id"))?;
843 Ok(peer_id)
844 };
845
846 let peer_id = match peer_id.await {
847 Ok(peer_id) => peer_id,
848 Err(error) => {
849 self.peer.disconnect(connection_id);
850 return Err(error);
851 }
852 };
853
854 log::info!(
855 "set status to connected (connection id: {:?}, peer id: {:?})",
856 connection_id,
857 peer_id
858 );
859 self.set_status(
860 Status::Connected {
861 peer_id,
862 connection_id,
863 },
864 cx,
865 );
866
867 cx.spawn({
868 let this = self.clone();
869 |cx| {
870 async move {
871 while let Some(message) = incoming.next().await {
872 this.handle_message(message, &cx);
873 // Don't starve the main thread when receiving lots of messages at once.
874 smol::future::yield_now().await;
875 }
876 }
877 }
878 })
879 .detach();
880
881 cx.spawn({
882 let this = self.clone();
883 move |cx| async move {
884 match handle_io.await {
885 Ok(()) => {
886 if this.status().borrow().clone()
887 == (Status::Connected {
888 connection_id,
889 peer_id,
890 })
891 {
892 this.set_status(Status::SignedOut, &cx);
893 }
894 }
895 Err(err) => {
896 log::error!("connection error: {:?}", err);
897 this.set_status(Status::ConnectionLost, &cx);
898 }
899 }
900 }
901 })
902 .detach();
903
904 Ok(())
905 }
906
907 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
908 #[cfg(any(test, feature = "test-support"))]
909 if let Some(callback) = self.authenticate.read().as_ref() {
910 return callback(cx);
911 }
912
913 self.authenticate_with_browser(cx)
914 }
915
916 fn establish_connection(
917 self: &Arc<Self>,
918 credentials: &Credentials,
919 cx: &AsyncAppContext,
920 ) -> Task<Result<Connection, EstablishConnectionError>> {
921 #[cfg(any(test, feature = "test-support"))]
922 if let Some(callback) = self.establish_connection.read().as_ref() {
923 return callback(credentials, cx);
924 }
925
926 self.establish_websocket_connection(credentials, cx)
927 }
928
929 async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
930 let preview_param = if is_preview { "?preview=1" } else { "" };
931 let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
932 let response = http.get(&url, Default::default(), false).await?;
933
934 // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
935 // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
936 // which requires authorization via an HTTP header.
937 //
938 // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
939 // of a collab server. In that case, a request to the /rpc endpoint will
940 // return an 'unauthorized' response.
941 let collab_url = if response.status().is_redirection() {
942 response
943 .headers()
944 .get("Location")
945 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
946 .to_str()
947 .map_err(EstablishConnectionError::other)?
948 .to_string()
949 } else if response.status() == StatusCode::UNAUTHORIZED {
950 url
951 } else {
952 Err(anyhow!(
953 "unexpected /rpc response status {}",
954 response.status()
955 ))?
956 };
957
958 Url::parse(&collab_url).context("invalid rpc url")
959 }
960
961 fn establish_websocket_connection(
962 self: &Arc<Self>,
963 credentials: &Credentials,
964 cx: &AsyncAppContext,
965 ) -> Task<Result<Connection, EstablishConnectionError>> {
966 let use_preview_server = cx
967 .try_read_global(|channel: &ReleaseChannel, _| *channel != ReleaseChannel::Stable)
968 .unwrap_or(false);
969
970 let request = Request::builder()
971 .header(
972 "Authorization",
973 format!("{} {}", credentials.user_id, credentials.access_token),
974 )
975 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
976
977 let http = self.http.clone();
978 cx.background_executor().spawn(async move {
979 let mut rpc_url = Self::get_rpc_url(http, use_preview_server).await?;
980 let rpc_host = rpc_url
981 .host_str()
982 .zip(rpc_url.port_or_known_default())
983 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
984 let stream = smol::net::TcpStream::connect(rpc_host).await?;
985
986 log::info!("connected to rpc endpoint {}", rpc_url);
987
988 match rpc_url.scheme() {
989 "https" => {
990 rpc_url.set_scheme("wss").unwrap();
991 let request = request.uri(rpc_url.as_str()).body(())?;
992 let (stream, _) =
993 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
994 Ok(Connection::new(
995 stream
996 .map_err(|error| anyhow!(error))
997 .sink_map_err(|error| anyhow!(error)),
998 ))
999 }
1000 "http" => {
1001 rpc_url.set_scheme("ws").unwrap();
1002 let request = request.uri(rpc_url.as_str()).body(())?;
1003 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1004 Ok(Connection::new(
1005 stream
1006 .map_err(|error| anyhow!(error))
1007 .sink_map_err(|error| anyhow!(error)),
1008 ))
1009 }
1010 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1011 }
1012 })
1013 }
1014
1015 pub fn authenticate_with_browser(
1016 self: &Arc<Self>,
1017 cx: &AsyncAppContext,
1018 ) -> Task<Result<Credentials>> {
1019 let http = self.http.clone();
1020 cx.spawn(|cx| async move {
1021 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1022 // zed server to encrypt the user's access token, so that it can'be intercepted by
1023 // any other app running on the user's device.
1024 let (public_key, private_key) =
1025 rpc::auth::keypair().expect("failed to generate keypair for auth");
1026 let public_key_string =
1027 String::try_from(public_key).expect("failed to serialize public key for auth");
1028
1029 if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1030 return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1031 }
1032
1033 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1034 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1035 let port = server.server_addr().port();
1036
1037 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1038 // that the user is signing in from a Zed app running on the same device.
1039 let mut url = format!(
1040 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1041 *ZED_SERVER_URL, port, public_key_string
1042 );
1043
1044 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1045 log::info!("impersonating user @{}", impersonate_login);
1046 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1047 }
1048
1049 cx.update(|cx| cx.open_url(&url))?;
1050
1051 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1052 // access token from the query params.
1053 //
1054 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1055 // custom URL scheme instead of this local HTTP server.
1056 let (user_id, access_token) = cx
1057 .spawn(|_| async move {
1058 for _ in 0..100 {
1059 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1060 let path = req.url();
1061 let mut user_id = None;
1062 let mut access_token = None;
1063 let url = Url::parse(&format!("http://example.com{}", path))
1064 .context("failed to parse login notification url")?;
1065 for (key, value) in url.query_pairs() {
1066 if key == "access_token" {
1067 access_token = Some(value.to_string());
1068 } else if key == "user_id" {
1069 user_id = Some(value.to_string());
1070 }
1071 }
1072
1073 let post_auth_url =
1074 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1075 req.respond(
1076 tiny_http::Response::empty(302).with_header(
1077 tiny_http::Header::from_bytes(
1078 &b"Location"[..],
1079 post_auth_url.as_bytes(),
1080 )
1081 .unwrap(),
1082 ),
1083 )
1084 .context("failed to respond to login http request")?;
1085 return Ok((
1086 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1087 access_token
1088 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1089 ));
1090 }
1091 }
1092
1093 Err(anyhow!("didn't receive login redirect"))
1094 })
1095 .await?;
1096
1097 let access_token = private_key
1098 .decrypt_string(&access_token)
1099 .context("failed to decrypt access token")?;
1100 cx.update(|cx| cx.activate(true))?;
1101
1102 Ok(Credentials {
1103 user_id: user_id.parse()?,
1104 access_token,
1105 })
1106 })
1107 }
1108
1109 async fn authenticate_as_admin(
1110 http: Arc<dyn HttpClient>,
1111 login: String,
1112 mut api_token: String,
1113 ) -> Result<Credentials> {
1114 #[derive(Deserialize)]
1115 struct AuthenticatedUserResponse {
1116 user: User,
1117 }
1118
1119 #[derive(Deserialize)]
1120 struct User {
1121 id: u64,
1122 }
1123
1124 // Use the collab server's admin API to retrieve the id
1125 // of the impersonated user.
1126 let mut url = Self::get_rpc_url(http.clone(), false).await?;
1127 url.set_path("/user");
1128 url.set_query(Some(&format!("github_login={login}")));
1129 let request = Request::get(url.as_str())
1130 .header("Authorization", format!("token {api_token}"))
1131 .body("".into())?;
1132
1133 let mut response = http.send(request).await?;
1134 let mut body = String::new();
1135 response.body_mut().read_to_string(&mut body).await?;
1136 if !response.status().is_success() {
1137 Err(anyhow!(
1138 "admin user request failed {} - {}",
1139 response.status().as_u16(),
1140 body,
1141 ))?;
1142 }
1143 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1144
1145 // Use the admin API token to authenticate as the impersonated user.
1146 api_token.insert_str(0, "ADMIN_TOKEN:");
1147 Ok(Credentials {
1148 user_id: response.user.id,
1149 access_token: api_token,
1150 })
1151 }
1152
1153 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1154 self.peer.teardown();
1155 self.set_status(Status::SignedOut, cx);
1156 }
1157
1158 pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1159 self.peer.teardown();
1160 self.set_status(Status::ConnectionLost, cx);
1161 }
1162
1163 fn connection_id(&self) -> Result<ConnectionId> {
1164 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1165 Ok(connection_id)
1166 } else {
1167 Err(anyhow!("not connected"))
1168 }
1169 }
1170
1171 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1172 log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
1173 self.peer.send(self.connection_id()?, message)
1174 }
1175
1176 pub fn request<T: RequestMessage>(
1177 &self,
1178 request: T,
1179 ) -> impl Future<Output = Result<T::Response>> {
1180 self.request_envelope(request)
1181 .map_ok(|envelope| envelope.payload)
1182 }
1183
1184 pub fn request_envelope<T: RequestMessage>(
1185 &self,
1186 request: T,
1187 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
1188 let client_id = self.id();
1189 log::debug!(
1190 "rpc request start. client_id:{}. name:{}",
1191 client_id,
1192 T::NAME
1193 );
1194 let response = self
1195 .connection_id()
1196 .map(|conn_id| self.peer.request_envelope(conn_id, request));
1197 async move {
1198 let response = response?.await;
1199 log::debug!(
1200 "rpc request finish. client_id:{}. name:{}",
1201 client_id,
1202 T::NAME
1203 );
1204 response
1205 }
1206 }
1207
1208 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1209 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1210 self.peer.respond(receipt, response)
1211 }
1212
1213 fn respond_with_error<T: RequestMessage>(
1214 &self,
1215 receipt: Receipt<T>,
1216 error: proto::Error,
1217 ) -> Result<()> {
1218 log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
1219 self.peer.respond_with_error(receipt, error)
1220 }
1221
1222 fn handle_message(
1223 self: &Arc<Client>,
1224 message: Box<dyn AnyTypedEnvelope>,
1225 cx: &AsyncAppContext,
1226 ) {
1227 let mut state = self.state.write();
1228 let type_name = message.payload_type_name();
1229 let payload_type_id = message.payload_type_id();
1230 let sender_id = message.original_sender_id();
1231
1232 let mut subscriber = None;
1233
1234 if let Some(handle) = state
1235 .models_by_message_type
1236 .get(&payload_type_id)
1237 .and_then(|handle| handle.upgrade())
1238 {
1239 subscriber = Some(handle);
1240 } else if let Some((extract_entity_id, entity_type_id)) =
1241 state.entity_id_extractors.get(&payload_type_id).zip(
1242 state
1243 .entity_types_by_message_type
1244 .get(&payload_type_id)
1245 .copied(),
1246 )
1247 {
1248 let entity_id = (extract_entity_id)(message.as_ref());
1249
1250 match state
1251 .entities_by_type_and_remote_id
1252 .get_mut(&(entity_type_id, entity_id))
1253 {
1254 Some(WeakSubscriber::Pending(pending)) => {
1255 pending.push(message);
1256 return;
1257 }
1258 Some(weak_subscriber @ _) => match weak_subscriber {
1259 WeakSubscriber::Entity { handle } => {
1260 subscriber = handle.upgrade();
1261 }
1262
1263 WeakSubscriber::Pending(_) => {}
1264 },
1265 _ => {}
1266 }
1267 }
1268
1269 let subscriber = if let Some(subscriber) = subscriber {
1270 subscriber
1271 } else {
1272 log::info!("unhandled message {}", type_name);
1273 self.peer.respond_with_unhandled_message(message).log_err();
1274 return;
1275 };
1276
1277 let handler = state.message_handlers.get(&payload_type_id).cloned();
1278 // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1279 // It also ensures we don't hold the lock while yielding back to the executor, as
1280 // that might cause the executor thread driving this future to block indefinitely.
1281 drop(state);
1282
1283 if let Some(handler) = handler {
1284 let future = handler(subscriber, message, &self, cx.clone());
1285 let client_id = self.id();
1286 log::debug!(
1287 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1288 client_id,
1289 sender_id,
1290 type_name
1291 );
1292 cx.spawn(move |_| async move {
1293 match future.await {
1294 Ok(()) => {
1295 log::debug!(
1296 "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1297 client_id,
1298 sender_id,
1299 type_name
1300 );
1301 }
1302 Err(error) => {
1303 log::error!(
1304 "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1305 client_id,
1306 sender_id,
1307 type_name,
1308 error
1309 );
1310 }
1311 }
1312 })
1313 .detach();
1314 } else {
1315 log::info!("unhandled message {}", type_name);
1316 self.peer.respond_with_unhandled_message(message).log_err();
1317 }
1318 }
1319
1320 pub fn telemetry(&self) -> &Arc<Telemetry> {
1321 &self.telemetry
1322 }
1323}
1324
1325async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1326 if IMPERSONATE_LOGIN.is_some() {
1327 return None;
1328 }
1329
1330 let (user_id, access_token) = cx
1331 .update(|cx| cx.read_credentials(&ZED_SERVER_URL).log_err().flatten())
1332 .ok()??;
1333
1334 Some(Credentials {
1335 user_id: user_id.parse().ok()?,
1336 access_token: String::from_utf8(access_token).ok()?,
1337 })
1338}
1339
1340async fn write_credentials_to_keychain(
1341 credentials: Credentials,
1342 cx: &AsyncAppContext,
1343) -> Result<()> {
1344 cx.update(move |cx| {
1345 cx.write_credentials(
1346 &ZED_SERVER_URL,
1347 &credentials.user_id.to_string(),
1348 credentials.access_token.as_bytes(),
1349 )
1350 })?
1351}
1352
1353async fn delete_credentials_from_keychain(cx: &AsyncAppContext) -> Result<()> {
1354 cx.update(move |cx| cx.delete_credentials(&ZED_SERVER_URL))?
1355}
1356
1357const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1358
1359pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1360 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1361}
1362
1363pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1364 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1365 let mut parts = path.split('/');
1366 let id = parts.next()?.parse::<u64>().ok()?;
1367 let access_token = parts.next()?;
1368 if access_token.is_empty() {
1369 return None;
1370 }
1371 Some((id, access_token.to_string()))
1372}
1373
1374#[cfg(test)]
1375mod tests {
1376 use super::*;
1377 use crate::test::FakeServer;
1378
1379 use gpui::{BackgroundExecutor, Context, TestAppContext};
1380 use parking_lot::Mutex;
1381 use std::future;
1382 use util::http::FakeHttpClient;
1383
1384 #[gpui::test(iterations = 10)]
1385 async fn test_reconnection(cx: &mut TestAppContext) {
1386 let user_id = 5;
1387 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1388 let server = FakeServer::for_client(user_id, &client, cx).await;
1389 let mut status = client.status();
1390 assert!(matches!(
1391 status.next().await,
1392 Some(Status::Connected { .. })
1393 ));
1394 assert_eq!(server.auth_count(), 1);
1395
1396 server.forbid_connections();
1397 server.disconnect();
1398 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1399
1400 server.allow_connections();
1401 cx.executor().advance_clock(Duration::from_secs(10));
1402 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1403 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1404
1405 server.forbid_connections();
1406 server.disconnect();
1407 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1408
1409 // Clear cached credentials after authentication fails
1410 server.roll_access_token();
1411 server.allow_connections();
1412 cx.executor().run_until_parked();
1413 cx.executor().advance_clock(Duration::from_secs(10));
1414 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1415 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1416 }
1417
1418 #[gpui::test(iterations = 10)]
1419 async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
1420 let user_id = 5;
1421 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1422 let mut status = client.status();
1423
1424 // Time out when client tries to connect.
1425 client.override_authenticate(move |cx| {
1426 cx.background_executor().spawn(async move {
1427 Ok(Credentials {
1428 user_id,
1429 access_token: "token".into(),
1430 })
1431 })
1432 });
1433 client.override_establish_connection(|_, cx| {
1434 cx.background_executor().spawn(async move {
1435 future::pending::<()>().await;
1436 unreachable!()
1437 })
1438 });
1439 let auth_and_connect = cx.spawn({
1440 let client = client.clone();
1441 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1442 });
1443 executor.run_until_parked();
1444 assert!(matches!(status.next().await, Some(Status::Connecting)));
1445
1446 executor.advance_clock(CONNECTION_TIMEOUT);
1447 assert!(matches!(
1448 status.next().await,
1449 Some(Status::ConnectionError { .. })
1450 ));
1451 auth_and_connect.await.unwrap_err();
1452
1453 // Allow the connection to be established.
1454 let server = FakeServer::for_client(user_id, &client, cx).await;
1455 assert!(matches!(
1456 status.next().await,
1457 Some(Status::Connected { .. })
1458 ));
1459
1460 // Disconnect client.
1461 server.forbid_connections();
1462 server.disconnect();
1463 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1464
1465 // Time out when re-establishing the connection.
1466 server.allow_connections();
1467 client.override_establish_connection(|_, cx| {
1468 cx.background_executor().spawn(async move {
1469 future::pending::<()>().await;
1470 unreachable!()
1471 })
1472 });
1473 executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1474 assert!(matches!(
1475 status.next().await,
1476 Some(Status::Reconnecting { .. })
1477 ));
1478
1479 executor.advance_clock(CONNECTION_TIMEOUT);
1480 assert!(matches!(
1481 status.next().await,
1482 Some(Status::ReconnectionError { .. })
1483 ));
1484 }
1485
1486 #[gpui::test(iterations = 10)]
1487 async fn test_authenticating_more_than_once(
1488 cx: &mut TestAppContext,
1489 executor: BackgroundExecutor,
1490 ) {
1491 let auth_count = Arc::new(Mutex::new(0));
1492 let dropped_auth_count = Arc::new(Mutex::new(0));
1493 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1494 client.override_authenticate({
1495 let auth_count = auth_count.clone();
1496 let dropped_auth_count = dropped_auth_count.clone();
1497 move |cx| {
1498 let auth_count = auth_count.clone();
1499 let dropped_auth_count = dropped_auth_count.clone();
1500 cx.background_executor().spawn(async move {
1501 *auth_count.lock() += 1;
1502 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1503 future::pending::<()>().await;
1504 unreachable!()
1505 })
1506 }
1507 });
1508
1509 let _authenticate = cx.spawn({
1510 let client = client.clone();
1511 move |cx| async move { client.authenticate_and_connect(false, &cx).await }
1512 });
1513 executor.run_until_parked();
1514 assert_eq!(*auth_count.lock(), 1);
1515 assert_eq!(*dropped_auth_count.lock(), 0);
1516
1517 let _authenticate = cx.spawn({
1518 let client = client.clone();
1519 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1520 });
1521 executor.run_until_parked();
1522 assert_eq!(*auth_count.lock(), 2);
1523 assert_eq!(*dropped_auth_count.lock(), 1);
1524 }
1525
1526 #[test]
1527 fn test_encode_and_decode_worktree_url() {
1528 let url = encode_worktree_url(5, "deadbeef");
1529 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1530 assert_eq!(
1531 decode_worktree_url(&format!("\n {}\t", url)),
1532 Some((5, "deadbeef".to_string()))
1533 );
1534 assert_eq!(decode_worktree_url("not://the-right-format"), None);
1535 }
1536
1537 #[gpui::test]
1538 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1539 let user_id = 5;
1540 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1541 let server = FakeServer::for_client(user_id, &client, cx).await;
1542
1543 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1544 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1545 client.add_model_message_handler(
1546 move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
1547 match model.update(&mut cx, |model, _| model.id).unwrap() {
1548 1 => done_tx1.try_send(()).unwrap(),
1549 2 => done_tx2.try_send(()).unwrap(),
1550 _ => unreachable!(),
1551 }
1552 async { Ok(()) }
1553 },
1554 );
1555 let model1 = cx.build_model(|_| TestModel {
1556 id: 1,
1557 subscription: None,
1558 });
1559 let model2 = cx.build_model(|_| TestModel {
1560 id: 2,
1561 subscription: None,
1562 });
1563 let model3 = cx.build_model(|_| TestModel {
1564 id: 3,
1565 subscription: None,
1566 });
1567
1568 let _subscription1 = client
1569 .subscribe_to_entity(1)
1570 .unwrap()
1571 .set_model(&model1, &mut cx.to_async());
1572 let _subscription2 = client
1573 .subscribe_to_entity(2)
1574 .unwrap()
1575 .set_model(&model2, &mut cx.to_async());
1576 // Ensure dropping a subscription for the same entity type still allows receiving of
1577 // messages for other entity IDs of the same type.
1578 let subscription3 = client
1579 .subscribe_to_entity(3)
1580 .unwrap()
1581 .set_model(&model3, &mut cx.to_async());
1582 drop(subscription3);
1583
1584 server.send(proto::JoinProject { project_id: 1 });
1585 server.send(proto::JoinProject { project_id: 2 });
1586 done_rx1.next().await.unwrap();
1587 done_rx2.next().await.unwrap();
1588 }
1589
1590 #[gpui::test]
1591 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1592 let user_id = 5;
1593 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1594 let server = FakeServer::for_client(user_id, &client, cx).await;
1595
1596 let model = cx.build_model(|_| TestModel::default());
1597 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1598 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1599 let subscription1 = client.add_message_handler(
1600 model.downgrade(),
1601 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1602 done_tx1.try_send(()).unwrap();
1603 async { Ok(()) }
1604 },
1605 );
1606 drop(subscription1);
1607 let _subscription2 = client.add_message_handler(
1608 model.downgrade(),
1609 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1610 done_tx2.try_send(()).unwrap();
1611 async { Ok(()) }
1612 },
1613 );
1614 server.send(proto::Ping {});
1615 done_rx2.next().await.unwrap();
1616 }
1617
1618 #[gpui::test]
1619 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1620 let user_id = 5;
1621 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1622 let server = FakeServer::for_client(user_id, &client, cx).await;
1623
1624 let model = cx.build_model(|_| TestModel::default());
1625 let (done_tx, mut done_rx) = smol::channel::unbounded();
1626 let subscription = client.add_message_handler(
1627 model.clone().downgrade(),
1628 move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1629 model
1630 .update(&mut cx, |model, _| model.subscription.take())
1631 .unwrap();
1632 done_tx.try_send(()).unwrap();
1633 async { Ok(()) }
1634 },
1635 );
1636 model.update(cx, |model, _| {
1637 model.subscription = Some(subscription);
1638 });
1639 server.send(proto::Ping {});
1640 done_rx.next().await.unwrap();
1641 }
1642
1643 #[derive(Default)]
1644 struct TestModel {
1645 id: usize,
1646 subscription: Option<Subscription>,
1647 }
1648}