Migrate to extensions with fallback to builtin

Richard Feldman created

Change summary

Cargo.lock                                          |   1 
crates/extension/src/extension_host_proxy.rs        |   8 
crates/extension_host/src/extension_host.rs         | 124 ++++++++++++++
crates/extension_host/src/wasm_host/llm_provider.rs |   3 
crates/language_model/src/registry.rs               |  23 ++
crates/language_models/Cargo.toml                   |   1 
crates/language_models/src/extension.rs             |  54 ++++++
crates/language_models/src/language_models.rs       |   2 
crates/zed/src/main.rs                              |   6 
9 files changed, 213 insertions(+), 9 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -8859,6 +8859,7 @@ dependencies = [
  "credentials_provider",
  "deepseek",
  "editor",
+ "extension",
  "extension_host",
  "fs",
  "futures 0.3.31",

crates/extension/src/extension_host_proxy.rs 🔗

@@ -414,9 +414,17 @@ impl ExtensionLanguageModelProviderProxy for ExtensionHostProxy {
         cx: &mut App,
     ) {
         let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
+            log::warn!(
+                "ExtensionHostProxy::register_language_model_provider: no proxy set for provider {}",
+                provider_id
+            );
             return;
         };
 
+        log::info!(
+            "ExtensionHostProxy::register_language_model_provider: delegating to proxy for {}",
+            provider_id
+        );
         proxy.register_language_model_provider(provider_id, register_fn, cx)
     }
 

crates/extension_host/src/extension_host.rs 🔗

