Add development credentials provider (#11505)

Marshall Bowers created

This PR adds a new development credentials provider for the purpose of
streamlining local development against production collab.

## Problem

Today if you want to run a development build of Zed against the
production collab server, you need to either:

1. Enter your keychain password every time in order to retrieve your
saved credentials
2. Re-authenticate with zed.dev every time
    - This can get annoying as you need to pop out into a browser window
- I've also seen cases where if you re-auth too many times in a row
GitHub will make you confirm the authentication, as it looks suspicious

## Solution

This PR decouples the concept of the credentials provider from the
keychain, and adds a new development credentials provider to address
this specific case.

Now when running a development build of Zed and the
`ZED_DEVELOPMENT_AUTH` environment variable is set to a non-empty value,
the credentials will be saved to disk instead of the system keychain.

While this is not as secure as storing them in the system keychain,
since it is only used for development the tradeoff seems acceptable for
the resulting improvement in UX.

Release Notes:

- N/A

Change summary

crates/client/Cargo.toml    |  19 +-
crates/client/src/client.rs | 229 +++++++++++++++++++++++++++++++-------
crates/zed/src/main.rs      |   6 
3 files changed, 200 insertions(+), 54 deletions(-)

Detailed changes

crates/client/Cargo.toml 🔗

@@ -16,39 +16,38 @@ doctest = false
 test-support = ["clock/test-support", "collections/test-support", "gpui/test-support", "rpc/test-support"]
 
 [dependencies]
+anyhow.workspace = true
+async-recursion = "0.3"
+async-tungstenite = { version = "0.16", features = ["async-std", "async-native-tls"] }
 chrono = { workspace = true, features = ["serde"] }
 clock.workspace = true
 collections.workspace = true
-gpui.workspace = true
-util.workspace = true
-release_channel.workspace = true
-rpc.workspace = true
-text.workspace = true
-settings.workspace = true
 feature_flags.workspace = true
-
-anyhow.workspace = true
-async-recursion = "0.3"
-async-tungstenite = { version = "0.16", features = ["async-std", "async-native-tls"] }
 futures.workspace = true
+gpui.workspace = true
 lazy_static.workspace = true
 log.workspace = true
 once_cell = "1.19.0"
 parking_lot.workspace = true
 postage.workspace = true
 rand.workspace = true
+release_channel.workspace = true
+rpc.workspace = true
 schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true
+settings.workspace = true
 sha2.workspace = true
 smol.workspace = true
 sysinfo.workspace = true
 telemetry_events.workspace = true
 tempfile.workspace = true
+text.workspace = true
 thiserror.workspace = true
 time.workspace = true
 tiny_http = "0.8"
 url.workspace = true
+util.workspace = true
 
 [dev-dependencies]
 clock = { workspace = true, features = ["test-support"] }

crates/client/src/client.rs 🔗

@@ -30,6 +30,7 @@ use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsSources, SettingsStore};
 use std::fmt;
