From d7caae30de9520b8110e01f13e0c66de9d7e7eff Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Wed, 3 Dec 2025 13:00:53 -0500 Subject: [PATCH] Fix auth and subscriptions for provider extensions --- crates/agent/src/agent.rs | 27 ++++ crates/agent_ui/src/agent_ui.rs | 34 +++- .../agent_ui/src/language_model_selector.rs | 153 +++++++++++++++--- crates/extension/src/extension_host_proxy.rs | 5 + crates/extension_host/src/extension_host.rs | 67 ++++++-- crates/extension_host/src/wasm_host.rs | 6 - .../src/wasm_host/llm_provider.rs | 13 +- crates/language_model/src/registry.rs | 31 +++- crates/language_models/src/extension.rs | 10 +- 9 files changed, 298 insertions(+), 48 deletions(-) diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 404cd6549e5786b92c49379918346b83fcc0e0c1..fdfcf8ca5863ff50678ae6a86767a15a661a220a 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -103,12 +103,22 @@ impl LanguageModels { } fn refresh_list(&mut self, cx: &App) { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!("[{}ms] LanguageModels::refresh_list called", now); let providers = LanguageModelRegistry::global(cx) .read(cx) .providers() .into_iter() .filter(|provider| provider.is_authenticated(cx)) .collect::>(); + eprintln!( + "[{}ms] LanguageModels::refresh_list got {} authenticated providers", + now, + providers.len() + ); let mut language_model_list = IndexMap::default(); let mut recommended_models = HashSet::default(); @@ -146,6 +156,15 @@ impl LanguageModels { self.models = models; self.model_list = acp_thread::AgentModelList::Grouped(language_model_list); + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] LanguageModels::refresh_list completed with {} models in list", + now, + self.models.len() + ); self.refresh_models_tx.send(()).ok(); } @@ -603,6 +622,14 @@ impl NativeAgent { _event: &language_model::Event, cx: &mut Context, ) { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] NativeAgent::handle_models_updated_event called", + now + ); self.models.refresh_list(cx); let registry = LanguageModelRegistry::read_global(cx); diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 5f5682b7dcc90d2b779744ba353380987a5907a1..f5fd35f5a7fe0b7b8ba7b1d2865b522e57bde21a 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -344,9 +344,37 @@ fn init_language_model_settings(cx: &mut App) { cx.subscribe( &LanguageModelRegistry::global(cx), |_, event: &language_model::Event, cx| match event { - language_model::Event::ProviderStateChanged(_) - | language_model::Event::AddedProvider(_) - | language_model::Event::RemovedProvider(_) => { + language_model::Event::ProviderStateChanged(id) => { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] agent_ui global subscription: ProviderStateChanged for {:?}", + now, id + ); + update_active_language_model_from_settings(cx); + } + language_model::Event::AddedProvider(id) => { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] agent_ui global subscription: AddedProvider for {:?}", + now, id + ); + update_active_language_model_from_settings(cx); + } + language_model::Event::RemovedProvider(id) => { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] agent_ui global subscription: RemovedProvider for {:?}", + now, id + ); update_active_language_model_from_settings(cx); } _ => {} diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 5b5a4513c6dca32e985c966e07ad84e84fc9a872..05384b4a93083bbc471f07b83b00e4bff4a82604 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -1,10 +1,9 @@ use std::{cmp::Reverse, sync::Arc}; use collections::IndexMap; +use futures::{StreamExt, channel::mpsc}; use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; -use gpui::{ - Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task, -}; +use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task}; use language_model::{ AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, @@ -47,7 +46,13 @@ pub fn language_model_selector( } fn all_models(cx: &App) -> GroupedModels { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!("[{}ms] all_models() called", now); let providers = LanguageModelRegistry::global(cx).read(cx).providers(); + eprintln!("[{}ms] all_models: got {} providers", now, providers.len()); let recommended = providers .iter() @@ -62,19 +67,41 @@ fn all_models(cx: &App) -> GroupedModels { }) .collect(); - let all = providers + let all: Vec = providers .iter() .flat_map(|provider| { - provider - .provided_models(cx) - .into_iter() - .map(|model| ModelInfo { - model, - icon: provider.icon(), - }) + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] all_models: calling provided_models for {:?}", + now, + provider.id() + ); + let models = provider.provided_models(cx); + eprintln!( + "[{}ms] all_models: provider {:?} returned {} models", + now, + provider.id(), + models.len() + ); + models.into_iter().map(|model| ModelInfo { + model, + icon: provider.icon(), + }) }) .collect(); + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] all_models: returning {} total models", + now, + all.len() + ); GroupedModels::new(all, recommended) } @@ -91,7 +118,7 @@ pub struct LanguageModelPickerDelegate { filtered_entries: Vec, selected_index: usize, _authenticate_all_providers_task: Task<()>, - _subscriptions: Vec, + _refresh_models_task: Task<()>, popover_styles: bool, focus_handle: FocusHandle, } @@ -105,8 +132,18 @@ impl LanguageModelPickerDelegate { window: &mut Window, cx: &mut Context>, ) -> Self { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!("[{}ms] LanguageModelPickerDelegate::new() called", now); let on_model_changed = Arc::new(on_model_changed); let models = all_models(cx); + eprintln!( + "[{}ms] LanguageModelPickerDelegate::new() got {} models from all_models()", + now, + models.all.len() + ); let entries = models.entries(); Self { @@ -116,24 +153,88 @@ impl LanguageModelPickerDelegate { filtered_entries: entries, get_active_model: Arc::new(get_active_model), _authenticate_all_providers_task: Self::authenticate_all_providers(cx), - _subscriptions: vec![cx.subscribe_in( - &LanguageModelRegistry::global(cx), - window, - |picker, _, event, window, cx| { + _refresh_models_task: { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] LanguageModelPickerDelegate::new() setting up refresh task for LanguageModelRegistry", + now + ); + + // Create a channel to signal when models need refreshing + let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>(); + + // Subscribe to registry events and send refresh signals through the channel + let registry = LanguageModelRegistry::global(cx); + eprintln!( + "[{}ms] LanguageModelPickerDelegate::new() subscribing to registry entity_id: {:?}", + now, + registry.entity_id() + ); + cx.subscribe(®istry, move |_picker, _, event, _cx| { match event { - language_model::Event::ProviderStateChanged(_) - | language_model::Event::AddedProvider(_) - | language_model::Event::RemovedProvider(_) => { - let query = picker.query(cx); - picker.delegate.all_models = Arc::new(all_models(cx)); - // Update matches will automatically drop the previous task - // if we get a provider event again - picker.update_matches(query, window, cx) + language_model::Event::ProviderStateChanged(id) => { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] LanguageModelSelector: ProviderStateChanged event for {:?}, sending refresh signal", + now, id + ); + refresh_tx.unbounded_send(()).ok(); + } + language_model::Event::AddedProvider(id) => { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] LanguageModelSelector: AddedProvider event for {:?}, sending refresh signal", + now, id + ); + refresh_tx.unbounded_send(()).ok(); + } + language_model::Event::RemovedProvider(id) => { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] LanguageModelSelector: RemovedProvider event for {:?}, sending refresh signal", + now, id + ); + refresh_tx.unbounded_send(()).ok(); } _ => {} } - }, - )], + }) + .detach(); + + // Spawn a task that listens for refresh signals and updates the picker + cx.spawn_in(window, async move |this, cx| { + while let Some(()) = refresh_rx.next().await { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] LanguageModelSelector: refresh signal received, updating models", + now + ); + let result = this.update_in(cx, |picker, window, cx| { + picker.delegate.all_models = Arc::new(all_models(cx)); + picker.refresh(window, cx); + }); + if result.is_err() { + // Picker was dropped, exit the loop + break; + } + } + }) + }, popover_styles, focus_handle, } diff --git a/crates/extension/src/extension_host_proxy.rs b/crates/extension/src/extension_host_proxy.rs index dc395d6a937c7b72a3e3a95ff9fc0513a4088e3d..6fe9e1f8084cf5074dddaac9263cee666153640c 100644 --- a/crates/extension/src/extension_host_proxy.rs +++ b/crates/extension/src/extension_host_proxy.rs @@ -414,9 +414,14 @@ impl ExtensionLanguageModelProviderProxy for ExtensionHostProxy { cx: &mut App, ) { let Some(proxy) = self.language_model_provider_proxy.read().clone() else { + eprintln!( + "Failed to register LLM provider '{}': language_model_provider_proxy not yet initialized", + provider_id + ); return; }; + eprintln!("Registering LLM provider: {}", provider_id); proxy.register_language_model_provider(provider_id, register_fn, cx) } diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index f399354ecec6a1890a99b85fe50a002fab31e15c..1fd568e9063817386f5868836fcd27259d251aeb 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -1364,7 +1364,7 @@ impl ExtensionStore { let mut wasm_extensions: Vec<( Arc, WasmExtension, - Vec<(LlmProviderInfo, Vec)>, + Vec<(LlmProviderInfo, Vec, bool)>, )> = Vec::new(); for extension in extension_entries { if extension.manifest.lib.kind.is_none() { @@ -1384,8 +1384,14 @@ impl ExtensionStore { match wasm_extension { Ok(wasm_extension) => { // Query for LLM providers if the manifest declares any + // Tuple is (provider_info, models, is_authenticated) let mut llm_providers_with_models = Vec::new(); if !extension.manifest.language_model_providers.is_empty() { + eprintln!( + "Extension {} declares {} LLM providers in manifest, querying...", + extension.manifest.id, + extension.manifest.language_model_providers.len() + ); let providers_result = wasm_extension .call(|ext, store| { async move { ext.call_llm_providers(store).await }.boxed() @@ -1393,6 +1399,11 @@ impl ExtensionStore { .await; if let Ok(Ok(providers)) = providers_result { + eprintln!( + "Extension {} returned {} LLM providers", + extension.manifest.id, + providers.len() + ); for provider_info in providers { let models_result = wasm_extension .call({ @@ -1410,7 +1421,7 @@ impl ExtensionStore { let models: Vec = match models_result { Ok(Ok(Ok(models))) => models, Ok(Ok(Err(e))) => { - log::error!( + eprintln!( "Failed to get models for LLM provider {} in extension {}: {}", provider_info.id, extension.manifest.id, @@ -1419,7 +1430,7 @@ impl ExtensionStore { Vec::new() } Ok(Err(e)) => { - log::error!( + eprintln!( "Wasm error calling llm_provider_models for {} in extension {}: {:?}", provider_info.id, extension.manifest.id, @@ -1428,7 +1439,7 @@ impl ExtensionStore { Vec::new() } Err(e) => { - log::error!( + eprintln!( "Extension call failed for llm_provider_models {} in extension {}: {:?}", provider_info.id, extension.manifest.id, @@ -1438,8 +1449,40 @@ impl ExtensionStore { } }; - llm_providers_with_models.push((provider_info, models)); + // Query initial authentication state + let is_authenticated = wasm_extension + .call({ + let provider_id = provider_info.id.clone(); + |ext, store| { + async move { + ext.call_llm_provider_is_authenticated( + store, + &provider_id, + ) + .await + } + .boxed() + } + }) + .await + .unwrap_or(Ok(false)) + .unwrap_or(false); + + eprintln!( + "LLM provider {} has {} models, is_authenticated={}", + provider_info.id, + models.len(), + is_authenticated + ); + llm_providers_with_models + .push((provider_info, models, is_authenticated)); } + } else { + eprintln!( + "Failed to get LLM providers from extension {}: {:?}", + extension.manifest.id, + providers_result + ); } } @@ -1522,28 +1565,34 @@ impl ExtensionStore { } // Register LLM providers - for (provider_info, models) in llm_providers_with_models { + for (provider_info, models, is_authenticated) in llm_providers_with_models { let provider_id: Arc = format!("{}:{}", manifest.id, provider_info.id).into(); - let wasm_ext = wasm_extension.clone(); + let wasm_ext = extension.as_ref().clone(); let pinfo = provider_info.clone(); let mods = models.clone(); + let auth = *is_authenticated; this.proxy.register_language_model_provider( - provider_id, + provider_id.clone(), Box::new(move |cx: &mut App| { + eprintln!("register_fn closure called, creating provider"); let provider = Arc::new(ExtensionLanguageModelProvider::new( - wasm_ext, pinfo, mods, cx, + wasm_ext, pinfo, mods, auth, cx, )); + eprintln!("Provider created, registering with registry"); language_model::LanguageModelRegistry::global(cx).update( cx, |registry, cx| { + eprintln!("Inside registry.register_provider"); registry.register_provider(provider, cx); }, ); + eprintln!("Provider registered"); }), cx, ); + eprintln!("register_language_model_provider call completed for {}", provider_id); } } diff --git a/crates/extension_host/src/wasm_host.rs b/crates/extension_host/src/wasm_host.rs index cd0b99cc02499bbce73cf46e1822b9cebecd2aa3..93b8a1de9b723e27a69720051ab4443277b38fcb 100644 --- a/crates/extension_host/src/wasm_host.rs +++ b/crates/extension_host/src/wasm_host.rs @@ -73,12 +73,6 @@ pub struct WasmExtension { _task: Arc>>, } -impl Drop for WasmExtension { - fn drop(&mut self) { - self.tx.close_channel(); - } -} - #[async_trait] impl extension::Extension for WasmExtension { fn manifest(&self) -> Arc { diff --git a/crates/extension_host/src/wasm_host/llm_provider.rs b/crates/extension_host/src/wasm_host/llm_provider.rs index 7e98e0b400ef4db3fb903758c0859f35b8c62839..02cc9722f3ca8af3537a7977b37a39895cc0e278 100644 --- a/crates/extension_host/src/wasm_host/llm_provider.rs +++ b/crates/extension_host/src/wasm_host/llm_provider.rs @@ -40,10 +40,11 @@ impl ExtensionLanguageModelProvider { extension: WasmExtension, provider_info: LlmProviderInfo, models: Vec, + is_authenticated: bool, cx: &mut App, ) -> Self { let state = cx.new(|_| ExtensionLlmProviderState { - is_authenticated: false, + is_authenticated, available_models: models, }); @@ -61,7 +62,9 @@ impl ExtensionLanguageModelProvider { impl LanguageModelProvider for ExtensionLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId::from(self.provider_id_string()) + let id = LanguageModelProviderId::from(self.provider_id_string()); + eprintln!("ExtensionLanguageModelProvider::id() -> {:?}", id); + id } fn name(&self) -> LanguageModelProviderName { @@ -111,10 +114,16 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { fn provided_models(&self, cx: &App) -> Vec> { let state = self.state.read(cx); + eprintln!( + "ExtensionLanguageModelProvider::provided_models called for {}, returning {} models", + self.provider_info.name, + state.available_models.len() + ); state .available_models .iter() .map(|model_info| { + eprintln!(" - model: {}", model_info.name); Arc::new(ExtensionLanguageModel { extension: self.extension.clone(), model_info: model_info.clone(), diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 6ed8bf07c4e976c88fecebd929843335333b1fa6..e7b511e8d9de08a6c6f92df567743cb17b8a2a31 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -153,7 +153,19 @@ impl LanguageModelRegistry { } self.providers.insert(id.clone(), provider); - cx.emit(Event::AddedProvider(id)); + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] LanguageModelRegistry: About to emit AddedProvider event for {:?}", + now, id + ); + cx.emit(Event::AddedProvider(id.clone())); + eprintln!( + "[{}ms] LanguageModelRegistry: Emitted AddedProvider event for {:?}", + now, id + ); } pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context) { @@ -163,6 +175,18 @@ impl LanguageModelRegistry { } pub fn providers(&self) -> Vec> { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + eprintln!( + "[{}ms] LanguageModelRegistry::providers() called, {} providers in registry", + now, + self.providers.len() + ); + for (id, _) in &self.providers { + eprintln!(" - provider: {:?}", id); + } let zed_provider_id = LanguageModelProviderId("zed.dev".into()); let mut providers = Vec::with_capacity(self.providers.len()); if let Some(provider) = self.providers.get(&zed_provider_id) { @@ -175,6 +199,11 @@ impl LanguageModelRegistry { None } })); + eprintln!( + "[{}ms] LanguageModelRegistry::providers() returning {} providers", + now, + providers.len() + ); providers } diff --git a/crates/language_models/src/extension.rs b/crates/language_models/src/extension.rs index 139ff632657cd8a47feba9c4fd14aed07589d1f6..e17e581fef6bafb86bdb32bceaed49bb61c6f4c0 100644 --- a/crates/language_models/src/extension.rs +++ b/crates/language_models/src/extension.rs @@ -18,11 +18,19 @@ impl ExtensionLanguageModelProxy { impl ExtensionLanguageModelProviderProxy for ExtensionLanguageModelProxy { fn register_language_model_provider( &self, - _provider_id: Arc, + provider_id: Arc, register_fn: LanguageModelProviderRegistration, cx: &mut App, ) { + eprintln!( + "ExtensionLanguageModelProxy::register_language_model_provider called for {}", + provider_id + ); register_fn(cx); + eprintln!( + "ExtensionLanguageModelProxy::register_language_model_provider completed for {}", + provider_id + ); } fn unregister_language_model_provider(&self, provider_id: Arc, cx: &mut App) {