Revert "ai: Auto select user model when there's no default" (#36932)

Bennet Bo Fenner created

Reverts zed-industries/zed#36722

Release Notes:

- N/A

Change summary

crates/agent/src/thread.rs                     |  17 +-
crates/agent2/src/agent.rs                     |   4 
crates/agent2/src/tests/mod.rs                 |   4 
crates/agent_ui/src/language_model_selector.rs |  55 +++++++++
crates/git_ui/src/git_panel.rs                 |   2 
crates/language_model/src/registry.rs          | 114 ++++++++-----------
crates/language_models/Cargo.toml              |   1 
crates/language_models/src/language_models.rs  | 103 -----------------
crates/language_models/src/provider/cloud.rs   |   6 
9 files changed, 122 insertions(+), 184 deletions(-)

Detailed changes

crates/agent/src/thread.rs 🔗

@@ -664,7 +664,7 @@ impl Thread {
     }
 
     pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
-        if self.configured_model.is_none() || self.messages.is_empty() {
+        if self.configured_model.is_none() {
             self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
         }
         self.configured_model.clone()
@@ -2097,7 +2097,7 @@ impl Thread {
     }
 
     pub fn summarize(&mut self, cx: &mut Context<Self>) {
-        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model(cx) else {
+        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
             println!("No thread summary model");
             return;
         };
@@ -2416,7 +2416,7 @@ impl Thread {
         }
 
         let Some(ConfiguredModel { model, provider }) =
-            LanguageModelRegistry::read_global(cx).thread_summary_model(cx)
+            LanguageModelRegistry::read_global(cx).thread_summary_model()
         else {
             return;
         };
@@ -5410,10 +5410,13 @@ fn main() {{
                     }),
                     cx,
                 );
-                registry.set_thread_summary_model(Some(ConfiguredModel {
-                    provider,
-                    model: model.clone(),
-                }));
+                registry.set_thread_summary_model(
+                    Some(ConfiguredModel {
+                        provider,
+                        model: model.clone(),
+                    }),
+                    cx,
+                );
             })
         });
 

crates/agent2/src/agent.rs 🔗

