Fix bugs around logging out from provider extensions

Richard Feldman created

Change summary

crates/agent_ui/src/agent_ui.rs                         | 27 ++++
crates/extension_host/src/extension_host.rs             | 66 +++++++----
crates/extension_host/src/extension_settings.rs         | 10 +
crates/extension_host/src/wasm_host.rs                  |  5 
crates/extension_host/src/wasm_host/llm_provider.rs     | 14 ++
crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs | 40 +++++-
crates/settings/src/settings_content/extension.rs       |  3 
7 files changed, 127 insertions(+), 38 deletions(-)

Detailed changes

crates/agent_ui/src/agent_ui.rs 🔗

@@ -371,26 +371,49 @@ fn update_active_language_model_from_settings(cx: &mut App) {
         }
     }
 
-    let default = settings.default_model.as_ref().map(to_selected_model);
+    // Filter out models from providers that are not authenticated
+    fn is_provider_authenticated(
+        selection: &LanguageModelSelection,
+        registry: &LanguageModelRegistry,
+        cx: &App,
+    ) -> bool {
+        let provider_id = LanguageModelProviderId::from(selection.provider.0.clone());
+        registry
+            .provider(&provider_id)
+            .map_or(false, |provider| provider.is_authenticated(cx))
+    }
+
+    let registry = LanguageModelRegistry::global(cx);
+    let registry_ref = registry.read(cx);
+
+    let default = settings
+        .default_model
+        .as_ref()
+        .filter(|s| is_provider_authenticated(s, registry_ref, cx))
+        .map(to_selected_model);
     let inline_assistant = settings
         .inline_assistant_model
         .as_ref()
+        .filter(|s| is_provider_authenticated(s, registry_ref, cx))
         .map(to_selected_model);
     let commit_message = settings
         .commit_message_model
         .as_ref()
+        .filter(|s| is_provider_authenticated(s, registry_ref, cx))
         .map(to_selected_model);
     let thread_summary = settings
         .thread_summary_model
         .as_ref()
+        .filter(|s| is_provider_authenticated(s, registry_ref, cx))
         .map(to_selected_model);
     let inline_alternatives = settings
         .inline_alternatives
         .iter()
+        .filter(|s| is_provider_authenticated(s, registry_ref, cx))
         .map(to_selected_model)
         .collect::<Vec<_>>();
 
