@@ -637,6 +637,8 @@
"profiles": {
"ask": {
"name": "Ask",
+ // We don't know which of the context server tools are safe for the "Ask" profile, so we don't enable them by default.
+ // "enable_all_context_servers": true,
"tools": {
"diagnostics": true,
"fetch": true,
@@ -650,6 +652,7 @@
},
"write": {
"name": "Write",
+ "enable_all_context_servers": true,
"tools": {
"bash": true,
"batch-tool": true,
@@ -191,8 +191,8 @@ impl PickerDelegate for ToolPickerDelegate {
let active_profile_id = &AssistantSettings::get_global(cx).default_profile;
if active_profile_id == &self.profile_id {
self.thread_store
- .update(cx, |this, _cx| {
- this.load_profile(&self.profile);
+ .update(cx, |this, cx| {
+ this.load_profile(&self.profile, cx);
})
.log_err();
}
@@ -212,6 +212,9 @@ impl PickerDelegate for ToolPickerDelegate {
.or_insert_with(|| AgentProfileContent {
name: default_profile.name.into(),
tools: default_profile.tools,
+ enable_all_context_servers: Some(
+ default_profile.enable_all_context_servers,
+ ),
context_servers: default_profile
.context_servers
.into_iter()
@@ -12,7 +12,8 @@ use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use futures::FutureExt as _;
use futures::future::{self, BoxFuture, Shared};
use gpui::{
- App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task, prelude::*,
+ App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
+ prelude::*,
};
use heed::Database;
use heed::types::SerdeBincode;
@@ -20,7 +21,7 @@ use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::Project;
use prompt_store::PromptBuilder;
use serde::{Deserialize, Serialize};
-use settings::Settings as _;
+use settings::{Settings as _, SettingsStore};
use util::ResultExt as _;
use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
@@ -36,6 +37,7 @@ pub struct ThreadStore {
context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>,
+ _subscriptions: Vec<Subscription>,
}
impl ThreadStore {
@@ -50,6 +52,10 @@ impl ThreadStore {
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
+ let settings_subscription =
+ cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
+ this.load_default_profile(cx);
+ });
let this = Self {
project,
@@ -58,6 +64,7 @@ impl ThreadStore {
context_server_manager,
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
+ _subscriptions: vec![settings_subscription],
};
this.load_default_profile(cx);
this.register_context_server_handlers(cx);
@@ -197,11 +204,11 @@ impl ThreadStore {
let assistant_settings = AssistantSettings::get_global(cx);
if let Some(profile) = assistant_settings.profiles.get(profile_id) {
- self.load_profile(profile);
+ self.load_profile(profile, cx);
}
}
- pub fn load_profile(&self, profile: &AgentProfile) {
+ pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
self.tools.disable_all_tools();
self.tools.enable(
ToolSource::Native,
@@ -212,17 +219,28 @@ impl ThreadStore {
.collect::<Vec<_>>(),
);
- for (context_server_id, preset) in &profile.context_servers {
- self.tools.enable(
- ToolSource::ContextServer {
- id: context_server_id.clone().into(),
- },
- &preset
- .tools
- .iter()
- .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
- .collect::<Vec<_>>(),
- )
+ if profile.enable_all_context_servers {
+ for context_server in self.context_server_manager.read(cx).all_servers() {
+ self.tools.enable_source(
+ ToolSource::ContextServer {
+ id: context_server.id().into(),
+ },
+ cx,
+ );
+ }
+ } else {
+ for (context_server_id, preset) in &profile.context_servers {
+ self.tools.enable(
+ ToolSource::ContextServer {
+ id: context_server_id.clone().into(),
+ },
+ &preset
+ .tools
+ .iter()
+ .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
+ .collect::<Vec<_>>(),
+ )
+ }
}
}
@@ -273,8 +291,9 @@ impl ThreadStore {
})
.collect::<Vec<_>>();
- this.update(cx, |this, _cx| {
+ this.update(cx, |this, cx| {
this.context_server_tool_ids.insert(server_id, tool_ids);
+ this.load_default_profile(cx);
})
.log_err();
}
@@ -287,6 +306,7 @@ impl ThreadStore {
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.remove(&tool_ids);
+ self.load_default_profile(cx);
}
}
}
@@ -352,6 +352,7 @@ impl AssistantSettingsContent {
AgentProfileContent {
name: profile.name.into(),
tools: profile.tools,
+ enable_all_context_servers: Some(profile.enable_all_context_servers),
context_servers: profile
.context_servers
.into_iter()
@@ -485,6 +486,8 @@ impl Default for LanguageModelSelection {
pub struct AgentProfileContent {
pub name: Arc<str>,
pub tools: IndexMap<Arc<str>, bool>,
+ /// Whether all context servers are enabled by default.
+ pub enable_all_context_servers: Option<bool>,
#[serde(default)]
pub context_servers: IndexMap<Arc<str>, ContextServerPresetContent>,
}
@@ -607,6 +610,9 @@ impl Settings for AssistantSettings {
AgentProfile {
name: profile.name.into(),
tools: profile.tools,
+ enable_all_context_servers: profile
+ .enable_all_context_servers
+ .unwrap_or_default(),
context_servers: profile
.context_servers
.into_iter()
@@ -19,6 +19,7 @@ pub struct ToolWorkingSet {
struct WorkingSetState {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
+ enabled_sources: HashSet<ToolSource>,
enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
next_tool_id: ToolId,
}
@@ -168,21 +169,22 @@ impl WorkingSetState {
}
fn enable_source(&mut self, source: ToolSource, cx: &App) {
+ self.enabled_sources.insert(source.clone());
+
let tools_by_source = self.tools_by_source(cx);
- let Some(tools) = tools_by_source.get(&source) else {
- return;
- };
-
- self.enabled_tools_by_source.insert(
- source,
- tools
- .into_iter()
- .map(|tool| tool.name().into())
- .collect::<HashSet<_>>(),
- );
+ if let Some(tools) = tools_by_source.get(&source) {
+ self.enabled_tools_by_source.insert(
+ source,
+ tools
+ .into_iter()
+ .map(|tool| tool.name().into())
+ .collect::<HashSet<_>>(),
+ );
+ }
}
fn disable_source(&mut self, source: &ToolSource) {
+ self.enabled_sources.remove(source);
self.enabled_tools_by_source.remove(source);
}