From 709523bf363b48206eb30a8e9bef9a5ce4603c48 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Fri, 6 Jun 2025 14:05:27 +0200 Subject: [PATCH] Store profile per thread (#31907) This allows storing the profile per thread, as well as moving the logic of which tools are enabled or not to the profile itself. This makes it much easier to switch between profiles, means there is less global state being changed on every profile change. Release Notes: - agent panel: allow saving the profile per thread --------- Co-authored-by: Ben Kunkle --- Cargo.lock | 3 +- assets/settings/default.json | 1 - crates/agent/Cargo.toml | 1 + crates/agent/src/active_thread.rs | 4 + crates/agent/src/agent.rs | 1 + .../manage_profiles_modal.rs | 66 +--- .../src/agent_configuration/tool_picker.rs | 54 +-- crates/agent/src/agent_diff.rs | 3 +- crates/agent/src/agent_profile.rs | 334 ++++++++++++++++++ crates/agent/src/message_editor.rs | 14 +- crates/agent/src/profile_selector.rs | 64 ++-- crates/agent/src/thread.rs | 122 +++++-- crates/agent/src/thread_store.rs | 102 +----- crates/agent/src/tool_compatibility.rs | 22 +- crates/agent_settings/Cargo.toml | 1 - crates/agent_settings/src/agent_profile.rs | 25 +- crates/agent_settings/src/agent_settings.rs | 14 +- crates/assistant_tool/src/tool_working_set.rs | 84 +---- crates/collab/Cargo.toml | 1 - crates/eval/src/example.rs | 1 + crates/eval/src/instance.rs | 10 +- 21 files changed, 557 insertions(+), 370 deletions(-) create mode 100644 crates/agent/src/agent_profile.rs diff --git a/Cargo.lock b/Cargo.lock index 9554c46aacf2f24c2c8eb4b4ccf31f0a313d0aa5..af14e424300ef57a3be8fe7c031989e52bc6f248 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -59,6 +59,7 @@ dependencies = [ "assistant_slash_command", "assistant_slash_commands", "assistant_tool", + "assistant_tools", "async-watch", "audio", "buffer_diff", @@ -147,7 +148,6 @@ dependencies = [ "deepseek", "fs", "gpui", - "indexmap", "language_model", "lmstudio", "log", @@ -2987,7 +2987,6 @@ dependencies = [ "anyhow", "assistant_context_editor", "assistant_slash_command", - "assistant_tool", "async-stripe", "async-trait", "async-tungstenite", diff --git a/assets/settings/default.json b/assets/settings/default.json index 8d8c65884cdc7e593b49d22895da8c8a1b3e0e47..fbcde696c3bfc6e7e5f9c026f029de2ac3f8933e 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -771,7 +771,6 @@ "tools": { "copy_path": true, "create_directory": true, - "create_file": true, "delete_path": true, "diagnostics": true, "edit_file": true, diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index c1f9d9a3fafe24cfbf0e59bc6a8b65c80bcddbc9..1b07d9460519b7d619449d6a5e64966a0fe855a9 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -102,6 +102,7 @@ zed_llm_client.workspace = true zstd.workspace = true [dev-dependencies] +assistant_tools.workspace = true buffer_diff = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index a983d43690006b07b7daeba128b2c7fe5e7581d5..8eda04c60fed9b31f99698b6cf223611ab5860a3 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1144,6 +1144,10 @@ impl ActiveThread { cx, ); } + ThreadEvent::ProfileChanged => { + self.save_thread(cx); + cx.notify(); + } } } diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index db458b771e93ed4996ebf189767cb2ab34c685c7..0ac78699205ace84ee6090f55abad5148ae4fb43 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -3,6 +3,7 @@ mod agent_configuration; mod agent_diff; mod agent_model_selector; mod agent_panel; +mod agent_profile; mod buffer_codegen; mod context; mod context_picker; diff --git a/crates/agent/src/agent_configuration/manage_profiles_modal.rs b/crates/agent/src/agent_configuration/manage_profiles_modal.rs index 8cb7d4dfe2973e7dc25a7e38ab73c99f62b079be..feb0a8e53f61171a2245e3d53176479973de0912 100644 --- a/crates/agent/src/agent_configuration/manage_profiles_modal.rs +++ b/crates/agent/src/agent_configuration/manage_profiles_modal.rs @@ -2,25 +2,21 @@ mod profile_modal_header; use std::sync::Arc; -use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, builtin_profiles}; +use agent_settings::{AgentProfileId, AgentSettings, builtin_profiles}; use assistant_tool::ToolWorkingSet; -use convert_case::{Case, Casing as _}; use editor::Editor; use fs::Fs; -use gpui::{ - DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Subscription, WeakEntity, - prelude::*, -}; -use settings::{Settings as _, update_settings_file}; +use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Subscription, prelude::*}; +use settings::Settings as _; use ui::{ KeyBinding, ListItem, ListItemSpacing, ListSeparator, Navigable, NavigableEntry, prelude::*, }; -use util::ResultExt as _; use workspace::{ModalView, Workspace}; use crate::agent_configuration::manage_profiles_modal::profile_modal_header::ProfileModalHeader; use crate::agent_configuration::tool_picker::{ToolPicker, ToolPickerDelegate}; -use crate::{AgentPanel, ManageProfiles, ThreadStore}; +use crate::agent_profile::AgentProfile; +use crate::{AgentPanel, ManageProfiles}; use super::tool_picker::ToolPickerMode; @@ -103,7 +99,6 @@ pub struct NewProfileMode { pub struct ManageProfilesModal { fs: Arc, tools: Entity, - thread_store: WeakEntity, focus_handle: FocusHandle, mode: Mode, } @@ -119,9 +114,8 @@ impl ManageProfilesModal { let fs = workspace.app_state().fs.clone(); let thread_store = panel.read(cx).thread_store(); let tools = thread_store.read(cx).tools(); - let thread_store = thread_store.downgrade(); workspace.toggle_modal(window, cx, |window, cx| { - let mut this = Self::new(fs, tools, thread_store, window, cx); + let mut this = Self::new(fs, tools, window, cx); if let Some(profile_id) = action.customize_tools.clone() { this.configure_builtin_tools(profile_id, window, cx); @@ -136,7 +130,6 @@ impl ManageProfilesModal { pub fn new( fs: Arc, tools: Entity, - thread_store: WeakEntity, window: &mut Window, cx: &mut Context, ) -> Self { @@ -145,7 +138,6 @@ impl ManageProfilesModal { Self { fs, tools, - thread_store, focus_handle, mode: Mode::choose_profile(window, cx), } @@ -206,7 +198,6 @@ impl ManageProfilesModal { ToolPickerMode::McpTools, self.fs.clone(), self.tools.clone(), - self.thread_store.clone(), profile_id.clone(), profile, cx, @@ -244,7 +235,6 @@ impl ManageProfilesModal { ToolPickerMode::BuiltinTools, self.fs.clone(), self.tools.clone(), - self.thread_store.clone(), profile_id.clone(), profile, cx, @@ -270,32 +260,10 @@ impl ManageProfilesModal { match &self.mode { Mode::ChooseProfile { .. } => {} Mode::NewProfile(mode) => { - let settings = AgentSettings::get_global(cx); - - let base_profile = mode - .base_profile_id - .as_ref() - .and_then(|profile_id| settings.profiles.get(profile_id).cloned()); - let name = mode.name_editor.read(cx).text(cx); - let profile_id = AgentProfileId(name.to_case(Case::Kebab).into()); - - let profile = AgentProfile { - name: name.into(), - tools: base_profile - .as_ref() - .map(|profile| profile.tools.clone()) - .unwrap_or_default(), - enable_all_context_servers: base_profile - .as_ref() - .map(|profile| profile.enable_all_context_servers) - .unwrap_or_default(), - context_servers: base_profile - .map(|profile| profile.context_servers) - .unwrap_or_default(), - }; - - self.create_profile(profile_id.clone(), profile, cx); + + let profile_id = + AgentProfile::create(name, mode.base_profile_id.clone(), self.fs.clone(), cx); self.view_profile(profile_id, window, cx); } Mode::ViewProfile(_) => {} @@ -325,19 +293,6 @@ impl ManageProfilesModal { } } } - - fn create_profile( - &self, - profile_id: AgentProfileId, - profile: AgentProfile, - cx: &mut Context, - ) { - update_settings_file::(self.fs.clone(), cx, { - move |settings, _cx| { - settings.create_profile(profile_id, profile).log_err(); - } - }); - } } impl ModalView for ManageProfilesModal {} @@ -520,14 +475,13 @@ impl ManageProfilesModal { ) -> impl IntoElement { let settings = AgentSettings::get_global(cx); - let profile_id = &settings.default_profile; let profile_name = settings .profiles .get(&mode.profile_id) .map(|profile| profile.name.clone()) .unwrap_or_else(|| "Unknown".into()); - let icon = match profile_id.as_str() { + let icon = match mode.profile_id.as_str() { "write" => IconName::Pencil, "ask" => IconName::MessageBubbles, _ => IconName::UserRoundPen, diff --git a/crates/agent/src/agent_configuration/tool_picker.rs b/crates/agent/src/agent_configuration/tool_picker.rs index 5ac2d4496b53528e145a9fa92be8ebc42a35e960..7c3d20457e2b9138e49f3c61e867b2f15b54bb84 100644 --- a/crates/agent/src/agent_configuration/tool_picker.rs +++ b/crates/agent/src/agent_configuration/tool_picker.rs @@ -1,19 +1,17 @@ use std::{collections::BTreeMap, sync::Arc}; use agent_settings::{ - AgentProfile, AgentProfileContent, AgentProfileId, AgentSettings, AgentSettingsContent, + AgentProfileContent, AgentProfileId, AgentProfileSettings, AgentSettings, AgentSettingsContent, ContextServerPresetContent, }; use assistant_tool::{ToolSource, ToolWorkingSet}; use fs::Fs; use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window}; use picker::{Picker, PickerDelegate}; -use settings::{Settings as _, update_settings_file}; +use settings::update_settings_file; use ui::{ListItem, ListItemSpacing, prelude::*}; use util::ResultExt as _; -use crate::ThreadStore; - pub struct ToolPicker { picker: Entity>, } @@ -71,11 +69,10 @@ pub enum PickerItem { pub struct ToolPickerDelegate { tool_picker: WeakEntity, - thread_store: WeakEntity, fs: Arc, items: Arc>, profile_id: AgentProfileId, - profile: AgentProfile, + profile_settings: AgentProfileSettings, filtered_items: Vec, selected_index: usize, mode: ToolPickerMode, @@ -86,20 +83,18 @@ impl ToolPickerDelegate { mode: ToolPickerMode, fs: Arc, tool_set: Entity, - thread_store: WeakEntity, profile_id: AgentProfileId, - profile: AgentProfile, + profile_settings: AgentProfileSettings, cx: &mut Context, ) -> Self { let items = Arc::new(Self::resolve_items(mode, &tool_set, cx)); Self { tool_picker: cx.entity().downgrade(), - thread_store, fs, items, profile_id, - profile, + profile_settings, filtered_items: Vec::new(), selected_index: 0, mode, @@ -249,28 +244,31 @@ impl PickerDelegate for ToolPickerDelegate { }; let is_currently_enabled = if let Some(server_id) = server_id.clone() { - let preset = self.profile.context_servers.entry(server_id).or_default(); + let preset = self + .profile_settings + .context_servers + .entry(server_id) + .or_default(); let is_enabled = *preset.tools.entry(tool_name.clone()).or_default(); *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled; is_enabled } else { - let is_enabled = *self.profile.tools.entry(tool_name.clone()).or_default(); - *self.profile.tools.entry(tool_name.clone()).or_default() = !is_enabled; + let is_enabled = *self + .profile_settings + .tools + .entry(tool_name.clone()) + .or_default(); + *self + .profile_settings + .tools + .entry(tool_name.clone()) + .or_default() = !is_enabled; is_enabled }; - let active_profile_id = &AgentSettings::get_global(cx).default_profile; - if active_profile_id == &self.profile_id { - self.thread_store - .update(cx, |this, cx| { - this.load_profile(self.profile.clone(), cx); - }) - .log_err(); - } - update_settings_file::(self.fs.clone(), cx, { let profile_id = self.profile_id.clone(); - let default_profile = self.profile.clone(); + let default_profile = self.profile_settings.clone(); let server_id = server_id.clone(); let tool_name = tool_name.clone(); move |settings: &mut AgentSettingsContent, _cx| { @@ -348,14 +346,18 @@ impl PickerDelegate for ToolPickerDelegate { ), PickerItem::Tool { name, server_id } => { let is_enabled = if let Some(server_id) = server_id { - self.profile + self.profile_settings .context_servers .get(server_id.as_ref()) .and_then(|preset| preset.tools.get(name)) .copied() - .unwrap_or(self.profile.enable_all_context_servers) + .unwrap_or(self.profile_settings.enable_all_context_servers) } else { - self.profile.tools.get(name).copied().unwrap_or(false) + self.profile_settings + .tools + .get(name) + .copied() + .unwrap_or(false) }; Some( diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index b620d53c786011396e9e4dba860fe681561919ae..34ff249e95777bb02f87e755aa337e5e89710f12 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -1378,7 +1378,8 @@ impl AgentDiff { | ThreadEvent::CheckpointChanged | ThreadEvent::ToolConfirmationNeeded | ThreadEvent::ToolUseLimitReached - | ThreadEvent::CancelEditing => {} + | ThreadEvent::CancelEditing + | ThreadEvent::ProfileChanged => {} } } diff --git a/crates/agent/src/agent_profile.rs b/crates/agent/src/agent_profile.rs new file mode 100644 index 0000000000000000000000000000000000000000..5cd69bd3249f8422de7a7fede6c27674b3a24c97 --- /dev/null +++ b/crates/agent/src/agent_profile.rs @@ -0,0 +1,334 @@ +use std::sync::Arc; + +use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings}; +use assistant_tool::{Tool, ToolSource, ToolWorkingSet}; +use collections::IndexMap; +use convert_case::{Case, Casing}; +use fs::Fs; +use gpui::{App, Entity}; +use settings::{Settings, update_settings_file}; +use ui::SharedString; +use util::ResultExt; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct AgentProfile { + id: AgentProfileId, + tool_set: Entity, +} + +pub type AvailableProfiles = IndexMap; + +impl AgentProfile { + pub fn new(id: AgentProfileId, tool_set: Entity) -> Self { + Self { id, tool_set } + } + + /// Saves a new profile to the settings. + pub fn create( + name: String, + base_profile_id: Option, + fs: Arc, + cx: &App, + ) -> AgentProfileId { + let id = AgentProfileId(name.to_case(Case::Kebab).into()); + + let base_profile = + base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned()); + + let profile_settings = AgentProfileSettings { + name: name.into(), + tools: base_profile + .as_ref() + .map(|profile| profile.tools.clone()) + .unwrap_or_default(), + enable_all_context_servers: base_profile + .as_ref() + .map(|profile| profile.enable_all_context_servers) + .unwrap_or_default(), + context_servers: base_profile + .map(|profile| profile.context_servers) + .unwrap_or_default(), + }; + + update_settings_file::(fs, cx, { + let id = id.clone(); + move |settings, _cx| { + settings.create_profile(id, profile_settings).log_err(); + } + }); + + id + } + + /// Returns a map of AgentProfileIds to their names + pub fn available_profiles(cx: &App) -> AvailableProfiles { + let mut profiles = AvailableProfiles::default(); + for (id, profile) in AgentSettings::get_global(cx).profiles.iter() { + profiles.insert(id.clone(), profile.name.clone()); + } + profiles + } + + pub fn id(&self) -> &AgentProfileId { + &self.id + } + + pub fn enabled_tools(&self, cx: &App) -> Vec> { + let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else { + return Vec::new(); + }; + + self.tool_set + .read(cx) + .tools(cx) + .into_iter() + .filter(|tool| Self::is_enabled(settings, tool.source(), tool.name())) + .collect() + } + + fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool { + match source { + ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false), + ToolSource::ContextServer { id } => { + if settings.enable_all_context_servers { + return true; + } + + let Some(preset) = settings.context_servers.get(id.as_ref()) else { + return false; + }; + *preset.tools.get(name.as_str()).unwrap_or(&false) + } + } + } +} + +#[cfg(test)] +mod tests { + use agent_settings::ContextServerPreset; + use assistant_tool::ToolRegistry; + use collections::IndexMap; + use gpui::{AppContext, TestAppContext}; + use http_client::FakeHttpClient; + use project::Project; + use settings::{Settings, SettingsStore}; + use ui::SharedString; + + use super::*; + + #[gpui::test] + async fn test_enabled_built_in_tools_for_profile(cx: &mut TestAppContext) { + init_test_settings(cx); + + let id = AgentProfileId::default(); + let profile_settings = cx.read(|cx| { + AgentSettings::get_global(cx) + .profiles + .get(&id) + .unwrap() + .clone() + }); + let tool_set = default_tool_set(cx); + + let profile = AgentProfile::new(id.clone(), tool_set); + + let mut enabled_tools = cx + .read(|cx| profile.enabled_tools(cx)) + .into_iter() + .map(|tool| tool.name()) + .collect::>(); + enabled_tools.sort(); + + let mut expected_tools = profile_settings + .tools + .into_iter() + .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string())) + // Provider dependent + .filter(|tool| tool != "web_search") + .collect::>(); + // Plus all registered MCP tools + expected_tools.extend(["enabled_mcp_tool".into(), "disabled_mcp_tool".into()]); + expected_tools.sort(); + + assert_eq!(enabled_tools, expected_tools); + } + + #[gpui::test] + async fn test_custom_mcp_settings(cx: &mut TestAppContext) { + init_test_settings(cx); + + let id = AgentProfileId("custom_mcp".into()); + let profile_settings = cx.read(|cx| { + AgentSettings::get_global(cx) + .profiles + .get(&id) + .unwrap() + .clone() + }); + let tool_set = default_tool_set(cx); + + let profile = AgentProfile::new(id.clone(), tool_set); + + let mut enabled_tools = cx + .read(|cx| profile.enabled_tools(cx)) + .into_iter() + .map(|tool| tool.name()) + .collect::>(); + enabled_tools.sort(); + + let mut expected_tools = profile_settings.context_servers["mcp"] + .tools + .iter() + .filter_map(|(key, enabled)| enabled.then(|| key.to_string())) + .collect::>(); + expected_tools.sort(); + + assert_eq!(enabled_tools, expected_tools); + } + + #[gpui::test] + async fn test_only_built_in(cx: &mut TestAppContext) { + init_test_settings(cx); + + let id = AgentProfileId("write_minus_mcp".into()); + let profile_settings = cx.read(|cx| { + AgentSettings::get_global(cx) + .profiles + .get(&id) + .unwrap() + .clone() + }); + let tool_set = default_tool_set(cx); + + let profile = AgentProfile::new(id.clone(), tool_set); + + let mut enabled_tools = cx + .read(|cx| profile.enabled_tools(cx)) + .into_iter() + .map(|tool| tool.name()) + .collect::>(); + enabled_tools.sort(); + + let mut expected_tools = profile_settings + .tools + .into_iter() + .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string())) + // Provider dependent + .filter(|tool| tool != "web_search") + .collect::>(); + expected_tools.sort(); + + assert_eq!(enabled_tools, expected_tools); + } + + fn init_test_settings(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + AgentSettings::register(cx); + language_model::init_settings(cx); + ToolRegistry::default_global(cx); + assistant_tools::init(FakeHttpClient::with_404_response(), cx); + }); + + cx.update(|cx| { + let mut agent_settings = AgentSettings::get_global(cx).clone(); + agent_settings.profiles.insert( + AgentProfileId("write_minus_mcp".into()), + AgentProfileSettings { + name: "write_minus_mcp".into(), + enable_all_context_servers: false, + ..agent_settings.profiles[&AgentProfileId::default()].clone() + }, + ); + agent_settings.profiles.insert( + AgentProfileId("custom_mcp".into()), + AgentProfileSettings { + name: "mcp".into(), + tools: IndexMap::default(), + enable_all_context_servers: false, + context_servers: IndexMap::from_iter([("mcp".into(), context_server_preset())]), + }, + ); + AgentSettings::override_global(agent_settings, cx); + }) + } + + fn context_server_preset() -> ContextServerPreset { + ContextServerPreset { + tools: IndexMap::from_iter([ + ("enabled_mcp_tool".into(), true), + ("disabled_mcp_tool".into(), false), + ]), + } + } + + fn default_tool_set(cx: &mut TestAppContext) -> Entity { + cx.new(|_| { + let mut tool_set = ToolWorkingSet::default(); + tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp"))); + tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp"))); + tool_set + }) + } + + struct FakeTool { + name: String, + source: SharedString, + } + + impl FakeTool { + fn new(name: impl Into, source: impl Into) -> Self { + Self { + name: name.into(), + source: source.into(), + } + } + } + + impl Tool for FakeTool { + fn name(&self) -> String { + self.name.clone() + } + + fn source(&self) -> ToolSource { + ToolSource::ContextServer { + id: self.source.clone(), + } + } + + fn description(&self) -> String { + unimplemented!() + } + + fn icon(&self) -> ui::IconName { + unimplemented!() + } + + fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { + unimplemented!() + } + + fn ui_text(&self, _input: &serde_json::Value) -> String { + unimplemented!() + } + + fn run( + self: Arc, + _input: serde_json::Value, + _request: Arc, + _project: Entity, + _action_log: Entity, + _model: Arc, + _window: Option, + _cx: &mut App, + ) -> assistant_tool::ToolResult { + unimplemented!() + } + + fn may_perform_edits(&self) -> bool { + unimplemented!() + } + } +} diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 0ae326bd44f162df86fab2913deba32de35c20a9..a3958d9acbd8e651192d98ef45ac8e29ff315601 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -175,8 +175,7 @@ impl MessageEditor { ) }); - let incompatible_tools = - cx.new(|cx| IncompatibleToolsState::new(thread.read(cx).tools().clone(), cx)); + let incompatible_tools = cx.new(|cx| IncompatibleToolsState::new(thread.clone(), cx)); let subscriptions = vec![ cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event), @@ -204,15 +203,8 @@ impl MessageEditor { ) }); - let profile_selector = cx.new(|cx| { - ProfileSelector::new( - fs, - thread.clone(), - thread_store, - editor.focus_handle(cx), - cx, - ) - }); + let profile_selector = + cx.new(|cx| ProfileSelector::new(fs, thread.clone(), editor.focus_handle(cx), cx)); Self { editor: editor.clone(), diff --git a/crates/agent/src/profile_selector.rs b/crates/agent/src/profile_selector.rs index a51440ddb94296ff3ac4710eb4ccce21f396171c..7a42e45fa4f817a90b004e906fa88d0c3c55c40d 100644 --- a/crates/agent/src/profile_selector.rs +++ b/crates/agent/src/profile_selector.rs @@ -1,26 +1,24 @@ use std::sync::Arc; -use agent_settings::{ - AgentDockPosition, AgentProfile, AgentProfileId, AgentSettings, GroupedAgentProfiles, - builtin_profiles, -}; +use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles}; use fs::Fs; -use gpui::{Action, Empty, Entity, FocusHandle, Subscription, WeakEntity, prelude::*}; +use gpui::{Action, Empty, Entity, FocusHandle, Subscription, prelude::*}; use language_model::LanguageModelRegistry; use settings::{Settings as _, SettingsStore, update_settings_file}; use ui::{ ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*, }; -use util::ResultExt as _; -use crate::{ManageProfiles, Thread, ThreadStore, ToggleProfileSelector}; +use crate::{ + ManageProfiles, Thread, ToggleProfileSelector, + agent_profile::{AgentProfile, AvailableProfiles}, +}; pub struct ProfileSelector { - profiles: GroupedAgentProfiles, + profiles: AvailableProfiles, fs: Arc, thread: Entity, - thread_store: WeakEntity, menu_handle: PopoverMenuHandle, focus_handle: FocusHandle, _subscriptions: Vec, @@ -30,7 +28,6 @@ impl ProfileSelector { pub fn new( fs: Arc, thread: Entity, - thread_store: WeakEntity, focus_handle: FocusHandle, cx: &mut Context, ) -> Self { @@ -39,10 +36,9 @@ impl ProfileSelector { }); Self { - profiles: GroupedAgentProfiles::from_settings(AgentSettings::get_global(cx)), + profiles: AgentProfile::available_profiles(cx), fs, thread, - thread_store, menu_handle: PopoverMenuHandle::default(), focus_handle, _subscriptions: vec![settings_subscription], @@ -54,7 +50,7 @@ impl ProfileSelector { } fn refresh_profiles(&mut self, cx: &mut Context) { - self.profiles = GroupedAgentProfiles::from_settings(AgentSettings::get_global(cx)); + self.profiles = AgentProfile::available_profiles(cx); } fn build_context_menu( @@ -64,21 +60,30 @@ impl ProfileSelector { ) -> Entity { ContextMenu::build(window, cx, |mut menu, _window, cx| { let settings = AgentSettings::get_global(cx); - for (profile_id, profile) in self.profiles.builtin.iter() { + + let mut found_non_builtin = false; + for (profile_id, profile_name) in self.profiles.iter() { + if !builtin_profiles::is_builtin(profile_id) { + found_non_builtin = true; + continue; + } menu = menu.item(self.menu_entry_for_profile( profile_id.clone(), - profile, + profile_name, settings, cx, )); } - if !self.profiles.custom.is_empty() { + if found_non_builtin { menu = menu.separator().header("Custom Profiles"); - for (profile_id, profile) in self.profiles.custom.iter() { + for (profile_id, profile_name) in self.profiles.iter() { + if builtin_profiles::is_builtin(profile_id) { + continue; + } menu = menu.item(self.menu_entry_for_profile( profile_id.clone(), - profile, + profile_name, settings, cx, )); @@ -99,19 +104,20 @@ impl ProfileSelector { fn menu_entry_for_profile( &self, profile_id: AgentProfileId, - profile: &AgentProfile, + profile_name: &SharedString, settings: &AgentSettings, - _cx: &App, + cx: &App, ) -> ContextMenuEntry { - let documentation = match profile.name.to_lowercase().as_str() { + let documentation = match profile_name.to_lowercase().as_str() { builtin_profiles::WRITE => Some("Get help to write anything."), builtin_profiles::ASK => Some("Chat about your codebase."), builtin_profiles::MINIMAL => Some("Chat about anything with no tools."), _ => None, }; + let thread_profile_id = self.thread.read(cx).profile().id(); - let entry = ContextMenuEntry::new(profile.name.clone()) - .toggleable(IconPosition::End, profile_id == settings.default_profile); + let entry = ContextMenuEntry::new(profile_name.clone()) + .toggleable(IconPosition::End, &profile_id == thread_profile_id); let entry = if let Some(doc_text) = documentation { entry.documentation_aside(documentation_side(settings.dock), move |_| { @@ -123,7 +129,7 @@ impl ProfileSelector { entry.handler({ let fs = self.fs.clone(); - let thread_store = self.thread_store.clone(); + let thread = self.thread.clone(); let profile_id = profile_id.clone(); move |_window, cx| { update_settings_file::(fs.clone(), cx, { @@ -133,11 +139,9 @@ impl ProfileSelector { } }); - thread_store - .update(cx, |this, cx| { - this.load_profile_by_id(profile_id.clone(), cx); - }) - .log_err(); + thread.update(cx, |this, cx| { + this.set_profile(profile_id.clone(), cx); + }); } }) } @@ -146,7 +150,7 @@ impl ProfileSelector { impl Render for ProfileSelector { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { let settings = AgentSettings::get_global(cx); - let profile_id = &settings.default_profile; + let profile_id = self.thread.read(cx).profile().id(); let profile = settings.profiles.get(profile_id); let selected_profile = profile diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index f857557271cc08bb3d90949ab0c6ffd6c4c41d87..bb8cc706bb898e7d631e12ea41523c98f61980ea 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -4,7 +4,7 @@ use std::ops::Range; use std::sync::Arc; use std::time::Instant; -use agent_settings::{AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; @@ -41,6 +41,7 @@ use uuid::Uuid; use zed_llm_client::{CompletionIntent, CompletionRequestStatus}; use crate::ThreadStore; +use crate::agent_profile::AgentProfile; use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext}; use crate::thread_store::{ SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, @@ -360,6 +361,7 @@ pub struct Thread { >, remaining_turns: u32, configured_model: Option, + profile: AgentProfile, } #[derive(Clone, Debug, PartialEq, Eq)] @@ -407,6 +409,7 @@ impl Thread { ) -> Self { let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel(); let configured_model = LanguageModelRegistry::read_global(cx).default_model(); + let profile_id = AgentSettings::get_global(cx).default_profile.clone(); Self { id: ThreadId::new(), @@ -449,6 +452,7 @@ impl Thread { request_callback: None, remaining_turns: u32::MAX, configured_model, + profile: AgentProfile::new(profile_id, tools), } } @@ -495,6 +499,9 @@ impl Thread { let completion_mode = serialized .completion_mode .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode); + let profile_id = serialized + .profile + .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone()); Self { id, @@ -554,7 +561,7 @@ impl Thread { pending_checkpoint: None, project: project.clone(), prompt_builder, - tools, + tools: tools.clone(), tool_use, action_log: cx.new(|_| ActionLog::new(project)), initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(), @@ -570,6 +577,7 @@ impl Thread { request_callback: None, remaining_turns: u32::MAX, configured_model, + profile: AgentProfile::new(profile_id, tools), } } @@ -585,6 +593,17 @@ impl Thread { &self.id } + pub fn profile(&self) -> &AgentProfile { + &self.profile + } + + pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context) { + if &id != self.profile.id() { + self.profile = AgentProfile::new(id, self.tools.clone()); + cx.emit(ThreadEvent::ProfileChanged); + } + } + pub fn is_empty(&self) -> bool { self.messages.is_empty() } @@ -919,8 +938,7 @@ impl Thread { model: Arc, ) -> Vec { if model.supports_tools() { - self.tools() - .read(cx) + self.profile .enabled_tools(cx) .into_iter() .filter_map(|tool| { @@ -1180,6 +1198,7 @@ impl Thread { }), completion_mode: Some(this.completion_mode), tool_use_limit_reached: this.tool_use_limit_reached, + profile: Some(this.profile.id().clone()), }) }) } @@ -2121,7 +2140,7 @@ impl Thread { window: Option, cx: &mut Context, ) { - let available_tools = self.tools.read(cx).enabled_tools(cx); + let available_tools = self.profile.enabled_tools(cx); let tool_list = available_tools .iter() @@ -2213,19 +2232,15 @@ impl Thread { ) -> Task<()> { let tool_name: Arc = tool.name().into(); - let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) { - Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into() - } else { - tool.run( - input, - request, - self.project.clone(), - self.action_log.clone(), - model, - window, - cx, - ) - }; + let tool_result = tool.run( + input, + request, + self.project.clone(), + self.action_log.clone(), + model, + window, + cx, + ); // Store the card separately if it exists if let Some(card) = tool_result.card.clone() { @@ -2344,8 +2359,7 @@ impl Thread { let client = self.project.read(cx).client(); let enabled_tool_names: Vec = self - .tools() - .read(cx) + .profile .enabled_tools(cx) .iter() .map(|tool| tool.name()) @@ -2858,6 +2872,7 @@ pub enum ThreadEvent { ToolUseLimitReached, CancelEditing, CompletionCanceled, + ProfileChanged, } impl EventEmitter for Thread {} @@ -2872,7 +2887,7 @@ struct PendingCompletion { mod tests { use super::*; use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store}; - use agent_settings::{AgentSettings, LanguageModelParameters}; + use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters}; use assistant_tool::ToolRegistry; use editor::EditorSettings; use gpui::TestAppContext; @@ -3285,6 +3300,71 @@ fn main() {{ ); } + #[gpui::test] + async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let (_workspace, thread_store, thread, _context_store, _model) = + setup_test_environment(cx, project.clone()).await; + + // Check that we are starting with the default profile + let profile = cx.read(|cx| thread.read(cx).profile.clone()); + let tool_set = cx.read(|cx| thread_store.read(cx).tools()); + assert_eq!( + profile, + AgentProfile::new(AgentProfileId::default(), tool_set) + ); + } + + #[gpui::test] + async fn test_serializing_thread_profile(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let (_workspace, thread_store, thread, _context_store, _model) = + setup_test_environment(cx, project.clone()).await; + + // Profile gets serialized with default values + let serialized = thread + .update(cx, |thread, cx| thread.serialize(cx)) + .await + .unwrap(); + + assert_eq!(serialized.profile, Some(AgentProfileId::default())); + + let deserialized = cx.update(|cx| { + thread.update(cx, |thread, cx| { + Thread::deserialize( + thread.id.clone(), + serialized, + thread.project.clone(), + thread.tools.clone(), + thread.prompt_builder.clone(), + thread.project_context.clone(), + None, + cx, + ) + }) + }); + let tool_set = cx.read(|cx| thread_store.read(cx).tools()); + + assert_eq!( + deserialized.profile, + AgentProfile::new(AgentProfileId::default(), tool_set) + ); + } + #[gpui::test] async fn test_temperature_setting(cx: &mut TestAppContext) { init_test_settings(cx); diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 964cb8d75e0488943e17a0699fe2dcf9ef00ff85..504280fac405970ce0a710ca9a07eb7bab2da984 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -3,9 +3,9 @@ use std::path::{Path, PathBuf}; use std::rc::Rc; use std::sync::{Arc, Mutex}; -use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, CompletionMode}; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ToolId, ToolSource, ToolWorkingSet}; +use assistant_tool::{ToolId, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::HashMap; use context_server::ContextServerId; @@ -25,7 +25,6 @@ use prompt_store::{ UserRulesContext, WorktreeContext, }; use serde::{Deserialize, Serialize}; -use settings::{Settings as _, SettingsStore}; use ui::Window; use util::ResultExt as _; @@ -147,12 +146,7 @@ impl ThreadStore { prompt_store: Option>, cx: &mut Context, ) -> (Self, oneshot::Receiver<()>) { - let mut subscriptions = vec![ - cx.observe_global::(move |this: &mut Self, cx| { - this.load_default_profile(cx); - }), - cx.subscribe(&project, Self::handle_project_event), - ]; + let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)]; if let Some(prompt_store) = prompt_store.as_ref() { subscriptions.push(cx.subscribe( @@ -200,7 +194,6 @@ impl ThreadStore { _reload_system_prompt_task: reload_system_prompt_task, _subscriptions: subscriptions, }; - this.load_default_profile(cx); this.register_context_server_handlers(cx); this.reload(cx).detach_and_log_err(cx); (this, ready_rx) @@ -520,86 +513,6 @@ impl ThreadStore { }) } - fn load_default_profile(&self, cx: &mut Context) { - let assistant_settings = AgentSettings::get_global(cx); - - self.load_profile_by_id(assistant_settings.default_profile.clone(), cx); - } - - pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context) { - let assistant_settings = AgentSettings::get_global(cx); - - if let Some(profile) = assistant_settings.profiles.get(&profile_id) { - self.load_profile(profile.clone(), cx); - } - } - - pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context) { - self.tools.update(cx, |tools, cx| { - tools.disable_all_tools(cx); - tools.enable( - ToolSource::Native, - &profile - .tools - .into_iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool)) - .collect::>(), - cx, - ); - }); - - if profile.enable_all_context_servers { - for context_server_id in self - .project - .read(cx) - .context_server_store() - .read(cx) - .all_server_ids() - { - self.tools.update(cx, |tools, cx| { - tools.enable_source( - ToolSource::ContextServer { - id: context_server_id.0.into(), - }, - cx, - ); - }); - } - // Enable all the tools from all context servers, but disable the ones that are explicitly disabled - for (context_server_id, preset) in profile.context_servers { - self.tools.update(cx, |tools, cx| { - tools.disable( - ToolSource::ContextServer { - id: context_server_id.into(), - }, - &preset - .tools - .into_iter() - .filter_map(|(tool, enabled)| (!enabled).then(|| tool)) - .collect::>(), - cx, - ) - }) - } - } else { - for (context_server_id, preset) in profile.context_servers { - self.tools.update(cx, |tools, cx| { - tools.enable( - ToolSource::ContextServer { - id: context_server_id.into(), - }, - &preset - .tools - .into_iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool)) - .collect::>(), - cx, - ) - }) - } - } - } - fn register_context_server_handlers(&self, cx: &mut Context) { cx.subscribe( &self.project.read(cx).context_server_store(), @@ -618,6 +531,7 @@ impl ThreadStore { match event { project::context_server_store::Event::ServerStatusChanged { server_id, status } => { match status { + ContextServerStatus::Starting => {} ContextServerStatus::Running => { if let Some(server) = context_server_store.read(cx).get_running_server(server_id) @@ -656,10 +570,9 @@ impl ThreadStore { .log_err(); if let Some(tool_ids) = tool_ids { - this.update(cx, |this, cx| { + this.update(cx, |this, _| { this.context_server_tool_ids .insert(server_id, tool_ids); - this.load_default_profile(cx); }) .log_err(); } @@ -675,10 +588,8 @@ impl ThreadStore { tool_working_set.update(cx, |tool_working_set, _| { tool_working_set.remove(&tool_ids); }); - self.load_default_profile(cx); } } - _ => {} } } } @@ -714,6 +625,8 @@ pub struct SerializedThread { pub completion_mode: Option, #[serde(default)] pub tool_use_limit_reached: bool, + #[serde(default)] + pub profile: Option, } #[derive(Serialize, Deserialize, Debug)] @@ -856,6 +769,7 @@ impl LegacySerializedThread { model: None, completion_mode: None, tool_use_limit_reached: false, + profile: None, } } } diff --git a/crates/agent/src/tool_compatibility.rs b/crates/agent/src/tool_compatibility.rs index 141d87c96fb2867608fa1a5165beae611b7775e1..6193b0929d775f2cd4246de7fb7d15ddaa61aa3a 100644 --- a/crates/agent/src/tool_compatibility.rs +++ b/crates/agent/src/tool_compatibility.rs @@ -1,30 +1,33 @@ use std::sync::Arc; -use assistant_tool::{Tool, ToolSource, ToolWorkingSet, ToolWorkingSetEvent}; +use assistant_tool::{Tool, ToolSource}; use collections::HashMap; use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window}; use language_model::{LanguageModel, LanguageModelToolSchemaFormat}; use ui::prelude::*; +use crate::{Thread, ThreadEvent}; + pub struct IncompatibleToolsState { cache: HashMap>>, - tool_working_set: Entity, - _tool_working_set_subscription: Subscription, + thread: Entity, + _thread_subscription: Subscription, } impl IncompatibleToolsState { - pub fn new(tool_working_set: Entity, cx: &mut Context) -> Self { + pub fn new(thread: Entity, cx: &mut Context) -> Self { let _tool_working_set_subscription = - cx.subscribe(&tool_working_set, |this, _, event, _| match event { - ToolWorkingSetEvent::EnabledToolsChanged => { + cx.subscribe(&thread, |this, _, event, _| match event { + ThreadEvent::ProfileChanged => { this.cache.clear(); } + _ => {} }); Self { cache: HashMap::default(), - tool_working_set, - _tool_working_set_subscription, + thread, + _thread_subscription: _tool_working_set_subscription, } } @@ -36,8 +39,9 @@ impl IncompatibleToolsState { self.cache .entry(model.tool_input_format()) .or_insert_with(|| { - self.tool_working_set + self.thread .read(cx) + .profile() .enabled_tools(cx) .iter() .filter(|tool| tool.input_schema(model.tool_input_format()).is_err()) diff --git a/crates/agent_settings/Cargo.toml b/crates/agent_settings/Cargo.toml index 200c531c3c6c95b4dea0d1a653c68539993ba246..c6a4bedbb5e848d48a03b1d7cbb4329322d1c99b 100644 --- a/crates/agent_settings/Cargo.toml +++ b/crates/agent_settings/Cargo.toml @@ -16,7 +16,6 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true collections.workspace = true gpui.workspace = true -indexmap.workspace = true language_model.workspace = true lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true diff --git a/crates/agent_settings/src/agent_profile.rs b/crates/agent_settings/src/agent_profile.rs index 599932114a9f3901f8f5a5680b25337da892d28a..a6b8633b34d1e969e8cc8952dbc932e54d38a49f 100644 --- a/crates/agent_settings/src/agent_profile.rs +++ b/crates/agent_settings/src/agent_profile.rs @@ -17,29 +17,6 @@ pub mod builtin_profiles { } } -#[derive(Default)] -pub struct GroupedAgentProfiles { - pub builtin: IndexMap, - pub custom: IndexMap, -} - -impl GroupedAgentProfiles { - pub fn from_settings(settings: &crate::AgentSettings) -> Self { - let mut builtin = IndexMap::default(); - let mut custom = IndexMap::default(); - - for (profile_id, profile) in settings.profiles.clone() { - if builtin_profiles::is_builtin(&profile_id) { - builtin.insert(profile_id, profile); - } else { - custom.insert(profile_id, profile); - } - } - - Self { builtin, custom } - } -} - #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize, JsonSchema)] pub struct AgentProfileId(pub Arc); @@ -63,7 +40,7 @@ impl Default for AgentProfileId { /// A profile for the Zed Agent that controls its behavior. #[derive(Debug, Clone)] -pub struct AgentProfile { +pub struct AgentProfileSettings { /// The name of the profile. pub name: SharedString, pub tools: IndexMap, bool>, diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 36480f30d5a4d4a2c25e215fae7c1efb213b2c98..9e8fd0c699ff47fdadc069f5fcaee8408b11495d 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -102,7 +102,7 @@ pub struct AgentSettings { pub using_outdated_settings_version: bool, pub default_profile: AgentProfileId, pub default_view: DefaultView, - pub profiles: IndexMap, + pub profiles: IndexMap, pub always_allow_tool_actions: bool, pub notify_when_agent_waiting: NotifyWhenAgentWaiting, pub play_sound_when_agent_done: bool, @@ -531,7 +531,7 @@ impl AgentSettingsContent { pub fn create_profile( &mut self, profile_id: AgentProfileId, - profile: AgentProfile, + profile_settings: AgentProfileSettings, ) -> Result<()> { self.v2_setting(|settings| { let profiles = settings.profiles.get_or_insert_default(); @@ -542,10 +542,10 @@ impl AgentSettingsContent { profiles.insert( profile_id, AgentProfileContent { - name: profile.name.into(), - tools: profile.tools, - enable_all_context_servers: Some(profile.enable_all_context_servers), - context_servers: profile + name: profile_settings.name.into(), + tools: profile_settings.tools, + enable_all_context_servers: Some(profile_settings.enable_all_context_servers), + context_servers: profile_settings .context_servers .into_iter() .map(|(server_id, preset)| { @@ -910,7 +910,7 @@ impl Settings for AgentSettings { .extend(profiles.into_iter().map(|(id, profile)| { ( id, - AgentProfile { + AgentProfileSettings { name: profile.name.into(), tools: profile.tools, enable_all_context_servers: profile diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index c7e20d3517ad6bb559961f6d211339fc6781d06a..c72c52ba7a668ca31c91242872d7ef0c4834fb17 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use collections::{HashMap, HashSet, IndexMap}; -use gpui::{App, Context, EventEmitter}; +use collections::{HashMap, IndexMap}; +use gpui::App; use crate::{Tool, ToolRegistry, ToolSource}; @@ -13,17 +13,9 @@ pub struct ToolId(usize); pub struct ToolWorkingSet { context_server_tools_by_id: HashMap>, context_server_tools_by_name: HashMap>, - enabled_sources: HashSet, - enabled_tools_by_source: HashMap>>, next_tool_id: ToolId, } -pub enum ToolWorkingSetEvent { - EnabledToolsChanged, -} - -impl EventEmitter for ToolWorkingSet {} - impl ToolWorkingSet { pub fn tool(&self, name: &str, cx: &App) -> Option> { self.context_server_tools_by_name @@ -57,42 +49,6 @@ impl ToolWorkingSet { tools_by_source } - pub fn enabled_tools(&self, cx: &App) -> Vec> { - let all_tools = self.tools(cx); - - all_tools - .into_iter() - .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into())) - .collect() - } - - pub fn disable_all_tools(&mut self, cx: &mut Context) { - self.enabled_tools_by_source.clear(); - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - - pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context) { - self.enabled_sources.insert(source.clone()); - - let tools_by_source = self.tools_by_source(cx); - if let Some(tools) = tools_by_source.get(&source) { - self.enabled_tools_by_source.insert( - source, - tools - .into_iter() - .map(|tool| tool.name().into()) - .collect::>(), - ); - } - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - - pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context) { - self.enabled_sources.remove(source); - self.enabled_tools_by_source.remove(source); - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - pub fn insert(&mut self, tool: Arc) -> ToolId { let tool_id = self.next_tool_id; self.next_tool_id.0 += 1; @@ -102,42 +58,6 @@ impl ToolWorkingSet { tool_id } - pub fn is_enabled(&self, source: &ToolSource, name: &Arc) -> bool { - self.enabled_tools_by_source - .get(source) - .map_or(false, |enabled_tools| enabled_tools.contains(name)) - } - - pub fn is_disabled(&self, source: &ToolSource, name: &Arc) -> bool { - !self.is_enabled(source, name) - } - - pub fn enable( - &mut self, - source: ToolSource, - tools_to_enable: &[Arc], - cx: &mut Context, - ) { - self.enabled_tools_by_source - .entry(source) - .or_default() - .extend(tools_to_enable.into_iter().cloned()); - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - - pub fn disable( - &mut self, - source: ToolSource, - tools_to_disable: &[Arc], - cx: &mut Context, - ) { - self.enabled_tools_by_source - .entry(source) - .or_default() - .retain(|name| !tools_to_disable.contains(name)); - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) { self.context_server_tools_by_id .retain(|id, _| !tool_ids_to_remove.contains(id)); diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 020aedbc57220fa44954bdd2fb55139a4622c78e..a91fdac992a6f69768f5324cdb4ed88d5c47620e 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -80,7 +80,6 @@ zed_llm_client.workspace = true agent_settings.workspace = true assistant_context_editor.workspace = true assistant_slash_command.workspace = true -assistant_tool.workspace = true async-trait.workspace = true audio.workspace = true buffer_diff.workspace = true diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index dc384668c33fcee4f1d0ba4e6634787dc50a1b6f..85af49e3397ab93bd2ab62ccd4996a2de3698575 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -294,6 +294,7 @@ impl ExampleContext { | ThreadEvent::MessageDeleted(_) | ThreadEvent::SummaryChanged | ThreadEvent::SummaryGenerated + | ThreadEvent::ProfileChanged | ThreadEvent::ReceivedTextChunk | ThreadEvent::StreamedToolUse { .. } | ThreadEvent::CheckpointChanged diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 94fdaf90bf76401dde61d03a22447bda7e4b1efd..f28165e859be017b28e26e359d0df9e5b1f63391 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -306,17 +306,19 @@ impl ExampleInstance { let thread_store = thread_store.await?; - let profile_id = meta.profile_id.clone(); - thread_store.update(cx, |thread_store, cx| thread_store.load_profile_by_id(profile_id, cx)).expect("Failed to load profile"); let thread = thread_store.update(cx, |thread_store, cx| { - if let Some(json) = &meta.existing_thread_json { + let thread = if let Some(json) = &meta.existing_thread_json { let serialized = SerializedThread::from_json(json.as_bytes()).expect("Can't read serialized thread"); thread_store.create_thread_from_serialized(serialized, cx) } else { thread_store.create_thread(cx) - } + }; + thread.update(cx, |thread, cx| { + thread.set_profile(meta.profile_id.clone(), cx); + }); + thread })?;