Cargo.lock 🔗
@@ -301,6 +301,7 @@ dependencies = [
name = "agent_settings"
version = "0.1.0"
dependencies = [
+ "agent-client-protocol",
"anyhow",
"cloud_llm_client",
"collections",
Oleksii (Alexey) Orlenko , Danilo Leal , and Bennet Bo Fenner created
This PR solves my main pain point with Zed agent: I have a long list of
available models from different providers, and I switch between a few of
them depending on the context and the project. In particular, I use the
same models from different providers depending on whether I'm working on
a personal project or at my day job. Since I only care about a few
models (none of which are in "recommended") that are scattered all over
the list, switching between them is bothersome, even using search.
This change adds a new option in `settings.json`
(`agent.favorite_models`) and the UI to manipulate it directly from the
list of available models. When any models are marked as favorites, they
appear in a dedicated section at the very top of the list. Each model
has a small icon button that appears on hover and allows to toggle
whether it's marked as favorite.
I implemented this on the UI level (i.e. there's no first-party
knowledge about favorite models in the agent itself; in theory it could
return favorite models as a group but it would make it harder to
implement bespoke UI for the favorite models section and it also
wouldn't work for text threads which don't use the ACP infrastructure).
The feature is only enabled for the native agent but disabled for
external agents because we can't easily map their model IDs to settings
and there could be weird collisions between them.
https://github.com/user-attachments/assets/cf23afe4-3883-45cb-9906-f55de3ea2a97
Closes https://github.com/zed-industries/zed/issues/31507
Release Notes:
- Added the ability to mark language models as favorites and pin them to
the top of the list. This feature is available in the native Zed agent
(including text threads and the inline assistant), but not in external
agents via ACP.
---------
Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Cargo.lock | 1
crates/acp_thread/src/connection.rs | 10
crates/agent/src/agent.rs | 4
crates/agent_settings/Cargo.toml | 1
crates/agent_settings/src/agent_settings.rs | 12
crates/agent_ui/src/acp/model_selector.rs | 282 +
crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs | 45
crates/agent_ui/src/agent_model_selector.rs | 37
crates/agent_ui/src/agent_ui.rs | 2
crates/agent_ui/src/favorite_models.rs | 57
crates/agent_ui/src/language_model_selector.rs | 216 +
crates/agent_ui/src/text_thread_editor.rs | 36
crates/agent_ui/src/ui/model_selector_components.rs | 35
crates/settings/src/settings_content/agent.rs | 13
14 files changed, 653 insertions(+), 98 deletions(-)
@@ -301,6 +301,7 @@ dependencies = [
name = "agent_settings"
version = "0.1.0"
dependencies = [
+ "agent-client-protocol",
"anyhow",
"cloud_llm_client",
"collections",
@@ -202,6 +202,12 @@ pub trait AgentModelSelector: 'static {
fn should_render_footer(&self) -> bool {
false
}
+
+ /// Whether this selector supports the favorites feature.
+ /// Only the native agent uses the model ID format that maps to settings.
+ fn supports_favorites(&self) -> bool {
+ false
+ }
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -239,6 +245,10 @@ impl AgentModelList {
AgentModelList::Grouped(groups) => groups.is_empty(),
}
}
+
+ pub fn is_flat(&self) -> bool {
+ matches!(self, AgentModelList::Flat(_))
+ }
}
#[cfg(feature = "test-support")]
@@ -1164,6 +1164,10 @@ impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
fn should_render_footer(&self) -> bool {
true
}
+
+ fn supports_favorites(&self) -> bool {
+ true
+ }
}
impl acp_thread::AgentConnection for NativeAgentConnection {
@@ -12,6 +12,7 @@ workspace = true
path = "src/agent_settings.rs"
[dependencies]
+agent-client-protocol.workspace = true
anyhow.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
@@ -2,7 +2,8 @@ mod agent_profile;
use std::sync::Arc;
-use collections::IndexMap;
+use agent_client_protocol::ModelId;
+use collections::{HashSet, IndexMap};
use gpui::{App, Pixels, px};
use language_model::LanguageModel;
use project::DisableAiSettings;
@@ -33,6 +34,7 @@ pub struct AgentSettings {
pub commit_message_model: Option<LanguageModelSelection>,
pub thread_summary_model: Option<LanguageModelSelection>,
pub inline_alternatives: Vec<LanguageModelSelection>,
+ pub favorite_models: Vec<LanguageModelSelection>,
pub default_profile: AgentProfileId,
pub default_view: DefaultAgentView,
pub profiles: IndexMap<AgentProfileId, AgentProfileSettings>,
@@ -96,6 +98,13 @@ impl AgentSettings {
pub fn set_message_editor_max_lines(&self) -> usize {
self.message_editor_min_lines * 2
}
+
+ pub fn favorite_model_ids(&self) -> HashSet<ModelId> {
+ self.favorite_models
+ .iter()
+ .map(|sel| ModelId::new(format!("{}/{}", sel.provider.0, sel.model)))
+ .collect()
+ }
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Default)]
@@ -164,6 +173,7 @@ impl Settings for AgentSettings {
commit_message_model: agent.commit_message_model,
thread_summary_model: agent.thread_summary_model,
inline_alternatives: agent.inline_alternatives.unwrap_or_default(),
+ favorite_models: agent.favorite_models,
default_profile: AgentProfileId(agent.default_profile.unwrap()),
default_view: agent.default_view.unwrap(),
profiles: agent
@@ -1,18 +1,22 @@
use std::{cmp::Reverse, rc::Rc, sync::Arc};
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
+use agent_client_protocol::ModelId;
use agent_servers::AgentServer;
+use agent_settings::AgentSettings;
use anyhow::Result;
-use collections::IndexMap;
+use collections::{HashSet, IndexMap};
use fs::Fs;
use futures::FutureExt;
use fuzzy::{StringMatchCandidate, match_strings};
use gpui::{
Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
};
+use itertools::Itertools;
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
-use ui::{DocumentationAside, DocumentationEdge, DocumentationSide, prelude::*};
+use settings::Settings;
+use ui::{DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, prelude::*};
use util::ResultExt;
use zed_actions::agent::OpenSettings;
@@ -38,7 +42,7 @@ pub fn acp_model_selector(
enum AcpModelPickerEntry {
Separator(SharedString),
- Model(AgentModelInfo),
+ Model(AgentModelInfo, bool),
}
pub struct AcpModelPickerDelegate {
@@ -140,7 +144,7 @@ impl PickerDelegate for AcpModelPickerDelegate {
_cx: &mut Context<Picker<Self>>,
) -> bool {
match self.filtered_entries.get(ix) {
- Some(AcpModelPickerEntry::Model(_)) => true,
+ Some(AcpModelPickerEntry::Model(_, _)) => true,
Some(AcpModelPickerEntry::Separator(_)) | None => false,
}
}
@@ -155,6 +159,12 @@ impl PickerDelegate for AcpModelPickerDelegate {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
+ let favorites = if self.selector.supports_favorites() {
+ Arc::new(AgentSettings::get_global(cx).favorite_model_ids())
+ } else {
+ Default::default()
+ };
+
cx.spawn_in(window, async move |this, cx| {
let filtered_models = match this
.read_with(cx, |this, cx| {
@@ -171,7 +181,7 @@ impl PickerDelegate for AcpModelPickerDelegate {
this.update_in(cx, |this, window, cx| {
this.delegate.filtered_entries =
- info_list_to_picker_entries(filtered_models).collect();
+ info_list_to_picker_entries(filtered_models, favorites);
// Finds the currently selected model in the list
let new_index = this
.delegate
@@ -179,7 +189,7 @@ impl PickerDelegate for AcpModelPickerDelegate {
.as_ref()
.and_then(|selected| {
this.delegate.filtered_entries.iter().position(|entry| {
- if let AcpModelPickerEntry::Model(model_info) = entry {
+ if let AcpModelPickerEntry::Model(model_info, _) = entry {
model_info.id == selected.id
} else {
false
@@ -195,7 +205,7 @@ impl PickerDelegate for AcpModelPickerDelegate {
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
- if let Some(AcpModelPickerEntry::Model(model_info)) =
+ if let Some(AcpModelPickerEntry::Model(model_info, _)) =
self.filtered_entries.get(self.selected_index)
{
if window.modifiers().secondary() {
@@ -233,7 +243,7 @@ impl PickerDelegate for AcpModelPickerDelegate {
fn render_match(
&self,
ix: usize,
- is_focused: bool,
+ selected: bool,
_: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
@@ -241,32 +251,53 @@ impl PickerDelegate for AcpModelPickerDelegate {
AcpModelPickerEntry::Separator(title) => {
Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
}
- AcpModelPickerEntry::Model(model_info) => {
+ AcpModelPickerEntry::Model(model_info, is_favorite) => {
let is_selected = Some(model_info) == self.selected_model.as_ref();
let default_model = self.agent_server.default_model(cx);
let is_default = default_model.as_ref() == Some(&model_info.id);
+ let supports_favorites = self.selector.supports_favorites();
+
+ let is_favorite = *is_favorite;
+ let handle_action_click = {
+ let model_id = model_info.id.clone();
+ let fs = self.fs.clone();
+
+ move |cx: &App| {
+ crate::favorite_models::toggle_model_id_in_settings(
+ model_id.clone(),
+ !is_favorite,
+ fs.clone(),
+ cx,
+ );
+ }
+ };
+
Some(
div()
.id(("model-picker-menu-child", ix))
.when_some(model_info.description.clone(), |this, description| {
- this
- .on_hover(cx.listener(move |menu, hovered, _, cx| {
- if *hovered {
- menu.delegate.selected_description = Some((ix, description.clone(), is_default));
- } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
- menu.delegate.selected_description = None;
- }
- cx.notify();
- }))
+ this.on_hover(cx.listener(move |menu, hovered, _, cx| {
+ if *hovered {
+ menu.delegate.selected_description =
+ Some((ix, description.clone(), is_default));
+ } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
+ menu.delegate.selected_description = None;
+ }
+ cx.notify();
+ }))
})
.child(
ModelSelectorListItem::new(ix, model_info.name.clone())
- .is_focused(is_focused)
+ .when_some(model_info.icon, |this, icon| this.icon(icon))
.is_selected(is_selected)
- .when_some(model_info.icon, |this, icon| this.icon(icon)),
+ .is_focused(selected)
+ .when(supports_favorites, |this| {
+ this.is_favorite(is_favorite)
+ .on_toggle_favorite(handle_action_click)
+ }),
)
- .into_any_element()
+ .into_any_element(),
)
}
}
@@ -314,18 +345,51 @@ impl PickerDelegate for AcpModelPickerDelegate {
fn info_list_to_picker_entries(
model_list: AgentModelList,
-) -> impl Iterator<Item = AcpModelPickerEntry> {
+ favorites: Arc<HashSet<ModelId>>,
+) -> Vec<AcpModelPickerEntry> {
+ let mut entries = Vec::new();
+
+ let all_models: Vec<_> = match &model_list {
+ AgentModelList::Flat(list) => list.iter().collect(),
+ AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
+ };
+
+ let favorite_models: Vec<_> = all_models
+ .iter()
+ .filter(|m| favorites.contains(&m.id))
+ .unique_by(|m| &m.id)
+ .collect();
+
+ let has_favorites = !favorite_models.is_empty();
+ if has_favorites {
+ entries.push(AcpModelPickerEntry::Separator("Favorite".into()));
+ for model in favorite_models {
+ entries.push(AcpModelPickerEntry::Model((*model).clone(), true));
+ }
+ }
+
match model_list {
AgentModelList::Flat(list) => {
- itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
+ if has_favorites {
+ entries.push(AcpModelPickerEntry::Separator("All".into()));
+ }
+ for model in list {
+ let is_favorite = favorites.contains(&model.id);
+ entries.push(AcpModelPickerEntry::Model(model, is_favorite));
+ }
}
AgentModelList::Grouped(index_map) => {
- itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
- std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
- .chain(models.into_iter().map(AcpModelPickerEntry::Model))
- }))
+ for (group_name, models) in index_map {
+ entries.push(AcpModelPickerEntry::Separator(group_name.0));
+ for model in models {
+ let is_favorite = favorites.contains(&model.id);
+ entries.push(AcpModelPickerEntry::Model(model, is_favorite));
+ }
+ }
}
}
+
+ entries
}
async fn fuzzy_search(
@@ -447,6 +511,170 @@ mod tests {
}
}
+ fn create_favorites(models: Vec<&str>) -> Arc<HashSet<ModelId>> {
+ Arc::new(
+ models
+ .into_iter()
+ .map(|m| ModelId::new(m.to_string()))
+ .collect(),
+ )
+ }
+
+ fn get_entry_model_ids(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
+ entries
+ .iter()
+ .filter_map(|entry| match entry {
+ AcpModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()),
+ _ => None,
+ })
+ .collect()
+ }
+
+ fn get_entry_labels(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
+ entries
+ .iter()
+ .map(|entry| match entry {
+ AcpModelPickerEntry::Model(info, _) => info.id.0.as_ref(),
+ AcpModelPickerEntry::Separator(s) => &s,
+ })
+ .collect()
+ }
+
+ #[gpui::test]
+ fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
+ let models = create_model_list(vec![
+ ("zed", vec!["zed/claude", "zed/gemini"]),
+ ("openai", vec!["openai/gpt-5"]),
+ ]);
+ let favorites = create_favorites(vec!["zed/gemini"]);
+
+ let entries = info_list_to_picker_entries(models, favorites);
+
+ assert!(matches!(
+ entries.first(),
+ Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
+ ));
+
+ let model_ids = get_entry_model_ids(&entries);
+ assert_eq!(model_ids[0], "zed/gemini");
+ }
+
+ #[gpui::test]
+ fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
+ let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]);
+ let favorites = create_favorites(vec![]);
+
+ let entries = info_list_to_picker_entries(models, favorites);
+
+ assert!(matches!(
+ entries.first(),
+ Some(AcpModelPickerEntry::Separator(s)) if s == "zed"
+ ));
+ }
+
+ #[gpui::test]
+ fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
+ let models = create_model_list(vec![
+ ("zed", vec!["zed/claude", "zed/gemini"]),
+ ("openai", vec!["openai/gpt-5"]),
+ ]);
+ let favorites = create_favorites(vec!["zed/claude"]);
+
+ let entries = info_list_to_picker_entries(models, favorites);
+
+ for entry in &entries {
+ if let AcpModelPickerEntry::Model(info, is_favorite) = entry {
+ if info.id.0.as_ref() == "zed/claude" {
+ assert!(is_favorite, "zed/claude should be a favorite");
+ } else {
+ assert!(!is_favorite, "{} should not be a favorite", info.id.0);
+ }
+ }
+ }
+ }
+
+ #[gpui::test]
+ fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) {
+ let models = create_model_list(vec![
+ ("zed", vec!["zed/claude", "zed/gemini"]),
+ ("openai", vec!["openai/gpt-5", "openai/gpt-4"]),
+ ]);
+ let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]);
+
+ let entries = info_list_to_picker_entries(models, favorites);
+ let model_ids = get_entry_model_ids(&entries);
+
+ assert_eq!(model_ids[0], "zed/gemini");
+ assert_eq!(model_ids[1], "openai/gpt-5");
+
+ assert!(model_ids[2..].contains(&"zed/gemini"));
+ assert!(model_ids[2..].contains(&"openai/gpt-5"));
+ }
+
+ #[gpui::test]
+ fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) {
+ let models = create_model_list(vec![
+ ("Recommended", vec!["zed/claude", "anthropic/claude"]),
+ ("Zed", vec!["zed/claude", "zed/gpt-5"]),
+ ("Antropic", vec!["anthropic/claude"]),
+ ("OpenAI", vec!["openai/gpt-5"]),
+ ]);
+
+ let favorites = create_favorites(vec!["zed/claude"]);
+
+ let entries = info_list_to_picker_entries(models, favorites);
+ let labels = get_entry_labels(&entries);
+
+ assert_eq!(
+ labels,
+ vec![
+ "Favorite",
+ "zed/claude",
+ "Recommended",
+ "zed/claude",
+ "anthropic/claude",
+ "Zed",
+ "zed/claude",
+ "zed/gpt-5",
+ "Antropic",
+ "anthropic/claude",
+ "OpenAI",
+ "openai/gpt-5"
+ ]
+ );
+ }
+
+ #[gpui::test]
+ fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) {
+ let models = AgentModelList::Flat(vec![
+ acp_thread::AgentModelInfo {
+ id: acp::ModelId::new("zed/claude".to_string()),
+ name: "Claude".into(),
+ description: None,
+ icon: None,
+ },
+ acp_thread::AgentModelInfo {
+ id: acp::ModelId::new("zed/gemini".to_string()),
+ name: "Gemini".into(),
+ description: None,
+ icon: None,
+ },
+ ]);
+ let favorites = create_favorites(vec!["zed/gemini"]);
+
+ let entries = info_list_to_picker_entries(models, favorites);
+
+ assert!(matches!(
+ entries.first(),
+ Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
+ ));
+
+ assert!(entries.iter().any(|e| matches!(
+ e,
+ AcpModelPickerEntry::Separator(s) if s == "All"
+ )));
+ }
+
#[gpui::test]
async fn test_fuzzy_match(cx: &mut TestAppContext) {
let models = create_model_list(vec![
@@ -222,7 +222,6 @@ impl ManageProfilesModal {
let profile_id_for_closure = profile_id.clone();
let model_picker = cx.new(|cx| {
- let fs = fs.clone();
let profile_id = profile_id_for_closure.clone();
language_model_selector(
@@ -250,22 +249,36 @@ impl ManageProfilesModal {
})
}
},
- move |model, cx| {
- let provider = model.provider_id().0.to_string();
- let model_id = model.id().0.to_string();
- let profile_id = profile_id.clone();
-
- update_settings_file(fs.clone(), cx, move |settings, _cx| {
- let agent_settings = settings.agent.get_or_insert_default();
- if let Some(profiles) = agent_settings.profiles.as_mut() {
- if let Some(profile) = profiles.get_mut(profile_id.0.as_ref()) {
- profile.default_model = Some(LanguageModelSelection {
- provider: LanguageModelProviderSetting(provider.clone()),
- model: model_id.clone(),
- });
+ {
+ let fs = fs.clone();
+ move |model, cx| {
+ let provider = model.provider_id().0.to_string();
+ let model_id = model.id().0.to_string();
+ let profile_id = profile_id.clone();
+
+ update_settings_file(fs.clone(), cx, move |settings, _cx| {
+ let agent_settings = settings.agent.get_or_insert_default();
+ if let Some(profiles) = agent_settings.profiles.as_mut() {
+ if let Some(profile) = profiles.get_mut(profile_id.0.as_ref()) {
+ profile.default_model = Some(LanguageModelSelection {
+ provider: LanguageModelProviderSetting(provider.clone()),
+ model: model_id.clone(),
+ });
+ }
}
- }
- });
+ });
+ }
+ },
+ {
+ let fs = fs.clone();
+ move |model, should_be_favorite, cx| {
+ crate::favorite_models::toggle_in_settings(
+ model,
+ should_be_favorite,
+ fs.clone(),
+ cx,
+ );
+ }
},
false, // Do not use popover styles for the model picker
self.focus_handle.clone(),
@@ -29,26 +29,39 @@ impl AgentModelSelector {
Self {
selector: cx.new(move |cx| {
- let fs = fs.clone();
language_model_selector(
{
let model_context = model_usage_context.clone();
move |cx| model_context.configured_model(cx)
},
- move |model, cx| {
- let provider = model.provider_id().0.to_string();
- let model_id = model.id().0.to_string();
- match &model_usage_context {
- ModelUsageContext::InlineAssistant => {
- update_settings_file(fs.clone(), cx, move |settings, _cx| {
- settings
- .agent
- .get_or_insert_default()
- .set_inline_assistant_model(provider.clone(), model_id);
- });
+ {
+ let fs = fs.clone();
+ move |model, cx| {
+ let provider = model.provider_id().0.to_string();
+ let model_id = model.id().0.to_string();
+ match &model_usage_context {
+ ModelUsageContext::InlineAssistant => {
+ update_settings_file(fs.clone(), cx, move |settings, _cx| {
+ settings
+ .agent
+ .get_or_insert_default()
+ .set_inline_assistant_model(provider.clone(), model_id);
+ });
+ }
}
}
},
+ {
+ let fs = fs.clone();
+ move |model, should_be_favorite, cx| {
+ crate::favorite_models::toggle_in_settings(
+ model,
+ should_be_favorite,
+ fs.clone(),
+ cx,
+ );
+ }
+ },
true, // Use popover styles for picker
focus_handle_clone,
window,
@@ -7,6 +7,7 @@ mod buffer_codegen;
mod completion_provider;
mod context;
mod context_server_configuration;
+mod favorite_models;
mod inline_assistant;
mod inline_prompt_editor;
mod language_model_selector;
@@ -467,6 +468,7 @@ mod tests {
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: vec![],
+ favorite_models: vec![],
default_profile: AgentProfileId::default(),
default_view: DefaultAgentView::Thread,
profiles: Default::default(),
@@ -0,0 +1,57 @@
+use std::sync::Arc;
+
+use agent_client_protocol::ModelId;
+use fs::Fs;
+use language_model::LanguageModel;
+use settings::{LanguageModelSelection, update_settings_file};
+use ui::App;
+
+fn language_model_to_selection(model: &Arc<dyn LanguageModel>) -> LanguageModelSelection {
+ LanguageModelSelection {
+ provider: model.provider_id().to_string().into(),
+ model: model.id().0.to_string(),
+ }
+}
+
+fn model_id_to_selection(model_id: &ModelId) -> LanguageModelSelection {
+ let id = model_id.0.as_ref();
+ let (provider, model) = id.split_once('/').unwrap_or(("", id));
+ LanguageModelSelection {
+ provider: provider.to_owned().into(),
+ model: model.to_owned(),
+ }
+}
+
+pub fn toggle_in_settings(
+ model: Arc<dyn LanguageModel>,
+ should_be_favorite: bool,
+ fs: Arc<dyn Fs>,
+ cx: &App,
+) {
+ let selection = language_model_to_selection(&model);
+ update_settings_file(fs, cx, move |settings, _| {
+ let agent = settings.agent.get_or_insert_default();
+ if should_be_favorite {
+ agent.add_favorite_model(selection.clone());
+ } else {
+ agent.remove_favorite_model(&selection);
+ }
+ });
+}
+
+pub fn toggle_model_id_in_settings(
+ model_id: ModelId,
+ should_be_favorite: bool,
+ fs: Arc<dyn Fs>,
+ cx: &App,
+) {
+ let selection = model_id_to_selection(&model_id);
+ update_settings_file(fs, cx, move |settings, _| {
+ let agent = settings.agent.get_or_insert_default();
+ if should_be_favorite {
+ agent.add_favorite_model(selection.clone());
+ } else {
+ agent.remove_favorite_model(&selection);
+ }
+ });
+}
@@ -1,16 +1,18 @@
use std::{cmp::Reverse, sync::Arc};
-use collections::IndexMap;
+use agent_settings::AgentSettings;
+use collections::{HashMap, HashSet, IndexMap};
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{
Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
};
use language_model::{
- AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
- LanguageModelRegistry,
+ AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProvider,
+ LanguageModelProviderId, LanguageModelRegistry,
};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
+use settings::Settings;
use ui::prelude::*;
use zed_actions::agent::OpenSettings;
@@ -18,12 +20,14 @@ use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem}
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
+type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &App) + 'static>;
pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
pub fn language_model_selector(
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
+ on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
popover_styles: bool,
focus_handle: FocusHandle,
window: &mut Window,
@@ -32,6 +36,7 @@ pub fn language_model_selector(
let delegate = LanguageModelPickerDelegate::new(
get_active_model,
on_model_changed,
+ on_toggle_favorite,
popover_styles,
focus_handle,
window,
@@ -49,7 +54,17 @@ pub fn language_model_selector(
}
fn all_models(cx: &App) -> GroupedModels {
- let providers = LanguageModelRegistry::global(cx).read(cx).providers();
+ let lm_registry = LanguageModelRegistry::global(cx).read(cx);
+ let providers = lm_registry.providers();
+
+ let mut favorites_index = FavoritesIndex::default();
+
+ for sel in &AgentSettings::get_global(cx).favorite_models {
+ favorites_index
+ .entry(sel.provider.0.clone().into())
+ .or_default()
+ .insert(sel.model.clone().into());
+ }
let recommended = providers
.iter()
@@ -57,10 +72,7 @@ fn all_models(cx: &App) -> GroupedModels {
provider
.recommended_models(cx)
.into_iter()
- .map(|model| ModelInfo {
- model,
- icon: provider.icon(),
- })
+ .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
})
.collect();
@@ -70,25 +82,44 @@ fn all_models(cx: &App) -> GroupedModels {
provider
.provided_models(cx)
.into_iter()
- .map(|model| ModelInfo {
- model,
- icon: provider.icon(),
- })
+ .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
})
.collect();
GroupedModels::new(all, recommended)
}
+type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
+
#[derive(Clone)]
struct ModelInfo {
model: Arc<dyn LanguageModel>,
icon: IconName,
+ is_favorite: bool,
+}
+
+impl ModelInfo {
+ fn new(
+ provider: &dyn LanguageModelProvider,
+ model: Arc<dyn LanguageModel>,
+ favorites_index: &FavoritesIndex,
+ ) -> Self {
+ let is_favorite = favorites_index
+ .get(&provider.id())
+ .map_or(false, |set| set.contains(&model.id()));
+
+ Self {
+ model,
+ icon: provider.icon(),
+ is_favorite,
+ }
+ }
}
pub struct LanguageModelPickerDelegate {
on_model_changed: OnModelChanged,
get_active_model: GetActiveModel,
+ on_toggle_favorite: OnToggleFavorite,
all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
@@ -102,6 +133,7 @@ impl LanguageModelPickerDelegate {
fn new(
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
+ on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
popover_styles: bool,
focus_handle: FocusHandle,
window: &mut Window,
@@ -117,6 +149,7 @@ impl LanguageModelPickerDelegate {
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
filtered_entries: entries,
get_active_model: Arc::new(get_active_model),
+ on_toggle_favorite: Arc::new(on_toggle_favorite),
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
_subscriptions: vec![cx.subscribe_in(
&LanguageModelRegistry::global(cx),
@@ -219,12 +252,19 @@ impl LanguageModelPickerDelegate {
}
struct GroupedModels {
+ favorites: Vec<ModelInfo>,
recommended: Vec<ModelInfo>,
all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
}
impl GroupedModels {
pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
+ let favorites = all
+ .iter()
+ .filter(|info| info.is_favorite)
+ .cloned()
+ .collect();
+
let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
for model in all {
let provider = model.model.provider_id();
@@ -236,6 +276,7 @@ impl GroupedModels {
}
Self {
+ favorites,
recommended,
all: all_by_provider,
}
@@ -244,13 +285,18 @@ impl GroupedModels {
fn entries(&self) -> Vec<LanguageModelPickerEntry> {
let mut entries = Vec::new();
+ if !self.favorites.is_empty() {
+ entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
+ for info in &self.favorites {
+ entries.push(LanguageModelPickerEntry::Model(info.clone()));
+ }
+ }
+
if !self.recommended.is_empty() {
entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
- entries.extend(
- self.recommended
- .iter()
- .map(|info| LanguageModelPickerEntry::Model(info.clone())),
- );
+ for info in &self.recommended {
+ entries.push(LanguageModelPickerEntry::Model(info.clone()));
+ }
}
for models in self.all.values() {
@@ -260,12 +306,11 @@ impl GroupedModels {
entries.push(LanguageModelPickerEntry::Separator(
models[0].model.provider_name().0,
));
- entries.extend(
- models
- .iter()
- .map(|info| LanguageModelPickerEntry::Model(info.clone())),
- );
+ for info in models {
+ entries.push(LanguageModelPickerEntry::Model(info.clone()));
+ }
}
+
entries
}
}
@@ -461,7 +506,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
fn render_match(
&self,
ix: usize,
- is_focused: bool,
+ selected: bool,
_: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
@@ -477,11 +522,20 @@ impl PickerDelegate for LanguageModelPickerDelegate {
let is_selected = Some(model_info.model.provider_id()) == active_provider_id
&& Some(model_info.model.id()) == active_model_id;
+ let is_favorite = model_info.is_favorite;
+ let handle_action_click = {
+ let model = model_info.model.clone();
+ let on_toggle_favorite = self.on_toggle_favorite.clone();
+ move |cx: &App| on_toggle_favorite(model.clone(), !is_favorite, cx)
+ };
+
Some(
ModelSelectorListItem::new(ix, model_info.model.name().0)
- .is_focused(is_focused)
- .is_selected(is_selected)
.icon(model_info.icon)
+ .is_selected(is_selected)
+ .is_focused(selected)
+ .is_favorite(is_favorite)
+ .on_toggle_favorite(handle_action_click)
.into_any_element(),
)
}
@@ -493,12 +547,12 @@ impl PickerDelegate for LanguageModelPickerDelegate {
_window: &mut Window,
_cx: &mut Context<Picker<Self>>,
) -> Option<gpui::AnyElement> {
+ let focus_handle = self.focus_handle.clone();
+
if !self.popover_styles {
return None;
}
- let focus_handle = self.focus_handle.clone();
-
Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
}
}
@@ -598,11 +652,24 @@ mod tests {
}
fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
+ create_models_with_favorites(model_specs, vec![])
+ }
+
+ fn create_models_with_favorites(
+ model_specs: Vec<(&str, &str)>,
+ favorites: Vec<(&str, &str)>,
+ ) -> Vec<ModelInfo> {
model_specs
.into_iter()
- .map(|(provider, name)| ModelInfo {
- model: Arc::new(TestLanguageModel::new(name, provider)),
- icon: IconName::Ai,
+ .map(|(provider, name)| {
+ let is_favorite = favorites
+ .iter()
+ .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
+ ModelInfo {
+ model: Arc::new(TestLanguageModel::new(name, provider)),
+ icon: IconName::Ai,
+ is_favorite,
+ }
})
.collect()
}
@@ -740,4 +807,93 @@ mod tests {
vec!["zed/claude", "zed/gemini", "copilot/claude"],
);
}
+
+ #[gpui::test]
+ fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
+ let recommended_models = create_models(vec![("zed", "claude")]);
+ let all_models = create_models_with_favorites(
+ vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
+ vec![("zed", "gemini")],
+ );
+
+ let grouped_models = GroupedModels::new(all_models, recommended_models);
+ let entries = grouped_models.entries();
+
+ assert!(matches!(
+ entries.first(),
+ Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
+ ));
+
+ assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
+ }
+
+ #[gpui::test]
+ fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
+ let recommended_models = create_models(vec![("zed", "claude")]);
+ let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
+
+ let grouped_models = GroupedModels::new(all_models, recommended_models);
+ let entries = grouped_models.entries();
+
+ assert!(matches!(
+ entries.first(),
+ Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
+ ));
+
+ assert!(grouped_models.favorites.is_empty());
+ }
+
+ #[gpui::test]
+ fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
+ let recommended_models =
+ create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
+ let all_models = create_models_with_favorites(
+ vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
+ vec![("zed", "claude")],
+ );
+
+ let grouped_models = GroupedModels::new(all_models, recommended_models);
+ let entries = grouped_models.entries();
+
+ for entry in &entries {
+ if let LanguageModelPickerEntry::Model(info) = entry {
+ if info.model.telemetry_id() == "zed/claude" {
+ assert!(info.is_favorite, "zed/claude should be a favorite");
+ } else {
+ assert!(
+ !info.is_favorite,
+ "{} should not be a favorite",
+ info.model.telemetry_id()
+ );
+ }
+ }
+ }
+ }
+
+ #[gpui::test]
+ fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
+ let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
+
+ let recommended_models =
+ create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
+
+ let all_models = create_models_with_favorites(
+ vec![
+ ("zed", "claude"),
+ ("zed", "gemini"),
+ ("openai", "gpt-4"),
+ ("openai", "gpt-3.5"),
+ ],
+ favorites,
+ );
+
+ let grouped_models = GroupedModels::new(all_models, recommended_models);
+
+ assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
+ assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
+ assert_models_eq(
+ grouped_models.all.values().flatten().cloned().collect(),
+ vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
+ );
+ }
}
@@ -304,17 +304,31 @@ impl TextThreadEditor {
language_model_selector: cx.new(|cx| {
language_model_selector(
|cx| LanguageModelRegistry::read_global(cx).default_model(),
- move |model, cx| {
- update_settings_file(fs.clone(), cx, move |settings, _| {
- let provider = model.provider_id().0.to_string();
- let model = model.id().0.to_string();
- settings.agent.get_or_insert_default().set_model(
- LanguageModelSelection {
- provider: LanguageModelProviderSetting(provider),
- model,
- },
- )
- });
+ {
+ let fs = fs.clone();
+ move |model, cx| {
+ update_settings_file(fs.clone(), cx, move |settings, _| {
+ let provider = model.provider_id().0.to_string();
+ let model = model.id().0.to_string();
+ settings.agent.get_or_insert_default().set_model(
+ LanguageModelSelection {
+ provider: LanguageModelProviderSetting(provider),
+ model,
+ },
+ )
+ });
+ }
+ },
+ {
+ let fs = fs.clone();
+ move |model, should_be_favorite, cx| {
+ crate::favorite_models::toggle_in_settings(
+ model,
+ should_be_favorite,
+ fs.clone(),
+ cx,
+ );
+ }
},
true, // Use popover styles for picker
focus_handle,
@@ -1,5 +1,5 @@
use gpui::{Action, FocusHandle, prelude::*};
-use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
+use ui::{ElevationIndex, KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*};
#[derive(IntoElement)]
pub struct ModelSelectorHeader {
@@ -42,6 +42,8 @@ pub struct ModelSelectorListItem {
icon: Option<IconName>,
is_selected: bool,
is_focused: bool,
+ is_favorite: bool,
+ on_toggle_favorite: Option<Box<dyn Fn(&App) + 'static>>,
}
impl ModelSelectorListItem {
@@ -52,6 +54,8 @@ impl ModelSelectorListItem {
icon: None,
is_selected: false,
is_focused: false,
+ is_favorite: false,
+ on_toggle_favorite: None,
}
}
@@ -69,6 +73,16 @@ impl ModelSelectorListItem {
self.is_focused = is_focused;
self
}
+
+ pub fn is_favorite(mut self, is_favorite: bool) -> Self {
+ self.is_favorite = is_favorite;
+ self
+ }
+
+ pub fn on_toggle_favorite(mut self, handler: impl Fn(&App) + 'static) -> Self {
+ self.on_toggle_favorite = Some(Box::new(handler));
+ self
+ }
}
impl RenderOnce for ModelSelectorListItem {
@@ -79,6 +93,8 @@ impl RenderOnce for ModelSelectorListItem {
Color::Muted
};
+ let is_favorite = self.is_favorite;
+
ListItem::new(self.index)
.inset(true)
.spacing(ListItemSpacing::Sparse)
@@ -103,6 +119,23 @@ impl RenderOnce for ModelSelectorListItem {
.size(IconSize::Small),
)
}))
+ .end_hover_slot(div().pr_2().when_some(self.on_toggle_favorite, {
+ |this, handle_click| {
+ let (icon, color, tooltip) = if is_favorite {
+ (IconName::StarFilled, Color::Accent, "Unfavorite Model")
+ } else {
+ (IconName::Star, Color::Default, "Favorite Model")
+ };
+ this.child(
+ IconButton::new(("toggle-favorite", self.index), icon)
+ .layer(ElevationIndex::ElevatedSurface)
+ .icon_color(color)
+ .icon_size(IconSize::Small)
+ .tooltip(Tooltip::text(tooltip))
+ .on_click(move |_, _, cx| (handle_click)(cx)),
+ )
+ }
+ }))
}
}
@@ -38,6 +38,9 @@ pub struct AgentSettingsContent {
pub default_height: Option<f32>,
/// The default model to use when creating new chats and for other features when a specific model is not specified.
pub default_model: Option<LanguageModelSelection>,
+ /// Favorite models to show at the top of the model selector.
+ #[serde(default)]
+ pub favorite_models: Vec<LanguageModelSelection>,
/// Model to use for the inline assistant. Defaults to default_model when not specified.
pub inline_assistant_model: Option<LanguageModelSelection>,
/// Model to use for the inline assistant when streaming tools are enabled.
@@ -176,6 +179,16 @@ impl AgentSettingsContent {
pub fn set_profile(&mut self, profile_id: Arc<str>) {
self.default_profile = Some(profile_id);
}
+
+ pub fn add_favorite_model(&mut self, model: LanguageModelSelection) {
+ if !self.favorite_models.contains(&model) {
+ self.favorite_models.push(model);
+ }
+ }
+
+ pub fn remove_favorite_model(&mut self, model: &LanguageModelSelection) {
+ self.favorite_models.retain(|m| m != model);
+ }
}
#[with_fallible_options]