From f517050548fb81aaad159b44b0d6183961fc305d Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 24 Feb 2025 08:29:55 +0100 Subject: [PATCH] Partially fix assistant onboarding (#25313) While investigating #24896, I noticed two issues: 1. The default configuration for the `zed.dev` provider was using the wrong string for Claude 3.5 Sonnet. This meant the provider would always result as not configured until the user selected it from the model picker, because we couldn't deserialize that string to a valid `anthropic::Model` enum variant. 2. When clicking on `Open New Chat`/`Start New Thread` in the provider configuration, we would select `Claude 3.5 Haiku` by default instead of Claude 3.5 Sonnet. Release Notes: - Fixed some issues that caused AI providers to sometimes be misconfigured. --- assets/settings/default.json | 2 +- crates/assistant/src/assistant_panel.rs | 2 +- crates/assistant2/src/assistant_panel.rs | 2 +- crates/assistant_settings/src/assistant_settings.rs | 2 +- crates/google_ai/src/google_ai.rs | 3 ++- crates/language_model/src/fake_provider.rs | 4 ++++ crates/language_model/src/language_model.rs | 1 + crates/language_models/src/provider/anthropic.rs | 11 +++++++++++ crates/language_models/src/provider/cloud.rs | 12 ++++++++++++ crates/language_models/src/provider/copilot_chat.rs | 8 ++++++++ crates/language_models/src/provider/deepseek.rs | 11 +++++++++++ crates/language_models/src/provider/google.rs | 11 +++++++++++ crates/language_models/src/provider/lmstudio.rs | 4 ++++ crates/language_models/src/provider/mistral.rs | 11 +++++++++++ crates/language_models/src/provider/ollama.rs | 4 ++++ crates/language_models/src/provider/open_ai.rs | 11 +++++++++++ 16 files changed, 94 insertions(+), 5 deletions(-) diff --git a/assets/settings/default.json b/assets/settings/default.json index d27f9a2fd133220f38161a08ef6f1c8bc345eda1..8183c3d60ef49f4c966295239383c8f054fd9550 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -581,7 +581,7 @@ // The provider to use. "provider": "zed.dev", // The model to use. - "model": "claude-3-5-sonnet" + "model": "claude-3-5-sonnet-latest" } }, // The settings for slash commands. diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index ecf2e2c4216e85e15062de61ce88970c9d979e13..e0791e003937acfa15bcdef350825e7cee0ee66d 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -978,7 +978,7 @@ impl AssistantPanel { .active_provider() .map_or(true, |p| p.id() != provider.id()) { - if let Some(model) = provider.provided_models(cx).first().cloned() { + if let Some(model) = provider.default_model(cx) { update_settings_file::( this.fs.clone(), cx, diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index f83a6dc75da6a96611eea0216b78d68575361a66..fb94a18e99ba90effdc778b4e0b1da38cc46939f 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -431,7 +431,7 @@ impl AssistantPanel { active_provider.id() != provider.id() }) { - if let Some(model) = provider.provided_models(cx).first().cloned() { + if let Some(model) = provider.default_model(cx) { update_settings_file::( self.fs.clone(), cx, diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index 56e801fad38bebd021ee93ff5619fac34aa67311..5e044282b07b49fc7eafe9cbb9ecf18ff0b5c395 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -512,7 +512,7 @@ mod tests { AssistantSettings::get_global(cx).default_model, LanguageModelSelection { provider: "zed.dev".into(), - model: "claude-3-5-sonnet".into(), + model: "claude-3-5-sonnet-latest".into(), } ); }); diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index ace7ea22c4589d3b5dc5af6490d4a889883cfe43..e885599a0fb3ecde2eed152c3f617ef5e6b2467e 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -299,7 +299,7 @@ pub struct CountTokensResponse { } #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] -#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] +#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] pub enum Model { #[serde(rename = "gemini-1.5-pro")] Gemini15Pro, @@ -308,6 +308,7 @@ pub enum Model { #[serde(rename = "gemini-2.0-pro-exp")] Gemini20Pro, #[serde(rename = "gemini-2.0-flash")] + #[default] Gemini20Flash, #[serde(rename = "gemini-2.0-flash-thinking-exp")] Gemini20FlashThinking, diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index a955638b2187c89b1e18394241e46d971b52143e..0e4c0748fc95d913af7c832cb4ed98535e192e23 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -46,6 +46,10 @@ impl LanguageModelProvider for FakeLanguageModelProvider { provider_name() } + fn default_model(&self, _cx: &App) -> Option> { + Some(Arc::new(FakeLanguageModel::default())) + } + fn provided_models(&self, _: &App) -> Vec> { vec![Arc::new(FakeLanguageModel::default())] } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 6219fda7397b659d10d27e576e036afcd511e5a5..7b50702a6ee7e02c1f61033f422c959d1623440b 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -247,6 +247,7 @@ pub trait LanguageModelProvider: 'static { fn icon(&self) -> IconName { IconName::ZedAssistant } + fn default_model(&self, cx: &App) -> Option>; fn provided_models(&self, cx: &App) -> Vec>; fn load_model(&self, _model: Arc, _cx: &App) {} fn is_authenticated(&self, cx: &App) -> bool; diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index e3ca4998feab0c6456416e8a8b9e7eba48260db3..9908929457bf46ab4d7da23457e57f62f278ded3 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -183,6 +183,17 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { IconName::AiAnthropic } + fn default_model(&self, _cx: &App) -> Option> { + let model = anthropic::Model::default(); + Some(Arc::new(AnthropicModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 05544f40db27e32096cc62769139eddef7379517..236b78527b29efdb710461bc29030091fa2ab5b3 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -272,6 +272,18 @@ impl LanguageModelProvider for CloudLanguageModelProvider { IconName::AiZed } + fn default_model(&self, cx: &App) -> Option> { + let llm_api_token = self.state.read(cx).llm_api_token.clone(); + let model = CloudModel::Anthropic(anthropic::Model::default()); + Some(Arc::new(CloudLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + llm_api_token: llm_api_token.clone(), + client: self.client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 1c4a4273ac4f956fe8b0fd815cf25599071367ec..7bf2cfe4f6ed73bfe46c27b313c7cf7eedcd57db 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -89,6 +89,14 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider { IconName::Copilot } + fn default_model(&self, _cx: &App) -> Option> { + let model = CopilotChatModel::default(); + Some(Arc::new(CopilotChatLanguageModel { + model, + request_limiter: RateLimiter::new(4), + }) as Arc) + } + fn provided_models(&self, _cx: &App) -> Vec> { CopilotChatModel::iter() .map(|model| { diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 91cc02149d74c826b2f4c8db5e6c821c1c3c86be..830e94ecb5c1239f4ecb79cd563efd05864f19d5 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -163,6 +163,17 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { IconName::AiDeepSeek } + fn default_model(&self, _cx: &App) -> Option> { + let model = deepseek::Model::Chat; + Some(Arc::new(DeepSeekLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 9e313935c2145fa9d0e76526fa65dcc74cc91b5d..0bf5001f794abe8df8d6f16eecb702ed54c067b6 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -166,6 +166,17 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { IconName::AiGoogle } + fn default_model(&self, _cx: &App) -> Option> { + let model = google_ai::Model::default(); + Some(Arc::new(GoogleLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + rate_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 76832a44e195ef350bc4aabe610b46955e92b6a9..edd07c053a67addf1b84344b204f949551b01a84 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -152,6 +152,10 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { IconName::AiLmStudio } + fn default_model(&self, cx: &App) -> Option> { + self.provided_models(cx).into_iter().next() + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models: BTreeMap = BTreeMap::default(); diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 032ee38c426bf9b52b303696911c8ed8bec0104a..80a5988cffaa033b9a4fbebc48cabe05d16acaae 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -167,6 +167,17 @@ impl LanguageModelProvider for MistralLanguageModelProvider { IconName::AiMistral } + fn default_model(&self, _cx: &App) -> Option> { + let model = mistral::Model::default(); + Some(Arc::new(MistralLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index a982eb3aa794ed33a4fd7990fb42389c3f104ff6..33ad0bcafd61a48435266d8d5cdb59b8e79e14b6 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -157,6 +157,10 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { IconName::AiOllama } + fn default_model(&self, cx: &App) -> Option> { + self.provided_models(cx).into_iter().next() + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models: BTreeMap = BTreeMap::default(); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index ee277247b821c079022286f4376c2b669ba08763..3e46983ebb7518d2afd5ac54eccf781784049b1a 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -169,6 +169,17 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { IconName::AiOpenAi } + fn default_model(&self, _cx: &App) -> Option> { + let model = open_ai::Model::default(); + Some(Arc::new(OpenAiLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + })) + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default();