Merge pull request #1443 from zed-industries/non-functional-sign-in

Antonio Scandurra created

Allow signing in again if authentication is pending or was unsuccessful

Change summary

crates/client/src/client.rs                 | 145 +++++++++++++++-------
crates/contacts_panel/src/contacts_panel.rs |  15 +-
crates/workspace/src/workspace.rs           |   2 
3 files changed, 109 insertions(+), 53 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -569,14 +569,14 @@ impl Client {
     ) -> anyhow::Result<()> {
         let was_disconnected = match *self.status().borrow() {
             Status::SignedOut => true,
-            Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
-                false
+            Status::ConnectionError
+            | Status::ConnectionLost
+            | Status::Authenticating { .. }
+            | Status::Reauthenticating { .. }
+            | Status::ReconnectionError { .. } => false,
+            Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
+                return Ok(())
             }
-            Status::Connected { .. }
-            | Status::Connecting { .. }
-            | Status::Reconnecting { .. }
-            | Status::Authenticating
-            | Status::Reauthenticating => return Ok(()),
             Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
         };
 
@@ -593,13 +593,22 @@ impl Client {
             read_from_keychain = credentials.is_some();
         }
         if credentials.is_none() {
-            credentials = Some(match self.authenticate(&cx).await {
-                Ok(credentials) => credentials,
-                Err(err) => {
-                    self.set_status(Status::ConnectionError, cx);
-                    return Err(err);
+            let mut status_rx = self.status();
+            let _ = status_rx.next().await;
+            futures::select_biased! {
+                authenticate = self.authenticate(&cx).fuse() => {
+                    match authenticate {
+                        Ok(creds) => credentials = Some(creds),
+                        Err(err) => {
+                            self.set_status(Status::ConnectionError, cx);
+                            return Err(err);
+                        }
+                    }
                 }
-            });
+                _ = status_rx.next().fuse() => {
+                    return Err(anyhow!("authentication canceled"));
+                }
+            }
         }
         let credentials = credentials.unwrap();
 
@@ -899,40 +908,42 @@ impl Client {
             // custom URL scheme instead of this local HTTP server.
             let (user_id, access_token) = executor
                 .spawn(async move {
-                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
-                        let path = req.url();
-                        let mut user_id = None;
-                        let mut access_token = None;
-                        let url = Url::parse(&format!("http://example.com{}", path))
-                            .context("failed to parse login notification url")?;
-                        for (key, value) in url.query_pairs() {
-                            if key == "access_token" {
-                                access_token = Some(value.to_string());
-                            } else if key == "user_id" {
-                                user_id = Some(value.to_string());
+                    for _ in 0..100 {
+                        if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
+                            let path = req.url();
+                            let mut user_id = None;
+                            let mut access_token = None;
+                            let url = Url::parse(&format!("http://example.com{}", path))
+                                .context("failed to parse login notification url")?;
+                            for (key, value) in url.query_pairs() {
+                                if key == "access_token" {
+                                    access_token = Some(value.to_string());
+                                } else if key == "user_id" {
+                                    user_id = Some(value.to_string());
+                                }
                             }
-                        }
 
-                        let post_auth_url =
-                            format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
-                        req.respond(
-                            tiny_http::Response::empty(302).with_header(
-                                tiny_http::Header::from_bytes(
-                                    &b"Location"[..],
-                                    post_auth_url.as_bytes(),
-                                )
-                                .unwrap(),
-                            ),
-                        )
-                        .context("failed to respond to login http request")?;
-                        Ok((
-                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
-                            access_token
-                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
-                        ))
-                    } else {
-                        Err(anyhow!("didn't receive login redirect"))
+                            let post_auth_url =
+                                format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL);
+                            req.respond(
+                                tiny_http::Response::empty(302).with_header(
+                                    tiny_http::Header::from_bytes(
+                                        &b"Location"[..],
+                                        post_auth_url.as_bytes(),
+                                    )
+                                    .unwrap(),
+                                ),
+                            )
+                            .context("failed to respond to login http request")?;
+                            return Ok((
+                                user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
+                                access_token
+                                    .ok_or_else(|| anyhow!("missing access_token parameter"))?,
+                            ));
+                        }
                     }
+
+                    Err(anyhow!("didn't receive login redirect"))
                 })
                 .await?;
 
@@ -1061,7 +1072,9 @@ pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
 mod tests {
     use super::*;
     use crate::test::{FakeHttpClient, FakeServer};
-    use gpui::TestAppContext;
+    use gpui::{executor::Deterministic, TestAppContext};
+    use parking_lot::Mutex;
+    use std::future;
 
     #[gpui::test(iterations = 10)]
     async fn test_reconnection(cx: &mut TestAppContext) {
@@ -1098,6 +1111,48 @@ mod tests {
         assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
     }
 
+    #[gpui::test(iterations = 10)]
+    async fn test_authenticating_more_than_once(
+        cx: &mut TestAppContext,
+        deterministic: Arc<Deterministic>,
+    ) {
+        cx.foreground().forbid_parking();
+
+        let auth_count = Arc::new(Mutex::new(0));
+        let dropped_auth_count = Arc::new(Mutex::new(0));
+        let client = Client::new(FakeHttpClient::with_404_response());
+        client.override_authenticate({
+            let auth_count = auth_count.clone();
+            let dropped_auth_count = dropped_auth_count.clone();
+            move |cx| {
+                let auth_count = auth_count.clone();
+                let dropped_auth_count = dropped_auth_count.clone();
+                cx.foreground().spawn(async move {
+                    *auth_count.lock() += 1;
+                    let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
+                    future::pending::<()>().await;
+                    unreachable!()
+                })
+            }
+        });
+
+        let _authenticate = cx.spawn(|cx| {
+            let client = client.clone();
+            async move { client.authenticate_and_connect(false, &cx).await }
+        });
+        deterministic.run_until_parked();
+        assert_eq!(*auth_count.lock(), 1);
+        assert_eq!(*dropped_auth_count.lock(), 0);
+
+        let _authenticate = cx.spawn(|cx| {
+            let client = client.clone();
+            async move { client.authenticate_and_connect(false, &cx).await }
+        });
+        deterministic.run_until_parked();
+        assert_eq!(*auth_count.lock(), 2);
+        assert_eq!(*dropped_auth_count.lock(), 1);
+    }
+
     #[test]
     fn test_encode_and_decode_worktree_url() {
         let url = encode_worktree_url(5, "deadbeef");

crates/contacts_panel/src/contacts_panel.rs 🔗

@@ -1270,13 +1270,6 @@ mod tests {
             .detach();
         });
 
-        let request = server.receive::<proto::RegisterProject>().await.unwrap();
-        server
-            .respond(
-                request.receipt(),
-                proto::RegisterProjectResponse { project_id: 200 },
-            )
-            .await;
         let get_users_request = server.receive::<proto::GetUsers>().await.unwrap();
         server
             .respond(
@@ -1307,6 +1300,14 @@ mod tests {
             )
             .await;
 
+        let request = server.receive::<proto::RegisterProject>().await.unwrap();
+        server
+            .respond(
+                request.receipt(),
+                proto::RegisterProjectResponse { project_id: 200 },
+            )
+            .await;
+
         server.send(proto::UpdateContacts {
             incoming_requests: vec![proto::IncomingContactRequest {
                 requester_id: 1,

crates/workspace/src/workspace.rs 🔗

@@ -1811,7 +1811,7 @@ impl Workspace {
         match &*self.client.status().borrow() {
             client::Status::ConnectionError
             | client::Status::ConnectionLost
-            | client::Status::Reauthenticating
+            | client::Status::Reauthenticating { .. }
             | client::Status::Reconnecting { .. }
             | client::Status::ReconnectionError { .. } => Some(
                 Container::new(