agent: Auto-select user model when there's no default (#54125)

Anthony Eid created

Reimplements #36722 while fixing the race that required the revert in
#36932.

When no default model is configured, this picks an environment fallback
by authenticating all providers. It always prefers the Zed cloud
provider when it's authenticated, and waits for its models to load
before picking another provider as the fallback, so we don't flicker
from Zed models to Anthropic while sign-in is in flight.

The fallback is recomputed whenever provider state changes (via
`ProviderStateChanged`/`AddedProvider`/`RemovedProvider` events), so the
selection becomes correct as soon as cloud models arrive.

### What changed vs. the original PR

- `language_models::init` now owns `authenticate_all_providers`
(previously done in `LanguageModelPickerDelegate` and `agent`'s
`LanguageModels`).
- After all authentications settle, and on any subsequent provider state
change, `update_environment_fallback_model` recomputes the fallback.
- The fallback logic prefers Zed cloud: if the cloud provider is
authenticated, only use it (waiting for its models to load). Otherwise,
fall through to the first authenticated provider with a default or
recommended model.
- `LanguageModelRegistry::default_model()` falls back to
`environment_fallback_model` when no explicit default is set.
- Existing `Thread`s that are empty are updated to the new default when
`DefaultModelChanged` fires, so a blank thread started before sign-in
switches to Zed models once the user signs in.

Release Notes:

- agent: Automatically select a model when there's no selected model or configured default

Change summary

crates/agent/src/agent.rs                      | 17 ++-
crates/agent_ui/src/language_model_selector.rs | 56 ------------
crates/git_ui/src/git_panel.rs                 |  4 
crates/language_model/src/registry.rs          | 84 +++++++++++++++----
crates/language_models/src/language_models.rs  | 62 ++++++++++++++
5 files changed, 141 insertions(+), 82 deletions(-)

Detailed changes

crates/agent/src/agent.rs 🔗

@@ -201,7 +201,7 @@ impl LanguageModels {
             .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
             .collect::<Vec<_>>();
 
-        cx.background_spawn(async move {
+        cx.spawn(async move |cx| {
             for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
                 if let Err(err) = authenticate_task.await {
                     match err {
@@ -244,6 +244,8 @@ impl LanguageModels {
                     }
                 }
             }
+
+            cx.update(language_models::update_environment_fallback_model);
         })
     }
 }
@@ -365,7 +367,7 @@ impl NativeAgent {
         });
 
         let registry = LanguageModelRegistry::read_global(cx);
-        let summarization_model = registry.thread_summary_model().map(|c| c.model);
+        let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
 
         let weak = cx.weak_entity();
         let weak_thread = thread_handle.downgrade();
