diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 0e9ec4076ad43754a53e11f833935c9b807f025c..af084dc88d45eff6acbcf57c3de73499e34fbee6 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -569,14 +569,14 @@ impl Client { ) -> anyhow::Result<()> { let was_disconnected = match *self.status().borrow() { Status::SignedOut => true, - Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => { - false + Status::ConnectionError + | Status::ConnectionLost + | Status::Authenticating { .. } + | Status::Reauthenticating { .. } + | Status::ReconnectionError { .. } => false, + Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => { + return Ok(()) } - Status::Connected { .. } - | Status::Connecting { .. } - | Status::Reconnecting { .. } - | Status::Authenticating - | Status::Reauthenticating => return Ok(()), Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?, }; @@ -593,13 +593,22 @@ impl Client { read_from_keychain = credentials.is_some(); } if credentials.is_none() { - credentials = Some(match self.authenticate(&cx).await { - Ok(credentials) => credentials, - Err(err) => { - self.set_status(Status::ConnectionError, cx); - return Err(err); + let mut status_rx = self.status(); + let _ = status_rx.next().await; + futures::select_biased! { + authenticate = self.authenticate(&cx).fuse() => { + match authenticate { + Ok(creds) => credentials = Some(creds), + Err(err) => { + self.set_status(Status::ConnectionError, cx); + return Err(err); + } + } } - }); + _ = status_rx.next().fuse() => { + return Err(anyhow!("authentication canceled")); + } + } } let credentials = credentials.unwrap(); @@ -899,40 +908,42 @@ impl Client { // custom URL scheme instead of this local HTTP server. let (user_id, access_token) = executor .spawn(async move { - if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? { - let path = req.url(); - let mut user_id = None; - let mut access_token = None; - let url = Url::parse(&format!("http://example.com{}", path)) - .context("failed to parse login notification url")?; - for (key, value) in url.query_pairs() { - if key == "access_token" { - access_token = Some(value.to_string()); - } else if key == "user_id" { - user_id = Some(value.to_string()); + for _ in 0..100 { + if let Some(req) = server.recv_timeout(Duration::from_secs(1))? { + let path = req.url(); + let mut user_id = None; + let mut access_token = None; + let url = Url::parse(&format!("http://example.com{}", path)) + .context("failed to parse login notification url")?; + for (key, value) in url.query_pairs() { + if key == "access_token" { + access_token = Some(value.to_string()); + } else if key == "user_id" { + user_id = Some(value.to_string()); + } } - } - let post_auth_url = - format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL); - req.respond( - tiny_http::Response::empty(302).with_header( - tiny_http::Header::from_bytes( - &b"Location"[..], - post_auth_url.as_bytes(), - ) - .unwrap(), - ), - ) - .context("failed to respond to login http request")?; - Ok(( - user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?, - access_token - .ok_or_else(|| anyhow!("missing access_token parameter"))?, - )) - } else { - Err(anyhow!("didn't receive login redirect")) + let post_auth_url = + format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL); + req.respond( + tiny_http::Response::empty(302).with_header( + tiny_http::Header::from_bytes( + &b"Location"[..], + post_auth_url.as_bytes(), + ) + .unwrap(), + ), + ) + .context("failed to respond to login http request")?; + return Ok(( + user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?, + access_token + .ok_or_else(|| anyhow!("missing access_token parameter"))?, + )); + } } + + Err(anyhow!("didn't receive login redirect")) }) .await?; @@ -1061,7 +1072,9 @@ pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> { mod tests { use super::*; use crate::test::{FakeHttpClient, FakeServer}; - use gpui::TestAppContext; + use gpui::{executor::Deterministic, TestAppContext}; + use parking_lot::Mutex; + use std::future; #[gpui::test(iterations = 10)] async fn test_reconnection(cx: &mut TestAppContext) { @@ -1098,6 +1111,48 @@ mod tests { assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token } + #[gpui::test(iterations = 10)] + async fn test_authenticating_more_than_once( + cx: &mut TestAppContext, + deterministic: Arc, + ) { + cx.foreground().forbid_parking(); + + let auth_count = Arc::new(Mutex::new(0)); + let dropped_auth_count = Arc::new(Mutex::new(0)); + let client = Client::new(FakeHttpClient::with_404_response()); + client.override_authenticate({ + let auth_count = auth_count.clone(); + let dropped_auth_count = dropped_auth_count.clone(); + move |cx| { + let auth_count = auth_count.clone(); + let dropped_auth_count = dropped_auth_count.clone(); + cx.foreground().spawn(async move { + *auth_count.lock() += 1; + let _drop = util::defer(move || *dropped_auth_count.lock() += 1); + future::pending::<()>().await; + unreachable!() + }) + } + }); + + let _authenticate = cx.spawn(|cx| { + let client = client.clone(); + async move { client.authenticate_and_connect(false, &cx).await } + }); + deterministic.run_until_parked(); + assert_eq!(*auth_count.lock(), 1); + assert_eq!(*dropped_auth_count.lock(), 0); + + let _authenticate = cx.spawn(|cx| { + let client = client.clone(); + async move { client.authenticate_and_connect(false, &cx).await } + }); + deterministic.run_until_parked(); + assert_eq!(*auth_count.lock(), 2); + assert_eq!(*dropped_auth_count.lock(), 1); + } + #[test] fn test_encode_and_decode_worktree_url() { let url = encode_worktree_url(5, "deadbeef"); diff --git a/crates/contacts_panel/src/contacts_panel.rs b/crates/contacts_panel/src/contacts_panel.rs index 532d545b9121ddcbdd809b23567fb768fb265cfa..aef90378791a9f49f6e08dddd647324ae0ee361c 100644 --- a/crates/contacts_panel/src/contacts_panel.rs +++ b/crates/contacts_panel/src/contacts_panel.rs @@ -1270,13 +1270,6 @@ mod tests { .detach(); }); - let request = server.receive::().await.unwrap(); - server - .respond( - request.receipt(), - proto::RegisterProjectResponse { project_id: 200 }, - ) - .await; let get_users_request = server.receive::().await.unwrap(); server .respond( @@ -1307,6 +1300,14 @@ mod tests { ) .await; + let request = server.receive::().await.unwrap(); + server + .respond( + request.receipt(), + proto::RegisterProjectResponse { project_id: 200 }, + ) + .await; + server.send(proto::UpdateContacts { incoming_requests: vec![proto::IncomingContactRequest { requester_id: 1, diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index ac3d1c36960d8914d40506a2e30cf7bdbec8f302..c060f57072692767305e12e55fed7c2c80ce7481 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -1811,7 +1811,7 @@ impl Workspace { match &*self.client.status().borrow() { client::Status::ConnectionError | client::Status::ConnectionLost - | client::Status::Reauthenticating + | client::Status::Reauthenticating { .. } | client::Status::Reconnecting { .. } | client::Status::ReconnectionError { .. } => Some( Container::new(