@@ -101,8 +101,17 @@ const LEGACY_LLM_EXTENSION_IDS: &[&str] = &[
 /// This migration only runs once per provider - we track which providers have been
 /// migrated in `migrated_llm_providers` to avoid overriding user preferences.
 fn migrate_legacy_llm_provider_env_var(manifest: &ExtensionManifest, cx: &mut App) {
+    log::info!(
+        "migrate_legacy_llm_provider_env_var called for extension: {}",
+        manifest.id
+    );
+
     // Only apply migration to known legacy LLM extensions
     if !LEGACY_LLM_EXTENSION_IDS.contains(&manifest.id.as_ref()) {
+        log::info!(
+            "  skipping - not a legacy LLM extension (known: {:?})",
+            LEGACY_LLM_EXTENSION_IDS
+        );
         return;
     }
 
@@ -122,7 +131,15 @@ fn migrate_legacy_llm_provider_env_var(manifest: &ExtensionManifest, cx: &mut Ap
             .migrated_llm_providers
             .contains(full_provider_id.as_ref());
 
+        log::info!(
+            "  provider {}: env_var={}, already_migrated={}",
+            full_provider_id,
+            env_var_name,
+            already_migrated
+        );
+
         if already_migrated {
+            log::info!("  skipping - already migrated");
             continue;
         }
 
@@ -131,6 +148,8 @@ fn migrate_legacy_llm_provider_env_var(manifest: &ExtensionManifest, cx: &mut Ap
             .map(|v| !v.is_empty())
             .unwrap_or(false);
 
+        log::info!("  env_var_is_set: {}", env_var_is_set);
+
         // Mark as migrated regardless of whether we enable env var reading
         let should_enable_env_var = env_var_is_set;
         settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
@@ -697,28 +716,114 @@ impl ExtensionStore {
     /// This can be used to make certain functionality provided by extensions
     /// available out-of-the-box.
     pub fn auto_install_extensions(&mut self, cx: &mut Context<Self>) {
+        log::info!("auto_install_extensions called");
+
         if cfg!(test) {
+            log::info!("auto_install_extensions: skipping because cfg!(test)");
             return;
         }
 
         let extension_settings = ExtensionSettings::get_global(cx);
 
+        log::info!(
+            "auto_install_extensions: settings has {} extensions: {:?}",
+            extension_settings.auto_install_extensions.len(),
+            extension_settings
+                .auto_install_extensions
+                .keys()
+                .collect::<Vec<_>>()
+        );
+
         let extensions_to_install = extension_settings
             .auto_install_extensions
             .keys()
-            .filter(|extension_id| extension_settings.should_auto_install(extension_id))
+            .filter(|extension_id| {
+                let should = extension_settings.should_auto_install(extension_id);
+                log::info!("  {} should_auto_install: {}", extension_id, should);
+                should
+            })
             .filter(|extension_id| {
                 let is_already_installed = self
                     .extension_index
                     .extensions
                     .contains_key(extension_id.as_ref());
-                !is_already_installed && !SUPPRESSED_EXTENSIONS.contains(&extension_id.as_ref())
+                let dominated = SUPPRESSED_EXTENSIONS.contains(&extension_id.as_ref());
+                log::info!(
+                    "  {} is_already_installed: {}, suppressed: {}",
+                    extension_id,
+                    is_already_installed,
+                    dominated
+                );
+                !is_already_installed && !dominated
             })
             .cloned()
             .collect::<Vec<_>>();
 
+        log::info!(
+            "auto_install_extensions: will install {:?}",
+            extensions_to_install
+        );
+
         cx.spawn(async move |this, cx| {
             for extension_id in extensions_to_install {
+                // HACK: In debug builds, check if extension exists locally in repo's extensions/ dir
+                // and install as dev extension instead of fetching from registry.
+                // This allows testing unpublished extensions.
+                #[cfg(debug_assertions)]
+                {
+                    let local_extension_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
+                        .parent()
+                        .unwrap()
+                        .parent()
+                        .unwrap()
+                        .join("extensions")
+                        .join(extension_id.as_ref());
+
+                    if local_extension_path.exists() {
+                        log::info!(
+                            "Auto-installing local dev extension: {} from {:?}",
+                            extension_id,
+                            local_extension_path
+                        );
+
+                        // Force-remove existing extension directory if it exists and isn't a symlink
+                        // This handles the case where the extension was previously installed from the registry
+                        if let Some(installed_dir) = this
+                            .update(cx, |this, _cx| this.installed_dir.clone())
+                            .ok()
+                        {
+                            let existing_path = installed_dir.join(extension_id.as_ref());
+                            if existing_path.exists() {
+                                let metadata = std::fs::symlink_metadata(&existing_path);
+                                let is_symlink = metadata.map(|m| m.is_symlink()).unwrap_or(false);
+                                if !is_symlink {
+                                    log::info!(
+                                        "Removing existing non-dev extension directory: {:?}",
+                                        existing_path
+                                    );
+                                    if let Err(e) = std::fs::remove_dir_all(&existing_path) {
+                                        log::error!(
+                                            "Failed to remove existing extension directory {:?}: {}",
+                                            existing_path,
+                                            e
+                                        );
+                                    }
+                                }
+                            }
+                        }
+
+                        if let Some(task) = this
+                            .update(cx, |this, cx| {
+                                this.install_dev_extension(local_extension_path, cx)
+                            })
+                            .ok()
+                        {
+                            task.await.log_err();
+                        }
+                        continue;
+                    }
+                }
+
                 this.update(cx, |this, cx| {
                     this.install_latest_extension(extension_id.clone(), cx);
                 })
@@ -1139,6 +1244,11 @@ impl ExtensionStore {
 
             this.update(cx, |this, cx| this.reload(None, cx))?.await;
             this.update(cx, |this, cx| {
+                // Run migration for legacy LLM provider env vars
+                if let Some(manifest) = this.extension_manifest_for_id(&extension_id) {
+                    migrate_legacy_llm_provider_env_var(&manifest, cx);
+                }
+
                 cx.emit(Event::ExtensionInstalled(extension_id.clone()));
                 if let Some(events) = ExtensionEvents::try_global(cx)
                     && let Some(manifest) = this.extension_manifest_for_id(&extension_id)
@@ -1696,9 +1806,19 @@ impl ExtensionStore {
                     }
 
                     // Register LLM providers
+                    log::info!(
+                        "Extension {} has {} LLM providers to register",
+                        manifest.id,
+                        llm_providers_with_models.len()
+                    );
                     for llm_provider in llm_providers_with_models {
                         let provider_id: Arc<str> =
                             format!("{}:{}", manifest.id, llm_provider.provider_info.id).into();
+                        log::info!(
+                            "Registering LLM provider {} with {} models",
+                            provider_id,
+                            llm_provider.models.len()
+                        );
                         let wasm_ext = extension.as_ref().clone();
                         let pinfo = llm_provider.provider_info.clone();
                         let mods = llm_provider.models.clone();

crates/extension_host/src/wasm_host/llm_provider.rs 🔗

@@ -770,7 +770,8 @@ impl LanguageModel for ExtensionLanguageModel {
     }
 
     fn name(&self) -> LanguageModelName {
-        LanguageModelName::from(self.model_info.name.clone())
+        // HACK: Add "(Extension)" prefix to help distinguish extension models during debugging
+        LanguageModelName::from(format!("(Extension) {}", self.model_info.name))
     }
 
     fn provider_id(&self) -> LanguageModelProviderId {

crates/language_model/src/registry.rs 🔗

@@ -157,6 +157,11 @@ impl LanguageModelRegistry {
         cx: &mut Context<Self>,
     ) {
         let id = provider.id();
+        log::info!(
+            "LanguageModelRegistry::register_provider: {} (name: {})",
+            id,
+            provider.name()
+        );
 
         let subscription = provider.subscribe(cx, {
             let id = id.clone();
@@ -196,8 +201,22 @@ impl LanguageModelRegistry {
 
     /// Returns providers, filtering out hidden built-in providers.
     pub fn visible_providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
-        self.providers()
-            .into_iter()
+        let all = self.providers();
+        log::info!(
+            "LanguageModelRegistry::visible_providers called, all_providers={}, installed_llm_extension_ids={:?}",
+            all.len(),
+            self.installed_llm_extension_ids
+        );
+        for p in &all {
+            let hidden = self.should_hide_provider(&p.id());
+            log::info!(
+                "  provider {} (id: {}): hidden={}",
+                p.name(),
+                p.id(),
+                hidden
+            );
+        }
+        all.into_iter()
             .filter(|p| !self.should_hide_provider(&p.id()))
             .collect()
     }

crates/language_models/Cargo.toml 🔗

@@ -28,6 +28,7 @@ convert_case.workspace = true
 copilot.workspace = true
 credentials_provider.workspace = true
 deepseek = { workspace = true, features = ["schemars"] }
+extension.workspace = true
 extension_host.workspace = true
 fs.workspace = true
 futures.workspace = true

crates/language_models/src/extension.rs 🔗

@@ -1,5 +1,10 @@
+use ::extension::{
+    ExtensionHostProxy, ExtensionLanguageModelProviderProxy, LanguageModelProviderRegistration,
+};
 use collections::HashMap;
-use std::sync::LazyLock;
+use gpui::{App, Entity};
+use language_model::{LanguageModelProviderId, LanguageModelRegistry};
+use std::sync::{Arc, LazyLock};
 
 /// Maps built-in provider IDs to their corresponding extension IDs.
 /// When an extension with this ID is installed, the built-in provider should be hidden.
@@ -9,7 +14,7 @@ static BUILTIN_TO_EXTENSION_MAP: LazyLock<HashMap<&'static str, &'static str>> =
         map.insert("anthropic", "anthropic");
         map.insert("openai", "openai");
         map.insert("google", "google-ai");
-        map.insert("open_router", "open-router");
+        map.insert("openrouter", "open-router");
         map.insert("copilot_chat", "copilot-chat");
         map
     });
@@ -18,3 +23,48 @@ static BUILTIN_TO_EXTENSION_MAP: LazyLock<HashMap<&'static str, &'static str>> =
 pub fn extension_for_builtin_provider(provider_id: &str) -> Option<&'static str> {
     BUILTIN_TO_EXTENSION_MAP.get(provider_id).copied()
 }
+
+/// Proxy that registers extension language model providers with the LanguageModelRegistry.
+pub struct LanguageModelProviderRegistryProxy {
+    registry: Entity<LanguageModelRegistry>,
+}
+
+impl LanguageModelProviderRegistryProxy {
+    pub fn new(registry: Entity<LanguageModelRegistry>) -> Self {
+        Self { registry }
+    }
+}
+
+impl ExtensionLanguageModelProviderProxy for LanguageModelProviderRegistryProxy {
+    fn register_language_model_provider(
+        &self,
+        provider_id: Arc<str>,
+        register_fn: LanguageModelProviderRegistration,
+        cx: &mut App,
+    ) {
+        log::info!(
+            "LanguageModelProviderRegistryProxy::register_language_model_provider called for: {}",
+            provider_id
+        );
+        // The register_fn closure will call registry.register_provider internally
+        register_fn(cx);
+    }
+
+    fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {
+        self.registry.update(cx, |registry, cx| {
+            registry.unregister_provider(LanguageModelProviderId::from(provider_id), cx);
+        });
+    }
+}
+
+/// Initialize the extension language model provider proxy.
+/// This must be called BEFORE extension_host::init to ensure the proxy is available
+/// when extensions try to register their language model providers.
+pub fn init_proxy(cx: &mut App) {
+    let proxy = ExtensionHostProxy::default_global(cx);
+    let registry = LanguageModelRegistry::global(cx);
+    log::info!(
+        "language_models::extension::init_proxy: registering LanguageModelProviderRegistryProxy"
+    );
+    proxy.register_language_model_provider_proxy(LanguageModelProviderRegistryProxy::new(registry));
+}

crates/language_models/src/language_models.rs 🔗

@@ -14,7 +14,7 @@ pub mod provider;
 mod settings;
 pub mod ui;
 
-pub use crate::extension::extension_for_builtin_provider;
+pub use crate::extension::{extension_for_builtin_provider, init_proxy as init_extension_proxy};
 pub use crate::google_ai_api_key::api_key_for_gemini_cli;
 use crate::provider::anthropic::AnthropicLanguageModelProvider;
 use crate::provider::bedrock::BedrockLanguageModelProvider;

crates/zed/src/main.rs 🔗

@@ -554,6 +554,11 @@ pub fn main() {
         dap_adapters::init(cx);
         auto_update_ui::init(cx);
         reliability::init(client.clone(), cx);
+        // Initialize the language model registry first, then set up the extension proxy
+        // BEFORE extension_host::init so that extensions can register their LLM providers
+        // when they load.
+        language_model::init(app_state.client.clone(), cx);
+        language_models::init_extension_proxy(cx);
         extension_host::init(
             extension_host_proxy.clone(),
             app_state.fs.clone(),
@@ -579,7 +584,6 @@ pub fn main() {
             cx,
         );
         supermaven::init(app_state.client.clone(), cx);
-        language_model::init(app_state.client.clone(), cx);
         language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
         acp_tools::init(cx);
         edit_prediction_ui::init(cx);