+use std::pin::Pin;
 use std::{
     any::TypeId,
     convert::TryFrom,
@@ -65,6 +66,13 @@ impl fmt::Display for DevServerToken {
 lazy_static! {
     static ref ZED_SERVER_URL: Option<String> = std::env::var("ZED_SERVER_URL").ok();
     static ref ZED_RPC_URL: Option<String> = std::env::var("ZED_RPC_URL").ok();
+    /// An environment variable whose presence indicates that the development auth
+    /// provider should be used.
+    ///
+    /// Only works in development. Setting this environment variable in other release
+    /// channels is a no-op.
+    pub static ref ZED_DEVELOPMENT_AUTH: bool =
+        std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty());
     pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
         .ok()
         .and_then(|s| if s.is_empty() { None } else { Some(s) });
@@ -161,6 +169,7 @@ pub struct Client {
     peer: Arc<Peer>,
     http: Arc<HttpClientWithUrl>,
     telemetry: Arc<Telemetry>,
+    credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
     state: RwLock<ClientState>,
 
     #[allow(clippy::type_complexity)]
@@ -298,6 +307,32 @@ impl Credentials {
     }
 }
 
+/// A provider for [`Credentials`].
+///
+/// Used to abstract over reading and writing credentials to some form of
+/// persistence (like the system keychain).
+trait CredentialsProvider {
+    /// Reads the credentials from the provider.
+    fn read_credentials<'a>(
+        &'a self,
+        cx: &'a AsyncAppContext,
+    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>>;
+
+    /// Writes the credentials to the provider.
+    fn write_credentials<'a>(
+        &'a self,
+        user_id: u64,
+        access_token: String,
+        cx: &'a AsyncAppContext,
+    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
+
+    /// Deletes the credentials from the provider.
+    fn delete_credentials<'a>(
+        &'a self,
+        cx: &'a AsyncAppContext,
+    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
+}
+
 impl Default for ClientState {
     fn default() -> Self {
         Self {
@@ -443,11 +478,27 @@ impl Client {
         http: Arc<HttpClientWithUrl>,
         cx: &mut AppContext,
     ) -> Arc<Self> {
+        let use_zed_development_auth = match ReleaseChannel::try_global(cx) {
+            Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH,
+            Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
+            | None => false,
+        };
+
+        let credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static> =
+            if use_zed_development_auth {
+                Arc::new(DevelopmentCredentialsProvider {
+                    path: util::paths::CONFIG_DIR.join("development_auth"),
+                })
+            } else {
+                Arc::new(KeychainCredentialsProvider)
+            };
+
         Arc::new(Self {
             id: AtomicU64::new(0),
             peer: Peer::new(0),
             telemetry: Telemetry::new(clock, http.clone(), cx),
             http,
+            credentials_provider,
             state: Default::default(),
 
             #[cfg(any(test, feature = "test-support"))]
@@ -763,8 +814,11 @@ impl Client {
         }
     }
 
-    pub async fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
-        read_credentials_from_keychain(cx).await.is_some()
+    pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool {
+        self.credentials_provider
+            .read_credentials(cx)
+            .await
+            .is_some()
     }
 
     pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self {
@@ -775,7 +829,7 @@ impl Client {
     #[async_recursion(?Send)]
     pub async fn authenticate_and_connect(
         self: &Arc<Self>,
-        try_keychain: bool,
+        try_provider: bool,
         cx: &AsyncAppContext,
     ) -> anyhow::Result<()> {
         let was_disconnected = match *self.status().borrow() {
@@ -796,12 +850,13 @@ impl Client {
             self.set_status(Status::Reauthenticating, cx)
         }
 
-        let mut read_from_keychain = false;
+        let mut read_from_provider = false;
         let mut credentials = self.state.read().credentials.clone();
-        if credentials.is_none() && try_keychain {
-            credentials = read_credentials_from_keychain(cx).await;
-            read_from_keychain = credentials.is_some();
+        if credentials.is_none() && try_provider {
+            credentials = self.credentials_provider.read_credentials(cx).await;
+            read_from_provider = credentials.is_some();
         }
+
         if credentials.is_none() {
             let mut status_rx = self.status();
             let _ = status_rx.next().await;
@@ -838,9 +893,9 @@ impl Client {
                 match connection {
                     Ok(conn) => {
                         self.state.write().credentials = Some(credentials.clone());
-                        if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
+                        if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
                             if let Credentials::User{user_id, access_token} = credentials {
-                                write_credentials_to_keychain(user_id, access_token, cx).await.log_err();
+                                self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err();
                             }
                         }
 
@@ -854,8 +909,8 @@ impl Client {
                     }
                     Err(EstablishConnectionError::Unauthorized) => {
                         self.state.write().credentials.take();
-                        if read_from_keychain {
-                            delete_credentials_from_keychain(cx).await.log_err();
+                        if read_from_provider {
+                            self.credentials_provider.delete_credentials(cx).await.log_err();
                             self.set_status(Status::SignedOut, cx);
                             self.authenticate_and_connect(false, cx).await
                         } else {
@@ -1264,8 +1319,11 @@ impl Client {
         self.state.write().credentials = None;
         self.disconnect(&cx);
 
-        if self.has_keychain_credentials(cx).await {
-            delete_credentials_from_keychain(cx).await.log_err();
+        if self.has_credentials(cx).await {
+            self.credentials_provider
+                .delete_credentials(cx)
+                .await
+                .log_err();
         }
     }
 
@@ -1465,41 +1523,128 @@ impl Client {
     }
 }
 
-async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
-    if IMPERSONATE_LOGIN.is_some() {
-        return None;
+#[derive(Serialize, Deserialize)]
+struct DevelopmentCredentials {
+    user_id: u64,
+    access_token: String,
+}
+
+/// A credentials provider that stores credentials in a local file.
+///
+/// This MUST only be used in development, as this is not a secure way of storing
+/// credentials on user machines.
+///
+/// Its existence is purely to work around the annoyance of having to constantly
+/// re-allow access to the system keychain when developing Zed.
+struct DevelopmentCredentialsProvider {
+    path: PathBuf,
+}
+
+impl CredentialsProvider for DevelopmentCredentialsProvider {
+    fn read_credentials<'a>(
+        &'a self,
+        _cx: &'a AsyncAppContext,
+    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
+        async move {
+            if IMPERSONATE_LOGIN.is_some() {
+                return None;
+            }
+
+            let json = std::fs::read(&self.path).log_err()?;
+
+            let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?;
+
+            Some(Credentials::User {
+                user_id: credentials.user_id,
+                access_token: credentials.access_token,
+            })
+        }
+        .boxed_local()
     }
 
-    let (user_id, access_token) = cx
-        .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
-        .log_err()?
-        .await
-        .log_err()??;
+    fn write_credentials<'a>(
+        &'a self,
+        user_id: u64,
+        access_token: String,
+        _cx: &'a AsyncAppContext,
+    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+        async move {
+            let json = serde_json::to_string(&DevelopmentCredentials {
+                user_id,
+                access_token,
+            })?;
 
-    Some(Credentials::User {
-        user_id: user_id.parse().ok()?,
-        access_token: String::from_utf8(access_token).ok()?,
-    })
-}
+            std::fs::write(&self.path, json)?;
 
-async fn write_credentials_to_keychain(
-    user_id: u64,
-    access_token: String,
-    cx: &AsyncAppContext,
-) -> Result<()> {
-    cx.update(move |cx| {
-        cx.write_credentials(
-            &ClientSettings::get_global(cx).server_url,
-            &user_id.to_string(),
-            access_token.as_bytes(),
-        )
-    })?
-    .await
+            Ok(())
+        }
+        .boxed_local()
+    }
+
+    fn delete_credentials<'a>(
+        &'a self,
+        _cx: &'a AsyncAppContext,
+    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+        async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local()
+    }
 }
 
-async fn delete_credentials_from_keychain(cx: &AsyncAppContext) -> Result<()> {
-    cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
-        .await
+/// A credentials provider that stores credentials in the system keychain.
+struct KeychainCredentialsProvider;
+
+impl CredentialsProvider for KeychainCredentialsProvider {
+    fn read_credentials<'a>(
+        &'a self,
+        cx: &'a AsyncAppContext,
+    ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
+        async move {
+            if IMPERSONATE_LOGIN.is_some() {
+                return None;
+            }
+
+            let (user_id, access_token) = cx
+                .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
+                .log_err()?
+                .await
+                .log_err()??;
+
+            Some(Credentials::User {
+                user_id: user_id.parse().ok()?,
+                access_token: String::from_utf8(access_token).ok()?,
+            })
+        }
+        .boxed_local()
+    }
+
+    fn write_credentials<'a>(
+        &'a self,
+        user_id: u64,
+        access_token: String,
+        cx: &'a AsyncAppContext,
+    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+        async move {
+            cx.update(move |cx| {
+                cx.write_credentials(
+                    &ClientSettings::get_global(cx).server_url,
+                    &user_id.to_string(),
+                    access_token.as_bytes(),
+                )
+            })?
+            .await
+        }
+        .boxed_local()
+    }
+
+    fn delete_credentials<'a>(
+        &'a self,
+        cx: &'a AsyncAppContext,
+    ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+        async move {
+            cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
+                .await
+        }
+        .boxed_local()
+    }
 }
 
 /// prefix for the zed:// url scheme

crates/zed/src/main.rs 🔗

@@ -542,10 +542,12 @@ fn handle_open_request(
 
 async fn authenticate(client: Arc<Client>, cx: &AsyncAppContext) -> Result<()> {
     if stdout_is_a_pty() {
-        if client::IMPERSONATE_LOGIN.is_some() {
+        if *client::ZED_DEVELOPMENT_AUTH {
+            client.authenticate_and_connect(true, &cx).await?;
+        } else if client::IMPERSONATE_LOGIN.is_some() {
             client.authenticate_and_connect(false, &cx).await?;
         }
-    } else if client.has_keychain_credentials(&cx).await {
+    } else if client.has_credentials(&cx).await {
         client.authenticate_and_connect(true, &cx).await?;
     }
     Ok::<_, anyhow::Error>(())