1use std::sync::Arc;
2
3use collections::{HashMap, HashSet, IndexMap};
4use gpui::App;
5use parking_lot::Mutex;
6
7use crate::{Tool, ToolRegistry, ToolSource};
8
9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
10pub struct ToolId(usize);
11
12/// A working set of tools for use in one instance of the Assistant Panel.
13#[derive(Default)]
14pub struct ToolWorkingSet {
15 state: Mutex<WorkingSetState>,
16}
17
18struct WorkingSetState {
19 context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
20 context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
21 disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
22 is_scripting_tool_disabled: bool,
23 next_tool_id: ToolId,
24}
25
26impl Default for WorkingSetState {
27 fn default() -> Self {
28 Self {
29 context_server_tools_by_id: HashMap::default(),
30 context_server_tools_by_name: HashMap::default(),
31 disabled_tools_by_source: HashMap::default(),
32 is_scripting_tool_disabled: true,
33 next_tool_id: ToolId::default(),
34 }
35 }
36}
37
38impl ToolWorkingSet {
39 pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
40 self.state
41 .lock()
42 .context_server_tools_by_name
43 .get(name)
44 .cloned()
45 .or_else(|| ToolRegistry::global(cx).tool(name))
46 }
47
48 pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
49 self.state.lock().tools(cx)
50 }
51
52 pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
53 self.state.lock().tools_by_source(cx)
54 }
55
56 pub fn are_all_tools_enabled(&self) -> bool {
57 let state = self.state.lock();
58 state.disabled_tools_by_source.is_empty() && !state.is_scripting_tool_disabled
59 }
60
61 pub fn are_all_tools_from_source_enabled(&self, source: &ToolSource) -> bool {
62 let state = self.state.lock();
63 !state.disabled_tools_by_source.contains_key(source)
64 }
65
66 pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
67 self.state.lock().enabled_tools(cx)
68 }
69
70 pub fn enable_all_tools(&self) {
71 let mut state = self.state.lock();
72 state.disabled_tools_by_source.clear();
73 state.enable_scripting_tool();
74 }
75
76 pub fn disable_all_tools(&self, cx: &App) {
77 let mut state = self.state.lock();
78 state.disable_all_tools(cx);
79 }
80
81 pub fn enable_source(&self, source: &ToolSource) {
82 let mut state = self.state.lock();
83 state.enable_source(source);
84 }
85
86 pub fn disable_source(&self, source: ToolSource, cx: &App) {
87 let mut state = self.state.lock();
88 state.disable_source(source, cx);
89 }
90
91 pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
92 let mut state = self.state.lock();
93 let tool_id = state.next_tool_id;
94 state.next_tool_id.0 += 1;
95 state
96 .context_server_tools_by_id
97 .insert(tool_id, tool.clone());
98 state.tools_changed();
99 tool_id
100 }
101
102 pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
103 self.state.lock().is_enabled(source, name)
104 }
105
106 pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
107 self.state.lock().is_disabled(source, name)
108 }
109
110 pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
111 let mut state = self.state.lock();
112 state.enable(source, tools_to_enable);
113 }
114
115 pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
116 let mut state = self.state.lock();
117 state.disable(source, tools_to_disable);
118 }
119
120 pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
121 let mut state = self.state.lock();
122 state
123 .context_server_tools_by_id
124 .retain(|id, _| !tool_ids_to_remove.contains(id));
125 state.tools_changed();
126 }
127
128 pub fn is_scripting_tool_enabled(&self) -> bool {
129 let state = self.state.lock();
130 !state.is_scripting_tool_disabled
131 }
132
133 pub fn enable_scripting_tool(&self) {
134 let mut state = self.state.lock();
135 state.enable_scripting_tool();
136 }
137
138 pub fn disable_scripting_tool(&self) {
139 let mut state = self.state.lock();
140 state.disable_scripting_tool();
141 }
142}
143
144impl WorkingSetState {
145 fn tools_changed(&mut self) {
146 self.context_server_tools_by_name.clear();
147 self.context_server_tools_by_name.extend(
148 self.context_server_tools_by_id
149 .values()
150 .map(|tool| (tool.name(), tool.clone())),
151 );
152 }
153
154 fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
155 let mut tools = ToolRegistry::global(cx).tools();
156 tools.extend(self.context_server_tools_by_id.values().cloned());
157
158 tools
159 }
160
161 fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
162 let mut tools_by_source = IndexMap::default();
163
164 for tool in self.tools(cx) {
165 tools_by_source
166 .entry(tool.source())
167 .or_insert_with(Vec::new)
168 .push(tool);
169 }
170
171 for tools in tools_by_source.values_mut() {
172 tools.sort_by_key(|tool| tool.name());
173 }
174
175 tools_by_source.sort_unstable_keys();
176
177 tools_by_source
178 }
179
180 fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
181 let all_tools = self.tools(cx);
182
183 all_tools
184 .into_iter()
185 .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
186 .collect()
187 }
188
189 fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
190 !self.is_disabled(source, name)
191 }
192
193 fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
194 self.disabled_tools_by_source
195 .get(source)
196 .map_or(false, |disabled_tools| disabled_tools.contains(name))
197 }
198
199 fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
200 self.disabled_tools_by_source
201 .entry(source)
202 .or_default()
203 .retain(|name| !tools_to_enable.contains(name));
204 }
205
206 fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
207 self.disabled_tools_by_source
208 .entry(source)
209 .or_default()
210 .extend(tools_to_disable.into_iter().cloned());
211 }
212
213 fn enable_source(&mut self, source: &ToolSource) {
214 self.disabled_tools_by_source.remove(source);
215 }
216
217 fn disable_source(&mut self, source: ToolSource, cx: &App) {
218 let tools_by_source = self.tools_by_source(cx);
219 let Some(tools) = tools_by_source.get(&source) else {
220 return;
221 };
222
223 self.disabled_tools_by_source.insert(
224 source,
225 tools
226 .into_iter()
227 .map(|tool| tool.name().into())
228 .collect::<HashSet<_>>(),
229 );
230 }
231
232 fn disable_all_tools(&mut self, cx: &App) {
233 let tools = self.tools_by_source(cx);
234
235 for (source, tools) in tools {
236 let tool_names = tools
237 .into_iter()
238 .map(|tool| tool.name().into())
239 .collect::<Vec<_>>();
240
241 self.disable(source, &tool_names);
242 }
243
244 self.disable_scripting_tool();
245 }
246
247 fn enable_scripting_tool(&mut self) {
248 self.is_scripting_tool_disabled = false;
249 }
250
251 fn disable_scripting_tool(&mut self) {
252 self.is_scripting_tool_disabled = true;
253 }
254}