@@ -569,14 +569,14 @@ impl Client {
) -> anyhow::Result<()> {
let was_disconnected = match *self.status().borrow() {
Status::SignedOut => true,
- Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
- false
+ Status::ConnectionError
+ | Status::ConnectionLost
+ | Status::Authenticating { .. }
+ | Status::Reauthenticating { .. }
+ | Status::ReconnectionError { .. } => false,
+ Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
+ return Ok(())
}
- Status::Connected { .. }
- | Status::Connecting { .. }
- | Status::Reconnecting { .. }
- | Status::Authenticating
- | Status::Reauthenticating => return Ok(()),
Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
};
@@ -593,13 +593,22 @@ impl Client {
read_from_keychain = credentials.is_some();
}
if credentials.is_none() {
- credentials = Some(match self.authenticate(&cx).await {
- Ok(credentials) => credentials,
- Err(err) => {
- self.set_status(Status::ConnectionError, cx);
- return Err(err);
+ let mut status_rx = self.status();
+ let _ = status_rx.next().await;
+ futures::select_biased! {
+ authenticate = self.authenticate(&cx).fuse() => {
+ match authenticate {
+ Ok(creds) => credentials = Some(creds),
+ Err(err) => {
+ self.set_status(Status::ConnectionError, cx);
+ return Err(err);
+ }
+ }
}
- });
+ _ = status_rx.next().fuse() => {
+ return Err(anyhow!("authentication canceled"));
+ }
+ }
}
let credentials = credentials.unwrap();
@@ -899,40 +908,42 @@ impl Client {
// custom URL scheme instead of this local HTTP server.
let (user_id, access_token) = executor
.spawn(async move {
- if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
- let path = req.url();
- let mut user_id = None;
- let mut access_token = None;
- let url = Url::parse(&format!("http://example.com{}", path))
- .context("failed to parse login notification url")?;
- for (key, value) in url.query_pairs() {
- if key == "access_token" {
- access_token = Some(value.to_string());
- } else if key == "user_id" {
- user_id = Some(value.to_string());
+ for _ in 0..100 {
+ if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
+ let path = req.url();
+ let mut user_id = None;
+ let mut access_token = None;
+ let url = Url::parse(&format!("http://example.com{}", path))
+ .context("failed to parse login notification url")?;
+ for (key, value) in url.query_pairs() {
+ if key == "access_token" {
+ access_token = Some(value.to_string());
+ } else if key == "user_id" {
+ user_id = Some(value.to_string());
+ }
}
- }
- let post_auth_url =
- format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
- req.respond(
- tiny_http::Response::empty(302).with_header(
- tiny_http::Header::from_bytes(
- &b"Location"[..],
- post_auth_url.as_bytes(),
- )
- .unwrap(),
- ),
- )
- .context("failed to respond to login http request")?;
- Ok((
- user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
- access_token
- .ok_or_else(|| anyhow!("missing access_token parameter"))?,
- ))
- } else {
- Err(anyhow!("didn't receive login redirect"))
+ let post_auth_url =
+ format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
+ req.respond(
+ tiny_http::Response::empty(302).with_header(
+ tiny_http::Header::from_bytes(
+ &b"Location"[..],
+ post_auth_url.as_bytes(),
+ )
+ .unwrap(),
+ ),
+ )
+ .context("failed to respond to login http request")?;
+ return Ok((
+ user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
+ access_token
+ .ok_or_else(|| anyhow!("missing access_token parameter"))?,
+ ));
+ }
}
+
+ Err(anyhow!("didn't receive login redirect"))
})
.await?;
@@ -1061,7 +1072,9 @@ pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
mod tests {
use super::*;
use crate::test::{FakeHttpClient, FakeServer};
- use gpui::TestAppContext;
+ use gpui::{executor::Deterministic, TestAppContext};
+ use parking_lot::Mutex;
+ use std::future;
#[gpui::test(iterations = 10)]
async fn test_reconnection(cx: &mut TestAppContext) {
@@ -1098,6 +1111,48 @@ mod tests {
assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
}
+ #[gpui::test(iterations = 10)]
+ async fn test_authenticating_more_than_once(
+ cx: &mut TestAppContext,
+ deterministic: Arc<Deterministic>,
+ ) {
+ cx.foreground().forbid_parking();
+
+ let auth_count = Arc::new(Mutex::new(0));
+ let dropped_auth_count = Arc::new(Mutex::new(0));
+ let client = Client::new(FakeHttpClient::with_404_response());
+ client.override_authenticate({
+ let auth_count = auth_count.clone();
+ let dropped_auth_count = dropped_auth_count.clone();
+ move |cx| {
+ let auth_count = auth_count.clone();
+ let dropped_auth_count = dropped_auth_count.clone();
+ cx.foreground().spawn(async move {
+ *auth_count.lock() += 1;
+ let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
+ future::pending::<()>().await;
+ unreachable!()
+ })
+ }
+ });
+
+ let _authenticate = cx.spawn(|cx| {
+ let client = client.clone();
+ async move { client.authenticate_and_connect(false, &cx).await }
+ });
+ deterministic.run_until_parked();
+ assert_eq!(*auth_count.lock(), 1);
+ assert_eq!(*dropped_auth_count.lock(), 0);
+
+ let _authenticate = cx.spawn(|cx| {
+ let client = client.clone();
+ async move { client.authenticate_and_connect(false, &cx).await }
+ });
+ deterministic.run_until_parked();
+ assert_eq!(*auth_count.lock(), 2);
+ assert_eq!(*dropped_auth_count.lock(), 1);
+ }
+
#[test]
fn test_encode_and_decode_worktree_url() {
let url = encode_worktree_url(5, "deadbeef");
@@ -1270,13 +1270,6 @@ mod tests {
.detach();
});
- let request = server.receive::<proto::RegisterProject>().await.unwrap();
- server
- .respond(
- request.receipt(),
- proto::RegisterProjectResponse { project_id: 200 },
- )
- .await;
let get_users_request = server.receive::<proto::GetUsers>().await.unwrap();
server
.respond(
@@ -1307,6 +1300,14 @@ mod tests {
)
.await;
+ let request = server.receive::<proto::RegisterProject>().await.unwrap();
+ server
+ .respond(
+ request.receipt(),
+ proto::RegisterProjectResponse { project_id: 200 },
+ )
+ .await;
+
server.send(proto::UpdateContacts {
incoming_requests: vec![proto::IncomingContactRequest {
requester_id: 1,