-    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+    registry.update(cx, |registry, cx| {
         registry.select_default_model(default.as_ref(), cx);
         registry.select_inline_assistant_model(inline_assistant.as_ref(), cx);
         registry.select_commit_message_model(commit_message.as_ref(), cx);

crates/extension_host/src/extension_host.rs 🔗

@@ -82,7 +82,7 @@ const FS_WATCH_LATENCY: Duration = Duration::from_millis(100);
 
 /// Extension IDs that are being migrated from hardcoded LLM providers.
 /// For backwards compatibility, if the user has the corresponding env var set,
-/// we automatically enable env var reading for these extensions.
+/// we automatically enable env var reading for these extensions on first install.
 const LEGACY_LLM_EXTENSION_IDS: &[&str] = &[
     "anthropic",
     "copilot_chat",
@@ -93,6 +93,9 @@ const LEGACY_LLM_EXTENSION_IDS: &[&str] = &[
 
 /// Migrates legacy LLM provider extensions by auto-enabling env var reading
 /// if the env var is currently present in the environment.
+///
+/// This migration only runs once per provider - we track which providers have been
+/// migrated in `migrated_llm_providers` to avoid overriding user preferences.
 fn migrate_legacy_llm_provider_env_var(manifest: &ExtensionManifest, cx: &mut App) {
     // Only apply migration to known legacy LLM extensions
     if !LEGACY_LLM_EXTENSION_IDS.contains(&manifest.id.as_ref()) {
@@ -108,49 +111,64 @@ fn migrate_legacy_llm_provider_env_var(manifest: &ExtensionManifest, cx: &mut Ap
             continue;
         };
 
-        // Check if the env var is present and non-empty
-        let env_var_is_set = std::env::var(env_var_name)
-            .map(|v| !v.is_empty())
-            .unwrap_or(false);
-
-        if !env_var_is_set {
-            continue;
-        }
-
         let full_provider_id: Arc<str> = format!("{}:{}", manifest.id, provider_id).into();
 
-        // Check if already in settings
-        let already_allowed = ExtensionSettings::get_global(cx)
-            .allowed_env_var_providers
+        // Check if we've already run migration for this provider (regardless of outcome)
+        let already_migrated = ExtensionSettings::get_global(cx)
+            .migrated_llm_providers
             .contains(full_provider_id.as_ref());
 
-        if already_allowed {
+        if already_migrated {
             continue;
         }
 
-        // Auto-enable env var reading for this provider
-        log::info!(
-            "Migrating legacy LLM provider {}: auto-enabling {} env var reading",
-            full_provider_id,
-            env_var_name
-        );
+        // Check if the env var is present and non-empty
+        let env_var_is_set = std::env::var(env_var_name)
+            .map(|v| !v.is_empty())
+            .unwrap_or(false);
 
+        // Mark as migrated regardless of whether we enable env var reading
         settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
             let full_provider_id = full_provider_id.clone();
+            let env_var_is_set = env_var_is_set;
             move |settings, _| {
-                let providers = settings
+                // Always mark as migrated
+                let migrated = settings
                     .extension
-                    .allowed_env_var_providers
+                    .migrated_llm_providers
                     .get_or_insert_with(Vec::new);
 
-                if !providers
+                if !migrated
                     .iter()
                     .any(|id| id.as_ref() == full_provider_id.as_ref())
                 {
-                    providers.push(full_provider_id);
+                    migrated.push(full_provider_id.clone());
+                }
+
+                // Only enable env var reading if the env var is set
+                if env_var_is_set {
+                    let providers = settings
+                        .extension
+                        .allowed_env_var_providers
+                        .get_or_insert_with(Vec::new);
+
+                    if !providers
+                        .iter()
+                        .any(|id| id.as_ref() == full_provider_id.as_ref())
+                    {
+                        providers.push(full_provider_id);
+                    }
                 }
             }
         });
+
+        if env_var_is_set {
+            log::info!(
+                "Migrating legacy LLM provider {}: auto-enabling {} env var reading",
+                full_provider_id,
+                env_var_name
+            );
+        }
     }
 }
 

crates/extension_host/src/extension_settings.rs 🔗

@@ -20,6 +20,9 @@ pub struct ExtensionSettings {
     /// from environment variables. Each entry is a provider ID in the format
     /// "extension_id:provider_id".
     pub allowed_env_var_providers: HashSet<Arc<str>>,
+    /// Tracks which legacy LLM providers have been migrated.
+    /// This prevents the migration from running multiple times and overriding user preferences.
+    pub migrated_llm_providers: HashSet<Arc<str>>,
 }
 
 impl ExtensionSettings {
@@ -71,6 +74,13 @@ impl Settings for ExtensionSettings {
                 .unwrap_or_default()
                 .into_iter()
                 .collect(),
+            migrated_llm_providers: content
+                .extension
+                .migrated_llm_providers
+                .clone()
+                .unwrap_or_default()
+                .into_iter()
+                .collect(),
         }
     }
 }

crates/extension_host/src/wasm_host.rs 🔗

@@ -5,7 +5,7 @@ use crate::capability_granter::CapabilityGranter;
 use crate::{ExtensionManifest, ExtensionSettings};
 use anyhow::{Context as _, Result, anyhow, bail};
 use async_trait::async_trait;
-use collections::HashSet;
+
 use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest};
 use extension::{
     CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary,
@@ -60,8 +60,6 @@ pub struct WasmHost {
     pub work_dir: PathBuf,
     /// The capabilities granted to extensions running on the host.
     pub(crate) granted_capabilities: Vec<ExtensionCapability>,
-    /// Extension LLM providers allowed to read API keys from environment variables.
-    pub(crate) allowed_env_var_providers: HashSet<Arc<str>>,
     _main_thread_message_task: Task<()>,
     main_thread_message_tx: mpsc::UnboundedSender<MainThreadCall>,
 }
@@ -597,7 +595,6 @@ impl WasmHost {
             proxy,
             release_channel: ReleaseChannel::global(cx),
             granted_capabilities: extension_settings.granted_capabilities.clone(),
-            allowed_env_var_providers: extension_settings.allowed_env_var_providers.clone(),
             _main_thread_message_task: task,
             main_thread_message_tx: tx,
         })

crates/extension_host/src/wasm_host/llm_provider.rs 🔗

@@ -495,6 +495,20 @@ impl ExtensionProviderConfigurationView {
             cx.notify();
         });
 
