diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 2ec9beb71bf08c90ea85b8752410405714d31537..a44bdd1f22478e92ace192c939561f855c2814bd 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -35,6 +35,7 @@ pub struct AcpConnection { auth_methods: Vec, agent_capabilities: acp::AgentCapabilities, default_mode: Option, + default_model: Option, root_dir: PathBuf, // NB: Don't move this into the wait_task, since we need to ensure the process is // killed on drop (setting kill_on_drop on the command seems to not always work). @@ -57,6 +58,7 @@ pub async fn connect( command: AgentServerCommand, root_dir: &Path, default_mode: Option, + default_model: Option, is_remote: bool, cx: &mut AsyncApp, ) -> Result> { @@ -66,6 +68,7 @@ pub async fn connect( command.clone(), root_dir, default_mode, + default_model, is_remote, cx, ) @@ -82,6 +85,7 @@ impl AcpConnection { command: AgentServerCommand, root_dir: &Path, default_mode: Option, + default_model: Option, is_remote: bool, cx: &mut AsyncApp, ) -> Result { @@ -207,6 +211,7 @@ impl AcpConnection { sessions, agent_capabilities: response.agent_capabilities, default_mode, + default_model, _io_task: io_task, _wait_task: wait_task, _stderr_task: stderr_task, @@ -245,6 +250,7 @@ impl AgentConnection for AcpConnection { let conn = self.connection.clone(); let sessions = self.sessions.clone(); let default_mode = self.default_mode.clone(); + let default_model = self.default_model.clone(); let cwd = cwd.to_path_buf(); let context_server_store = project.read(cx).context_server_store().read(cx); let mcp_servers = @@ -333,6 +339,7 @@ impl AgentConnection for AcpConnection { let default_mode = default_mode.clone(); let session_id = response.session_id.clone(); let modes = modes.clone(); + let conn = conn.clone(); async move |_| { let result = conn.set_session_mode(acp::SetSessionModeRequest { session_id, @@ -367,6 +374,53 @@ impl AgentConnection for AcpConnection { } } + if let Some(default_model) = default_model { + if let Some(models) = models.as_ref() { + let mut models_ref = models.borrow_mut(); + let has_model = models_ref.available_models.iter().any(|model| model.model_id == default_model); + + if has_model { + let initial_model_id = models_ref.current_model_id.clone(); + + cx.spawn({ + let default_model = default_model.clone(); + let session_id = response.session_id.clone(); + let models = models.clone(); + let conn = conn.clone(); + async move |_| { + let result = conn.set_session_model(acp::SetSessionModelRequest { + session_id, + model_id: default_model, + meta: None, + }) + .await.log_err(); + + if result.is_none() { + models.borrow_mut().current_model_id = initial_model_id; + } + } + }).detach(); + + models_ref.current_model_id = default_model; + } else { + let available_models = models_ref + .available_models + .iter() + .map(|model| format!("- `{}`: {}", model.model_id, model.name)) + .collect::>() + .join("\n"); + + log::warn!( + "`{default_model}` is not a valid {name} model. Available options:\n{available_models}", + ); + } + } else { + log::warn!( + "`{name}` does not support model selection, but `default_model` was set in settings.", + ); + } + } + let session_id = response.session_id; let action_log = cx.new(|_| ActionLog::new(project.clone()))?; let thread = cx.new(|cx| { diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index b44c2123fb5052e2487464d813936cd1edf9821a..cf03b71a78b358d7b110c450f769f9645094baaa 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -68,6 +68,18 @@ pub trait AgentServer: Send { ) { } + fn default_model(&self, _cx: &mut App) -> Option { + None + } + + fn set_default_model( + &self, + _model_id: Option, + _fs: Arc, + _cx: &mut App, + ) { + } + fn connect( &self, root_dir: Option<&Path>, diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index cd3207824a7c05ddfaafeca965deea0918ccfb39..ac79ab7484de90a84ce3d6720f54bcec6addc6b5 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -55,6 +55,27 @@ impl AgentServer for ClaudeCode { }); } + fn default_model(&self, cx: &mut App) -> Option { + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).claude.clone() + }); + + settings + .as_ref() + .and_then(|s| s.default_model.clone().map(|m| acp::ModelId(m.into()))) + } + + fn set_default_model(&self, model_id: Option, fs: Arc, cx: &mut App) { + update_settings_file(fs, cx, |settings, _| { + settings + .agent_servers + .get_or_insert_default() + .claude + .get_or_insert_default() + .default_model = model_id.map(|m| m.to_string()) + }); + } + fn connect( &self, root_dir: Option<&Path>, @@ -68,6 +89,7 @@ impl AgentServer for ClaudeCode { let store = delegate.store.downgrade(); let extra_env = load_proxy_env(cx); let default_mode = self.default_mode(cx); + let default_model = self.default_model(cx); cx.spawn(async move |cx| { let (command, root_dir, login) = store @@ -90,6 +112,7 @@ impl AgentServer for ClaudeCode { command, root_dir.as_ref(), default_mode, + default_model, is_remote, cx, ) diff --git a/crates/agent_servers/src/codex.rs b/crates/agent_servers/src/codex.rs index 95375ad412c31272dbfce9262b4b5fd38fe55c50..ec01cd4e523b5696b2f09b5e51e7137fcfb16c91 100644 --- a/crates/agent_servers/src/codex.rs +++ b/crates/agent_servers/src/codex.rs @@ -56,6 +56,27 @@ impl AgentServer for Codex { }); } + fn default_model(&self, cx: &mut App) -> Option { + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).codex.clone() + }); + + settings + .as_ref() + .and_then(|s| s.default_model.clone().map(|m| acp::ModelId(m.into()))) + } + + fn set_default_model(&self, model_id: Option, fs: Arc, cx: &mut App) { + update_settings_file(fs, cx, |settings, _| { + settings + .agent_servers + .get_or_insert_default() + .codex + .get_or_insert_default() + .default_model = model_id.map(|m| m.to_string()) + }); + } + fn connect( &self, root_dir: Option<&Path>, @@ -69,6 +90,7 @@ impl AgentServer for Codex { let store = delegate.store.downgrade(); let extra_env = load_proxy_env(cx); let default_mode = self.default_mode(cx); + let default_model = self.default_model(cx); cx.spawn(async move |cx| { let (command, root_dir, login) = store @@ -92,6 +114,7 @@ impl AgentServer for Codex { command, root_dir.as_ref(), default_mode, + default_model, is_remote, cx, ) diff --git a/crates/agent_servers/src/custom.rs b/crates/agent_servers/src/custom.rs index 7d36cc758389a828b819a822c91c9bb4b3444985..b417e2bdf30a7ed6b9e2ab4baa6211cee2a9a890 100644 --- a/crates/agent_servers/src/custom.rs +++ b/crates/agent_servers/src/custom.rs @@ -61,6 +61,34 @@ impl crate::AgentServer for CustomAgentServer { }); } + fn default_model(&self, cx: &mut App) -> Option { + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings + .get::(None) + .custom + .get(&self.name()) + .cloned() + }); + + settings + .as_ref() + .and_then(|s| s.default_model.clone().map(|m| acp::ModelId(m.into()))) + } + + fn set_default_model(&self, model_id: Option, fs: Arc, cx: &mut App) { + let name = self.name(); + update_settings_file(fs, cx, move |settings, _| { + if let Some(settings) = settings + .agent_servers + .get_or_insert_default() + .custom + .get_mut(&name) + { + settings.default_model = model_id.map(|m| m.to_string()) + } + }); + } + fn connect( &self, root_dir: Option<&Path>, @@ -72,6 +100,7 @@ impl crate::AgentServer for CustomAgentServer { let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned()); let is_remote = delegate.project.read(cx).is_via_remote_server(); let default_mode = self.default_mode(cx); + let default_model = self.default_model(cx); let store = delegate.store.downgrade(); let extra_env = load_proxy_env(cx); @@ -98,6 +127,7 @@ impl crate::AgentServer for CustomAgentServer { command, root_dir.as_ref(), default_mode, + default_model, is_remote, cx, ) diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index 7618625278121cc1426f06ed8626a68759f34995..824b999bdaff46cf3ad3a570b62fecd596612563 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -476,6 +476,7 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc { env: None, ignore_system_version: None, default_mode: None, + default_model: None, }), gemini: Some(crate::gemini::tests::local_command().into()), codex: Some(BuiltinAgentServerSettings { @@ -484,6 +485,7 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc { env: None, ignore_system_version: None, default_mode: None, + default_model: None, }), custom: collections::HashMap::default(), }, diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index feaa221cbccb789ed3a89bed9f23d544e1d3b5f7..c1b2efb081551f82752dc15a909eec64ff78d94e 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -37,6 +37,7 @@ impl AgentServer for Gemini { let store = delegate.store.downgrade(); let mut extra_env = load_proxy_env(cx); let default_mode = self.default_mode(cx); + let default_model = self.default_model(cx); cx.spawn(async move |cx| { extra_env.insert("SURFACE".to_owned(), "zed".to_owned()); @@ -69,6 +70,7 @@ impl AgentServer for Gemini { command, root_dir.as_ref(), default_mode, + default_model, is_remote, cx, ) diff --git a/crates/agent_ui/src/acp/mode_selector.rs b/crates/agent_ui/src/acp/mode_selector.rs index 83ab9c299976848b973af28192462fda4eb69409..2db031cafeb8a66e43120be9766debe3c16eb2d0 100644 --- a/crates/agent_ui/src/acp/mode_selector.rs +++ b/crates/agent_ui/src/acp/mode_selector.rs @@ -11,7 +11,7 @@ use ui::{ PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*, }; -use crate::{CycleModeSelector, ToggleProfileSelector}; +use crate::{CycleModeSelector, ToggleProfileSelector, ui::HoldForDefault}; pub struct ModeSelector { connection: Rc, @@ -108,36 +108,11 @@ impl ModeSelector { entry.documentation_aside(side, DocumentationEdge::Bottom, { let description = description.clone(); - move |cx| { + move |_| { v_flex() .gap_1() .child(Label::new(description.clone())) - .child( - h_flex() - .pt_1() - .border_t_1() - .border_color(cx.theme().colors().border_variant) - .gap_0p5() - .text_sm() - .text_color(Color::Muted.color(cx)) - .child("Hold") - .child(h_flex().flex_shrink_0().children( - ui::render_modifiers( - &gpui::Modifiers::secondary_key(), - PlatformStyle::platform(), - None, - Some(ui::TextSize::Default.rems(cx).into()), - true, - ), - )) - .child(div().map(|this| { - if is_default { - this.child("to also unset as default") - } else { - this.child("to also set as default") - } - })), - ) + .child(HoldForDefault::new(is_default)) .into_any_element() } }) diff --git a/crates/agent_ui/src/acp/model_selector.rs b/crates/agent_ui/src/acp/model_selector.rs index 91aacde2aebcd0a2d4c8098119bbc43342d3ef74..c60a3b6cb61970caba02df82506848b6efa90cc1 100644 --- a/crates/agent_ui/src/acp/model_selector.rs +++ b/crates/agent_ui/src/acp/model_selector.rs @@ -1,8 +1,10 @@ use std::{cmp::Reverse, rc::Rc, sync::Arc}; use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector}; +use agent_servers::AgentServer; use anyhow::Result; use collections::IndexMap; +use fs::Fs; use futures::FutureExt; use fuzzy::{StringMatchCandidate, match_strings}; use gpui::{AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity}; @@ -14,14 +16,18 @@ use ui::{ }; use util::ResultExt; +use crate::ui::HoldForDefault; + pub type AcpModelSelector = Picker; pub fn acp_model_selector( selector: Rc, + agent_server: Rc, + fs: Arc, window: &mut Window, cx: &mut Context, ) -> AcpModelSelector { - let delegate = AcpModelPickerDelegate::new(selector, window, cx); + let delegate = AcpModelPickerDelegate::new(selector, agent_server, fs, window, cx); Picker::list(delegate, window, cx) .show_scrollbar(true) .width(rems(20.)) @@ -35,10 +41,12 @@ enum AcpModelPickerEntry { pub struct AcpModelPickerDelegate { selector: Rc, + agent_server: Rc, + fs: Arc, filtered_entries: Vec, models: Option, selected_index: usize, - selected_description: Option<(usize, SharedString)>, + selected_description: Option<(usize, SharedString, bool)>, selected_model: Option, _refresh_models_task: Task<()>, } @@ -46,6 +54,8 @@ pub struct AcpModelPickerDelegate { impl AcpModelPickerDelegate { fn new( selector: Rc, + agent_server: Rc, + fs: Arc, window: &mut Window, cx: &mut Context, ) -> Self { @@ -86,6 +96,8 @@ impl AcpModelPickerDelegate { Self { selector, + agent_server, + fs, filtered_entries: Vec::new(), models: None, selected_model: None, @@ -181,6 +193,21 @@ impl PickerDelegate for AcpModelPickerDelegate { if let Some(AcpModelPickerEntry::Model(model_info)) = self.filtered_entries.get(self.selected_index) { + if window.modifiers().secondary() { + let default_model = self.agent_server.default_model(cx); + let is_default = default_model.as_ref() == Some(&model_info.id); + + self.agent_server.set_default_model( + if is_default { + None + } else { + Some(model_info.id.clone()) + }, + self.fs.clone(), + cx, + ); + } + self.selector .select_model(model_info.id.clone(), cx) .detach_and_log_err(cx); @@ -225,6 +252,8 @@ impl PickerDelegate for AcpModelPickerDelegate { ), AcpModelPickerEntry::Model(model_info) => { let is_selected = Some(model_info) == self.selected_model.as_ref(); + let default_model = self.agent_server.default_model(cx); + let is_default = default_model.as_ref() == Some(&model_info.id); let model_icon_color = if is_selected { Color::Accent @@ -239,8 +268,8 @@ impl PickerDelegate for AcpModelPickerDelegate { this .on_hover(cx.listener(move |menu, hovered, _, cx| { if *hovered { - menu.delegate.selected_description = Some((ix, description.clone())); - } else if matches!(menu.delegate.selected_description, Some((id, _)) if id == ix) { + menu.delegate.selected_description = Some((ix, description.clone(), is_default)); + } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) { menu.delegate.selected_description = None; } cx.notify(); @@ -283,14 +312,24 @@ impl PickerDelegate for AcpModelPickerDelegate { _window: &mut Window, _cx: &mut Context>, ) -> Option { - self.selected_description.as_ref().map(|(_, description)| { - let description = description.clone(); - DocumentationAside::new( - DocumentationSide::Left, - DocumentationEdge::Top, - Rc::new(move |_| Label::new(description.clone()).into_any_element()), - ) - }) + self.selected_description + .as_ref() + .map(|(_, description, is_default)| { + let description = description.clone(); + let is_default = *is_default; + + DocumentationAside::new( + DocumentationSide::Left, + DocumentationEdge::Top, + Rc::new(move |_| { + v_flex() + .gap_1() + .child(Label::new(description.clone())) + .child(HoldForDefault::new(is_default)) + .into_any_element() + }), + ) + }) } } diff --git a/crates/agent_ui/src/acp/model_selector_popover.rs b/crates/agent_ui/src/acp/model_selector_popover.rs index 2e8ade95ffcb65d8c7742b60fa0facc70358ae1e..04e7e06a85aadf7c7fb1b69bfcaf81ec6ff6bf89 100644 --- a/crates/agent_ui/src/acp/model_selector_popover.rs +++ b/crates/agent_ui/src/acp/model_selector_popover.rs @@ -1,6 +1,9 @@ use std::rc::Rc; +use std::sync::Arc; use acp_thread::{AgentModelInfo, AgentModelSelector}; +use agent_servers::AgentServer; +use fs::Fs; use gpui::{Entity, FocusHandle}; use picker::popover_menu::PickerPopoverMenu; use ui::{ @@ -20,13 +23,15 @@ pub struct AcpModelSelectorPopover { impl AcpModelSelectorPopover { pub(crate) fn new( selector: Rc, + agent_server: Rc, + fs: Arc, menu_handle: PopoverMenuHandle, focus_handle: FocusHandle, window: &mut Window, cx: &mut Context, ) -> Self { Self { - selector: cx.new(move |cx| acp_model_selector(selector, window, cx)), + selector: cx.new(move |cx| acp_model_selector(selector, agent_server, fs, window, cx)), menu_handle, focus_handle, } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index c2d3e5262354b57ae3c7e6dbd10189dedefebfe6..784fef0b9f3862047c868dddf88a8fcd217c278d 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -591,9 +591,13 @@ impl AcpThreadView { .connection() .model_selector(thread.read(cx).session_id()) .map(|selector| { + let agent_server = this.agent.clone(); + let fs = this.project.read(cx).fs().clone(); cx.new(|cx| { AcpModelSelectorPopover::new( selector, + agent_server, + fs, PopoverMenuHandle::default(), this.focus_handle(cx), window, diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 8652f5cbd6c750da9260970ddc9ddcaef8337451..45ba29a595b59f4a1c329d46e43030a1b9c7ed14 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -1348,6 +1348,7 @@ async fn open_new_agent_servers_entry_in_settings_editor( args: vec![], env: Some(HashMap::default()), default_mode: None, + default_model: None, }, ); } diff --git a/crates/agent_ui/src/ui.rs b/crates/agent_ui/src/ui.rs index 5363949b904d74d3749c066357e0c60fef19d3b9..f556f8eece8efef77f4a6c286fee032cbfcb42df 100644 --- a/crates/agent_ui/src/ui.rs +++ b/crates/agent_ui/src/ui.rs @@ -4,6 +4,7 @@ mod burn_mode_tooltip; mod claude_code_onboarding_modal; mod context_pill; mod end_trial_upsell; +mod hold_for_default; mod onboarding_modal; mod unavailable_editing_tooltip; mod usage_callout; @@ -14,6 +15,7 @@ pub use burn_mode_tooltip::*; pub use claude_code_onboarding_modal::*; pub use context_pill::*; pub use end_trial_upsell::*; +pub use hold_for_default::*; pub use onboarding_modal::*; pub use unavailable_editing_tooltip::*; pub use usage_callout::*; diff --git a/crates/agent_ui/src/ui/hold_for_default.rs b/crates/agent_ui/src/ui/hold_for_default.rs new file mode 100644 index 0000000000000000000000000000000000000000..409e5d59707caa3a6bc62bbf470e33cb150183f5 --- /dev/null +++ b/crates/agent_ui/src/ui/hold_for_default.rs @@ -0,0 +1,40 @@ +use gpui::{App, IntoElement, Modifiers, RenderOnce, Window}; +use ui::{prelude::*, render_modifiers}; + +#[derive(IntoElement)] +pub struct HoldForDefault { + is_default: bool, +} + +impl HoldForDefault { + pub fn new(is_default: bool) -> Self { + Self { is_default } + } +} + +impl RenderOnce for HoldForDefault { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + h_flex() + .pt_1() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + .gap_0p5() + .text_sm() + .text_color(Color::Muted.color(cx)) + .child("Hold") + .child(h_flex().flex_shrink_0().children(render_modifiers( + &Modifiers::secondary_key(), + PlatformStyle::platform(), + None, + Some(TextSize::Default.rems(cx).into()), + true, + ))) + .child(div().map(|this| { + if self.is_default { + this.child("to unset as default") + } else { + this.child("to set as default") + } + })) + } +} diff --git a/crates/project/src/agent_server_store.rs b/crates/project/src/agent_server_store.rs index f1fb210084fb118832f5ca8f5ffa78990c892aa1..944eb593185bd5016e397d1417ed834da3ee73ef 100644 --- a/crates/project/src/agent_server_store.rs +++ b/crates/project/src/agent_server_store.rs @@ -1777,6 +1777,7 @@ pub struct BuiltinAgentServerSettings { pub env: Option>, pub ignore_system_version: Option, pub default_mode: Option, + pub default_model: Option, } impl BuiltinAgentServerSettings { @@ -1799,6 +1800,7 @@ impl From for BuiltinAgentServerSettings { env: value.env, ignore_system_version: value.ignore_system_version, default_mode: value.default_mode, + default_model: value.default_model, } } } @@ -1823,6 +1825,12 @@ pub struct CustomAgentServerSettings { /// /// Default: None pub default_mode: Option, + /// The default model to use for this agent. + /// + /// This should be the model ID as reported by the agent. + /// + /// Default: None + pub default_model: Option, } impl From for CustomAgentServerSettings { @@ -1834,6 +1842,7 @@ impl From for CustomAgentServerSettings { env: value.env, }, default_mode: value.default_mode, + default_model: value.default_model, } } } @@ -2156,6 +2165,7 @@ mod extension_agent_tests { env: None, ignore_system_version: None, default_mode: None, + default_model: None, }; let BuiltinAgentServerSettings { path, .. } = settings.into(); @@ -2171,6 +2181,7 @@ mod extension_agent_tests { args: vec!["serve".into()], env: None, default_mode: None, + default_model: None, }; let CustomAgentServerSettings { diff --git a/crates/settings/src/settings_content/agent.rs b/crates/settings/src/settings_content/agent.rs index 425b5f05ff46fa705c073838dceab6c431c74bde..59b5a4e0f516387ce6316cd31376bb45c2c5cb94 100644 --- a/crates/settings/src/settings_content/agent.rs +++ b/crates/settings/src/settings_content/agent.rs @@ -332,6 +332,12 @@ pub struct BuiltinAgentServerSettings { /// /// Default: None pub default_mode: Option, + /// The default model to use for this agent. + /// + /// This should be the model ID as reported by the agent. + /// + /// Default: None + pub default_model: Option, } #[skip_serializing_none] @@ -348,4 +354,10 @@ pub struct CustomAgentServerSettings { /// /// Default: None pub default_mode: Option, + /// The default model to use for this agent. + /// + /// This should be the model ID as reported by the agent. + /// + /// Default: None + pub default_model: Option, }