Merge pull request #666 from zed-industries/keychain

Max Brunsfeld created

Refine authentication and keychain interaction

Change summary

crates/chat_panel/src/chat_panel.rs |  7 +++++
crates/client/src/client.rs         | 36 +++++++++++++++---------------
crates/client/src/test.rs           |  2 
crates/project/src/project.rs       |  2 
crates/server/src/rpc.rs            |  2 
crates/zed/src/main.rs              | 10 ++++++-
6 files changed, 35 insertions(+), 24 deletions(-)

Detailed changes

crates/chat_panel/src/chat_panel.rs 🔗

@@ -327,7 +327,12 @@ impl ChatPanel {
                 let rpc = rpc.clone();
                 let this = this.clone();
                 cx.spawn(|mut cx| async move {
-                    if rpc.authenticate_and_connect(&cx).log_err().await.is_some() {
+                    if rpc
+                        .authenticate_and_connect(true, &cx)
+                        .log_err()
+                        .await
+                        .is_some()
+                    {
                         cx.update(|cx| {
                             if let Some(this) = this.upgrade(cx) {
                                 if this.is_focused(cx) {

crates/client/src/client.rs 🔗

@@ -45,7 +45,7 @@ pub use user::*;
 lazy_static! {
     static ref ZED_SERVER_URL: String =
         std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
-    static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
+    pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
         .ok()
         .and_then(|s| if s.is_empty() { None } else { Some(s) });
 }
@@ -55,7 +55,7 @@ action!(Authenticate);
 pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
     cx.add_global_action(move |_: &Authenticate, cx| {
         let rpc = rpc.clone();
-        cx.spawn(|cx| async move { rpc.authenticate_and_connect(&cx).log_err().await })
+        cx.spawn(|cx| async move { rpc.authenticate_and_connect(true, &cx).log_err().await })
             .detach();
     });
 }
@@ -302,7 +302,7 @@ impl Client {
                 state._reconnect_task = Some(cx.spawn(|cx| async move {
                     let mut rng = StdRng::from_entropy();
                     let mut delay = Duration::from_millis(100);
-                    while let Err(error) = this.authenticate_and_connect(&cx).await {
+                    while let Err(error) = this.authenticate_and_connect(true, &cx).await {
                         log::error!("failed to connect {}", error);
                         this.set_status(
                             Status::ReconnectionError {
@@ -547,6 +547,7 @@ impl Client {
     #[async_recursion(?Send)]
     pub async fn authenticate_and_connect(
         self: &Arc<Self>,
+        try_keychain: bool,
         cx: &AsyncAppContext,
     ) -> anyhow::Result<()> {
         let was_disconnected = match *self.status().borrow() {
@@ -568,23 +569,22 @@ impl Client {
             self.set_status(Status::Reauthenticating, cx)
         }
 
-        let mut used_keychain = false;
-        let credentials = self.state.read().credentials.clone();
-        let credentials = if let Some(credentials) = credentials {
-            credentials
-        } else if let Some(credentials) = read_credentials_from_keychain(cx) {
-            used_keychain = true;
-            credentials
-        } else {
-            let credentials = match self.authenticate(&cx).await {
+        let mut read_from_keychain = false;
+        let mut credentials = self.state.read().credentials.clone();
+        if credentials.is_none() && try_keychain {
+            credentials = read_credentials_from_keychain(cx);
+            read_from_keychain = credentials.is_some();
+        }
+        if credentials.is_none() {
+            credentials = Some(match self.authenticate(&cx).await {
                 Ok(credentials) => credentials,
                 Err(err) => {
                     self.set_status(Status::ConnectionError, cx);
                     return Err(err);
                 }
-            };
-            credentials
-        };
+            });
+        }
+        let credentials = credentials.unwrap();
 
         if was_disconnected {
             self.set_status(Status::Connecting, cx);
@@ -595,7 +595,7 @@ impl Client {
         match self.establish_connection(&credentials, cx).await {
             Ok(conn) => {
                 self.state.write().credentials = Some(credentials.clone());
-                if !used_keychain && IMPERSONATE_LOGIN.is_none() {
+                if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
                     write_credentials_to_keychain(&credentials, cx).log_err();
                 }
                 self.set_connection(conn, cx).await;
@@ -603,10 +603,10 @@ impl Client {
             }
             Err(EstablishConnectionError::Unauthorized) => {
                 self.state.write().credentials.take();
-                if used_keychain {
+                if read_from_keychain {
                     cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
                     self.set_status(Status::SignedOut, cx);
-                    self.authenticate_and_connect(cx).await
+                    self.authenticate_and_connect(false, cx).await
                 } else {
                     self.set_status(Status::ConnectionError, cx);
                     Err(EstablishConnectionError::Unauthorized)?

crates/client/src/test.rs 🔗

@@ -91,7 +91,7 @@ impl FakeServer {
             });
 
         client
-            .authenticate_and_connect(&cx.to_async())
+            .authenticate_and_connect(false, &cx.to_async())
             .await
             .unwrap();
         server

crates/project/src/project.rs 🔗

@@ -377,7 +377,7 @@ impl Project {
         fs: Arc<dyn Fs>,
         cx: &mut AsyncAppContext,
     ) -> Result<ModelHandle<Self>> {
-        client.authenticate_and_connect(&cx).await?;
+        client.authenticate_and_connect(true, &cx).await?;
 
         let response = client
             .request(proto::JoinProject {

crates/server/src/rpc.rs 🔗

@@ -5021,7 +5021,7 @@ mod tests {
                 });
 
             client
-                .authenticate_and_connect(&cx.to_async())
+                .authenticate_and_connect(false, &cx.to_async())
                 .await
                 .unwrap();
 

crates/zed/src/main.rs 🔗

@@ -81,8 +81,14 @@ fn main() {
         cx.spawn({
             let client = client.clone();
             |cx| async move {
-                if client.has_keychain_credentials(&cx) {
-                    client.authenticate_and_connect(&cx).await?;
+                if stdout_is_a_pty() {
+                    if client::IMPERSONATE_LOGIN.is_some() {
+                        client.authenticate_and_connect(false, &cx).await?;
+                    }
+                } else {
+                    if client.has_keychain_credentials(&cx) {
+                        client.authenticate_and_connect(true, &cx).await?;
+                    }
                 }
                 Ok::<_, anyhow::Error>(())
             }