1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod channel;
5pub mod http;
6pub mod user;
7
8use anyhow::{anyhow, Context, Result};
9use async_recursion::async_recursion;
10use async_tungstenite::tungstenite::{
11 error::Error as WebsocketError,
12 http::{Request, StatusCode},
13};
14use futures::{future::LocalBoxFuture, FutureExt, StreamExt};
15use gpui::{action, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
16use http::HttpClient;
17use lazy_static::lazy_static;
18use parking_lot::RwLock;
19use postage::watch;
20use rand::prelude::*;
21use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
22use std::{
23 any::{type_name, TypeId},
24 collections::HashMap,
25 convert::TryFrom,
26 fmt::Write as _,
27 future::Future,
28 sync::{
29 atomic::{AtomicUsize, Ordering},
30 Arc, Weak,
31 },
32 time::{Duration, Instant},
33};
34use surf::{http::Method, Url};
35use thiserror::Error;
36use util::{ResultExt, TryFutureExt};
37
38pub use channel::*;
39pub use rpc::*;
40pub use user::*;
41
42lazy_static! {
43 static ref ZED_SERVER_URL: String =
44 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
45 static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
46 .ok()
47 .and_then(|s| if s.is_empty() { None } else { Some(s) });
48}
49
50action!(Authenticate);
51
52pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
53 cx.add_global_action(move |_: &Authenticate, cx| {
54 let rpc = rpc.clone();
55 cx.spawn(|cx| async move { rpc.authenticate_and_connect(&cx).log_err().await })
56 .detach();
57 });
58}
59
60pub struct Client {
61 id: usize,
62 peer: Arc<Peer>,
63 http: Arc<dyn HttpClient>,
64 state: RwLock<ClientState>,
65 authenticate:
66 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
67 establish_connection: Option<
68 Box<
69 dyn 'static
70 + Send
71 + Sync
72 + Fn(
73 &Credentials,
74 &AsyncAppContext,
75 ) -> Task<Result<Connection, EstablishConnectionError>>,
76 >,
77 >,
78}
79
80#[derive(Error, Debug)]
81pub enum EstablishConnectionError {
82 #[error("upgrade required")]
83 UpgradeRequired,
84 #[error("unauthorized")]
85 Unauthorized,
86 #[error("{0}")]
87 Other(#[from] anyhow::Error),
88 #[error("{0}")]
89 Io(#[from] std::io::Error),
90 #[error("{0}")]
91 Http(#[from] async_tungstenite::tungstenite::http::Error),
92}
93
94impl From<WebsocketError> for EstablishConnectionError {
95 fn from(error: WebsocketError) -> Self {
96 if let WebsocketError::Http(response) = &error {
97 match response.status() {
98 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
99 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
100 _ => {}
101 }
102 }
103 EstablishConnectionError::Other(error.into())
104 }
105}
106
107impl EstablishConnectionError {
108 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
109 Self::Other(error.into())
110 }
111}
112
113#[derive(Copy, Clone, Debug)]
114pub enum Status {
115 SignedOut,
116 UpgradeRequired,
117 Authenticating,
118 Connecting,
119 ConnectionError,
120 Connected { connection_id: ConnectionId },
121 ConnectionLost,
122 Reauthenticating,
123 Reconnecting,
124 ReconnectionError { next_reconnection: Instant },
125}
126
127type ModelHandler = Box<
128 dyn Send
129 + Sync
130 + FnMut(Box<dyn AnyTypedEnvelope>, &AsyncAppContext) -> LocalBoxFuture<'static, Result<()>>,
131>;
132
133struct ClientState {
134 credentials: Option<Credentials>,
135 status: (watch::Sender<Status>, watch::Receiver<Status>),
136 entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
137 model_handlers: HashMap<(TypeId, Option<u64>), Option<ModelHandler>>,
138 _maintain_connection: Option<Task<()>>,
139 heartbeat_interval: Duration,
140}
141
142#[derive(Clone, Debug)]
143pub struct Credentials {
144 pub user_id: u64,
145 pub access_token: String,
146}
147
148impl Default for ClientState {
149 fn default() -> Self {
150 Self {
151 credentials: None,
152 status: watch::channel_with(Status::SignedOut),
153 entity_id_extractors: Default::default(),
154 model_handlers: Default::default(),
155 _maintain_connection: None,
156 heartbeat_interval: Duration::from_secs(5),
157 }
158 }
159}
160
161pub struct Subscription {
162 client: Weak<Client>,
163 id: (TypeId, Option<u64>),
164}
165
166impl Drop for Subscription {
167 fn drop(&mut self) {
168 if let Some(client) = self.client.upgrade() {
169 let mut state = client.state.write();
170 let _ = state.model_handlers.remove(&self.id).unwrap();
171 }
172 }
173}
174
175impl Client {
176 pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
177 lazy_static! {
178 static ref NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::default();
179 }
180
181 Arc::new(Self {
182 id: NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst),
183 peer: Peer::new(),
184 http,
185 state: Default::default(),
186 authenticate: None,
187 establish_connection: None,
188 })
189 }
190
191 #[cfg(any(test, feature = "test-support"))]
192 pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
193 where
194 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
195 {
196 self.authenticate = Some(Box::new(authenticate));
197 self
198 }
199
200 #[cfg(any(test, feature = "test-support"))]
201 pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
202 where
203 F: 'static
204 + Send
205 + Sync
206 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
207 {
208 self.establish_connection = Some(Box::new(connect));
209 self
210 }
211
212 pub fn user_id(&self) -> Option<u64> {
213 self.state
214 .read()
215 .credentials
216 .as_ref()
217 .map(|credentials| credentials.user_id)
218 }
219
220 pub fn status(&self) -> watch::Receiver<Status> {
221 self.state.read().status.1.clone()
222 }
223
224 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
225 let mut state = self.state.write();
226 *state.status.0.borrow_mut() = status;
227
228 match status {
229 Status::Connected { .. } => {
230 let heartbeat_interval = state.heartbeat_interval;
231 let this = self.clone();
232 let foreground = cx.foreground();
233 state._maintain_connection = Some(cx.foreground().spawn(async move {
234 loop {
235 foreground.timer(heartbeat_interval).await;
236 let _ = this.request(proto::Ping {}).await;
237 }
238 }));
239 }
240 Status::ConnectionLost => {
241 let this = self.clone();
242 let foreground = cx.foreground();
243 let heartbeat_interval = state.heartbeat_interval;
244 state._maintain_connection = Some(cx.spawn(|cx| async move {
245 let mut rng = StdRng::from_entropy();
246 let mut delay = Duration::from_millis(100);
247 while let Err(error) = this.authenticate_and_connect(&cx).await {
248 log::error!("failed to connect {}", error);
249 this.set_status(
250 Status::ReconnectionError {
251 next_reconnection: Instant::now() + delay,
252 },
253 &cx,
254 );
255 foreground.timer(delay).await;
256 delay = delay
257 .mul_f32(rng.gen_range(1.0..=2.0))
258 .min(heartbeat_interval);
259 }
260 }));
261 }
262 Status::SignedOut | Status::UpgradeRequired => {
263 state._maintain_connection.take();
264 }
265 _ => {}
266 }
267 }
268
269 pub fn add_message_handler<T, M, F, Fut>(
270 self: &Arc<Self>,
271 cx: &mut ModelContext<M>,
272 mut handler: F,
273 ) -> Subscription
274 where
275 T: EnvelopedMessage,
276 M: Entity,
277 F: 'static
278 + Send
279 + Sync
280 + FnMut(ModelHandle<M>, TypedEnvelope<T>, Arc<Self>, AsyncAppContext) -> Fut,
281 Fut: 'static + Future<Output = Result<()>>,
282 {
283 let subscription_id = (TypeId::of::<T>(), None);
284 let client = self.clone();
285 let mut state = self.state.write();
286 let model = cx.weak_handle();
287 let prev_handler = state.model_handlers.insert(
288 subscription_id,
289 Some(Box::new(move |envelope, cx| {
290 if let Some(model) = model.upgrade(cx) {
291 let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
292 handler(model, *envelope, client.clone(), cx.clone()).boxed_local()
293 } else {
294 async move {
295 Err(anyhow!(
296 "received message for {:?} but model was dropped",
297 type_name::<M>()
298 ))
299 }
300 .boxed_local()
301 }
302 })),
303 );
304 if prev_handler.is_some() {
305 panic!("registered handler for the same message twice");
306 }
307
308 Subscription {
309 client: Arc::downgrade(self),
310 id: subscription_id,
311 }
312 }
313
314 pub fn add_entity_message_handler<T, M, F, Fut>(
315 self: &Arc<Self>,
316 remote_id: u64,
317 cx: &mut ModelContext<M>,
318 mut handler: F,
319 ) -> Subscription
320 where
321 T: EntityMessage,
322 M: Entity,
323 F: 'static
324 + Send
325 + Sync
326 + FnMut(ModelHandle<M>, TypedEnvelope<T>, Arc<Self>, AsyncAppContext) -> Fut,
327 Fut: 'static + Future<Output = Result<()>>,
328 {
329 let subscription_id = (TypeId::of::<T>(), Some(remote_id));
330 let client = self.clone();
331 let mut state = self.state.write();
332 let model = cx.weak_handle();
333 state
334 .entity_id_extractors
335 .entry(subscription_id.0)
336 .or_insert_with(|| {
337 Box::new(|envelope| {
338 let envelope = envelope
339 .as_any()
340 .downcast_ref::<TypedEnvelope<T>>()
341 .unwrap();
342 envelope.payload.remote_entity_id()
343 })
344 });
345 let prev_handler = state.model_handlers.insert(
346 subscription_id,
347 Some(Box::new(move |envelope, cx| {
348 if let Some(model) = model.upgrade(cx) {
349 let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
350 handler(model, *envelope, client.clone(), cx.clone()).boxed_local()
351 } else {
352 async move {
353 Err(anyhow!(
354 "received message for {:?} but model was dropped",
355 type_name::<M>()
356 ))
357 }
358 .boxed_local()
359 }
360 })),
361 );
362 if prev_handler.is_some() {
363 panic!("registered a handler for the same entity twice")
364 }
365
366 Subscription {
367 client: Arc::downgrade(self),
368 id: subscription_id,
369 }
370 }
371
372 pub fn add_entity_request_handler<T, M, F, Fut>(
373 self: &Arc<Self>,
374 remote_id: u64,
375 cx: &mut ModelContext<M>,
376 mut handler: F,
377 ) -> Subscription
378 where
379 T: EntityMessage + RequestMessage,
380 M: Entity,
381 F: 'static
382 + Send
383 + Sync
384 + FnMut(ModelHandle<M>, TypedEnvelope<T>, Arc<Self>, AsyncAppContext) -> Fut,
385 Fut: 'static + Future<Output = Result<T::Response>>,
386 {
387 self.add_entity_message_handler(remote_id, cx, move |model, envelope, client, cx| {
388 let receipt = envelope.receipt();
389 let response = handler(model, envelope, client.clone(), cx);
390 async move {
391 match response.await {
392 Ok(response) => {
393 client.respond(receipt, response)?;
394 Ok(())
395 }
396 Err(error) => {
397 client.respond_with_error(
398 receipt,
399 proto::Error {
400 message: error.to_string(),
401 },
402 )?;
403 Err(error)
404 }
405 }
406 }
407 })
408 }
409
410 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
411 read_credentials_from_keychain(cx).is_some()
412 }
413
414 #[async_recursion(?Send)]
415 pub async fn authenticate_and_connect(
416 self: &Arc<Self>,
417 cx: &AsyncAppContext,
418 ) -> anyhow::Result<()> {
419 let was_disconnected = match *self.status().borrow() {
420 Status::SignedOut => true,
421 Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
422 false
423 }
424 Status::Connected { .. }
425 | Status::Connecting { .. }
426 | Status::Reconnecting { .. }
427 | Status::Authenticating
428 | Status::Reauthenticating => return Ok(()),
429 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
430 };
431
432 if was_disconnected {
433 self.set_status(Status::Authenticating, cx);
434 } else {
435 self.set_status(Status::Reauthenticating, cx)
436 }
437
438 let mut used_keychain = false;
439 let credentials = self.state.read().credentials.clone();
440 let credentials = if let Some(credentials) = credentials {
441 credentials
442 } else if let Some(credentials) = read_credentials_from_keychain(cx) {
443 used_keychain = true;
444 credentials
445 } else {
446 let credentials = match self.authenticate(&cx).await {
447 Ok(credentials) => credentials,
448 Err(err) => {
449 self.set_status(Status::ConnectionError, cx);
450 return Err(err);
451 }
452 };
453 credentials
454 };
455
456 if was_disconnected {
457 self.set_status(Status::Connecting, cx);
458 } else {
459 self.set_status(Status::Reconnecting, cx);
460 }
461
462 match self.establish_connection(&credentials, cx).await {
463 Ok(conn) => {
464 self.state.write().credentials = Some(credentials.clone());
465 if !used_keychain && IMPERSONATE_LOGIN.is_none() {
466 write_credentials_to_keychain(&credentials, cx).log_err();
467 }
468 self.set_connection(conn, cx).await;
469 Ok(())
470 }
471 Err(EstablishConnectionError::Unauthorized) => {
472 self.state.write().credentials.take();
473 if used_keychain {
474 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
475 self.set_status(Status::SignedOut, cx);
476 self.authenticate_and_connect(cx).await
477 } else {
478 self.set_status(Status::ConnectionError, cx);
479 Err(EstablishConnectionError::Unauthorized)?
480 }
481 }
482 Err(EstablishConnectionError::UpgradeRequired) => {
483 self.set_status(Status::UpgradeRequired, cx);
484 Err(EstablishConnectionError::UpgradeRequired)?
485 }
486 Err(error) => {
487 self.set_status(Status::ConnectionError, cx);
488 Err(error)?
489 }
490 }
491 }
492
493 async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
494 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
495 cx.foreground()
496 .spawn({
497 let cx = cx.clone();
498 let this = self.clone();
499 async move {
500 while let Some(message) = incoming.next().await {
501 let mut state = this.state.write();
502 let payload_type_id = message.payload_type_id();
503 let entity_id = if let Some(extract_entity_id) =
504 state.entity_id_extractors.get(&message.payload_type_id())
505 {
506 Some((extract_entity_id)(message.as_ref()))
507 } else {
508 None
509 };
510
511 let type_name = message.payload_type_name();
512
513 let handler_key = (payload_type_id, entity_id);
514 if let Some(handler) = state.model_handlers.get_mut(&handler_key) {
515 let mut handler = handler.take().unwrap();
516 drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
517 let future = (handler)(message, &cx);
518 {
519 let mut state = this.state.write();
520 if state.model_handlers.contains_key(&handler_key) {
521 state.model_handlers.insert(handler_key, Some(handler));
522 }
523 }
524
525 let client_id = this.id;
526 log::debug!(
527 "rpc message received. client_id:{}, name:{}",
528 client_id,
529 type_name
530 );
531 cx.foreground()
532 .spawn(async move {
533 match future.await {
534 Ok(()) => {
535 log::debug!(
536 "rpc message handled. client_id:{}, name:{}",
537 client_id,
538 type_name
539 );
540 }
541 Err(error) => {
542 log::error!(
543 "error handling rpc message. client_id:{}, name:{}, error:{}",
544 client_id,
545 type_name,
546 error
547 );
548 }
549 }
550 })
551 .detach();
552 } else {
553 log::info!("unhandled message {}", type_name);
554 }
555 }
556 }
557 })
558 .detach();
559
560 self.set_status(Status::Connected { connection_id }, cx);
561
562 let handle_io = cx.background().spawn(handle_io);
563 let this = self.clone();
564 let cx = cx.clone();
565 cx.foreground()
566 .spawn(async move {
567 match handle_io.await {
568 Ok(()) => this.set_status(Status::SignedOut, &cx),
569 Err(err) => {
570 log::error!("connection error: {:?}", err);
571 this.set_status(Status::ConnectionLost, &cx);
572 }
573 }
574 })
575 .detach();
576 }
577
578 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
579 if let Some(callback) = self.authenticate.as_ref() {
580 callback(cx)
581 } else {
582 self.authenticate_with_browser(cx)
583 }
584 }
585
586 fn establish_connection(
587 self: &Arc<Self>,
588 credentials: &Credentials,
589 cx: &AsyncAppContext,
590 ) -> Task<Result<Connection, EstablishConnectionError>> {
591 if let Some(callback) = self.establish_connection.as_ref() {
592 callback(credentials, cx)
593 } else {
594 self.establish_websocket_connection(credentials, cx)
595 }
596 }
597
598 fn establish_websocket_connection(
599 self: &Arc<Self>,
600 credentials: &Credentials,
601 cx: &AsyncAppContext,
602 ) -> Task<Result<Connection, EstablishConnectionError>> {
603 let request = Request::builder()
604 .header(
605 "Authorization",
606 format!("{} {}", credentials.user_id, credentials.access_token),
607 )
608 .header("X-Zed-Protocol-Version", rpc::PROTOCOL_VERSION);
609
610 let http = self.http.clone();
611 cx.background().spawn(async move {
612 let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
613 let rpc_request = surf::Request::new(
614 Method::Get,
615 surf::Url::parse(&rpc_url).context("invalid ZED_SERVER_URL")?,
616 );
617 let rpc_response = http.send(rpc_request).await?;
618
619 if rpc_response.status().is_redirection() {
620 rpc_url = rpc_response
621 .header("Location")
622 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
623 .as_str()
624 .to_string();
625 }
626 // Until we switch the zed.dev domain to point to the new Next.js app, there
627 // will be no redirect required, and the app will connect directly to
628 // wss://zed.dev/rpc.
629 else if rpc_response.status() != surf::StatusCode::UpgradeRequired {
630 Err(anyhow!(
631 "unexpected /rpc response status {}",
632 rpc_response.status()
633 ))?
634 }
635
636 let mut rpc_url = surf::Url::parse(&rpc_url).context("invalid rpc url")?;
637 let rpc_host = rpc_url
638 .host_str()
639 .zip(rpc_url.port_or_known_default())
640 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
641 let stream = smol::net::TcpStream::connect(rpc_host).await?;
642
643 log::info!("connected to rpc endpoint {}", rpc_url);
644
645 match rpc_url.scheme() {
646 "https" => {
647 rpc_url.set_scheme("wss").unwrap();
648 let request = request.uri(rpc_url.as_str()).body(())?;
649 let (stream, _) =
650 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
651 Ok(Connection::new(stream))
652 }
653 "http" => {
654 rpc_url.set_scheme("ws").unwrap();
655 let request = request.uri(rpc_url.as_str()).body(())?;
656 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
657 Ok(Connection::new(stream))
658 }
659 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
660 }
661 })
662 }
663
664 pub fn authenticate_with_browser(
665 self: &Arc<Self>,
666 cx: &AsyncAppContext,
667 ) -> Task<Result<Credentials>> {
668 let platform = cx.platform();
669 let executor = cx.background();
670 executor.clone().spawn(async move {
671 // Generate a pair of asymmetric encryption keys. The public key will be used by the
672 // zed server to encrypt the user's access token, so that it can'be intercepted by
673 // any other app running on the user's device.
674 let (public_key, private_key) =
675 rpc::auth::keypair().expect("failed to generate keypair for auth");
676 let public_key_string =
677 String::try_from(public_key).expect("failed to serialize public key for auth");
678
679 // Start an HTTP server to receive the redirect from Zed's sign-in page.
680 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
681 let port = server.server_addr().port();
682
683 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
684 // that the user is signing in from a Zed app running on the same device.
685 let mut url = format!(
686 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
687 *ZED_SERVER_URL, port, public_key_string
688 );
689
690 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
691 log::info!("impersonating user @{}", impersonate_login);
692 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
693 }
694
695 platform.open_url(&url);
696
697 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
698 // access token from the query params.
699 //
700 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
701 // custom URL scheme instead of this local HTTP server.
702 let (user_id, access_token) = executor
703 .spawn(async move {
704 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
705 let path = req.url();
706 let mut user_id = None;
707 let mut access_token = None;
708 let url = Url::parse(&format!("http://example.com{}", path))
709 .context("failed to parse login notification url")?;
710 for (key, value) in url.query_pairs() {
711 if key == "access_token" {
712 access_token = Some(value.to_string());
713 } else if key == "user_id" {
714 user_id = Some(value.to_string());
715 }
716 }
717
718 let post_auth_url =
719 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
720 req.respond(
721 tiny_http::Response::empty(302).with_header(
722 tiny_http::Header::from_bytes(
723 &b"Location"[..],
724 post_auth_url.as_bytes(),
725 )
726 .unwrap(),
727 ),
728 )
729 .context("failed to respond to login http request")?;
730 Ok((
731 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
732 access_token
733 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
734 ))
735 } else {
736 Err(anyhow!("didn't receive login redirect"))
737 }
738 })
739 .await?;
740
741 let access_token = private_key
742 .decrypt_string(&access_token)
743 .context("failed to decrypt access token")?;
744 platform.activate(true);
745
746 Ok(Credentials {
747 user_id: user_id.parse()?,
748 access_token,
749 })
750 })
751 }
752
753 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
754 let conn_id = self.connection_id()?;
755 self.peer.disconnect(conn_id);
756 self.set_status(Status::SignedOut, cx);
757 Ok(())
758 }
759
760 fn connection_id(&self) -> Result<ConnectionId> {
761 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
762 Ok(connection_id)
763 } else {
764 Err(anyhow!("not connected"))
765 }
766 }
767
768 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
769 log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
770 self.peer.send(self.connection_id()?, message)
771 }
772
773 pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
774 log::debug!(
775 "rpc request start. client_id: {}. name:{}",
776 self.id,
777 T::NAME
778 );
779 let response = self.peer.request(self.connection_id()?, request).await;
780 log::debug!(
781 "rpc request finish. client_id: {}. name:{}",
782 self.id,
783 T::NAME
784 );
785 response
786 }
787
788 fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
789 log::debug!("rpc respond. client_id: {}. name:{}", self.id, T::NAME);
790 self.peer.respond(receipt, response)
791 }
792
793 fn respond_with_error<T: RequestMessage>(
794 &self,
795 receipt: Receipt<T>,
796 error: proto::Error,
797 ) -> Result<()> {
798 log::debug!("rpc respond. client_id: {}. name:{}", self.id, T::NAME);
799 self.peer.respond_with_error(receipt, error)
800 }
801}
802
803fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
804 if IMPERSONATE_LOGIN.is_some() {
805 return None;
806 }
807
808 let (user_id, access_token) = cx
809 .platform()
810 .read_credentials(&ZED_SERVER_URL)
811 .log_err()
812 .flatten()?;
813 Some(Credentials {
814 user_id: user_id.parse().ok()?,
815 access_token: String::from_utf8(access_token).ok()?,
816 })
817}
818
819fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
820 cx.platform().write_credentials(
821 &ZED_SERVER_URL,
822 &credentials.user_id.to_string(),
823 credentials.access_token.as_bytes(),
824 )
825}
826
827const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
828
829pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
830 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
831}
832
833pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
834 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
835 let mut parts = path.split('/');
836 let id = parts.next()?.parse::<u64>().ok()?;
837 let access_token = parts.next()?;
838 if access_token.is_empty() {
839 return None;
840 }
841 Some((id, access_token.to_string()))
842}
843
844#[cfg(test)]
845mod tests {
846 use super::*;
847 use crate::test::{FakeHttpClient, FakeServer};
848 use gpui::TestAppContext;
849
850 #[gpui::test(iterations = 10)]
851 async fn test_heartbeat(cx: TestAppContext) {
852 cx.foreground().forbid_parking();
853
854 let user_id = 5;
855 let mut client = Client::new(FakeHttpClient::with_404_response());
856 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
857
858 cx.foreground().advance_clock(Duration::from_secs(10));
859 let ping = server.receive::<proto::Ping>().await.unwrap();
860 server.respond(ping.receipt(), proto::Ack {}).await;
861
862 cx.foreground().advance_clock(Duration::from_secs(10));
863 let ping = server.receive::<proto::Ping>().await.unwrap();
864 server.respond(ping.receipt(), proto::Ack {}).await;
865
866 client.disconnect(&cx.to_async()).unwrap();
867 assert!(server.receive::<proto::Ping>().await.is_err());
868 }
869
870 #[gpui::test(iterations = 10)]
871 async fn test_reconnection(cx: TestAppContext) {
872 cx.foreground().forbid_parking();
873
874 let user_id = 5;
875 let mut client = Client::new(FakeHttpClient::with_404_response());
876 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
877 let mut status = client.status();
878 assert!(matches!(
879 status.next().await,
880 Some(Status::Connected { .. })
881 ));
882 assert_eq!(server.auth_count(), 1);
883
884 server.forbid_connections();
885 server.disconnect();
886 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
887
888 server.allow_connections();
889 cx.foreground().advance_clock(Duration::from_secs(10));
890 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
891 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
892
893 server.forbid_connections();
894 server.disconnect();
895 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
896
897 // Clear cached credentials after authentication fails
898 server.roll_access_token();
899 server.allow_connections();
900 cx.foreground().advance_clock(Duration::from_secs(10));
901 assert_eq!(server.auth_count(), 1);
902 cx.foreground().advance_clock(Duration::from_secs(10));
903 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
904 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
905 }
906
907 #[test]
908 fn test_encode_and_decode_worktree_url() {
909 let url = encode_worktree_url(5, "deadbeef");
910 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
911 assert_eq!(
912 decode_worktree_url(&format!("\n {}\t", url)),
913 Some((5, "deadbeef".to_string()))
914 );
915 assert_eq!(decode_worktree_url("not://the-right-format"), None);
916 }
917
918 #[gpui::test]
919 async fn test_subscribing_to_entity(mut cx: TestAppContext) {
920 cx.foreground().forbid_parking();
921
922 let user_id = 5;
923 let mut client = Client::new(FakeHttpClient::with_404_response());
924 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
925
926 let model = cx.add_model(|_| Model { subscription: None });
927 let (mut done_tx1, mut done_rx1) = postage::oneshot::channel();
928 let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
929 let _subscription1 = model.update(&mut cx, |_, cx| {
930 client.add_entity_message_handler(
931 1,
932 cx,
933 move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
934 postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
935 async { Ok(()) }
936 },
937 )
938 });
939 let _subscription2 = model.update(&mut cx, |_, cx| {
940 client.add_entity_message_handler(
941 2,
942 cx,
943 move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
944 postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
945 async { Ok(()) }
946 },
947 )
948 });
949
950 // Ensure dropping a subscription for the same entity type still allows receiving of
951 // messages for other entity IDs of the same type.
952 let subscription3 = model.update(&mut cx, |_, cx| {
953 client.add_entity_message_handler(
954 3,
955 cx,
956 |_, _: TypedEnvelope<proto::UnshareProject>, _, _| async { Ok(()) },
957 )
958 });
959 drop(subscription3);
960
961 server.send(proto::UnshareProject { project_id: 1 });
962 server.send(proto::UnshareProject { project_id: 2 });
963 done_rx1.next().await.unwrap();
964 done_rx2.next().await.unwrap();
965 }
966
967 #[gpui::test]
968 async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) {
969 cx.foreground().forbid_parking();
970
971 let user_id = 5;
972 let mut client = Client::new(FakeHttpClient::with_404_response());
973 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
974
975 let model = cx.add_model(|_| Model { subscription: None });
976 let (mut done_tx1, _done_rx1) = postage::oneshot::channel();
977 let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
978 let subscription1 = model.update(&mut cx, |_, cx| {
979 client.add_message_handler(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
980 postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
981 async { Ok(()) }
982 })
983 });
984 drop(subscription1);
985 let _subscription2 = model.update(&mut cx, |_, cx| {
986 client.add_message_handler(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
987 postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
988 async { Ok(()) }
989 })
990 });
991 server.send(proto::Ping {});
992 done_rx2.next().await.unwrap();
993 }
994
995 #[gpui::test]
996 async fn test_dropping_subscription_in_handler(mut cx: TestAppContext) {
997 cx.foreground().forbid_parking();
998
999 let user_id = 5;
1000 let mut client = Client::new(FakeHttpClient::with_404_response());
1001 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
1002
1003 let model = cx.add_model(|_| Model { subscription: None });
1004 let (mut done_tx, mut done_rx) = postage::oneshot::channel();
1005 model.update(&mut cx, |model, cx| {
1006 model.subscription = Some(client.add_message_handler(
1007 cx,
1008 move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
1009 model.update(&mut cx, |model, _| model.subscription.take());
1010 postage::sink::Sink::try_send(&mut done_tx, ()).unwrap();
1011 async { Ok(()) }
1012 },
1013 ));
1014 });
1015 server.send(proto::Ping {});
1016 done_rx.next().await.unwrap();
1017 }
1018
1019 struct Model {
1020 subscription: Option<Subscription>,
1021 }
1022
1023 impl Entity for Model {
1024 type Event = ();
1025 }
1026}