1use std::sync::Arc;
2
3use collections::{HashMap, HashSet, IndexMap};
4use gpui::App;
5use parking_lot::Mutex;
6
7use crate::{Tool, ToolRegistry, ToolSource};
8
9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
10pub struct ToolId(usize);
11
12/// A working set of tools for use in one instance of the Assistant Panel.
13#[derive(Default)]
14pub struct ToolWorkingSet {
15 state: Mutex<WorkingSetState>,
16}
17
18struct WorkingSetState {
19 context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
20 context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
21 disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
22 is_scripting_tool_disabled: bool,
23 next_tool_id: ToolId,
24}
25
26impl Default for WorkingSetState {
27 fn default() -> Self {
28 Self {
29 context_server_tools_by_id: HashMap::default(),
30 context_server_tools_by_name: HashMap::default(),
31 disabled_tools_by_source: HashMap::default(),
32 is_scripting_tool_disabled: true,
33 next_tool_id: ToolId::default(),
34 }
35 }
36}
37
38impl ToolWorkingSet {
39 pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
40 self.state
41 .lock()
42 .context_server_tools_by_name
43 .get(name)
44 .cloned()
45 .or_else(|| ToolRegistry::global(cx).tool(name))
46 }
47
48 pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
49 let mut tools = ToolRegistry::global(cx).tools();
50 tools.extend(
51 self.state
52 .lock()
53 .context_server_tools_by_id
54 .values()
55 .cloned(),
56 );
57
58 tools
59 }
60
61 pub fn are_all_tools_enabled(&self) -> bool {
62 let state = self.state.lock();
63
64 state.disabled_tools_by_source.is_empty() && !state.is_scripting_tool_disabled
65 }
66
67 pub fn enable_all_tools(&self) {
68 let mut state = self.state.lock();
69
70 state.disabled_tools_by_source.clear();
71 state.is_scripting_tool_disabled = false;
72 }
73
74 pub fn disable_all_tools(&self, cx: &App) {
75 let tools = self.tools_by_source(cx);
76
77 for (source, tools) in tools {
78 let tool_names = tools
79 .into_iter()
80 .map(|tool| tool.name().into())
81 .collect::<Vec<_>>();
82
83 self.disable(source, &tool_names);
84 }
85
86 self.disable_scripting_tool();
87 }
88
89 pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
90 let all_tools = self.tools(cx);
91
92 all_tools
93 .into_iter()
94 .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
95 .collect()
96 }
97
98 pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
99 let mut tools_by_source = IndexMap::default();
100
101 for tool in self.tools(cx) {
102 tools_by_source
103 .entry(tool.source())
104 .or_insert_with(Vec::new)
105 .push(tool);
106 }
107
108 for tools in tools_by_source.values_mut() {
109 tools.sort_by_key(|tool| tool.name());
110 }
111
112 tools_by_source.sort_unstable_keys();
113
114 tools_by_source
115 }
116
117 pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
118 let mut state = self.state.lock();
119 let tool_id = state.next_tool_id;
120 state.next_tool_id.0 += 1;
121 state
122 .context_server_tools_by_id
123 .insert(tool_id, tool.clone());
124 state.tools_changed();
125 tool_id
126 }
127
128 pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
129 !self.is_disabled(source, name)
130 }
131
132 pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
133 let state = self.state.lock();
134 state
135 .disabled_tools_by_source
136 .get(source)
137 .map_or(false, |disabled_tools| disabled_tools.contains(name))
138 }
139
140 pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
141 let mut state = self.state.lock();
142 state
143 .disabled_tools_by_source
144 .entry(source)
145 .or_default()
146 .retain(|name| !tools_to_enable.contains(name));
147 }
148
149 pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
150 let mut state = self.state.lock();
151 state
152 .disabled_tools_by_source
153 .entry(source)
154 .or_default()
155 .extend(tools_to_disable.into_iter().cloned());
156 }
157
158 pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
159 let mut state = self.state.lock();
160 state
161 .context_server_tools_by_id
162 .retain(|id, _| !tool_ids_to_remove.contains(id));
163 state.tools_changed();
164 }
165
166 pub fn is_scripting_tool_enabled(&self) -> bool {
167 let state = self.state.lock();
168 !state.is_scripting_tool_disabled
169 }
170
171 pub fn enable_scripting_tool(&self) {
172 let mut state = self.state.lock();
173 state.is_scripting_tool_disabled = false;
174 }
175
176 pub fn disable_scripting_tool(&self) {
177 let mut state = self.state.lock();
178 state.is_scripting_tool_disabled = true;
179 }
180}
181
182impl WorkingSetState {
183 fn tools_changed(&mut self) {
184 self.context_server_tools_by_name.clear();
185 self.context_server_tools_by_name.extend(
186 self.context_server_tools_by_id
187 .values()
188 .map(|tool| (tool.name(), tool.clone())),
189 );
190 }
191}