1use std::sync::Arc;
2
3use collections::{HashMap, HashSet, IndexMap};
4use gpui::App;
5use parking_lot::Mutex;
6
7use crate::{Tool, ToolRegistry, ToolSource};
8
9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
10pub struct ToolId(usize);
11
12/// A working set of tools for use in one instance of the Assistant Panel.
13#[derive(Default)]
14pub struct ToolWorkingSet {
15 state: Mutex<WorkingSetState>,
16}
17
18struct WorkingSetState {
19 context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
20 context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
21 disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
22 is_scripting_tool_disabled: bool,
23 next_tool_id: ToolId,
24}
25
26impl Default for WorkingSetState {
27 fn default() -> Self {
28 Self {
29 context_server_tools_by_id: HashMap::default(),
30 context_server_tools_by_name: HashMap::default(),
31 disabled_tools_by_source: HashMap::default(),
32 is_scripting_tool_disabled: true,
33 next_tool_id: ToolId::default(),
34 }
35 }
36}
37
38impl ToolWorkingSet {
39 pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
40 self.state
41 .lock()
42 .context_server_tools_by_name
43 .get(name)
44 .cloned()
45 .or_else(|| ToolRegistry::global(cx).tool(name))
46 }
47
48 pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
49 self.state.lock().tools(cx)
50 }
51
52 pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
53 self.state.lock().tools_by_source(cx)
54 }
55
56 pub fn are_all_tools_enabled(&self) -> bool {
57 let state = self.state.lock();
58 state.disabled_tools_by_source.is_empty() && !state.is_scripting_tool_disabled
59 }
60
61 pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
62 self.state.lock().enabled_tools(cx)
63 }
64
65 pub fn enable_all_tools(&self) {
66 let mut state = self.state.lock();
67 state.disabled_tools_by_source.clear();
68 state.enable_scripting_tool();
69 }
70
71 pub fn disable_all_tools(&self, cx: &App) {
72 let mut state = self.state.lock();
73 state.disable_all_tools(cx);
74 }
75
76 pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
77 let mut state = self.state.lock();
78 let tool_id = state.next_tool_id;
79 state.next_tool_id.0 += 1;
80 state
81 .context_server_tools_by_id
82 .insert(tool_id, tool.clone());
83 state.tools_changed();
84 tool_id
85 }
86
87 pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
88 self.state.lock().is_enabled(source, name)
89 }
90
91 pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
92 self.state.lock().is_disabled(source, name)
93 }
94
95 pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
96 let mut state = self.state.lock();
97 state.enable(source, tools_to_enable);
98 }
99
100 pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
101 let mut state = self.state.lock();
102 state.disable(source, tools_to_disable);
103 }
104
105 pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
106 let mut state = self.state.lock();
107 state
108 .context_server_tools_by_id
109 .retain(|id, _| !tool_ids_to_remove.contains(id));
110 state.tools_changed();
111 }
112
113 pub fn is_scripting_tool_enabled(&self) -> bool {
114 let state = self.state.lock();
115 !state.is_scripting_tool_disabled
116 }
117
118 pub fn enable_scripting_tool(&self) {
119 let mut state = self.state.lock();
120 state.enable_scripting_tool();
121 }
122
123 pub fn disable_scripting_tool(&self) {
124 let mut state = self.state.lock();
125 state.disable_scripting_tool();
126 }
127}
128
129impl WorkingSetState {
130 fn tools_changed(&mut self) {
131 self.context_server_tools_by_name.clear();
132 self.context_server_tools_by_name.extend(
133 self.context_server_tools_by_id
134 .values()
135 .map(|tool| (tool.name(), tool.clone())),
136 );
137 }
138
139 fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
140 let mut tools = ToolRegistry::global(cx).tools();
141 tools.extend(self.context_server_tools_by_id.values().cloned());
142
143 tools
144 }
145
146 fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
147 let mut tools_by_source = IndexMap::default();
148
149 for tool in self.tools(cx) {
150 tools_by_source
151 .entry(tool.source())
152 .or_insert_with(Vec::new)
153 .push(tool);
154 }
155
156 for tools in tools_by_source.values_mut() {
157 tools.sort_by_key(|tool| tool.name());
158 }
159
160 tools_by_source.sort_unstable_keys();
161
162 tools_by_source
163 }
164
165 fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
166 let all_tools = self.tools(cx);
167
168 all_tools
169 .into_iter()
170 .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
171 .collect()
172 }
173
174 fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
175 !self.is_disabled(source, name)
176 }
177
178 fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
179 self.disabled_tools_by_source
180 .get(source)
181 .map_or(false, |disabled_tools| disabled_tools.contains(name))
182 }
183
184 fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
185 self.disabled_tools_by_source
186 .entry(source)
187 .or_default()
188 .retain(|name| !tools_to_enable.contains(name));
189 }
190
191 fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
192 self.disabled_tools_by_source
193 .entry(source)
194 .or_default()
195 .extend(tools_to_disable.into_iter().cloned());
196 }
197
198 fn disable_all_tools(&mut self, cx: &App) {
199 let tools = self.tools_by_source(cx);
200
201 for (source, tools) in tools {
202 let tool_names = tools
203 .into_iter()
204 .map(|tool| tool.name().into())
205 .collect::<Vec<_>>();
206
207 self.disable(source, &tool_names);
208 }
209
210 self.disable_scripting_tool();
211 }
212
213 fn enable_scripting_tool(&mut self) {
214 self.is_scripting_tool_disabled = false;
215 }
216
217 fn disable_scripting_tool(&mut self) {
218 self.is_scripting_tool_disabled = true;
219 }
220}