1use std::sync::Arc;
2
3use collections::{HashMap, HashSet, IndexMap};
4use gpui::App;
5use parking_lot::Mutex;
6
7use crate::{Tool, ToolRegistry, ToolSource};
8
9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
10pub struct ToolId(usize);
11
12/// A working set of tools for use in one instance of the Assistant Panel.
13#[derive(Default)]
14pub struct ToolWorkingSet {
15 state: Mutex<WorkingSetState>,
16}
17
18struct WorkingSetState {
19 context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
20 context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
21 disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
22 is_scripting_tool_disabled: bool,
23 next_tool_id: ToolId,
24}
25
26impl Default for WorkingSetState {
27 fn default() -> Self {
28 Self {
29 context_server_tools_by_id: Default::default(),
30 context_server_tools_by_name: Default::default(),
31 disabled_tools_by_source: Default::default(),
32 is_scripting_tool_disabled: true,
33 next_tool_id: Default::default(),
34 }
35 }
36}
37
38impl ToolWorkingSet {
39 pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
40 self.state
41 .lock()
42 .context_server_tools_by_name
43 .get(name)
44 .cloned()
45 .or_else(|| ToolRegistry::global(cx).tool(name))
46 }
47
48 pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
49 let mut tools = ToolRegistry::global(cx).tools();
50 tools.extend(
51 self.state
52 .lock()
53 .context_server_tools_by_id
54 .values()
55 .cloned(),
56 );
57
58 tools
59 }
60
61 pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
62 let all_tools = self.tools(cx);
63
64 all_tools
65 .into_iter()
66 .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
67 .collect()
68 }
69
70 pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
71 let mut tools_by_source = IndexMap::default();
72
73 for tool in self.tools(cx) {
74 tools_by_source
75 .entry(tool.source())
76 .or_insert_with(Vec::new)
77 .push(tool);
78 }
79
80 for tools in tools_by_source.values_mut() {
81 tools.sort_by_key(|tool| tool.name());
82 }
83
84 tools_by_source.sort_unstable_keys();
85
86 tools_by_source
87 }
88
89 pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
90 let mut state = self.state.lock();
91 let tool_id = state.next_tool_id;
92 state.next_tool_id.0 += 1;
93 state
94 .context_server_tools_by_id
95 .insert(tool_id, tool.clone());
96 state.tools_changed();
97 tool_id
98 }
99
100 pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
101 !self.is_disabled(source, name)
102 }
103
104 pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
105 let state = self.state.lock();
106 state
107 .disabled_tools_by_source
108 .get(source)
109 .map_or(false, |disabled_tools| disabled_tools.contains(name))
110 }
111
112 pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
113 let mut state = self.state.lock();
114 state
115 .disabled_tools_by_source
116 .entry(source)
117 .or_default()
118 .retain(|name| !tools_to_enable.contains(name));
119 }
120
121 pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
122 let mut state = self.state.lock();
123 state
124 .disabled_tools_by_source
125 .entry(source)
126 .or_default()
127 .extend(tools_to_disable.into_iter().cloned());
128 }
129
130 pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
131 let mut state = self.state.lock();
132 state
133 .context_server_tools_by_id
134 .retain(|id, _| !tool_ids_to_remove.contains(id));
135 state.tools_changed();
136 }
137
138 pub fn is_scripting_tool_enabled(&self) -> bool {
139 let state = self.state.lock();
140 !state.is_scripting_tool_disabled
141 }
142
143 pub fn enable_scripting_tool(&self) {
144 let mut state = self.state.lock();
145 state.is_scripting_tool_disabled = false;
146 }
147
148 pub fn disable_scripting_tool(&self) {
149 let mut state = self.state.lock();
150 state.is_scripting_tool_disabled = true;
151 }
152}
153
154impl WorkingSetState {
155 fn tools_changed(&mut self) {
156 self.context_server_tools_by_name.clear();
157 self.context_server_tools_by_name.extend(
158 self.context_server_tools_by_id
159 .values()
160 .map(|tool| (tool.name(), tool.clone())),
161 );
162 }
163}