Detailed changes
@@ -35,6 +35,7 @@ pub struct AcpConnection {
auth_methods: Vec<acp::AuthMethod>,
agent_capabilities: acp::AgentCapabilities,
default_mode: Option<acp::SessionModeId>,
+ default_model: Option<acp::ModelId>,
root_dir: PathBuf,
// NB: Don't move this into the wait_task, since we need to ensure the process is
// killed on drop (setting kill_on_drop on the command seems to not always work).
@@ -57,6 +58,7 @@ pub async fn connect(
command: AgentServerCommand,
root_dir: &Path,
default_mode: Option<acp::SessionModeId>,
+ default_model: Option<acp::ModelId>,
is_remote: bool,
cx: &mut AsyncApp,
) -> Result<Rc<dyn AgentConnection>> {
@@ -66,6 +68,7 @@ pub async fn connect(
command.clone(),
root_dir,
default_mode,
+ default_model,
is_remote,
cx,
)
@@ -82,6 +85,7 @@ impl AcpConnection {
command: AgentServerCommand,
root_dir: &Path,
default_mode: Option<acp::SessionModeId>,
+ default_model: Option<acp::ModelId>,
is_remote: bool,
cx: &mut AsyncApp,
) -> Result<Self> {
@@ -207,6 +211,7 @@ impl AcpConnection {
sessions,
agent_capabilities: response.agent_capabilities,
default_mode,
+ default_model,
_io_task: io_task,
_wait_task: wait_task,
_stderr_task: stderr_task,
@@ -245,6 +250,7 @@ impl AgentConnection for AcpConnection {
let conn = self.connection.clone();
let sessions = self.sessions.clone();
let default_mode = self.default_mode.clone();
+ let default_model = self.default_model.clone();
let cwd = cwd.to_path_buf();
let context_server_store = project.read(cx).context_server_store().read(cx);
let mcp_servers =
@@ -333,6 +339,7 @@ impl AgentConnection for AcpConnection {
let default_mode = default_mode.clone();
let session_id = response.session_id.clone();
let modes = modes.clone();
+ let conn = conn.clone();
async move |_| {
let result = conn.set_session_mode(acp::SetSessionModeRequest {
session_id,
@@ -367,6 +374,53 @@ impl AgentConnection for AcpConnection {
}
}
+ if let Some(default_model) = default_model {
+ if let Some(models) = models.as_ref() {
+ let mut models_ref = models.borrow_mut();
+ let has_model = models_ref.available_models.iter().any(|model| model.model_id == default_model);
+
+ if has_model {
+ let initial_model_id = models_ref.current_model_id.clone();
+
+ cx.spawn({
+ let default_model = default_model.clone();
+ let session_id = response.session_id.clone();
+ let models = models.clone();
+ let conn = conn.clone();
+ async move |_| {
+ let result = conn.set_session_model(acp::SetSessionModelRequest {
+ session_id,
+ model_id: default_model,
+ meta: None,
+ })
+ .await.log_err();
+
+ if result.is_none() {
+ models.borrow_mut().current_model_id = initial_model_id;
+ }
+ }
+ }).detach();
+
+ models_ref.current_model_id = default_model;
+ } else {
+ let available_models = models_ref
+ .available_models
+ .iter()
+ .map(|model| format!("- `{}`: {}", model.model_id, model.name))
+ .collect::<Vec<_>>()
+ .join("\n");
+
+ log::warn!(
+ "`{default_model}` is not a valid {name} model. Available options:\n{available_models}",
+ );
+ }
+ } else {
+ log::warn!(
+ "`{name}` does not support model selection, but `default_model` was set in settings.",
+ );
+ }
+ }
+
let session_id = response.session_id;
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|cx| {
@@ -68,6 +68,18 @@ pub trait AgentServer: Send {
) {
}
+ fn default_model(&self, _cx: &mut App) -> Option<agent_client_protocol::ModelId> {
+ None
+ }
+
+ fn set_default_model(
+ &self,
+ _model_id: Option<agent_client_protocol::ModelId>,
+ _fs: Arc<dyn Fs>,
+ _cx: &mut App,
+ ) {
+ }
+
fn connect(
&self,
root_dir: Option<&Path>,
@@ -55,6 +55,27 @@ impl AgentServer for ClaudeCode {
});
}
+ fn default_model(&self, cx: &mut App) -> Option<acp::ModelId> {
+ let settings = cx.read_global(|settings: &SettingsStore, _| {
+ settings.get::<AllAgentServersSettings>(None).claude.clone()
+ });
+
+ settings
+ .as_ref()
+ .and_then(|s| s.default_model.clone().map(|m| acp::ModelId(m.into())))
+ }
+
+ fn set_default_model(&self, model_id: Option<acp::ModelId>, fs: Arc<dyn Fs>, cx: &mut App) {
+ update_settings_file(fs, cx, |settings, _| {
+ settings
+ .agent_servers
+ .get_or_insert_default()
+ .claude
+ .get_or_insert_default()
+ .default_model = model_id.map(|m| m.to_string())
+ });
+ }
+
fn connect(
&self,
root_dir: Option<&Path>,
@@ -68,6 +89,7 @@ impl AgentServer for ClaudeCode {
let store = delegate.store.downgrade();
let extra_env = load_proxy_env(cx);
let default_mode = self.default_mode(cx);
+ let default_model = self.default_model(cx);
cx.spawn(async move |cx| {
let (command, root_dir, login) = store
@@ -90,6 +112,7 @@ impl AgentServer for ClaudeCode {
command,
root_dir.as_ref(),
default_mode,
+ default_model,
is_remote,
cx,
)
@@ -56,6 +56,27 @@ impl AgentServer for Codex {
});
}
+ fn default_model(&self, cx: &mut App) -> Option<acp::ModelId> {
+ let settings = cx.read_global(|settings: &SettingsStore, _| {
+ settings.get::<AllAgentServersSettings>(None).codex.clone()
+ });
+
+ settings
+ .as_ref()
+ .and_then(|s| s.default_model.clone().map(|m| acp::ModelId(m.into())))
+ }
+
+ fn set_default_model(&self, model_id: Option<acp::ModelId>, fs: Arc<dyn Fs>, cx: &mut App) {
+ update_settings_file(fs, cx, |settings, _| {
+ settings
+ .agent_servers
+ .get_or_insert_default()
+ .codex
+ .get_or_insert_default()
+ .default_model = model_id.map(|m| m.to_string())
+ });
+ }
+
fn connect(
&self,
root_dir: Option<&Path>,
@@ -69,6 +90,7 @@ impl AgentServer for Codex {
let store = delegate.store.downgrade();
let extra_env = load_proxy_env(cx);
let default_mode = self.default_mode(cx);
+ let default_model = self.default_model(cx);
cx.spawn(async move |cx| {
let (command, root_dir, login) = store
@@ -92,6 +114,7 @@ impl AgentServer for Codex {
command,
root_dir.as_ref(),
default_mode,
+ default_model,
is_remote,
cx,
)
@@ -61,6 +61,34 @@ impl crate::AgentServer for CustomAgentServer {
});
}
+ fn default_model(&self, cx: &mut App) -> Option<acp::ModelId> {
+ let settings = cx.read_global(|settings: &SettingsStore, _| {
+ settings
+ .get::<AllAgentServersSettings>(None)
+ .custom
+ .get(&self.name())
+ .cloned()
+ });
+
+ settings
+ .as_ref()
+ .and_then(|s| s.default_model.clone().map(|m| acp::ModelId(m.into())))
+ }
+
+ fn set_default_model(&self, model_id: Option<acp::ModelId>, fs: Arc<dyn Fs>, cx: &mut App) {
+ let name = self.name();
+ update_settings_file(fs, cx, move |settings, _| {
+ if let Some(settings) = settings
+ .agent_servers
+ .get_or_insert_default()
+ .custom
+ .get_mut(&name)
+ {
+ settings.default_model = model_id.map(|m| m.to_string())
+ }
+ });
+ }
+
fn connect(
&self,
root_dir: Option<&Path>,
@@ -72,6 +100,7 @@ impl crate::AgentServer for CustomAgentServer {
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let default_mode = self.default_mode(cx);
+ let default_model = self.default_model(cx);
let store = delegate.store.downgrade();
let extra_env = load_proxy_env(cx);
@@ -98,6 +127,7 @@ impl crate::AgentServer for CustomAgentServer {
command,
root_dir.as_ref(),
default_mode,
+ default_model,
is_remote,
cx,
)
@@ -476,6 +476,7 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
env: None,
ignore_system_version: None,
default_mode: None,
+ default_model: None,
}),
gemini: Some(crate::gemini::tests::local_command().into()),
codex: Some(BuiltinAgentServerSettings {
@@ -484,6 +485,7 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
env: None,
ignore_system_version: None,
default_mode: None,
+ default_model: None,
}),
custom: collections::HashMap::default(),
},
@@ -37,6 +37,7 @@ impl AgentServer for Gemini {
let store = delegate.store.downgrade();
let mut extra_env = load_proxy_env(cx);
let default_mode = self.default_mode(cx);
+ let default_model = self.default_model(cx);
cx.spawn(async move |cx| {
extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
@@ -69,6 +70,7 @@ impl AgentServer for Gemini {
command,
root_dir.as_ref(),
default_mode,
+ default_model,
is_remote,
cx,
)
@@ -11,7 +11,7 @@ use ui::{
PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*,
};
-use crate::{CycleModeSelector, ToggleProfileSelector};
+use crate::{CycleModeSelector, ToggleProfileSelector, ui::HoldForDefault};
pub struct ModeSelector {
connection: Rc<dyn AgentSessionModes>,
@@ -108,36 +108,11 @@ impl ModeSelector {
entry.documentation_aside(side, DocumentationEdge::Bottom, {
let description = description.clone();
- move |cx| {
+ move |_| {
v_flex()
.gap_1()
.child(Label::new(description.clone()))
- .child(
- h_flex()
- .pt_1()
- .border_t_1()
- .border_color(cx.theme().colors().border_variant)
- .gap_0p5()
- .text_sm()
- .text_color(Color::Muted.color(cx))
- .child("Hold")
- .child(h_flex().flex_shrink_0().children(
- ui::render_modifiers(
- &gpui::Modifiers::secondary_key(),
- PlatformStyle::platform(),
- None,
- Some(ui::TextSize::Default.rems(cx).into()),
- true,
- ),
- ))
- .child(div().map(|this| {
- if is_default {
- this.child("to also unset as default")
- } else {
- this.child("to also set as default")
- }
- })),
- )
+ .child(HoldForDefault::new(is_default))
.into_any_element()
}
})
@@ -1,8 +1,10 @@
use std::{cmp::Reverse, rc::Rc, sync::Arc};
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
+use agent_servers::AgentServer;
use anyhow::Result;
use collections::IndexMap;
+use fs::Fs;
use futures::FutureExt;
use fuzzy::{StringMatchCandidate, match_strings};
use gpui::{AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
@@ -14,14 +16,18 @@ use ui::{
};
use util::ResultExt;
+use crate::ui::HoldForDefault;
+
pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
pub fn acp_model_selector(
selector: Rc<dyn AgentModelSelector>,
+ agent_server: Rc<dyn AgentServer>,
+ fs: Arc<dyn Fs>,
window: &mut Window,
cx: &mut Context<AcpModelSelector>,
) -> AcpModelSelector {
- let delegate = AcpModelPickerDelegate::new(selector, window, cx);
+ let delegate = AcpModelPickerDelegate::new(selector, agent_server, fs, window, cx);
Picker::list(delegate, window, cx)
.show_scrollbar(true)
.width(rems(20.))
@@ -35,10 +41,12 @@ enum AcpModelPickerEntry {
pub struct AcpModelPickerDelegate {
selector: Rc<dyn AgentModelSelector>,
+ agent_server: Rc<dyn AgentServer>,
+ fs: Arc<dyn Fs>,
filtered_entries: Vec<AcpModelPickerEntry>,
models: Option<AgentModelList>,
selected_index: usize,
- selected_description: Option<(usize, SharedString)>,
+ selected_description: Option<(usize, SharedString, bool)>,
selected_model: Option<AgentModelInfo>,
_refresh_models_task: Task<()>,
}
@@ -46,6 +54,8 @@ pub struct AcpModelPickerDelegate {
impl AcpModelPickerDelegate {
fn new(
selector: Rc<dyn AgentModelSelector>,
+ agent_server: Rc<dyn AgentServer>,
+ fs: Arc<dyn Fs>,
window: &mut Window,
cx: &mut Context<AcpModelSelector>,
) -> Self {
@@ -86,6 +96,8 @@ impl AcpModelPickerDelegate {
Self {
selector,
+ agent_server,
+ fs,
filtered_entries: Vec::new(),
models: None,
selected_model: None,
@@ -181,6 +193,21 @@ impl PickerDelegate for AcpModelPickerDelegate {
if let Some(AcpModelPickerEntry::Model(model_info)) =
self.filtered_entries.get(self.selected_index)
{
+ if window.modifiers().secondary() {
+ let default_model = self.agent_server.default_model(cx);
+ let is_default = default_model.as_ref() == Some(&model_info.id);
+
+ self.agent_server.set_default_model(
+ if is_default {
+ None
+ } else {
+ Some(model_info.id.clone())
+ },
+ self.fs.clone(),
+ cx,
+ );
+ }
+
self.selector
.select_model(model_info.id.clone(), cx)
.detach_and_log_err(cx);
@@ -225,6 +252,8 @@ impl PickerDelegate for AcpModelPickerDelegate {
),
AcpModelPickerEntry::Model(model_info) => {
let is_selected = Some(model_info) == self.selected_model.as_ref();
+ let default_model = self.agent_server.default_model(cx);
+ let is_default = default_model.as_ref() == Some(&model_info.id);
let model_icon_color = if is_selected {
Color::Accent
@@ -239,8 +268,8 @@ impl PickerDelegate for AcpModelPickerDelegate {
this
.on_hover(cx.listener(move |menu, hovered, _, cx| {
if *hovered {
- menu.delegate.selected_description = Some((ix, description.clone()));
- } else if matches!(menu.delegate.selected_description, Some((id, _)) if id == ix) {
+ menu.delegate.selected_description = Some((ix, description.clone(), is_default));
+ } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
menu.delegate.selected_description = None;
}
cx.notify();
@@ -283,14 +312,24 @@ impl PickerDelegate for AcpModelPickerDelegate {
_window: &mut Window,
_cx: &mut Context<Picker<Self>>,
) -> Option<ui::DocumentationAside> {
- self.selected_description.as_ref().map(|(_, description)| {
- let description = description.clone();
- DocumentationAside::new(
- DocumentationSide::Left,
- DocumentationEdge::Top,
- Rc::new(move |_| Label::new(description.clone()).into_any_element()),
- )
- })
+ self.selected_description
+ .as_ref()
+ .map(|(_, description, is_default)| {
+ let description = description.clone();
+ let is_default = *is_default;
+
+ DocumentationAside::new(
+ DocumentationSide::Left,
+ DocumentationEdge::Top,
+ Rc::new(move |_| {
+ v_flex()
+ .gap_1()
+ .child(Label::new(description.clone()))
+ .child(HoldForDefault::new(is_default))
+ .into_any_element()
+ }),
+ )
+ })
}
}
@@ -1,6 +1,9 @@
use std::rc::Rc;
+use std::sync::Arc;
use acp_thread::{AgentModelInfo, AgentModelSelector};
+use agent_servers::AgentServer;
+use fs::Fs;
use gpui::{Entity, FocusHandle};
use picker::popover_menu::PickerPopoverMenu;
use ui::{
@@ -20,13 +23,15 @@ pub struct AcpModelSelectorPopover {
impl AcpModelSelectorPopover {
pub(crate) fn new(
selector: Rc<dyn AgentModelSelector>,
+ agent_server: Rc<dyn AgentServer>,
+ fs: Arc<dyn Fs>,
menu_handle: PopoverMenuHandle<AcpModelSelector>,
focus_handle: FocusHandle,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
Self {
- selector: cx.new(move |cx| acp_model_selector(selector, window, cx)),
+ selector: cx.new(move |cx| acp_model_selector(selector, agent_server, fs, window, cx)),
menu_handle,
focus_handle,
}
@@ -591,9 +591,13 @@ impl AcpThreadView {
.connection()
.model_selector(thread.read(cx).session_id())
.map(|selector| {
+ let agent_server = this.agent.clone();
+ let fs = this.project.read(cx).fs().clone();
cx.new(|cx| {
AcpModelSelectorPopover::new(
selector,
+ agent_server,
+ fs,
PopoverMenuHandle::default(),
this.focus_handle(cx),
window,
@@ -1348,6 +1348,7 @@ async fn open_new_agent_servers_entry_in_settings_editor(
args: vec![],
env: Some(HashMap::default()),
default_mode: None,
+ default_model: None,
},
);
}
@@ -4,6 +4,7 @@ mod burn_mode_tooltip;
mod claude_code_onboarding_modal;
mod context_pill;
mod end_trial_upsell;
+mod hold_for_default;
mod onboarding_modal;
mod unavailable_editing_tooltip;
mod usage_callout;
@@ -14,6 +15,7 @@ pub use burn_mode_tooltip::*;
pub use claude_code_onboarding_modal::*;
pub use context_pill::*;
pub use end_trial_upsell::*;
+pub use hold_for_default::*;
pub use onboarding_modal::*;
pub use unavailable_editing_tooltip::*;
pub use usage_callout::*;
@@ -0,0 +1,40 @@
+use gpui::{App, IntoElement, Modifiers, RenderOnce, Window};
+use ui::{prelude::*, render_modifiers};
+
+#[derive(IntoElement)]
+pub struct HoldForDefault {
+ is_default: bool,
+}
+
+impl HoldForDefault {
+ pub fn new(is_default: bool) -> Self {
+ Self { is_default }
+ }
+}
+
+impl RenderOnce for HoldForDefault {
+ fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
+ h_flex()
+ .pt_1()
+ .border_t_1()
+ .border_color(cx.theme().colors().border_variant)
+ .gap_0p5()
+ .text_sm()
+ .text_color(Color::Muted.color(cx))
+ .child("Hold")
+ .child(h_flex().flex_shrink_0().children(render_modifiers(
+ &Modifiers::secondary_key(),
+ PlatformStyle::platform(),
+ None,
+ Some(TextSize::Default.rems(cx).into()),
+ true,
+ )))
+ .child(div().map(|this| {
+ if self.is_default {
+ this.child("to unset as default")
+ } else {
+ this.child("to set as default")
+ }
+ }))
+ }
+}
@@ -1777,6 +1777,7 @@ pub struct BuiltinAgentServerSettings {
pub env: Option<HashMap<String, String>>,
pub ignore_system_version: Option<bool>,
pub default_mode: Option<String>,
+ pub default_model: Option<String>,
}
impl BuiltinAgentServerSettings {
@@ -1799,6 +1800,7 @@ impl From<settings::BuiltinAgentServerSettings> for BuiltinAgentServerSettings {
env: value.env,
ignore_system_version: value.ignore_system_version,
default_mode: value.default_mode,
+ default_model: value.default_model,
}
}
}
@@ -1823,6 +1825,12 @@ pub struct CustomAgentServerSettings {
///
/// Default: None
pub default_mode: Option<String>,
+ /// The default model to use for this agent.
+ ///
+ /// This should be the model ID as reported by the agent.
+ ///
+ /// Default: None
+ pub default_model: Option<String>,
}
impl From<settings::CustomAgentServerSettings> for CustomAgentServerSettings {
@@ -1834,6 +1842,7 @@ impl From<settings::CustomAgentServerSettings> for CustomAgentServerSettings {
env: value.env,
},
default_mode: value.default_mode,
+ default_model: value.default_model,
}
}
}
@@ -2156,6 +2165,7 @@ mod extension_agent_tests {
env: None,
ignore_system_version: None,
default_mode: None,
+ default_model: None,
};
let BuiltinAgentServerSettings { path, .. } = settings.into();
@@ -2171,6 +2181,7 @@ mod extension_agent_tests {
args: vec!["serve".into()],
env: None,
default_mode: None,
+ default_model: None,
};
let CustomAgentServerSettings {
@@ -332,6 +332,12 @@ pub struct BuiltinAgentServerSettings {
///
/// Default: None
pub default_mode: Option<String>,
+ /// The default model to use for this agent.
+ ///
+ /// This should be the model ID as reported by the agent.
+ ///
+ /// Default: None
+ pub default_model: Option<String>,
}
#[skip_serializing_none]
@@ -348,4 +354,10 @@ pub struct CustomAgentServerSettings {
///
/// Default: None
pub default_mode: Option<String>,
+ /// The default model to use for this agent.
+ ///
+ /// This should be the model ID as reported by the agent.
+ ///
+ /// Default: None
+ pub default_model: Option<String>,
}