1use std::sync::Arc;
2
3use collections::{HashMap, HashSet, IndexMap};
4use gpui::App;
5use parking_lot::Mutex;
6
7use crate::{Tool, ToolRegistry, ToolSource};
8
9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
10pub struct ToolId(usize);
11
12/// A working set of tools for use in one instance of the Assistant Panel.
13#[derive(Default)]
14pub struct ToolWorkingSet {
15 state: Mutex<WorkingSetState>,
16}
17
18#[derive(Default)]
19struct WorkingSetState {
20 context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
21 context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
22 enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
23 next_tool_id: ToolId,
24}
25
26impl ToolWorkingSet {
27 pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
28 self.state
29 .lock()
30 .context_server_tools_by_name
31 .get(name)
32 .cloned()
33 .or_else(|| ToolRegistry::global(cx).tool(name))
34 }
35
36 pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
37 self.state.lock().tools(cx)
38 }
39
40 pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
41 self.state.lock().tools_by_source(cx)
42 }
43
44 pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
45 self.state.lock().enabled_tools(cx)
46 }
47
48 pub fn disable_all_tools(&self) {
49 let mut state = self.state.lock();
50 state.disable_all_tools();
51 }
52
53 pub fn enable_source(&self, source: ToolSource, cx: &App) {
54 let mut state = self.state.lock();
55 state.enable_source(source, cx);
56 }
57
58 pub fn disable_source(&self, source: &ToolSource) {
59 let mut state = self.state.lock();
60 state.disable_source(source);
61 }
62
63 pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
64 let mut state = self.state.lock();
65 let tool_id = state.next_tool_id;
66 state.next_tool_id.0 += 1;
67 state
68 .context_server_tools_by_id
69 .insert(tool_id, tool.clone());
70 state.tools_changed();
71 tool_id
72 }
73
74 pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
75 self.state.lock().is_enabled(source, name)
76 }
77
78 pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
79 self.state.lock().is_disabled(source, name)
80 }
81
82 pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
83 let mut state = self.state.lock();
84 state.enable(source, tools_to_enable);
85 }
86
87 pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
88 let mut state = self.state.lock();
89 state.disable(source, tools_to_disable);
90 }
91
92 pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
93 let mut state = self.state.lock();
94 state
95 .context_server_tools_by_id
96 .retain(|id, _| !tool_ids_to_remove.contains(id));
97 state.tools_changed();
98 }
99}
100
101impl WorkingSetState {
102 fn tools_changed(&mut self) {
103 self.context_server_tools_by_name.clear();
104 self.context_server_tools_by_name.extend(
105 self.context_server_tools_by_id
106 .values()
107 .map(|tool| (tool.name(), tool.clone())),
108 );
109 }
110
111 fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
112 let mut tools = ToolRegistry::global(cx).tools();
113 tools.extend(self.context_server_tools_by_id.values().cloned());
114
115 tools
116 }
117
118 fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
119 let mut tools_by_source = IndexMap::default();
120
121 for tool in self.tools(cx) {
122 tools_by_source
123 .entry(tool.source())
124 .or_insert_with(Vec::new)
125 .push(tool);
126 }
127
128 for tools in tools_by_source.values_mut() {
129 tools.sort_by_key(|tool| tool.name());
130 }
131
132 tools_by_source.sort_unstable_keys();
133
134 tools_by_source
135 }
136
137 fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
138 let all_tools = self.tools(cx);
139
140 all_tools
141 .into_iter()
142 .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
143 .collect()
144 }
145
146 fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
147 self.enabled_tools_by_source
148 .get(source)
149 .map_or(false, |enabled_tools| enabled_tools.contains(name))
150 }
151
152 fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
153 !self.is_enabled(source, name)
154 }
155
156 fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
157 self.enabled_tools_by_source
158 .entry(source)
159 .or_default()
160 .extend(tools_to_enable.into_iter().cloned());
161 }
162
163 fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
164 self.enabled_tools_by_source
165 .entry(source)
166 .or_default()
167 .retain(|name| !tools_to_disable.contains(name));
168 }
169
170 fn enable_source(&mut self, source: ToolSource, cx: &App) {
171 let tools_by_source = self.tools_by_source(cx);
172 let Some(tools) = tools_by_source.get(&source) else {
173 return;
174 };
175
176 self.enabled_tools_by_source.insert(
177 source,
178 tools
179 .into_iter()
180 .map(|tool| tool.name().into())
181 .collect::<HashSet<_>>(),
182 );
183 }
184
185 fn disable_source(&mut self, source: &ToolSource) {
186 self.enabled_tools_by_source.remove(source);
187 }
188
189 fn disable_all_tools(&mut self) {
190 self.enabled_tools_by_source.clear();
191 }
192}