1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod channel;
5pub mod http;
6pub mod telemetry;
7pub mod user;
8
9use anyhow::{anyhow, Context, Result};
10use async_recursion::async_recursion;
11use async_tungstenite::tungstenite::{
12 error::Error as WebsocketError,
13 http::{Request, StatusCode},
14};
15use db::Db;
16use futures::{future::LocalBoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt};
17use gpui::{
18 actions, serde_json::Value, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle,
19 AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle,
20 MutableAppContext, Task, View, ViewContext, ViewHandle,
21};
22use http::HttpClient;
23use lazy_static::lazy_static;
24use parking_lot::RwLock;
25use postage::watch;
26use rand::prelude::*;
27use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
28use std::{
29 any::TypeId,
30 collections::HashMap,
31 convert::TryFrom,
32 fmt::Write as _,
33 future::Future,
34 path::PathBuf,
35 sync::{Arc, Weak},
36 time::{Duration, Instant},
37};
38use telemetry::Telemetry;
39use thiserror::Error;
40use url::Url;
41use util::{ResultExt, TryFutureExt};
42
43pub use channel::*;
44pub use rpc::*;
45pub use user::*;
46
47lazy_static! {
48 pub static ref ZED_SERVER_URL: String =
49 std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string());
50 pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
51 .ok()
52 .and_then(|s| if s.is_empty() { None } else { Some(s) });
53}
54
55pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
56pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
57pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
58
59actions!(client, [Authenticate]);
60
61pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
62 cx.add_global_action({
63 let client = client.clone();
64 move |_: &Authenticate, cx| {
65 let client = client.clone();
66 cx.spawn(
67 |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await },
68 )
69 .detach();
70 }
71 });
72}
73
74pub struct Client {
75 id: usize,
76 peer: Arc<Peer>,
77 http: Arc<dyn HttpClient>,
78 telemetry: Arc<Telemetry>,
79 state: RwLock<ClientState>,
80
81 #[allow(clippy::type_complexity)]
82 #[cfg(any(test, feature = "test-support"))]
83 authenticate: RwLock<
84 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
85 >,
86
87 #[allow(clippy::type_complexity)]
88 #[cfg(any(test, feature = "test-support"))]
89 establish_connection: RwLock<
90 Option<
91 Box<
92 dyn 'static
93 + Send
94 + Sync
95 + Fn(
96 &Credentials,
97 &AsyncAppContext,
98 ) -> Task<Result<Connection, EstablishConnectionError>>,
99 >,
100 >,
101 >,
102}
103
104#[derive(Error, Debug)]
105pub enum EstablishConnectionError {
106 #[error("upgrade required")]
107 UpgradeRequired,
108 #[error("unauthorized")]
109 Unauthorized,
110 #[error("{0}")]
111 Other(#[from] anyhow::Error),
112 #[error("{0}")]
113 Http(#[from] http::Error),
114 #[error("{0}")]
115 Io(#[from] std::io::Error),
116 #[error("{0}")]
117 Websocket(#[from] async_tungstenite::tungstenite::http::Error),
118}
119
120impl From<WebsocketError> for EstablishConnectionError {
121 fn from(error: WebsocketError) -> Self {
122 if let WebsocketError::Http(response) = &error {
123 match response.status() {
124 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
125 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
126 _ => {}
127 }
128 }
129 EstablishConnectionError::Other(error.into())
130 }
131}
132
133impl EstablishConnectionError {
134 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
135 Self::Other(error.into())
136 }
137}
138
139#[derive(Copy, Clone, Debug, Eq, PartialEq)]
140pub enum Status {
141 SignedOut,
142 UpgradeRequired,
143 Authenticating,
144 Connecting,
145 ConnectionError,
146 Connected {
147 peer_id: PeerId,
148 connection_id: ConnectionId,
149 },
150 ConnectionLost,
151 Reauthenticating,
152 Reconnecting,
153 ReconnectionError {
154 next_reconnection: Instant,
155 },
156}
157
158impl Status {
159 pub fn is_connected(&self) -> bool {
160 matches!(self, Self::Connected { .. })
161 }
162}
163
164struct ClientState {
165 credentials: Option<Credentials>,
166 status: (watch::Sender<Status>, watch::Receiver<Status>),
167 entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
168 _reconnect_task: Option<Task<()>>,
169 reconnect_interval: Duration,
170 entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
171 models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
172 entity_types_by_message_type: HashMap<TypeId, TypeId>,
173 #[allow(clippy::type_complexity)]
174 message_handlers: HashMap<
175 TypeId,
176 Arc<
177 dyn Send
178 + Sync
179 + Fn(
180 AnyEntityHandle,
181 Box<dyn AnyTypedEnvelope>,
182 &Arc<Client>,
183 AsyncAppContext,
184 ) -> LocalBoxFuture<'static, Result<()>>,
185 >,
186 >,
187}
188
189enum AnyWeakEntityHandle {
190 Model(AnyWeakModelHandle),
191 View(AnyWeakViewHandle),
192}
193
194enum AnyEntityHandle {
195 Model(AnyModelHandle),
196 View(AnyViewHandle),
197}
198
199#[derive(Clone, Debug)]
200pub struct Credentials {
201 pub user_id: u64,
202 pub access_token: String,
203}
204
205impl Default for ClientState {
206 fn default() -> Self {
207 Self {
208 credentials: None,
209 status: watch::channel_with(Status::SignedOut),
210 entity_id_extractors: Default::default(),
211 _reconnect_task: None,
212 reconnect_interval: Duration::from_secs(5),
213 models_by_message_type: Default::default(),
214 entities_by_type_and_remote_id: Default::default(),
215 entity_types_by_message_type: Default::default(),
216 message_handlers: Default::default(),
217 }
218 }
219}
220
221pub enum Subscription {
222 Entity {
223 client: Weak<Client>,
224 id: (TypeId, u64),
225 },
226 Message {
227 client: Weak<Client>,
228 id: TypeId,
229 },
230}
231
232impl Drop for Subscription {
233 fn drop(&mut self) {
234 match self {
235 Subscription::Entity { client, id } => {
236 if let Some(client) = client.upgrade() {
237 let mut state = client.state.write();
238 let _ = state.entities_by_type_and_remote_id.remove(id);
239 }
240 }
241 Subscription::Message { client, id } => {
242 if let Some(client) = client.upgrade() {
243 let mut state = client.state.write();
244 let _ = state.entity_types_by_message_type.remove(id);
245 let _ = state.message_handlers.remove(id);
246 }
247 }
248 }
249 }
250}
251
252impl Client {
253 pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
254 Arc::new(Self {
255 id: 0,
256 peer: Peer::new(),
257 telemetry: Telemetry::new(http.clone(), cx),
258 http,
259 state: Default::default(),
260
261 #[cfg(any(test, feature = "test-support"))]
262 authenticate: Default::default(),
263 #[cfg(any(test, feature = "test-support"))]
264 establish_connection: Default::default(),
265 })
266 }
267
268 pub fn id(&self) -> usize {
269 self.id
270 }
271
272 pub fn http_client(&self) -> Arc<dyn HttpClient> {
273 self.http.clone()
274 }
275
276 #[cfg(any(test, feature = "test-support"))]
277 pub fn set_id(&mut self, id: usize) -> &Self {
278 self.id = id;
279 self
280 }
281
282 #[cfg(any(test, feature = "test-support"))]
283 pub fn tear_down(&self) {
284 let mut state = self.state.write();
285 state._reconnect_task.take();
286 state.message_handlers.clear();
287 state.models_by_message_type.clear();
288 state.entities_by_type_and_remote_id.clear();
289 state.entity_id_extractors.clear();
290 self.peer.reset();
291 }
292
293 #[cfg(any(test, feature = "test-support"))]
294 pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
295 where
296 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
297 {
298 *self.authenticate.write() = Some(Box::new(authenticate));
299 self
300 }
301
302 #[cfg(any(test, feature = "test-support"))]
303 pub fn override_establish_connection<F>(&self, connect: F) -> &Self
304 where
305 F: 'static
306 + Send
307 + Sync
308 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
309 {
310 *self.establish_connection.write() = Some(Box::new(connect));
311 self
312 }
313
314 pub fn user_id(&self) -> Option<u64> {
315 self.state
316 .read()
317 .credentials
318 .as_ref()
319 .map(|credentials| credentials.user_id)
320 }
321
322 pub fn status(&self) -> watch::Receiver<Status> {
323 self.state.read().status.1.clone()
324 }
325
326 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
327 log::info!("set status on client {}: {:?}", self.id, status);
328 let mut state = self.state.write();
329 *state.status.0.borrow_mut() = status;
330
331 match status {
332 Status::Connected { .. } => {
333 state._reconnect_task = None;
334 }
335 Status::ConnectionLost => {
336 let this = self.clone();
337 let reconnect_interval = state.reconnect_interval;
338 state._reconnect_task = Some(cx.spawn(|cx| async move {
339 let mut rng = StdRng::from_entropy();
340 let mut delay = INITIAL_RECONNECTION_DELAY;
341 while let Err(error) = this.authenticate_and_connect(true, &cx).await {
342 log::error!("failed to connect {}", error);
343 if matches!(*this.status().borrow(), Status::ConnectionError) {
344 this.set_status(
345 Status::ReconnectionError {
346 next_reconnection: Instant::now() + delay,
347 },
348 &cx,
349 );
350 cx.background().timer(delay).await;
351 delay = delay
352 .mul_f32(rng.gen_range(1.0..=2.0))
353 .min(reconnect_interval);
354 } else {
355 break;
356 }
357 }
358 }));
359 }
360 Status::SignedOut | Status::UpgradeRequired => {
361 self.telemetry.set_authenticated_user_info(None, false);
362 state._reconnect_task.take();
363 }
364 _ => {}
365 }
366 }
367
368 pub fn add_view_for_remote_entity<T: View>(
369 self: &Arc<Self>,
370 remote_id: u64,
371 cx: &mut ViewContext<T>,
372 ) -> Subscription {
373 let id = (TypeId::of::<T>(), remote_id);
374 self.state
375 .write()
376 .entities_by_type_and_remote_id
377 .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
378 Subscription::Entity {
379 client: Arc::downgrade(self),
380 id,
381 }
382 }
383
384 pub fn add_model_for_remote_entity<T: Entity>(
385 self: &Arc<Self>,
386 remote_id: u64,
387 cx: &mut ModelContext<T>,
388 ) -> Subscription {
389 let id = (TypeId::of::<T>(), remote_id);
390 self.state
391 .write()
392 .entities_by_type_and_remote_id
393 .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
394 Subscription::Entity {
395 client: Arc::downgrade(self),
396 id,
397 }
398 }
399
400 pub fn add_message_handler<M, E, H, F>(
401 self: &Arc<Self>,
402 model: ModelHandle<E>,
403 handler: H,
404 ) -> Subscription
405 where
406 M: EnvelopedMessage,
407 E: Entity,
408 H: 'static
409 + Send
410 + Sync
411 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
412 F: 'static + Future<Output = Result<()>>,
413 {
414 let message_type_id = TypeId::of::<M>();
415
416 let mut state = self.state.write();
417 state
418 .models_by_message_type
419 .insert(message_type_id, model.downgrade().into());
420
421 let prev_handler = state.message_handlers.insert(
422 message_type_id,
423 Arc::new(move |handle, envelope, client, cx| {
424 let handle = if let AnyEntityHandle::Model(handle) = handle {
425 handle
426 } else {
427 unreachable!();
428 };
429 let model = handle.downcast::<E>().unwrap();
430 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
431 handler(model, *envelope, client.clone(), cx).boxed_local()
432 }),
433 );
434 if prev_handler.is_some() {
435 panic!("registered handler for the same message twice");
436 }
437
438 Subscription::Message {
439 client: Arc::downgrade(self),
440 id: message_type_id,
441 }
442 }
443
444 pub fn add_request_handler<M, E, H, F>(
445 self: &Arc<Self>,
446 model: ModelHandle<E>,
447 handler: H,
448 ) -> Subscription
449 where
450 M: RequestMessage,
451 E: Entity,
452 H: 'static
453 + Send
454 + Sync
455 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
456 F: 'static + Future<Output = Result<M::Response>>,
457 {
458 self.add_message_handler(model, move |handle, envelope, this, cx| {
459 Self::respond_to_request(
460 envelope.receipt(),
461 handler(handle, envelope, this.clone(), cx),
462 this,
463 )
464 })
465 }
466
467 pub fn add_view_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
468 where
469 M: EntityMessage,
470 E: View,
471 H: 'static
472 + Send
473 + Sync
474 + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
475 F: 'static + Future<Output = Result<()>>,
476 {
477 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
478 if let AnyEntityHandle::View(handle) = handle {
479 handler(handle.downcast::<E>().unwrap(), message, client, cx)
480 } else {
481 unreachable!();
482 }
483 })
484 }
485
486 pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
487 where
488 M: EntityMessage,
489 E: Entity,
490 H: 'static
491 + Send
492 + Sync
493 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
494 F: 'static + Future<Output = Result<()>>,
495 {
496 self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
497 if let AnyEntityHandle::Model(handle) = handle {
498 handler(handle.downcast::<E>().unwrap(), message, client, cx)
499 } else {
500 unreachable!();
501 }
502 })
503 }
504
505 fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
506 where
507 M: EntityMessage,
508 E: Entity,
509 H: 'static
510 + Send
511 + Sync
512 + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
513 F: 'static + Future<Output = Result<()>>,
514 {
515 let model_type_id = TypeId::of::<E>();
516 let message_type_id = TypeId::of::<M>();
517
518 let mut state = self.state.write();
519 state
520 .entity_types_by_message_type
521 .insert(message_type_id, model_type_id);
522 state
523 .entity_id_extractors
524 .entry(message_type_id)
525 .or_insert_with(|| {
526 |envelope| {
527 envelope
528 .as_any()
529 .downcast_ref::<TypedEnvelope<M>>()
530 .unwrap()
531 .payload
532 .remote_entity_id()
533 }
534 });
535 let prev_handler = state.message_handlers.insert(
536 message_type_id,
537 Arc::new(move |handle, envelope, client, cx| {
538 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
539 handler(handle, *envelope, client.clone(), cx).boxed_local()
540 }),
541 );
542 if prev_handler.is_some() {
543 panic!("registered handler for the same message twice");
544 }
545 }
546
547 pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
548 where
549 M: EntityMessage + RequestMessage,
550 E: Entity,
551 H: 'static
552 + Send
553 + Sync
554 + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
555 F: 'static + Future<Output = Result<M::Response>>,
556 {
557 self.add_model_message_handler(move |entity, envelope, client, cx| {
558 Self::respond_to_request::<M, _>(
559 envelope.receipt(),
560 handler(entity, envelope, client.clone(), cx),
561 client,
562 )
563 })
564 }
565
566 pub fn add_view_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
567 where
568 M: EntityMessage + RequestMessage,
569 E: View,
570 H: 'static
571 + Send
572 + Sync
573 + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
574 F: 'static + Future<Output = Result<M::Response>>,
575 {
576 self.add_view_message_handler(move |entity, envelope, client, cx| {
577 Self::respond_to_request::<M, _>(
578 envelope.receipt(),
579 handler(entity, envelope, client.clone(), cx),
580 client,
581 )
582 })
583 }
584
585 async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
586 receipt: Receipt<T>,
587 response: F,
588 client: Arc<Self>,
589 ) -> Result<()> {
590 match response.await {
591 Ok(response) => {
592 client.respond(receipt, response)?;
593 Ok(())
594 }
595 Err(error) => {
596 client.respond_with_error(
597 receipt,
598 proto::Error {
599 message: format!("{:?}", error),
600 },
601 )?;
602 Err(error)
603 }
604 }
605 }
606
607 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
608 read_credentials_from_keychain(cx).is_some()
609 }
610
611 #[async_recursion(?Send)]
612 pub async fn authenticate_and_connect(
613 self: &Arc<Self>,
614 try_keychain: bool,
615 cx: &AsyncAppContext,
616 ) -> anyhow::Result<()> {
617 let was_disconnected = match *self.status().borrow() {
618 Status::SignedOut => true,
619 Status::ConnectionError
620 | Status::ConnectionLost
621 | Status::Authenticating { .. }
622 | Status::Reauthenticating { .. }
623 | Status::ReconnectionError { .. } => false,
624 Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
625 return Ok(())
626 }
627 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
628 };
629
630 if was_disconnected {
631 self.set_status(Status::Authenticating, cx);
632 } else {
633 self.set_status(Status::Reauthenticating, cx)
634 }
635
636 let mut read_from_keychain = false;
637 let mut credentials = self.state.read().credentials.clone();
638 if credentials.is_none() && try_keychain {
639 credentials = read_credentials_from_keychain(cx);
640 read_from_keychain = credentials.is_some();
641 if read_from_keychain {
642 self.report_event("read credentials from keychain", Default::default());
643 }
644 }
645 if credentials.is_none() {
646 let mut status_rx = self.status();
647 let _ = status_rx.next().await;
648 futures::select_biased! {
649 authenticate = self.authenticate(cx).fuse() => {
650 match authenticate {
651 Ok(creds) => credentials = Some(creds),
652 Err(err) => {
653 self.set_status(Status::ConnectionError, cx);
654 return Err(err);
655 }
656 }
657 }
658 _ = status_rx.next().fuse() => {
659 return Err(anyhow!("authentication canceled"));
660 }
661 }
662 }
663 let credentials = credentials.unwrap();
664
665 if was_disconnected {
666 self.set_status(Status::Connecting, cx);
667 } else {
668 self.set_status(Status::Reconnecting, cx);
669 }
670
671 let mut timeout = cx.background().timer(CONNECTION_TIMEOUT).fuse();
672 futures::select_biased! {
673 connection = self.establish_connection(&credentials, cx).fuse() => {
674 match connection {
675 Ok(conn) => {
676 self.state.write().credentials = Some(credentials.clone());
677 if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
678 write_credentials_to_keychain(&credentials, cx).log_err();
679 }
680
681 futures::select_biased! {
682 result = self.set_connection(conn, cx).fuse() => result,
683 _ = timeout => {
684 self.set_status(Status::ConnectionError, cx);
685 Err(anyhow!("timed out waiting on hello message from server"))
686 }
687 }
688 }
689 Err(EstablishConnectionError::Unauthorized) => {
690 self.state.write().credentials.take();
691 if read_from_keychain {
692 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
693 self.set_status(Status::SignedOut, cx);
694 self.authenticate_and_connect(false, cx).await
695 } else {
696 self.set_status(Status::ConnectionError, cx);
697 Err(EstablishConnectionError::Unauthorized)?
698 }
699 }
700 Err(EstablishConnectionError::UpgradeRequired) => {
701 self.set_status(Status::UpgradeRequired, cx);
702 Err(EstablishConnectionError::UpgradeRequired)?
703 }
704 Err(error) => {
705 self.set_status(Status::ConnectionError, cx);
706 Err(error)?
707 }
708 }
709 }
710 _ = &mut timeout => {
711 self.set_status(Status::ConnectionError, cx);
712 Err(anyhow!("timed out trying to establish connection"))
713 }
714 }
715 }
716
717 async fn set_connection(
718 self: &Arc<Self>,
719 conn: Connection,
720 cx: &AsyncAppContext,
721 ) -> Result<()> {
722 let executor = cx.background();
723 log::info!("add connection to peer");
724 let (connection_id, handle_io, mut incoming) = self
725 .peer
726 .add_connection(conn, move |duration| executor.timer(duration));
727 let handle_io = cx.background().spawn(handle_io);
728
729 let peer_id = async {
730 log::info!("waiting for server hello");
731 let message = incoming
732 .next()
733 .await
734 .ok_or_else(|| anyhow!("no hello message received"))?;
735 log::info!("got server hello");
736 let hello_message_type_name = message.payload_type_name().to_string();
737 let hello = message
738 .into_any()
739 .downcast::<TypedEnvelope<proto::Hello>>()
740 .map_err(|_| {
741 anyhow!(
742 "invalid hello message received: {:?}",
743 hello_message_type_name
744 )
745 })?;
746 Ok(PeerId(hello.payload.peer_id))
747 };
748
749 let peer_id = match peer_id.await {
750 Ok(peer_id) => peer_id,
751 Err(error) => {
752 self.peer.disconnect(connection_id);
753 return Err(error);
754 }
755 };
756
757 log::info!(
758 "set status to connected (connection id: {}, peer id: {})",
759 connection_id,
760 peer_id
761 );
762 self.set_status(
763 Status::Connected {
764 peer_id,
765 connection_id,
766 },
767 cx,
768 );
769 cx.foreground()
770 .spawn({
771 let cx = cx.clone();
772 let this = self.clone();
773 async move {
774 let mut message_id = 0_usize;
775 while let Some(message) = incoming.next().await {
776 let mut state = this.state.write();
777 message_id += 1;
778 let type_name = message.payload_type_name();
779 let payload_type_id = message.payload_type_id();
780 let sender_id = message.original_sender_id().map(|id| id.0);
781
782 let model = state
783 .models_by_message_type
784 .get(&payload_type_id)
785 .and_then(|model| model.upgrade(&cx))
786 .map(AnyEntityHandle::Model)
787 .or_else(|| {
788 let entity_type_id =
789 *state.entity_types_by_message_type.get(&payload_type_id)?;
790 let entity_id = state
791 .entity_id_extractors
792 .get(&message.payload_type_id())
793 .map(|extract_entity_id| {
794 (extract_entity_id)(message.as_ref())
795 })?;
796
797 let entity = state
798 .entities_by_type_and_remote_id
799 .get(&(entity_type_id, entity_id))?;
800 if let Some(entity) = entity.upgrade(&cx) {
801 Some(entity)
802 } else {
803 state
804 .entities_by_type_and_remote_id
805 .remove(&(entity_type_id, entity_id));
806 None
807 }
808 });
809
810 let model = if let Some(model) = model {
811 model
812 } else {
813 log::info!("unhandled message {}", type_name);
814 continue;
815 };
816
817 if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
818 {
819 drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
820 let future = handler(model, message, &this, cx.clone());
821
822 let client_id = this.id;
823 log::debug!(
824 "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
825 client_id,
826 message_id,
827 sender_id,
828 type_name
829 );
830 cx.foreground()
831 .spawn(async move {
832 match future.await {
833 Ok(()) => {
834 log::debug!(
835 "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
836 client_id,
837 message_id,
838 sender_id,
839 type_name
840 );
841 }
842 Err(error) => {
843 log::error!(
844 "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
845 client_id,
846 message_id,
847 sender_id,
848 type_name,
849 error
850 );
851 }
852 }
853 })
854 .detach();
855 } else {
856 log::info!("unhandled message {}", type_name);
857 }
858
859 // Don't starve the main thread when receiving lots of messages at once.
860 smol::future::yield_now().await;
861 }
862 }
863 })
864 .detach();
865
866 let this = self.clone();
867 let cx = cx.clone();
868 cx.foreground()
869 .spawn(async move {
870 match handle_io.await {
871 Ok(()) => {
872 if *this.status().borrow()
873 == (Status::Connected {
874 connection_id,
875 peer_id,
876 })
877 {
878 this.set_status(Status::SignedOut, &cx);
879 }
880 }
881 Err(err) => {
882 log::error!("connection error: {:?}", err);
883 this.set_status(Status::ConnectionLost, &cx);
884 }
885 }
886 })
887 .detach();
888
889 Ok(())
890 }
891
892 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
893 #[cfg(any(test, feature = "test-support"))]
894 if let Some(callback) = self.authenticate.read().as_ref() {
895 return callback(cx);
896 }
897
898 self.authenticate_with_browser(cx)
899 }
900
901 fn establish_connection(
902 self: &Arc<Self>,
903 credentials: &Credentials,
904 cx: &AsyncAppContext,
905 ) -> Task<Result<Connection, EstablishConnectionError>> {
906 #[cfg(any(test, feature = "test-support"))]
907 if let Some(callback) = self.establish_connection.read().as_ref() {
908 return callback(credentials, cx);
909 }
910
911 self.establish_websocket_connection(credentials, cx)
912 }
913
914 fn establish_websocket_connection(
915 self: &Arc<Self>,
916 credentials: &Credentials,
917 cx: &AsyncAppContext,
918 ) -> Task<Result<Connection, EstablishConnectionError>> {
919 let request = Request::builder()
920 .header(
921 "Authorization",
922 format!("{} {}", credentials.user_id, credentials.access_token),
923 )
924 .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION);
925
926 let http = self.http.clone();
927 cx.background().spawn(async move {
928 let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
929 let rpc_response = http.get(&rpc_url, Default::default(), false).await?;
930 if rpc_response.status().is_redirection() {
931 rpc_url = rpc_response
932 .headers()
933 .get("Location")
934 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
935 .to_str()
936 .map_err(EstablishConnectionError::other)?
937 .to_string();
938 }
939 // Until we switch the zed.dev domain to point to the new Next.js app, there
940 // will be no redirect required, and the app will connect directly to
941 // wss://zed.dev/rpc.
942 else if rpc_response.status() != StatusCode::UPGRADE_REQUIRED {
943 Err(anyhow!(
944 "unexpected /rpc response status {}",
945 rpc_response.status()
946 ))?
947 }
948
949 let mut rpc_url = Url::parse(&rpc_url).context("invalid rpc url")?;
950 let rpc_host = rpc_url
951 .host_str()
952 .zip(rpc_url.port_or_known_default())
953 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
954 let stream = smol::net::TcpStream::connect(rpc_host).await?;
955
956 log::info!("connected to rpc endpoint {}", rpc_url);
957
958 match rpc_url.scheme() {
959 "https" => {
960 rpc_url.set_scheme("wss").unwrap();
961 let request = request.uri(rpc_url.as_str()).body(())?;
962 let (stream, _) =
963 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
964 Ok(Connection::new(
965 stream
966 .map_err(|error| anyhow!(error))
967 .sink_map_err(|error| anyhow!(error)),
968 ))
969 }
970 "http" => {
971 rpc_url.set_scheme("ws").unwrap();
972 let request = request.uri(rpc_url.as_str()).body(())?;
973 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
974 Ok(Connection::new(
975 stream
976 .map_err(|error| anyhow!(error))
977 .sink_map_err(|error| anyhow!(error)),
978 ))
979 }
980 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
981 }
982 })
983 }
984
985 pub fn authenticate_with_browser(
986 self: &Arc<Self>,
987 cx: &AsyncAppContext,
988 ) -> Task<Result<Credentials>> {
989 let platform = cx.platform();
990 let executor = cx.background();
991 let telemetry = self.telemetry.clone();
992 executor.clone().spawn(async move {
993 // Generate a pair of asymmetric encryption keys. The public key will be used by the
994 // zed server to encrypt the user's access token, so that it can'be intercepted by
995 // any other app running on the user's device.
996 let (public_key, private_key) =
997 rpc::auth::keypair().expect("failed to generate keypair for auth");
998 let public_key_string =
999 String::try_from(public_key).expect("failed to serialize public key for auth");
1000
1001 // Start an HTTP server to receive the redirect from Zed's sign-in page.
1002 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
1003 let port = server.server_addr().port();
1004
1005 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
1006 // that the user is signing in from a Zed app running on the same device.
1007 let mut url = format!(
1008 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
1009 *ZED_SERVER_URL, port, public_key_string
1010 );
1011
1012 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
1013 log::info!("impersonating user @{}", impersonate_login);
1014 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
1015 }
1016
1017 platform.open_url(&url);
1018
1019 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
1020 // access token from the query params.
1021 //
1022 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
1023 // custom URL scheme instead of this local HTTP server.
1024 let (user_id, access_token) = executor
1025 .spawn(async move {
1026 for _ in 0..100 {
1027 if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
1028 let path = req.url();
1029 let mut user_id = None;
1030 let mut access_token = None;
1031 let url = Url::parse(&format!("http://example.com{}", path))
1032 .context("failed to parse login notification url")?;
1033 for (key, value) in url.query_pairs() {
1034 if key == "access_token" {
1035 access_token = Some(value.to_string());
1036 } else if key == "user_id" {
1037 user_id = Some(value.to_string());
1038 }
1039 }
1040
1041 let post_auth_url =
1042 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
1043 req.respond(
1044 tiny_http::Response::empty(302).with_header(
1045 tiny_http::Header::from_bytes(
1046 &b"Location"[..],
1047 post_auth_url.as_bytes(),
1048 )
1049 .unwrap(),
1050 ),
1051 )
1052 .context("failed to respond to login http request")?;
1053 return Ok((
1054 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
1055 access_token
1056 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
1057 ));
1058 }
1059 }
1060
1061 Err(anyhow!("didn't receive login redirect"))
1062 })
1063 .await?;
1064
1065 let access_token = private_key
1066 .decrypt_string(&access_token)
1067 .context("failed to decrypt access token")?;
1068 platform.activate(true);
1069
1070 telemetry.report_event("authenticate with browser", Default::default());
1071
1072 Ok(Credentials {
1073 user_id: user_id.parse()?,
1074 access_token,
1075 })
1076 })
1077 }
1078
1079 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
1080 let conn_id = self.connection_id()?;
1081 self.peer.disconnect(conn_id);
1082 self.set_status(Status::SignedOut, cx);
1083 Ok(())
1084 }
1085
1086 fn connection_id(&self) -> Result<ConnectionId> {
1087 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
1088 Ok(connection_id)
1089 } else {
1090 Err(anyhow!("not connected"))
1091 }
1092 }
1093
1094 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
1095 log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
1096 self.peer.send(self.connection_id()?, message)
1097 }
1098
1099 pub fn request<T: RequestMessage>(
1100 &self,
1101 request: T,
1102 ) -> impl Future<Output = Result<T::Response>> {
1103 let client_id = self.id;
1104 log::debug!(
1105 "rpc request start. client_id:{}. name:{}",
1106 client_id,
1107 T::NAME
1108 );
1109 let response = self
1110 .connection_id()
1111 .map(|conn_id| self.peer.request(conn_id, request));
1112 async move {
1113 let response = response?.await;
1114 log::debug!(
1115 "rpc request finish. client_id:{}. name:{}",
1116 client_id,
1117 T::NAME
1118 );
1119 response
1120 }
1121 }
1122
1123 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
1124 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1125 self.peer.respond(receipt, response)
1126 }
1127
1128 fn respond_with_error<T: RequestMessage>(
1129 &self,
1130 receipt: Receipt<T>,
1131 error: proto::Error,
1132 ) -> Result<()> {
1133 log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
1134 self.peer.respond_with_error(receipt, error)
1135 }
1136
1137 pub fn start_telemetry(&self, db: Arc<Db>) {
1138 self.telemetry.start(db);
1139 }
1140
1141 pub fn report_event(&self, kind: &str, properties: Value) {
1142 self.telemetry.report_event(kind, properties)
1143 }
1144
1145 pub fn telemetry_log_file_path(&self) -> Option<PathBuf> {
1146 self.telemetry.log_file_path()
1147 }
1148}
1149
1150impl AnyWeakEntityHandle {
1151 fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
1152 match self {
1153 AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
1154 AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
1155 }
1156 }
1157}
1158
1159fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
1160 if IMPERSONATE_LOGIN.is_some() {
1161 return None;
1162 }
1163
1164 let (user_id, access_token) = cx
1165 .platform()
1166 .read_credentials(&ZED_SERVER_URL)
1167 .log_err()
1168 .flatten()?;
1169 Some(Credentials {
1170 user_id: user_id.parse().ok()?,
1171 access_token: String::from_utf8(access_token).ok()?,
1172 })
1173}
1174
1175fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
1176 cx.platform().write_credentials(
1177 &ZED_SERVER_URL,
1178 &credentials.user_id.to_string(),
1179 credentials.access_token.as_bytes(),
1180 )
1181}
1182
1183const WORKTREE_URL_PREFIX: &str = "zed://worktrees/";
1184
1185pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
1186 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
1187}
1188
1189pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
1190 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
1191 let mut parts = path.split('/');
1192 let id = parts.next()?.parse::<u64>().ok()?;
1193 let access_token = parts.next()?;
1194 if access_token.is_empty() {
1195 return None;
1196 }
1197 Some((id, access_token.to_string()))
1198}
1199
1200#[cfg(test)]
1201mod tests {
1202 use super::*;
1203 use crate::test::{FakeHttpClient, FakeServer};
1204 use gpui::{executor::Deterministic, TestAppContext};
1205 use parking_lot::Mutex;
1206 use std::future;
1207
1208 #[gpui::test(iterations = 10)]
1209 async fn test_reconnection(cx: &mut TestAppContext) {
1210 cx.foreground().forbid_parking();
1211
1212 let user_id = 5;
1213 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1214 let server = FakeServer::for_client(user_id, &client, cx).await;
1215 let mut status = client.status();
1216 assert!(matches!(
1217 status.next().await,
1218 Some(Status::Connected { .. })
1219 ));
1220 assert_eq!(server.auth_count(), 1);
1221
1222 server.forbid_connections();
1223 server.disconnect();
1224 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1225
1226 server.allow_connections();
1227 cx.foreground().advance_clock(Duration::from_secs(10));
1228 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1229 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
1230
1231 server.forbid_connections();
1232 server.disconnect();
1233 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1234
1235 // Clear cached credentials after authentication fails
1236 server.roll_access_token();
1237 server.allow_connections();
1238 cx.foreground().advance_clock(Duration::from_secs(10));
1239 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
1240 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
1241 }
1242
1243 #[gpui::test(iterations = 10)]
1244 async fn test_connection_timeout(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
1245 deterministic.forbid_parking();
1246
1247 let user_id = 5;
1248 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1249 let mut status = client.status();
1250
1251 // Time out when client tries to connect.
1252 client.override_authenticate(move |cx| {
1253 cx.foreground().spawn(async move {
1254 Ok(Credentials {
1255 user_id,
1256 access_token: "token".into(),
1257 })
1258 })
1259 });
1260 client.override_establish_connection(|_, cx| {
1261 cx.foreground().spawn(async move {
1262 future::pending::<()>().await;
1263 unreachable!()
1264 })
1265 });
1266 let auth_and_connect = cx.spawn({
1267 let client = client.clone();
1268 |cx| async move { client.authenticate_and_connect(false, &cx).await }
1269 });
1270 deterministic.run_until_parked();
1271 assert!(matches!(status.next().await, Some(Status::Connecting)));
1272
1273 deterministic.advance_clock(CONNECTION_TIMEOUT);
1274 assert!(matches!(
1275 status.next().await,
1276 Some(Status::ConnectionError { .. })
1277 ));
1278 auth_and_connect.await.unwrap_err();
1279
1280 // Allow the connection to be established.
1281 let server = FakeServer::for_client(user_id, &client, cx).await;
1282 assert!(matches!(
1283 status.next().await,
1284 Some(Status::Connected { .. })
1285 ));
1286
1287 // Disconnect client.
1288 server.forbid_connections();
1289 server.disconnect();
1290 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
1291
1292 // Time out when re-establishing the connection.
1293 server.allow_connections();
1294 client.override_establish_connection(|_, cx| {
1295 cx.foreground().spawn(async move {
1296 future::pending::<()>().await;
1297 unreachable!()
1298 })
1299 });
1300 deterministic.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
1301 assert!(matches!(
1302 status.next().await,
1303 Some(Status::Reconnecting { .. })
1304 ));
1305
1306 deterministic.advance_clock(CONNECTION_TIMEOUT);
1307 assert!(matches!(
1308 status.next().await,
1309 Some(Status::ReconnectionError { .. })
1310 ));
1311 }
1312
1313 #[gpui::test(iterations = 10)]
1314 async fn test_authenticating_more_than_once(
1315 cx: &mut TestAppContext,
1316 deterministic: Arc<Deterministic>,
1317 ) {
1318 cx.foreground().forbid_parking();
1319
1320 let auth_count = Arc::new(Mutex::new(0));
1321 let dropped_auth_count = Arc::new(Mutex::new(0));
1322 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1323 client.override_authenticate({
1324 let auth_count = auth_count.clone();
1325 let dropped_auth_count = dropped_auth_count.clone();
1326 move |cx| {
1327 let auth_count = auth_count.clone();
1328 let dropped_auth_count = dropped_auth_count.clone();
1329 cx.foreground().spawn(async move {
1330 *auth_count.lock() += 1;
1331 let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
1332 future::pending::<()>().await;
1333 unreachable!()
1334 })
1335 }
1336 });
1337
1338 let _authenticate = cx.spawn(|cx| {
1339 let client = client.clone();
1340 async move { client.authenticate_and_connect(false, &cx).await }
1341 });
1342 deterministic.run_until_parked();
1343 assert_eq!(*auth_count.lock(), 1);
1344 assert_eq!(*dropped_auth_count.lock(), 0);
1345
1346 let _authenticate = cx.spawn(|cx| {
1347 let client = client.clone();
1348 async move { client.authenticate_and_connect(false, &cx).await }
1349 });
1350 deterministic.run_until_parked();
1351 assert_eq!(*auth_count.lock(), 2);
1352 assert_eq!(*dropped_auth_count.lock(), 1);
1353 }
1354
1355 #[test]
1356 fn test_encode_and_decode_worktree_url() {
1357 let url = encode_worktree_url(5, "deadbeef");
1358 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
1359 assert_eq!(
1360 decode_worktree_url(&format!("\n {}\t", url)),
1361 Some((5, "deadbeef".to_string()))
1362 );
1363 assert_eq!(decode_worktree_url("not://the-right-format"), None);
1364 }
1365
1366 #[gpui::test]
1367 async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
1368 cx.foreground().forbid_parking();
1369
1370 let user_id = 5;
1371 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1372 let server = FakeServer::for_client(user_id, &client, cx).await;
1373
1374 let (done_tx1, mut done_rx1) = smol::channel::unbounded();
1375 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1376 client.add_model_message_handler(
1377 move |model: ModelHandle<Model>, _: TypedEnvelope<proto::JoinProject>, _, cx| {
1378 match model.read_with(&cx, |model, _| model.id) {
1379 1 => done_tx1.try_send(()).unwrap(),
1380 2 => done_tx2.try_send(()).unwrap(),
1381 _ => unreachable!(),
1382 }
1383 async { Ok(()) }
1384 },
1385 );
1386 let model1 = cx.add_model(|_| Model {
1387 id: 1,
1388 subscription: None,
1389 });
1390 let model2 = cx.add_model(|_| Model {
1391 id: 2,
1392 subscription: None,
1393 });
1394 let model3 = cx.add_model(|_| Model {
1395 id: 3,
1396 subscription: None,
1397 });
1398
1399 let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
1400 let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
1401 // Ensure dropping a subscription for the same entity type still allows receiving of
1402 // messages for other entity IDs of the same type.
1403 let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
1404 drop(subscription3);
1405
1406 server.send(proto::JoinProject { project_id: 1 });
1407 server.send(proto::JoinProject { project_id: 2 });
1408 done_rx1.next().await.unwrap();
1409 done_rx2.next().await.unwrap();
1410 }
1411
1412 #[gpui::test]
1413 async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
1414 cx.foreground().forbid_parking();
1415
1416 let user_id = 5;
1417 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1418 let server = FakeServer::for_client(user_id, &client, cx).await;
1419
1420 let model = cx.add_model(|_| Model::default());
1421 let (done_tx1, _done_rx1) = smol::channel::unbounded();
1422 let (done_tx2, mut done_rx2) = smol::channel::unbounded();
1423 let subscription1 = client.add_message_handler(
1424 model.clone(),
1425 move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1426 done_tx1.try_send(()).unwrap();
1427 async { Ok(()) }
1428 },
1429 );
1430 drop(subscription1);
1431 let _subscription2 =
1432 client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
1433 done_tx2.try_send(()).unwrap();
1434 async { Ok(()) }
1435 });
1436 server.send(proto::Ping {});
1437 done_rx2.next().await.unwrap();
1438 }
1439
1440 #[gpui::test]
1441 async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
1442 cx.foreground().forbid_parking();
1443
1444 let user_id = 5;
1445 let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
1446 let server = FakeServer::for_client(user_id, &client, cx).await;
1447
1448 let model = cx.add_model(|_| Model::default());
1449 let (done_tx, mut done_rx) = smol::channel::unbounded();
1450 let subscription = client.add_message_handler(
1451 model.clone(),
1452 move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1453 model.update(&mut cx, |model, _| model.subscription.take());
1454 done_tx.try_send(()).unwrap();
1455 async { Ok(()) }
1456 },
1457 );
1458 model.update(cx, |model, _| {
1459 model.subscription = Some(subscription);
1460 });
1461 server.send(proto::Ping {});
1462 done_rx.next().await.unwrap();
1463 }
1464
1465 #[derive(Default)]
1466 struct Model {
1467 id: usize,
1468 subscription: Option<Subscription>,
1469 }
1470
1471 impl Entity for Model {
1472 type Event = ();
1473 }
1474}