1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod channel;
5pub mod http;
6pub mod user;
7
8use anyhow::{anyhow, Context, Result};
9use async_recursion::async_recursion;
10use async_tungstenite::tungstenite::{
11 error::Error as WebsocketError,
12 http::{Request, StatusCode},
13};
14use futures::{future::LocalBoxFuture, FutureExt, StreamExt};
15use gpui::{action, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
16use http::HttpClient;
17use lazy_static::lazy_static;
18use parking_lot::RwLock;
19use postage::watch;
20use rand::prelude::*;
21use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
22use std::{
23 any::{type_name, TypeId},
24 collections::HashMap,
25 convert::TryFrom,
26 fmt::Write as _,
27 future::Future,
28 sync::{
29 atomic::{AtomicUsize, Ordering},
30 Arc, Weak,
31 },
32 time::{Duration, Instant},
33};
34use surf::{http::Method, Url};
35use thiserror::Error;
36use util::{ResultExt, TryFutureExt};
37
38pub use channel::*;
39pub use rpc::*;
40pub use user::*;
41
42lazy_static! {
43 static ref ZED_SERVER_URL: String =
44 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
45 static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
46 .ok()
47 .and_then(|s| if s.is_empty() { None } else { Some(s) });
48}
49
50action!(Authenticate);
51
52pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
53 cx.add_global_action(move |_: &Authenticate, cx| {
54 let rpc = rpc.clone();
55 cx.spawn(|cx| async move { rpc.authenticate_and_connect(&cx).log_err().await })
56 .detach();
57 });
58}
59
60pub struct Client {
61 id: usize,
62 peer: Arc<Peer>,
63 http: Arc<dyn HttpClient>,
64 state: RwLock<ClientState>,
65 authenticate:
66 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
67 establish_connection: Option<
68 Box<
69 dyn 'static
70 + Send
71 + Sync
72 + Fn(
73 &Credentials,
74 &AsyncAppContext,
75 ) -> Task<Result<Connection, EstablishConnectionError>>,
76 >,
77 >,
78}
79
80#[derive(Error, Debug)]
81pub enum EstablishConnectionError {
82 #[error("upgrade required")]
83 UpgradeRequired,
84 #[error("unauthorized")]
85 Unauthorized,
86 #[error("{0}")]
87 Other(#[from] anyhow::Error),
88 #[error("{0}")]
89 Io(#[from] std::io::Error),
90 #[error("{0}")]
91 Http(#[from] async_tungstenite::tungstenite::http::Error),
92}
93
94impl From<WebsocketError> for EstablishConnectionError {
95 fn from(error: WebsocketError) -> Self {
96 if let WebsocketError::Http(response) = &error {
97 match response.status() {
98 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
99 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
100 _ => {}
101 }
102 }
103 EstablishConnectionError::Other(error.into())
104 }
105}
106
107impl EstablishConnectionError {
108 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
109 Self::Other(error.into())
110 }
111}
112
113#[derive(Copy, Clone, Debug)]
114pub enum Status {
115 SignedOut,
116 UpgradeRequired,
117 Authenticating,
118 Connecting,
119 ConnectionError,
120 Connected { connection_id: ConnectionId },
121 ConnectionLost,
122 Reauthenticating,
123 Reconnecting,
124 ReconnectionError { next_reconnection: Instant },
125}
126
127type ModelHandler = Box<
128 dyn Send
129 + Sync
130 + FnMut(Box<dyn AnyTypedEnvelope>, &AsyncAppContext) -> LocalBoxFuture<'static, Result<()>>,
131>;
132
133struct ClientState {
134 credentials: Option<Credentials>,
135 status: (watch::Sender<Status>, watch::Receiver<Status>),
136 entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
137 model_handlers: HashMap<(TypeId, Option<u64>), Option<ModelHandler>>,
138 _maintain_connection: Option<Task<()>>,
139 heartbeat_interval: Duration,
140}
141
142#[derive(Clone, Debug)]
143pub struct Credentials {
144 pub user_id: u64,
145 pub access_token: String,
146}
147
148impl Default for ClientState {
149 fn default() -> Self {
150 Self {
151 credentials: None,
152 status: watch::channel_with(Status::SignedOut),
153 entity_id_extractors: Default::default(),
154 model_handlers: Default::default(),
155 _maintain_connection: None,
156 heartbeat_interval: Duration::from_secs(5),
157 }
158 }
159}
160
161pub struct Subscription {
162 client: Weak<Client>,
163 id: (TypeId, Option<u64>),
164}
165
166impl Drop for Subscription {
167 fn drop(&mut self) {
168 if let Some(client) = self.client.upgrade() {
169 let mut state = client.state.write();
170 let _ = state.model_handlers.remove(&self.id).unwrap();
171 }
172 }
173}
174
175impl Client {
176 pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
177 lazy_static! {
178 static ref NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::default();
179 }
180
181 Arc::new(Self {
182 id: NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst),
183 peer: Peer::new(),
184 http,
185 state: Default::default(),
186 authenticate: None,
187 establish_connection: None,
188 })
189 }
190
191 #[cfg(any(test, feature = "test-support"))]
192 pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
193 where
194 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
195 {
196 self.authenticate = Some(Box::new(authenticate));
197 self
198 }
199
200 #[cfg(any(test, feature = "test-support"))]
201 pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
202 where
203 F: 'static
204 + Send
205 + Sync
206 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
207 {
208 self.establish_connection = Some(Box::new(connect));
209 self
210 }
211
212 pub fn user_id(&self) -> Option<u64> {
213 self.state
214 .read()
215 .credentials
216 .as_ref()
217 .map(|credentials| credentials.user_id)
218 }
219
220 pub fn status(&self) -> watch::Receiver<Status> {
221 self.state.read().status.1.clone()
222 }
223
224 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
225 let mut state = self.state.write();
226 *state.status.0.borrow_mut() = status;
227
228 match status {
229 Status::Connected { .. } => {
230 let heartbeat_interval = state.heartbeat_interval;
231 let this = self.clone();
232 let foreground = cx.foreground();
233 state._maintain_connection = Some(cx.foreground().spawn(async move {
234 loop {
235 foreground.timer(heartbeat_interval).await;
236 let _ = this.request(proto::Ping {}).await;
237 }
238 }));
239 }
240 Status::ConnectionLost => {
241 let this = self.clone();
242 let foreground = cx.foreground();
243 let heartbeat_interval = state.heartbeat_interval;
244 state._maintain_connection = Some(cx.spawn(|cx| async move {
245 let mut rng = StdRng::from_entropy();
246 let mut delay = Duration::from_millis(100);
247 while let Err(error) = this.authenticate_and_connect(&cx).await {
248 log::error!("failed to connect {}", error);
249 this.set_status(
250 Status::ReconnectionError {
251 next_reconnection: Instant::now() + delay,
252 },
253 &cx,
254 );
255 foreground.timer(delay).await;
256 delay = delay
257 .mul_f32(rng.gen_range(1.0..=2.0))
258 .min(heartbeat_interval);
259 }
260 }));
261 }
262 Status::SignedOut | Status::UpgradeRequired => {
263 state._maintain_connection.take();
264 }
265 _ => {}
266 }
267 }
268
269 pub fn subscribe<T, M, F, Fut>(
270 self: &Arc<Self>,
271 cx: &mut ModelContext<M>,
272 mut handler: F,
273 ) -> Subscription
274 where
275 T: EnvelopedMessage,
276 M: Entity,
277 F: 'static
278 + Send
279 + Sync
280 + FnMut(ModelHandle<M>, TypedEnvelope<T>, Arc<Self>, AsyncAppContext) -> Fut,
281 Fut: 'static + Future<Output = Result<()>>,
282 {
283 let subscription_id = (TypeId::of::<T>(), None);
284 let client = self.clone();
285 let mut state = self.state.write();
286 let model = cx.weak_handle();
287 let prev_handler = state.model_handlers.insert(
288 subscription_id,
289 Some(Box::new(move |envelope, cx| {
290 if let Some(model) = model.upgrade(cx) {
291 let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
292 handler(model, *envelope, client.clone(), cx.clone()).boxed_local()
293 } else {
294 async move {
295 Err(anyhow!(
296 "received message for {:?} but model was dropped",
297 type_name::<M>()
298 ))
299 }
300 .boxed_local()
301 }
302 })),
303 );
304 if prev_handler.is_some() {
305 panic!("registered handler for the same message twice");
306 }
307
308 Subscription {
309 client: Arc::downgrade(self),
310 id: subscription_id,
311 }
312 }
313
314 pub fn subscribe_to_entity<T, M, F, Fut>(
315 self: &Arc<Self>,
316 remote_id: u64,
317 cx: &mut ModelContext<M>,
318 mut handler: F,
319 ) -> Subscription
320 where
321 T: EntityMessage,
322 M: Entity,
323 F: 'static
324 + Send
325 + Sync
326 + FnMut(ModelHandle<M>, TypedEnvelope<T>, Arc<Self>, AsyncAppContext) -> Fut,
327 Fut: 'static + Future<Output = Result<()>>,
328 {
329 let subscription_id = (TypeId::of::<T>(), Some(remote_id));
330 let client = self.clone();
331 let mut state = self.state.write();
332 let model = cx.weak_handle();
333 state
334 .entity_id_extractors
335 .entry(subscription_id.0)
336 .or_insert_with(|| {
337 Box::new(|envelope| {
338 let envelope = envelope
339 .as_any()
340 .downcast_ref::<TypedEnvelope<T>>()
341 .unwrap();
342 envelope.payload.remote_entity_id()
343 })
344 });
345 let prev_handler = state.model_handlers.insert(
346 subscription_id,
347 Some(Box::new(move |envelope, cx| {
348 if let Some(model) = model.upgrade(cx) {
349 let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
350 handler(model, *envelope, client.clone(), cx.clone()).boxed_local()
351 } else {
352 async move {
353 Err(anyhow!(
354 "received message for {:?} but model was dropped",
355 type_name::<M>()
356 ))
357 }
358 .boxed_local()
359 }
360 })),
361 );
362 if prev_handler.is_some() {
363 panic!("registered a handler for the same entity twice")
364 }
365
366 Subscription {
367 client: Arc::downgrade(self),
368 id: subscription_id,
369 }
370 }
371
372 pub fn subscribe_to_entity_request<T, M, F, Fut>(
373 self: &Arc<Self>,
374 remote_id: u64,
375 cx: &mut ModelContext<M>,
376 mut handler: F,
377 ) -> Subscription
378 where
379 T: EntityMessage + RequestMessage,
380 M: Entity,
381 F: 'static
382 + Send
383 + Sync
384 + FnMut(ModelHandle<M>, TypedEnvelope<T>, Arc<Self>, AsyncAppContext) -> Fut,
385 Fut: 'static + Future<Output = Result<T::Response>>,
386 {
387 self.subscribe_to_entity(remote_id, cx, move |model, envelope, client, cx| {
388 let receipt = envelope.receipt();
389 let response = handler(model, envelope, client.clone(), cx);
390 async move {
391 match response.await {
392 Ok(response) => {
393 client.respond(receipt, response)?;
394 Ok(())
395 }
396 Err(error) => {
397 client.respond_with_error(
398 receipt,
399 proto::Error {
400 message: error.to_string(),
401 },
402 )?;
403 Err(error)
404 }
405 }
406 }
407 })
408 }
409
410 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
411 read_credentials_from_keychain(cx).is_some()
412 }
413
414 #[async_recursion(?Send)]
415 pub async fn authenticate_and_connect(
416 self: &Arc<Self>,
417 cx: &AsyncAppContext,
418 ) -> anyhow::Result<()> {
419 let was_disconnected = match *self.status().borrow() {
420 Status::SignedOut => true,
421 Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
422 false
423 }
424 Status::Connected { .. }
425 | Status::Connecting { .. }
426 | Status::Reconnecting { .. }
427 | Status::Authenticating
428 | Status::Reauthenticating => return Ok(()),
429 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
430 };
431
432 if was_disconnected {
433 self.set_status(Status::Authenticating, cx);
434 } else {
435 self.set_status(Status::Reauthenticating, cx)
436 }
437
438 let mut used_keychain = false;
439 let credentials = self.state.read().credentials.clone();
440 let credentials = if let Some(credentials) = credentials {
441 credentials
442 } else if let Some(credentials) = read_credentials_from_keychain(cx) {
443 used_keychain = true;
444 credentials
445 } else {
446 let credentials = match self.authenticate(&cx).await {
447 Ok(credentials) => credentials,
448 Err(err) => {
449 self.set_status(Status::ConnectionError, cx);
450 return Err(err);
451 }
452 };
453 credentials
454 };
455
456 if was_disconnected {
457 self.set_status(Status::Connecting, cx);
458 } else {
459 self.set_status(Status::Reconnecting, cx);
460 }
461
462 match self.establish_connection(&credentials, cx).await {
463 Ok(conn) => {
464 self.state.write().credentials = Some(credentials.clone());
465 if !used_keychain && IMPERSONATE_LOGIN.is_none() {
466 write_credentials_to_keychain(&credentials, cx).log_err();
467 }
468 self.set_connection(conn, cx).await;
469 Ok(())
470 }
471 Err(EstablishConnectionError::Unauthorized) => {
472 self.state.write().credentials.take();
473 if used_keychain {
474 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
475 self.set_status(Status::SignedOut, cx);
476 self.authenticate_and_connect(cx).await
477 } else {
478 self.set_status(Status::ConnectionError, cx);
479 Err(EstablishConnectionError::Unauthorized)?
480 }
481 }
482 Err(EstablishConnectionError::UpgradeRequired) => {
483 self.set_status(Status::UpgradeRequired, cx);
484 Err(EstablishConnectionError::UpgradeRequired)?
485 }
486 Err(error) => {
487 self.set_status(Status::ConnectionError, cx);
488 Err(error)?
489 }
490 }
491 }
492
493 async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
494 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
495 cx.foreground()
496 .spawn({
497 let cx = cx.clone();
498 let this = self.clone();
499 async move {
500 while let Some(message) = incoming.next().await {
501 let mut state = this.state.write();
502 let payload_type_id = message.payload_type_id();
503 let entity_id = if let Some(extract_entity_id) =
504 state.entity_id_extractors.get(&message.payload_type_id())
505 {
506 Some((extract_entity_id)(message.as_ref()))
507 } else {
508 None
509 };
510
511 let type_name = message.payload_type_name();
512
513 let handler_key = (payload_type_id, entity_id);
514 if let Some(handler) = state.model_handlers.get_mut(&handler_key) {
515 let mut handler = handler.take().unwrap();
516 drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
517
518 log::debug!(
519 "rpc message received. client_id:{}, name:{}",
520 this.id,
521 type_name
522 );
523
524 let future = (handler)(message, &cx);
525 let client_id = this.id;
526 cx.foreground()
527 .spawn(async move {
528 match future.await {
529 Ok(()) => {
530 log::debug!(
531 "rpc message handled. client_id:{}, name:{}",
532 client_id,
533 type_name
534 );
535 }
536 Err(error) => {
537 log::error!(
538 "error handling rpc message. client_id:{}, name:{}, error: {}",
539 client_id, type_name, error
540 );
541 }
542 }
543 })
544 .detach();
545
546 let mut state = this.state.write();
547 if state.model_handlers.contains_key(&handler_key) {
548 state.model_handlers.insert(handler_key, Some(handler));
549 }
550 } else {
551 log::info!("unhandled message {}", type_name);
552 }
553 }
554 }
555 })
556 .detach();
557
558 self.set_status(Status::Connected { connection_id }, cx);
559
560 let handle_io = cx.background().spawn(handle_io);
561 let this = self.clone();
562 let cx = cx.clone();
563 cx.foreground()
564 .spawn(async move {
565 match handle_io.await {
566 Ok(()) => this.set_status(Status::SignedOut, &cx),
567 Err(err) => {
568 log::error!("connection error: {:?}", err);
569 this.set_status(Status::ConnectionLost, &cx);
570 }
571 }
572 })
573 .detach();
574 }
575
576 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
577 if let Some(callback) = self.authenticate.as_ref() {
578 callback(cx)
579 } else {
580 self.authenticate_with_browser(cx)
581 }
582 }
583
584 fn establish_connection(
585 self: &Arc<Self>,
586 credentials: &Credentials,
587 cx: &AsyncAppContext,
588 ) -> Task<Result<Connection, EstablishConnectionError>> {
589 if let Some(callback) = self.establish_connection.as_ref() {
590 callback(credentials, cx)
591 } else {
592 self.establish_websocket_connection(credentials, cx)
593 }
594 }
595
596 fn establish_websocket_connection(
597 self: &Arc<Self>,
598 credentials: &Credentials,
599 cx: &AsyncAppContext,
600 ) -> Task<Result<Connection, EstablishConnectionError>> {
601 let request = Request::builder()
602 .header(
603 "Authorization",
604 format!("{} {}", credentials.user_id, credentials.access_token),
605 )
606 .header("X-Zed-Protocol-Version", rpc::PROTOCOL_VERSION);
607
608 let http = self.http.clone();
609 cx.background().spawn(async move {
610 let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
611 let rpc_request = surf::Request::new(
612 Method::Get,
613 surf::Url::parse(&rpc_url).context("invalid ZED_SERVER_URL")?,
614 );
615 let rpc_response = http.send(rpc_request).await?;
616
617 if rpc_response.status().is_redirection() {
618 rpc_url = rpc_response
619 .header("Location")
620 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
621 .as_str()
622 .to_string();
623 }
624 // Until we switch the zed.dev domain to point to the new Next.js app, there
625 // will be no redirect required, and the app will connect directly to
626 // wss://zed.dev/rpc.
627 else if rpc_response.status() != surf::StatusCode::UpgradeRequired {
628 Err(anyhow!(
629 "unexpected /rpc response status {}",
630 rpc_response.status()
631 ))?
632 }
633
634 let mut rpc_url = surf::Url::parse(&rpc_url).context("invalid rpc url")?;
635 let rpc_host = rpc_url
636 .host_str()
637 .zip(rpc_url.port_or_known_default())
638 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
639 let stream = smol::net::TcpStream::connect(rpc_host).await?;
640
641 log::info!("connected to rpc endpoint {}", rpc_url);
642
643 match rpc_url.scheme() {
644 "https" => {
645 rpc_url.set_scheme("wss").unwrap();
646 let request = request.uri(rpc_url.as_str()).body(())?;
647 let (stream, _) =
648 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
649 Ok(Connection::new(stream))
650 }
651 "http" => {
652 rpc_url.set_scheme("ws").unwrap();
653 let request = request.uri(rpc_url.as_str()).body(())?;
654 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
655 Ok(Connection::new(stream))
656 }
657 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
658 }
659 })
660 }
661
662 pub fn authenticate_with_browser(
663 self: &Arc<Self>,
664 cx: &AsyncAppContext,
665 ) -> Task<Result<Credentials>> {
666 let platform = cx.platform();
667 let executor = cx.background();
668 executor.clone().spawn(async move {
669 // Generate a pair of asymmetric encryption keys. The public key will be used by the
670 // zed server to encrypt the user's access token, so that it can'be intercepted by
671 // any other app running on the user's device.
672 let (public_key, private_key) =
673 rpc::auth::keypair().expect("failed to generate keypair for auth");
674 let public_key_string =
675 String::try_from(public_key).expect("failed to serialize public key for auth");
676
677 // Start an HTTP server to receive the redirect from Zed's sign-in page.
678 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
679 let port = server.server_addr().port();
680
681 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
682 // that the user is signing in from a Zed app running on the same device.
683 let mut url = format!(
684 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
685 *ZED_SERVER_URL, port, public_key_string
686 );
687
688 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
689 log::info!("impersonating user @{}", impersonate_login);
690 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
691 }
692
693 platform.open_url(&url);
694
695 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
696 // access token from the query params.
697 //
698 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
699 // custom URL scheme instead of this local HTTP server.
700 let (user_id, access_token) = executor
701 .spawn(async move {
702 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
703 let path = req.url();
704 let mut user_id = None;
705 let mut access_token = None;
706 let url = Url::parse(&format!("http://example.com{}", path))
707 .context("failed to parse login notification url")?;
708 for (key, value) in url.query_pairs() {
709 if key == "access_token" {
710 access_token = Some(value.to_string());
711 } else if key == "user_id" {
712 user_id = Some(value.to_string());
713 }
714 }
715
716 let post_auth_url =
717 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
718 req.respond(
719 tiny_http::Response::empty(302).with_header(
720 tiny_http::Header::from_bytes(
721 &b"Location"[..],
722 post_auth_url.as_bytes(),
723 )
724 .unwrap(),
725 ),
726 )
727 .context("failed to respond to login http request")?;
728 Ok((
729 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
730 access_token
731 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
732 ))
733 } else {
734 Err(anyhow!("didn't receive login redirect"))
735 }
736 })
737 .await?;
738
739 let access_token = private_key
740 .decrypt_string(&access_token)
741 .context("failed to decrypt access token")?;
742 platform.activate(true);
743
744 Ok(Credentials {
745 user_id: user_id.parse()?,
746 access_token,
747 })
748 })
749 }
750
751 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
752 let conn_id = self.connection_id()?;
753 self.peer.disconnect(conn_id);
754 self.set_status(Status::SignedOut, cx);
755 Ok(())
756 }
757
758 fn connection_id(&self) -> Result<ConnectionId> {
759 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
760 Ok(connection_id)
761 } else {
762 Err(anyhow!("not connected"))
763 }
764 }
765
766 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
767 log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
768 self.peer.send(self.connection_id()?, message)
769 }
770
771 pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
772 log::debug!(
773 "rpc request start. client_id: {}. name:{}",
774 self.id,
775 T::NAME
776 );
777 let response = self.peer.request(self.connection_id()?, request).await;
778 log::debug!(
779 "rpc request finish. client_id: {}. name:{}",
780 self.id,
781 T::NAME
782 );
783 response
784 }
785
786 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
787 log::debug!("rpc respond. client_id: {}. name:{}", self.id, T::NAME);
788 self.peer.respond(receipt, response)
789 }
790
791 fn respond_with_error<T: RequestMessage>(
792 &self,
793 receipt: Receipt<T>,
794 error: proto::Error,
795 ) -> Result<()> {
796 log::debug!("rpc respond. client_id: {}. name:{}", self.id, T::NAME);
797 self.peer.respond_with_error(receipt, error)
798 }
799}
800
801fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
802 if IMPERSONATE_LOGIN.is_some() {
803 return None;
804 }
805
806 let (user_id, access_token) = cx
807 .platform()
808 .read_credentials(&ZED_SERVER_URL)
809 .log_err()
810 .flatten()?;
811 Some(Credentials {
812 user_id: user_id.parse().ok()?,
813 access_token: String::from_utf8(access_token).ok()?,
814 })
815}
816
817fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
818 cx.platform().write_credentials(
819 &ZED_SERVER_URL,
820 &credentials.user_id.to_string(),
821 credentials.access_token.as_bytes(),
822 )
823}
824
825const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
826
827pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
828 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
829}
830
831pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
832 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
833 let mut parts = path.split('/');
834 let id = parts.next()?.parse::<u64>().ok()?;
835 let access_token = parts.next()?;
836 if access_token.is_empty() {
837 return None;
838 }
839 Some((id, access_token.to_string()))
840}
841
842#[cfg(test)]
843mod tests {
844 use super::*;
845 use crate::test::{FakeHttpClient, FakeServer};
846 use gpui::TestAppContext;
847
848 #[gpui::test(iterations = 10)]
849 async fn test_heartbeat(cx: TestAppContext) {
850 cx.foreground().forbid_parking();
851
852 let user_id = 5;
853 let mut client = Client::new(FakeHttpClient::with_404_response());
854 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
855
856 cx.foreground().advance_clock(Duration::from_secs(10));
857 let ping = server.receive::<proto::Ping>().await.unwrap();
858 server.respond(ping.receipt(), proto::Ack {}).await;
859
860 cx.foreground().advance_clock(Duration::from_secs(10));
861 let ping = server.receive::<proto::Ping>().await.unwrap();
862 server.respond(ping.receipt(), proto::Ack {}).await;
863
864 client.disconnect(&cx.to_async()).unwrap();
865 assert!(server.receive::<proto::Ping>().await.is_err());
866 }
867
868 #[gpui::test(iterations = 10)]
869 async fn test_reconnection(cx: TestAppContext) {
870 cx.foreground().forbid_parking();
871
872 let user_id = 5;
873 let mut client = Client::new(FakeHttpClient::with_404_response());
874 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
875 let mut status = client.status();
876 assert!(matches!(
877 status.next().await,
878 Some(Status::Connected { .. })
879 ));
880 assert_eq!(server.auth_count(), 1);
881
882 server.forbid_connections();
883 server.disconnect();
884 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
885
886 server.allow_connections();
887 cx.foreground().advance_clock(Duration::from_secs(10));
888 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
889 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
890
891 server.forbid_connections();
892 server.disconnect();
893 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
894
895 // Clear cached credentials after authentication fails
896 server.roll_access_token();
897 server.allow_connections();
898 cx.foreground().advance_clock(Duration::from_secs(10));
899 assert_eq!(server.auth_count(), 1);
900 cx.foreground().advance_clock(Duration::from_secs(10));
901 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
902 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
903 }
904
905 #[test]
906 fn test_encode_and_decode_worktree_url() {
907 let url = encode_worktree_url(5, "deadbeef");
908 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
909 assert_eq!(
910 decode_worktree_url(&format!("\n {}\t", url)),
911 Some((5, "deadbeef".to_string()))
912 );
913 assert_eq!(decode_worktree_url("not://the-right-format"), None);
914 }
915
916 #[gpui::test]
917 async fn test_subscribing_to_entity(mut cx: TestAppContext) {
918 cx.foreground().forbid_parking();
919
920 let user_id = 5;
921 let mut client = Client::new(FakeHttpClient::with_404_response());
922 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
923
924 let model = cx.add_model(|_| Model { subscription: None });
925 let (mut done_tx1, mut done_rx1) = postage::oneshot::channel();
926 let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
927 let _subscription1 = model.update(&mut cx, |_, cx| {
928 client.subscribe_to_entity(
929 1,
930 cx,
931 move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
932 postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
933 async { Ok(()) }
934 },
935 )
936 });
937 let _subscription2 = model.update(&mut cx, |_, cx| {
938 client.subscribe_to_entity(
939 2,
940 cx,
941 move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
942 postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
943 async { Ok(()) }
944 },
945 )
946 });
947
948 // Ensure dropping a subscription for the same entity type still allows receiving of
949 // messages for other entity IDs of the same type.
950 let subscription3 = model.update(&mut cx, |_, cx| {
951 client.subscribe_to_entity(
952 3,
953 cx,
954 |_, _: TypedEnvelope<proto::UnshareProject>, _, _| async { Ok(()) },
955 )
956 });
957 drop(subscription3);
958
959 server.send(proto::UnshareProject { project_id: 1 });
960 server.send(proto::UnshareProject { project_id: 2 });
961 done_rx1.next().await.unwrap();
962 done_rx2.next().await.unwrap();
963 }
964
965 #[gpui::test]
966 async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) {
967 cx.foreground().forbid_parking();
968
969 let user_id = 5;
970 let mut client = Client::new(FakeHttpClient::with_404_response());
971 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
972
973 let model = cx.add_model(|_| Model { subscription: None });
974 let (mut done_tx1, _done_rx1) = postage::oneshot::channel();
975 let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
976 let subscription1 = model.update(&mut cx, |_, cx| {
977 client.subscribe(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
978 postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
979 async { Ok(()) }
980 })
981 });
982 drop(subscription1);
983 let _subscription2 = model.update(&mut cx, |_, cx| {
984 client.subscribe(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
985 postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
986 async { Ok(()) }
987 })
988 });
989 server.send(proto::Ping {});
990 done_rx2.next().await.unwrap();
991 }
992
993 #[gpui::test]
994 async fn test_dropping_subscription_in_handler(mut cx: TestAppContext) {
995 cx.foreground().forbid_parking();
996
997 let user_id = 5;
998 let mut client = Client::new(FakeHttpClient::with_404_response());
999 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1000
1001 let model = cx.add_model(|_| Model { subscription: None });
1002 let (mut done_tx, mut done_rx) = postage::oneshot::channel();
1003 model.update(&mut cx, |model, cx| {
1004 model.subscription = Some(client.subscribe(
1005 cx,
1006 move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1007 model.update(&mut cx, |model, _| model.subscription.take());
1008 postage::sink::Sink::try_send(&mut done_tx, ()).unwrap();
1009 async { Ok(()) }
1010 },
1011 ));
1012 });
1013 server.send(proto::Ping {});
1014 done_rx.next().await.unwrap();
1015 }
1016
1017 struct Model {
1018 subscription: Option<Subscription>,
1019 }
1020
1021 impl Entity for Model {
1022 type Event = ();
1023 }
1024}