From 0c91f061c360468da86bd1cb88768ceae6f71308 Mon Sep 17 00:00:00 2001 From: "Oleksii (Alexey) Orlenko" Date: Tue, 16 Dec 2025 20:22:30 +0100 Subject: [PATCH] agent_ui: Implement favorite models selection (#44297) This PR solves my main pain point with Zed agent: I have a long list of available models from different providers, and I switch between a few of them depending on the context and the project. In particular, I use the same models from different providers depending on whether I'm working on a personal project or at my day job. Since I only care about a few models (none of which are in "recommended") that are scattered all over the list, switching between them is bothersome, even using search. This change adds a new option in `settings.json` (`agent.favorite_models`) and the UI to manipulate it directly from the list of available models. When any models are marked as favorites, they appear in a dedicated section at the very top of the list. Each model has a small icon button that appears on hover and allows to toggle whether it's marked as favorite. I implemented this on the UI level (i.e. there's no first-party knowledge about favorite models in the agent itself; in theory it could return favorite models as a group but it would make it harder to implement bespoke UI for the favorite models section and it also wouldn't work for text threads which don't use the ACP infrastructure). The feature is only enabled for the native agent but disabled for external agents because we can't easily map their model IDs to settings and there could be weird collisions between them. https://github.com/user-attachments/assets/cf23afe4-3883-45cb-9906-f55de3ea2a97 Closes https://github.com/zed-industries/zed/issues/31507 Release Notes: - Added the ability to mark language models as favorites and pin them to the top of the list. This feature is available in the native Zed agent (including text threads and the inline assistant), but not in external agents via ACP. --------- Co-authored-by: Danilo Leal Co-authored-by: Bennet Bo Fenner --- Cargo.lock | 1 + crates/acp_thread/src/connection.rs | 10 + crates/agent/src/agent.rs | 4 + crates/agent_settings/Cargo.toml | 1 + crates/agent_settings/src/agent_settings.rs | 12 +- crates/agent_ui/src/acp/model_selector.rs | 282 ++++++++++++++++-- .../manage_profiles_modal.rs | 45 ++- crates/agent_ui/src/agent_model_selector.rs | 37 ++- crates/agent_ui/src/agent_ui.rs | 2 + crates/agent_ui/src/favorite_models.rs | 57 ++++ .../agent_ui/src/language_model_selector.rs | 216 ++++++++++++-- crates/agent_ui/src/text_thread_editor.rs | 36 ++- .../src/ui/model_selector_components.rs | 35 ++- crates/settings/src/settings_content/agent.rs | 13 + 14 files changed, 653 insertions(+), 98 deletions(-) create mode 100644 crates/agent_ui/src/favorite_models.rs diff --git a/Cargo.lock b/Cargo.lock index 2d0cb8235d547c5486ffd89e3d54fd8d46a54f0c..6908a8ed5185ea71cc51a34d63990decaaf082d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -301,6 +301,7 @@ dependencies = [ name = "agent_settings" version = "0.1.0" dependencies = [ + "agent-client-protocol", "anyhow", "cloud_llm_client", "collections", diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 3c8c56b2c02cd775be030cb4c4b05a9c75f0d10f..a670ba601159ec323ad2c88695c30bf4aeae4118 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -202,6 +202,12 @@ pub trait AgentModelSelector: 'static { fn should_render_footer(&self) -> bool { false } + + /// Whether this selector supports the favorites feature. + /// Only the native agent uses the model ID format that maps to settings. + fn supports_favorites(&self) -> bool { + false + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -239,6 +245,10 @@ impl AgentModelList { AgentModelList::Grouped(groups) => groups.is_empty(), } } + + pub fn is_flat(&self) -> bool { + matches!(self, AgentModelList::Flat(_)) + } } #[cfg(feature = "test-support")] diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 693d3abd4497c057a75b4f01c07bd51f311f1fdb..5e16f74682ef95a4e990ed5a124a0d6031acfb0e 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -1164,6 +1164,10 @@ impl acp_thread::AgentModelSelector for NativeAgentModelSelector { fn should_render_footer(&self) -> bool { true } + + fn supports_favorites(&self) -> bool { + true + } } impl acp_thread::AgentConnection for NativeAgentConnection { diff --git a/crates/agent_settings/Cargo.toml b/crates/agent_settings/Cargo.toml index 8ddcac24fe054d1226f2bbac49498fd35d6ed1c3..0d7163549f0a4b172773c9ac95dcbc84b7212667 100644 --- a/crates/agent_settings/Cargo.toml +++ b/crates/agent_settings/Cargo.toml @@ -12,6 +12,7 @@ workspace = true path = "src/agent_settings.rs" [dependencies] +agent-client-protocol.workspace = true anyhow.workspace = true cloud_llm_client.workspace = true collections.workspace = true diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 25ca5c78d6b76145a1b1b5d19ac86246ff419d1d..b513ec1a70b6f7ab02382dfa312ea2d4d6a47234 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -2,7 +2,8 @@ mod agent_profile; use std::sync::Arc; -use collections::IndexMap; +use agent_client_protocol::ModelId; +use collections::{HashSet, IndexMap}; use gpui::{App, Pixels, px}; use language_model::LanguageModel; use project::DisableAiSettings; @@ -33,6 +34,7 @@ pub struct AgentSettings { pub commit_message_model: Option, pub thread_summary_model: Option, pub inline_alternatives: Vec, + pub favorite_models: Vec, pub default_profile: AgentProfileId, pub default_view: DefaultAgentView, pub profiles: IndexMap, @@ -96,6 +98,13 @@ impl AgentSettings { pub fn set_message_editor_max_lines(&self) -> usize { self.message_editor_min_lines * 2 } + + pub fn favorite_model_ids(&self) -> HashSet { + self.favorite_models + .iter() + .map(|sel| ModelId::new(format!("{}/{}", sel.provider.0, sel.model))) + .collect() + } } #[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Default)] @@ -164,6 +173,7 @@ impl Settings for AgentSettings { commit_message_model: agent.commit_message_model, thread_summary_model: agent.thread_summary_model, inline_alternatives: agent.inline_alternatives.unwrap_or_default(), + favorite_models: agent.favorite_models, default_profile: AgentProfileId(agent.default_profile.unwrap()), default_view: agent.default_view.unwrap(), profiles: agent diff --git a/crates/agent_ui/src/acp/model_selector.rs b/crates/agent_ui/src/acp/model_selector.rs index 658b88e0c2a4f0b4203c5f1191c0a49cb4ad6fd5..f885ff12e598168abdf7727dc03e4814e5de3b49 100644 --- a/crates/agent_ui/src/acp/model_selector.rs +++ b/crates/agent_ui/src/acp/model_selector.rs @@ -1,18 +1,22 @@ use std::{cmp::Reverse, rc::Rc, sync::Arc}; use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector}; +use agent_client_protocol::ModelId; use agent_servers::AgentServer; +use agent_settings::AgentSettings; use anyhow::Result; -use collections::IndexMap; +use collections::{HashSet, IndexMap}; use fs::Fs; use futures::FutureExt; use fuzzy::{StringMatchCandidate, match_strings}; use gpui::{ Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity, }; +use itertools::Itertools; use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; -use ui::{DocumentationAside, DocumentationEdge, DocumentationSide, prelude::*}; +use settings::Settings; +use ui::{DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, prelude::*}; use util::ResultExt; use zed_actions::agent::OpenSettings; @@ -38,7 +42,7 @@ pub fn acp_model_selector( enum AcpModelPickerEntry { Separator(SharedString), - Model(AgentModelInfo), + Model(AgentModelInfo, bool), } pub struct AcpModelPickerDelegate { @@ -140,7 +144,7 @@ impl PickerDelegate for AcpModelPickerDelegate { _cx: &mut Context>, ) -> bool { match self.filtered_entries.get(ix) { - Some(AcpModelPickerEntry::Model(_)) => true, + Some(AcpModelPickerEntry::Model(_, _)) => true, Some(AcpModelPickerEntry::Separator(_)) | None => false, } } @@ -155,6 +159,12 @@ impl PickerDelegate for AcpModelPickerDelegate { window: &mut Window, cx: &mut Context>, ) -> Task<()> { + let favorites = if self.selector.supports_favorites() { + Arc::new(AgentSettings::get_global(cx).favorite_model_ids()) + } else { + Default::default() + }; + cx.spawn_in(window, async move |this, cx| { let filtered_models = match this .read_with(cx, |this, cx| { @@ -171,7 +181,7 @@ impl PickerDelegate for AcpModelPickerDelegate { this.update_in(cx, |this, window, cx| { this.delegate.filtered_entries = - info_list_to_picker_entries(filtered_models).collect(); + info_list_to_picker_entries(filtered_models, favorites); // Finds the currently selected model in the list let new_index = this .delegate @@ -179,7 +189,7 @@ impl PickerDelegate for AcpModelPickerDelegate { .as_ref() .and_then(|selected| { this.delegate.filtered_entries.iter().position(|entry| { - if let AcpModelPickerEntry::Model(model_info) = entry { + if let AcpModelPickerEntry::Model(model_info, _) = entry { model_info.id == selected.id } else { false @@ -195,7 +205,7 @@ impl PickerDelegate for AcpModelPickerDelegate { } fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { - if let Some(AcpModelPickerEntry::Model(model_info)) = + if let Some(AcpModelPickerEntry::Model(model_info, _)) = self.filtered_entries.get(self.selected_index) { if window.modifiers().secondary() { @@ -233,7 +243,7 @@ impl PickerDelegate for AcpModelPickerDelegate { fn render_match( &self, ix: usize, - is_focused: bool, + selected: bool, _: &mut Window, cx: &mut Context>, ) -> Option { @@ -241,32 +251,53 @@ impl PickerDelegate for AcpModelPickerDelegate { AcpModelPickerEntry::Separator(title) => { Some(ModelSelectorHeader::new(title, ix > 1).into_any_element()) } - AcpModelPickerEntry::Model(model_info) => { + AcpModelPickerEntry::Model(model_info, is_favorite) => { let is_selected = Some(model_info) == self.selected_model.as_ref(); let default_model = self.agent_server.default_model(cx); let is_default = default_model.as_ref() == Some(&model_info.id); + let supports_favorites = self.selector.supports_favorites(); + + let is_favorite = *is_favorite; + let handle_action_click = { + let model_id = model_info.id.clone(); + let fs = self.fs.clone(); + + move |cx: &App| { + crate::favorite_models::toggle_model_id_in_settings( + model_id.clone(), + !is_favorite, + fs.clone(), + cx, + ); + } + }; + Some( div() .id(("model-picker-menu-child", ix)) .when_some(model_info.description.clone(), |this, description| { - this - .on_hover(cx.listener(move |menu, hovered, _, cx| { - if *hovered { - menu.delegate.selected_description = Some((ix, description.clone(), is_default)); - } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) { - menu.delegate.selected_description = None; - } - cx.notify(); - })) + this.on_hover(cx.listener(move |menu, hovered, _, cx| { + if *hovered { + menu.delegate.selected_description = + Some((ix, description.clone(), is_default)); + } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) { + menu.delegate.selected_description = None; + } + cx.notify(); + })) }) .child( ModelSelectorListItem::new(ix, model_info.name.clone()) - .is_focused(is_focused) + .when_some(model_info.icon, |this, icon| this.icon(icon)) .is_selected(is_selected) - .when_some(model_info.icon, |this, icon| this.icon(icon)), + .is_focused(selected) + .when(supports_favorites, |this| { + this.is_favorite(is_favorite) + .on_toggle_favorite(handle_action_click) + }), ) - .into_any_element() + .into_any_element(), ) } } @@ -314,18 +345,51 @@ impl PickerDelegate for AcpModelPickerDelegate { fn info_list_to_picker_entries( model_list: AgentModelList, -) -> impl Iterator { + favorites: Arc>, +) -> Vec { + let mut entries = Vec::new(); + + let all_models: Vec<_> = match &model_list { + AgentModelList::Flat(list) => list.iter().collect(), + AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(), + }; + + let favorite_models: Vec<_> = all_models + .iter() + .filter(|m| favorites.contains(&m.id)) + .unique_by(|m| &m.id) + .collect(); + + let has_favorites = !favorite_models.is_empty(); + if has_favorites { + entries.push(AcpModelPickerEntry::Separator("Favorite".into())); + for model in favorite_models { + entries.push(AcpModelPickerEntry::Model((*model).clone(), true)); + } + } + match model_list { AgentModelList::Flat(list) => { - itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model)) + if has_favorites { + entries.push(AcpModelPickerEntry::Separator("All".into())); + } + for model in list { + let is_favorite = favorites.contains(&model.id); + entries.push(AcpModelPickerEntry::Model(model, is_favorite)); + } } AgentModelList::Grouped(index_map) => { - itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| { - std::iter::once(AcpModelPickerEntry::Separator(group_name.0)) - .chain(models.into_iter().map(AcpModelPickerEntry::Model)) - })) + for (group_name, models) in index_map { + entries.push(AcpModelPickerEntry::Separator(group_name.0)); + for model in models { + let is_favorite = favorites.contains(&model.id); + entries.push(AcpModelPickerEntry::Model(model, is_favorite)); + } + } } } + + entries } async fn fuzzy_search( @@ -447,6 +511,170 @@ mod tests { } } + fn create_favorites(models: Vec<&str>) -> Arc> { + Arc::new( + models + .into_iter() + .map(|m| ModelId::new(m.to_string())) + .collect(), + ) + } + + fn get_entry_model_ids(entries: &[AcpModelPickerEntry]) -> Vec<&str> { + entries + .iter() + .filter_map(|entry| match entry { + AcpModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()), + _ => None, + }) + .collect() + } + + fn get_entry_labels(entries: &[AcpModelPickerEntry]) -> Vec<&str> { + entries + .iter() + .map(|entry| match entry { + AcpModelPickerEntry::Model(info, _) => info.id.0.as_ref(), + AcpModelPickerEntry::Separator(s) => &s, + }) + .collect() + } + + #[gpui::test] + fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) { + let models = create_model_list(vec![ + ("zed", vec!["zed/claude", "zed/gemini"]), + ("openai", vec!["openai/gpt-5"]), + ]); + let favorites = create_favorites(vec!["zed/gemini"]); + + let entries = info_list_to_picker_entries(models, favorites); + + assert!(matches!( + entries.first(), + Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite" + )); + + let model_ids = get_entry_model_ids(&entries); + assert_eq!(model_ids[0], "zed/gemini"); + } + + #[gpui::test] + fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) { + let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]); + let favorites = create_favorites(vec![]); + + let entries = info_list_to_picker_entries(models, favorites); + + assert!(matches!( + entries.first(), + Some(AcpModelPickerEntry::Separator(s)) if s == "zed" + )); + } + + #[gpui::test] + fn test_models_have_correct_actions(_cx: &mut TestAppContext) { + let models = create_model_list(vec![ + ("zed", vec!["zed/claude", "zed/gemini"]), + ("openai", vec!["openai/gpt-5"]), + ]); + let favorites = create_favorites(vec!["zed/claude"]); + + let entries = info_list_to_picker_entries(models, favorites); + + for entry in &entries { + if let AcpModelPickerEntry::Model(info, is_favorite) = entry { + if info.id.0.as_ref() == "zed/claude" { + assert!(is_favorite, "zed/claude should be a favorite"); + } else { + assert!(!is_favorite, "{} should not be a favorite", info.id.0); + } + } + } + } + + #[gpui::test] + fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) { + let models = create_model_list(vec![ + ("zed", vec!["zed/claude", "zed/gemini"]), + ("openai", vec!["openai/gpt-5", "openai/gpt-4"]), + ]); + let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]); + + let entries = info_list_to_picker_entries(models, favorites); + let model_ids = get_entry_model_ids(&entries); + + assert_eq!(model_ids[0], "zed/gemini"); + assert_eq!(model_ids[1], "openai/gpt-5"); + + assert!(model_ids[2..].contains(&"zed/gemini")); + assert!(model_ids[2..].contains(&"openai/gpt-5")); + } + + #[gpui::test] + fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) { + let models = create_model_list(vec![ + ("Recommended", vec!["zed/claude", "anthropic/claude"]), + ("Zed", vec!["zed/claude", "zed/gpt-5"]), + ("Antropic", vec!["anthropic/claude"]), + ("OpenAI", vec!["openai/gpt-5"]), + ]); + + let favorites = create_favorites(vec!["zed/claude"]); + + let entries = info_list_to_picker_entries(models, favorites); + let labels = get_entry_labels(&entries); + + assert_eq!( + labels, + vec![ + "Favorite", + "zed/claude", + "Recommended", + "zed/claude", + "anthropic/claude", + "Zed", + "zed/claude", + "zed/gpt-5", + "Antropic", + "anthropic/claude", + "OpenAI", + "openai/gpt-5" + ] + ); + } + + #[gpui::test] + fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) { + let models = AgentModelList::Flat(vec![ + acp_thread::AgentModelInfo { + id: acp::ModelId::new("zed/claude".to_string()), + name: "Claude".into(), + description: None, + icon: None, + }, + acp_thread::AgentModelInfo { + id: acp::ModelId::new("zed/gemini".to_string()), + name: "Gemini".into(), + description: None, + icon: None, + }, + ]); + let favorites = create_favorites(vec!["zed/gemini"]); + + let entries = info_list_to_picker_entries(models, favorites); + + assert!(matches!( + entries.first(), + Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite" + )); + + assert!(entries.iter().any(|e| matches!( + e, + AcpModelPickerEntry::Separator(s) if s == "All" + ))); + } + #[gpui::test] async fn test_fuzzy_match(cx: &mut TestAppContext) { let models = create_model_list(vec![ diff --git a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs index ed00b2b5c716fdf27abc1c9d7c5850b36fce830f..127852fd50e81cf56ae37a7af430f88ae2accf99 100644 --- a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs +++ b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs @@ -222,7 +222,6 @@ impl ManageProfilesModal { let profile_id_for_closure = profile_id.clone(); let model_picker = cx.new(|cx| { - let fs = fs.clone(); let profile_id = profile_id_for_closure.clone(); language_model_selector( @@ -250,22 +249,36 @@ impl ManageProfilesModal { }) } }, - move |model, cx| { - let provider = model.provider_id().0.to_string(); - let model_id = model.id().0.to_string(); - let profile_id = profile_id.clone(); - - update_settings_file(fs.clone(), cx, move |settings, _cx| { - let agent_settings = settings.agent.get_or_insert_default(); - if let Some(profiles) = agent_settings.profiles.as_mut() { - if let Some(profile) = profiles.get_mut(profile_id.0.as_ref()) { - profile.default_model = Some(LanguageModelSelection { - provider: LanguageModelProviderSetting(provider.clone()), - model: model_id.clone(), - }); + { + let fs = fs.clone(); + move |model, cx| { + let provider = model.provider_id().0.to_string(); + let model_id = model.id().0.to_string(); + let profile_id = profile_id.clone(); + + update_settings_file(fs.clone(), cx, move |settings, _cx| { + let agent_settings = settings.agent.get_or_insert_default(); + if let Some(profiles) = agent_settings.profiles.as_mut() { + if let Some(profile) = profiles.get_mut(profile_id.0.as_ref()) { + profile.default_model = Some(LanguageModelSelection { + provider: LanguageModelProviderSetting(provider.clone()), + model: model_id.clone(), + }); + } } - } - }); + }); + } + }, + { + let fs = fs.clone(); + move |model, should_be_favorite, cx| { + crate::favorite_models::toggle_in_settings( + model, + should_be_favorite, + fs.clone(), + cx, + ); + } }, false, // Do not use popover styles for the model picker self.focus_handle.clone(), diff --git a/crates/agent_ui/src/agent_model_selector.rs b/crates/agent_ui/src/agent_model_selector.rs index 9c2634143099d2097b5c6492f81c56aa51f12491..ac57ed575d9d1b6de2c53d3e0e4a91b4bd16ab1a 100644 --- a/crates/agent_ui/src/agent_model_selector.rs +++ b/crates/agent_ui/src/agent_model_selector.rs @@ -29,26 +29,39 @@ impl AgentModelSelector { Self { selector: cx.new(move |cx| { - let fs = fs.clone(); language_model_selector( { let model_context = model_usage_context.clone(); move |cx| model_context.configured_model(cx) }, - move |model, cx| { - let provider = model.provider_id().0.to_string(); - let model_id = model.id().0.to_string(); - match &model_usage_context { - ModelUsageContext::InlineAssistant => { - update_settings_file(fs.clone(), cx, move |settings, _cx| { - settings - .agent - .get_or_insert_default() - .set_inline_assistant_model(provider.clone(), model_id); - }); + { + let fs = fs.clone(); + move |model, cx| { + let provider = model.provider_id().0.to_string(); + let model_id = model.id().0.to_string(); + match &model_usage_context { + ModelUsageContext::InlineAssistant => { + update_settings_file(fs.clone(), cx, move |settings, _cx| { + settings + .agent + .get_or_insert_default() + .set_inline_assistant_model(provider.clone(), model_id); + }); + } } } }, + { + let fs = fs.clone(); + move |model, should_be_favorite, cx| { + crate::favorite_models::toggle_in_settings( + model, + should_be_favorite, + fs.clone(), + cx, + ); + } + }, true, // Use popover styles for picker focus_handle_clone, window, diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 4f759d6a9c7687d2cdf29752c489db2fcb1ffe68..1622d17f5852d825b9c8d69996fad7c89bb89dce 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -7,6 +7,7 @@ mod buffer_codegen; mod completion_provider; mod context; mod context_server_configuration; +mod favorite_models; mod inline_assistant; mod inline_prompt_editor; mod language_model_selector; @@ -467,6 +468,7 @@ mod tests { commit_message_model: None, thread_summary_model: None, inline_alternatives: vec![], + favorite_models: vec![], default_profile: AgentProfileId::default(), default_view: DefaultAgentView::Thread, profiles: Default::default(), diff --git a/crates/agent_ui/src/favorite_models.rs b/crates/agent_ui/src/favorite_models.rs new file mode 100644 index 0000000000000000000000000000000000000000..d8d4db976fc9916973eedd9174925fba75a06b2b --- /dev/null +++ b/crates/agent_ui/src/favorite_models.rs @@ -0,0 +1,57 @@ +use std::sync::Arc; + +use agent_client_protocol::ModelId; +use fs::Fs; +use language_model::LanguageModel; +use settings::{LanguageModelSelection, update_settings_file}; +use ui::App; + +fn language_model_to_selection(model: &Arc) -> LanguageModelSelection { + LanguageModelSelection { + provider: model.provider_id().to_string().into(), + model: model.id().0.to_string(), + } +} + +fn model_id_to_selection(model_id: &ModelId) -> LanguageModelSelection { + let id = model_id.0.as_ref(); + let (provider, model) = id.split_once('/').unwrap_or(("", id)); + LanguageModelSelection { + provider: provider.to_owned().into(), + model: model.to_owned(), + } +} + +pub fn toggle_in_settings( + model: Arc, + should_be_favorite: bool, + fs: Arc, + cx: &App, +) { + let selection = language_model_to_selection(&model); + update_settings_file(fs, cx, move |settings, _| { + let agent = settings.agent.get_or_insert_default(); + if should_be_favorite { + agent.add_favorite_model(selection.clone()); + } else { + agent.remove_favorite_model(&selection); + } + }); +} + +pub fn toggle_model_id_in_settings( + model_id: ModelId, + should_be_favorite: bool, + fs: Arc, + cx: &App, +) { + let selection = model_id_to_selection(&model_id); + update_settings_file(fs, cx, move |settings, _| { + let agent = settings.agent.get_or_insert_default(); + if should_be_favorite { + agent.add_favorite_model(selection.clone()); + } else { + agent.remove_favorite_model(&selection); + } + }); +} diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 7e1c35eba45bf9a79d42b59374c8cdb2aa0cac21..7bb42fb330dcccb4b5401217d0181d3d616fe66f 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -1,16 +1,18 @@ use std::{cmp::Reverse, sync::Arc}; -use collections::IndexMap; +use agent_settings::AgentSettings; +use collections::{HashMap, HashSet, IndexMap}; use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; use gpui::{ Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task, }; use language_model::{ - AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId, - LanguageModelRegistry, + AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProvider, + LanguageModelProviderId, LanguageModelRegistry, }; use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; +use settings::Settings; use ui::prelude::*; use zed_actions::agent::OpenSettings; @@ -18,12 +20,14 @@ use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem} type OnModelChanged = Arc, &mut App) + 'static>; type GetActiveModel = Arc Option + 'static>; +type OnToggleFavorite = Arc, bool, &App) + 'static>; pub type LanguageModelSelector = Picker; pub fn language_model_selector( get_active_model: impl Fn(&App) -> Option + 'static, on_model_changed: impl Fn(Arc, &mut App) + 'static, + on_toggle_favorite: impl Fn(Arc, bool, &App) + 'static, popover_styles: bool, focus_handle: FocusHandle, window: &mut Window, @@ -32,6 +36,7 @@ pub fn language_model_selector( let delegate = LanguageModelPickerDelegate::new( get_active_model, on_model_changed, + on_toggle_favorite, popover_styles, focus_handle, window, @@ -49,7 +54,17 @@ pub fn language_model_selector( } fn all_models(cx: &App) -> GroupedModels { - let providers = LanguageModelRegistry::global(cx).read(cx).providers(); + let lm_registry = LanguageModelRegistry::global(cx).read(cx); + let providers = lm_registry.providers(); + + let mut favorites_index = FavoritesIndex::default(); + + for sel in &AgentSettings::get_global(cx).favorite_models { + favorites_index + .entry(sel.provider.0.clone().into()) + .or_default() + .insert(sel.model.clone().into()); + } let recommended = providers .iter() @@ -57,10 +72,7 @@ fn all_models(cx: &App) -> GroupedModels { provider .recommended_models(cx) .into_iter() - .map(|model| ModelInfo { - model, - icon: provider.icon(), - }) + .map(|model| ModelInfo::new(&**provider, model, &favorites_index)) }) .collect(); @@ -70,25 +82,44 @@ fn all_models(cx: &App) -> GroupedModels { provider .provided_models(cx) .into_iter() - .map(|model| ModelInfo { - model, - icon: provider.icon(), - }) + .map(|model| ModelInfo::new(&**provider, model, &favorites_index)) }) .collect(); GroupedModels::new(all, recommended) } +type FavoritesIndex = HashMap>; + #[derive(Clone)] struct ModelInfo { model: Arc, icon: IconName, + is_favorite: bool, +} + +impl ModelInfo { + fn new( + provider: &dyn LanguageModelProvider, + model: Arc, + favorites_index: &FavoritesIndex, + ) -> Self { + let is_favorite = favorites_index + .get(&provider.id()) + .map_or(false, |set| set.contains(&model.id())); + + Self { + model, + icon: provider.icon(), + is_favorite, + } + } } pub struct LanguageModelPickerDelegate { on_model_changed: OnModelChanged, get_active_model: GetActiveModel, + on_toggle_favorite: OnToggleFavorite, all_models: Arc, filtered_entries: Vec, selected_index: usize, @@ -102,6 +133,7 @@ impl LanguageModelPickerDelegate { fn new( get_active_model: impl Fn(&App) -> Option + 'static, on_model_changed: impl Fn(Arc, &mut App) + 'static, + on_toggle_favorite: impl Fn(Arc, bool, &App) + 'static, popover_styles: bool, focus_handle: FocusHandle, window: &mut Window, @@ -117,6 +149,7 @@ impl LanguageModelPickerDelegate { selected_index: Self::get_active_model_index(&entries, get_active_model(cx)), filtered_entries: entries, get_active_model: Arc::new(get_active_model), + on_toggle_favorite: Arc::new(on_toggle_favorite), _authenticate_all_providers_task: Self::authenticate_all_providers(cx), _subscriptions: vec![cx.subscribe_in( &LanguageModelRegistry::global(cx), @@ -219,12 +252,19 @@ impl LanguageModelPickerDelegate { } struct GroupedModels { + favorites: Vec, recommended: Vec, all: IndexMap>, } impl GroupedModels { pub fn new(all: Vec, recommended: Vec) -> Self { + let favorites = all + .iter() + .filter(|info| info.is_favorite) + .cloned() + .collect(); + let mut all_by_provider: IndexMap<_, Vec> = IndexMap::default(); for model in all { let provider = model.model.provider_id(); @@ -236,6 +276,7 @@ impl GroupedModels { } Self { + favorites, recommended, all: all_by_provider, } @@ -244,13 +285,18 @@ impl GroupedModels { fn entries(&self) -> Vec { let mut entries = Vec::new(); + if !self.favorites.is_empty() { + entries.push(LanguageModelPickerEntry::Separator("Favorite".into())); + for info in &self.favorites { + entries.push(LanguageModelPickerEntry::Model(info.clone())); + } + } + if !self.recommended.is_empty() { entries.push(LanguageModelPickerEntry::Separator("Recommended".into())); - entries.extend( - self.recommended - .iter() - .map(|info| LanguageModelPickerEntry::Model(info.clone())), - ); + for info in &self.recommended { + entries.push(LanguageModelPickerEntry::Model(info.clone())); + } } for models in self.all.values() { @@ -260,12 +306,11 @@ impl GroupedModels { entries.push(LanguageModelPickerEntry::Separator( models[0].model.provider_name().0, )); - entries.extend( - models - .iter() - .map(|info| LanguageModelPickerEntry::Model(info.clone())), - ); + for info in models { + entries.push(LanguageModelPickerEntry::Model(info.clone())); + } } + entries } } @@ -461,7 +506,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { fn render_match( &self, ix: usize, - is_focused: bool, + selected: bool, _: &mut Window, cx: &mut Context>, ) -> Option { @@ -477,11 +522,20 @@ impl PickerDelegate for LanguageModelPickerDelegate { let is_selected = Some(model_info.model.provider_id()) == active_provider_id && Some(model_info.model.id()) == active_model_id; + let is_favorite = model_info.is_favorite; + let handle_action_click = { + let model = model_info.model.clone(); + let on_toggle_favorite = self.on_toggle_favorite.clone(); + move |cx: &App| on_toggle_favorite(model.clone(), !is_favorite, cx) + }; + Some( ModelSelectorListItem::new(ix, model_info.model.name().0) - .is_focused(is_focused) - .is_selected(is_selected) .icon(model_info.icon) + .is_selected(is_selected) + .is_focused(selected) + .is_favorite(is_favorite) + .on_toggle_favorite(handle_action_click) .into_any_element(), ) } @@ -493,12 +547,12 @@ impl PickerDelegate for LanguageModelPickerDelegate { _window: &mut Window, _cx: &mut Context>, ) -> Option { + let focus_handle = self.focus_handle.clone(); + if !self.popover_styles { return None; } - let focus_handle = self.focus_handle.clone(); - Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element()) } } @@ -598,11 +652,24 @@ mod tests { } fn create_models(model_specs: Vec<(&str, &str)>) -> Vec { + create_models_with_favorites(model_specs, vec![]) + } + + fn create_models_with_favorites( + model_specs: Vec<(&str, &str)>, + favorites: Vec<(&str, &str)>, + ) -> Vec { model_specs .into_iter() - .map(|(provider, name)| ModelInfo { - model: Arc::new(TestLanguageModel::new(name, provider)), - icon: IconName::Ai, + .map(|(provider, name)| { + let is_favorite = favorites + .iter() + .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name); + ModelInfo { + model: Arc::new(TestLanguageModel::new(name, provider)), + icon: IconName::Ai, + is_favorite, + } }) .collect() } @@ -740,4 +807,93 @@ mod tests { vec!["zed/claude", "zed/gemini", "copilot/claude"], ); } + + #[gpui::test] + fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) { + let recommended_models = create_models(vec![("zed", "claude")]); + let all_models = create_models_with_favorites( + vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")], + vec![("zed", "gemini")], + ); + + let grouped_models = GroupedModels::new(all_models, recommended_models); + let entries = grouped_models.entries(); + + assert!(matches!( + entries.first(), + Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite" + )); + + assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]); + } + + #[gpui::test] + fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) { + let recommended_models = create_models(vec![("zed", "claude")]); + let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]); + + let grouped_models = GroupedModels::new(all_models, recommended_models); + let entries = grouped_models.entries(); + + assert!(matches!( + entries.first(), + Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended" + )); + + assert!(grouped_models.favorites.is_empty()); + } + + #[gpui::test] + fn test_models_have_correct_actions(_cx: &mut TestAppContext) { + let recommended_models = + create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]); + let all_models = create_models_with_favorites( + vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")], + vec![("zed", "claude")], + ); + + let grouped_models = GroupedModels::new(all_models, recommended_models); + let entries = grouped_models.entries(); + + for entry in &entries { + if let LanguageModelPickerEntry::Model(info) = entry { + if info.model.telemetry_id() == "zed/claude" { + assert!(info.is_favorite, "zed/claude should be a favorite"); + } else { + assert!( + !info.is_favorite, + "{} should not be a favorite", + info.model.telemetry_id() + ); + } + } + } + } + + #[gpui::test] + fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) { + let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")]; + + let recommended_models = + create_models_with_favorites(vec![("zed", "claude")], favorites.clone()); + + let all_models = create_models_with_favorites( + vec![ + ("zed", "claude"), + ("zed", "gemini"), + ("openai", "gpt-4"), + ("openai", "gpt-3.5"), + ], + favorites, + ); + + let grouped_models = GroupedModels::new(all_models, recommended_models); + + assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]); + assert_models_eq(grouped_models.recommended, vec!["zed/claude"]); + assert_models_eq( + grouped_models.all.values().flatten().cloned().collect(), + vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"], + ); + } } diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 5e3f348c17de3cd0dae9f5fe41a2477211d6ddd8..881eb213a3886b894a778a34cb6ba129bf42c1a4 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -304,17 +304,31 @@ impl TextThreadEditor { language_model_selector: cx.new(|cx| { language_model_selector( |cx| LanguageModelRegistry::read_global(cx).default_model(), - move |model, cx| { - update_settings_file(fs.clone(), cx, move |settings, _| { - let provider = model.provider_id().0.to_string(); - let model = model.id().0.to_string(); - settings.agent.get_or_insert_default().set_model( - LanguageModelSelection { - provider: LanguageModelProviderSetting(provider), - model, - }, - ) - }); + { + let fs = fs.clone(); + move |model, cx| { + update_settings_file(fs.clone(), cx, move |settings, _| { + let provider = model.provider_id().0.to_string(); + let model = model.id().0.to_string(); + settings.agent.get_or_insert_default().set_model( + LanguageModelSelection { + provider: LanguageModelProviderSetting(provider), + model, + }, + ) + }); + } + }, + { + let fs = fs.clone(); + move |model, should_be_favorite, cx| { + crate::favorite_models::toggle_in_settings( + model, + should_be_favorite, + fs.clone(), + cx, + ); + } }, true, // Use popover styles for picker focus_handle, diff --git a/crates/agent_ui/src/ui/model_selector_components.rs b/crates/agent_ui/src/ui/model_selector_components.rs index 3218daef7c9aadae5cd45b2fc65807d8a32254bd..184c8e0ba2d3ea307c869e42a13b75f36e713c42 100644 --- a/crates/agent_ui/src/ui/model_selector_components.rs +++ b/crates/agent_ui/src/ui/model_selector_components.rs @@ -1,5 +1,5 @@ use gpui::{Action, FocusHandle, prelude::*}; -use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*}; +use ui::{ElevationIndex, KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*}; #[derive(IntoElement)] pub struct ModelSelectorHeader { @@ -42,6 +42,8 @@ pub struct ModelSelectorListItem { icon: Option, is_selected: bool, is_focused: bool, + is_favorite: bool, + on_toggle_favorite: Option>, } impl ModelSelectorListItem { @@ -52,6 +54,8 @@ impl ModelSelectorListItem { icon: None, is_selected: false, is_focused: false, + is_favorite: false, + on_toggle_favorite: None, } } @@ -69,6 +73,16 @@ impl ModelSelectorListItem { self.is_focused = is_focused; self } + + pub fn is_favorite(mut self, is_favorite: bool) -> Self { + self.is_favorite = is_favorite; + self + } + + pub fn on_toggle_favorite(mut self, handler: impl Fn(&App) + 'static) -> Self { + self.on_toggle_favorite = Some(Box::new(handler)); + self + } } impl RenderOnce for ModelSelectorListItem { @@ -79,6 +93,8 @@ impl RenderOnce for ModelSelectorListItem { Color::Muted }; + let is_favorite = self.is_favorite; + ListItem::new(self.index) .inset(true) .spacing(ListItemSpacing::Sparse) @@ -103,6 +119,23 @@ impl RenderOnce for ModelSelectorListItem { .size(IconSize::Small), ) })) + .end_hover_slot(div().pr_2().when_some(self.on_toggle_favorite, { + |this, handle_click| { + let (icon, color, tooltip) = if is_favorite { + (IconName::StarFilled, Color::Accent, "Unfavorite Model") + } else { + (IconName::Star, Color::Default, "Favorite Model") + }; + this.child( + IconButton::new(("toggle-favorite", self.index), icon) + .layer(ElevationIndex::ElevatedSurface) + .icon_color(color) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text(tooltip)) + .on_click(move |_, _, cx| (handle_click)(cx)), + ) + } + })) } } diff --git a/crates/settings/src/settings_content/agent.rs b/crates/settings/src/settings_content/agent.rs index f7a88deb7d8ba88db6497da2cf79035a64446456..d3a8e40084fc5db7fd348908b1b721617c7c8206 100644 --- a/crates/settings/src/settings_content/agent.rs +++ b/crates/settings/src/settings_content/agent.rs @@ -38,6 +38,9 @@ pub struct AgentSettingsContent { pub default_height: Option, /// The default model to use when creating new chats and for other features when a specific model is not specified. pub default_model: Option, + /// Favorite models to show at the top of the model selector. + #[serde(default)] + pub favorite_models: Vec, /// Model to use for the inline assistant. Defaults to default_model when not specified. pub inline_assistant_model: Option, /// Model to use for the inline assistant when streaming tools are enabled. @@ -176,6 +179,16 @@ impl AgentSettingsContent { pub fn set_profile(&mut self, profile_id: Arc) { self.default_profile = Some(profile_id); } + + pub fn add_favorite_model(&mut self, model: LanguageModelSelection) { + if !self.favorite_models.contains(&model) { + self.favorite_models.push(model); + } + } + + pub fn remove_favorite_model(&mut self, model: &LanguageModelSelection) { + self.favorite_models.retain(|m| m != model); + } } #[with_fallible_options]