Fix auth and subscriptions for provider extensions

Richard Feldman created

Change summary

crates/agent/src/agent.rs                           |  27 ++
crates/agent_ui/src/agent_ui.rs                     |  34 +++
crates/agent_ui/src/language_model_selector.rs      | 153 ++++++++++++--
crates/extension/src/extension_host_proxy.rs        |   5 
crates/extension_host/src/extension_host.rs         |  67 +++++
crates/extension_host/src/wasm_host.rs              |   6 
crates/extension_host/src/wasm_host/llm_provider.rs |  13 +
crates/language_model/src/registry.rs               |  31 ++
crates/language_models/src/extension.rs             |  10 
9 files changed, 298 insertions(+), 48 deletions(-)

Detailed changes

crates/agent/src/agent.rs 🔗

@@ -103,12 +103,22 @@ impl LanguageModels {
     }
 
     fn refresh_list(&mut self, cx: &App) {
+        let now = std::time::SystemTime::now()
+            .duration_since(std::time::UNIX_EPOCH)
+            .unwrap_or_default()
+            .as_millis();
+        eprintln!("[{}ms] LanguageModels::refresh_list called", now);
         let providers = LanguageModelRegistry::global(cx)
             .read(cx)
             .providers()
             .into_iter()
             .filter(|provider| provider.is_authenticated(cx))
             .collect::<Vec<_>>();
+        eprintln!(
+            "[{}ms] LanguageModels::refresh_list got {} authenticated providers",
+            now,
+            providers.len()
+        );
 
         let mut language_model_list = IndexMap::default();
         let mut recommended_models = HashSet::default();
@@ -146,6 +156,15 @@ impl LanguageModels {
 
         self.models = models;
         self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
+        let now = std::time::SystemTime::now()
+            .duration_since(std::time::UNIX_EPOCH)
+            .unwrap_or_default()
+            .as_millis();
+        eprintln!(
+            "[{}ms] LanguageModels::refresh_list completed with {} models in list",
+            now,
+            self.models.len()
+        );
         self.refresh_models_tx.send(()).ok();
     }
 
@@ -603,6 +622,14 @@ impl NativeAgent {
         _event: &language_model::Event,
         cx: &mut Context<Self>,
     ) {
+        let now = std::time::SystemTime::now()
+            .duration_since(std::time::UNIX_EPOCH)
+            .unwrap_or_default()
+            .as_millis();
+        eprintln!(
+            "[{}ms] NativeAgent::handle_models_updated_event called",
+            now
+        );
         self.models.refresh_list(cx);
 
         let registry = LanguageModelRegistry::read_global(cx);

crates/agent_ui/src/agent_ui.rs 🔗

@@ -344,9 +344,37 @@ fn init_language_model_settings(cx: &mut App) {
     cx.subscribe(
         &LanguageModelRegistry::global(cx),
         |_, event: &language_model::Event, cx| match event {
-            language_model::Event::ProviderStateChanged(_)
-            | language_model::Event::AddedProvider(_)
-            | language_model::Event::RemovedProvider(_) => {
+            language_model::Event::ProviderStateChanged(id) => {
+                let now = std::time::SystemTime::now()
+                    .duration_since(std::time::UNIX_EPOCH)
+                    .unwrap_or_default()
+                    .as_millis();
+                eprintln!(
+                    "[{}ms] agent_ui global subscription: ProviderStateChanged for {:?}",
+                    now, id
+                );
+                update_active_language_model_from_settings(cx);
+            }
+            language_model::Event::AddedProvider(id) => {
+                let now = std::time::SystemTime::now()
+                    .duration_since(std::time::UNIX_EPOCH)
+                    .unwrap_or_default()
+                    .as_millis();
+                eprintln!(
+                    "[{}ms] agent_ui global subscription: AddedProvider for {:?}",
+                    now, id
+                );
+                update_active_language_model_from_settings(cx);
+            }
+            language_model::Event::RemovedProvider(id) => {
+                let now = std::time::SystemTime::now()
+                    .duration_since(std::time::UNIX_EPOCH)
+                    .unwrap_or_default()
+                    .as_millis();
+                eprintln!(
+                    "[{}ms] agent_ui global subscription: RemovedProvider for {:?}",
+                    now, id
+                );
                 update_active_language_model_from_settings(cx);
             }
             _ => {}

crates/agent_ui/src/language_model_selector.rs 🔗

@@ -1,10 +1,9 @@
 use std::{cmp::Reverse, sync::Arc};
 
 use collections::IndexMap;
+use futures::{StreamExt, channel::mpsc};
 use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
-use gpui::{
-    Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
-};
+use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
 use language_model::{
     AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
     LanguageModelRegistry,
@@ -47,7 +46,13 @@ pub fn language_model_selector(
 }
 
 fn all_models(cx: &App) -> GroupedModels {
+    let now = std::time::SystemTime::now()
+        .duration_since(std::time::UNIX_EPOCH)
+        .unwrap_or_default()
+        .as_millis();
+    eprintln!("[{}ms] all_models() called", now);
     let providers = LanguageModelRegistry::global(cx).read(cx).providers();
+    eprintln!("[{}ms] all_models: got {} providers", now, providers.len());
 
     let recommended = providers
         .iter()
@@ -62,19 +67,41 @@ fn all_models(cx: &App) -> GroupedModels {
         })
         .collect();
 
-    let all = providers
+    let all: Vec<ModelInfo> = providers
         .iter()
         .flat_map(|provider| {
-            provider
-                .provided_models(cx)
-                .into_iter()
-                .map(|model| ModelInfo {
-                    model,
-                    icon: provider.icon(),
-                })
+            let now = std::time::SystemTime::now()
+                .duration_since(std::time::UNIX_EPOCH)
+                .unwrap_or_default()
+                .as_millis();
+            eprintln!(
+                "[{}ms] all_models: calling provided_models for {:?}",
+                now,
+                provider.id()
+            );
+            let models = provider.provided_models(cx);
+            eprintln!(
+                "[{}ms] all_models: provider {:?} returned {} models",
+                now,
+                provider.id(),
+                models.len()
+            );
+            models.into_iter().map(|model| ModelInfo {
+                model,
+                icon: provider.icon(),
+            })
         })
         .collect();
 
+    let now = std::time::SystemTime::now()
+        .duration_since(std::time::UNIX_EPOCH)
+        .unwrap_or_default()
+        .as_millis();
+    eprintln!(
+        "[{}ms] all_models: returning {} total models",
+        now,
+        all.len()
+    );
     GroupedModels::new(all, recommended)
 }
 
@@ -91,7 +118,7 @@ pub struct LanguageModelPickerDelegate {
     filtered_entries: Vec<LanguageModelPickerEntry>,
     selected_index: usize,
     _authenticate_all_providers_task: Task<()>,
-    _subscriptions: Vec<Subscription>,
+    _refresh_models_task: Task<()>,
     popover_styles: bool,
     focus_handle: FocusHandle,
 }
@@ -105,8 +132,18 @@ impl LanguageModelPickerDelegate {
         window: &mut Window,
         cx: &mut Context<Picker<Self>>,
     ) -> Self {
+        let now = std::time::SystemTime::now()
+            .duration_since(std::time::UNIX_EPOCH)
+            .unwrap_or_default()
+            .as_millis();
+        eprintln!("[{}ms] LanguageModelPickerDelegate::new() called", now);
         let on_model_changed = Arc::new(on_model_changed);
         let models = all_models(cx);
+        eprintln!(
+            "[{}ms] LanguageModelPickerDelegate::new() got {} models from all_models()",
+            now,
+            models.all.len()
+        );
         let entries = models.entries();
 
         Self {
@@ -116,24 +153,88 @@ impl LanguageModelPickerDelegate {
             filtered_entries: entries,
             get_active_model: Arc::new(get_active_model),
             _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
-            _subscriptions: vec![cx.subscribe_in(
-                &LanguageModelRegistry::global(cx),
-                window,
-                |picker, _, event, window, cx| {
+            _refresh_models_task: {
+                let now = std::time::SystemTime::now()
+                    .duration_since(std::time::UNIX_EPOCH)
+                    .unwrap_or_default()
+                    .as_millis();
+                eprintln!(
+                    "[{}ms] LanguageModelPickerDelegate::new() setting up refresh task for LanguageModelRegistry",
+                    now
+                );
+
+                // Create a channel to signal when models need refreshing
+                let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
+
+                // Subscribe to registry events and send refresh signals through the channel
+                let registry = LanguageModelRegistry::global(cx);
+                eprintln!(
+                    "[{}ms] LanguageModelPickerDelegate::new() subscribing to registry entity_id: {:?}",
+                    now,
+                    registry.entity_id()
+                );
+                cx.subscribe(&registry, move |_picker, _, event, _cx| {
                     match event {
-                        language_model::Event::ProviderStateChanged(_)
-                        | language_model::Event::AddedProvider(_)
-                        | language_model::Event::RemovedProvider(_) => {
-                            let query = picker.query(cx);
-                            picker.delegate.all_models = Arc::new(all_models(cx));
-                            // Update matches will automatically drop the previous task
-                            // if we get a provider event again
-                            picker.update_matches(query, window, cx)
+                        language_model::Event::ProviderStateChanged(id) => {
+                            let now = std::time::SystemTime::now()
+                                .duration_since(std::time::UNIX_EPOCH)
+                                .unwrap_or_default()
+                                .as_millis();
+                            eprintln!(
+                                "[{}ms] LanguageModelSelector: ProviderStateChanged event for {:?}, sending refresh signal",
+                                now, id
+                            );
+                            refresh_tx.unbounded_send(()).ok();
+                        }
+                        language_model::Event::AddedProvider(id) => {
+                            let now = std::time::SystemTime::now()
+                                .duration_since(std::time::UNIX_EPOCH)
+                                .unwrap_or_default()
+                                .as_millis();
+                            eprintln!(
+                                "[{}ms] LanguageModelSelector: AddedProvider event for {:?}, sending refresh signal",
+                                now, id
+                            );
+                            refresh_tx.unbounded_send(()).ok();
+                        }
+                        language_model::Event::RemovedProvider(id) => {
+                            let now = std::time::SystemTime::now()
+                                .duration_since(std::time::UNIX_EPOCH)
+                                .unwrap_or_default()
+                                .as_millis();
+                            eprintln!(
+                                "[{}ms] LanguageModelSelector: RemovedProvider event for {:?}, sending refresh signal",
+                                now, id
+                            );
+                            refresh_tx.unbounded_send(()).ok();
                         }
                         _ => {}
                     }
-                },
-            )],
+                })
+                .detach();
+
+                // Spawn a task that listens for refresh signals and updates the picker
+                cx.spawn_in(window, async move |this, cx| {
+                    while let Some(()) = refresh_rx.next().await {
+                        let now = std::time::SystemTime::now()
+                            .duration_since(std::time::UNIX_EPOCH)
+                            .unwrap_or_default()
+                            .as_millis();
+                        eprintln!(
+                            "[{}ms] LanguageModelSelector: refresh signal received, updating models",
+                            now
+                        );
+                        let result = this.update_in(cx, |picker, window, cx| {
+                            picker.delegate.all_models = Arc::new(all_models(cx));
+                            picker.refresh(window, cx);
+                        });
+                        if result.is_err() {
+                            // Picker was dropped, exit the loop
+                            break;
+                        }
+                    }
+                })
+            },
             popover_styles,
             focus_handle,
         }

crates/extension/src/extension_host_proxy.rs 🔗

@@ -414,9 +414,14 @@ impl ExtensionLanguageModelProviderProxy for ExtensionHostProxy {
         cx: &mut App,
     ) {
         let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
+            eprintln!(
+                "Failed to register LLM provider '{}': language_model_provider_proxy not yet initialized",
+                provider_id
+            );
             return;
         };
 
+        eprintln!("Registering LLM provider: {}", provider_id);
         proxy.register_language_model_provider(provider_id, register_fn, cx)
     }
 

crates/extension_host/src/extension_host.rs 🔗

@@ -1364,7 +1364,7 @@ impl ExtensionStore {
             let mut wasm_extensions: Vec<(
                 Arc<ExtensionManifest>,
                 WasmExtension,
-                Vec<(LlmProviderInfo, Vec<LlmModelInfo>)>,
+                Vec<(LlmProviderInfo, Vec<LlmModelInfo>, bool)>,
             )> = Vec::new();
             for extension in extension_entries {
                 if extension.manifest.lib.kind.is_none() {
@@ -1384,8 +1384,14 @@ impl ExtensionStore {
                 match wasm_extension {
                     Ok(wasm_extension) => {
                         // Query for LLM providers if the manifest declares any
+                        // Tuple is (provider_info, models, is_authenticated)
                         let mut llm_providers_with_models = Vec::new();
                         if !extension.manifest.language_model_providers.is_empty() {
+                            eprintln!(
+                                "Extension {} declares {} LLM providers in manifest, querying...",
+                                extension.manifest.id,
+                                extension.manifest.language_model_providers.len()
+                            );
                             let providers_result = wasm_extension
                                 .call(|ext, store| {
                                     async move { ext.call_llm_providers(store).await }.boxed()
@@ -1393,6 +1399,11 @@ impl ExtensionStore {
                                 .await;
 
                             if let Ok(Ok(providers)) = providers_result {
+                                eprintln!(
+                                    "Extension {} returned {} LLM providers",
+                                    extension.manifest.id,
+                                    providers.len()
+                                );
                                 for provider_info in providers {
                                     let models_result = wasm_extension
                                         .call({
@@ -1410,7 +1421,7 @@ impl ExtensionStore {
                                     let models: Vec<LlmModelInfo> = match models_result {
                                         Ok(Ok(Ok(models))) => models,
                                         Ok(Ok(Err(e))) => {
-                                            log::error!(
+                                            eprintln!(
                                                 "Failed to get models for LLM provider {} in extension {}: {}",
                                                 provider_info.id,
                                                 extension.manifest.id,
@@ -1419,7 +1430,7 @@ impl ExtensionStore {
                                             Vec::new()
                                         }
                                         Ok(Err(e)) => {
-                                            log::error!(
+                                            eprintln!(
                                                 "Wasm error calling llm_provider_models for {} in extension {}: {:?}",
                                                 provider_info.id,
                                                 extension.manifest.id,
@@ -1428,7 +1439,7 @@ impl ExtensionStore {
                                             Vec::new()
                                         }
                                         Err(e) => {
-                                            log::error!(
+                                            eprintln!(
                                                 "Extension call failed for llm_provider_models {} in extension {}: {:?}",
                                                 provider_info.id,
                                                 extension.manifest.id,
@@ -1438,8 +1449,40 @@ impl ExtensionStore {
                                         }
                                     };
 
-                                    llm_providers_with_models.push((provider_info, models));
+                                    // Query initial authentication state
+                                    let is_authenticated = wasm_extension
+                                        .call({
+                                            let provider_id = provider_info.id.clone();
+                                            |ext, store| {
+                                                async move {
+                                                    ext.call_llm_provider_is_authenticated(
+                                                        store,
+                                                        &provider_id,
+                                                    )
+                                                    .await
+                                                }
+                                                .boxed()
+                                            }
+                                        })
+                                        .await
+                                        .unwrap_or(Ok(false))
+                                        .unwrap_or(false);
+
+                                    eprintln!(
+                                        "LLM provider {} has {} models, is_authenticated={}",
+                                        provider_info.id,
+                                        models.len(),
+                                        is_authenticated
+                                    );
+                                    llm_providers_with_models
+                                        .push((provider_info, models, is_authenticated));
                                 }
+                            } else {
+                                eprintln!(
+                                    "Failed to get LLM providers from extension {}: {:?}",
+                                    extension.manifest.id,
+                                    providers_result
+                                );
                             }
                         }
 
@@ -1522,28 +1565,34 @@ impl ExtensionStore {
                     }
 
                     // Register LLM providers
-                    for (provider_info, models) in llm_providers_with_models {
+                    for (provider_info, models, is_authenticated) in llm_providers_with_models {
                         let provider_id: Arc<str> =
                             format!("{}:{}", manifest.id, provider_info.id).into();
-                        let wasm_ext = wasm_extension.clone();
+                        let wasm_ext = extension.as_ref().clone();
                         let pinfo = provider_info.clone();
                         let mods = models.clone();
+                        let auth = *is_authenticated;
 
                         this.proxy.register_language_model_provider(
-                            provider_id,
+                            provider_id.clone(),
                             Box::new(move |cx: &mut App| {
+                                eprintln!("register_fn closure called, creating provider");
                                 let provider = Arc::new(ExtensionLanguageModelProvider::new(
-                                    wasm_ext, pinfo, mods, cx,
+                                    wasm_ext, pinfo, mods, auth, cx,
                                 ));
+                                eprintln!("Provider created, registering with registry");
                                 language_model::LanguageModelRegistry::global(cx).update(
                                     cx,
                                     |registry, cx| {
+                                        eprintln!("Inside registry.register_provider");
                                         registry.register_provider(provider, cx);
                                     },
                                 );
+                                eprintln!("Provider registered");
                             }),
                             cx,
                         );
+                        eprintln!("register_language_model_provider call completed for {}", provider_id);
                     }
                 }
 

crates/extension_host/src/wasm_host.rs 🔗

@@ -73,12 +73,6 @@ pub struct WasmExtension {
     _task: Arc<Task<Result<(), gpui_tokio::JoinError>>>,
 }
 
-impl Drop for WasmExtension {
-    fn drop(&mut self) {
-        self.tx.close_channel();
-    }
-}
-
 #[async_trait]
 impl extension::Extension for WasmExtension {
     fn manifest(&self) -> Arc<ExtensionManifest> {

crates/extension_host/src/wasm_host/llm_provider.rs 🔗

@@ -40,10 +40,11 @@ impl ExtensionLanguageModelProvider {
         extension: WasmExtension,
         provider_info: LlmProviderInfo,
         models: Vec<LlmModelInfo>,
+        is_authenticated: bool,
         cx: &mut App,
     ) -> Self {
         let state = cx.new(|_| ExtensionLlmProviderState {
-            is_authenticated: false,
+            is_authenticated,
             available_models: models,
         });
 
@@ -61,7 +62,9 @@ impl ExtensionLanguageModelProvider {
 
 impl LanguageModelProvider for ExtensionLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId::from(self.provider_id_string())
+        let id = LanguageModelProviderId::from(self.provider_id_string());
+        eprintln!("ExtensionLanguageModelProvider::id() -> {:?}", id);
+        id
     }
 
     fn name(&self) -> LanguageModelProviderName {
@@ -111,10 +114,16 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
 
     fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
         let state = self.state.read(cx);
+        eprintln!(
+            "ExtensionLanguageModelProvider::provided_models called for {}, returning {} models",
+            self.provider_info.name,
+            state.available_models.len()
+        );
         state
             .available_models
             .iter()
             .map(|model_info| {
+                eprintln!("  - model: {}", model_info.name);
                 Arc::new(ExtensionLanguageModel {
                     extension: self.extension.clone(),
                     model_info: model_info.clone(),

crates/language_model/src/registry.rs 🔗

@@ -153,7 +153,19 @@ impl LanguageModelRegistry {
         }
 
         self.providers.insert(id.clone(), provider);
-        cx.emit(Event::AddedProvider(id));
+        let now = std::time::SystemTime::now()
+            .duration_since(std::time::UNIX_EPOCH)
+            .unwrap_or_default()
+            .as_millis();
+        eprintln!(
+            "[{}ms] LanguageModelRegistry: About to emit AddedProvider event for {:?}",
+            now, id
+        );
+        cx.emit(Event::AddedProvider(id.clone()));
+        eprintln!(
+            "[{}ms] LanguageModelRegistry: Emitted AddedProvider event for {:?}",
+            now, id
+        );
     }
 
     pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
@@ -163,6 +175,18 @@ impl LanguageModelRegistry {
     }
 
     pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
+        let now = std::time::SystemTime::now()
+            .duration_since(std::time::UNIX_EPOCH)
+            .unwrap_or_default()
+            .as_millis();
+        eprintln!(
+            "[{}ms] LanguageModelRegistry::providers() called, {} providers in registry",
+            now,
+            self.providers.len()
+        );
+        for (id, _) in &self.providers {
+            eprintln!("  - provider: {:?}", id);
+        }
         let zed_provider_id = LanguageModelProviderId("zed.dev".into());
         let mut providers = Vec::with_capacity(self.providers.len());
         if let Some(provider) = self.providers.get(&zed_provider_id) {
@@ -175,6 +199,11 @@ impl LanguageModelRegistry {
                 None
             }
         }));
+        eprintln!(
+            "[{}ms] LanguageModelRegistry::providers() returning {} providers",
+            now,
+            providers.len()
+        );
         providers
     }
 

crates/language_models/src/extension.rs 🔗

@@ -18,11 +18,19 @@ impl ExtensionLanguageModelProxy {
 impl ExtensionLanguageModelProviderProxy for ExtensionLanguageModelProxy {
     fn register_language_model_provider(
         &self,
-        _provider_id: Arc<str>,
+        provider_id: Arc<str>,
         register_fn: LanguageModelProviderRegistration,
         cx: &mut App,
     ) {
+        eprintln!(
+            "ExtensionLanguageModelProxy::register_language_model_provider called for {}",
+            provider_id
+        );
         register_fn(cx);
+        eprintln!(
+            "ExtensionLanguageModelProxy::register_language_model_provider completed for {}",
+            provider_id
+        );
     }
 
     fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {