1use std::sync::Arc;
2
3use collections::{HashMap, HashSet, IndexMap};
4use gpui::App;
5use parking_lot::Mutex;
6
7use crate::{Tool, ToolRegistry, ToolSource};
8
9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
10pub struct ToolId(usize);
11
12/// A working set of tools for use in one instance of the Assistant Panel.
13#[derive(Default)]
14pub struct ToolWorkingSet {
15 state: Mutex<WorkingSetState>,
16}
17
18#[derive(Default)]
19struct WorkingSetState {
20 context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
21 context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
22 disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
23 next_tool_id: ToolId,
24}
25
26impl ToolWorkingSet {
27 pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
28 self.state
29 .lock()
30 .context_server_tools_by_name
31 .get(name)
32 .cloned()
33 .or_else(|| ToolRegistry::global(cx).tool(name))
34 }
35
36 pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
37 self.state.lock().tools(cx)
38 }
39
40 pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
41 self.state.lock().tools_by_source(cx)
42 }
43
44 pub fn are_all_tools_enabled(&self) -> bool {
45 let state = self.state.lock();
46 state.disabled_tools_by_source.is_empty()
47 }
48
49 pub fn are_all_tools_from_source_enabled(&self, source: &ToolSource) -> bool {
50 let state = self.state.lock();
51 !state.disabled_tools_by_source.contains_key(source)
52 }
53
54 pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
55 self.state.lock().enabled_tools(cx)
56 }
57
58 pub fn enable_all_tools(&self) {
59 let mut state = self.state.lock();
60 state.disabled_tools_by_source.clear();
61 }
62
63 pub fn disable_all_tools(&self, cx: &App) {
64 let mut state = self.state.lock();
65 state.disable_all_tools(cx);
66 }
67
68 pub fn enable_source(&self, source: &ToolSource) {
69 let mut state = self.state.lock();
70 state.enable_source(source);
71 }
72
73 pub fn disable_source(&self, source: ToolSource, cx: &App) {
74 let mut state = self.state.lock();
75 state.disable_source(source, cx);
76 }
77
78 pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
79 let mut state = self.state.lock();
80 let tool_id = state.next_tool_id;
81 state.next_tool_id.0 += 1;
82 state
83 .context_server_tools_by_id
84 .insert(tool_id, tool.clone());
85 state.tools_changed();
86 tool_id
87 }
88
89 pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
90 self.state.lock().is_enabled(source, name)
91 }
92
93 pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
94 self.state.lock().is_disabled(source, name)
95 }
96
97 pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
98 let mut state = self.state.lock();
99 state.enable(source, tools_to_enable);
100 }
101
102 pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
103 let mut state = self.state.lock();
104 state.disable(source, tools_to_disable);
105 }
106
107 pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
108 let mut state = self.state.lock();
109 state
110 .context_server_tools_by_id
111 .retain(|id, _| !tool_ids_to_remove.contains(id));
112 state.tools_changed();
113 }
114}
115
116impl WorkingSetState {
117 fn tools_changed(&mut self) {
118 self.context_server_tools_by_name.clear();
119 self.context_server_tools_by_name.extend(
120 self.context_server_tools_by_id
121 .values()
122 .map(|tool| (tool.name(), tool.clone())),
123 );
124 }
125
126 fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
127 let mut tools = ToolRegistry::global(cx).tools();
128 tools.extend(self.context_server_tools_by_id.values().cloned());
129
130 tools
131 }
132
133 fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
134 let mut tools_by_source = IndexMap::default();
135
136 for tool in self.tools(cx) {
137 tools_by_source
138 .entry(tool.source())
139 .or_insert_with(Vec::new)
140 .push(tool);
141 }
142
143 for tools in tools_by_source.values_mut() {
144 tools.sort_by_key(|tool| tool.name());
145 }
146
147 tools_by_source.sort_unstable_keys();
148
149 tools_by_source
150 }
151
152 fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
153 let all_tools = self.tools(cx);
154
155 all_tools
156 .into_iter()
157 .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
158 .collect()
159 }
160
161 fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
162 !self.is_disabled(source, name)
163 }
164
165 fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
166 self.disabled_tools_by_source
167 .get(source)
168 .map_or(false, |disabled_tools| disabled_tools.contains(name))
169 }
170
171 fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
172 self.disabled_tools_by_source
173 .entry(source)
174 .or_default()
175 .retain(|name| !tools_to_enable.contains(name));
176 }
177
178 fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
179 self.disabled_tools_by_source
180 .entry(source)
181 .or_default()
182 .extend(tools_to_disable.into_iter().cloned());
183 }
184
185 fn enable_source(&mut self, source: &ToolSource) {
186 self.disabled_tools_by_source.remove(source);
187 }
188
189 fn disable_source(&mut self, source: ToolSource, cx: &App) {
190 let tools_by_source = self.tools_by_source(cx);
191 let Some(tools) = tools_by_source.get(&source) else {
192 return;
193 };
194
195 self.disabled_tools_by_source.insert(
196 source,
197 tools
198 .into_iter()
199 .map(|tool| tool.name().into())
200 .collect::<HashSet<_>>(),
201 );
202 }
203
204 fn disable_all_tools(&mut self, cx: &App) {
205 let tools = self.tools_by_source(cx);
206
207 for (source, tools) in tools {
208 let tool_names = tools
209 .into_iter()
210 .map(|tool| tool.name().into())
211 .collect::<Vec<_>>();
212
213 self.disable(source, &tool_names);
214 }
215 }
216}