From 2d84af91bf3e272770c07ddb4056618cb5e5e9a8 Mon Sep 17 00:00:00 2001 From: David <688326+dvcrn@users.noreply.github.com> Date: Mon, 10 Nov 2025 21:11:24 +0700 Subject: [PATCH] agent: Add ability to set a default_model per profile (#39220) Split off from https://github.com/zed-industries/zed/pull/39175 Requires https://github.com/zed-industries/zed/pull/39219 to be merged first Adds support for `default_model` for profiles: ``` "my-profile": { "name": "Coding Agent", "tools": {}, "enable_all_context_servers": false, "context_servers": {}, "default_model": { "provider": "copilot_chat", "model": "grok-code-fast-1" } } ``` Which will then switch to the default model whenever the profile is activated ![2025-09-30 17 09 06](https://github.com/user-attachments/assets/43f07b7b-85d9-4aff-82ce-25d6f5050d50) Release Notes: - Added `default_model` configuration to agent profile --------- Co-authored-by: Danilo Leal --- crates/agent/src/tests/mod.rs | 12 +- crates/agent/src/thread.rs | 64 ++++++++-- crates/agent_settings/src/agent_profile.rs | 58 ++++++--- crates/agent_ui/src/acp/thread_view.rs | 5 +- .../src/agent_configuration/tool_picker.rs | 1 + crates/agent_ui/src/profile_selector.rs | 114 ++++++++++-------- crates/eval/src/instance.rs | 2 +- crates/settings/src/settings_content/agent.rs | 2 + 8 files changed, 177 insertions(+), 81 deletions(-) diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index d80edca35de03578d0d557eb320dc77471a3b8fb..5d4bdce27cc05d1cf46a4b73821f0a97878fd6f4 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -933,7 +933,7 @@ async fn test_profiles(cx: &mut TestAppContext) { // Test that test-1 profile (default) has echo and delay tools thread .update(cx, |thread, cx| { - thread.set_profile(AgentProfileId("test-1".into())); + thread.set_profile(AgentProfileId("test-1".into()), cx); thread.send(UserMessageId::new(), ["test"], cx) }) .unwrap(); @@ -953,7 +953,7 @@ async fn test_profiles(cx: &mut TestAppContext) { // Switch to test-2 profile, and verify that it has only the infinite tool. thread .update(cx, |thread, cx| { - thread.set_profile(AgentProfileId("test-2".into())); + thread.set_profile(AgentProfileId("test-2".into()), cx); thread.send(UserMessageId::new(), ["test2"], cx) }) .unwrap(); @@ -1002,8 +1002,8 @@ async fn test_mcp_tools(cx: &mut TestAppContext) { ) .await; cx.run_until_parked(); - thread.update(cx, |thread, _| { - thread.set_profile(AgentProfileId("test".into())) + thread.update(cx, |thread, cx| { + thread.set_profile(AgentProfileId("test".into()), cx) }); let mut mcp_tool_calls = setup_context_server( @@ -1169,8 +1169,8 @@ async fn test_mcp_tool_truncation(cx: &mut TestAppContext) { .await; cx.run_until_parked(); - thread.update(cx, |thread, _| { - thread.set_profile(AgentProfileId("test".into())); + thread.update(cx, |thread, cx| { + thread.set_profile(AgentProfileId("test".into()), cx); thread.add_tool(EchoTool); thread.add_tool(DelayTool); thread.add_tool(WordListTool); diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 78f20152b4daf461de40cfa7746216092f82cf41..5cf230629c8e542a23ea7ffc5bdb0fa5a1c73a53 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -30,16 +30,17 @@ use gpui::{ }; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt, - LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, - LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, - LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID, + LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, + LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, + LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, + LanguageModelToolUse, LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage, + ZED_CLOUD_PROVIDER_ID, }; use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; -use settings::{Settings, update_settings_file}; +use settings::{LanguageModelSelection, Settings, update_settings_file}; use smol::stream::StreamExt; use std::{ collections::BTreeMap, @@ -798,7 +799,8 @@ impl Thread { let profile_id = db_thread .profile .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone()); - let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + + let mut model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| { db_thread .model .and_then(|model| { @@ -811,6 +813,16 @@ impl Thread { .or_else(|| registry.default_model()) .map(|model| model.model) }); + + if model.is_none() { + model = Self::resolve_profile_model(&profile_id, cx); + } + if model.is_none() { + model = LanguageModelRegistry::global(cx).update(cx, |registry, _cx| { + registry.default_model().map(|model| model.model) + }); + } + let (prompt_capabilities_tx, prompt_capabilities_rx) = watch::channel(Self::prompt_capabilities(model.as_deref())); @@ -1007,8 +1019,17 @@ impl Thread { &self.profile_id } - pub fn set_profile(&mut self, profile_id: AgentProfileId) { + pub fn set_profile(&mut self, profile_id: AgentProfileId, cx: &mut Context) { + if self.profile_id == profile_id { + return; + } + self.profile_id = profile_id; + + // Swap to the profile's preferred model when available. + if let Some(model) = Self::resolve_profile_model(&self.profile_id, cx) { + self.set_model(model, cx); + } } pub fn cancel(&mut self, cx: &mut Context) { @@ -1065,6 +1086,35 @@ impl Thread { }) } + /// Look up the active profile and resolve its preferred model if one is configured. + fn resolve_profile_model( + profile_id: &AgentProfileId, + cx: &mut Context, + ) -> Option> { + let selection = AgentSettings::get_global(cx) + .profiles + .get(profile_id)? + .default_model + .clone()?; + Self::resolve_model_from_selection(&selection, cx) + } + + /// Translate a stored model selection into the configured model from the registry. + fn resolve_model_from_selection( + selection: &LanguageModelSelection, + cx: &mut Context, + ) -> Option> { + let selected = SelectedModel { + provider: LanguageModelProviderId::from(selection.provider.0.clone()), + model: LanguageModelId::from(selection.model.clone()), + }; + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry + .select_model(&selected, cx) + .map(|configured| configured.model) + }) + } + pub fn resume( &mut self, cx: &mut Context, diff --git a/crates/agent_settings/src/agent_profile.rs b/crates/agent_settings/src/agent_profile.rs index 999ddc8083a1a4b4c271ea9bde4c1e45307e9542..aff666e01111dc5db539b370cd440fa88438fe8d 100644 --- a/crates/agent_settings/src/agent_profile.rs +++ b/crates/agent_settings/src/agent_profile.rs @@ -6,8 +6,8 @@ use convert_case::{Case, Casing as _}; use fs::Fs; use gpui::{App, SharedString}; use settings::{ - AgentProfileContent, ContextServerPresetContent, Settings as _, SettingsContent, - update_settings_file, + AgentProfileContent, ContextServerPresetContent, LanguageModelSelection, Settings as _, + SettingsContent, update_settings_file, }; use util::ResultExt as _; @@ -53,19 +53,30 @@ impl AgentProfile { let base_profile = base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned()); + // Copy toggles from the base profile so the new profile starts with familiar defaults. + let tools = base_profile + .as_ref() + .map(|profile| profile.tools.clone()) + .unwrap_or_default(); + let enable_all_context_servers = base_profile + .as_ref() + .map(|profile| profile.enable_all_context_servers) + .unwrap_or_default(); + let context_servers = base_profile + .as_ref() + .map(|profile| profile.context_servers.clone()) + .unwrap_or_default(); + // Preserve the base profile's model preference when cloning into a new profile. + let default_model = base_profile + .as_ref() + .and_then(|profile| profile.default_model.clone()); + let profile_settings = AgentProfileSettings { name: name.into(), - tools: base_profile - .as_ref() - .map(|profile| profile.tools.clone()) - .unwrap_or_default(), - enable_all_context_servers: base_profile - .as_ref() - .map(|profile| profile.enable_all_context_servers) - .unwrap_or_default(), - context_servers: base_profile - .map(|profile| profile.context_servers) - .unwrap_or_default(), + tools, + enable_all_context_servers, + context_servers, + default_model, }; update_settings_file(fs, cx, { @@ -96,6 +107,8 @@ pub struct AgentProfileSettings { pub tools: IndexMap, bool>, pub enable_all_context_servers: bool, pub context_servers: IndexMap, ContextServerPreset>, + /// Default language model to apply when this profile becomes active. + pub default_model: Option, } impl AgentProfileSettings { @@ -144,6 +157,7 @@ impl AgentProfileSettings { ) }) .collect(), + default_model: self.default_model.clone(), }, ); @@ -153,15 +167,23 @@ impl AgentProfileSettings { impl From for AgentProfileSettings { fn from(content: AgentProfileContent) -> Self { + let AgentProfileContent { + name, + tools, + enable_all_context_servers, + context_servers, + default_model, + } = content; + Self { - name: content.name.into(), - tools: content.tools, - enable_all_context_servers: content.enable_all_context_servers.unwrap_or_default(), - context_servers: content - .context_servers + name: name.into(), + tools, + enable_all_context_servers: enable_all_context_servers.unwrap_or_default(), + context_servers: context_servers .into_iter() .map(|(server_id, preset)| (server_id, preset.into())) .collect(), + default_model, } } } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index daf2249909fc9a29df6969dba1ad51cc099c891c..306976473d772f55cfdf1ee9caa65eab4f1d5552 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -125,8 +125,9 @@ impl ProfileProvider for Entity { } fn set_profile(&self, profile_id: AgentProfileId, cx: &mut App) { - self.update(cx, |thread, _cx| { - thread.set_profile(profile_id); + self.update(cx, |thread, cx| { + // Apply the profile and let the thread swap to its default model. + thread.set_profile(profile_id, cx); }); } diff --git a/crates/agent_ui/src/agent_configuration/tool_picker.rs b/crates/agent_ui/src/agent_configuration/tool_picker.rs index 6b84205e1bd6336d70751090d8f0451b1b1925b0..1c99f665ab1c8fc995d47682f92365852bbc9637 100644 --- a/crates/agent_ui/src/agent_configuration/tool_picker.rs +++ b/crates/agent_ui/src/agent_configuration/tool_picker.rs @@ -314,6 +314,7 @@ impl PickerDelegate for ToolPickerDelegate { ) }) .collect(), + default_model: default_profile.default_model.clone(), }); if let Some(server_id) = server_id { diff --git a/crates/agent_ui/src/profile_selector.rs b/crates/agent_ui/src/profile_selector.rs index 2f9fe19eb33667d6ca6bb2f5502fbd1c9f094e9c..c1949d22e268e8744db7834a58d1a3303fa4e236 100644 --- a/crates/agent_ui/src/profile_selector.rs +++ b/crates/agent_ui/src/profile_selector.rs @@ -15,8 +15,8 @@ use std::{ sync::{Arc, atomic::AtomicBool}, }; use ui::{ - DocumentationAside, DocumentationEdge, DocumentationSide, HighlightedLabel, LabelSize, - ListItem, ListItemSpacing, PopoverMenuHandle, TintColor, Tooltip, prelude::*, + DocumentationAside, DocumentationEdge, DocumentationSide, HighlightedLabel, KeyBinding, + LabelSize, ListItem, ListItemSpacing, PopoverMenuHandle, TintColor, Tooltip, prelude::*, }; /// Trait for types that can provide and manage agent profiles @@ -81,6 +81,7 @@ impl ProfileSelector { self.provider.clone(), self.profiles.clone(), cx.background_executor().clone(), + self.focus_handle.clone(), cx, ); @@ -207,6 +208,7 @@ pub(crate) struct ProfilePickerDelegate { selected_index: usize, query: String, cancel: Option>, + focus_handle: FocusHandle, } impl ProfilePickerDelegate { @@ -215,6 +217,7 @@ impl ProfilePickerDelegate { provider: Arc, profiles: AvailableProfiles, background: BackgroundExecutor, + focus_handle: FocusHandle, cx: &mut Context, ) -> Self { let candidates = Self::candidates_from(profiles); @@ -231,6 +234,7 @@ impl ProfilePickerDelegate { selected_index: 0, query: String::new(), cancel: None, + focus_handle, }; this.selected_index = this @@ -594,20 +598,26 @@ impl PickerDelegate for ProfilePickerDelegate { _: &mut Window, cx: &mut Context>, ) -> Option { + let focus_handle = self.focus_handle.clone(); + Some( h_flex() .w_full() .border_t_1() .border_color(cx.theme().colors().border_variant) - .p_1() - .gap_4() - .justify_between() + .p_1p5() .child( Button::new("configure", "Configure") - .icon(IconName::Settings) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) + .full_width() + .style(ButtonStyle::Outlined) + .key_binding( + KeyBinding::for_action_in( + &ManageProfiles::default(), + &focus_handle, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) .on_click(|_, window, cx| { window.dispatch_action(ManageProfiles::default().boxed_clone(), cx); }), @@ -659,20 +669,25 @@ mod tests { is_builtin: true, }]; - let delegate = ProfilePickerDelegate { - fs: FakeFs::new(cx.executor()), - provider: Arc::new(TestProfileProvider::new(AgentProfileId("write".into()))), - background: cx.executor(), - candidates, - string_candidates: Arc::new(Vec::new()), - filtered_entries: Vec::new(), - selected_index: 0, - query: String::new(), - cancel: None, - }; - - let matches = Vec::new(); // No matches - let _entries = delegate.entries_from_matches(matches); + cx.update(|cx| { + let focus_handle = cx.focus_handle(); + + let delegate = ProfilePickerDelegate { + fs: FakeFs::new(cx.background_executor().clone()), + provider: Arc::new(TestProfileProvider::new(AgentProfileId("write".into()))), + background: cx.background_executor().clone(), + candidates, + string_candidates: Arc::new(Vec::new()), + filtered_entries: Vec::new(), + selected_index: 0, + query: String::new(), + cancel: None, + focus_handle, + }; + + let matches = Vec::new(); // No matches + let _entries = delegate.entries_from_matches(matches); + }); } #[gpui::test] @@ -690,30 +705,35 @@ mod tests { }, ]; - let delegate = ProfilePickerDelegate { - fs: FakeFs::new(cx.executor()), - provider: Arc::new(TestProfileProvider::new(AgentProfileId("write".into()))), - background: cx.executor(), - candidates, - string_candidates: Arc::new(Vec::new()), - filtered_entries: vec![ - ProfilePickerEntry::Profile(ProfileMatchEntry { - candidate_index: 0, - positions: Vec::new(), - }), - ProfilePickerEntry::Profile(ProfileMatchEntry { - candidate_index: 1, - positions: Vec::new(), - }), - ], - selected_index: 0, - query: String::new(), - cancel: None, - }; - - // Active profile should be found at index 0 - let active_index = delegate.index_of_profile(&AgentProfileId("write".into())); - assert_eq!(active_index, Some(0)); + cx.update(|cx| { + let focus_handle = cx.focus_handle(); + + let delegate = ProfilePickerDelegate { + fs: FakeFs::new(cx.background_executor().clone()), + provider: Arc::new(TestProfileProvider::new(AgentProfileId("write".into()))), + background: cx.background_executor().clone(), + candidates, + string_candidates: Arc::new(Vec::new()), + filtered_entries: vec![ + ProfilePickerEntry::Profile(ProfileMatchEntry { + candidate_index: 0, + positions: Vec::new(), + }), + ProfilePickerEntry::Profile(ProfileMatchEntry { + candidate_index: 1, + positions: Vec::new(), + }), + ], + selected_index: 0, + query: String::new(), + cancel: None, + focus_handle, + }; + + // Active profile should be found at index 0 + let active_index = delegate.index_of_profile(&AgentProfileId("write".into())); + assert_eq!(active_index, Some(0)); + }); } struct TestProfileProvider { diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 5317f100456748616dfec63819bc0373aaceb4c1..035f1ec0ac8d0c6490dc39637e03e377ee3d194b 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -322,7 +322,7 @@ impl ExampleInstance { thread.add_default_tools(Rc::new(EvalThreadEnvironment { project: project.clone(), }), cx); - thread.set_profile(meta.profile_id.clone()); + thread.set_profile(meta.profile_id.clone(), cx); thread.set_model( LanguageModelInterceptor::new( LanguageModelRegistry::read_global(cx).default_model().expect("Missing model").model.clone(), diff --git a/crates/settings/src/settings_content/agent.rs b/crates/settings/src/settings_content/agent.rs index c641f280e177669a2af14e91c844f2a5f059b648..425b5f05ff46fa705c073838dceab6c431c74bde 100644 --- a/crates/settings/src/settings_content/agent.rs +++ b/crates/settings/src/settings_content/agent.rs @@ -176,6 +176,8 @@ pub struct AgentProfileContent { pub enable_all_context_servers: Option, #[serde(default)] pub context_servers: IndexMap, ContextServerPresetContent>, + /// The default language model selected when using this profile. + pub default_model: Option, } #[skip_serializing_none]