1use std::sync::Arc;
2
3use collections::{HashMap, HashSet, IndexMap};
4use gpui::App;
5use parking_lot::Mutex;
6
7use crate::{Tool, ToolRegistry, ToolSource};
8
9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
10pub struct ToolId(usize);
11
12/// A working set of tools for use in one instance of the Assistant Panel.
13#[derive(Default)]
14pub struct ToolWorkingSet {
15 state: Mutex<WorkingSetState>,
16}
17
18#[derive(Default)]
19struct WorkingSetState {
20 context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
21 context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
22 enabled_sources: HashSet<ToolSource>,
23 enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
24 next_tool_id: ToolId,
25}
26
27impl ToolWorkingSet {
28 pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
29 self.state
30 .lock()
31 .context_server_tools_by_name
32 .get(name)
33 .cloned()
34 .or_else(|| ToolRegistry::global(cx).tool(name))
35 }
36
37 pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
38 self.state.lock().tools(cx)
39 }
40
41 pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
42 self.state.lock().tools_by_source(cx)
43 }
44
45 pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
46 self.state.lock().enabled_tools(cx)
47 }
48
49 pub fn disable_all_tools(&self) {
50 let mut state = self.state.lock();
51 state.disable_all_tools();
52 }
53
54 pub fn enable_source(&self, source: ToolSource, cx: &App) {
55 let mut state = self.state.lock();
56 state.enable_source(source, cx);
57 }
58
59 pub fn disable_source(&self, source: &ToolSource) {
60 let mut state = self.state.lock();
61 state.disable_source(source);
62 }
63
64 pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
65 let mut state = self.state.lock();
66 let tool_id = state.next_tool_id;
67 state.next_tool_id.0 += 1;
68 state
69 .context_server_tools_by_id
70 .insert(tool_id, tool.clone());
71 state.tools_changed();
72 tool_id
73 }
74
75 pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
76 self.state.lock().is_enabled(source, name)
77 }
78
79 pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
80 self.state.lock().is_disabled(source, name)
81 }
82
83 pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
84 let mut state = self.state.lock();
85 state.enable(source, tools_to_enable);
86 }
87
88 pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
89 let mut state = self.state.lock();
90 state.disable(source, tools_to_disable);
91 }
92
93 pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
94 let mut state = self.state.lock();
95 state
96 .context_server_tools_by_id
97 .retain(|id, _| !tool_ids_to_remove.contains(id));
98 state.tools_changed();
99 }
100}
101
102impl WorkingSetState {
103 fn tools_changed(&mut self) {
104 self.context_server_tools_by_name.clear();
105 self.context_server_tools_by_name.extend(
106 self.context_server_tools_by_id
107 .values()
108 .map(|tool| (tool.name(), tool.clone())),
109 );
110 }
111
112 fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
113 let mut tools = ToolRegistry::global(cx).tools();
114 tools.extend(self.context_server_tools_by_id.values().cloned());
115
116 tools
117 }
118
119 fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
120 let mut tools_by_source = IndexMap::default();
121
122 for tool in self.tools(cx) {
123 tools_by_source
124 .entry(tool.source())
125 .or_insert_with(Vec::new)
126 .push(tool);
127 }
128
129 for tools in tools_by_source.values_mut() {
130 tools.sort_by_key(|tool| tool.name());
131 }
132
133 tools_by_source.sort_unstable_keys();
134
135 tools_by_source
136 }
137
138 fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
139 let all_tools = self.tools(cx);
140
141 all_tools
142 .into_iter()
143 .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
144 .collect()
145 }
146
147 fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
148 self.enabled_tools_by_source
149 .get(source)
150 .map_or(false, |enabled_tools| enabled_tools.contains(name))
151 }
152
153 fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
154 !self.is_enabled(source, name)
155 }
156
157 fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
158 self.enabled_tools_by_source
159 .entry(source)
160 .or_default()
161 .extend(tools_to_enable.into_iter().cloned());
162 }
163
164 fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
165 self.enabled_tools_by_source
166 .entry(source)
167 .or_default()
168 .retain(|name| !tools_to_disable.contains(name));
169 }
170
171 fn enable_source(&mut self, source: ToolSource, cx: &App) {
172 self.enabled_sources.insert(source.clone());
173
174 let tools_by_source = self.tools_by_source(cx);
175 if let Some(tools) = tools_by_source.get(&source) {
176 self.enabled_tools_by_source.insert(
177 source,
178 tools
179 .into_iter()
180 .map(|tool| tool.name().into())
181 .collect::<HashSet<_>>(),
182 );
183 }
184 }
185
186 fn disable_source(&mut self, source: &ToolSource) {
187 self.enabled_sources.remove(source);
188 self.enabled_tools_by_source.remove(source);
189 }
190
191 fn disable_all_tools(&mut self) {
192 self.enabled_tools_by_source.clear();
193 }
194}