@@ -3,7 +3,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{anyhow, Result};
-use assistant_settings::AssistantSettings;
+use assistant_settings::{AgentProfile, AssistantSettings};
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
@@ -187,35 +187,42 @@ impl ThreadStore {
})
}
- pub fn load_default_profile(&self, cx: &mut Context<Self>) {
+ fn load_default_profile(&self, cx: &Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx);
- if let Some(profile) = assistant_settings
- .profiles
- .get(&assistant_settings.default_profile)
- {
- self.tools.disable_source(ToolSource::Native, cx);
+ self.load_profile_by_id(&assistant_settings.default_profile, cx);
+ }
+
+ pub fn load_profile_by_id(&self, profile_id: &Arc<str>, cx: &Context<Self>) {
+ let assistant_settings = AssistantSettings::get_global(cx);
+
+ if let Some(profile) = assistant_settings.profiles.get(profile_id) {
+ self.load_profile(profile);
+ }
+ }
+
+ pub fn load_profile(&self, profile: &AgentProfile) {
+ self.tools.disable_all_tools();
+ self.tools.enable(
+ ToolSource::Native,
+ &profile
+ .tools
+ .iter()
+ .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
+ .collect::<Vec<_>>(),
+ );
+
+ for (context_server_id, preset) in &profile.context_servers {
self.tools.enable(
- ToolSource::Native,
- &profile
+ ToolSource::ContextServer {
+ id: context_server_id.clone().into(),
+ },
+ &preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
- );
-
- for (context_server_id, preset) in &profile.context_servers {
- self.tools.enable(
- ToolSource::ContextServer {
- id: context_server_id.clone().into(),
- },
- &preset
- .tools
- .iter()
- .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
- .collect::<Vec<_>>(),
- )
- }
+ )
}
}
@@ -19,7 +19,7 @@ pub struct ToolWorkingSet {
struct WorkingSetState {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
- disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
+ enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
next_tool_id: ToolId,
}
@@ -41,38 +41,23 @@ impl ToolWorkingSet {
self.state.lock().tools_by_source(cx)
}
- pub fn are_all_tools_enabled(&self) -> bool {
- let state = self.state.lock();
- state.disabled_tools_by_source.is_empty()
- }
-
- pub fn are_all_tools_from_source_enabled(&self, source: &ToolSource) -> bool {
- let state = self.state.lock();
- !state.disabled_tools_by_source.contains_key(source)
- }
-
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
self.state.lock().enabled_tools(cx)
}
- pub fn enable_all_tools(&self) {
+ pub fn disable_all_tools(&self) {
let mut state = self.state.lock();
- state.disabled_tools_by_source.clear();
+ state.disable_all_tools();
}
- pub fn disable_all_tools(&self, cx: &App) {
+ pub fn enable_source(&self, source: ToolSource, cx: &App) {
let mut state = self.state.lock();
- state.disable_all_tools(cx);
+ state.enable_source(source, cx);
}
- pub fn enable_source(&self, source: &ToolSource) {
+ pub fn disable_source(&self, source: &ToolSource) {
let mut state = self.state.lock();
- state.enable_source(source);
- }
-
- pub fn disable_source(&self, source: ToolSource, cx: &App) {
- let mut state = self.state.lock();
- state.disable_source(source, cx);
+ state.disable_source(source);
}
pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
@@ -159,40 +144,36 @@ impl WorkingSetState {
}
fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
- !self.is_disabled(source, name)
+ self.enabled_tools_by_source
+ .get(source)
+ .map_or(false, |enabled_tools| enabled_tools.contains(name))
}
fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
- self.disabled_tools_by_source
- .get(source)
- .map_or(false, |disabled_tools| disabled_tools.contains(name))
+ !self.is_enabled(source, name)
}
fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
- self.disabled_tools_by_source
+ self.enabled_tools_by_source
.entry(source)
.or_default()
- .retain(|name| !tools_to_enable.contains(name));
+ .extend(tools_to_enable.into_iter().cloned());
}
fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
- self.disabled_tools_by_source
+ self.enabled_tools_by_source
.entry(source)
.or_default()
- .extend(tools_to_disable.into_iter().cloned());
- }
-
- fn enable_source(&mut self, source: &ToolSource) {
- self.disabled_tools_by_source.remove(source);
+ .retain(|name| !tools_to_disable.contains(name));
}
- fn disable_source(&mut self, source: ToolSource, cx: &App) {
+ fn enable_source(&mut self, source: ToolSource, cx: &App) {
let tools_by_source = self.tools_by_source(cx);
let Some(tools) = tools_by_source.get(&source) else {
return;
};
- self.disabled_tools_by_source.insert(
+ self.enabled_tools_by_source.insert(
source,
tools
.into_iter()
@@ -201,16 +182,11 @@ impl WorkingSetState {
);
}
- fn disable_all_tools(&mut self, cx: &App) {
- let tools = self.tools_by_source(cx);
-
- for (source, tools) in tools {
- let tool_names = tools
- .into_iter()
- .map(|tool| tool.name().into())
- .collect::<Vec<_>>();
+ fn disable_source(&mut self, source: &ToolSource) {
+ self.enabled_tools_by_source.remove(source);
+ }
- self.disable(source, &tool_names);
- }
+ fn disable_all_tools(&mut self) {
+ self.enabled_tools_by_source.clear();
}
}