+        // If env var is being enabled, clear any stored keychain credentials
+        // so there's only one source of truth for the API key
+        if new_allowed {
+            let credential_key = self.credential_key.clone();
+            let credentials_provider = <dyn CredentialsProvider>::global(cx);
+            cx.spawn(async move |_this, cx| {
+                credentials_provider
+                    .delete_credentials(&credential_key, cx)
+                    .await
+                    .log_err();
+            })
+            .detach();
+        }
+
         // If env var is being disabled, reload credentials from keychain
         if !new_allowed {
             self.reload_keychain_credentials(cx);

crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs 🔗

@@ -1139,11 +1139,23 @@ impl llm_provider::Host for WasmState {
 
         if let Some(env_var_name) = env_var_name {
             let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
-            // Use cached settings from WasmHost instead of going to main thread
+            // Read settings dynamically to get current allowed_env_var_providers
             let is_allowed = self
-                .host
-                .allowed_env_var_providers
-                .contains(&full_provider_id);
+                .on_main_thread({
+                    let full_provider_id = full_provider_id.clone();
+                    move |cx| {
+                        async move {
+                            cx.update(|cx| {
+                                crate::extension_settings::ExtensionSettings::get_global(cx)
+                                    .allowed_env_var_providers
+                                    .contains(&full_provider_id)
+                            })
+                        }
+                        .boxed_local()
+                    }
+                })
+                .await
+                .unwrap_or(false);
 
             if is_allowed {
                 if let Ok(value) = env::var(&env_var_name) {
@@ -1240,12 +1252,24 @@ impl llm_provider::Host for WasmState {
         };
 
         // Check if the user has allowed this provider to read env vars
-        // Use cached settings from WasmHost instead of going to main thread
+        // Read settings dynamically to get current allowed_env_var_providers
         let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
         let is_allowed = self
-            .host
-            .allowed_env_var_providers
-            .contains(&full_provider_id);
+            .on_main_thread({
+                let full_provider_id = full_provider_id.clone();
+                move |cx| {
+                    async move {
+                        cx.update(|cx| {
+                            crate::extension_settings::ExtensionSettings::get_global(cx)
+                                .allowed_env_var_providers
+                                .contains(&full_provider_id)
+                        })
+                    }
+                    .boxed_local()
+                }
+            })
+            .await
+            .unwrap_or(false);
 
         if !is_allowed {
             log::debug!(

crates/settings/src/settings_content/extension.rs 🔗

@@ -26,6 +26,9 @@ pub struct ExtensionSettingsContent {
     ///
     /// Default: []
     pub allowed_env_var_providers: Option<Vec<Arc<str>>>,
+    /// Tracks which legacy LLM providers have been migrated. This is an internal
+    /// setting used to prevent the migration from running multiple times.
+    pub migrated_llm_providers: Option<Vec<Arc<str>>>,
 }
 
 /// A capability for an extension.