From 12037dc2c64d8789357d5f3f50fcabd98f0898db Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Tue, 1 Apr 2025 11:25:23 -0400 Subject: [PATCH] assistant2: Allow profiles to enable all context servers (#27847) This PR adds a new `enable_all_context_servers` field to agent profiles to allow them to enable all context servers without having to opt into them individually. The "Write" profile will now have all context servers enabled out of the box. Release Notes: - N/A --- assets/settings/default.json | 3 ++ .../manage_profiles_modal.rs | 4 ++ .../assistant_configuration/tool_picker.rs | 7 ++- crates/assistant2/src/thread_store.rs | 52 +++++++++++++------ .../assistant_settings/src/agent_profile.rs | 1 + .../src/assistant_settings.rs | 6 +++ crates/assistant_tool/src/tool_working_set.rs | 24 +++++---- 7 files changed, 68 insertions(+), 29 deletions(-) diff --git a/assets/settings/default.json b/assets/settings/default.json index 46dae6ccecf5ec05f340109842107a456ed237aa..515ca6746decb8e680e3a1c9f94d1a3a3e8bc9ee 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -637,6 +637,8 @@ "profiles": { "ask": { "name": "Ask", + // We don't know which of the context server tools are safe for the "Ask" profile, so we don't enable them by default. + // "enable_all_context_servers": true, "tools": { "diagnostics": true, "fetch": true, @@ -650,6 +652,7 @@ }, "write": { "name": "Write", + "enable_all_context_servers": true, "tools": { "bash": true, "batch-tool": true, diff --git a/crates/assistant2/src/assistant_configuration/manage_profiles_modal.rs b/crates/assistant2/src/assistant_configuration/manage_profiles_modal.rs index c05c351a0d2c8bfd1a283dd47f44e1e514465387..3cdff034403985534a6d628f0ab9d88828b44e6a 100644 --- a/crates/assistant2/src/assistant_configuration/manage_profiles_modal.rs +++ b/crates/assistant2/src/assistant_configuration/manage_profiles_modal.rs @@ -227,6 +227,10 @@ impl ManageProfilesModal { .as_ref() .map(|profile| profile.tools.clone()) .unwrap_or_default(), + enable_all_context_servers: base_profile + .as_ref() + .map(|profile| profile.enable_all_context_servers) + .unwrap_or_default(), context_servers: base_profile .map(|profile| profile.context_servers) .unwrap_or_default(), diff --git a/crates/assistant2/src/assistant_configuration/tool_picker.rs b/crates/assistant2/src/assistant_configuration/tool_picker.rs index 5c1ca4bed587a372f5a545c8ca8b23e79efbbd1a..d8bbf449a5d6c5b58539cd05ac2c50ad236cf9aa 100644 --- a/crates/assistant2/src/assistant_configuration/tool_picker.rs +++ b/crates/assistant2/src/assistant_configuration/tool_picker.rs @@ -191,8 +191,8 @@ impl PickerDelegate for ToolPickerDelegate { let active_profile_id = &AssistantSettings::get_global(cx).default_profile; if active_profile_id == &self.profile_id { self.thread_store - .update(cx, |this, _cx| { - this.load_profile(&self.profile); + .update(cx, |this, cx| { + this.load_profile(&self.profile, cx); }) .log_err(); } @@ -212,6 +212,9 @@ impl PickerDelegate for ToolPickerDelegate { .or_insert_with(|| AgentProfileContent { name: default_profile.name.into(), tools: default_profile.tools, + enable_all_context_servers: Some( + default_profile.enable_all_context_servers, + ), context_servers: default_profile .context_servers .into_iter() diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs index 86471215f83e8c479461c723d5d6df4f2abd2a17..bd540d5d26c3dbd93adbf97df9b8b74968f3e962 100644 --- a/crates/assistant2/src/thread_store.rs +++ b/crates/assistant2/src/thread_store.rs @@ -12,7 +12,8 @@ use context_server::{ContextServerFactoryRegistry, ContextServerTool}; use futures::FutureExt as _; use futures::future::{self, BoxFuture, Shared}; use gpui::{ - App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task, prelude::*, + App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task, + prelude::*, }; use heed::Database; use heed::types::SerdeBincode; @@ -20,7 +21,7 @@ use language_model::{LanguageModelToolUseId, Role, TokenUsage}; use project::Project; use prompt_store::PromptBuilder; use serde::{Deserialize, Serialize}; -use settings::Settings as _; +use settings::{Settings as _, SettingsStore}; use util::ResultExt as _; use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId}; @@ -36,6 +37,7 @@ pub struct ThreadStore { context_server_manager: Entity, context_server_tool_ids: HashMap, Vec>, threads: Vec, + _subscriptions: Vec, } impl ThreadStore { @@ -50,6 +52,10 @@ impl ThreadStore { let context_server_manager = cx.new(|cx| { ContextServerManager::new(context_server_factory_registry, project.clone(), cx) }); + let settings_subscription = + cx.observe_global::(move |this: &mut Self, cx| { + this.load_default_profile(cx); + }); let this = Self { project, @@ -58,6 +64,7 @@ impl ThreadStore { context_server_manager, context_server_tool_ids: HashMap::default(), threads: Vec::new(), + _subscriptions: vec![settings_subscription], }; this.load_default_profile(cx); this.register_context_server_handlers(cx); @@ -197,11 +204,11 @@ impl ThreadStore { let assistant_settings = AssistantSettings::get_global(cx); if let Some(profile) = assistant_settings.profiles.get(profile_id) { - self.load_profile(profile); + self.load_profile(profile, cx); } } - pub fn load_profile(&self, profile: &AgentProfile) { + pub fn load_profile(&self, profile: &AgentProfile, cx: &Context) { self.tools.disable_all_tools(); self.tools.enable( ToolSource::Native, @@ -212,17 +219,28 @@ impl ThreadStore { .collect::>(), ); - for (context_server_id, preset) in &profile.context_servers { - self.tools.enable( - ToolSource::ContextServer { - id: context_server_id.clone().into(), - }, - &preset - .tools - .iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) - .collect::>(), - ) + if profile.enable_all_context_servers { + for context_server in self.context_server_manager.read(cx).all_servers() { + self.tools.enable_source( + ToolSource::ContextServer { + id: context_server.id().into(), + }, + cx, + ); + } + } else { + for (context_server_id, preset) in &profile.context_servers { + self.tools.enable( + ToolSource::ContextServer { + id: context_server_id.clone().into(), + }, + &preset + .tools + .iter() + .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) + .collect::>(), + ) + } } } @@ -273,8 +291,9 @@ impl ThreadStore { }) .collect::>(); - this.update(cx, |this, _cx| { + this.update(cx, |this, cx| { this.context_server_tool_ids.insert(server_id, tool_ids); + this.load_default_profile(cx); }) .log_err(); } @@ -287,6 +306,7 @@ impl ThreadStore { context_server::manager::Event::ServerStopped { server_id } => { if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { tool_working_set.remove(&tool_ids); + self.load_default_profile(cx); } } } diff --git a/crates/assistant_settings/src/agent_profile.rs b/crates/assistant_settings/src/agent_profile.rs index 0e9459fcf09ce9296bd4a955ce6fd97f9c3f3e7b..00a7fbcedb96da2fcb3c3138112eae7bc962af35 100644 --- a/crates/assistant_settings/src/agent_profile.rs +++ b/crates/assistant_settings/src/agent_profile.rs @@ -9,6 +9,7 @@ pub struct AgentProfile { /// The name of the profile. pub name: SharedString, pub tools: IndexMap, bool>, + pub enable_all_context_servers: bool, pub context_servers: IndexMap, ContextServerPreset>, } diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index 102fad134f77f21d2e9c7603a360b479108ea0a7..0a5af98ab57163d6ed56967d5a66e6b716c77527 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -352,6 +352,7 @@ impl AssistantSettingsContent { AgentProfileContent { name: profile.name.into(), tools: profile.tools, + enable_all_context_servers: Some(profile.enable_all_context_servers), context_servers: profile .context_servers .into_iter() @@ -485,6 +486,8 @@ impl Default for LanguageModelSelection { pub struct AgentProfileContent { pub name: Arc, pub tools: IndexMap, bool>, + /// Whether all context servers are enabled by default. + pub enable_all_context_servers: Option, #[serde(default)] pub context_servers: IndexMap, ContextServerPresetContent>, } @@ -607,6 +610,9 @@ impl Settings for AssistantSettings { AgentProfile { name: profile.name.into(), tools: profile.tools, + enable_all_context_servers: profile + .enable_all_context_servers + .unwrap_or_default(), context_servers: profile .context_servers .into_iter() diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index 82a4455920267de7dc377155b07f49b5569c8360..97060cfdad1b9c5fbf62761b9730d15ebdbff94e 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -19,6 +19,7 @@ pub struct ToolWorkingSet { struct WorkingSetState { context_server_tools_by_id: HashMap>, context_server_tools_by_name: HashMap>, + enabled_sources: HashSet, enabled_tools_by_source: HashMap>>, next_tool_id: ToolId, } @@ -168,21 +169,22 @@ impl WorkingSetState { } fn enable_source(&mut self, source: ToolSource, cx: &App) { + self.enabled_sources.insert(source.clone()); + let tools_by_source = self.tools_by_source(cx); - let Some(tools) = tools_by_source.get(&source) else { - return; - }; - - self.enabled_tools_by_source.insert( - source, - tools - .into_iter() - .map(|tool| tool.name().into()) - .collect::>(), - ); + if let Some(tools) = tools_by_source.get(&source) { + self.enabled_tools_by_source.insert( + source, + tools + .into_iter() + .map(|tool| tool.name().into()) + .collect::>(), + ); + } } fn disable_source(&mut self, source: &ToolSource) { + self.enabled_sources.remove(source); self.enabled_tools_by_source.remove(source); }