1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod telemetry;
5pub mod user;
6
7use anyhow::{anyhow, Context, Result};
8use async_recursion::async_recursion;
9use async_tungstenite::tungstenite::{
10 error::Error as WebsocketError,
11 http::{Request, StatusCode},
12};
13use futures::{future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryStreamExt};
14use gpui::{
15 actions,
16 serde_json::{self, Value},
17 AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AppContext, AppVersion,
18 AsyncAppContext, Entity, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle,
19};
20use lazy_static::lazy_static;
21use parking_lot::RwLock;
22use postage::watch;
23use rand::prelude::*;
24use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
25use serde::Deserialize;
26use settings::{Settings, TelemetrySettings};
27use std::{
28 any::TypeId,
29 collections::HashMap,
30 convert::TryFrom,
31 fmt::Write as _,
32 future::Future,
33 marker::PhantomData,
34 path::PathBuf,
35 sync::{Arc, Weak},
36 time::{Duration, Instant},
37};
38use telemetry::Telemetry;
39use thiserror::Error;
40use url::Url;
41use util::channel::ReleaseChannel;
42use util::http::HttpClient;
43use util::{ResultExt, TryFutureExt};
44
45pub use rpc::*;
46pub use user::*;
47
48lazy_static! {
49 pub static ref ZED_SERVER_URL: String =
50 std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
51 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
52 .ok()
53 .and_then(|s| if s.is_empty() { None } else { Some(s) });
54 pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
55 .ok()
56 .and_then(|s| if s.is_empty() { None } else { Some(s) });
57 pub static ref ZED_APP_VERSION: Option<AppVersion> = std::env::var("ZED_APP_VERSION")
58 .ok()
59 .and_then(|v| v.parse().ok());
60 pub static ref ZED_APP_PATH: Option<PathBuf> =
61 std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
62}
63
64pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
65pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
66pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
67
68actions!(client, [SignIn, SignOut]);
69
70pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
71 cx.add_global_action({
72 let client = client.clone();
73 move |_: &SignIn, cx| {
74 let client = client.clone();
75 cx.spawn(
76 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
77 )
78 .detach();
79 }
80 });
81 cx.add_global_action({
82 let client = client.clone();
83 move |_: &SignOut, cx| {
84 let client = client.clone();
85 cx.spawn(|cx| async move {
86 client.disconnect(&cx);
87 })
88 .detach();
89 }
90 });
91}
92
93pub struct Client {
94 id: usize,
95 peer: Arc<Peer>,
96 http: Arc<dyn HttpClient>,
97 telemetry: Arc<Telemetry>,
98 state: RwLock<ClientState>,
99
100 #[allow(clippy::type_complexity)]
101 #[cfg(any(test, feature = "test-support"))]
102 authenticate: RwLock<
103 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
104 >,
105
106 #[allow(clippy::type_complexity)]
107 #[cfg(any(test, feature = "test-support"))]
108 establish_connection: RwLock<
109 Option<
110 Box<
111 dyn 'static
112 + Send
113 + Sync
114 + Fn(
115 &Credentials,
116 &AsyncAppContext,
117 ) -> Task<Result<Connection, EstablishConnectionError>>,
118 >,
119 >,
120 >,
121}
122
123#[derive(Error, Debug)]
124pub enum EstablishConnectionError {
125 #[error("upgrade required")]
126 UpgradeRequired,
127 #[error("unauthorized")]
128 Unauthorized,
129 #[error("{0}")]
130 Other(#[from] anyhow::Error),
131 #[error("{0}")]
132 Http(#[from] util::http::Error),
133 #[error("{0}")]
134 Io(#[from] std::io::Error),
135 #[error("{0}")]
136 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
137}
138
139impl From<WebsocketError> for EstablishConnectionError {
140 fn from(error: WebsocketError) -> Self {
141 if let WebsocketError::Http(response) = &error {
142 match response.status() {
143 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
144 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
145 _ => {}
146 }
147 }
148 EstablishConnectionError::Other(error.into())
149 }
150}
151
152impl EstablishConnectionError {
153 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
154 Self::Other(error.into())
155 }
156}
157
158#[derive(Copy, Clone, Debug, PartialEq)]
159pub enum Status {
160 SignedOut,
161 UpgradeRequired,
162 Authenticating,
163 Connecting,
164 ConnectionError,
165 Connected {
166 peer_id: PeerId,
167 connection_id: ConnectionId,
168 },
169 ConnectionLost,
170 Reauthenticating,
171 Reconnecting,
172 ReconnectionError {
173 next_reconnection: Instant,
174 },
175}
176
177impl Status {
178 pub fn is_connected(&self) -> bool {
179 matches!(self, Self::Connected { .. })
180 }
181
182 pub fn is_signed_out(&self) -> bool {
183 matches!(self, Self::SignedOut | Self::UpgradeRequired)
184 }
185}
186
187struct ClientState {
188 credentials: Option<Credentials>,
189 status: (watch::Sender<Status>, watch::Receiver<Status>),
190 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
191 _reconnect_task: Option<Task<()>>,
192 reconnect_interval: Duration,
193 entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
194 models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
195 entity_types_by_message_type: HashMap<TypeId, TypeId>,
196 #[allow(clippy::type_complexity)]
197 message_handlers: HashMap<
198 TypeId,
199 Arc<
200 dyn Send
201 + Sync
202 + Fn(
203 Subscriber,
204 Box<dyn AnyTypedEnvelope>,
205 &Arc<Client>,
206 AsyncAppContext,
207 ) -> LocalBoxFuture<'static, Result<()>>,
208 >,
209 >,
210}
211
212enum WeakSubscriber {
213 Model(AnyWeakModelHandle),
214 View(AnyWeakViewHandle),
215 Pending(Vec<Box<dyn AnyTypedEnvelope>>),
216}
217
218enum Subscriber {
219 Model(AnyModelHandle),
220 View(AnyViewHandle),
221}
222
223#[derive(Clone, Debug)]
224pub struct Credentials {
225 pub user_id: u64,
226 pub access_token: String,
227}
228
229impl Default for ClientState {
230 fn default() -> Self {
231 Self {
232 credentials: None,
233 status: watch::channel_with(Status::SignedOut),
234 entity_id_extractors: Default::default(),
235 _reconnect_task: None,
236 reconnect_interval: Duration::from_secs(5),
237 models_by_message_type: Default::default(),
238 entities_by_type_and_remote_id: Default::default(),
239 entity_types_by_message_type: Default::default(),
240 message_handlers: Default::default(),
241 }
242 }
243}
244
245pub enum Subscription {
246 Entity {
247 client: Weak<Client>,
248 id: (TypeId, u64),
249 },
250 Message {
251 client: Weak<Client>,
252 id: TypeId,
253 },
254}
255
256impl Drop for Subscription {
257 fn drop(&mut self) {
258 match self {
259 Subscription::Entity { client, id } => {
260 if let Some(client) = client.upgrade() {
261 let mut state = client.state.write();
262 let _ = state.entities_by_type_and_remote_id.remove(id);
263 }
264 }
265 Subscription::Message { client, id } => {
266 if let Some(client) = client.upgrade() {
267 let mut state = client.state.write();
268 let _ = state.entity_types_by_message_type.remove(id);
269 let _ = state.message_handlers.remove(id);
270 }
271 }
272 }
273 }
274}
275
276pub struct PendingEntitySubscription<T: Entity> {
277 client: Arc<Client>,
278 remote_id: u64,
279 _entity_type: PhantomData<T>,
280 consumed: bool,
281}
282
283impl<T: Entity> PendingEntitySubscription<T> {
284 pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
285 self.consumed = true;
286 let mut state = self.client.state.write();
287 let id = (TypeId::of::<T>(), self.remote_id);
288 let Some(WeakSubscriber::Pending(messages)) =
289 state.entities_by_type_and_remote_id.remove(&id)
290 else {
291 unreachable!()
292 };
293
294 state
295 .entities_by_type_and_remote_id
296 .insert(id, WeakSubscriber::Model(model.downgrade().into_any()));
297 drop(state);
298 for message in messages {
299 self.client.handle_message(message, cx);
300 }
301 Subscription::Entity {
302 client: Arc::downgrade(&self.client),
303 id,
304 }
305 }
306}
307
308impl<T: Entity> Drop for PendingEntitySubscription<T> {
309 fn drop(&mut self) {
310 if !self.consumed {
311 let mut state = self.client.state.write();
312 if let Some(WeakSubscriber::Pending(messages)) = state
313 .entities_by_type_and_remote_id
314 .remove(&(TypeId::of::<T>(), self.remote_id))
315 {
316 for message in messages {
317 log::info!("unhandled message {}", message.payload_type_name());
318 }
319 }
320 }
321 }
322}
323
324impl Client {
325 pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
326 Arc::new(Self {
327 id: 0,
328 peer: Peer::new(0),
329 telemetry: Telemetry::new(http.clone(), cx),
330 http,
331 state: Default::default(),
332
333 #[cfg(any(test, feature = "test-support"))]
334 authenticate: Default::default(),
335 #[cfg(any(test, feature = "test-support"))]
336 establish_connection: Default::default(),
337 })
338 }
339
340 pub fn id(&self) -> usize {
341 self.id
342 }
343
344 pub fn http_client(&self) -> Arc<dyn HttpClient> {
345 self.http.clone()
346 }
347
348 #[cfg(any(test, feature = "test-support"))]
349 pub fn set_id(&mut self, id: usize) -> &Self {
350 self.id = id;
351 self
352 }
353
354 #[cfg(any(test, feature = "test-support"))]
355 pub fn teardown(&self) {
356 let mut state = self.state.write();
357 state._reconnect_task.take();
358 state.message_handlers.clear();
359 state.models_by_message_type.clear();
360 state.entities_by_type_and_remote_id.clear();
361 state.entity_id_extractors.clear();
362 self.peer.teardown();
363 }
364
365 #[cfg(any(test, feature = "test-support"))]
366 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
367 where
368 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
369 {
370 *self.authenticate.write() = Some(Box::new(authenticate));
371 self
372 }
373
374 #[cfg(any(test, feature = "test-support"))]
375 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
376 where
377 F: 'static
378 + Send
379 + Sync
380 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
381 {
382 *self.establish_connection.write() = Some(Box::new(connect));
383 self
384 }
385
386 pub fn user_id(&self) -> Option<u64> {
387 self.state
388 .read()
389 .credentials
390 .as_ref()
391 .map(|credentials| credentials.user_id)
392 }
393
394 pub fn peer_id(&self) -> Option<PeerId> {
395 if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
396 Some(*peer_id)
397 } else {
398 None
399 }
400 }
401
402 pub fn status(&self) -> watch::Receiver<Status> {
403 self.state.read().status.1.clone()
404 }
405
406 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
407 log::info!("set status on client {}: {:?}", self.id, status);
408 let mut state = self.state.write();
409 *state.status.0.borrow_mut() = status;
410
411 match status {
412 Status::Connected { .. } => {
413 state._reconnect_task = None;
414 }
415 Status::ConnectionLost => {
416 let this = self.clone();
417 let reconnect_interval = state.reconnect_interval;
418 state._reconnect_task = Some(cx.spawn(|cx| async move {
419 #[cfg(any(test, feature = "test-support"))]
420 let mut rng = StdRng::seed_from_u64(0);
421 #[cfg(not(any(test, feature = "test-support")))]
422 let mut rng = StdRng::from_entropy();
423
424 let mut delay = INITIAL_RECONNECTION_DELAY;
425 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
426 log::error!("failed to connect {}", error);
427 if matches!(*this.status().borrow(), Status::ConnectionError) {
428 this.set_status(
429 Status::ReconnectionError {
430 next_reconnection: Instant::now() + delay,
431 },
432 &cx,
433 );
434 cx.background().timer(delay).await;
435 delay = delay
436 .mul_f32(rng.gen_range(1.0..=2.0))
437 .min(reconnect_interval);
438 } else {
439 break;
440 }
441 }
442 }));
443 }
444 Status::SignedOut | Status::UpgradeRequired => {
445 let telemetry_settings = cx.read(|cx| cx.global::<Settings>().telemetry());
446 self.telemetry
447 .set_authenticated_user_info(None, false, telemetry_settings);
448 state._reconnect_task.take();
449 }
450 _ => {}
451 }
452 }
453
454 pub fn add_view_for_remote_entity<T: View>(
455 self: &Arc<Self>,
456 remote_id: u64,
457 cx: &mut ViewContext<T>,
458 ) -> Subscription {
459 let id = (TypeId::of::<T>(), remote_id);
460 self.state
461 .write()
462 .entities_by_type_and_remote_id
463 .insert(id, WeakSubscriber::View(cx.weak_handle().into_any()));
464 Subscription::Entity {
465 client: Arc::downgrade(self),
466 id,
467 }
468 }
469
470 pub fn subscribe_to_entity<T: Entity>(
471 self: &Arc<Self>,
472 remote_id: u64,
473 ) -> PendingEntitySubscription<T> {
474 let id = (TypeId::of::<T>(), remote_id);
475 self.state
476 .write()
477 .entities_by_type_and_remote_id
478 .insert(id, WeakSubscriber::Pending(Default::default()));
479
480 PendingEntitySubscription {
481 client: self.clone(),
482 remote_id,
483 consumed: false,
484 _entity_type: PhantomData,
485 }
486 }
487
488 pub fn add_message_handler<M, E, H, F>(
489 self: &Arc<Self>,
490 model: ModelHandle<E>,
491 handler: H,
492 ) -> Subscription
493 where
494 M: EnvelopedMessage,
495 E: Entity,
496 H: 'static
497 + Send
498 + Sync
499 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
500 F: 'static + Future<Output = Result<()>>,
501 {
502 let message_type_id = TypeId::of::<M>();
503
504 let mut state = self.state.write();
505 state
506 .models_by_message_type
507 .insert(message_type_id, model.downgrade().into_any());
508
509 let prev_handler = state.message_handlers.insert(
510 message_type_id,
511 Arc::new(move |handle, envelope, client, cx| {
512 let handle = if let Subscriber::Model(handle) = handle {
513 handle
514 } else {
515 unreachable!();
516 };
517 let model = handle.downcast::<E>().unwrap();
518 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
519 handler(model, *envelope, client.clone(), cx).boxed_local()
520 }),
521 );
522 if prev_handler.is_some() {
523 panic!("registered handler for the same message twice");
524 }
525
526 Subscription::Message {
527 client: Arc::downgrade(self),
528 id: message_type_id,
529 }
530 }
531
532 pub fn add_request_handler<M, E, H, F>(
533 self: &Arc<Self>,
534 model: ModelHandle<E>,
535 handler: H,
536 ) -> Subscription
537 where
538 M: RequestMessage,
539 E: Entity,
540 H: 'static
541 + Send
542 + Sync
543 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
544 F: 'static + Future<Output = Result<M::Response>>,
545 {
546 self.add_message_handler(model, move |handle, envelope, this, cx| {
547 Self::respond_to_request(
548 envelope.receipt(),
549 handler(handle, envelope, this.clone(), cx),
550 this,
551 )
552 })
553 }
554
555 pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
556 where
557 M: EntityMessage,
558 E: View,
559 H: 'static
560 + Send
561 + Sync
562 + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
563 F: 'static + Future<Output = Result<()>>,
564 {
565 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
566 if let Subscriber::View(handle) = handle {
567 handler(handle.downcast::<E>().unwrap(), message, client, cx)
568 } else {
569 unreachable!();
570 }
571 })
572 }
573
574 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
575 where
576 M: EntityMessage,
577 E: Entity,
578 H: 'static
579 + Send
580 + Sync
581 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
582 F: 'static + Future<Output = Result<()>>,
583 {
584 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
585 if let Subscriber::Model(handle) = handle {
586 handler(handle.downcast::<E>().unwrap(), message, client, cx)
587 } else {
588 unreachable!();
589 }
590 })
591 }
592
593 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
594 where
595 M: EntityMessage,
596 E: Entity,
597 H: 'static
598 + Send
599 + Sync
600 + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
601 F: 'static + Future<Output = Result<()>>,
602 {
603 let model_type_id = TypeId::of::<E>();
604 let message_type_id = TypeId::of::<M>();
605
606 let mut state = self.state.write();
607 state
608 .entity_types_by_message_type
609 .insert(message_type_id, model_type_id);
610 state
611 .entity_id_extractors
612 .entry(message_type_id)
613 .or_insert_with(|| {
614 |envelope| {
615 envelope
616 .as_any()
617 .downcast_ref::<TypedEnvelope<M>>()
618 .unwrap()
619 .payload
620 .remote_entity_id()
621 }
622 });
623 let prev_handler = state.message_handlers.insert(
624 message_type_id,
625 Arc::new(move |handle, envelope, client, cx| {
626 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
627 handler(handle, *envelope, client.clone(), cx).boxed_local()
628 }),
629 );
630 if prev_handler.is_some() {
631 panic!("registered handler for the same message twice");
632 }
633 }
634
635 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
636 where
637 M: EntityMessage + RequestMessage,
638 E: Entity,
639 H: 'static
640 + Send
641 + Sync
642 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
643 F: 'static + Future<Output = Result<M::Response>>,
644 {
645 self.add_model_message_handler(move |entity, envelope, client, cx| {
646 Self::respond_to_request::<M, _>(
647 envelope.receipt(),
648 handler(entity, envelope, client.clone(), cx),
649 client,
650 )
651 })
652 }
653
654 pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
655 where
656 M: EntityMessage + RequestMessage,
657 E: View,
658 H: 'static
659 + Send
660 + Sync
661 + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
662 F: 'static + Future<Output = Result<M::Response>>,
663 {
664 self.add_view_message_handler(move |entity, envelope, client, cx| {
665 Self::respond_to_request::<M, _>(
666 envelope.receipt(),
667 handler(entity, envelope, client.clone(), cx),
668 client,
669 )
670 })
671 }
672
673 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
674 receipt: Receipt<T>,
675 response: F,
676 client: Arc<Self>,
677 ) -> Result<()> {
678 match response.await {
679 Ok(response) => {
680 client.respond(receipt, response)?;
681 Ok(())
682 }
683 Err(error) => {
684 client.respond_with_error(
685 receipt,
686 proto::Error {
687 message: format!("{:?}", error),
688 },
689 )?;
690 Err(error)
691 }
692 }
693 }
694
695 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
696 read_credentials_from_keychain(cx).is_some()
697 }
698
699 #[async_recursion(?Send)]
700 pub async fn authenticate_and_connect(
701 self: &Arc<Self>,
702 try_keychain: bool,
703 cx: &AsyncAppContext,
704 ) -> anyhow::Result<()> {
705 let was_disconnected = match *self.status().borrow() {
706 Status::SignedOut => true,
707 Status::ConnectionError
708 | Status::ConnectionLost
709 | Status::Authenticating { .. }
710 | Status::Reauthenticating { .. }
711 | Status::ReconnectionError { .. } => false,
712 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
713 return Ok(())
714 }
715 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
716 };
717
718 if was_disconnected {
719 self.set_status(Status::Authenticating, cx);
720 } else {
721 self.set_status(Status::Reauthenticating, cx)
722 }
723
724 let mut read_from_keychain = false;
725 let mut credentials = self.state.read().credentials.clone();
726 if credentials.is_none() && try_keychain {
727 credentials = read_credentials_from_keychain(cx);
728 read_from_keychain = credentials.is_some();
729 if read_from_keychain {
730 cx.read(|cx| {
731 self.report_event(
732 "read credentials from keychain",
733 Default::default(),
734 cx.global::<Settings>().telemetry(),
735 );
736 });
737 }
738 }
739 if credentials.is_none() {
740 let mut status_rx = self.status();
741 let _ = status_rx.next().await;
742 futures::select_biased! {
743 authenticate = self.authenticate(cx).fuse() => {
744 match authenticate {
745 Ok(creds) => credentials = Some(creds),
746 Err(err) => {
747 self.set_status(Status::ConnectionError, cx);
748 return Err(err);
749 }
750 }
751 }
752 _ = status_rx.next().fuse() => {
753 return Err(anyhow!("authentication canceled"));
754 }
755 }
756 }
757 let credentials = credentials.unwrap();
758
759 if was_disconnected {
760 self.set_status(Status::Connecting, cx);
761 } else {
762 self.set_status(Status::Reconnecting, cx);
763 }
764
765 let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
766 futures::select_biased! {
767 connection = self.establish_connection(&credentials, cx).fuse() => {
768 match connection {
769 Ok(conn) => {
770 self.state.write().credentials = Some(credentials.clone());
771 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
772 write_credentials_to_keychain(&credentials, cx).log_err();
773 }
774
775 futures::select_biased! {
776 result = self.set_connection(conn, cx).fuse() => result,
777 _ = timeout => {
778 self.set_status(Status::ConnectionError, cx);
779 Err(anyhow!("timed out waiting on hello message from server"))
780 }
781 }
782 }
783 Err(EstablishConnectionError::Unauthorized) => {
784 self.state.write().credentials.take();
785 if read_from_keychain {
786 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
787 self.set_status(Status::SignedOut, cx);
788 self.authenticate_and_connect(false, cx).await
789 } else {
790 self.set_status(Status::ConnectionError, cx);
791 Err(EstablishConnectionError::Unauthorized)?
792 }
793 }
794 Err(EstablishConnectionError::UpgradeRequired) => {
795 self.set_status(Status::UpgradeRequired, cx);
796 Err(EstablishConnectionError::UpgradeRequired)?
797 }
798 Err(error) => {
799 self.set_status(Status::ConnectionError, cx);
800 Err(error)?
801 }
802 }
803 }
804 _ = &mut timeout => {
805 self.set_status(Status::ConnectionError, cx);
806 Err(anyhow!("timed out trying to establish connection"))
807 }
808 }
809 }
810
811 async fn set_connection(
812 self: &Arc<Self>,
813 conn: Connection,
814 cx: &AsyncAppContext,
815 ) -> Result<()> {
816 let executor = cx.background();
817 log::info!("add connection to peer");
818 let (connection_id, handle_io, mut incoming) = self
819 .peer
820 .add_connection(conn, move |duration| executor.timer(duration));
821 let handle_io = cx.background().spawn(handle_io);
822
823 let peer_id = async {
824 log::info!("waiting for server hello");
825 let message = incoming
826 .next()
827 .await
828 .ok_or_else(|| anyhow!("no hello message received"))?;
829 log::info!("got server hello");
830 let hello_message_type_name = message.payload_type_name().to_string();
831 let hello = message
832 .into_any()
833 .downcast::<TypedEnvelope<proto::Hello>>()
834 .map_err(|_| {
835 anyhow!(
836 "invalid hello message received: {:?}",
837 hello_message_type_name
838 )
839 })?;
840 let peer_id = hello
841 .payload
842 .peer_id
843 .ok_or_else(|| anyhow!("invalid peer id"))?;
844 Ok(peer_id)
845 };
846
847 let peer_id = match peer_id.await {
848 Ok(peer_id) => peer_id,
849 Err(error) => {
850 self.peer.disconnect(connection_id);
851 return Err(error);
852 }
853 };
854
855 log::info!(
856 "set status to connected (connection id: {:?}, peer id: {:?})",
857 connection_id,
858 peer_id
859 );
860 self.set_status(
861 Status::Connected {
862 peer_id,
863 connection_id,
864 },
865 cx,
866 );
867 cx.foreground()
868 .spawn({
869 let cx = cx.clone();
870 let this = self.clone();
871 async move {
872 while let Some(message) = incoming.next().await {
873 this.handle_message(message, &cx);
874 // Don't starve the main thread when receiving lots of messages at once.
875 smol::future::yield_now().await;
876 }
877 }
878 })
879 .detach();
880
881 let this = self.clone();
882 let cx = cx.clone();
883 cx.foreground()
884 .spawn(async move {
885 match handle_io.await {
886 Ok(()) => {
887 if this.status().borrow().clone()
888 == (Status::Connected {
889 connection_id,
890 peer_id,
891 })
892 {
893 this.set_status(Status::SignedOut, &cx);
894 }
895 }
896 Err(err) => {
897 log::error!("connection error: {:?}", err);
898 this.set_status(Status::ConnectionLost, &cx);
899 }
900 }
901 })
902 .detach();
903
904 Ok(())
905 }
906
907 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
908 #[cfg(any(test, feature = "test-support"))]
909 if let Some(callback) = self.authenticate.read().as_ref() {
910 return callback(cx);
911 }
912
913 self.authenticate_with_browser(cx)
914 }
915
916 fn establish_connection(
917 self: &Arc<Self>,
918 credentials: &Credentials,
919 cx: &AsyncAppContext,
920 ) -> Task<Result<Connection, EstablishConnectionError>> {
921 #[cfg(any(test, feature = "test-support"))]
922 if let Some(callback) = self.establish_connection.read().as_ref() {
923 return callback(credentials, cx);
924 }
925
926 self.establish_websocket_connection(credentials, cx)
927 }
928
929 async fn get_rpc_url(http: Arc<dyn HttpClient>, is_preview: bool) -> Result<Url> {
930 let preview_param = if is_preview { "?preview=1" } else { "" };
931 let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL);
932 let response = http.get(&url, Default::default(), false).await?;
933
934 // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
935 // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
936 // which requires authorization via an HTTP header.
937 //
938 // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
939 // of a collab server. In that case, a request to the /rpc endpoint will
940 // return an 'unauthorized' response.
941 let collab_url = if response.status().is_redirection() {
942 response
943 .headers()
944 .get("Location")
945 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
946 .to_str()
947 .map_err(EstablishConnectionError::other)?
948 .to_string()
949 } else if response.status() == StatusCode::UNAUTHORIZED {
950 url
951 } else {
952 Err(anyhow!(
953 "unexpected /rpc response status {}",
954 response.status()
955 ))?
956 };
957
958 Url::parse(&collab_url).context("invalid rpc url")
959 }
960
961 fn establish_websocket_connection(
962 self: &Arc<Self>,
963 credentials: &Credentials,
964 cx: &AsyncAppContext,
965 ) -> Task<Result<Connection, EstablishConnectionError>> {
966 let is_preview = cx.read(|cx| {
967 if cx.has_global::<ReleaseChannel>() {
968 *cx.global::<ReleaseChannel>() == ReleaseChannel::Preview
969 } else {
970 false
971 }
972 });
973
974 let request = Request::builder()
975 .header(
976 "Authorization",
977 format!("{} {}", credentials.user_id, credentials.access_token),
978 )
979 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
980
981 let http = self.http.clone();
982 cx.background().spawn(async move {
983 let mut rpc_url = Self::get_rpc_url(http, is_preview).await?;
984 let rpc_host = rpc_url
985 .host_str()
986 .zip(rpc_url.port_or_known_default())
987 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
988 let stream = smol::net::TcpStream::connect(rpc_host).await?;
989
990 log::info!("connected to rpc endpoint {}", rpc_url);
991
992 match rpc_url.scheme() {
993 "https" => {
994 rpc_url.set_scheme("wss").unwrap();
995 let request = request.uri(rpc_url.as_str()).body(())?;
996 let (stream, _) =
997 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
998 Ok(Connection::new(
999 stream
1000 .map_err(|error| anyhow!(error))
1001 .sink_map_err(|error| anyhow!(error)),
1002 ))
1003 }
1004 "http" => {
1005 rpc_url.set_scheme("ws").unwrap();
1006 let request = request.uri(rpc_url.as_str()).body(())?;
1007 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
1008 Ok(Connection::new(
1009 stream
1010 .map_err(|error| anyhow!(error))
1011 .sink_map_err(|error| anyhow!(error)),
1012 ))
1013 }
1014 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
1015 }
1016 })
1017 }
1018
1019 pub fn authenticate_with_browser(
1020 self: &Arc<Self>,
1021 cx: &AsyncAppContext,
1022 ) -> Task<Result<Credentials>> {
1023 let platform = cx.platform();
1024 let executor = cx.background();
1025 let telemetry = self.telemetry.clone();
1026 let http = self.http.clone();
1027 let metrics_enabled = cx.read(|cx| cx.global::<Settings>().telemetry());
1028
1029 executor.clone().spawn(async move {
1030 // Generate a pair of asymmetric encryption keys. The public key will be used by the
1031 // zed server to encrypt the user's access token, so that it can'be intercepted by
1032 // any other app running on the user's device.
1033 let (public_key, private_key) =
1034 rpc::auth::keypair().expect("failed to generate keypair for auth");
1035 let public_key_string =
1036 String::try_from(public_key).expect("failed to serialize public key for auth");
1037
1038 if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
1039 return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
1040 }
1041
1042 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1043 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1044 let port = server.server_addr().port();
1045
1046 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1047 // that the user is signing in from a Zed app running on the same device.
1048 let mut url = format!(
1049 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1050 *ZED_SERVER_URL, port, public_key_string
1051 );
1052
1053 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1054 log::info!("impersonating user @{}", impersonate_login);
1055 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1056 }
1057
1058 platform.open_url(&url);
1059
1060 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1061 // access token from the query params.
1062 //
1063 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1064 // custom URL scheme instead of this local HTTP server.
1065 let (user_id, access_token) = executor
1066 .spawn(async move {
1067 for _ in 0..100 {
1068 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1069 let path = req.url();
1070 let mut user_id = None;
1071 let mut access_token = None;
1072 let url = Url::parse(&format!("http://example.com{}", path))
1073 .context("failed to parse login notification url")?;
1074 for (key, value) in url.query_pairs() {
1075 if key == "access_token" {
1076 access_token = Some(value.to_string());
1077 } else if key == "user_id" {
1078 user_id = Some(value.to_string());
1079 }
1080 }
1081
1082 let post_auth_url =
1083 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1084 req.respond(
1085 tiny_http::Response::empty(302).with_header(
1086 tiny_http::Header::from_bytes(
1087 &b"Location"[..],
1088 post_auth_url.as_bytes(),
1089 )
1090 .unwrap(),
1091 ),
1092 )
1093 .context("failed to respond to login http request")?;
1094 return Ok((
1095 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1096 access_token
1097 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1098 ));
1099 }
1100 }
1101
1102 Err(anyhow!("didn't receive login redirect"))
1103 })
1104 .await?;
1105
1106 let access_token = private_key
1107 .decrypt_string(&access_token)
1108 .context("failed to decrypt access token")?;
1109 platform.activate(true);
1110
1111 telemetry.report_event(
1112 "authenticate with browser",
1113 Default::default(),
1114 metrics_enabled,
1115 );
1116
1117 Ok(Credentials {
1118 user_id: user_id.parse()?,
1119 access_token,
1120 })
1121 })
1122 }
1123
1124 async fn authenticate_as_admin(
1125 http: Arc<dyn HttpClient>,
1126 login: String,
1127 mut api_token: String,
1128 ) -> Result<Credentials> {
1129 #[derive(Deserialize)]
1130 struct AuthenticatedUserResponse {
1131 user: User,
1132 }
1133
1134 #[derive(Deserialize)]
1135 struct User {
1136 id: u64,
1137 }
1138
1139 // Use the collab server's admin API to retrieve the id
1140 // of the impersonated user.
1141 let mut url = Self::get_rpc_url(http.clone(), false).await?;
1142 url.set_path("/user");
1143 url.set_query(Some(&format!("github_login={login}")));
1144 let request = Request::get(url.as_str())
1145 .header("Authorization", format!("token {api_token}"))
1146 .body("".into())?;
1147
1148 let mut response = http.send(request).await?;
1149 let mut body = String::new();
1150 response.body_mut().read_to_string(&mut body).await?;
1151 if !response.status().is_success() {
1152 Err(anyhow!(
1153 "admin user request failed {} - {}",
1154 response.status().as_u16(),
1155 body,
1156 ))?;
1157 }
1158 let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
1159
1160 // Use the admin API token to authenticate as the impersonated user.
1161 api_token.insert_str(0, "ADMIN_TOKEN:");
1162 Ok(Credentials {
1163 user_id: response.user.id,
1164 access_token: api_token,
1165 })
1166 }
1167
1168 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
1169 self.peer.teardown();
1170 self.set_status(Status::SignedOut, cx);
1171 }
1172
1173 fn connection_id(&self) -> Result<ConnectionId> {
1174 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1175 Ok(connection_id)
1176 } else {
1177 Err(anyhow!("not connected"))
1178 }
1179 }
1180
1181 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1182 log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1183 self.peer.send(self.connection_id()?, message)
1184 }
1185
1186 pub fn request<T: RequestMessage>(
1187 &self,
1188 request: T,
1189 ) -> impl Future<Output = Result<T::Response>> {
1190 let client_id = self.id;
1191 log::debug!(
1192 "rpc request start. client_id:{}. name:{}",
1193 client_id,
1194 T::NAME
1195 );
1196 let response = self
1197 .connection_id()
1198 .map(|conn_id| self.peer.request(conn_id, request));
1199 async move {
1200 let response = response?.await;
1201 log::debug!(
1202 "rpc request finish. client_id:{}. name:{}",
1203 client_id,
1204 T::NAME
1205 );
1206 response
1207 }
1208 }
1209
1210 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1211 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1212 self.peer.respond(receipt, response)
1213 }
1214
1215 fn respond_with_error<T: RequestMessage>(
1216 &self,
1217 receipt: Receipt<T>,
1218 error: proto::Error,
1219 ) -> Result<()> {
1220 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1221 self.peer.respond_with_error(receipt, error)
1222 }
1223
1224 fn handle_message(
1225 self: &Arc<Client>,
1226 message: Box<dyn AnyTypedEnvelope>,
1227 cx: &AsyncAppContext,
1228 ) {
1229 let mut state = self.state.write();
1230 let type_name = message.payload_type_name();
1231 let payload_type_id = message.payload_type_id();
1232 let sender_id = message.original_sender_id();
1233
1234 let mut subscriber = None;
1235
1236 if let Some(message_model) = state
1237 .models_by_message_type
1238 .get(&payload_type_id)
1239 .and_then(|model| model.upgrade(cx))
1240 {
1241 subscriber = Some(Subscriber::Model(message_model));
1242 } else if let Some((extract_entity_id, entity_type_id)) =
1243 state.entity_id_extractors.get(&payload_type_id).zip(
1244 state
1245 .entity_types_by_message_type
1246 .get(&payload_type_id)
1247 .copied(),
1248 )
1249 {
1250 let entity_id = (extract_entity_id)(message.as_ref());
1251
1252 match state
1253 .entities_by_type_and_remote_id
1254 .get_mut(&(entity_type_id, entity_id))
1255 {
1256 Some(WeakSubscriber::Pending(pending)) => {
1257 pending.push(message);
1258 return;
1259 }
1260 Some(weak_subscriber @ _) => subscriber = weak_subscriber.upgrade(cx),
1261 _ => {}
1262 }
1263 }
1264
1265 let subscriber = if let Some(subscriber) = subscriber {
1266 subscriber
1267 } else {
1268 log::info!("unhandled message {}", type_name);
1269 self.peer.respond_with_unhandled_message(message).log_err();
1270 return;
1271 };
1272
1273 let handler = state.message_handlers.get(&payload_type_id).cloned();
1274 // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
1275 // It also ensures we don't hold the lock while yielding back to the executor, as
1276 // that might cause the executor thread driving this future to block indefinitely.
1277 drop(state);
1278
1279 if let Some(handler) = handler {
1280 let future = handler(subscriber, message, &self, cx.clone());
1281 let client_id = self.id;
1282 log::debug!(
1283 "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
1284 client_id,
1285 sender_id,
1286 type_name
1287 );
1288 cx.foreground()
1289 .spawn(async move {
1290 match future.await {
1291 Ok(()) => {
1292 log::debug!(
1293 "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
1294 client_id,
1295 sender_id,
1296 type_name
1297 );
1298 }
1299 Err(error) => {
1300 log::error!(
1301 "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
1302 client_id,
1303 sender_id,
1304 type_name,
1305 error
1306 );
1307 }
1308 }
1309 })
1310 .detach();
1311 } else {
1312 log::info!("unhandled message {}", type_name);
1313 self.peer.respond_with_unhandled_message(message).log_err();
1314 }
1315 }
1316
1317 pub fn start_telemetry(&self) {
1318 self.telemetry.start();
1319 }
1320
1321 pub fn report_event(
1322 &self,
1323 kind: &str,
1324 properties: Value,
1325 telemetry_settings: TelemetrySettings,
1326 ) {
1327 self.telemetry
1328 .report_event(kind, properties.clone(), telemetry_settings);
1329 }
1330
1331 pub fn telemetry_log_file_path(&self) -> Option<PathBuf> {
1332 self.telemetry.log_file_path()
1333 }
1334
1335 pub fn metrics_id(&self) -> Option<Arc<str>> {
1336 self.telemetry.metrics_id()
1337 }
1338
1339 pub fn is_staff(&self) -> Option<bool> {
1340 self.telemetry.is_staff()
1341 }
1342}
1343
1344impl WeakSubscriber {
1345 fn upgrade(&self, cx: &AsyncAppContext) -> Option<Subscriber> {
1346 match self {
1347 WeakSubscriber::Model(handle) => handle.upgrade(cx).map(Subscriber::Model),
1348 WeakSubscriber::View(handle) => handle.upgrade(cx).map(Subscriber::View),
1349 WeakSubscriber::Pending(_) => None,
1350 }
1351 }
1352}
1353
1354fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1355 if IMPERSONATE_LOGIN.is_some() {
1356 return None;
1357 }
1358
1359 let (user_id, access_token) = cx
1360 .platform()
1361 .read_credentials(&ZED_SERVER_URL)
1362 .log_err()
1363 .flatten()?;
1364 Some(Credentials {
1365 user_id: user_id.parse().ok()?,
1366 access_token: String::from_utf8(access_token).ok()?,
1367 })
1368}
1369
1370fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1371 cx.platform().write_credentials(
1372 &ZED_SERVER_URL,
1373 &credentials.user_id.to_string(),
1374 credentials.access_token.as_bytes(),
1375 )
1376}
1377
1378const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1379
1380pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1381 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1382}
1383
1384pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1385 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1386 let mut parts = path.split('/');
1387 let id = parts.next()?.parse::<u64>().ok()?;
1388 let access_token = parts.next()?;
1389 if access_token.is_empty() {
1390 return None;
1391 }
1392 Some((id, access_token.to_string()))
1393}
1394
1395#[cfg(test)]
1396mod tests {
1397 use super::*;
1398 use crate::test::FakeServer;
1399 use gpui::{executor::Deterministic, TestAppContext};
1400 use parking_lot::Mutex;
1401 use std::future;
1402 use util::http::FakeHttpClient;
1403
1404 #[gpui::test(iterations = 10)]
1405 async fn test_reconnection(cx: &mut TestAppContext) {
1406 cx.foreground().forbid_parking();
1407
1408 let user_id = 5;
1409 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1410 let server = FakeServer::for_client(user_id, &client, cx).await;
1411 let mut status = client.status();
1412 assert!(matches!(
1413 status.next().await,
1414 Some(Status::Connected { .. })
1415 ));
1416 assert_eq!(server.auth_count(), 1);
1417
1418 server.forbid_connections();
1419 server.disconnect();
1420 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1421
1422 server.allow_connections();
1423 cx.foreground().advance_clock(Duration::from_secs(10));
1424 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1425 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1426
1427 server.forbid_connections();
1428 server.disconnect();
1429 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1430
1431 // Clear cached credentials after authentication fails
1432 server.roll_access_token();
1433 server.allow_connections();
1434 cx.foreground().advance_clock(Duration::from_secs(10));
1435 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1436 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1437 }
1438
1439 #[gpui::test(iterations = 10)]
1440 async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1441 deterministic.forbid_parking();
1442
1443 let user_id = 5;
1444 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1445 let mut status = client.status();
1446
1447 // Time out when client tries to connect.
1448 client.override_authenticate(move |cx| {
1449 cx.foreground().spawn(async move {
1450 Ok(Credentials {
1451 user_id,
1452 access_token: "token".into(),
1453 })
1454 })
1455 });
1456 client.override_establish_connection(|_, cx| {
1457 cx.foreground().spawn(async move {
1458 future::pending::<()>().await;
1459 unreachable!()
1460 })
1461 });
1462 let auth_and_connect = cx.spawn({
1463 let client = client.clone();
1464 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1465 });
1466 deterministic.run_until_parked();
1467 assert!(matches!(status.next().await, Some(Status::Connecting)));
1468
1469 deterministic.advance_clock(CONNECTION_TIMEOUT);
1470 assert!(matches!(
1471 status.next().await,
1472 Some(Status::ConnectionError { .. })
1473 ));
1474 auth_and_connect.await.unwrap_err();
1475
1476 // Allow the connection to be established.
1477 let server = FakeServer::for_client(user_id, &client, cx).await;
1478 assert!(matches!(
1479 status.next().await,
1480 Some(Status::Connected { .. })
1481 ));
1482
1483 // Disconnect client.
1484 server.forbid_connections();
1485 server.disconnect();
1486 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1487
1488 // Time out when re-establishing the connection.
1489 server.allow_connections();
1490 client.override_establish_connection(|_, cx| {
1491 cx.foreground().spawn(async move {
1492 future::pending::<()>().await;
1493 unreachable!()
1494 })
1495 });
1496 deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1497 assert!(matches!(
1498 status.next().await,
1499 Some(Status::Reconnecting { .. })
1500 ));
1501
1502 deterministic.advance_clock(CONNECTION_TIMEOUT);
1503 assert!(matches!(
1504 status.next().await,
1505 Some(Status::ReconnectionError { .. })
1506 ));
1507 }
1508
1509 #[gpui::test(iterations = 10)]
1510 async fn test_authenticating_more_than_once(
1511 cx: &mut TestAppContext,
1512 deterministic: Arc<Deterministic>,
1513 ) {
1514 cx.foreground().forbid_parking();
1515
1516 let auth_count = Arc::new(Mutex::new(0));
1517 let dropped_auth_count = Arc::new(Mutex::new(0));
1518 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1519 client.override_authenticate({
1520 let auth_count = auth_count.clone();
1521 let dropped_auth_count = dropped_auth_count.clone();
1522 move |cx| {
1523 let auth_count = auth_count.clone();
1524 let dropped_auth_count = dropped_auth_count.clone();
1525 cx.foreground().spawn(async move {
1526 *auth_count.lock() += 1;
1527 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1528 future::pending::<()>().await;
1529 unreachable!()
1530 })
1531 }
1532 });
1533
1534 let _authenticate = cx.spawn(|cx| {
1535 let client = client.clone();
1536 async move { client.authenticate_and_connect(false, &cx).await }
1537 });
1538 deterministic.run_until_parked();
1539 assert_eq!(*auth_count.lock(), 1);
1540 assert_eq!(*dropped_auth_count.lock(), 0);
1541
1542 let _authenticate = cx.spawn(|cx| {
1543 let client = client.clone();
1544 async move { client.authenticate_and_connect(false, &cx).await }
1545 });
1546 deterministic.run_until_parked();
1547 assert_eq!(*auth_count.lock(), 2);
1548 assert_eq!(*dropped_auth_count.lock(), 1);
1549 }
1550
1551 #[test]
1552 fn test_encode_and_decode_worktree_url() {
1553 let url = encode_worktree_url(5, "deadbeef");
1554 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1555 assert_eq!(
1556 decode_worktree_url(&format!("\n {}\t", url)),
1557 Some((5, "deadbeef".to_string()))
1558 );
1559 assert_eq!(decode_worktree_url("not://the-right-format"), None);
1560 }
1561
1562 #[gpui::test]
1563 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1564 cx.foreground().forbid_parking();
1565
1566 let user_id = 5;
1567 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1568 let server = FakeServer::for_client(user_id, &client, cx).await;
1569
1570 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1571 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1572 client.add_model_message_handler(
1573 move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1574 match model.read_with(&cx, |model, _| model.id) {
1575 1 => done_tx1.try_send(()).unwrap(),
1576 2 => done_tx2.try_send(()).unwrap(),
1577 _ => unreachable!(),
1578 }
1579 async { Ok(()) }
1580 },
1581 );
1582 let model1 = cx.add_model(|_| Model {
1583 id: 1,
1584 subscription: None,
1585 });
1586 let model2 = cx.add_model(|_| Model {
1587 id: 2,
1588 subscription: None,
1589 });
1590 let model3 = cx.add_model(|_| Model {
1591 id: 3,
1592 subscription: None,
1593 });
1594
1595 let _subscription1 = client
1596 .subscribe_to_entity(1)
1597 .set_model(&model1, &mut cx.to_async());
1598 let _subscription2 = client
1599 .subscribe_to_entity(2)
1600 .set_model(&model2, &mut cx.to_async());
1601 // Ensure dropping a subscription for the same entity type still allows receiving of
1602 // messages for other entity IDs of the same type.
1603 let subscription3 = client
1604 .subscribe_to_entity(3)
1605 .set_model(&model3, &mut cx.to_async());
1606 drop(subscription3);
1607
1608 server.send(proto::JoinProject { project_id: 1 });
1609 server.send(proto::JoinProject { project_id: 2 });
1610 done_rx1.next().await.unwrap();
1611 done_rx2.next().await.unwrap();
1612 }
1613
1614 #[gpui::test]
1615 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1616 cx.foreground().forbid_parking();
1617
1618 let user_id = 5;
1619 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1620 let server = FakeServer::for_client(user_id, &client, cx).await;
1621
1622 let model = cx.add_model(|_| Model::default());
1623 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1624 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1625 let subscription1 = client.add_message_handler(
1626 model.clone(),
1627 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1628 done_tx1.try_send(()).unwrap();
1629 async { Ok(()) }
1630 },
1631 );
1632 drop(subscription1);
1633 let _subscription2 =
1634 client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1635 done_tx2.try_send(()).unwrap();
1636 async { Ok(()) }
1637 });
1638 server.send(proto::Ping {});
1639 done_rx2.next().await.unwrap();
1640 }
1641
1642 #[gpui::test]
1643 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1644 cx.foreground().forbid_parking();
1645
1646 let user_id = 5;
1647 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1648 let server = FakeServer::for_client(user_id, &client, cx).await;
1649
1650 let model = cx.add_model(|_| Model::default());
1651 let (done_tx, mut done_rx) = smol::channel::unbounded();
1652 let subscription = client.add_message_handler(
1653 model.clone(),
1654 move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1655 model.update(&mut cx, |model, _| model.subscription.take());
1656 done_tx.try_send(()).unwrap();
1657 async { Ok(()) }
1658 },
1659 );
1660 model.update(cx, |model, _| {
1661 model.subscription = Some(subscription);
1662 });
1663 server.send(proto::Ping {});
1664 done_rx.next().await.unwrap();
1665 }
1666
1667 #[derive(Default)]
1668 struct Model {
1669 id: usize,
1670 subscription: Option<Subscription>,
1671 }
1672
1673 impl Entity for Model {
1674 type Event = ();
1675 }
1676}