1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod channel;
5pub mod http;
6pub mod user;
7
8use anyhow::{anyhow, Context, Result};
9use async_recursion::async_recursion;
10use async_tungstenite::tungstenite::{
11 error::Error as WebsocketError,
12 http::{Request, StatusCode},
13};
14use futures::StreamExt;
15use gpui::{action, AsyncAppContext, Entity, ModelContext, MutableAppContext, Task};
16use http::HttpClient;
17use lazy_static::lazy_static;
18use parking_lot::RwLock;
19use postage::watch;
20use rand::prelude::*;
21use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
22use std::{
23 any::TypeId,
24 collections::HashMap,
25 convert::TryFrom,
26 fmt::Write as _,
27 sync::{
28 atomic::{AtomicUsize, Ordering},
29 Arc, Weak,
30 },
31 time::{Duration, Instant},
32};
33use surf::{http::Method, Url};
34use thiserror::Error;
35use util::{ResultExt, TryFutureExt};
36
37pub use channel::*;
38pub use rpc::*;
39pub use user::*;
40
41lazy_static! {
42 static ref ZED_SERVER_URL: String =
43 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
44 static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
45 .ok()
46 .and_then(|s| if s.is_empty() { None } else { Some(s) });
47}
48
49action!(Authenticate);
50
51pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
52 cx.add_global_action(move |_: &Authenticate, cx| {
53 let rpc = rpc.clone();
54 cx.spawn(|cx| async move { rpc.authenticate_and_connect(&cx).log_err().await })
55 .detach();
56 });
57}
58
59pub struct Client {
60 id: usize,
61 peer: Arc<Peer>,
62 http: Arc<dyn HttpClient>,
63 state: RwLock<ClientState>,
64 authenticate:
65 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
66 establish_connection: Option<
67 Box<
68 dyn 'static
69 + Send
70 + Sync
71 + Fn(
72 &Credentials,
73 &AsyncAppContext,
74 ) -> Task<Result<Connection, EstablishConnectionError>>,
75 >,
76 >,
77}
78
79#[derive(Error, Debug)]
80pub enum EstablishConnectionError {
81 #[error("upgrade required")]
82 UpgradeRequired,
83 #[error("unauthorized")]
84 Unauthorized,
85 #[error("{0}")]
86 Other(#[from] anyhow::Error),
87 #[error("{0}")]
88 Io(#[from] std::io::Error),
89 #[error("{0}")]
90 Http(#[from] async_tungstenite::tungstenite::http::Error),
91}
92
93impl From<WebsocketError> for EstablishConnectionError {
94 fn from(error: WebsocketError) -> Self {
95 if let WebsocketError::Http(response) = &error {
96 match response.status() {
97 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
98 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
99 _ => {}
100 }
101 }
102 EstablishConnectionError::Other(error.into())
103 }
104}
105
106impl EstablishConnectionError {
107 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
108 Self::Other(error.into())
109 }
110}
111
112#[derive(Copy, Clone, Debug)]
113pub enum Status {
114 SignedOut,
115 UpgradeRequired,
116 Authenticating,
117 Connecting,
118 ConnectionError,
119 Connected { connection_id: ConnectionId },
120 ConnectionLost,
121 Reauthenticating,
122 Reconnecting,
123 ReconnectionError { next_reconnection: Instant },
124}
125
126struct ClientState {
127 credentials: Option<Credentials>,
128 status: (watch::Sender<Status>, watch::Receiver<Status>),
129 entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
130 model_handlers: HashMap<
131 (TypeId, Option<u64>),
132 Option<Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>>,
133 >,
134 _maintain_connection: Option<Task<()>>,
135 heartbeat_interval: Duration,
136}
137
138#[derive(Clone, Debug)]
139pub struct Credentials {
140 pub user_id: u64,
141 pub access_token: String,
142}
143
144impl Default for ClientState {
145 fn default() -> Self {
146 Self {
147 credentials: None,
148 status: watch::channel_with(Status::SignedOut),
149 entity_id_extractors: Default::default(),
150 model_handlers: Default::default(),
151 _maintain_connection: None,
152 heartbeat_interval: Duration::from_secs(5),
153 }
154 }
155}
156
157pub struct Subscription {
158 client: Weak<Client>,
159 id: (TypeId, Option<u64>),
160}
161
162impl Drop for Subscription {
163 fn drop(&mut self) {
164 if let Some(client) = self.client.upgrade() {
165 let mut state = client.state.write();
166 let _ = state.model_handlers.remove(&self.id).unwrap();
167 }
168 }
169}
170
171impl Client {
172 pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
173 lazy_static! {
174 static ref NEXT_CLIENT_ID: AtomicUsize = AtomicUsize::default();
175 }
176
177 Arc::new(Self {
178 id: NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst),
179 peer: Peer::new(),
180 http,
181 state: Default::default(),
182 authenticate: None,
183 establish_connection: None,
184 })
185 }
186
187 #[cfg(any(test, feature = "test-support"))]
188 pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
189 where
190 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
191 {
192 self.authenticate = Some(Box::new(authenticate));
193 self
194 }
195
196 #[cfg(any(test, feature = "test-support"))]
197 pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
198 where
199 F: 'static
200 + Send
201 + Sync
202 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
203 {
204 self.establish_connection = Some(Box::new(connect));
205 self
206 }
207
208 pub fn user_id(&self) -> Option<u64> {
209 self.state
210 .read()
211 .credentials
212 .as_ref()
213 .map(|credentials| credentials.user_id)
214 }
215
216 pub fn status(&self) -> watch::Receiver<Status> {
217 self.state.read().status.1.clone()
218 }
219
220 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
221 let mut state = self.state.write();
222 *state.status.0.borrow_mut() = status;
223
224 match status {
225 Status::Connected { .. } => {
226 let heartbeat_interval = state.heartbeat_interval;
227 let this = self.clone();
228 let foreground = cx.foreground();
229 state._maintain_connection = Some(cx.foreground().spawn(async move {
230 loop {
231 foreground.timer(heartbeat_interval).await;
232 let _ = this.request(proto::Ping {}).await;
233 }
234 }));
235 }
236 Status::ConnectionLost => {
237 let this = self.clone();
238 let foreground = cx.foreground();
239 let heartbeat_interval = state.heartbeat_interval;
240 state._maintain_connection = Some(cx.spawn(|cx| async move {
241 let mut rng = StdRng::from_entropy();
242 let mut delay = Duration::from_millis(100);
243 while let Err(error) = this.authenticate_and_connect(&cx).await {
244 log::error!("failed to connect {}", error);
245 this.set_status(
246 Status::ReconnectionError {
247 next_reconnection: Instant::now() + delay,
248 },
249 &cx,
250 );
251 foreground.timer(delay).await;
252 delay = delay
253 .mul_f32(rng.gen_range(1.0..=2.0))
254 .min(heartbeat_interval);
255 }
256 }));
257 }
258 Status::SignedOut | Status::UpgradeRequired => {
259 state._maintain_connection.take();
260 }
261 _ => {}
262 }
263 }
264
265 pub fn subscribe<T, M, F>(
266 self: &Arc<Self>,
267 cx: &mut ModelContext<M>,
268 mut handler: F,
269 ) -> Subscription
270 where
271 T: EnvelopedMessage,
272 M: Entity,
273 F: 'static
274 + Send
275 + Sync
276 + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
277 {
278 let subscription_id = (TypeId::of::<T>(), None);
279 let client = self.clone();
280 let mut state = self.state.write();
281 let model = cx.weak_handle();
282 let prev_handler = state.model_handlers.insert(
283 subscription_id,
284 Some(Box::new(move |envelope, cx| {
285 if let Some(model) = model.upgrade(cx) {
286 let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
287 model.update(cx, |model, cx| {
288 if let Err(error) = handler(model, *envelope, client.clone(), cx) {
289 log::error!("error handling message: {}", error)
290 }
291 });
292 }
293 })),
294 );
295 if prev_handler.is_some() {
296 panic!("registered handler for the same message twice");
297 }
298
299 Subscription {
300 client: Arc::downgrade(self),
301 id: subscription_id,
302 }
303 }
304
305 pub fn subscribe_to_entity<T, M, F>(
306 self: &Arc<Self>,
307 remote_id: u64,
308 cx: &mut ModelContext<M>,
309 mut handler: F,
310 ) -> Subscription
311 where
312 T: EntityMessage,
313 M: Entity,
314 F: 'static
315 + Send
316 + Sync
317 + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
318 {
319 let subscription_id = (TypeId::of::<T>(), Some(remote_id));
320 let client = self.clone();
321 let mut state = self.state.write();
322 let model = cx.weak_handle();
323 state
324 .entity_id_extractors
325 .entry(subscription_id.0)
326 .or_insert_with(|| {
327 Box::new(|envelope| {
328 let envelope = envelope
329 .as_any()
330 .downcast_ref::<TypedEnvelope<T>>()
331 .unwrap();
332 envelope.payload.remote_entity_id()
333 })
334 });
335 let prev_handler = state.model_handlers.insert(
336 subscription_id,
337 Some(Box::new(move |envelope, cx| {
338 if let Some(model) = model.upgrade(cx) {
339 let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
340 model.update(cx, |model, cx| {
341 if let Err(error) = handler(model, *envelope, client.clone(), cx) {
342 log::error!("error handling message: {}", error)
343 }
344 });
345 }
346 })),
347 );
348 if prev_handler.is_some() {
349 panic!("registered a handler for the same entity twice")
350 }
351
352 Subscription {
353 client: Arc::downgrade(self),
354 id: subscription_id,
355 }
356 }
357
358 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
359 read_credentials_from_keychain(cx).is_some()
360 }
361
362 #[async_recursion(?Send)]
363 pub async fn authenticate_and_connect(
364 self: &Arc<Self>,
365 cx: &AsyncAppContext,
366 ) -> anyhow::Result<()> {
367 let was_disconnected = match *self.status().borrow() {
368 Status::SignedOut => true,
369 Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
370 false
371 }
372 Status::Connected { .. }
373 | Status::Connecting { .. }
374 | Status::Reconnecting { .. }
375 | Status::Authenticating
376 | Status::Reauthenticating => return Ok(()),
377 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
378 };
379
380 if was_disconnected {
381 self.set_status(Status::Authenticating, cx);
382 } else {
383 self.set_status(Status::Reauthenticating, cx)
384 }
385
386 let mut used_keychain = false;
387 let credentials = self.state.read().credentials.clone();
388 let credentials = if let Some(credentials) = credentials {
389 credentials
390 } else if let Some(credentials) = read_credentials_from_keychain(cx) {
391 used_keychain = true;
392 credentials
393 } else {
394 let credentials = match self.authenticate(&cx).await {
395 Ok(credentials) => credentials,
396 Err(err) => {
397 self.set_status(Status::ConnectionError, cx);
398 return Err(err);
399 }
400 };
401 credentials
402 };
403
404 if was_disconnected {
405 self.set_status(Status::Connecting, cx);
406 } else {
407 self.set_status(Status::Reconnecting, cx);
408 }
409
410 match self.establish_connection(&credentials, cx).await {
411 Ok(conn) => {
412 self.state.write().credentials = Some(credentials.clone());
413 if !used_keychain && IMPERSONATE_LOGIN.is_none() {
414 write_credentials_to_keychain(&credentials, cx).log_err();
415 }
416 self.set_connection(conn, cx).await;
417 Ok(())
418 }
419 Err(EstablishConnectionError::Unauthorized) => {
420 self.state.write().credentials.take();
421 if used_keychain {
422 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
423 self.set_status(Status::SignedOut, cx);
424 self.authenticate_and_connect(cx).await
425 } else {
426 self.set_status(Status::ConnectionError, cx);
427 Err(EstablishConnectionError::Unauthorized)?
428 }
429 }
430 Err(EstablishConnectionError::UpgradeRequired) => {
431 self.set_status(Status::UpgradeRequired, cx);
432 Err(EstablishConnectionError::UpgradeRequired)?
433 }
434 Err(error) => {
435 self.set_status(Status::ConnectionError, cx);
436 Err(error)?
437 }
438 }
439 }
440
441 async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
442 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
443 cx.foreground()
444 .spawn({
445 let mut cx = cx.clone();
446 let this = self.clone();
447 async move {
448 while let Some(message) = incoming.next().await {
449 let mut state = this.state.write();
450 let payload_type_id = message.payload_type_id();
451 let entity_id = if let Some(extract_entity_id) =
452 state.entity_id_extractors.get(&message.payload_type_id())
453 {
454 Some((extract_entity_id)(message.as_ref()))
455 } else {
456 None
457 };
458
459 let type_name = message.payload_type_name();
460
461 let handler_key = (payload_type_id, entity_id);
462 if let Some(handler) = state.model_handlers.get_mut(&handler_key) {
463 let mut handler = handler.take().unwrap();
464 drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
465
466 log::debug!(
467 "rpc message received. client_id:{}, name:{}",
468 this.id,
469 type_name
470 );
471 (handler)(message, &mut cx);
472 log::debug!(
473 "rpc message handled. client_id:{}, name:{}",
474 this.id,
475 type_name
476 );
477
478 let mut state = this.state.write();
479 if state.model_handlers.contains_key(&handler_key) {
480 state.model_handlers.insert(handler_key, Some(handler));
481 }
482 } else {
483 log::info!("unhandled message {}", type_name);
484 }
485 }
486 }
487 })
488 .detach();
489
490 self.set_status(Status::Connected { connection_id }, cx);
491
492 let handle_io = cx.background().spawn(handle_io);
493 let this = self.clone();
494 let cx = cx.clone();
495 cx.foreground()
496 .spawn(async move {
497 match handle_io.await {
498 Ok(()) => this.set_status(Status::SignedOut, &cx),
499 Err(err) => {
500 log::error!("connection error: {:?}", err);
501 this.set_status(Status::ConnectionLost, &cx);
502 }
503 }
504 })
505 .detach();
506 }
507
508 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
509 if let Some(callback) = self.authenticate.as_ref() {
510 callback(cx)
511 } else {
512 self.authenticate_with_browser(cx)
513 }
514 }
515
516 fn establish_connection(
517 self: &Arc<Self>,
518 credentials: &Credentials,
519 cx: &AsyncAppContext,
520 ) -> Task<Result<Connection, EstablishConnectionError>> {
521 if let Some(callback) = self.establish_connection.as_ref() {
522 callback(credentials, cx)
523 } else {
524 self.establish_websocket_connection(credentials, cx)
525 }
526 }
527
528 fn establish_websocket_connection(
529 self: &Arc<Self>,
530 credentials: &Credentials,
531 cx: &AsyncAppContext,
532 ) -> Task<Result<Connection, EstablishConnectionError>> {
533 let request = Request::builder()
534 .header(
535 "Authorization",
536 format!("{} {}", credentials.user_id, credentials.access_token),
537 )
538 .header("X-Zed-Protocol-Version", rpc::PROTOCOL_VERSION);
539
540 let http = self.http.clone();
541 cx.background().spawn(async move {
542 let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
543 let rpc_request = surf::Request::new(
544 Method::Get,
545 surf::Url::parse(&rpc_url).context("invalid ZED_SERVER_URL")?,
546 );
547 let rpc_response = http.send(rpc_request).await?;
548
549 if rpc_response.status().is_redirection() {
550 rpc_url = rpc_response
551 .header("Location")
552 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
553 .as_str()
554 .to_string();
555 }
556 // Until we switch the zed.dev domain to point to the new Next.js app, there
557 // will be no redirect required, and the app will connect directly to
558 // wss://zed.dev/rpc.
559 else if rpc_response.status() != surf::StatusCode::UpgradeRequired {
560 Err(anyhow!(
561 "unexpected /rpc response status {}",
562 rpc_response.status()
563 ))?
564 }
565
566 let mut rpc_url = surf::Url::parse(&rpc_url).context("invalid rpc url")?;
567 let rpc_host = rpc_url
568 .host_str()
569 .zip(rpc_url.port_or_known_default())
570 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
571 let stream = smol::net::TcpStream::connect(rpc_host).await?;
572
573 log::info!("connected to rpc endpoint {}", rpc_url);
574
575 match rpc_url.scheme() {
576 "https" => {
577 rpc_url.set_scheme("wss").unwrap();
578 let request = request.uri(rpc_url.as_str()).body(())?;
579 let (stream, _) =
580 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
581 Ok(Connection::new(stream))
582 }
583 "http" => {
584 rpc_url.set_scheme("ws").unwrap();
585 let request = request.uri(rpc_url.as_str()).body(())?;
586 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
587 Ok(Connection::new(stream))
588 }
589 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
590 }
591 })
592 }
593
594 pub fn authenticate_with_browser(
595 self: &Arc<Self>,
596 cx: &AsyncAppContext,
597 ) -> Task<Result<Credentials>> {
598 let platform = cx.platform();
599 let executor = cx.background();
600 executor.clone().spawn(async move {
601 // Generate a pair of asymmetric encryption keys. The public key will be used by the
602 // zed server to encrypt the user's access token, so that it can'be intercepted by
603 // any other app running on the user's device.
604 let (public_key, private_key) =
605 rpc::auth::keypair().expect("failed to generate keypair for auth");
606 let public_key_string =
607 String::try_from(public_key).expect("failed to serialize public key for auth");
608
609 // Start an HTTP server to receive the redirect from Zed's sign-in page.
610 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
611 let port = server.server_addr().port();
612
613 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
614 // that the user is signing in from a Zed app running on the same device.
615 let mut url = format!(
616 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
617 *ZED_SERVER_URL, port, public_key_string
618 );
619
620 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
621 log::info!("impersonating user @{}", impersonate_login);
622 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
623 }
624
625 platform.open_url(&url);
626
627 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
628 // access token from the query params.
629 //
630 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
631 // custom URL scheme instead of this local HTTP server.
632 let (user_id, access_token) = executor
633 .spawn(async move {
634 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
635 let path = req.url();
636 let mut user_id = None;
637 let mut access_token = None;
638 let url = Url::parse(&format!("http://example.com{}", path))
639 .context("failed to parse login notification url")?;
640 for (key, value) in url.query_pairs() {
641 if key == "access_token" {
642 access_token = Some(value.to_string());
643 } else if key == "user_id" {
644 user_id = Some(value.to_string());
645 }
646 }
647
648 let post_auth_url =
649 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
650 req.respond(
651 tiny_http::Response::empty(302).with_header(
652 tiny_http::Header::from_bytes(
653 &b"Location"[..],
654 post_auth_url.as_bytes(),
655 )
656 .unwrap(),
657 ),
658 )
659 .context("failed to respond to login http request")?;
660 Ok((
661 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
662 access_token
663 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
664 ))
665 } else {
666 Err(anyhow!("didn't receive login redirect"))
667 }
668 })
669 .await?;
670
671 let access_token = private_key
672 .decrypt_string(&access_token)
673 .context("failed to decrypt access token")?;
674 platform.activate(true);
675
676 Ok(Credentials {
677 user_id: user_id.parse()?,
678 access_token,
679 })
680 })
681 }
682
683 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
684 let conn_id = self.connection_id()?;
685 self.peer.disconnect(conn_id);
686 self.set_status(Status::SignedOut, cx);
687 Ok(())
688 }
689
690 fn connection_id(&self) -> Result<ConnectionId> {
691 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
692 Ok(connection_id)
693 } else {
694 Err(anyhow!("not connected"))
695 }
696 }
697
698 pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
699 log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
700 self.peer.send(self.connection_id()?, message)
701 }
702
703 pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
704 log::debug!(
705 "rpc request start. client_id: {}. name:{}",
706 self.id,
707 T::NAME
708 );
709 let response = self.peer.request(self.connection_id()?, request).await;
710 log::debug!(
711 "rpc request finish. client_id: {}. name:{}",
712 self.id,
713 T::NAME
714 );
715 response
716 }
717
718 pub fn respond<T: RequestMessage>(
719 &self,
720 receipt: Receipt<T>,
721 response: T::Response,
722 ) -> Result<()> {
723 log::debug!("rpc respond. client_id: {}. name:{}", self.id, T::NAME);
724 self.peer.respond(receipt, response)
725 }
726
727 pub fn respond_with_error<T: RequestMessage>(
728 &self,
729 receipt: Receipt<T>,
730 error: proto::Error,
731 ) -> Result<()> {
732 log::debug!("rpc respond. client_id: {}. name:{}", self.id, T::NAME);
733 self.peer.respond_with_error(receipt, error)
734 }
735}
736
737fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
738 if IMPERSONATE_LOGIN.is_some() {
739 return None;
740 }
741
742 let (user_id, access_token) = cx
743 .platform()
744 .read_credentials(&ZED_SERVER_URL)
745 .log_err()
746 .flatten()?;
747 Some(Credentials {
748 user_id: user_id.parse().ok()?,
749 access_token: String::from_utf8(access_token).ok()?,
750 })
751}
752
753fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
754 cx.platform().write_credentials(
755 &ZED_SERVER_URL,
756 &credentials.user_id.to_string(),
757 credentials.access_token.as_bytes(),
758 )
759}
760
761const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
762
763pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
764 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
765}
766
767pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
768 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
769 let mut parts = path.split('/');
770 let id = parts.next()?.parse::<u64>().ok()?;
771 let access_token = parts.next()?;
772 if access_token.is_empty() {
773 return None;
774 }
775 Some((id, access_token.to_string()))
776}
777
778#[cfg(test)]
779mod tests {
780 use super::*;
781 use crate::test::{FakeHttpClient, FakeServer};
782 use gpui::TestAppContext;
783
784 #[gpui::test(iterations = 10)]
785 async fn test_heartbeat(cx: TestAppContext) {
786 cx.foreground().forbid_parking();
787
788 let user_id = 5;
789 let mut client = Client::new(FakeHttpClient::with_404_response());
790 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
791
792 cx.foreground().advance_clock(Duration::from_secs(10));
793 let ping = server.receive::<proto::Ping>().await.unwrap();
794 server.respond(ping.receipt(), proto::Ack {}).await;
795
796 cx.foreground().advance_clock(Duration::from_secs(10));
797 let ping = server.receive::<proto::Ping>().await.unwrap();
798 server.respond(ping.receipt(), proto::Ack {}).await;
799
800 client.disconnect(&cx.to_async()).unwrap();
801 assert!(server.receive::<proto::Ping>().await.is_err());
802 }
803
804 #[gpui::test(iterations = 10)]
805 async fn test_reconnection(cx: TestAppContext) {
806 cx.foreground().forbid_parking();
807
808 let user_id = 5;
809 let mut client = Client::new(FakeHttpClient::with_404_response());
810 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
811 let mut status = client.status();
812 assert!(matches!(
813 status.next().await,
814 Some(Status::Connected { .. })
815 ));
816 assert_eq!(server.auth_count(), 1);
817
818 server.forbid_connections();
819 server.disconnect();
820 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
821
822 server.allow_connections();
823 cx.foreground().advance_clock(Duration::from_secs(10));
824 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
825 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
826
827 server.forbid_connections();
828 server.disconnect();
829 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
830
831 // Clear cached credentials after authentication fails
832 server.roll_access_token();
833 server.allow_connections();
834 cx.foreground().advance_clock(Duration::from_secs(10));
835 assert_eq!(server.auth_count(), 1);
836 cx.foreground().advance_clock(Duration::from_secs(10));
837 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
838 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
839 }
840
841 #[test]
842 fn test_encode_and_decode_worktree_url() {
843 let url = encode_worktree_url(5, "deadbeef");
844 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
845 assert_eq!(
846 decode_worktree_url(&format!("\n {}\t", url)),
847 Some((5, "deadbeef".to_string()))
848 );
849 assert_eq!(decode_worktree_url("not://the-right-format"), None);
850 }
851
852 #[gpui::test]
853 async fn test_subscribing_to_entity(mut cx: TestAppContext) {
854 cx.foreground().forbid_parking();
855
856 let user_id = 5;
857 let mut client = Client::new(FakeHttpClient::with_404_response());
858 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
859
860 let model = cx.add_model(|_| Model { subscription: None });
861 let (mut done_tx1, mut done_rx1) = postage::oneshot::channel();
862 let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
863 let _subscription1 = model.update(&mut cx, |_, cx| {
864 client.subscribe_to_entity(
865 1,
866 cx,
867 move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
868 postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
869 Ok(())
870 },
871 )
872 });
873 let _subscription2 = model.update(&mut cx, |_, cx| {
874 client.subscribe_to_entity(
875 2,
876 cx,
877 move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
878 postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
879 Ok(())
880 },
881 )
882 });
883
884 // Ensure dropping a subscription for the same entity type still allows receiving of
885 // messages for other entity IDs of the same type.
886 let subscription3 = model.update(&mut cx, |_, cx| {
887 client.subscribe_to_entity(
888 3,
889 cx,
890 move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| Ok(()),
891 )
892 });
893 drop(subscription3);
894
895 server.send(proto::UnshareProject { project_id: 1 });
896 server.send(proto::UnshareProject { project_id: 2 });
897 done_rx1.next().await.unwrap();
898 done_rx2.next().await.unwrap();
899 }
900
901 #[gpui::test]
902 async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) {
903 cx.foreground().forbid_parking();
904
905 let user_id = 5;
906 let mut client = Client::new(FakeHttpClient::with_404_response());
907 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
908
909 let model = cx.add_model(|_| Model { subscription: None });
910 let (mut done_tx1, _done_rx1) = postage::oneshot::channel();
911 let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
912 let subscription1 = model.update(&mut cx, |_, cx| {
913 client.subscribe(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
914 postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
915 Ok(())
916 })
917 });
918 drop(subscription1);
919 let _subscription2 = model.update(&mut cx, |_, cx| {
920 client.subscribe(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
921 postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
922 Ok(())
923 })
924 });
925 server.send(proto::Ping {});
926 done_rx2.next().await.unwrap();
927 }
928
929 #[gpui::test]
930 async fn test_dropping_subscription_in_handler(mut cx: TestAppContext) {
931 cx.foreground().forbid_parking();
932
933 let user_id = 5;
934 let mut client = Client::new(FakeHttpClient::with_404_response());
935 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
936
937 let model = cx.add_model(|_| Model { subscription: None });
938 let (mut done_tx, mut done_rx) = postage::oneshot::channel();
939 model.update(&mut cx, |model, cx| {
940 model.subscription = Some(client.subscribe(
941 cx,
942 move |model, _: TypedEnvelope<proto::Ping>, _, _| {
943 model.subscription.take();
944 postage::sink::Sink::try_send(&mut done_tx, ()).unwrap();
945 Ok(())
946 },
947 ));
948 });
949 server.send(proto::Ping {});
950 done_rx.next().await.unwrap();
951 }
952
953 struct Model {
954 subscription: Option<Subscription>,
955 }
956
957 impl Entity for Model {
958 type Event = ();
959 }
960}