1use std::sync::Arc;
2
3use collections::{HashMap, HashSet, IndexMap};
4use gpui::App;
5use parking_lot::Mutex;
6
7use crate::{Tool, ToolRegistry, ToolSource};
8
9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
10pub struct ToolId(usize);
11
12/// A working set of tools for use in one instance of the Assistant Panel.
13#[derive(Default)]
14pub struct ToolWorkingSet {
15 state: Mutex<WorkingSetState>,
16}
17
18#[derive(Default)]
19struct WorkingSetState {
20 context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
21 context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
22 disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
23 is_scripting_tool_disabled: bool,
24 next_tool_id: ToolId,
25}
26
27impl ToolWorkingSet {
28 pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
29 self.state
30 .lock()
31 .context_server_tools_by_name
32 .get(name)
33 .cloned()
34 .or_else(|| ToolRegistry::global(cx).tool(name))
35 }
36
37 pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
38 let mut tools = ToolRegistry::global(cx).tools();
39 tools.extend(
40 self.state
41 .lock()
42 .context_server_tools_by_id
43 .values()
44 .cloned(),
45 );
46
47 tools
48 }
49
50 pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
51 let all_tools = self.tools(cx);
52
53 all_tools
54 .into_iter()
55 .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
56 .collect()
57 }
58
59 pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
60 let mut tools_by_source = IndexMap::default();
61
62 for tool in self.tools(cx) {
63 tools_by_source
64 .entry(tool.source())
65 .or_insert_with(Vec::new)
66 .push(tool);
67 }
68
69 for tools in tools_by_source.values_mut() {
70 tools.sort_by_key(|tool| tool.name());
71 }
72
73 tools_by_source.sort_unstable_keys();
74
75 tools_by_source
76 }
77
78 pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
79 let mut state = self.state.lock();
80 let tool_id = state.next_tool_id;
81 state.next_tool_id.0 += 1;
82 state
83 .context_server_tools_by_id
84 .insert(tool_id, tool.clone());
85 state.tools_changed();
86 tool_id
87 }
88
89 pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
90 !self.is_disabled(source, name)
91 }
92
93 pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
94 let state = self.state.lock();
95 state
96 .disabled_tools_by_source
97 .get(source)
98 .map_or(false, |disabled_tools| disabled_tools.contains(name))
99 }
100
101 pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
102 let mut state = self.state.lock();
103 state
104 .disabled_tools_by_source
105 .entry(source)
106 .or_default()
107 .retain(|name| !tools_to_enable.contains(name));
108 }
109
110 pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
111 let mut state = self.state.lock();
112 state
113 .disabled_tools_by_source
114 .entry(source)
115 .or_default()
116 .extend(tools_to_disable.into_iter().cloned());
117 }
118
119 pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
120 let mut state = self.state.lock();
121 state
122 .context_server_tools_by_id
123 .retain(|id, _| !tool_ids_to_remove.contains(id));
124 state.tools_changed();
125 }
126
127 pub fn is_scripting_tool_enabled(&self) -> bool {
128 let state = self.state.lock();
129 !state.is_scripting_tool_disabled
130 }
131
132 pub fn enable_scripting_tool(&self) {
133 let mut state = self.state.lock();
134 state.is_scripting_tool_disabled = false;
135 }
136
137 pub fn disable_scripting_tool(&self) {
138 let mut state = self.state.lock();
139 state.is_scripting_tool_disabled = true;
140 }
141}
142
143impl WorkingSetState {
144 fn tools_changed(&mut self) {
145 self.context_server_tools_by_name.clear();
146 self.context_server_tools_by_name.extend(
147 self.context_server_tools_by_id
148 .values()
149 .map(|tool| (tool.name(), tool.clone())),
150 );
151 }
152}