From 848a99c605035b6bf26a3330319659b93fcd24f0 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 26 Mar 2025 16:26:26 -0400 Subject: [PATCH] assistant2: Rework enabled tool representation (#27527) This PR reworks how we store enabled tools in the `ToolWorkingSet`. We now track them based on which tools are explicitly enabled, rather than by the tools that have been disabled. Also fixed an issue where switching profiles wouldn't properly set the right tools. Release Notes: - N/A --- crates/assistant2/src/profile_selector.rs | 2 +- crates/assistant2/src/thread_store.rs | 53 ++++++++------- crates/assistant_tool/src/tool_working_set.rs | 68 ++++++------------- 3 files changed, 53 insertions(+), 70 deletions(-) diff --git a/crates/assistant2/src/profile_selector.rs b/crates/assistant2/src/profile_selector.rs index b9b2c45773e966ca52b3d453fa855237c78245a6..c0ba19016c6b34ef7a758929d6d96240f8a3690b 100644 --- a/crates/assistant2/src/profile_selector.rs +++ b/crates/assistant2/src/profile_selector.rs @@ -78,7 +78,7 @@ impl ProfileSelector { thread_store .update(cx, |this, cx| { - this.load_default_profile(cx); + this.load_profile_by_id(&profile_id, cx); }) .log_err(); } diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs index b2ace1a2ace234e7d3b9d49066d1dee372181a99..aa2e0aa97153b498a3d1a040d0c2ad3c4145ee4b 100644 --- a/crates/assistant2/src/thread_store.rs +++ b/crates/assistant2/src/thread_store.rs @@ -3,7 +3,7 @@ use std::path::PathBuf; use std::sync::Arc; use anyhow::{anyhow, Result}; -use assistant_settings::AssistantSettings; +use assistant_settings::{AgentProfile, AssistantSettings}; use assistant_tool::{ToolId, ToolSource, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::HashMap; @@ -187,35 +187,42 @@ impl ThreadStore { }) } - pub fn load_default_profile(&self, cx: &mut Context) { + fn load_default_profile(&self, cx: &Context) { let assistant_settings = AssistantSettings::get_global(cx); - if let Some(profile) = assistant_settings - .profiles - .get(&assistant_settings.default_profile) - { - self.tools.disable_source(ToolSource::Native, cx); + self.load_profile_by_id(&assistant_settings.default_profile, cx); + } + + pub fn load_profile_by_id(&self, profile_id: &Arc, cx: &Context) { + let assistant_settings = AssistantSettings::get_global(cx); + + if let Some(profile) = assistant_settings.profiles.get(profile_id) { + self.load_profile(profile); + } + } + + pub fn load_profile(&self, profile: &AgentProfile) { + self.tools.disable_all_tools(); + self.tools.enable( + ToolSource::Native, + &profile + .tools + .iter() + .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) + .collect::>(), + ); + + for (context_server_id, preset) in &profile.context_servers { self.tools.enable( - ToolSource::Native, - &profile + ToolSource::ContextServer { + id: context_server_id.clone().into(), + }, + &preset .tools .iter() .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) .collect::>(), - ); - - for (context_server_id, preset) in &profile.context_servers { - self.tools.enable( - ToolSource::ContextServer { - id: context_server_id.clone().into(), - }, - &preset - .tools - .iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) - .collect::>(), - ) - } + ) } } diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index 2c1acad053eb770ce0629e31e9a75245591d1e40..82a4455920267de7dc377155b07f49b5569c8360 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -19,7 +19,7 @@ pub struct ToolWorkingSet { struct WorkingSetState { context_server_tools_by_id: HashMap>, context_server_tools_by_name: HashMap>, - disabled_tools_by_source: HashMap>>, + enabled_tools_by_source: HashMap>>, next_tool_id: ToolId, } @@ -41,38 +41,23 @@ impl ToolWorkingSet { self.state.lock().tools_by_source(cx) } - pub fn are_all_tools_enabled(&self) -> bool { - let state = self.state.lock(); - state.disabled_tools_by_source.is_empty() - } - - pub fn are_all_tools_from_source_enabled(&self, source: &ToolSource) -> bool { - let state = self.state.lock(); - !state.disabled_tools_by_source.contains_key(source) - } - pub fn enabled_tools(&self, cx: &App) -> Vec> { self.state.lock().enabled_tools(cx) } - pub fn enable_all_tools(&self) { + pub fn disable_all_tools(&self) { let mut state = self.state.lock(); - state.disabled_tools_by_source.clear(); + state.disable_all_tools(); } - pub fn disable_all_tools(&self, cx: &App) { + pub fn enable_source(&self, source: ToolSource, cx: &App) { let mut state = self.state.lock(); - state.disable_all_tools(cx); + state.enable_source(source, cx); } - pub fn enable_source(&self, source: &ToolSource) { + pub fn disable_source(&self, source: &ToolSource) { let mut state = self.state.lock(); - state.enable_source(source); - } - - pub fn disable_source(&self, source: ToolSource, cx: &App) { - let mut state = self.state.lock(); - state.disable_source(source, cx); + state.disable_source(source); } pub fn insert(&self, tool: Arc) -> ToolId { @@ -159,40 +144,36 @@ impl WorkingSetState { } fn is_enabled(&self, source: &ToolSource, name: &Arc) -> bool { - !self.is_disabled(source, name) + self.enabled_tools_by_source + .get(source) + .map_or(false, |enabled_tools| enabled_tools.contains(name)) } fn is_disabled(&self, source: &ToolSource, name: &Arc) -> bool { - self.disabled_tools_by_source - .get(source) - .map_or(false, |disabled_tools| disabled_tools.contains(name)) + !self.is_enabled(source, name) } fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc]) { - self.disabled_tools_by_source + self.enabled_tools_by_source .entry(source) .or_default() - .retain(|name| !tools_to_enable.contains(name)); + .extend(tools_to_enable.into_iter().cloned()); } fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc]) { - self.disabled_tools_by_source + self.enabled_tools_by_source .entry(source) .or_default() - .extend(tools_to_disable.into_iter().cloned()); - } - - fn enable_source(&mut self, source: &ToolSource) { - self.disabled_tools_by_source.remove(source); + .retain(|name| !tools_to_disable.contains(name)); } - fn disable_source(&mut self, source: ToolSource, cx: &App) { + fn enable_source(&mut self, source: ToolSource, cx: &App) { let tools_by_source = self.tools_by_source(cx); let Some(tools) = tools_by_source.get(&source) else { return; }; - self.disabled_tools_by_source.insert( + self.enabled_tools_by_source.insert( source, tools .into_iter() @@ -201,16 +182,11 @@ impl WorkingSetState { ); } - fn disable_all_tools(&mut self, cx: &App) { - let tools = self.tools_by_source(cx); - - for (source, tools) in tools { - let tool_names = tools - .into_iter() - .map(|tool| tool.name().into()) - .collect::>(); + fn disable_source(&mut self, source: &ToolSource) { + self.enabled_tools_by_source.remove(source); + } - self.disable(source, &tool_names); - } + fn disable_all_tools(&mut self) { + self.enabled_tools_by_source.clear(); } }