From e92a40a9d897f0ecea7b0c51dae117a22580a333 Mon Sep 17 00:00:00 2001 From: Anthony Eid <56899983+Anthony-Eid@users.noreply.github.com> Date: Thu, 16 Apr 2026 18:35:47 -0400 Subject: [PATCH] agent: Auto-select user model when there's no default (#54125) Reimplements #36722 while fixing the race that required the revert in #36932. When no default model is configured, this picks an environment fallback by authenticating all providers. It always prefers the Zed cloud provider when it's authenticated, and waits for its models to load before picking another provider as the fallback, so we don't flicker from Zed models to Anthropic while sign-in is in flight. The fallback is recomputed whenever provider state changes (via `ProviderStateChanged`/`AddedProvider`/`RemovedProvider` events), so the selection becomes correct as soon as cloud models arrive. ### What changed vs. the original PR - `language_models::init` now owns `authenticate_all_providers` (previously done in `LanguageModelPickerDelegate` and `agent`'s `LanguageModels`). - After all authentications settle, and on any subsequent provider state change, `update_environment_fallback_model` recomputes the fallback. - The fallback logic prefers Zed cloud: if the cloud provider is authenticated, only use it (waiting for its models to load). Otherwise, fall through to the first authenticated provider with a default or recommended model. - `LanguageModelRegistry::default_model()` falls back to `environment_fallback_model` when no explicit default is set. - Existing `Thread`s that are empty are updated to the new default when `DefaultModelChanged` fires, so a blank thread started before sign-in switches to Zed models once the user signs in. Release Notes: - agent: Automatically select a model when there's no selected model or configured default --- crates/agent/src/agent.rs | 17 ++-- .../agent_ui/src/language_model_selector.rs | 56 +------------ crates/git_ui/src/git_panel.rs | 4 +- crates/language_model/src/registry.rs | 84 +++++++++++++++---- crates/language_models/src/language_models.rs | 62 +++++++++++++- 5 files changed, 141 insertions(+), 82 deletions(-) diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index fcb901347a12798aa8e2e40942f88b47beee011d..553858881a0144e1808c044592cb5b25b63229e0 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -201,7 +201,7 @@ impl LanguageModels { .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx))) .collect::>(); - cx.background_spawn(async move { + cx.spawn(async move |cx| { for (provider_id, provider_name, authenticate_task) in authenticate_all_providers { if let Err(err) = authenticate_task.await { match err { @@ -244,6 +244,8 @@ impl LanguageModels { } } } + + cx.update(language_models::update_environment_fallback_model); }) } } @@ -365,7 +367,7 @@ impl NativeAgent { }); let registry = LanguageModelRegistry::read_global(cx); - let summarization_model = registry.thread_summary_model().map(|c| c.model); + let summarization_model = registry.thread_summary_model(cx).map(|c| c.model); let weak = cx.weak_entity(); let weak_thread = thread_handle.downgrade(); @@ -749,13 +751,14 @@ impl NativeAgent { let registry = LanguageModelRegistry::read_global(cx); let default_model = registry.default_model().map(|m| m.model); - let summarization_model = registry.thread_summary_model().map(|m| m.model); + let summarization_model = registry.thread_summary_model(cx).map(|m| m.model); for session in self.sessions.values_mut() { session.thread.update(cx, |thread, cx| { - if thread.model().is_none() - && let Some(model) = default_model.clone() - { + let should_update_model = thread.model().is_none() + || (thread.is_empty() + && matches!(event, language_model::Event::DefaultModelChanged)); + if should_update_model && let Some(model) = default_model.clone() { thread.set_model(model, cx); cx.notify(); } @@ -910,7 +913,7 @@ impl NativeAgent { .get(&project_id) .context("project state not found")?; let summarization_model = LanguageModelRegistry::read_global(cx) - .thread_summary_model() + .thread_summary_model(cx) .map(|c| c.model); Ok(cx.new(|cx| { diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 899542245ab8f3618f6d70d807363cc91af3a257..7de58fd54ffd0d984b3a6079681f15f6a56507ae 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -8,8 +8,8 @@ use gpui::{ Subscription, Task, }; use language_model::{ - AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId, - LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, + ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId, LanguageModelProvider, + LanguageModelProviderId, LanguageModelRegistry, }; use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; @@ -124,7 +124,6 @@ pub struct LanguageModelPickerDelegate { all_models: Arc, filtered_entries: Vec, selected_index: usize, - _authenticate_all_providers_task: Task<()>, _subscriptions: Vec, popover_styles: bool, focus_handle: FocusHandle, @@ -151,7 +150,6 @@ impl LanguageModelPickerDelegate { filtered_entries: entries, get_active_model: Arc::new(get_active_model), on_toggle_favorite: Arc::new(on_toggle_favorite), - _authenticate_all_providers_task: Self::authenticate_all_providers(cx), _subscriptions: vec![cx.subscribe_in( &LanguageModelRegistry::global(cx), window, @@ -197,56 +195,6 @@ impl LanguageModelPickerDelegate { .unwrap_or(0) } - /// Authenticates all providers in the [`LanguageModelRegistry`]. - /// - /// We do this so that we can populate the language selector with all of the - /// models from the configured providers. - fn authenticate_all_providers(cx: &mut App) -> Task<()> { - let authenticate_all_providers = LanguageModelRegistry::global(cx) - .read(cx) - .visible_providers() - .iter() - .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx))) - .collect::>(); - - cx.spawn(async move |_cx| { - for (provider_id, provider_name, authenticate_task) in authenticate_all_providers { - if let Err(err) = authenticate_task.await { - if matches!(err, AuthenticateError::CredentialsNotFound) { - // Since we're authenticating these providers in the - // background for the purposes of populating the - // language selector, we don't care about providers - // where the credentials are not found. - } else { - // Some providers have noisy failure states that we - // don't want to spam the logs with every time the - // language model selector is initialized. - // - // Ideally these should have more clear failure modes - // that we know are safe to ignore here, like what we do - // with `CredentialsNotFound` above. - match provider_id.0.as_ref() { - "lmstudio" | "ollama" => { - // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated". - // - // These fail noisily, so we don't log them. - } - "copilot_chat" => { - // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors. - } - _ => { - log::error!( - "Failed to authenticate provider: {}: {err:#}", - provider_name.0 - ); - } - } - } - } - } - }) - } - pub fn active_model(&self, cx: &App) -> Option { (self.get_active_model)(cx) } diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index 2357bce33fc2dc01cc009c24b3c0d7685b782840..c8b249a7dff60266f397506b3c79e87fbfcc1dba 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -2688,7 +2688,7 @@ impl GitPanel { } let Some(ConfiguredModel { provider, model }) = - LanguageModelRegistry::read_global(cx).commit_message_model() + LanguageModelRegistry::read_global(cx).commit_message_model(cx) else { return; }; @@ -4056,7 +4056,7 @@ impl GitPanel { let model_registry = LanguageModelRegistry::read_global(cx); let has_commit_model_configuration_error = model_registry - .configuration_error(model_registry.commit_message_model(), cx) + .configuration_error(model_registry.commit_message_model(cx), cx) .is_some(); let can_commit = self.can_commit(); diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 680078808ab33cc2a90caead8b304326beccf11b..219d9f4b39e8facbefc56c479dad8acd0b5c53c5 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -6,7 +6,6 @@ use collections::{BTreeMap, HashSet}; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; use std::{str::FromStr, sync::Arc}; use thiserror::Error; -use util::maybe; /// Function type for checking if a built-in provider should be hidden. /// Returns Some(extension_id) if the provider should be hidden when that extension is installed. @@ -46,7 +45,9 @@ impl std::fmt::Debug for ConfigurationError { #[derive(Default)] pub struct LanguageModelRegistry { default_model: Option, - default_fast_model: Option, + /// This model is automatically configured by a user's environment after + /// authenticating all providers. It's only used when `default_model` is not set. + available_fallback_model: Option, inline_assistant_model: Option, commit_message_model: Option, thread_summary_model: Option, @@ -349,22 +350,29 @@ impl LanguageModelRegistry { } pub fn set_default_model(&mut self, model: Option, cx: &mut Context) { - match (self.default_model.as_ref(), model.as_ref()) { + match (self.default_model(), model.as_ref()) { (Some(old), Some(new)) if old.is_same_as(new) => {} (None, None) => {} _ => cx.emit(Event::DefaultModelChanged), } - self.default_fast_model = maybe!({ - let provider = &model.as_ref()?.provider; - let fast_model = provider.default_fast_model(cx)?; - Some(ConfiguredModel { - provider: provider.clone(), - model: fast_model, - }) - }); self.default_model = model; } + pub fn set_environment_fallback_model( + &mut self, + model: Option, + cx: &mut Context, + ) { + if self.default_model.is_none() { + match (self.available_fallback_model.as_ref(), model.as_ref()) { + (Some(old), Some(new)) if old.is_same_as(new) => {} + (None, None) => {} + _ => cx.emit(Event::DefaultModelChanged), + } + } + self.available_fallback_model = model; + } + pub fn set_inline_assistant_model( &mut self, model: Option, @@ -410,7 +418,18 @@ impl LanguageModelRegistry { return None; } - self.default_model.clone() + self.default_model + .clone() + .or_else(|| self.available_fallback_model.clone()) + } + + pub fn default_fast_model(&self, cx: &App) -> Option { + let configured = self.default_model()?; + let fast_model = configured.provider.default_fast_model(cx)?; + Some(ConfiguredModel { + provider: configured.provider, + model: fast_model, + }) } pub fn inline_assistant_model(&self) -> Option { @@ -424,7 +443,7 @@ impl LanguageModelRegistry { .or_else(|| self.default_model.clone()) } - pub fn commit_message_model(&self) -> Option { + pub fn commit_message_model(&self, cx: &App) -> Option { #[cfg(debug_assertions)] if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { return None; @@ -432,11 +451,11 @@ impl LanguageModelRegistry { self.commit_message_model .clone() - .or_else(|| self.default_fast_model.clone()) - .or_else(|| self.default_model.clone()) + .or_else(|| self.default_fast_model(cx)) + .or_else(|| self.default_model()) } - pub fn thread_summary_model(&self) -> Option { + pub fn thread_summary_model(&self, cx: &App) -> Option { #[cfg(debug_assertions)] if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { return None; @@ -444,8 +463,8 @@ impl LanguageModelRegistry { self.thread_summary_model .clone() - .or_else(|| self.default_fast_model.clone()) - .or_else(|| self.default_model.clone()) + .or_else(|| self.default_fast_model(cx)) + .or_else(|| self.default_model()) } /// The models to use for inline assists. Returns the union of the active @@ -576,6 +595,35 @@ mod tests { assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into()))); } + #[gpui::test] + async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) { + let registry = cx.new(|_| LanguageModelRegistry::default()); + + let provider = Arc::new(FakeLanguageModelProvider::default()); + registry.update(cx, |registry, cx| { + registry.register_provider(provider.clone(), cx); + }); + + cx.update(|cx| provider.authenticate(cx)).await.unwrap(); + + registry.update(cx, |registry, cx| { + let provider = registry.provider(&provider.id()).unwrap(); + let model = provider.default_model(cx).unwrap(); + + registry.set_environment_fallback_model( + Some(ConfiguredModel { + provider: provider.clone(), + model: model.clone(), + }), + cx, + ); + + let default_model = registry.default_model().unwrap(); + assert_eq!(default_model.model.id(), model.id()); + assert_eq!(default_model.provider.id(), provider.id()); + }); + } + #[gpui::test] fn test_sync_installed_llm_extensions(cx: &mut App) { let registry = cx.new(|_| LanguageModelRegistry::default()); diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 3154db91a43d1381f5b3f122a724be249adeb79b..bd29dbe08dbd16af25be4bd55b44067f47fa2a8a 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -5,7 +5,9 @@ use client::{Client, UserStore}; use collections::HashSet; use credentials_provider::CredentialsProvider; use gpui::{App, Context, Entity}; -use language_model::{LanguageModelProviderId, LanguageModelRegistry}; +use language_model::{ + ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID, +}; use provider::deepseek::DeepSeekLanguageModelProvider; pub mod extension; @@ -116,6 +118,20 @@ pub fn init(user_store: Entity, client: Arc, cx: &mut App) { cx, ); }); + + cx.subscribe( + ®istry, + |_registry, event: &language_model::Event, cx| match event { + language_model::Event::ProviderStateChanged(_) + | language_model::Event::AddedProvider(_) + | language_model::Event::RemovedProvider(_) => { + update_environment_fallback_model(cx); + } + _ => {} + }, + ) + .detach(); + let registry = registry.downgrade(); cx.observe_global::(move |cx| { let Some(registry) = registry.upgrade() else { @@ -143,6 +159,50 @@ pub fn init(user_store: Entity, client: Arc, cx: &mut App) { .detach(); } +/// Recomputes and sets the [`LanguageModelRegistry`]'s environment fallback +/// model based on currently authenticated providers. +/// +/// Prefers the Zed cloud provider so that, once the user is signed in, we +/// always pick a Zed-hosted model over models from other authenticated +/// providers in the environment. If the Zed cloud provider is authenticated +/// but hasn't finished loading its models yet, we don't fall back to another +/// provider to avoid flickering between providers during sign in. +pub fn update_environment_fallback_model(cx: &mut App) { + let registry = LanguageModelRegistry::global(cx); + let fallback_model = { + let registry = registry.read(cx); + let cloud_provider = registry.provider(&ZED_CLOUD_PROVIDER_ID); + if cloud_provider + .as_ref() + .is_some_and(|provider| provider.is_authenticated(cx)) + { + cloud_provider.and_then(|provider| { + let model = provider + .default_model(cx) + .or_else(|| provider.recommended_models(cx).first().cloned())?; + Some(ConfiguredModel { provider, model }) + }) + } else { + registry + .providers() + .iter() + .filter(|provider| provider.is_authenticated(cx)) + .find_map(|provider| { + let model = provider + .default_model(cx) + .or_else(|| provider.recommended_models(cx).first().cloned())?; + Some(ConfiguredModel { + provider: provider.clone(), + model, + }) + }) + } + }; + registry.update(cx, |registry, cx| { + registry.set_environment_fallback_model(fallback_model, cx); + }); +} + fn register_openai_compatible_providers( registry: &mut LanguageModelRegistry, old: &HashSet>,