@@ -749,13 +751,14 @@ impl NativeAgent {
 
         let registry = LanguageModelRegistry::read_global(cx);
         let default_model = registry.default_model().map(|m| m.model);
-        let summarization_model = registry.thread_summary_model().map(|m| m.model);
+        let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
 
         for session in self.sessions.values_mut() {
             session.thread.update(cx, |thread, cx| {
-                if thread.model().is_none()
-                    && let Some(model) = default_model.clone()
-                {
+                let should_update_model = thread.model().is_none()
+                    || (thread.is_empty()
+                        && matches!(event, language_model::Event::DefaultModelChanged));
+                if should_update_model && let Some(model) = default_model.clone() {
                     thread.set_model(model, cx);
                     cx.notify();
                 }
@@ -910,7 +913,7 @@ impl NativeAgent {
                     .get(&project_id)
                     .context("project state not found")?;
                 let summarization_model = LanguageModelRegistry::read_global(cx)
-                    .thread_summary_model()
+                    .thread_summary_model(cx)
                     .map(|c| c.model);
 
                 Ok(cx.new(|cx| {

crates/agent_ui/src/language_model_selector.rs 🔗

@@ -8,8 +8,8 @@ use gpui::{
     Subscription, Task,
 };
 use language_model::{
-    AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId,
-    LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
+    ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId, LanguageModelProvider,
+    LanguageModelProviderId, LanguageModelRegistry,
 };
 use ordered_float::OrderedFloat;
 use picker::{Picker, PickerDelegate};
@@ -124,7 +124,6 @@ pub struct LanguageModelPickerDelegate {
     all_models: Arc<GroupedModels>,
     filtered_entries: Vec<LanguageModelPickerEntry>,
     selected_index: usize,
-    _authenticate_all_providers_task: Task<()>,
     _subscriptions: Vec<Subscription>,
     popover_styles: bool,
     focus_handle: FocusHandle,
@@ -151,7 +150,6 @@ impl LanguageModelPickerDelegate {
             filtered_entries: entries,
             get_active_model: Arc::new(get_active_model),
             on_toggle_favorite: Arc::new(on_toggle_favorite),
-            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
             _subscriptions: vec![cx.subscribe_in(
                 &LanguageModelRegistry::global(cx),
                 window,
@@ -197,56 +195,6 @@ impl LanguageModelPickerDelegate {
             .unwrap_or(0)
     }
 
-    /// Authenticates all providers in the [`LanguageModelRegistry`].
-    ///
-    /// We do this so that we can populate the language selector with all of the
-    /// models from the configured providers.
-    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
-        let authenticate_all_providers = LanguageModelRegistry::global(cx)
-            .read(cx)
-            .visible_providers()
-            .iter()
-            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
-            .collect::<Vec<_>>();
-
-        cx.spawn(async move |_cx| {
-            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
-                if let Err(err) = authenticate_task.await {
-                    if matches!(err, AuthenticateError::CredentialsNotFound) {
-                        // Since we're authenticating these providers in the
-                        // background for the purposes of populating the
-                        // language selector, we don't care about providers
-                        // where the credentials are not found.
-                    } else {
-                        // Some providers have noisy failure states that we
-                        // don't want to spam the logs with every time the
-                        // language model selector is initialized.
-                        //
-                        // Ideally these should have more clear failure modes
-                        // that we know are safe to ignore here, like what we do
-                        // with `CredentialsNotFound` above.
-                        match provider_id.0.as_ref() {
-                            "lmstudio" | "ollama" => {
-                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
-                                //
-                                // These fail noisily, so we don't log them.
-                            }
-                            "copilot_chat" => {
-                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
-                            }
-                            _ => {
-                                log::error!(
-                                    "Failed to authenticate provider: {}: {err:#}",
-                                    provider_name.0
-                                );
-                            }
-                        }
-                    }
-                }
-            }
-        })
-    }
-
     pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
         (self.get_active_model)(cx)
     }

crates/git_ui/src/git_panel.rs 🔗

@@ -2688,7 +2688,7 @@ impl GitPanel {
         }
 
         let Some(ConfiguredModel { provider, model }) =
-            LanguageModelRegistry::read_global(cx).commit_message_model()
+            LanguageModelRegistry::read_global(cx).commit_message_model(cx)
         else {
             return;
         };
@@ -4056,7 +4056,7 @@ impl GitPanel {
 
         let model_registry = LanguageModelRegistry::read_global(cx);
         let has_commit_model_configuration_error = model_registry
-            .configuration_error(model_registry.commit_message_model(), cx)
+            .configuration_error(model_registry.commit_message_model(cx), cx)
             .is_some();
         let can_commit = self.can_commit();
 

crates/language_model/src/registry.rs 🔗

@@ -6,7 +6,6 @@ use collections::{BTreeMap, HashSet};
 use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
 use std::{str::FromStr, sync::Arc};
 use thiserror::Error;
-use util::maybe;
 
 /// Function type for checking if a built-in provider should be hidden.
 /// Returns Some(extension_id) if the provider should be hidden when that extension is installed.
@@ -46,7 +45,9 @@ impl std::fmt::Debug for ConfigurationError {
 #[derive(Default)]
 pub struct LanguageModelRegistry {
     default_model: Option<ConfiguredModel>,
-    default_fast_model: Option<ConfiguredModel>,
+    /// This model is automatically configured by a user's environment after
+    /// authenticating all providers. It's only used when `default_model` is not set.
+    available_fallback_model: Option<ConfiguredModel>,
     inline_assistant_model: Option<ConfiguredModel>,
     commit_message_model: Option<ConfiguredModel>,
     thread_summary_model: Option<ConfiguredModel>,
@@ -349,22 +350,29 @@ impl LanguageModelRegistry {
     }
 
     pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
-        match (self.default_model.as_ref(), model.as_ref()) {
+        match (self.default_model(), model.as_ref()) {
             (Some(old), Some(new)) if old.is_same_as(new) => {}
             (None, None) => {}
             _ => cx.emit(Event::DefaultModelChanged),
         }
-        self.default_fast_model = maybe!({
-            let provider = &model.as_ref()?.provider;
-            let fast_model = provider.default_fast_model(cx)?;
-            Some(ConfiguredModel {
-                provider: provider.clone(),
-                model: fast_model,
-            })
-        });
         self.default_model = model;
     }
 
+    pub fn set_environment_fallback_model(
+        &mut self,
+        model: Option<ConfiguredModel>,
+        cx: &mut Context<Self>,
+    ) {
+        if self.default_model.is_none() {
+            match (self.available_fallback_model.as_ref(), model.as_ref()) {
+                (Some(old), Some(new)) if old.is_same_as(new) => {}
+                (None, None) => {}
+                _ => cx.emit(Event::DefaultModelChanged),
+            }
+        }
+        self.available_fallback_model = model;
+    }
+
     pub fn set_inline_assistant_model(
         &mut self,
         model: Option<ConfiguredModel>,
@@ -410,7 +418,18 @@ impl LanguageModelRegistry {
             return None;
         }
 
-        self.default_model.clone()
+        self.default_model
+            .clone()
+            .or_else(|| self.available_fallback_model.clone())
+    }
+
+    pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
+        let configured = self.default_model()?;
+        let fast_model = configured.provider.default_fast_model(cx)?;
+        Some(ConfiguredModel {
+            provider: configured.provider,
+            model: fast_model,
+        })
     }
 
     pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
@@ -424,7 +443,7 @@ impl LanguageModelRegistry {
             .or_else(|| self.default_model.clone())
     }
 
-    pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
+    pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
         #[cfg(debug_assertions)]
         if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
             return None;
@@ -432,11 +451,11 @@ impl LanguageModelRegistry {
 
         self.commit_message_model
             .clone()
-            .or_else(|| self.default_fast_model.clone())
-            .or_else(|| self.default_model.clone())
+            .or_else(|| self.default_fast_model(cx))
+            .or_else(|| self.default_model())
     }
 
-    pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
+    pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
         #[cfg(debug_assertions)]
         if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
             return None;
@@ -444,8 +463,8 @@ impl LanguageModelRegistry {
 
         self.thread_summary_model
             .clone()
-            .or_else(|| self.default_fast_model.clone())
-            .or_else(|| self.default_model.clone())
+            .or_else(|| self.default_fast_model(cx))
+            .or_else(|| self.default_model())
     }
 
     /// The models to use for inline assists. Returns the union of the active
@@ -576,6 +595,35 @@ mod tests {
         assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into())));
     }
 
+    #[gpui::test]
+    async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
+        let registry = cx.new(|_| LanguageModelRegistry::default());
+
+        let provider = Arc::new(FakeLanguageModelProvider::default());
+        registry.update(cx, |registry, cx| {
+            registry.register_provider(provider.clone(), cx);
+        });
+
+        cx.update(|cx| provider.authenticate(cx)).await.unwrap();
+
+        registry.update(cx, |registry, cx| {
+            let provider = registry.provider(&provider.id()).unwrap();
+            let model = provider.default_model(cx).unwrap();
+
+            registry.set_environment_fallback_model(
+                Some(ConfiguredModel {
+                    provider: provider.clone(),
+                    model: model.clone(),
+                }),
+                cx,
+            );
+
+            let default_model = registry.default_model().unwrap();
+            assert_eq!(default_model.model.id(), model.id());
+            assert_eq!(default_model.provider.id(), provider.id());
+        });
+    }
+
     #[gpui::test]
     fn test_sync_installed_llm_extensions(cx: &mut App) {
         let registry = cx.new(|_| LanguageModelRegistry::default());

crates/language_models/src/language_models.rs 🔗

@@ -5,7 +5,9 @@ use client::{Client, UserStore};
 use collections::HashSet;
 use credentials_provider::CredentialsProvider;
 use gpui::{App, Context, Entity};
-use language_model::{LanguageModelProviderId, LanguageModelRegistry};
+use language_model::{
+    ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
+};
 use provider::deepseek::DeepSeekLanguageModelProvider;
 
 pub mod extension;
@@ -116,6 +118,20 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
             cx,
         );
     });
+
+    cx.subscribe(
+        &registry,
+        |_registry, event: &language_model::Event, cx| match event {
+            language_model::Event::ProviderStateChanged(_)
+            | language_model::Event::AddedProvider(_)
+            | language_model::Event::RemovedProvider(_) => {
+                update_environment_fallback_model(cx);
+            }
+            _ => {}
+        },
+    )
+    .detach();
+
     let registry = registry.downgrade();
     cx.observe_global::<SettingsStore>(move |cx| {
         let Some(registry) = registry.upgrade() else {
@@ -143,6 +159,50 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
     .detach();
 }
 
+/// Recomputes and sets the [`LanguageModelRegistry`]'s environment fallback
+/// model based on currently authenticated providers.
+///
+/// Prefers the Zed cloud provider so that, once the user is signed in, we
+/// always pick a Zed-hosted model over models from other authenticated
+/// providers in the environment. If the Zed cloud provider is authenticated
+/// but hasn't finished loading its models yet, we don't fall back to another
+/// provider to avoid flickering between providers during sign in.
+pub fn update_environment_fallback_model(cx: &mut App) {
+    let registry = LanguageModelRegistry::global(cx);
+    let fallback_model = {
+        let registry = registry.read(cx);
+        let cloud_provider = registry.provider(&ZED_CLOUD_PROVIDER_ID);
+        if cloud_provider
+            .as_ref()
+            .is_some_and(|provider| provider.is_authenticated(cx))
+        {
+            cloud_provider.and_then(|provider| {
+                let model = provider
+                    .default_model(cx)
+                    .or_else(|| provider.recommended_models(cx).first().cloned())?;
+                Some(ConfiguredModel { provider, model })
+            })
+        } else {
+            registry
+                .providers()
+                .iter()
+                .filter(|provider| provider.is_authenticated(cx))
+                .find_map(|provider| {
+                    let model = provider
+                        .default_model(cx)
+                        .or_else(|| provider.recommended_models(cx).first().cloned())?;
+                    Some(ConfiguredModel {
+                        provider: provider.clone(),
+                        model,
+                    })
+                })
+        }
+    };
+    registry.update(cx, |registry, cx| {
+        registry.set_environment_fallback_model(fallback_model, cx);
+    });
+}
+
 fn register_openai_compatible_providers(
     registry: &mut LanguageModelRegistry,
     old: &HashSet<Arc<str>>,