diff --git a/Cargo.lock b/Cargo.lock index aea0c49da87ae5aeed63e66f65aad561b1471870..60b2864c5bd3d70dcbe44bc1547f85d68e852869 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -344,6 +344,17 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "async-recursion" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7d78656ba01f1b93024b7c3a0467f1608e4be67d725749fdcd7d2c7678fd7a2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-rustls" version = "0.1.2" @@ -5881,6 +5892,7 @@ version = "0.1.0" dependencies = [ "anyhow", "arrayvec 0.7.1", + "async-recursion", "async-trait", "async-tungstenite", "cargo-bundle", diff --git a/zed/Cargo.toml b/zed/Cargo.toml index d9c2cc6a584535caace7d267b5a04b10dfffb721..bb4fab264e589d96a83960db49ad69d6b88488a2 100644 --- a/zed/Cargo.toml +++ b/zed/Cargo.toml @@ -18,6 +18,7 @@ test-support = ["tempdir", "zrpc/test-support"] [dependencies] anyhow = "1.0.38" +async-recursion = "0.3" async-trait = "0.1" arrayvec = "0.7.1" async-tungstenite = { version = "0.14", features = ["async-tls"] } diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 9596b671edee8daea398292d3ac9f662877a6481..a01d14193fa833d6e3151eefaff37d555acf4b60 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -1,5 +1,6 @@ use crate::util::ResultExt; use anyhow::{anyhow, Context, Result}; +use async_recursion::async_recursion; use async_tungstenite::tungstenite::{ error::Error as WebsocketError, http::{Request, StatusCode}, @@ -282,6 +283,7 @@ impl Client { } } + #[async_recursion(?Send)] pub async fn authenticate_and_connect( self: &Arc, cx: &AsyncAppContext, @@ -304,9 +306,13 @@ impl Client { self.set_status(Status::Reauthenticating, cx) } + let mut read_from_keychain = false; let credentials = self.state.read().credentials.clone(); let credentials = if let Some(credentials) = credentials { credentials + } else if let Some(credentials) = read_credentials_from_keychain(cx) { + read_from_keychain = true; + credentials } else { let credentials = match self.authenticate(&cx).await { Ok(credentials) => credentials, @@ -328,16 +334,27 @@ impl Client { match self.establish_connection(&credentials, cx).await { Ok(conn) => { log::info!("connected to rpc address {}", *ZED_SERVER_URL); + if !read_from_keychain { + write_credentials_to_keychain(&credentials, cx).log_err(); + } self.set_connection(conn, cx).await; Ok(()) } Err(err) => { if matches!(err, EstablishConnectionError::Unauthorized) { self.state.write().credentials.take(); - cx.platform().delete_credentials(&ZED_SERVER_URL).ok(); + cx.platform().delete_credentials(&ZED_SERVER_URL).log_err(); + if read_from_keychain { + self.set_status(Status::SignedOut, cx); + self.authenticate_and_connect(cx).await + } else { + self.set_status(Status::ConnectionError, cx); + Err(err)? + } + } else { + self.set_status(Status::ConnectionError, cx); + Err(err)? } - self.set_status(Status::ConnectionError, cx); - Err(err)? } } } @@ -449,18 +466,6 @@ impl Client { let platform = cx.platform(); let executor = cx.background(); executor.clone().spawn(async move { - if let Some((user_id, access_token)) = platform - .read_credentials(&ZED_SERVER_URL) - .log_err() - .flatten() - { - log::info!("already signed in. user_id: {}", user_id); - return Ok(Credentials { - user_id: user_id.parse()?, - access_token: String::from_utf8(access_token).unwrap(), - }); - } - // Generate a pair of asymmetric encryption keys. The public key will be used by the // zed server to encrypt the user's access token, so that it can'be intercepted by // any other app running on the user's device. @@ -521,9 +526,6 @@ impl Client { .decrypt_string(&access_token) .context("failed to decrypt access token")?; platform.activate(true); - platform - .write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes()) - .log_err(); Ok(Credentials { user_id: user_id.parse()?, @@ -564,6 +566,26 @@ impl Client { } } +fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option { + let (user_id, access_token) = cx + .platform() + .read_credentials(&ZED_SERVER_URL) + .log_err() + .flatten()?; + Some(Credentials { + user_id: user_id.parse().ok()?, + access_token: String::from_utf8(access_token).ok()?, + }) +} + +fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> { + cx.platform().write_credentials( + &ZED_SERVER_URL, + &credentials.user_id.to_string(), + credentials.access_token.as_bytes(), + ) +} + const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/"; pub fn encode_worktree_url(id: u64, access_token: &str) -> String {