Detailed changes
@@ -664,7 +664,7 @@ impl Thread {
}
pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
- if self.configured_model.is_none() {
+ if self.configured_model.is_none() || self.messages.is_empty() {
self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
}
self.configured_model.clone()
@@ -2097,7 +2097,7 @@ impl Thread {
}
pub fn summarize(&mut self, cx: &mut Context<Self>) {
- let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
+ let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model(cx) else {
println!("No thread summary model");
return;
};
@@ -2416,7 +2416,7 @@ impl Thread {
}
let Some(ConfiguredModel { model, provider }) =
- LanguageModelRegistry::read_global(cx).thread_summary_model()
+ LanguageModelRegistry::read_global(cx).thread_summary_model(cx)
else {
return;
};
@@ -5410,13 +5410,10 @@ fn main() {{
}),
cx,
);
- registry.set_thread_summary_model(
- Some(ConfiguredModel {
- provider,
- model: model.clone(),
- }),
- cx,
- );
+ registry.set_thread_summary_model(Some(ConfiguredModel {
+ provider,
+ model: model.clone(),
+ }));
})
});
@@ -228,7 +228,7 @@ impl NativeAgent {
) -> Entity<AcpThread> {
let connection = Rc::new(NativeAgentConnection(cx.entity()));
let registry = LanguageModelRegistry::read_global(cx);
- let summarization_model = registry.thread_summary_model().map(|c| c.model);
+ let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
thread_handle.update(cx, |thread, cx| {
thread.set_summarization_model(summarization_model, cx);
@@ -521,7 +521,7 @@ impl NativeAgent {
let registry = LanguageModelRegistry::read_global(cx);
let default_model = registry.default_model().map(|m| m.model);
- let summarization_model = registry.thread_summary_model().map(|m| m.model);
+ let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, cx| {
@@ -1414,11 +1414,11 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let clock = Arc::new(clock::FakeSystemClock::new());
let client = Client::new(clock, http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ Project::init_settings(cx);
+ agent_settings::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store, client.clone(), cx);
- Project::init_settings(cx);
LanguageModelRegistry::test(cx);
- agent_settings::init(cx);
});
cx.executor().forbid_parking();
@@ -6,8 +6,7 @@ use feature_flags::ZedProFeatureFlag;
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
use language_model::{
- AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
- LanguageModelRegistry,
+ ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
@@ -77,7 +76,6 @@ pub struct LanguageModelPickerDelegate {
all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
- _authenticate_all_providers_task: Task<()>,
_subscriptions: Vec<Subscription>,
}
@@ -98,7 +96,6 @@ impl LanguageModelPickerDelegate {
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
filtered_entries: entries,
get_active_model: Arc::new(get_active_model),
- _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
_subscriptions: vec![cx.subscribe_in(
&LanguageModelRegistry::global(cx),
window,
@@ -142,56 +139,6 @@ impl LanguageModelPickerDelegate {
.unwrap_or(0)
}
- /// Authenticates all providers in the [`LanguageModelRegistry`].
- ///
- /// We do this so that we can populate the language selector with all of the
- /// models from the configured providers.
- fn authenticate_all_providers(cx: &mut App) -> Task<()> {
- let authenticate_all_providers = LanguageModelRegistry::global(cx)
- .read(cx)
- .providers()
- .iter()
- .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
- .collect::<Vec<_>>();
-
- cx.spawn(async move |_cx| {
- for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
- if let Err(err) = authenticate_task.await {
- if matches!(err, AuthenticateError::CredentialsNotFound) {
- // Since we're authenticating these providers in the
- // background for the purposes of populating the
- // language selector, we don't care about providers
- // where the credentials are not found.
- } else {
- // Some providers have noisy failure states that we
- // don't want to spam the logs with every time the
- // language model selector is initialized.
- //
- // Ideally these should have more clear failure modes
- // that we know are safe to ignore here, like what we do
- // with `CredentialsNotFound` above.
- match provider_id.0.as_ref() {
- "lmstudio" | "ollama" => {
- // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
- //
- // These fail noisily, so we don't log them.
- }
- "copilot_chat" => {
- // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
- }
- _ => {
- log::error!(
- "Failed to authenticate provider: {}: {err}",
- provider_name.0
- );
- }
- }
- }
- }
- }
- })
- }
-
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
(self.get_active_model)(cx)
}
@@ -4466,7 +4466,7 @@ fn current_language_model(cx: &Context<'_, GitPanel>) -> Option<Arc<dyn Language
is_enabled
.then(|| {
let ConfiguredModel { provider, model } =
- LanguageModelRegistry::read_global(cx).commit_message_model()?;
+ LanguageModelRegistry::read_global(cx).commit_message_model(cx)?;
provider.is_authenticated(cx).then(|| model)
})
@@ -6,7 +6,6 @@ use collections::BTreeMap;
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
-use util::maybe;
pub fn init(cx: &mut App) {
let registry = cx.new(|_cx| LanguageModelRegistry::default());
@@ -48,7 +47,9 @@ impl std::fmt::Debug for ConfigurationError {
#[derive(Default)]
pub struct LanguageModelRegistry {
default_model: Option<ConfiguredModel>,
- default_fast_model: Option<ConfiguredModel>,
+ /// This model is automatically configured by a user's environment after
+ /// authenticating all providers. It's only used when default_model is not available.
+ environment_fallback_model: Option<ConfiguredModel>,
inline_assistant_model: Option<ConfiguredModel>,
commit_message_model: Option<ConfiguredModel>,
thread_summary_model: Option<ConfiguredModel>,
@@ -104,9 +105,6 @@ impl ConfiguredModel {
pub enum Event {
DefaultModelChanged,
- InlineAssistantModelChanged,
- CommitMessageModelChanged,
- ThreadSummaryModelChanged,
ProviderStateChanged(LanguageModelProviderId),
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
@@ -238,7 +236,7 @@ impl LanguageModelRegistry {
cx: &mut Context<Self>,
) {
let configured_model = model.and_then(|model| self.select_model(model, cx));
- self.set_inline_assistant_model(configured_model, cx);
+ self.set_inline_assistant_model(configured_model);
}
pub fn select_commit_message_model(
@@ -247,7 +245,7 @@ impl LanguageModelRegistry {
cx: &mut Context<Self>,
) {
let configured_model = model.and_then(|model| self.select_model(model, cx));
- self.set_commit_message_model(configured_model, cx);
+ self.set_commit_message_model(configured_model);
}
pub fn select_thread_summary_model(
@@ -256,7 +254,7 @@ impl LanguageModelRegistry {
cx: &mut Context<Self>,
) {
let configured_model = model.and_then(|model| self.select_model(model, cx));
- self.set_thread_summary_model(configured_model, cx);
+ self.set_thread_summary_model(configured_model);
}
/// Selects and sets the inline alternatives for language models based on
@@ -290,68 +288,60 @@ impl LanguageModelRegistry {
}
pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
- match (self.default_model.as_ref(), model.as_ref()) {
+ match (self.default_model(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::DefaultModelChanged),
}
- self.default_fast_model = maybe!({
- let provider = &model.as_ref()?.provider;
- let fast_model = provider.default_fast_model(cx)?;
- Some(ConfiguredModel {
- provider: provider.clone(),
- model: fast_model,
- })
- });
self.default_model = model;
}
- pub fn set_inline_assistant_model(
+ pub fn set_environment_fallback_model(
&mut self,
model: Option<ConfiguredModel>,
cx: &mut Context<Self>,
) {
- match (self.inline_assistant_model.as_ref(), model.as_ref()) {
- (Some(old), Some(new)) if old.is_same_as(new) => {}
- (None, None) => {}
- _ => cx.emit(Event::InlineAssistantModelChanged),
+ if self.default_model.is_none() {
+ match (self.environment_fallback_model.as_ref(), model.as_ref()) {
+ (Some(old), Some(new)) if old.is_same_as(new) => {}
+ (None, None) => {}
+ _ => cx.emit(Event::DefaultModelChanged),
+ }
}
+ self.environment_fallback_model = model;
+ }
+
+ pub fn set_inline_assistant_model(&mut self, model: Option<ConfiguredModel>) {
self.inline_assistant_model = model;
}
- pub fn set_commit_message_model(
- &mut self,
- model: Option<ConfiguredModel>,
- cx: &mut Context<Self>,
- ) {
- match (self.commit_message_model.as_ref(), model.as_ref()) {
- (Some(old), Some(new)) if old.is_same_as(new) => {}
- (None, None) => {}
- _ => cx.emit(Event::CommitMessageModelChanged),
- }
+ pub fn set_commit_message_model(&mut self, model: Option<ConfiguredModel>) {
self.commit_message_model = model;
}
- pub fn set_thread_summary_model(
- &mut self,
- model: Option<ConfiguredModel>,
- cx: &mut Context<Self>,
- ) {
- match (self.thread_summary_model.as_ref(), model.as_ref()) {
- (Some(old), Some(new)) if old.is_same_as(new) => {}
- (None, None) => {}
- _ => cx.emit(Event::ThreadSummaryModelChanged),
- }
+ pub fn set_thread_summary_model(&mut self, model: Option<ConfiguredModel>) {
self.thread_summary_model = model;
}
+ #[track_caller]
pub fn default_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
}
- self.default_model.clone()
+ self.default_model
+ .clone()
+ .or_else(|| self.environment_fallback_model.clone())
+ }
+
+ pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
+ let provider = self.default_model()?.provider;
+ let fast_model = provider.default_fast_model(cx)?;
+ Some(ConfiguredModel {
+ provider,
+ model: fast_model,
+ })
}
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
@@ -365,7 +355,7 @@ impl LanguageModelRegistry {
.or_else(|| self.default_model.clone())
}
- pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
+ pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
@@ -373,11 +363,11 @@ impl LanguageModelRegistry {
self.commit_message_model
.clone()
- .or_else(|| self.default_fast_model.clone())
+ .or_else(|| self.default_fast_model(cx))
.or_else(|| self.default_model.clone())
}
- pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
+ pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
@@ -385,7 +375,7 @@ impl LanguageModelRegistry {
self.thread_summary_model
.clone()
- .or_else(|| self.default_fast_model.clone())
+ .or_else(|| self.default_fast_model(cx))
.or_else(|| self.default_model.clone())
}
@@ -422,4 +412,34 @@ mod tests {
let providers = registry.read(cx).providers();
assert!(providers.is_empty());
}
+
+ #[gpui::test]
+ async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
+ let registry = cx.new(|_| LanguageModelRegistry::default());
+
+ let provider = FakeLanguageModelProvider::default();
+ registry.update(cx, |registry, cx| {
+ registry.register_provider(provider.clone(), cx);
+ });
+
+ cx.update(|cx| provider.authenticate(cx)).await.unwrap();
+
+ registry.update(cx, |registry, cx| {
+ let provider = registry.provider(&provider.id()).unwrap();
+
+ registry.set_environment_fallback_model(
+ Some(ConfiguredModel {
+ provider: provider.clone(),
+ model: provider.default_model(cx).unwrap(),
+ }),
+ cx,
+ );
+
+ let default_model = registry.default_model().unwrap();
+ let fallback_model = registry.environment_fallback_model.clone().unwrap();
+
+ assert_eq!(default_model.model.id(), fallback_model.model.id());
+ assert_eq!(default_model.provider.id(), fallback_model.provider.id());
+ });
+ }
}
@@ -44,6 +44,7 @@ ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
open_router = { workspace = true, features = ["schemars"] }
partial-json-fixer.workspace = true
+project.workspace = true
release_channel.workspace = true
schemars.workspace = true
serde.workspace = true
@@ -3,8 +3,12 @@ use std::sync::Arc;
use ::settings::{Settings, SettingsStore};
use client::{Client, UserStore};
use collections::HashSet;
-use gpui::{App, Context, Entity};
-use language_model::{LanguageModelProviderId, LanguageModelRegistry};
+use futures::future;
+use gpui::{App, AppContext as _, Context, Entity};
+use language_model::{
+ AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
+};
+use project::DisableAiSettings;
use provider::deepseek::DeepSeekLanguageModelProvider;
pub mod provider;
@@ -13,7 +17,7 @@ pub mod ui;
use crate::provider::anthropic::AnthropicLanguageModelProvider;
use crate::provider::bedrock::BedrockLanguageModelProvider;
-use crate::provider::cloud::CloudLanguageModelProvider;
+use crate::provider::cloud::{self, CloudLanguageModelProvider};
use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
use crate::provider::google::GoogleLanguageModelProvider;
use crate::provider::lmstudio::LmStudioLanguageModelProvider;
@@ -48,6 +52,13 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
cx,
);
});
+
+ let mut already_authenticated = false;
+ if !DisableAiSettings::get_global(cx).disable_ai {
+ authenticate_all_providers(registry.clone(), cx);
+ already_authenticated = true;
+ }
+
cx.observe_global::<SettingsStore>(move |cx| {
let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
.openai_compatible
@@ -65,6 +76,12 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
);
});
openai_compatible_providers = openai_compatible_providers_new;
+ already_authenticated = false;
+ }
+
+ if !DisableAiSettings::get_global(cx).disable_ai && !already_authenticated {
+ authenticate_all_providers(registry.clone(), cx);
+ already_authenticated = true;
}
})
.detach();
@@ -151,3 +168,83 @@ fn register_language_model_providers(
registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
}
+
+/// Authenticates all providers in the [`LanguageModelRegistry`].
+///
+/// We do this so that we can populate the language selector with all of the
+/// models from the configured providers.
+///
+/// This function won't do anything if AI is disabled.
+fn authenticate_all_providers(registry: Entity<LanguageModelRegistry>, cx: &mut App) {
+ let providers_to_authenticate = registry
+ .read(cx)
+ .providers()
+ .iter()
+ .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
+ .collect::<Vec<_>>();
+
+ let mut tasks = Vec::with_capacity(providers_to_authenticate.len());
+
+ for (provider_id, provider_name, authenticate_task) in providers_to_authenticate {
+ tasks.push(cx.background_spawn(async move {
+ if let Err(err) = authenticate_task.await {
+ if matches!(err, AuthenticateError::CredentialsNotFound) {
+ // Since we're authenticating these providers in the
+ // background for the purposes of populating the
+ // language selector, we don't care about providers
+ // where the credentials are not found.
+ } else {
+ // Some providers have noisy failure states that we
+ // don't want to spam the logs with every time the
+ // language model selector is initialized.
+ //
+ // Ideally these should have more clear failure modes
+ // that we know are safe to ignore here, like what we do
+ // with `CredentialsNotFound` above.
+ match provider_id.0.as_ref() {
+ "lmstudio" | "ollama" => {
+ // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
+ //
+ // These fail noisily, so we don't log them.
+ }
+ "copilot_chat" => {
+ // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
+ }
+ _ => {
+ log::error!(
+ "Failed to authenticate provider: {}: {err}",
+ provider_name.0
+ );
+ }
+ }
+ }
+ }
+ }));
+ }
+
+ let all_authenticated_future = future::join_all(tasks);
+
+ cx.spawn(async move |cx| {
+ all_authenticated_future.await;
+
+ registry
+ .update(cx, |registry, cx| {
+ let cloud_provider = registry.provider(&cloud::PROVIDER_ID);
+ let fallback_model = cloud_provider
+ .iter()
+ .chain(registry.providers().iter())
+ .find(|provider| provider.is_authenticated(cx))
+ .and_then(|provider| {
+ Some(ConfiguredModel {
+ provider: provider.clone(),
+ model: provider
+ .default_model(cx)
+ .or_else(|| provider.recommended_models(cx).first().cloned())?,
+ })
+ });
+ registry.set_environment_fallback_model(fallback_model, cx);
+ })
+ .ok();
+ })
+ .detach();
+}
@@ -44,8 +44,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i
use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
-const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
-const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
+pub const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
+pub const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
@@ -148,7 +148,7 @@ impl State {
default_fast_model: None,
recommended_models: Vec::new(),
_fetch_models_task: cx.spawn(async move |this, cx| {
- maybe!(async move {
+ maybe!(async {
let (client, llm_api_token) = this
.read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;