diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index 869319ded5419d815ccb53a1cb533f8d27aa1879..689224dda0e92a6e715950a61601727bc2a7731d 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -68,6 +68,7 @@ struct LlmProviderWithModels { models: Vec, is_authenticated: bool, icon_path: Option, + auth_config: Option, } pub use extension::{ @@ -1476,11 +1477,20 @@ impl ExtensionStore { SharedString::from(absolute_icon_path) }); + let provider_id_arc: Arc = + provider_info.id.as_str().into(); + let auth_config = extension + .manifest + .language_model_providers + .get(&provider_id_arc) + .and_then(|entry| entry.auth.clone()); + llm_providers_with_models.push(LlmProviderWithModels { provider_info, models, is_authenticated, icon_path, + auth_config, }); } } else { @@ -1579,12 +1589,13 @@ impl ExtensionStore { let mods = llm_provider.models.clone(); let auth = llm_provider.is_authenticated; let icon = llm_provider.icon_path.clone(); + let auth_config = llm_provider.auth_config.clone(); this.proxy.register_language_model_provider( provider_id.clone(), Box::new(move |cx: &mut App| { let provider = Arc::new(ExtensionLanguageModelProvider::new( - wasm_ext, pinfo, mods, auth, icon, cx, + wasm_ext, pinfo, mods, auth, icon, auth_config, cx, )); language_model::LanguageModelRegistry::global(cx).update( cx, diff --git a/crates/extension_host/src/extension_settings.rs b/crates/extension_host/src/extension_settings.rs index 736dd6b87ae53a5ffd57b5697aaf3890cedb6f03..3322ea4068cc08ef8f5257e670ad8cb7088b57b7 100644 --- a/crates/extension_host/src/extension_settings.rs +++ b/crates/extension_host/src/extension_settings.rs @@ -1,4 +1,4 @@ -use collections::HashMap; +use collections::{HashMap, HashSet}; use extension::{ DownloadFileCapability, ExtensionCapability, NpmInstallPackageCapability, ProcessExecCapability, }; @@ -16,6 +16,10 @@ pub struct ExtensionSettings { pub auto_install_extensions: HashMap, bool>, pub auto_update_extensions: HashMap, bool>, pub granted_capabilities: Vec, + /// The extension language model providers that are allowed to read API keys + /// from environment variables. Each entry is a provider ID in the format + /// "extension_id:provider_id". + pub allowed_env_var_providers: HashSet>, } impl ExtensionSettings { @@ -60,6 +64,13 @@ impl Settings for ExtensionSettings { } }) .collect(), + allowed_env_var_providers: content + .extension + .allowed_env_var_providers + .clone() + .unwrap_or_default() + .into_iter() + .collect(), } } } diff --git a/crates/extension_host/src/wasm_host/llm_provider.rs b/crates/extension_host/src/wasm_host/llm_provider.rs index 0ae833080a59a5f2a628aebb1678ce4f1f302c1d..3f16fb31cd11d5b12ec076e67dfbcfa60feff93e 100644 --- a/crates/extension_host/src/wasm_host/llm_provider.rs +++ b/crates/extension_host/src/wasm_host/llm_provider.rs @@ -1,3 +1,4 @@ +use crate::ExtensionSettings; use crate::wasm_host::WasmExtension; use crate::wasm_host::wit::{ @@ -9,6 +10,7 @@ use crate::wasm_host::wit::{ use anyhow::{Result, anyhow}; use credentials_provider::CredentialsProvider; use editor::Editor; +use extension::LanguageModelAuthConfig; use futures::future::BoxFuture; use futures::stream::BoxStream; use futures::{FutureExt, StreamExt}; @@ -37,12 +39,15 @@ pub struct ExtensionLanguageModelProvider { pub extension: WasmExtension, pub provider_info: LlmProviderInfo, icon_path: Option, + auth_config: Option, state: Entity, } pub struct ExtensionLlmProviderState { is_authenticated: bool, available_models: Vec, + env_var_allowed: bool, + api_key_from_env: bool, } impl EventEmitter<()> for ExtensionLlmProviderState {} @@ -54,17 +59,42 @@ impl ExtensionLanguageModelProvider { models: Vec, is_authenticated: bool, icon_path: Option, + auth_config: Option, cx: &mut App, ) -> Self { + let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id); + let env_var_allowed = ExtensionSettings::get_global(cx) + .allowed_env_var_providers + .contains(provider_id_string.as_str()); + + let (is_authenticated, api_key_from_env) = + if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) { + let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap(); + if let Ok(value) = std::env::var(env_var_name) { + if !value.is_empty() { + (true, true) + } else { + (is_authenticated, false) + } + } else { + (is_authenticated, false) + } + } else { + (is_authenticated, false) + }; + let state = cx.new(|_| ExtensionLlmProviderState { is_authenticated, available_models: models, + env_var_allowed, + api_key_from_env, }); Self { extension, provider_info, icon_path, + auth_config, state, } } @@ -194,13 +224,17 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { let credential_key = self.credential_key(); let extension = self.extension.clone(); let extension_provider_id = self.provider_info.id.clone(); + let full_provider_id = self.provider_id_string(); let state = self.state.clone(); + let auth_config = self.auth_config.clone(); cx.new(|cx| { ExtensionProviderConfigurationView::new( credential_key, extension, extension_provider_id, + full_provider_id, + auth_config, state, window, cx, @@ -274,6 +308,8 @@ struct ExtensionProviderConfigurationView { credential_key: String, extension: WasmExtension, extension_provider_id: String, + full_provider_id: String, + auth_config: Option, state: Entity, settings_markdown: Option>, api_key_editor: Entity, @@ -287,6 +323,8 @@ impl ExtensionProviderConfigurationView { credential_key: String, extension: WasmExtension, extension_provider_id: String, + full_provider_id: String, + auth_config: Option, state: Entity, window: &mut Window, cx: &mut Context, @@ -307,6 +345,8 @@ impl ExtensionProviderConfigurationView { credential_key, extension, extension_provider_id, + full_provider_id, + auth_config, state, settings_markdown: None, api_key_editor, @@ -362,7 +402,20 @@ impl ExtensionProviderConfigurationView { let credentials_provider = ::global(cx); let state = self.state.clone(); + // Check if we should use env var (already set in state during provider construction) + let api_key_from_env = self.state.read(cx).api_key_from_env; + cx.spawn(async move |this, cx| { + // If using env var, we're already authenticated + if api_key_from_env { + this.update(cx, |this, cx| { + this.loading_credentials = false; + cx.notify(); + }) + .log_err(); + return; + } + let credentials = credentials_provider .read_credentials(&credential_key, cx) .await @@ -388,6 +441,92 @@ impl ExtensionProviderConfigurationView { .detach(); } + fn toggle_env_var_permission(&mut self, cx: &mut Context) { + let full_provider_id: Arc = self.full_provider_id.clone().into(); + let env_var_name = match &self.auth_config { + Some(config) => config.env_var.clone(), + None => return, + }; + + let state = self.state.clone(); + let currently_allowed = self.state.read(cx).env_var_allowed; + + // Update settings file + settings::update_settings_file(::global(cx), cx, move |settings, _| { + let providers = settings + .extension + .allowed_env_var_providers + .get_or_insert_with(Vec::new); + + if currently_allowed { + providers.retain(|id| id.as_ref() != full_provider_id.as_ref()); + } else { + if !providers + .iter() + .any(|id| id.as_ref() == full_provider_id.as_ref()) + { + providers.push(full_provider_id.clone()); + } + } + }); + + // Update local state + let new_allowed = !currently_allowed; + let new_from_env = if new_allowed { + if let Some(var_name) = &env_var_name { + if let Ok(value) = std::env::var(var_name) { + !value.is_empty() + } else { + false + } + } else { + false + } + } else { + false + }; + + state.update(cx, |state, cx| { + state.env_var_allowed = new_allowed; + state.api_key_from_env = new_from_env; + if new_from_env { + state.is_authenticated = true; + } + cx.notify(); + }); + + // If env var is being disabled, reload credentials from keychain + if !new_allowed { + self.reload_keychain_credentials(cx); + } + + cx.notify(); + } + + fn reload_keychain_credentials(&mut self, cx: &mut Context) { + let credential_key = self.credential_key.clone(); + let credentials_provider = ::global(cx); + let state = self.state.clone(); + + cx.spawn(async move |_this, cx| { + let credentials = credentials_provider + .read_credentials(&credential_key, cx) + .await + .log_err() + .flatten(); + + let has_credentials = credentials.is_some(); + + let _ = cx.update(|cx| { + state.update(cx, |state, cx| { + state.is_authenticated = has_credentials; + cx.notify(); + }); + }); + }) + .detach(); + } + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { let api_key = self.api_key_editor.read(cx).text(cx); if api_key.is_empty() { @@ -456,6 +595,8 @@ impl gpui::Render for ExtensionProviderConfigurationView { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { let is_loading = self.loading_settings || self.loading_credentials; let is_authenticated = self.is_authenticated(cx); + let env_var_allowed = self.state.read(cx).env_var_allowed; + let api_key_from_env = self.state.read(cx).api_key_from_env; if is_loading { return v_flex() @@ -478,8 +619,67 @@ impl gpui::Render for ExtensionProviderConfigurationView { ); } + // Render env var checkbox if the extension specifies an env var + if let Some(auth_config) = &self.auth_config { + if let Some(env_var_name) = &auth_config.env_var { + let env_var_name = env_var_name.clone(); + let checkbox_label = + format!("Read API key from {} environment variable", env_var_name); + + content = content.child( + h_flex() + .gap_2() + .child( + ui::Checkbox::new("env-var-permission", env_var_allowed.into()) + .on_click(cx.listener(|this, _, _window, cx| { + this.toggle_env_var_permission(cx); + })), + ) + .child(Label::new(checkbox_label).size(LabelSize::Small)), + ); + + // Show status if env var is allowed + if env_var_allowed { + if api_key_from_env { + content = content.child( + h_flex() + .gap_2() + .child( + ui::Icon::new(ui::IconName::Check) + .color(Color::Success) + .size(ui::IconSize::Small), + ) + .child( + Label::new(format!("API key loaded from {}", env_var_name)) + .color(Color::Success), + ), + ); + return content.into_any_element(); + } else { + content = content.child( + h_flex() + .gap_2() + .child( + ui::Icon::new(ui::IconName::Warning) + .color(Color::Warning) + .size(ui::IconSize::Small), + ) + .child( + Label::new(format!( + "{} is not set or empty. You can set it and restart Zed, or enter an API key below.", + env_var_name + )) + .color(Color::Warning) + .size(LabelSize::Small), + ), + ); + } + } + } + } + // Render API key section - if is_authenticated { + if is_authenticated && !api_key_from_env { content = content.child( v_flex() .gap_2() @@ -501,13 +701,19 @@ impl gpui::Render for ExtensionProviderConfigurationView { })), ), ); - } else { + } else if !api_key_from_env { + let credential_label = self + .auth_config + .as_ref() + .and_then(|c| c.credential_label.clone()) + .unwrap_or_else(|| "API Key".to_string()); + content = content.child( v_flex() .gap_2() .on_action(cx.listener(Self::save_api_key)) .child( - Label::new("API Key") + Label::new(credential_label) .size(LabelSize::Small) .color(Color::Muted), ) diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_7_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_7_0.rs index 20ae6945b69d483336a165f5aae589b591f2e927..b2a6cc8315849d0c8364460011a381eaf041fba0 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_7_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_7_0.rs @@ -1,3 +1,4 @@ +use crate::ExtensionSettings; use crate::wasm_host::wit::since_v0_7_0::{ dap::{ AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, LaunchRequest, @@ -1195,6 +1196,55 @@ impl ExtensionImports for WasmState { } async fn llm_get_env_var(&mut self, name: String) -> wasmtime::Result> { + let extension_id = self.manifest.id.clone(); + + // Find which provider (if any) declares this env var in its auth config + let mut allowed_provider_id: Option> = None; + for (provider_id, provider_entry) in &self.manifest.language_model_providers { + if let Some(auth_config) = &provider_entry.auth { + if auth_config.env_var.as_deref() == Some(&name) { + allowed_provider_id = Some(provider_id.clone()); + break; + } + } + } + + // If no provider declares this env var, deny access + let Some(provider_id) = allowed_provider_id else { + log::warn!( + "Extension {} attempted to read env var {} which is not declared in any provider auth config", + extension_id, + name + ); + return Ok(None); + }; + + // Check if the user has allowed this provider to read env vars + let full_provider_id = format!("{}:{}", extension_id, provider_id); + let is_allowed = self + .on_main_thread(move |cx| { + async move { + cx.update(|cx| { + ExtensionSettings::get_global(cx) + .allowed_env_var_providers + .contains(full_provider_id.as_str()) + }) + .unwrap_or(false) + } + .boxed_local() + }) + .await; + + if !is_allowed { + log::debug!( + "Extension {} provider {} is not allowed to read env var {}", + extension_id, + provider_id, + name + ); + return Ok(None); + } + Ok(env::var(&name).ok()) } } diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs index 77a46118d2bfabf84a14773520acc8d70956da44..b5984d7a19a462254b606473aa76d8f5d97ab43c 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs @@ -1,3 +1,4 @@ +use crate::ExtensionSettings; use crate::wasm_host::wit::since_v0_8_0::{ dap::{ AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, LaunchRequest, @@ -1192,6 +1193,55 @@ impl ExtensionImports for WasmState { } async fn llm_get_env_var(&mut self, name: String) -> wasmtime::Result> { + let extension_id = self.manifest.id.clone(); + + // Find which provider (if any) declares this env var in its auth config + let mut allowed_provider_id: Option> = None; + for (provider_id, provider_entry) in &self.manifest.language_model_providers { + if let Some(auth_config) = &provider_entry.auth { + if auth_config.env_var.as_deref() == Some(&name) { + allowed_provider_id = Some(provider_id.clone()); + break; + } + } + } + + // If no provider declares this env var, deny access + let Some(provider_id) = allowed_provider_id else { + log::warn!( + "Extension {} attempted to read env var {} which is not declared in any provider auth config", + extension_id, + name + ); + return Ok(None); + }; + + // Check if the user has allowed this provider to read env vars + let full_provider_id = format!("{}:{}", extension_id, provider_id); + let is_allowed = self + .on_main_thread(move |cx| { + async move { + cx.update(|cx| { + ExtensionSettings::get_global(cx) + .allowed_env_var_providers + .contains(full_provider_id.as_str()) + }) + .unwrap_or(false) + } + .boxed_local() + }) + .await; + + if !is_allowed { + log::debug!( + "Extension {} provider {} is not allowed to read env var {}", + extension_id, + provider_id, + name + ); + return Ok(None); + } + Ok(env::var(&name).ok()) } } diff --git a/crates/settings/src/settings_content/extension.rs b/crates/settings/src/settings_content/extension.rs index 2fefd4ef38aeb9af133ed745d2732a3cb6ec77f7..64df163f4ec961cf6bfc469c18ac0f8884c39f0b 100644 --- a/crates/settings/src/settings_content/extension.rs +++ b/crates/settings/src/settings_content/extension.rs @@ -20,6 +20,12 @@ pub struct ExtensionSettingsContent { pub auto_update_extensions: HashMap, bool>, /// The capabilities granted to extensions. pub granted_extension_capabilities: Option>, + /// Extension language model providers that are allowed to read API keys from + /// environment variables. Each entry is a provider ID in the format + /// "extension_id:provider_id" (e.g., "openai:openai"). + /// + /// Default: [] + pub allowed_env_var_providers: Option>>, } /// A capability for an extension.