Fix API key bug

Richard Feldman created

Change summary

crates/extension_host/src/wasm_host.rs                  | 10 -
crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs | 46 +++++++---
crates/language_models/src/provider/cloud.rs            |  3 
3 files changed, 37 insertions(+), 22 deletions(-)

Detailed changes

crates/extension_host/src/wasm_host.rs 🔗

@@ -5,6 +5,7 @@ use crate::capability_granter::CapabilityGranter;
 use crate::{ExtensionManifest, ExtensionSettings};
 use anyhow::{Context as _, Result, anyhow, bail};
 use async_trait::async_trait;
+use collections::HashSet;
 use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest};
 use extension::{
     CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary,
@@ -59,6 +60,8 @@ pub struct WasmHost {
     pub work_dir: PathBuf,
     /// The capabilities granted to extensions running on the host.
     pub(crate) granted_capabilities: Vec<ExtensionCapability>,
+    /// Extension LLM providers allowed to read API keys from environment variables.
+    pub(crate) allowed_env_var_providers: HashSet<Arc<str>>,
     _main_thread_message_task: Task<()>,
     main_thread_message_tx: mpsc::UnboundedSender<MainThreadCall>,
 }
@@ -73,12 +76,6 @@ pub struct WasmExtension {
     _task: Arc<Task<Result<(), gpui_tokio::JoinError>>>,
 }
 
-impl Drop for WasmExtension {
-    fn drop(&mut self) {
-        self.tx.close_channel();
-    }
-}
-
 #[async_trait]
 impl extension::Extension for WasmExtension {
     fn manifest(&self) -> Arc<ExtensionManifest> {
@@ -591,6 +588,7 @@ impl WasmHost {
             proxy,
             release_channel: ReleaseChannel::global(cx),
             granted_capabilities: extension_settings.granted_capabilities.clone(),
+            allowed_env_var_providers: extension_settings.allowed_env_var_providers.clone(),
             _main_thread_message_task: task,
             main_thread_message_tx: tx,
         })

crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs 🔗

@@ -1,4 +1,3 @@
-use crate::ExtensionSettings;
 use crate::wasm_host::wit::since_v0_8_0::{
     dap::{
         AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, LaunchRequest,
@@ -1129,6 +1128,33 @@ impl llm_provider::Host for WasmState {
 
     async fn get_credential(&mut self, provider_id: String) -> wasmtime::Result<Option<String>> {
         let extension_id = self.manifest.id.clone();
+
+        // Check if this provider has an env var configured and if the user has allowed it
+        let env_var_name = self
+            .manifest
+            .language_model_providers
+            .get(&Arc::<str>::from(provider_id.as_str()))
+            .and_then(|entry| entry.auth.as_ref())
+            .and_then(|auth| auth.env_var.clone());
+
+        if let Some(env_var_name) = env_var_name {
+            let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
+            // Use cached settings from WasmHost instead of going to main thread
+            let is_allowed = self
+                .host
+                .allowed_env_var_providers
+                .contains(&full_provider_id);
+
+            if is_allowed {
+                if let Ok(value) = env::var(&env_var_name) {
+                    if !value.is_empty() {
+                        return Ok(Some(value));
+                    }
+                }
+            }
+        }
+
+        // Fall back to credential store
         let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
 
         self.on_main_thread(move |cx| {
@@ -1214,20 +1240,12 @@ impl llm_provider::Host for WasmState {
         };
 
         // Check if the user has allowed this provider to read env vars
-        let full_provider_id = format!("{}:{}", extension_id, provider_id);
+        // Use cached settings from WasmHost instead of going to main thread
+        let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
         let is_allowed = self
-            .on_main_thread(move |cx| {
-                async move {
-                    cx.update(|cx| {
-                        ExtensionSettings::get_global(cx)
-                            .allowed_env_var_providers
-                            .contains(full_provider_id.as_str())
-                    })
-                    .unwrap_or(false)
-                }
-                .boxed_local()
-            })
-            .await;
+            .host
+            .allowed_env_var_providers
+            .contains(&full_provider_id);
 
         if !is_allowed {
             log::debug!(

crates/language_models/src/provider/cloud.rs 🔗

@@ -1703,8 +1703,7 @@ impl AnthropicEventMapper {
                         let event = serde_json::from_str::<serde_json::Value>(&tool_use.input_json)
                             .ok()
                             .and_then(|input| {
-                                let input_json_roundtripped =
-                                    serde_json::to_string(&input).ok()?.to_string();
+                                let input_json_roundtripped = serde_json::to_string(&input).ok()?;
 
                                 if !tool_use.input_json.starts_with(&input_json_roundtripped) {
                                     return None;