1use std::sync::Arc;
2
3use collections::{HashMap, HashSet, IndexMap};
4use gpui::App;
5use parking_lot::Mutex;
6
7use crate::{Tool, ToolRegistry, ToolSource};
8
9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
10pub struct ToolId(usize);
11
12/// A working set of tools for use in one instance of the Assistant Panel.
13#[derive(Default)]
14pub struct ToolWorkingSet {
15 state: Mutex<WorkingSetState>,
16}
17
18#[derive(Default)]
19struct WorkingSetState {
20 context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
21 context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
22 disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
23 next_tool_id: ToolId,
24}
25
26impl ToolWorkingSet {
27 pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
28 self.state
29 .lock()
30 .context_server_tools_by_name
31 .get(name)
32 .cloned()
33 .or_else(|| ToolRegistry::global(cx).tool(name))
34 }
35
36 pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
37 let mut tools = ToolRegistry::global(cx).tools();
38 tools.extend(
39 self.state
40 .lock()
41 .context_server_tools_by_id
42 .values()
43 .cloned(),
44 );
45
46 tools
47 }
48
49 pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
50 let all_tools = self.tools(cx);
51
52 all_tools
53 .into_iter()
54 .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
55 .collect()
56 }
57
58 pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
59 let mut tools_by_source = IndexMap::default();
60
61 for tool in self.tools(cx) {
62 tools_by_source
63 .entry(tool.source())
64 .or_insert_with(Vec::new)
65 .push(tool);
66 }
67
68 for tools in tools_by_source.values_mut() {
69 tools.sort_by_key(|tool| tool.name());
70 }
71
72 tools_by_source.sort_unstable_keys();
73
74 tools_by_source
75 }
76
77 pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
78 let mut state = self.state.lock();
79 let tool_id = state.next_tool_id;
80 state.next_tool_id.0 += 1;
81 state
82 .context_server_tools_by_id
83 .insert(tool_id, tool.clone());
84 state.tools_changed();
85 tool_id
86 }
87
88 pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
89 !self.is_disabled(source, name)
90 }
91
92 pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
93 let state = self.state.lock();
94 state
95 .disabled_tools_by_source
96 .get(source)
97 .map_or(false, |disabled_tools| disabled_tools.contains(name))
98 }
99
100 pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
101 let mut state = self.state.lock();
102 state
103 .disabled_tools_by_source
104 .entry(source)
105 .or_default()
106 .retain(|name| !tools_to_enable.contains(name));
107 }
108
109 pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
110 let mut state = self.state.lock();
111 state
112 .disabled_tools_by_source
113 .entry(source)
114 .or_default()
115 .extend(tools_to_disable.into_iter().cloned());
116 }
117
118 pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
119 let mut state = self.state.lock();
120 state
121 .context_server_tools_by_id
122 .retain(|id, _| !tool_ids_to_remove.contains(id));
123 state.tools_changed();
124 }
125}
126
127impl WorkingSetState {
128 fn tools_changed(&mut self) {
129 self.context_server_tools_by_name.clear();
130 self.context_server_tools_by_name.extend(
131 self.context_server_tools_by_id
132 .values()
133 .map(|tool| (tool.name(), tool.clone())),
134 );
135 }
136}