diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index fcb901347a12798aa8e2e40942f88b47beee011d..553858881a0144e1808c044592cb5b25b63229e0 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -201,7 +201,7 @@ impl LanguageModels { .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx))) .collect::>(); - cx.background_spawn(async move { + cx.spawn(async move |cx| { for (provider_id, provider_name, authenticate_task) in authenticate_all_providers { if let Err(err) = authenticate_task.await { match err { @@ -244,6 +244,8 @@ impl LanguageModels { } } } + + cx.update(language_models::update_environment_fallback_model); }) } } @@ -365,7 +367,7 @@ impl NativeAgent { }); let registry = LanguageModelRegistry::read_global(cx); - let summarization_model = registry.thread_summary_model().map(|c| c.model); + let summarization_model = registry.thread_summary_model(cx).map(|c| c.model); let weak = cx.weak_entity(); let weak_thread = thread_handle.downgrade(); @@ -749,13 +751,14 @@ impl NativeAgent { let registry = LanguageModelRegistry::read_global(cx); let default_model = registry.default_model().map(|m| m.model); - let summarization_model = registry.thread_summary_model().map(|m| m.model); + let summarization_model = registry.thread_summary_model(cx).map(|m| m.model); for session in self.sessions.values_mut() { session.thread.update(cx, |thread, cx| { - if thread.model().is_none() - && let Some(model) = default_model.clone() - { + let should_update_model = thread.model().is_none() + || (thread.is_empty() + && matches!(event, language_model::Event::DefaultModelChanged)); + if should_update_model && let Some(model) = default_model.clone() { thread.set_model(model, cx); cx.notify(); } @@ -910,7 +913,7 @@ impl NativeAgent { .get(&project_id) .context("project state not found")?; let summarization_model = LanguageModelRegistry::read_global(cx) - .thread_summary_model() + .thread_summary_model(cx) .map(|c| c.model); Ok(cx.new(|cx| { diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 899542245ab8f3618f6d70d807363cc91af3a257..7de58fd54ffd0d984b3a6079681f15f6a56507ae 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -8,8 +8,8 @@ use gpui::{ Subscription, Task, }; use language_model::{ - AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId, - LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, + ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId, LanguageModelProvider, + LanguageModelProviderId, LanguageModelRegistry, }; use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; @@ -124,7 +124,6 @@ pub struct LanguageModelPickerDelegate { all_models: Arc, filtered_entries: Vec, selected_index: usize, - _authenticate_all_providers_task: Task<()>, _subscriptions: Vec, popover_styles: bool, focus_handle: FocusHandle, @@ -151,7 +150,6 @@ impl LanguageModelPickerDelegate { filtered_entries: entries, get_active_model: Arc::new(get_active_model), on_toggle_favorite: Arc::new(on_toggle_favorite), - _authenticate_all_providers_task: Self::authenticate_all_providers(cx), _subscriptions: vec![cx.subscribe_in( &LanguageModelRegistry::global(cx), window, @@ -197,56 +195,6 @@ impl LanguageModelPickerDelegate { .unwrap_or(0) } - /// Authenticates all providers in the [`LanguageModelRegistry`]. - /// - /// We do this so that we can populate the language selector with all of the - /// models from the configured providers. - fn authenticate_all_providers(cx: &mut App) -> Task<()> { - let authenticate_all_providers = LanguageModelRegistry::global(cx) - .read(cx) - .visible_providers() - .iter() - .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx))) - .collect::>(); - - cx.spawn(async move |_cx| { - for (provider_id, provider_name, authenticate_task) in authenticate_all_providers { - if let Err(err) = authenticate_task.await { - if matches!(err, AuthenticateError::CredentialsNotFound) { - // Since we're authenticating these providers in the - // background for the purposes of populating the - // language selector, we don't care about providers - // where the credentials are not found. - } else { - // Some providers have noisy failure states that we - // don't want to spam the logs with every time the - // language model selector is initialized. - // - // Ideally these should have more clear failure modes - // that we know are safe to ignore here, like what we do - // with `CredentialsNotFound` above. - match provider_id.0.as_ref() { - "lmstudio" | "ollama" => { - // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated". - // - // These fail noisily, so we don't log them. - } - "copilot_chat" => { - // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors. - } - _ => { - log::error!( - "Failed to authenticate provider: {}: {err:#}", - provider_name.0 - ); - } - } - } - } - } - }) - } - pub fn active_model(&self, cx: &App) -> Option { (self.get_active_model)(cx) } diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index 2357bce33fc2dc01cc009c24b3c0d7685b782840..c8b249a7dff60266f397506b3c79e87fbfcc1dba 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -2688,7 +2688,7 @@ impl GitPanel { } let Some(ConfiguredModel { provider, model }) = - LanguageModelRegistry::read_global(cx).commit_message_model() + LanguageModelRegistry::read_global(cx).commit_message_model(cx) else { return; }; @@ -4056,7 +4056,7 @@ impl GitPanel { let model_registry = LanguageModelRegistry::read_global(cx); let has_commit_model_configuration_error = model_registry - .configuration_error(model_registry.commit_message_model(), cx) + .configuration_error(model_registry.commit_message_model(cx), cx) .is_some(); let can_commit = self.can_commit(); diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 680078808ab33cc2a90caead8b304326beccf11b..219d9f4b39e8facbefc56c479dad8acd0b5c53c5 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -6,7 +6,6 @@ use collections::{BTreeMap, HashSet}; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; use std::{str::FromStr, sync::Arc}; use thiserror::Error; -use util::maybe; /// Function type for checking if a built-in provider should be hidden. /// Returns Some(extension_id) if the provider should be hidden when that extension is installed. @@ -46,7 +45,9 @@ impl std::fmt::Debug for ConfigurationError { #[derive(Default)] pub struct LanguageModelRegistry { default_model: Option, - default_fast_model: Option, + /// This model is automatically configured by a user's environment after + /// authenticating all providers. It's only used when `default_model` is not set. + available_fallback_model: Option, inline_assistant_model: Option, commit_message_model: Option, thread_summary_model: Option, @@ -349,22 +350,29 @@ impl LanguageModelRegistry { } pub fn set_default_model(&mut self, model: Option, cx: &mut Context) { - match (self.default_model.as_ref(), model.as_ref()) { + match (self.default_model(), model.as_ref()) { (Some(old), Some(new)) if old.is_same_as(new) => {} (None, None) => {} _ => cx.emit(Event::DefaultModelChanged), } - self.default_fast_model = maybe!({ - let provider = &model.as_ref()?.provider; - let fast_model = provider.default_fast_model(cx)?; - Some(ConfiguredModel { - provider: provider.clone(), - model: fast_model, - }) - }); self.default_model = model; } + pub fn set_environment_fallback_model( + &mut self, + model: Option, + cx: &mut Context, + ) { + if self.default_model.is_none() { + match (self.available_fallback_model.as_ref(), model.as_ref()) { + (Some(old), Some(new)) if old.is_same_as(new) => {} + (None, None) => {} + _ => cx.emit(Event::DefaultModelChanged), + } + } + self.available_fallback_model = model; + } + pub fn set_inline_assistant_model( &mut self, model: Option, @@ -410,7 +418,18 @@ impl LanguageModelRegistry { return None; } - self.default_model.clone() + self.default_model + .clone() + .or_else(|| self.available_fallback_model.clone()) + } + + pub fn default_fast_model(&self, cx: &App) -> Option { + let configured = self.default_model()?; + let fast_model = configured.provider.default_fast_model(cx)?; + Some(ConfiguredModel { + provider: configured.provider, + model: fast_model, + }) } pub fn inline_assistant_model(&self) -> Option { @@ -424,7 +443,7 @@ impl LanguageModelRegistry { .or_else(|| self.default_model.clone()) } - pub fn commit_message_model(&self) -> Option { + pub fn commit_message_model(&self, cx: &App) -> Option { #[cfg(debug_assertions)] if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { return None; @@ -432,11 +451,11 @@ impl LanguageModelRegistry { self.commit_message_model .clone() - .or_else(|| self.default_fast_model.clone()) - .or_else(|| self.default_model.clone()) + .or_else(|| self.default_fast_model(cx)) + .or_else(|| self.default_model()) } - pub fn thread_summary_model(&self) -> Option { + pub fn thread_summary_model(&self, cx: &App) -> Option { #[cfg(debug_assertions)] if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { return None; @@ -444,8 +463,8 @@ impl LanguageModelRegistry { self.thread_summary_model .clone() - .or_else(|| self.default_fast_model.clone()) - .or_else(|| self.default_model.clone()) + .or_else(|| self.default_fast_model(cx)) + .or_else(|| self.default_model()) } /// The models to use for inline assists. Returns the union of the active @@ -576,6 +595,35 @@ mod tests { assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into()))); } + #[gpui::test] + async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) { + let registry = cx.new(|_| LanguageModelRegistry::default()); + + let provider = Arc::new(FakeLanguageModelProvider::default()); + registry.update(cx, |registry, cx| { + registry.register_provider(provider.clone(), cx); + }); + + cx.update(|cx| provider.authenticate(cx)).await.unwrap(); + + registry.update(cx, |registry, cx| { + let provider = registry.provider(&provider.id()).unwrap(); + let model = provider.default_model(cx).unwrap(); + + registry.set_environment_fallback_model( + Some(ConfiguredModel { + provider: provider.clone(), + model: model.clone(), + }), + cx, + ); + + let default_model = registry.default_model().unwrap(); + assert_eq!(default_model.model.id(), model.id()); + assert_eq!(default_model.provider.id(), provider.id()); + }); + } + #[gpui::test] fn test_sync_installed_llm_extensions(cx: &mut App) { let registry = cx.new(|_| LanguageModelRegistry::default()); diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 3154db91a43d1381f5b3f122a724be249adeb79b..bd29dbe08dbd16af25be4bd55b44067f47fa2a8a 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -5,7 +5,9 @@ use client::{Client, UserStore}; use collections::HashSet; use credentials_provider::CredentialsProvider; use gpui::{App, Context, Entity}; -use language_model::{LanguageModelProviderId, LanguageModelRegistry}; +use language_model::{ + ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID, +}; use provider::deepseek::DeepSeekLanguageModelProvider; pub mod extension; @@ -116,6 +118,20 @@ pub fn init(user_store: Entity, client: Arc, cx: &mut App) { cx, ); }); + + cx.subscribe( + ®istry, + |_registry, event: &language_model::Event, cx| match event { + language_model::Event::ProviderStateChanged(_) + | language_model::Event::AddedProvider(_) + | language_model::Event::RemovedProvider(_) => { + update_environment_fallback_model(cx); + } + _ => {} + }, + ) + .detach(); + let registry = registry.downgrade(); cx.observe_global::(move |cx| { let Some(registry) = registry.upgrade() else { @@ -143,6 +159,50 @@ pub fn init(user_store: Entity, client: Arc, cx: &mut App) { .detach(); } +/// Recomputes and sets the [`LanguageModelRegistry`]'s environment fallback +/// model based on currently authenticated providers. +/// +/// Prefers the Zed cloud provider so that, once the user is signed in, we +/// always pick a Zed-hosted model over models from other authenticated +/// providers in the environment. If the Zed cloud provider is authenticated +/// but hasn't finished loading its models yet, we don't fall back to another +/// provider to avoid flickering between providers during sign in. +pub fn update_environment_fallback_model(cx: &mut App) { + let registry = LanguageModelRegistry::global(cx); + let fallback_model = { + let registry = registry.read(cx); + let cloud_provider = registry.provider(&ZED_CLOUD_PROVIDER_ID); + if cloud_provider + .as_ref() + .is_some_and(|provider| provider.is_authenticated(cx)) + { + cloud_provider.and_then(|provider| { + let model = provider + .default_model(cx) + .or_else(|| provider.recommended_models(cx).first().cloned())?; + Some(ConfiguredModel { provider, model }) + }) + } else { + registry + .providers() + .iter() + .filter(|provider| provider.is_authenticated(cx)) + .find_map(|provider| { + let model = provider + .default_model(cx) + .or_else(|| provider.recommended_models(cx).first().cloned())?; + Some(ConfiguredModel { + provider: provider.clone(), + model, + }) + }) + } + }; + registry.update(cx, |registry, cx| { + registry.set_environment_fallback_model(fallback_model, cx); + }); +} + fn register_openai_compatible_providers( registry: &mut LanguageModelRegistry, old: &HashSet>,