Add env var checkbox

Richard Feldman created

Change summary

crates/extension_host/src/extension_host.rs             |  13 
crates/extension_host/src/extension_settings.rs         |  13 
crates/extension_host/src/wasm_host/llm_provider.rs     | 212 ++++++++++
crates/extension_host/src/wasm_host/wit/since_v0_7_0.rs |  50 ++
crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs |  50 ++
crates/settings/src/settings_content/extension.rs       |   6 
6 files changed, 339 insertions(+), 5 deletions(-)

Detailed changes

crates/extension_host/src/extension_host.rs 🔗

@@ -68,6 +68,7 @@ struct LlmProviderWithModels {
     models: Vec<LlmModelInfo>,
     is_authenticated: bool,
     icon_path: Option<SharedString>,
+    auth_config: Option<extension::LanguageModelAuthConfig>,
 }
 
 pub use extension::{
@@ -1476,11 +1477,20 @@ impl ExtensionStore {
                                         SharedString::from(absolute_icon_path)
                                     });
 
+                                    let provider_id_arc: Arc<str> =
+                                        provider_info.id.as_str().into();
+                                    let auth_config = extension
+                                        .manifest
+                                        .language_model_providers
+                                        .get(&provider_id_arc)
+                                        .and_then(|entry| entry.auth.clone());
+
                                     llm_providers_with_models.push(LlmProviderWithModels {
                                         provider_info,
                                         models,
                                         is_authenticated,
                                         icon_path,
+                                        auth_config,
                                     });
                                 }
                             } else {
@@ -1579,12 +1589,13 @@ impl ExtensionStore {
                         let mods = llm_provider.models.clone();
                         let auth = llm_provider.is_authenticated;
                         let icon = llm_provider.icon_path.clone();
+                        let auth_config = llm_provider.auth_config.clone();
 
                         this.proxy.register_language_model_provider(
                             provider_id.clone(),
                             Box::new(move |cx: &mut App| {
                                 let provider = Arc::new(ExtensionLanguageModelProvider::new(
-                                    wasm_ext, pinfo, mods, auth, icon, cx,
+                                    wasm_ext, pinfo, mods, auth, icon, auth_config, cx,
                                 ));
                                 language_model::LanguageModelRegistry::global(cx).update(
                                     cx,

crates/extension_host/src/extension_settings.rs 🔗

@@ -1,4 +1,4 @@
-use collections::HashMap;
+use collections::{HashMap, HashSet};
 use extension::{
     DownloadFileCapability, ExtensionCapability, NpmInstallPackageCapability, ProcessExecCapability,
 };
@@ -16,6 +16,10 @@ pub struct ExtensionSettings {
     pub auto_install_extensions: HashMap<Arc<str>, bool>,
     pub auto_update_extensions: HashMap<Arc<str>, bool>,
     pub granted_capabilities: Vec<ExtensionCapability>,
+    /// The extension language model providers that are allowed to read API keys
+    /// from environment variables. Each entry is a provider ID in the format
+    /// "extension_id:provider_id".
+    pub allowed_env_var_providers: HashSet<Arc<str>>,
 }
 
 impl ExtensionSettings {
@@ -60,6 +64,13 @@ impl Settings for ExtensionSettings {
                     }
                 })
                 .collect(),
+            allowed_env_var_providers: content
+                .extension
+                .allowed_env_var_providers
+                .clone()
+                .unwrap_or_default()
+                .into_iter()
+                .collect(),
         }
     }
 }

crates/extension_host/src/wasm_host/llm_provider.rs 🔗

@@ -1,3 +1,4 @@
+use crate::ExtensionSettings;
 use crate::wasm_host::WasmExtension;
 
 use crate::wasm_host::wit::{
@@ -9,6 +10,7 @@ use crate::wasm_host::wit::{
 use anyhow::{Result, anyhow};
 use credentials_provider::CredentialsProvider;
 use editor::Editor;
+use extension::LanguageModelAuthConfig;
 use futures::future::BoxFuture;
 use futures::stream::BoxStream;
 use futures::{FutureExt, StreamExt};
@@ -37,12 +39,15 @@ pub struct ExtensionLanguageModelProvider {
     pub extension: WasmExtension,
     pub provider_info: LlmProviderInfo,
     icon_path: Option<SharedString>,
+    auth_config: Option<LanguageModelAuthConfig>,
     state: Entity<ExtensionLlmProviderState>,
 }
 
 pub struct ExtensionLlmProviderState {
     is_authenticated: bool,
     available_models: Vec<LlmModelInfo>,
+    env_var_allowed: bool,
+    api_key_from_env: bool,
 }
 
 impl EventEmitter<()> for ExtensionLlmProviderState {}
@@ -54,17 +59,42 @@ impl ExtensionLanguageModelProvider {
         models: Vec<LlmModelInfo>,
         is_authenticated: bool,
         icon_path: Option<SharedString>,
+        auth_config: Option<LanguageModelAuthConfig>,
         cx: &mut App,
     ) -> Self {
+        let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id);
+        let env_var_allowed = ExtensionSettings::get_global(cx)
+            .allowed_env_var_providers
+            .contains(provider_id_string.as_str());
+
+        let (is_authenticated, api_key_from_env) =
+            if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) {
+                let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap();
+                if let Ok(value) = std::env::var(env_var_name) {
+                    if !value.is_empty() {
+                        (true, true)
+                    } else {
+                        (is_authenticated, false)
+                    }
+                } else {
+                    (is_authenticated, false)
+                }
+            } else {
+                (is_authenticated, false)
+            };
+
         let state = cx.new(|_| ExtensionLlmProviderState {
             is_authenticated,
             available_models: models,
+            env_var_allowed,
+            api_key_from_env,
         });
 
         Self {
             extension,
             provider_info,
             icon_path,
+            auth_config,
             state,
         }
     }
@@ -194,13 +224,17 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
         let credential_key = self.credential_key();
         let extension = self.extension.clone();
         let extension_provider_id = self.provider_info.id.clone();
+        let full_provider_id = self.provider_id_string();
         let state = self.state.clone();
+        let auth_config = self.auth_config.clone();
 
         cx.new(|cx| {
             ExtensionProviderConfigurationView::new(
                 credential_key,
                 extension,
                 extension_provider_id,
+                full_provider_id,
+                auth_config,
                 state,
                 window,
                 cx,
@@ -274,6 +308,8 @@ struct ExtensionProviderConfigurationView {
     credential_key: String,
     extension: WasmExtension,
     extension_provider_id: String,
+    full_provider_id: String,
+    auth_config: Option<LanguageModelAuthConfig>,
     state: Entity<ExtensionLlmProviderState>,
     settings_markdown: Option<Entity<Markdown>>,
     api_key_editor: Entity<Editor>,
@@ -287,6 +323,8 @@ impl ExtensionProviderConfigurationView {
         credential_key: String,
         extension: WasmExtension,
         extension_provider_id: String,
+        full_provider_id: String,
+        auth_config: Option<LanguageModelAuthConfig>,
         state: Entity<ExtensionLlmProviderState>,
         window: &mut Window,
         cx: &mut Context<Self>,
@@ -307,6 +345,8 @@ impl ExtensionProviderConfigurationView {
             credential_key,
             extension,
             extension_provider_id,
+            full_provider_id,
+            auth_config,
             state,
             settings_markdown: None,
             api_key_editor,
@@ -362,7 +402,20 @@ impl ExtensionProviderConfigurationView {
         let credentials_provider = <dyn CredentialsProvider>::global(cx);
         let state = self.state.clone();
 
+        // Check if we should use env var (already set in state during provider construction)
+        let api_key_from_env = self.state.read(cx).api_key_from_env;
+
         cx.spawn(async move |this, cx| {
+            // If using env var, we're already authenticated
+            if api_key_from_env {
+                this.update(cx, |this, cx| {
+                    this.loading_credentials = false;
+                    cx.notify();
+                })
+                .log_err();
+                return;
+            }
+
             let credentials = credentials_provider
                 .read_credentials(&credential_key, cx)
                 .await
@@ -388,6 +441,92 @@ impl ExtensionProviderConfigurationView {
         .detach();
     }
 
+    fn toggle_env_var_permission(&mut self, cx: &mut Context<Self>) {
+        let full_provider_id: Arc<str> = self.full_provider_id.clone().into();
+        let env_var_name = match &self.auth_config {
+            Some(config) => config.env_var.clone(),
+            None => return,
+        };
+
+        let state = self.state.clone();
+        let currently_allowed = self.state.read(cx).env_var_allowed;
+
+        // Update settings file
+        settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, move |settings, _| {
+            let providers = settings
+                .extension
+                .allowed_env_var_providers
+                .get_or_insert_with(Vec::new);
+
+            if currently_allowed {
+                providers.retain(|id| id.as_ref() != full_provider_id.as_ref());
+            } else {
+                if !providers
+                    .iter()
+                    .any(|id| id.as_ref() == full_provider_id.as_ref())
+                {
+                    providers.push(full_provider_id.clone());
+                }
+            }
+        });
+
+        // Update local state
+        let new_allowed = !currently_allowed;
+        let new_from_env = if new_allowed {
+            if let Some(var_name) = &env_var_name {
+                if let Ok(value) = std::env::var(var_name) {
+                    !value.is_empty()
+                } else {
+                    false
+                }
+            } else {
+                false
+            }
+        } else {
+            false
+        };
+
+        state.update(cx, |state, cx| {
+            state.env_var_allowed = new_allowed;
+            state.api_key_from_env = new_from_env;
+            if new_from_env {
+                state.is_authenticated = true;
+            }
+            cx.notify();
+        });
+
+        // If env var is being disabled, reload credentials from keychain
+        if !new_allowed {
+            self.reload_keychain_credentials(cx);
+        }
+
+        cx.notify();
+    }
+
+    fn reload_keychain_credentials(&mut self, cx: &mut Context<Self>) {
+        let credential_key = self.credential_key.clone();
+        let credentials_provider = <dyn CredentialsProvider>::global(cx);
+        let state = self.state.clone();
+
+        cx.spawn(async move |_this, cx| {
+            let credentials = credentials_provider
+                .read_credentials(&credential_key, cx)
+                .await
+                .log_err()
+                .flatten();
+
+            let has_credentials = credentials.is_some();
+
+            let _ = cx.update(|cx| {
+                state.update(cx, |state, cx| {
+                    state.is_authenticated = has_credentials;
+                    cx.notify();
+                });
+            });
+        })
+        .detach();
+    }
+
     fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
         let api_key = self.api_key_editor.read(cx).text(cx);
         if api_key.is_empty() {
@@ -456,6 +595,8 @@ impl gpui::Render for ExtensionProviderConfigurationView {
     fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
         let is_loading = self.loading_settings || self.loading_credentials;
         let is_authenticated = self.is_authenticated(cx);
+        let env_var_allowed = self.state.read(cx).env_var_allowed;
+        let api_key_from_env = self.state.read(cx).api_key_from_env;
 
         if is_loading {
             return v_flex()
@@ -478,8 +619,67 @@ impl gpui::Render for ExtensionProviderConfigurationView {
             );
         }
 
+        // Render env var checkbox if the extension specifies an env var
+        if let Some(auth_config) = &self.auth_config {
+            if let Some(env_var_name) = &auth_config.env_var {
+                let env_var_name = env_var_name.clone();
+                let checkbox_label =
+                    format!("Read API key from {} environment variable", env_var_name);
+
+                content = content.child(
+                    h_flex()
+                        .gap_2()
+                        .child(
+                            ui::Checkbox::new("env-var-permission", env_var_allowed.into())
+                                .on_click(cx.listener(|this, _, _window, cx| {
+                                    this.toggle_env_var_permission(cx);
+                                })),
+                        )
+                        .child(Label::new(checkbox_label).size(LabelSize::Small)),
+                );
+
+                // Show status if env var is allowed
+                if env_var_allowed {
+                    if api_key_from_env {
+                        content = content.child(
+                            h_flex()
+                                .gap_2()
+                                .child(
+                                    ui::Icon::new(ui::IconName::Check)
+                                        .color(Color::Success)
+                                        .size(ui::IconSize::Small),
+                                )
+                                .child(
+                                    Label::new(format!("API key loaded from {}", env_var_name))
+                                        .color(Color::Success),
+                                ),
+                        );
+                        return content.into_any_element();
+                    } else {
+                        content = content.child(
+                            h_flex()
+                                .gap_2()
+                                .child(
+                                    ui::Icon::new(ui::IconName::Warning)
+                                        .color(Color::Warning)
+                                        .size(ui::IconSize::Small),
+                                )
+                                .child(
+                                    Label::new(format!(
+                                        "{} is not set or empty. You can set it and restart Zed, or enter an API key below.",
+                                        env_var_name
+                                    ))
+                                    .color(Color::Warning)
+                                    .size(LabelSize::Small),
+                                ),
+                        );
+                    }
+                }
+            }
+        }
+
         // Render API key section
-        if is_authenticated {
+        if is_authenticated && !api_key_from_env {
             content = content.child(
                 v_flex()
                     .gap_2()
@@ -501,13 +701,19 @@ impl gpui::Render for ExtensionProviderConfigurationView {
                             })),
                     ),
             );
-        } else {
+        } else if !api_key_from_env {
+            let credential_label = self
+                .auth_config
+                .as_ref()
+                .and_then(|c| c.credential_label.clone())
+                .unwrap_or_else(|| "API Key".to_string());
+
             content = content.child(
                 v_flex()
                     .gap_2()
                     .on_action(cx.listener(Self::save_api_key))
                     .child(
-                        Label::new("API Key")
+                        Label::new(credential_label)
                             .size(LabelSize::Small)
                             .color(Color::Muted),
                     )

crates/extension_host/src/wasm_host/wit/since_v0_7_0.rs 🔗

@@ -1,3 +1,4 @@
+use crate::ExtensionSettings;
 use crate::wasm_host::wit::since_v0_7_0::{
     dap::{
         AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, LaunchRequest,
@@ -1195,6 +1196,55 @@ impl ExtensionImports for WasmState {
     }
 
     async fn llm_get_env_var(&mut self, name: String) -> wasmtime::Result<Option<String>> {
+        let extension_id = self.manifest.id.clone();
+
+        // Find which provider (if any) declares this env var in its auth config
+        let mut allowed_provider_id: Option<Arc<str>> = None;
+        for (provider_id, provider_entry) in &self.manifest.language_model_providers {
+            if let Some(auth_config) = &provider_entry.auth {
+                if auth_config.env_var.as_deref() == Some(&name) {
+                    allowed_provider_id = Some(provider_id.clone());
+                    break;
+                }
+            }
+        }
+
+        // If no provider declares this env var, deny access
+        let Some(provider_id) = allowed_provider_id else {
+            log::warn!(
+                "Extension {} attempted to read env var {} which is not declared in any provider auth config",
+                extension_id,
+                name
+            );
+            return Ok(None);
+        };
+
+        // Check if the user has allowed this provider to read env vars
+        let full_provider_id = format!("{}:{}", extension_id, provider_id);
+        let is_allowed = self
+            .on_main_thread(move |cx| {
+                async move {
+                    cx.update(|cx| {
+                        ExtensionSettings::get_global(cx)
+                            .allowed_env_var_providers
+                            .contains(full_provider_id.as_str())
+                    })
+                    .unwrap_or(false)
+                }
+                .boxed_local()
+            })
+            .await;
+
+        if !is_allowed {
+            log::debug!(
+                "Extension {} provider {} is not allowed to read env var {}",
+                extension_id,
+                provider_id,
+                name
+            );
+            return Ok(None);
+        }
+
         Ok(env::var(&name).ok())
     }
 }

crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs 🔗

@@ -1,3 +1,4 @@
+use crate::ExtensionSettings;
 use crate::wasm_host::wit::since_v0_8_0::{
     dap::{
         AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, LaunchRequest,
@@ -1192,6 +1193,55 @@ impl ExtensionImports for WasmState {
     }
 
     async fn llm_get_env_var(&mut self, name: String) -> wasmtime::Result<Option<String>> {
+        let extension_id = self.manifest.id.clone();
+
+        // Find which provider (if any) declares this env var in its auth config
+        let mut allowed_provider_id: Option<Arc<str>> = None;
+        for (provider_id, provider_entry) in &self.manifest.language_model_providers {
+            if let Some(auth_config) = &provider_entry.auth {
+                if auth_config.env_var.as_deref() == Some(&name) {
+                    allowed_provider_id = Some(provider_id.clone());
+                    break;
+                }
+            }
+        }
+
+        // If no provider declares this env var, deny access
+        let Some(provider_id) = allowed_provider_id else {
+            log::warn!(
+                "Extension {} attempted to read env var {} which is not declared in any provider auth config",
+                extension_id,
+                name
+            );
+            return Ok(None);
+        };
+
+        // Check if the user has allowed this provider to read env vars
+        let full_provider_id = format!("{}:{}", extension_id, provider_id);
+        let is_allowed = self
+            .on_main_thread(move |cx| {
+                async move {
+                    cx.update(|cx| {
+                        ExtensionSettings::get_global(cx)
+                            .allowed_env_var_providers
+                            .contains(full_provider_id.as_str())
+                    })
+                    .unwrap_or(false)
+                }
+                .boxed_local()
+            })
+            .await;
+
+        if !is_allowed {
+            log::debug!(
+                "Extension {} provider {} is not allowed to read env var {}",
+                extension_id,
+                provider_id,
+                name
+            );
+            return Ok(None);
+        }
+
         Ok(env::var(&name).ok())
     }
 }

crates/settings/src/settings_content/extension.rs 🔗

@@ -20,6 +20,12 @@ pub struct ExtensionSettingsContent {
     pub auto_update_extensions: HashMap<Arc<str>, bool>,
     /// The capabilities granted to extensions.
     pub granted_extension_capabilities: Option<Vec<ExtensionCapabilityContent>>,
+    /// Extension language model providers that are allowed to read API keys from
+    /// environment variables. Each entry is a provider ID in the format
+    /// "extension_id:provider_id" (e.g., "openai:openai").
+    ///
+    /// Default: []
+    pub allowed_env_var_providers: Option<Vec<Arc<str>>>,
 }
 
 /// A capability for an extension.