Detailed changes
@@ -933,7 +933,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
// Test that test-1 profile (default) has echo and delay tools
thread
.update(cx, |thread, cx| {
- thread.set_profile(AgentProfileId("test-1".into()));
+ thread.set_profile(AgentProfileId("test-1".into()), cx);
thread.send(UserMessageId::new(), ["test"], cx)
})
.unwrap();
@@ -953,7 +953,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
// Switch to test-2 profile, and verify that it has only the infinite tool.
thread
.update(cx, |thread, cx| {
- thread.set_profile(AgentProfileId("test-2".into()));
+ thread.set_profile(AgentProfileId("test-2".into()), cx);
thread.send(UserMessageId::new(), ["test2"], cx)
})
.unwrap();
@@ -1002,8 +1002,8 @@ async fn test_mcp_tools(cx: &mut TestAppContext) {
)
.await;
cx.run_until_parked();
- thread.update(cx, |thread, _| {
- thread.set_profile(AgentProfileId("test".into()))
+ thread.update(cx, |thread, cx| {
+ thread.set_profile(AgentProfileId("test".into()), cx)
});
let mut mcp_tool_calls = setup_context_server(
@@ -1169,8 +1169,8 @@ async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
.await;
cx.run_until_parked();
- thread.update(cx, |thread, _| {
- thread.set_profile(AgentProfileId("test".into()));
+ thread.update(cx, |thread, cx| {
+ thread.set_profile(AgentProfileId("test".into()), cx);
thread.add_tool(EchoTool);
thread.add_tool(DelayTool);
thread.add_tool(WordListTool);
@@ -30,16 +30,17 @@ use gpui::{
};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
- LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
- LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
- LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
- LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID,
+ LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry,
+ LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
+ LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
+ LanguageModelToolUse, LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
+ ZED_CLOUD_PROVIDER_ID,
};
use project::Project;
use prompt_store::ProjectContext;
use schemars::{JsonSchema, Schema};
use serde::{Deserialize, Serialize};
-use settings::{Settings, update_settings_file};
+use settings::{LanguageModelSelection, Settings, update_settings_file};
use smol::stream::StreamExt;
use std::{
collections::BTreeMap,
@@ -798,7 +799,8 @@ impl Thread {
let profile_id = db_thread
.profile
.unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
- let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+
+ let mut model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
db_thread
.model
.and_then(|model| {
@@ -811,6 +813,16 @@ impl Thread {
.or_else(|| registry.default_model())
.map(|model| model.model)
});
+
+ if model.is_none() {
+ model = Self::resolve_profile_model(&profile_id, cx);
+ }
+ if model.is_none() {
+ model = LanguageModelRegistry::global(cx).update(cx, |registry, _cx| {
+ registry.default_model().map(|model| model.model)
+ });
+ }
+
let (prompt_capabilities_tx, prompt_capabilities_rx) =
watch::channel(Self::prompt_capabilities(model.as_deref()));
@@ -1007,8 +1019,17 @@ impl Thread {
&self.profile_id
}
- pub fn set_profile(&mut self, profile_id: AgentProfileId) {
+ pub fn set_profile(&mut self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
+ if self.profile_id == profile_id {
+ return;
+ }
+
self.profile_id = profile_id;
+
+ // Swap to the profile's preferred model when available.
+ if let Some(model) = Self::resolve_profile_model(&self.profile_id, cx) {
+ self.set_model(model, cx);
+ }
}
pub fn cancel(&mut self, cx: &mut Context<Self>) {
@@ -1065,6 +1086,35 @@ impl Thread {
})
}
+ /// Look up the active profile and resolve its preferred model if one is configured.
+ fn resolve_profile_model(
+ profile_id: &AgentProfileId,
+ cx: &mut Context<Self>,
+ ) -> Option<Arc<dyn LanguageModel>> {
+ let selection = AgentSettings::get_global(cx)
+ .profiles
+ .get(profile_id)?
+ .default_model
+ .clone()?;
+ Self::resolve_model_from_selection(&selection, cx)
+ }
+
+ /// Translate a stored model selection into the configured model from the registry.
+ fn resolve_model_from_selection(
+ selection: &LanguageModelSelection,
+ cx: &mut Context<Self>,
+ ) -> Option<Arc<dyn LanguageModel>> {
+ let selected = SelectedModel {
+ provider: LanguageModelProviderId::from(selection.provider.0.clone()),
+ model: LanguageModelId::from(selection.model.clone()),
+ };
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry
+ .select_model(&selected, cx)
+ .map(|configured| configured.model)
+ })
+ }
+
pub fn resume(
&mut self,
cx: &mut Context<Self>,
@@ -6,8 +6,8 @@ use convert_case::{Case, Casing as _};
use fs::Fs;
use gpui::{App, SharedString};
use settings::{
- AgentProfileContent, ContextServerPresetContent, Settings as _, SettingsContent,
- update_settings_file,
+ AgentProfileContent, ContextServerPresetContent, LanguageModelSelection, Settings as _,
+ SettingsContent, update_settings_file,
};
use util::ResultExt as _;
@@ -53,19 +53,30 @@ impl AgentProfile {
let base_profile =
base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned());
+ // Copy toggles from the base profile so the new profile starts with familiar defaults.
+ let tools = base_profile
+ .as_ref()
+ .map(|profile| profile.tools.clone())
+ .unwrap_or_default();
+ let enable_all_context_servers = base_profile
+ .as_ref()
+ .map(|profile| profile.enable_all_context_servers)
+ .unwrap_or_default();
+ let context_servers = base_profile
+ .as_ref()
+ .map(|profile| profile.context_servers.clone())
+ .unwrap_or_default();
+ // Preserve the base profile's model preference when cloning into a new profile.
+ let default_model = base_profile
+ .as_ref()
+ .and_then(|profile| profile.default_model.clone());
+
let profile_settings = AgentProfileSettings {
name: name.into(),
- tools: base_profile
- .as_ref()
- .map(|profile| profile.tools.clone())
- .unwrap_or_default(),
- enable_all_context_servers: base_profile
- .as_ref()
- .map(|profile| profile.enable_all_context_servers)
- .unwrap_or_default(),
- context_servers: base_profile
- .map(|profile| profile.context_servers)
- .unwrap_or_default(),
+ tools,
+ enable_all_context_servers,
+ context_servers,
+ default_model,
};
update_settings_file(fs, cx, {
@@ -96,6 +107,8 @@ pub struct AgentProfileSettings {
pub tools: IndexMap<Arc<str>, bool>,
pub enable_all_context_servers: bool,
pub context_servers: IndexMap<Arc<str>, ContextServerPreset>,
+ /// Default language model to apply when this profile becomes active.
+ pub default_model: Option<LanguageModelSelection>,
}
impl AgentProfileSettings {
@@ -144,6 +157,7 @@ impl AgentProfileSettings {
)
})
.collect(),
+ default_model: self.default_model.clone(),
},
);
@@ -153,15 +167,23 @@ impl AgentProfileSettings {
impl From<AgentProfileContent> for AgentProfileSettings {
fn from(content: AgentProfileContent) -> Self {
+ let AgentProfileContent {
+ name,
+ tools,
+ enable_all_context_servers,
+ context_servers,
+ default_model,
+ } = content;
+
Self {
- name: content.name.into(),
- tools: content.tools,
- enable_all_context_servers: content.enable_all_context_servers.unwrap_or_default(),
- context_servers: content
- .context_servers
+ name: name.into(),
+ tools,
+ enable_all_context_servers: enable_all_context_servers.unwrap_or_default(),
+ context_servers: context_servers
.into_iter()
.map(|(server_id, preset)| (server_id, preset.into()))
.collect(),
+ default_model,
}
}
}
@@ -125,8 +125,9 @@ impl ProfileProvider for Entity<agent::Thread> {
}
fn set_profile(&self, profile_id: AgentProfileId, cx: &mut App) {
- self.update(cx, |thread, _cx| {
- thread.set_profile(profile_id);
+ self.update(cx, |thread, cx| {
+ // Apply the profile and let the thread swap to its default model.
+ thread.set_profile(profile_id, cx);
});
}
@@ -314,6 +314,7 @@ impl PickerDelegate for ToolPickerDelegate {
)
})
.collect(),
+ default_model: default_profile.default_model.clone(),
});
if let Some(server_id) = server_id {
@@ -15,8 +15,8 @@ use std::{
sync::{Arc, atomic::AtomicBool},
};
use ui::{
- DocumentationAside, DocumentationEdge, DocumentationSide, HighlightedLabel, LabelSize,
- ListItem, ListItemSpacing, PopoverMenuHandle, TintColor, Tooltip, prelude::*,
+ DocumentationAside, DocumentationEdge, DocumentationSide, HighlightedLabel, KeyBinding,
+ LabelSize, ListItem, ListItemSpacing, PopoverMenuHandle, TintColor, Tooltip, prelude::*,
};
/// Trait for types that can provide and manage agent profiles
@@ -81,6 +81,7 @@ impl ProfileSelector {
self.provider.clone(),
self.profiles.clone(),
cx.background_executor().clone(),
+ self.focus_handle.clone(),
cx,
);
@@ -207,6 +208,7 @@ pub(crate) struct ProfilePickerDelegate {
selected_index: usize,
query: String,
cancel: Option<Arc<AtomicBool>>,
+ focus_handle: FocusHandle,
}
impl ProfilePickerDelegate {
@@ -215,6 +217,7 @@ impl ProfilePickerDelegate {
provider: Arc<dyn ProfileProvider>,
profiles: AvailableProfiles,
background: BackgroundExecutor,
+ focus_handle: FocusHandle,
cx: &mut Context<ProfileSelector>,
) -> Self {
let candidates = Self::candidates_from(profiles);
@@ -231,6 +234,7 @@ impl ProfilePickerDelegate {
selected_index: 0,
query: String::new(),
cancel: None,
+ focus_handle,
};
this.selected_index = this
@@ -594,20 +598,26 @@ impl PickerDelegate for ProfilePickerDelegate {
_: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<gpui::AnyElement> {
+ let focus_handle = self.focus_handle.clone();
+
Some(
h_flex()
.w_full()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
- .p_1()
- .gap_4()
- .justify_between()
+ .p_1p5()
.child(
Button::new("configure", "Configure")
- .icon(IconName::Settings)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
+ .full_width()
+ .style(ButtonStyle::Outlined)
+ .key_binding(
+ KeyBinding::for_action_in(
+ &ManageProfiles::default(),
+ &focus_handle,
+ cx,
+ )
+ .map(|kb| kb.size(rems_from_px(12.))),
+ )
.on_click(|_, window, cx| {
window.dispatch_action(ManageProfiles::default().boxed_clone(), cx);
}),
@@ -659,20 +669,25 @@ mod tests {
is_builtin: true,
}];
- let delegate = ProfilePickerDelegate {
- fs: FakeFs::new(cx.executor()),
- provider: Arc::new(TestProfileProvider::new(AgentProfileId("write".into()))),
- background: cx.executor(),
- candidates,
- string_candidates: Arc::new(Vec::new()),
- filtered_entries: Vec::new(),
- selected_index: 0,
- query: String::new(),
- cancel: None,
- };
-
- let matches = Vec::new(); // No matches
- let _entries = delegate.entries_from_matches(matches);
+ cx.update(|cx| {
+ let focus_handle = cx.focus_handle();
+
+ let delegate = ProfilePickerDelegate {
+ fs: FakeFs::new(cx.background_executor().clone()),
+ provider: Arc::new(TestProfileProvider::new(AgentProfileId("write".into()))),
+ background: cx.background_executor().clone(),
+ candidates,
+ string_candidates: Arc::new(Vec::new()),
+ filtered_entries: Vec::new(),
+ selected_index: 0,
+ query: String::new(),
+ cancel: None,
+ focus_handle,
+ };
+
+ let matches = Vec::new(); // No matches
+ let _entries = delegate.entries_from_matches(matches);
+ });
}
#[gpui::test]
@@ -690,30 +705,35 @@ mod tests {
},
];
- let delegate = ProfilePickerDelegate {
- fs: FakeFs::new(cx.executor()),
- provider: Arc::new(TestProfileProvider::new(AgentProfileId("write".into()))),
- background: cx.executor(),
- candidates,
- string_candidates: Arc::new(Vec::new()),
- filtered_entries: vec![
- ProfilePickerEntry::Profile(ProfileMatchEntry {
- candidate_index: 0,
- positions: Vec::new(),
- }),
- ProfilePickerEntry::Profile(ProfileMatchEntry {
- candidate_index: 1,
- positions: Vec::new(),
- }),
- ],
- selected_index: 0,
- query: String::new(),
- cancel: None,
- };
-
- // Active profile should be found at index 0
- let active_index = delegate.index_of_profile(&AgentProfileId("write".into()));
- assert_eq!(active_index, Some(0));
+ cx.update(|cx| {
+ let focus_handle = cx.focus_handle();
+
+ let delegate = ProfilePickerDelegate {
+ fs: FakeFs::new(cx.background_executor().clone()),
+ provider: Arc::new(TestProfileProvider::new(AgentProfileId("write".into()))),
+ background: cx.background_executor().clone(),
+ candidates,
+ string_candidates: Arc::new(Vec::new()),
+ filtered_entries: vec![
+ ProfilePickerEntry::Profile(ProfileMatchEntry {
+ candidate_index: 0,
+ positions: Vec::new(),
+ }),
+ ProfilePickerEntry::Profile(ProfileMatchEntry {
+ candidate_index: 1,
+ positions: Vec::new(),
+ }),
+ ],
+ selected_index: 0,
+ query: String::new(),
+ cancel: None,
+ focus_handle,
+ };
+
+ // Active profile should be found at index 0
+ let active_index = delegate.index_of_profile(&AgentProfileId("write".into()));
+ assert_eq!(active_index, Some(0));
+ });
}
struct TestProfileProvider {
@@ -322,7 +322,7 @@ impl ExampleInstance {
thread.add_default_tools(Rc::new(EvalThreadEnvironment {
project: project.clone(),
}), cx);
- thread.set_profile(meta.profile_id.clone());
+ thread.set_profile(meta.profile_id.clone(), cx);
thread.set_model(
LanguageModelInterceptor::new(
LanguageModelRegistry::read_global(cx).default_model().expect("Missing model").model.clone(),
@@ -176,6 +176,8 @@ pub struct AgentProfileContent {
pub enable_all_context_servers: Option<bool>,
#[serde(default)]
pub context_servers: IndexMap<Arc<str>, ContextServerPresetContent>,
+ /// The default language model selected when using this profile.
+ pub default_model: Option<LanguageModelSelection>,
}
#[skip_serializing_none]