Authenticate via the browser if keychain credentials are invalid

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

Cargo.lock     | 12 ++++++++++
zed/Cargo.toml |  1 
zed/src/rpc.rs | 58 +++++++++++++++++++++++++++++++++++----------------
3 files changed, 53 insertions(+), 18 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -344,6 +344,17 @@ dependencies = [
  "winapi 0.3.9",
 ]
 
+[[package]]
+name = "async-recursion"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d7d78656ba01f1b93024b7c3a0467f1608e4be67d725749fdcd7d2c7678fd7a2"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
 [[package]]
 name = "async-rustls"
 version = "0.1.2"
@@ -5881,6 +5892,7 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "arrayvec 0.7.1",
+ "async-recursion",
  "async-trait",
  "async-tungstenite",
  "cargo-bundle",

zed/Cargo.toml 🔗

@@ -18,6 +18,7 @@ test-support = ["tempdir", "zrpc/test-support"]
 
 [dependencies]
 anyhow = "1.0.38"
+async-recursion = "0.3"
 async-trait = "0.1"
 arrayvec = "0.7.1"
 async-tungstenite = { version = "0.14", features = ["async-tls"] }

zed/src/rpc.rs 🔗

@@ -1,5 +1,6 @@
 use crate::util::ResultExt;
 use anyhow::{anyhow, Context, Result};
+use async_recursion::async_recursion;
 use async_tungstenite::tungstenite::{
     error::Error as WebsocketError,
     http::{Request, StatusCode},
@@ -282,6 +283,7 @@ impl Client {
         }
     }
 
+    #[async_recursion(?Send)]
     pub async fn authenticate_and_connect(
         self: &Arc<Self>,
         cx: &AsyncAppContext,
@@ -304,9 +306,13 @@ impl Client {
             self.set_status(Status::Reauthenticating, cx)
         }
 
+        let mut read_from_keychain = false;
         let credentials = self.state.read().credentials.clone();
         let credentials = if let Some(credentials) = credentials {
             credentials
+        } else if let Some(credentials) = read_credentials_from_keychain(cx) {
+            read_from_keychain = true;
+            credentials
         } else {
             let credentials = match self.authenticate(&cx).await {
                 Ok(credentials) => credentials,
@@ -328,16 +334,27 @@ impl Client {
         match self.establish_connection(&credentials, cx).await {
             Ok(conn) => {
                 log::info!("connected to rpc address {}", *ZED_SERVER_URL);
+                if !read_from_keychain {
+                    write_credentials_to_keychain(&credentials, cx).log_err();
+                }
                 self.set_connection(conn, cx).await;
                 Ok(())
             }
             Err(err) => {
                 if matches!(err, EstablishConnectionError::Unauthorized) {
                     self.state.write().credentials.take();
-                    cx.platform().delete_credentials(&ZED_SERVER_URL).ok();
+                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
+                    if read_from_keychain {
+                        self.set_status(Status::SignedOut, cx);
+                        self.authenticate_and_connect(cx).await
+                    } else {
+                        self.set_status(Status::ConnectionError, cx);
+                        Err(err)?
+                    }
+                } else {
+                    self.set_status(Status::ConnectionError, cx);
+                    Err(err)?
                 }
-                self.set_status(Status::ConnectionError, cx);
-                Err(err)?
             }
         }
     }
@@ -449,18 +466,6 @@ impl Client {
         let platform = cx.platform();
         let executor = cx.background();
         executor.clone().spawn(async move {
-            if let Some((user_id, access_token)) = platform
-                .read_credentials(&ZED_SERVER_URL)
-                .log_err()
-                .flatten()
-            {
-                log::info!("already signed in. user_id: {}", user_id);
-                return Ok(Credentials {
-                    user_id: user_id.parse()?,
-                    access_token: String::from_utf8(access_token).unwrap(),
-                });
-            }
-
             // Generate a pair of asymmetric encryption keys. The public key will be used by the
             // zed server to encrypt the user's access token, so that it can'be intercepted by
             // any other app running on the user's device.
@@ -521,9 +526,6 @@ impl Client {
                 .decrypt_string(&access_token)
                 .context("failed to decrypt access token")?;
             platform.activate(true);
-            platform
-                .write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes())
-                .log_err();
 
             Ok(Credentials {
                 user_id: user_id.parse()?,
@@ -564,6 +566,26 @@ impl Client {
     }
 }
 
+fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
+    let (user_id, access_token) = cx
+        .platform()
+        .read_credentials(&ZED_SERVER_URL)
+        .log_err()
+        .flatten()?;
+    Some(Credentials {
+        user_id: user_id.parse().ok()?,
+        access_token: String::from_utf8(access_token).ok()?,
+    })
+}
+
+fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
+    cx.platform().write_credentials(
+        &ZED_SERVER_URL,
+        &credentials.user_id.to_string(),
+        credentials.access_token.as_bytes(),
+    )
+}
+
 const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
 
 pub fn encode_worktree_url(id: u64, access_token: &str) -> String {