Detailed changes
@@ -201,7 +201,7 @@ impl LanguageModels {
.map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
.collect::<Vec<_>>();
- cx.background_spawn(async move {
+ cx.spawn(async move |cx| {
for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
if let Err(err) = authenticate_task.await {
match err {
@@ -244,6 +244,8 @@ impl LanguageModels {
}
}
}
+
+ cx.update(language_models::update_environment_fallback_model);
})
}
}
@@ -365,7 +367,7 @@ impl NativeAgent {
});
let registry = LanguageModelRegistry::read_global(cx);
- let summarization_model = registry.thread_summary_model().map(|c| c.model);
+ let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
let weak = cx.weak_entity();
let weak_thread = thread_handle.downgrade();
@@ -749,13 +751,14 @@ impl NativeAgent {
let registry = LanguageModelRegistry::read_global(cx);
let default_model = registry.default_model().map(|m| m.model);
- let summarization_model = registry.thread_summary_model().map(|m| m.model);
+ let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, cx| {
- if thread.model().is_none()
- && let Some(model) = default_model.clone()
- {
+ let should_update_model = thread.model().is_none()
+ || (thread.is_empty()
+ && matches!(event, language_model::Event::DefaultModelChanged));
+ if should_update_model && let Some(model) = default_model.clone() {
thread.set_model(model, cx);
cx.notify();
}
@@ -910,7 +913,7 @@ impl NativeAgent {
.get(&project_id)
.context("project state not found")?;
let summarization_model = LanguageModelRegistry::read_global(cx)
- .thread_summary_model()
+ .thread_summary_model(cx)
.map(|c| c.model);
Ok(cx.new(|cx| {
@@ -8,8 +8,8 @@ use gpui::{
Subscription, Task,
};
use language_model::{
- AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId,
- LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
+ ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId, LanguageModelProvider,
+ LanguageModelProviderId, LanguageModelRegistry,
};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
@@ -124,7 +124,6 @@ pub struct LanguageModelPickerDelegate {
all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
- _authenticate_all_providers_task: Task<()>,
_subscriptions: Vec<Subscription>,
popover_styles: bool,
focus_handle: FocusHandle,
@@ -151,7 +150,6 @@ impl LanguageModelPickerDelegate {
filtered_entries: entries,
get_active_model: Arc::new(get_active_model),
on_toggle_favorite: Arc::new(on_toggle_favorite),
- _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
_subscriptions: vec![cx.subscribe_in(
&LanguageModelRegistry::global(cx),
window,
@@ -197,56 +195,6 @@ impl LanguageModelPickerDelegate {
.unwrap_or(0)
}
- /// Authenticates all providers in the [`LanguageModelRegistry`].
- ///
- /// We do this so that we can populate the language selector with all of the
- /// models from the configured providers.
- fn authenticate_all_providers(cx: &mut App) -> Task<()> {
- let authenticate_all_providers = LanguageModelRegistry::global(cx)
- .read(cx)
- .visible_providers()
- .iter()
- .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
- .collect::<Vec<_>>();
-
- cx.spawn(async move |_cx| {
- for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
- if let Err(err) = authenticate_task.await {
- if matches!(err, AuthenticateError::CredentialsNotFound) {
- // Since we're authenticating these providers in the
- // background for the purposes of populating the
- // language selector, we don't care about providers
- // where the credentials are not found.
- } else {
- // Some providers have noisy failure states that we
- // don't want to spam the logs with every time the
- // language model selector is initialized.
- //
- // Ideally these should have more clear failure modes
- // that we know are safe to ignore here, like what we do
- // with `CredentialsNotFound` above.
- match provider_id.0.as_ref() {
- "lmstudio" | "ollama" => {
- // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
- //
- // These fail noisily, so we don't log them.
- }
- "copilot_chat" => {
- // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
- }
- _ => {
- log::error!(
- "Failed to authenticate provider: {}: {err:#}",
- provider_name.0
- );
- }
- }
- }
- }
- }
- })
- }
-
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
(self.get_active_model)(cx)
}
@@ -2688,7 +2688,7 @@ impl GitPanel {
}
let Some(ConfiguredModel { provider, model }) =
- LanguageModelRegistry::read_global(cx).commit_message_model()
+ LanguageModelRegistry::read_global(cx).commit_message_model(cx)
else {
return;
};
@@ -4056,7 +4056,7 @@ impl GitPanel {
let model_registry = LanguageModelRegistry::read_global(cx);
let has_commit_model_configuration_error = model_registry
- .configuration_error(model_registry.commit_message_model(), cx)
+ .configuration_error(model_registry.commit_message_model(cx), cx)
.is_some();
let can_commit = self.can_commit();
@@ -6,7 +6,6 @@ use collections::{BTreeMap, HashSet};
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
-use util::maybe;
/// Function type for checking if a built-in provider should be hidden.
/// Returns Some(extension_id) if the provider should be hidden when that extension is installed.
@@ -46,7 +45,9 @@ impl std::fmt::Debug for ConfigurationError {
#[derive(Default)]
pub struct LanguageModelRegistry {
default_model: Option<ConfiguredModel>,
- default_fast_model: Option<ConfiguredModel>,
+ /// This model is automatically configured by a user's environment after
+ /// authenticating all providers. It's only used when `default_model` is not set.
+ available_fallback_model: Option<ConfiguredModel>,
inline_assistant_model: Option<ConfiguredModel>,
commit_message_model: Option<ConfiguredModel>,
thread_summary_model: Option<ConfiguredModel>,
@@ -349,22 +350,29 @@ impl LanguageModelRegistry {
}
pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
- match (self.default_model.as_ref(), model.as_ref()) {
+ match (self.default_model(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::DefaultModelChanged),
}
- self.default_fast_model = maybe!({
- let provider = &model.as_ref()?.provider;
- let fast_model = provider.default_fast_model(cx)?;
- Some(ConfiguredModel {
- provider: provider.clone(),
- model: fast_model,
- })
- });
self.default_model = model;
}
+ pub fn set_environment_fallback_model(
+ &mut self,
+ model: Option<ConfiguredModel>,
+ cx: &mut Context<Self>,
+ ) {
+ if self.default_model.is_none() {
+ match (self.available_fallback_model.as_ref(), model.as_ref()) {
+ (Some(old), Some(new)) if old.is_same_as(new) => {}
+ (None, None) => {}
+ _ => cx.emit(Event::DefaultModelChanged),
+ }
+ }
+ self.available_fallback_model = model;
+ }
+
pub fn set_inline_assistant_model(
&mut self,
model: Option<ConfiguredModel>,
@@ -410,7 +418,18 @@ impl LanguageModelRegistry {
return None;
}
- self.default_model.clone()
+ self.default_model
+ .clone()
+ .or_else(|| self.available_fallback_model.clone())
+ }
+
+ pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
+ let configured = self.default_model()?;
+ let fast_model = configured.provider.default_fast_model(cx)?;
+ Some(ConfiguredModel {
+ provider: configured.provider,
+ model: fast_model,
+ })
}
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
@@ -424,7 +443,7 @@ impl LanguageModelRegistry {
.or_else(|| self.default_model.clone())
}
- pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
+ pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
@@ -432,11 +451,11 @@ impl LanguageModelRegistry {
self.commit_message_model
.clone()
- .or_else(|| self.default_fast_model.clone())
- .or_else(|| self.default_model.clone())
+ .or_else(|| self.default_fast_model(cx))
+ .or_else(|| self.default_model())
}
- pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
+ pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
@@ -444,8 +463,8 @@ impl LanguageModelRegistry {
self.thread_summary_model
.clone()
- .or_else(|| self.default_fast_model.clone())
- .or_else(|| self.default_model.clone())
+ .or_else(|| self.default_fast_model(cx))
+ .or_else(|| self.default_model())
}
/// The models to use for inline assists. Returns the union of the active
@@ -576,6 +595,35 @@ mod tests {
assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into())));
}
+ #[gpui::test]
+ async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
+ let registry = cx.new(|_| LanguageModelRegistry::default());
+
+ let provider = Arc::new(FakeLanguageModelProvider::default());
+ registry.update(cx, |registry, cx| {
+ registry.register_provider(provider.clone(), cx);
+ });
+
+ cx.update(|cx| provider.authenticate(cx)).await.unwrap();
+
+ registry.update(cx, |registry, cx| {
+ let provider = registry.provider(&provider.id()).unwrap();
+ let model = provider.default_model(cx).unwrap();
+
+ registry.set_environment_fallback_model(
+ Some(ConfiguredModel {
+ provider: provider.clone(),
+ model: model.clone(),
+ }),
+ cx,
+ );
+
+ let default_model = registry.default_model().unwrap();
+ assert_eq!(default_model.model.id(), model.id());
+ assert_eq!(default_model.provider.id(), provider.id());
+ });
+ }
+
#[gpui::test]
fn test_sync_installed_llm_extensions(cx: &mut App) {
let registry = cx.new(|_| LanguageModelRegistry::default());
@@ -5,7 +5,9 @@ use client::{Client, UserStore};
use collections::HashSet;
use credentials_provider::CredentialsProvider;
use gpui::{App, Context, Entity};
-use language_model::{LanguageModelProviderId, LanguageModelRegistry};
+use language_model::{
+ ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
+};
use provider::deepseek::DeepSeekLanguageModelProvider;
pub mod extension;
@@ -116,6 +118,20 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
cx,
);
});
+
+ cx.subscribe(
+ ®istry,
+ |_registry, event: &language_model::Event, cx| match event {
+ language_model::Event::ProviderStateChanged(_)
+ | language_model::Event::AddedProvider(_)
+ | language_model::Event::RemovedProvider(_) => {
+ update_environment_fallback_model(cx);
+ }
+ _ => {}
+ },
+ )
+ .detach();
+
let registry = registry.downgrade();
cx.observe_global::<SettingsStore>(move |cx| {
let Some(registry) = registry.upgrade() else {
@@ -143,6 +159,50 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
.detach();
}
+/// Recomputes and sets the [`LanguageModelRegistry`]'s environment fallback
+/// model based on currently authenticated providers.
+///
+/// Prefers the Zed cloud provider so that, once the user is signed in, we
+/// always pick a Zed-hosted model over models from other authenticated
+/// providers in the environment. If the Zed cloud provider is authenticated
+/// but hasn't finished loading its models yet, we don't fall back to another
+/// provider to avoid flickering between providers during sign in.
+pub fn update_environment_fallback_model(cx: &mut App) {
+ let registry = LanguageModelRegistry::global(cx);
+ let fallback_model = {
+ let registry = registry.read(cx);
+ let cloud_provider = registry.provider(&ZED_CLOUD_PROVIDER_ID);
+ if cloud_provider
+ .as_ref()
+ .is_some_and(|provider| provider.is_authenticated(cx))
+ {
+ cloud_provider.and_then(|provider| {
+ let model = provider
+ .default_model(cx)
+ .or_else(|| provider.recommended_models(cx).first().cloned())?;
+ Some(ConfiguredModel { provider, model })
+ })
+ } else {
+ registry
+ .providers()
+ .iter()
+ .filter(|provider| provider.is_authenticated(cx))
+ .find_map(|provider| {
+ let model = provider
+ .default_model(cx)
+ .or_else(|| provider.recommended_models(cx).first().cloned())?;
+ Some(ConfiguredModel {
+ provider: provider.clone(),
+ model,
+ })
+ })
+ }
+ };
+ registry.update(cx, |registry, cx| {
+ registry.set_environment_fallback_model(fallback_model, cx);
+ });
+}
+
fn register_openai_compatible_providers(
registry: &mut LanguageModelRegistry,
old: &HashSet<Arc<str>>,