1#[cfg(any(test, feature = "test-support"))]
2pub mod test;
3
4pub mod channel;
5pub mod http;
6pub mod user;
7
8use anyhow::{anyhow, Context, Result};
9use async_recursion::async_recursion;
10use async_tungstenite::tungstenite::{
11 error::Error as WebsocketError,
12 http::{Request, StatusCode},
13};
14use futures::StreamExt;
15use gpui::{action, AsyncAppContext, Entity, ModelContext, MutableAppContext, Task};
16use http::HttpClient;
17use lazy_static::lazy_static;
18use parking_lot::RwLock;
19use postage::watch;
20use rand::prelude::*;
21use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
22use std::{
23 any::TypeId,
24 collections::HashMap,
25 convert::TryFrom,
26 fmt::Write as _,
27 future::Future,
28 sync::{Arc, Weak},
29 time::{Duration, Instant},
30};
31use surf::{http::Method, Url};
32use thiserror::Error;
33use util::{ResultExt, TryFutureExt};
34
35pub use channel::*;
36pub use rpc::*;
37pub use user::*;
38
39lazy_static! {
40 static ref ZED_SERVER_URL: String =
41 std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
42 static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
43 .ok()
44 .and_then(|s| if s.is_empty() { None } else { Some(s) });
45}
46
47action!(Authenticate);
48
49pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
50 cx.add_global_action(move |_: &Authenticate, cx| {
51 let rpc = rpc.clone();
52 cx.spawn(|cx| async move { rpc.authenticate_and_connect(&cx).log_err().await })
53 .detach();
54 });
55}
56
57pub struct Client {
58 peer: Arc<Peer>,
59 http: Arc<dyn HttpClient>,
60 state: RwLock<ClientState>,
61 authenticate:
62 Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
63 establish_connection: Option<
64 Box<
65 dyn 'static
66 + Send
67 + Sync
68 + Fn(
69 &Credentials,
70 &AsyncAppContext,
71 ) -> Task<Result<Connection, EstablishConnectionError>>,
72 >,
73 >,
74}
75
76#[derive(Error, Debug)]
77pub enum EstablishConnectionError {
78 #[error("upgrade required")]
79 UpgradeRequired,
80 #[error("unauthorized")]
81 Unauthorized,
82 #[error("{0}")]
83 Other(#[from] anyhow::Error),
84 #[error("{0}")]
85 Io(#[from] std::io::Error),
86 #[error("{0}")]
87 Http(#[from] async_tungstenite::tungstenite::http::Error),
88}
89
90impl From<WebsocketError> for EstablishConnectionError {
91 fn from(error: WebsocketError) -> Self {
92 if let WebsocketError::Http(response) = &error {
93 match response.status() {
94 StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
95 StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
96 _ => {}
97 }
98 }
99 EstablishConnectionError::Other(error.into())
100 }
101}
102
103impl EstablishConnectionError {
104 pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
105 Self::Other(error.into())
106 }
107}
108
109#[derive(Copy, Clone, Debug)]
110pub enum Status {
111 SignedOut,
112 UpgradeRequired,
113 Authenticating,
114 Connecting,
115 ConnectionError,
116 Connected { connection_id: ConnectionId },
117 ConnectionLost,
118 Reauthenticating,
119 Reconnecting,
120 ReconnectionError { next_reconnection: Instant },
121}
122
123struct ClientState {
124 credentials: Option<Credentials>,
125 status: (watch::Sender<Status>, watch::Receiver<Status>),
126 entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
127 model_handlers: HashMap<
128 (TypeId, Option<u64>),
129 Option<Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>>,
130 >,
131 _maintain_connection: Option<Task<()>>,
132 heartbeat_interval: Duration,
133}
134
135#[derive(Clone, Debug)]
136pub struct Credentials {
137 pub user_id: u64,
138 pub access_token: String,
139}
140
141impl Default for ClientState {
142 fn default() -> Self {
143 Self {
144 credentials: None,
145 status: watch::channel_with(Status::SignedOut),
146 entity_id_extractors: Default::default(),
147 model_handlers: Default::default(),
148 _maintain_connection: None,
149 heartbeat_interval: Duration::from_secs(5),
150 }
151 }
152}
153
154pub struct Subscription {
155 client: Weak<Client>,
156 id: (TypeId, Option<u64>),
157}
158
159impl Drop for Subscription {
160 fn drop(&mut self) {
161 if let Some(client) = self.client.upgrade() {
162 let mut state = client.state.write();
163 let _ = state.model_handlers.remove(&self.id).unwrap();
164 }
165 }
166}
167
168impl Client {
169 pub fn new(http: Arc<dyn HttpClient>) -> Arc<Self> {
170 Arc::new(Self {
171 peer: Peer::new(),
172 http,
173 state: Default::default(),
174 authenticate: None,
175 establish_connection: None,
176 })
177 }
178
179 #[cfg(any(test, feature = "test-support"))]
180 pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
181 where
182 F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
183 {
184 self.authenticate = Some(Box::new(authenticate));
185 self
186 }
187
188 #[cfg(any(test, feature = "test-support"))]
189 pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
190 where
191 F: 'static
192 + Send
193 + Sync
194 + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
195 {
196 self.establish_connection = Some(Box::new(connect));
197 self
198 }
199
200 pub fn user_id(&self) -> Option<u64> {
201 self.state
202 .read()
203 .credentials
204 .as_ref()
205 .map(|credentials| credentials.user_id)
206 }
207
208 pub fn status(&self) -> watch::Receiver<Status> {
209 self.state.read().status.1.clone()
210 }
211
212 fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
213 let mut state = self.state.write();
214 *state.status.0.borrow_mut() = status;
215
216 match status {
217 Status::Connected { .. } => {
218 let heartbeat_interval = state.heartbeat_interval;
219 let this = self.clone();
220 let foreground = cx.foreground();
221 state._maintain_connection = Some(cx.foreground().spawn(async move {
222 loop {
223 foreground.timer(heartbeat_interval).await;
224 let _ = this.request(proto::Ping {}).await;
225 }
226 }));
227 }
228 Status::ConnectionLost => {
229 let this = self.clone();
230 let foreground = cx.foreground();
231 let heartbeat_interval = state.heartbeat_interval;
232 state._maintain_connection = Some(cx.spawn(|cx| async move {
233 let mut rng = StdRng::from_entropy();
234 let mut delay = Duration::from_millis(100);
235 while let Err(error) = this.authenticate_and_connect(&cx).await {
236 log::error!("failed to connect {}", error);
237 this.set_status(
238 Status::ReconnectionError {
239 next_reconnection: Instant::now() + delay,
240 },
241 &cx,
242 );
243 foreground.timer(delay).await;
244 delay = delay
245 .mul_f32(rng.gen_range(1.0..=2.0))
246 .min(heartbeat_interval);
247 }
248 }));
249 }
250 Status::SignedOut | Status::UpgradeRequired => {
251 state._maintain_connection.take();
252 }
253 _ => {}
254 }
255 }
256
257 pub fn subscribe<T, M, F>(
258 self: &Arc<Self>,
259 cx: &mut ModelContext<M>,
260 mut handler: F,
261 ) -> Subscription
262 where
263 T: EnvelopedMessage,
264 M: Entity,
265 F: 'static
266 + Send
267 + Sync
268 + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
269 {
270 let subscription_id = (TypeId::of::<T>(), None);
271 let client = self.clone();
272 let mut state = self.state.write();
273 let model = cx.weak_handle();
274 let prev_handler = state.model_handlers.insert(
275 subscription_id,
276 Some(Box::new(move |envelope, cx| {
277 if let Some(model) = model.upgrade(cx) {
278 let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
279 model.update(cx, |model, cx| {
280 if let Err(error) = handler(model, *envelope, client.clone(), cx) {
281 log::error!("error handling message: {}", error)
282 }
283 });
284 }
285 })),
286 );
287 if prev_handler.is_some() {
288 panic!("registered handler for the same message twice");
289 }
290
291 Subscription {
292 client: Arc::downgrade(self),
293 id: subscription_id,
294 }
295 }
296
297 pub fn subscribe_to_entity<T, M, F>(
298 self: &Arc<Self>,
299 remote_id: u64,
300 cx: &mut ModelContext<M>,
301 mut handler: F,
302 ) -> Subscription
303 where
304 T: EntityMessage,
305 M: Entity,
306 F: 'static
307 + Send
308 + Sync
309 + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
310 {
311 let subscription_id = (TypeId::of::<T>(), Some(remote_id));
312 let client = self.clone();
313 let mut state = self.state.write();
314 let model = cx.weak_handle();
315 state
316 .entity_id_extractors
317 .entry(subscription_id.0)
318 .or_insert_with(|| {
319 Box::new(|envelope| {
320 let envelope = envelope
321 .as_any()
322 .downcast_ref::<TypedEnvelope<T>>()
323 .unwrap();
324 envelope.payload.remote_entity_id()
325 })
326 });
327 let prev_handler = state.model_handlers.insert(
328 subscription_id,
329 Some(Box::new(move |envelope, cx| {
330 if let Some(model) = model.upgrade(cx) {
331 let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
332 model.update(cx, |model, cx| {
333 if let Err(error) = handler(model, *envelope, client.clone(), cx) {
334 log::error!("error handling message: {}", error)
335 }
336 });
337 }
338 })),
339 );
340 if prev_handler.is_some() {
341 panic!("registered a handler for the same entity twice")
342 }
343
344 Subscription {
345 client: Arc::downgrade(self),
346 id: subscription_id,
347 }
348 }
349
350 pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
351 read_credentials_from_keychain(cx).is_some()
352 }
353
354 #[async_recursion(?Send)]
355 pub async fn authenticate_and_connect(
356 self: &Arc<Self>,
357 cx: &AsyncAppContext,
358 ) -> anyhow::Result<()> {
359 let was_disconnected = match *self.status().borrow() {
360 Status::SignedOut => true,
361 Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
362 false
363 }
364 Status::Connected { .. }
365 | Status::Connecting { .. }
366 | Status::Reconnecting { .. }
367 | Status::Authenticating
368 | Status::Reauthenticating => return Ok(()),
369 Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
370 };
371
372 if was_disconnected {
373 self.set_status(Status::Authenticating, cx);
374 } else {
375 self.set_status(Status::Reauthenticating, cx)
376 }
377
378 let mut used_keychain = false;
379 let credentials = self.state.read().credentials.clone();
380 let credentials = if let Some(credentials) = credentials {
381 credentials
382 } else if let Some(credentials) = read_credentials_from_keychain(cx) {
383 used_keychain = true;
384 credentials
385 } else {
386 let credentials = match self.authenticate(&cx).await {
387 Ok(credentials) => credentials,
388 Err(err) => {
389 self.set_status(Status::ConnectionError, cx);
390 return Err(err);
391 }
392 };
393 credentials
394 };
395
396 if was_disconnected {
397 self.set_status(Status::Connecting, cx);
398 } else {
399 self.set_status(Status::Reconnecting, cx);
400 }
401
402 match self.establish_connection(&credentials, cx).await {
403 Ok(conn) => {
404 self.state.write().credentials = Some(credentials.clone());
405 if !used_keychain && IMPERSONATE_LOGIN.is_none() {
406 write_credentials_to_keychain(&credentials, cx).log_err();
407 }
408 self.set_connection(conn, cx).await;
409 Ok(())
410 }
411 Err(EstablishConnectionError::Unauthorized) => {
412 self.state.write().credentials.take();
413 if used_keychain {
414 cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
415 self.set_status(Status::SignedOut, cx);
416 self.authenticate_and_connect(cx).await
417 } else {
418 self.set_status(Status::ConnectionError, cx);
419 Err(EstablishConnectionError::Unauthorized)?
420 }
421 }
422 Err(EstablishConnectionError::UpgradeRequired) => {
423 self.set_status(Status::UpgradeRequired, cx);
424 Err(EstablishConnectionError::UpgradeRequired)?
425 }
426 Err(error) => {
427 self.set_status(Status::ConnectionError, cx);
428 Err(error)?
429 }
430 }
431 }
432
433 async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
434 let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
435 cx.foreground()
436 .spawn({
437 let mut cx = cx.clone();
438 let this = self.clone();
439 async move {
440 while let Some(message) = incoming.next().await {
441 let mut state = this.state.write();
442 let payload_type_id = message.payload_type_id();
443 let entity_id = if let Some(extract_entity_id) =
444 state.entity_id_extractors.get(&message.payload_type_id())
445 {
446 Some((extract_entity_id)(message.as_ref()))
447 } else {
448 None
449 };
450
451 let handler_key = (payload_type_id, entity_id);
452 if let Some(handler) = state.model_handlers.get_mut(&handler_key) {
453 let mut handler = handler.take().unwrap();
454 drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
455 let start_time = Instant::now();
456 log::info!("RPC client message {}", message.payload_type_name());
457 (handler)(message, &mut cx);
458 log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
459
460 let mut state = this.state.write();
461 if state.model_handlers.contains_key(&handler_key) {
462 state.model_handlers.insert(handler_key, Some(handler));
463 }
464 } else {
465 log::info!("unhandled message {}", message.payload_type_name());
466 }
467 }
468 }
469 })
470 .detach();
471
472 self.set_status(Status::Connected { connection_id }, cx);
473
474 let handle_io = cx.background().spawn(handle_io);
475 let this = self.clone();
476 let cx = cx.clone();
477 cx.foreground()
478 .spawn(async move {
479 match handle_io.await {
480 Ok(()) => this.set_status(Status::SignedOut, &cx),
481 Err(err) => {
482 log::error!("connection error: {:?}", err);
483 this.set_status(Status::ConnectionLost, &cx);
484 }
485 }
486 })
487 .detach();
488 }
489
490 fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
491 if let Some(callback) = self.authenticate.as_ref() {
492 callback(cx)
493 } else {
494 self.authenticate_with_browser(cx)
495 }
496 }
497
498 fn establish_connection(
499 self: &Arc<Self>,
500 credentials: &Credentials,
501 cx: &AsyncAppContext,
502 ) -> Task<Result<Connection, EstablishConnectionError>> {
503 if let Some(callback) = self.establish_connection.as_ref() {
504 callback(credentials, cx)
505 } else {
506 self.establish_websocket_connection(credentials, cx)
507 }
508 }
509
510 fn establish_websocket_connection(
511 self: &Arc<Self>,
512 credentials: &Credentials,
513 cx: &AsyncAppContext,
514 ) -> Task<Result<Connection, EstablishConnectionError>> {
515 let request = Request::builder()
516 .header(
517 "Authorization",
518 format!("{} {}", credentials.user_id, credentials.access_token),
519 )
520 .header("X-Zed-Protocol-Version", rpc::PROTOCOL_VERSION);
521
522 let http = self.http.clone();
523 cx.background().spawn(async move {
524 let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
525 let rpc_request = surf::Request::new(
526 Method::Get,
527 surf::Url::parse(&rpc_url).context("invalid ZED_SERVER_URL")?,
528 );
529 let rpc_response = http.send(rpc_request).await?;
530
531 if rpc_response.status().is_redirection() {
532 rpc_url = rpc_response
533 .header("Location")
534 .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
535 .as_str()
536 .to_string();
537 }
538 // Until we switch the zed.dev domain to point to the new Next.js app, there
539 // will be no redirect required, and the app will connect directly to
540 // wss://zed.dev/rpc.
541 else if rpc_response.status() != surf::StatusCode::UpgradeRequired {
542 Err(anyhow!(
543 "unexpected /rpc response status {}",
544 rpc_response.status()
545 ))?
546 }
547
548 let mut rpc_url = surf::Url::parse(&rpc_url).context("invalid rpc url")?;
549 let rpc_host = rpc_url
550 .host_str()
551 .zip(rpc_url.port_or_known_default())
552 .ok_or_else(|| anyhow!("missing host in rpc url"))?;
553 let stream = smol::net::TcpStream::connect(rpc_host).await?;
554
555 log::info!("connected to rpc endpoint {}", rpc_url);
556
557 match rpc_url.scheme() {
558 "https" => {
559 rpc_url.set_scheme("wss").unwrap();
560 let request = request.uri(rpc_url.as_str()).body(())?;
561 let (stream, _) =
562 async_tungstenite::async_tls::client_async_tls(request, stream).await?;
563 Ok(Connection::new(stream))
564 }
565 "http" => {
566 rpc_url.set_scheme("ws").unwrap();
567 let request = request.uri(rpc_url.as_str()).body(())?;
568 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
569 Ok(Connection::new(stream))
570 }
571 _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
572 }
573 })
574 }
575
576 pub fn authenticate_with_browser(
577 self: &Arc<Self>,
578 cx: &AsyncAppContext,
579 ) -> Task<Result<Credentials>> {
580 let platform = cx.platform();
581 let executor = cx.background();
582 executor.clone().spawn(async move {
583 // Generate a pair of asymmetric encryption keys. The public key will be used by the
584 // zed server to encrypt the user's access token, so that it can'be intercepted by
585 // any other app running on the user's device.
586 let (public_key, private_key) =
587 rpc::auth::keypair().expect("failed to generate keypair for auth");
588 let public_key_string =
589 String::try_from(public_key).expect("failed to serialize public key for auth");
590
591 // Start an HTTP server to receive the redirect from Zed's sign-in page.
592 let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
593 let port = server.server_addr().port();
594
595 // Open the Zed sign-in page in the user's browser, with query parameters that indicate
596 // that the user is signing in from a Zed app running on the same device.
597 let mut url = format!(
598 "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
599 *ZED_SERVER_URL, port, public_key_string
600 );
601
602 if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
603 log::info!("impersonating user @{}", impersonate_login);
604 write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
605 }
606
607 platform.open_url(&url);
608
609 // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
610 // access token from the query params.
611 //
612 // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
613 // custom URL scheme instead of this local HTTP server.
614 let (user_id, access_token) = executor
615 .spawn(async move {
616 if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
617 let path = req.url();
618 let mut user_id = None;
619 let mut access_token = None;
620 let url = Url::parse(&format!("http://example.com{}", path))
621 .context("failed to parse login notification url")?;
622 for (key, value) in url.query_pairs() {
623 if key == "access_token" {
624 access_token = Some(value.to_string());
625 } else if key == "user_id" {
626 user_id = Some(value.to_string());
627 }
628 }
629
630 let post_auth_url =
631 format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
632 req.respond(
633 tiny_http::Response::empty(302).with_header(
634 tiny_http::Header::from_bytes(
635 &b"Location"[..],
636 post_auth_url.as_bytes(),
637 )
638 .unwrap(),
639 ),
640 )
641 .context("failed to respond to login http request")?;
642 Ok((
643 user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
644 access_token
645 .ok_or_else(|| anyhow!("missing access_token parameter"))?,
646 ))
647 } else {
648 Err(anyhow!("didn't receive login redirect"))
649 }
650 })
651 .await?;
652
653 let access_token = private_key
654 .decrypt_string(&access_token)
655 .context("failed to decrypt access token")?;
656 platform.activate(true);
657
658 Ok(Credentials {
659 user_id: user_id.parse()?,
660 access_token,
661 })
662 })
663 }
664
665 pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
666 let conn_id = self.connection_id()?;
667 self.peer.disconnect(conn_id);
668 self.set_status(Status::SignedOut, cx);
669 Ok(())
670 }
671
672 fn connection_id(&self) -> Result<ConnectionId> {
673 if let Status::Connected { connection_id, .. } = *self.status().borrow() {
674 Ok(connection_id)
675 } else {
676 Err(anyhow!("not connected"))
677 }
678 }
679
680 pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
681 self.peer.send(self.connection_id()?, message).await
682 }
683
684 pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
685 self.peer.request(self.connection_id()?, request).await
686 }
687
688 pub fn respond<T: RequestMessage>(
689 &self,
690 receipt: Receipt<T>,
691 response: T::Response,
692 ) -> impl Future<Output = Result<()>> {
693 self.peer.respond(receipt, response)
694 }
695
696 pub fn respond_with_error<T: RequestMessage>(
697 &self,
698 receipt: Receipt<T>,
699 error: proto::Error,
700 ) -> impl Future<Output = Result<()>> {
701 self.peer.respond_with_error(receipt, error)
702 }
703}
704
705fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
706 if IMPERSONATE_LOGIN.is_some() {
707 return None;
708 }
709
710 let (user_id, access_token) = cx
711 .platform()
712 .read_credentials(&ZED_SERVER_URL)
713 .log_err()
714 .flatten()?;
715 Some(Credentials {
716 user_id: user_id.parse().ok()?,
717 access_token: String::from_utf8(access_token).ok()?,
718 })
719}
720
721fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
722 cx.platform().write_credentials(
723 &ZED_SERVER_URL,
724 &credentials.user_id.to_string(),
725 credentials.access_token.as_bytes(),
726 )
727}
728
729const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
730
731pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
732 format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
733}
734
735pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
736 let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
737 let mut parts = path.split('/');
738 let id = parts.next()?.parse::<u64>().ok()?;
739 let access_token = parts.next()?;
740 if access_token.is_empty() {
741 return None;
742 }
743 Some((id, access_token.to_string()))
744}
745
746#[cfg(test)]
747mod tests {
748 use super::*;
749 use crate::test::{FakeHttpClient, FakeServer};
750 use gpui::TestAppContext;
751
752 #[gpui::test(iterations = 10)]
753 async fn test_heartbeat(cx: TestAppContext) {
754 cx.foreground().forbid_parking();
755
756 let user_id = 5;
757 let mut client = Client::new(FakeHttpClient::with_404_response());
758 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
759
760 cx.foreground().advance_clock(Duration::from_secs(10));
761 let ping = server.receive::<proto::Ping>().await.unwrap();
762 server.respond(ping.receipt(), proto::Ack {}).await;
763
764 cx.foreground().advance_clock(Duration::from_secs(10));
765 let ping = server.receive::<proto::Ping>().await.unwrap();
766 server.respond(ping.receipt(), proto::Ack {}).await;
767
768 client.disconnect(&cx.to_async()).unwrap();
769 assert!(server.receive::<proto::Ping>().await.is_err());
770 }
771
772 #[gpui::test(iterations = 10)]
773 async fn test_reconnection(cx: TestAppContext) {
774 cx.foreground().forbid_parking();
775
776 let user_id = 5;
777 let mut client = Client::new(FakeHttpClient::with_404_response());
778 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
779 let mut status = client.status();
780 assert!(matches!(
781 status.next().await,
782 Some(Status::Connected { .. })
783 ));
784 assert_eq!(server.auth_count(), 1);
785
786 server.forbid_connections();
787 server.disconnect();
788 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
789
790 server.allow_connections();
791 cx.foreground().advance_clock(Duration::from_secs(10));
792 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
793 assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
794
795 server.forbid_connections();
796 server.disconnect();
797 while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
798
799 // Clear cached credentials after authentication fails
800 server.roll_access_token();
801 server.allow_connections();
802 cx.foreground().advance_clock(Duration::from_secs(10));
803 assert_eq!(server.auth_count(), 1);
804 cx.foreground().advance_clock(Duration::from_secs(10));
805 while !matches!(status.next().await, Some(Status::Connected { .. })) {}
806 assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
807 }
808
809 #[test]
810 fn test_encode_and_decode_worktree_url() {
811 let url = encode_worktree_url(5, "deadbeef");
812 assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
813 assert_eq!(
814 decode_worktree_url(&format!("\n {}\t", url)),
815 Some((5, "deadbeef".to_string()))
816 );
817 assert_eq!(decode_worktree_url("not://the-right-format"), None);
818 }
819
820 #[gpui::test]
821 async fn test_subscribing_to_entity(mut cx: TestAppContext) {
822 cx.foreground().forbid_parking();
823
824 let user_id = 5;
825 let mut client = Client::new(FakeHttpClient::with_404_response());
826 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
827
828 let model = cx.add_model(|_| Model { subscription: None });
829 let (mut done_tx1, mut done_rx1) = postage::oneshot::channel();
830 let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
831 let _subscription1 = model.update(&mut cx, |_, cx| {
832 client.subscribe_to_entity(
833 1,
834 cx,
835 move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
836 postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
837 Ok(())
838 },
839 )
840 });
841 let _subscription2 = model.update(&mut cx, |_, cx| {
842 client.subscribe_to_entity(
843 2,
844 cx,
845 move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
846 postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
847 Ok(())
848 },
849 )
850 });
851
852 // Ensure dropping a subscription for the same entity type still allows receiving of
853 // messages for other entity IDs of the same type.
854 let subscription3 = model.update(&mut cx, |_, cx| {
855 client.subscribe_to_entity(
856 3,
857 cx,
858 move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| Ok(()),
859 )
860 });
861 drop(subscription3);
862
863 server.send(proto::UnshareProject { project_id: 1 }).await;
864 server.send(proto::UnshareProject { project_id: 2 }).await;
865 done_rx1.next().await.unwrap();
866 done_rx2.next().await.unwrap();
867 }
868
869 #[gpui::test]
870 async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) {
871 cx.foreground().forbid_parking();
872
873 let user_id = 5;
874 let mut client = Client::new(FakeHttpClient::with_404_response());
875 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
876
877 let model = cx.add_model(|_| Model { subscription: None });
878 let (mut done_tx1, _done_rx1) = postage::oneshot::channel();
879 let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
880 let subscription1 = model.update(&mut cx, |_, cx| {
881 client.subscribe(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
882 postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
883 Ok(())
884 })
885 });
886 drop(subscription1);
887 let _subscription2 = model.update(&mut cx, |_, cx| {
888 client.subscribe(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
889 postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
890 Ok(())
891 })
892 });
893 server.send(proto::Ping {}).await;
894 done_rx2.next().await.unwrap();
895 }
896
897 #[gpui::test]
898 async fn test_dropping_subscription_in_handler(mut cx: TestAppContext) {
899 cx.foreground().forbid_parking();
900
901 let user_id = 5;
902 let mut client = Client::new(FakeHttpClient::with_404_response());
903 let server = FakeServer::for_client(user_id, &mut client, &cx).await;
904
905 let model = cx.add_model(|_| Model { subscription: None });
906 let (mut done_tx, mut done_rx) = postage::oneshot::channel();
907 model.update(&mut cx, |model, cx| {
908 model.subscription = Some(client.subscribe(
909 cx,
910 move |model, _: TypedEnvelope<proto::Ping>, _, _| {
911 model.subscription.take();
912 postage::sink::Sink::try_send(&mut done_tx, ()).unwrap();
913 Ok(())
914 },
915 ));
916 });
917 server.send(proto::Ping {}).await;
918 done_rx.next().await.unwrap();
919 }
920
921 struct Model {
922 subscription: Option<Subscription>,
923 }
924
925 impl Entity for Model {
926 type Event = ();
927 }
928}