@@ -228,7 +228,7 @@ impl NativeAgent {
     ) -> Entity<AcpThread> {
         let connection = Rc::new(NativeAgentConnection(cx.entity()));
         let registry = LanguageModelRegistry::read_global(cx);
-        let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
+        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 
         thread_handle.update(cx, |thread, cx| {
             thread.set_summarization_model(summarization_model, cx);
@@ -524,7 +524,7 @@ impl NativeAgent {
 
         let registry = LanguageModelRegistry::read_global(cx);
         let default_model = registry.default_model().map(|m| m.model);
-        let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
+        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 
         for session in self.sessions.values_mut() {
             session.thread.update(cx, |thread, cx| {

crates/agent2/src/tests/mod.rs 🔗

@@ -1822,11 +1822,11 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
         let clock = Arc::new(clock::FakeSystemClock::new());
         let client = Client::new(clock, http_client, cx);
         let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-        Project::init_settings(cx);
-        agent_settings::init(cx);
         language_model::init(client.clone(), cx);
         language_models::init(user_store, client.clone(), cx);
+        Project::init_settings(cx);
         LanguageModelRegistry::test(cx);
+        agent_settings::init(cx);
     });
     cx.executor().forbid_parking();
 

crates/agent_ui/src/language_model_selector.rs 🔗

@@ -6,7 +6,8 @@ use feature_flags::ZedProFeatureFlag;
 use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
 use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
 use language_model::{
-    ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
+    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
+    LanguageModelRegistry,
 };
 use ordered_float::OrderedFloat;
 use picker::{Picker, PickerDelegate};
@@ -76,6 +77,7 @@ pub struct LanguageModelPickerDelegate {
     all_models: Arc<GroupedModels>,
     filtered_entries: Vec<LanguageModelPickerEntry>,
     selected_index: usize,
+    _authenticate_all_providers_task: Task<()>,
     _subscriptions: Vec<Subscription>,
 }
 
@@ -96,6 +98,7 @@ impl LanguageModelPickerDelegate {
             selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
             filtered_entries: entries,
             get_active_model: Arc::new(get_active_model),
+            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
             _subscriptions: vec![cx.subscribe_in(
                 &LanguageModelRegistry::global(cx),
                 window,
@@ -139,6 +142,56 @@ impl LanguageModelPickerDelegate {
             .unwrap_or(0)
     }
 
+    /// Authenticates all providers in the [`LanguageModelRegistry`].
+    ///
+    /// We do this so that we can populate the language selector with all of the
+    /// models from the configured providers.
+    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
+        let authenticate_all_providers = LanguageModelRegistry::global(cx)
+            .read(cx)
+            .providers()
+            .iter()
+            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
+            .collect::<Vec<_>>();
+
+        cx.spawn(async move |_cx| {
+            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
+                if let Err(err) = authenticate_task.await {
+                    if matches!(err, AuthenticateError::CredentialsNotFound) {
+                        // Since we're authenticating these providers in the
+                        // background for the purposes of populating the
+                        // language selector, we don't care about providers
+                        // where the credentials are not found.
+                    } else {
+                        // Some providers have noisy failure states that we
+                        // don't want to spam the logs with every time the
+                        // language model selector is initialized.
+                        //
+                        // Ideally these should have more clear failure modes
+                        // that we know are safe to ignore here, like what we do
+                        // with `CredentialsNotFound` above.
+                        match provider_id.0.as_ref() {
+                            "lmstudio" | "ollama" => {
+                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
+                                //
+                                // These fail noisily, so we don't log them.
+                            }
+                            "copilot_chat" => {
+                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
+                            }
+                            _ => {
+                                log::error!(
+                                    "Failed to authenticate provider: {}: {err}",
+                                    provider_name.0
+                                );
+                            }
+                        }
+                    }
+                }
+            }
+        })
+    }
+
     pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
         (self.get_active_model)(cx)
     }

crates/git_ui/src/git_panel.rs 🔗

@@ -4466,7 +4466,7 @@ fn current_language_model(cx: &Context<'_, GitPanel>) -> Option<Arc<dyn Language
     is_enabled
         .then(|| {
             let ConfiguredModel { provider, model } =
-                LanguageModelRegistry::read_global(cx).commit_message_model(cx)?;
+                LanguageModelRegistry::read_global(cx).commit_message_model()?;
 
             provider.is_authenticated(cx).then(|| model)
         })

crates/language_model/src/registry.rs 🔗

@@ -6,6 +6,7 @@ use collections::BTreeMap;
 use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
 use std::{str::FromStr, sync::Arc};
 use thiserror::Error;
+use util::maybe;
 
 pub fn init(cx: &mut App) {
     let registry = cx.new(|_cx| LanguageModelRegistry::default());
@@ -41,9 +42,7 @@ impl std::fmt::Debug for ConfigurationError {
 #[derive(Default)]
 pub struct LanguageModelRegistry {
     default_model: Option<ConfiguredModel>,
-    /// This model is automatically configured by a user's environment after
-    /// authenticating all providers. It's only used when default_model is not available.
-    environment_fallback_model: Option<ConfiguredModel>,
+    default_fast_model: Option<ConfiguredModel>,
     inline_assistant_model: Option<ConfiguredModel>,
     commit_message_model: Option<ConfiguredModel>,
     thread_summary_model: Option<ConfiguredModel>,
@@ -99,6 +98,9 @@ impl ConfiguredModel {
 
 pub enum Event {
     DefaultModelChanged,
+    InlineAssistantModelChanged,
+    CommitMessageModelChanged,
+    ThreadSummaryModelChanged,
     ProviderStateChanged(LanguageModelProviderId),
     AddedProvider(LanguageModelProviderId),
     RemovedProvider(LanguageModelProviderId),
@@ -224,7 +226,7 @@ impl LanguageModelRegistry {
         cx: &mut Context<Self>,
     ) {
         let configured_model = model.and_then(|model| self.select_model(model, cx));
-        self.set_inline_assistant_model(configured_model);
+        self.set_inline_assistant_model(configured_model, cx);
     }
 
     pub fn select_commit_message_model(
@@ -233,7 +235,7 @@ impl LanguageModelRegistry {
         cx: &mut Context<Self>,
     ) {
         let configured_model = model.and_then(|model| self.select_model(model, cx));
-        self.set_commit_message_model(configured_model);
+        self.set_commit_message_model(configured_model, cx);
     }
 
     pub fn select_thread_summary_model(
@@ -242,7 +244,7 @@ impl LanguageModelRegistry {
         cx: &mut Context<Self>,
     ) {
         let configured_model = model.and_then(|model| self.select_model(model, cx));
-        self.set_thread_summary_model(configured_model);
+        self.set_thread_summary_model(configured_model, cx);
     }
 
     /// Selects and sets the inline alternatives for language models based on
@@ -276,60 +278,68 @@ impl LanguageModelRegistry {
     }
 
     pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
-        match (self.default_model(), model.as_ref()) {
+        match (self.default_model.as_ref(), model.as_ref()) {
             (Some(old), Some(new)) if old.is_same_as(new) => {}
             (None, None) => {}
             _ => cx.emit(Event::DefaultModelChanged),
         }
+        self.default_fast_model = maybe!({
+            let provider = &model.as_ref()?.provider;
+            let fast_model = provider.default_fast_model(cx)?;
+            Some(ConfiguredModel {
+                provider: provider.clone(),
+                model: fast_model,
+            })
+        });
         self.default_model = model;
     }
 
-    pub fn set_environment_fallback_model(
+    pub fn set_inline_assistant_model(
         &mut self,
         model: Option<ConfiguredModel>,
         cx: &mut Context<Self>,
     ) {
-        if self.default_model.is_none() {
-            match (self.environment_fallback_model.as_ref(), model.as_ref()) {
-                (Some(old), Some(new)) if old.is_same_as(new) => {}
-                (None, None) => {}
-                _ => cx.emit(Event::DefaultModelChanged),
-            }
+        match (self.inline_assistant_model.as_ref(), model.as_ref()) {
+            (Some(old), Some(new)) if old.is_same_as(new) => {}
+            (None, None) => {}
+            _ => cx.emit(Event::InlineAssistantModelChanged),
         }
-        self.environment_fallback_model = model;
-    }
-
-    pub fn set_inline_assistant_model(&mut self, model: Option<ConfiguredModel>) {
         self.inline_assistant_model = model;
     }
 
-    pub fn set_commit_message_model(&mut self, model: Option<ConfiguredModel>) {
+    pub fn set_commit_message_model(
+        &mut self,
+        model: Option<ConfiguredModel>,
+        cx: &mut Context<Self>,
+    ) {
+        match (self.commit_message_model.as_ref(), model.as_ref()) {
+            (Some(old), Some(new)) if old.is_same_as(new) => {}
+            (None, None) => {}
+            _ => cx.emit(Event::CommitMessageModelChanged),
+        }
         self.commit_message_model = model;
     }
 
-    pub fn set_thread_summary_model(&mut self, model: Option<ConfiguredModel>) {
+    pub fn set_thread_summary_model(
+        &mut self,
+        model: Option<ConfiguredModel>,
+        cx: &mut Context<Self>,
+    ) {
+        match (self.thread_summary_model.as_ref(), model.as_ref()) {
+            (Some(old), Some(new)) if old.is_same_as(new) => {}
+            (None, None) => {}
+            _ => cx.emit(Event::ThreadSummaryModelChanged),
+        }
         self.thread_summary_model = model;
     }
 
-    #[track_caller]
     pub fn default_model(&self) -> Option<ConfiguredModel> {
         #[cfg(debug_assertions)]
         if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
             return None;
         }
 
-        self.default_model
-            .clone()
-            .or_else(|| self.environment_fallback_model.clone())
-    }
-
-    pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
-        let provider = self.default_model()?.provider;
-        let fast_model = provider.default_fast_model(cx)?;
-        Some(ConfiguredModel {
-            provider,
-            model: fast_model,
-        })
+        self.default_model.clone()
     }
 
     pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
@@ -343,7 +353,7 @@ impl LanguageModelRegistry {
             .or_else(|| self.default_model.clone())
     }
 
-    pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
+    pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
         #[cfg(debug_assertions)]
         if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
             return None;
@@ -351,11 +361,11 @@ impl LanguageModelRegistry {
 
         self.commit_message_model
             .clone()
-            .or_else(|| self.default_fast_model(cx))
+            .or_else(|| self.default_fast_model.clone())
             .or_else(|| self.default_model.clone())
     }
 
-    pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
+    pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
         #[cfg(debug_assertions)]
         if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
             return None;
@@ -363,7 +373,7 @@ impl LanguageModelRegistry {
 
         self.thread_summary_model
             .clone()
-            .or_else(|| self.default_fast_model(cx))
+            .or_else(|| self.default_fast_model.clone())
             .or_else(|| self.default_model.clone())
     }
 
@@ -400,34 +410,4 @@ mod tests {
         let providers = registry.read(cx).providers();
         assert!(providers.is_empty());
     }
-
-    #[gpui::test]
-    async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
-        let registry = cx.new(|_| LanguageModelRegistry::default());
-
-        let provider = FakeLanguageModelProvider::default();
-        registry.update(cx, |registry, cx| {
-            registry.register_provider(provider.clone(), cx);
-        });
-
-        cx.update(|cx| provider.authenticate(cx)).await.unwrap();
-
-        registry.update(cx, |registry, cx| {
-            let provider = registry.provider(&provider.id()).unwrap();
-
-            registry.set_environment_fallback_model(
-                Some(ConfiguredModel {
-                    provider: provider.clone(),
-                    model: provider.default_model(cx).unwrap(),
-                }),
-                cx,
-            );
-
-            let default_model = registry.default_model().unwrap();
-            let fallback_model = registry.environment_fallback_model.clone().unwrap();
-
-            assert_eq!(default_model.model.id(), fallback_model.model.id());
-            assert_eq!(default_model.provider.id(), fallback_model.provider.id());
-        });
-    }
 }

crates/language_models/Cargo.toml 🔗

@@ -44,7 +44,6 @@ ollama = { workspace = true, features = ["schemars"] }
 open_ai = { workspace = true, features = ["schemars"] }
 open_router = { workspace = true, features = ["schemars"] }
 partial-json-fixer.workspace = true
-project.workspace = true
 release_channel.workspace = true
 schemars.workspace = true
 serde.workspace = true

crates/language_models/src/language_models.rs 🔗

@@ -3,12 +3,8 @@ use std::sync::Arc;
 use ::settings::{Settings, SettingsStore};
 use client::{Client, UserStore};
 use collections::HashSet;
-use futures::future;
-use gpui::{App, AppContext as _, Context, Entity};
-use language_model::{
-    AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
-};
-use project::DisableAiSettings;
+use gpui::{App, Context, Entity};
+use language_model::{LanguageModelProviderId, LanguageModelRegistry};
 use provider::deepseek::DeepSeekLanguageModelProvider;
 
 pub mod provider;
@@ -17,7 +13,7 @@ pub mod ui;
 
 use crate::provider::anthropic::AnthropicLanguageModelProvider;
 use crate::provider::bedrock::BedrockLanguageModelProvider;
-use crate::provider::cloud::{self, CloudLanguageModelProvider};
+use crate::provider::cloud::CloudLanguageModelProvider;
 use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 use crate::provider::google::GoogleLanguageModelProvider;
 use crate::provider::lmstudio::LmStudioLanguageModelProvider;
@@ -52,13 +48,6 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
             cx,
         );
     });
-
-    let mut already_authenticated = false;
-    if !DisableAiSettings::get_global(cx).disable_ai {
-        authenticate_all_providers(registry.clone(), cx);
-        already_authenticated = true;
-    }
-
     cx.observe_global::<SettingsStore>(move |cx| {
         let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
             .openai_compatible
@@ -76,12 +65,6 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
                 );
             });
             openai_compatible_providers = openai_compatible_providers_new;
-            already_authenticated = false;
-        }
-
-        if !DisableAiSettings::get_global(cx).disable_ai && !already_authenticated {
-            authenticate_all_providers(registry.clone(), cx);
-            already_authenticated = true;
         }
     })
     .detach();
@@ -168,83 +151,3 @@ fn register_language_model_providers(
     registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
     registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
 }
-
-/// Authenticates all providers in the [`LanguageModelRegistry`].
-///
-/// We do this so that we can populate the language selector with all of the
-/// models from the configured providers.
-///
-/// This function won't do anything if AI is disabled.
-fn authenticate_all_providers(registry: Entity<LanguageModelRegistry>, cx: &mut App) {
-    let providers_to_authenticate = registry
-        .read(cx)
-        .providers()
-        .iter()
-        .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
-        .collect::<Vec<_>>();
-
-    let mut tasks = Vec::with_capacity(providers_to_authenticate.len());
-
-    for (provider_id, provider_name, authenticate_task) in providers_to_authenticate {
-        tasks.push(cx.background_spawn(async move {
-            if let Err(err) = authenticate_task.await {
-                if matches!(err, AuthenticateError::CredentialsNotFound) {
-                    // Since we're authenticating these providers in the
-                    // background for the purposes of populating the
-                    // language selector, we don't care about providers
-                    // where the credentials are not found.
-                } else {
-                    // Some providers have noisy failure states that we
-                    // don't want to spam the logs with every time the
-                    // language model selector is initialized.
-                    //
-                    // Ideally these should have more clear failure modes
-                    // that we know are safe to ignore here, like what we do
-                    // with `CredentialsNotFound` above.
-                    match provider_id.0.as_ref() {
-                        "lmstudio" | "ollama" => {
-                            // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
-                            //
-                            // These fail noisily, so we don't log them.
-                        }
-                        "copilot_chat" => {
-                            // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
-                        }
-                        _ => {
-                            log::error!(
-                                "Failed to authenticate provider: {}: {err}",
-                                provider_name.0
-                            );
-                        }
-                    }
-                }
-            }
-        }));
-    }
-
-    let all_authenticated_future = future::join_all(tasks);
-
-    cx.spawn(async move |cx| {
-        all_authenticated_future.await;
-
-        registry
-            .update(cx, |registry, cx| {
-                let cloud_provider = registry.provider(&cloud::PROVIDER_ID);
-                let fallback_model = cloud_provider
-                    .iter()
-                    .chain(registry.providers().iter())
-                    .find(|provider| provider.is_authenticated(cx))
-                    .and_then(|provider| {
-                        Some(ConfiguredModel {
-                            provider: provider.clone(),
-                            model: provider
-                                .default_model(cx)
-                                .or_else(|| provider.recommended_models(cx).first().cloned())?,
-                        })
-                    });
-                registry.set_environment_fallback_model(fallback_model, cx);
-            })
-            .ok();
-    })
-    .detach();
-}

crates/language_models/src/provider/cloud.rs 🔗

@@ -44,8 +44,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i
 use crate::provider::google::{GoogleEventMapper, into_google};
 use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
 
-pub const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
-pub const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
+const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
+const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
 
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct ZedDotDevSettings {
@@ -146,7 +146,7 @@ impl State {
             default_fast_model: None,
             recommended_models: Vec::new(),
             _fetch_models_task: cx.spawn(async move |this, cx| {
-                maybe!(async {
+                maybe!(async move {
                     let (client, llm_api_token) = this
                         .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;