@@ -5,6 +5,7 @@ use crate::capability_granter::CapabilityGranter;
use crate::{ExtensionManifest, ExtensionSettings};
use anyhow::{Context as _, Result, anyhow, bail};
use async_trait::async_trait;
+use collections::HashSet;
use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest};
use extension::{
CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary,
@@ -59,6 +60,8 @@ pub struct WasmHost {
pub work_dir: PathBuf,
/// The capabilities granted to extensions running on the host.
pub(crate) granted_capabilities: Vec<ExtensionCapability>,
+ /// Extension LLM providers allowed to read API keys from environment variables.
+ pub(crate) allowed_env_var_providers: HashSet<Arc<str>>,
_main_thread_message_task: Task<()>,
main_thread_message_tx: mpsc::UnboundedSender<MainThreadCall>,
}
@@ -73,12 +76,6 @@ pub struct WasmExtension {
_task: Arc<Task<Result<(), gpui_tokio::JoinError>>>,
}
-impl Drop for WasmExtension {
- fn drop(&mut self) {
- self.tx.close_channel();
- }
-}
-
#[async_trait]
impl extension::Extension for WasmExtension {
fn manifest(&self) -> Arc<ExtensionManifest> {
@@ -591,6 +588,7 @@ impl WasmHost {
proxy,
release_channel: ReleaseChannel::global(cx),
granted_capabilities: extension_settings.granted_capabilities.clone(),
+ allowed_env_var_providers: extension_settings.allowed_env_var_providers.clone(),
_main_thread_message_task: task,
main_thread_message_tx: tx,
})
@@ -1,4 +1,3 @@
-use crate::ExtensionSettings;
use crate::wasm_host::wit::since_v0_8_0::{
dap::{
AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, LaunchRequest,
@@ -1129,6 +1128,33 @@ impl llm_provider::Host for WasmState {
async fn get_credential(&mut self, provider_id: String) -> wasmtime::Result<Option<String>> {
let extension_id = self.manifest.id.clone();
+
+ // Check if this provider has an env var configured and if the user has allowed it
+ let env_var_name = self
+ .manifest
+ .language_model_providers
+ .get(&Arc::<str>::from(provider_id.as_str()))
+ .and_then(|entry| entry.auth.as_ref())
+ .and_then(|auth| auth.env_var.clone());
+
+ if let Some(env_var_name) = env_var_name {
+ let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
+ // Use cached settings from WasmHost instead of going to main thread
+ let is_allowed = self
+ .host
+ .allowed_env_var_providers
+ .contains(&full_provider_id);
+
+ if is_allowed {
+ if let Ok(value) = env::var(&env_var_name) {
+ if !value.is_empty() {
+ return Ok(Some(value));
+ }
+ }
+ }
+ }
+
+ // Fall back to credential store
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
self.on_main_thread(move |cx| {
@@ -1214,20 +1240,12 @@ impl llm_provider::Host for WasmState {
};
// Check if the user has allowed this provider to read env vars
- let full_provider_id = format!("{}:{}", extension_id, provider_id);
+ // Use cached settings from WasmHost instead of going to main thread
+ let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
let is_allowed = self
- .on_main_thread(move |cx| {
- async move {
- cx.update(|cx| {
- ExtensionSettings::get_global(cx)
- .allowed_env_var_providers
- .contains(full_provider_id.as_str())
- })
- .unwrap_or(false)
- }
- .boxed_local()
- })
- .await;
+ .host
+ .allowed_env_var_providers
+ .contains(&full_provider_id);
if !is_allowed {
log::debug!(
@@ -1703,8 +1703,7 @@ impl AnthropicEventMapper {
let event = serde_json::from_str::<serde_json::Value>(&tool_use.input_json)
.ok()
.and_then(|input| {
- let input_json_roundtripped =
- serde_json::to_string(&input).ok()?.to_string();
+ let input_json_roundtripped = serde_json::to_string(&input).ok()?;
if !tool_use.input_json.starts_with(&input_json_roundtripped) {
return None;