1use std::sync::Arc;
2
3use collections::{HashMap, HashSet, IndexMap};
4use gpui::{App, Context, EventEmitter};
5
6use crate::{Tool, ToolRegistry, ToolSource};
7
8#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
9pub struct ToolId(usize);
10
11/// A working set of tools for use in one instance of the Assistant Panel.
12#[derive(Default)]
13pub struct ToolWorkingSet {
14 context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
15 context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
16 enabled_sources: HashSet<ToolSource>,
17 enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
18 next_tool_id: ToolId,
19}
20
21pub enum ToolWorkingSetEvent {
22 EnabledToolsChanged,
23}
24
25impl EventEmitter<ToolWorkingSetEvent> for ToolWorkingSet {}
26
27impl ToolWorkingSet {
28 pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
29 self.context_server_tools_by_name
30 .get(name)
31 .cloned()
32 .or_else(|| ToolRegistry::global(cx).tool(name))
33 }
34
35 pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
36 let mut tools = ToolRegistry::global(cx).tools();
37 tools.extend(self.context_server_tools_by_id.values().cloned());
38 tools
39 }
40
41 pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
42 let mut tools_by_source = IndexMap::default();
43
44 for tool in self.tools(cx) {
45 tools_by_source
46 .entry(tool.source())
47 .or_insert_with(Vec::new)
48 .push(tool);
49 }
50
51 for tools in tools_by_source.values_mut() {
52 tools.sort_by_key(|tool| tool.name());
53 }
54
55 tools_by_source.sort_unstable_keys();
56
57 tools_by_source
58 }
59
60 pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
61 let all_tools = self.tools(cx);
62
63 all_tools
64 .into_iter()
65 .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
66 .collect()
67 }
68
69 pub fn disable_all_tools(&mut self, cx: &mut Context<Self>) {
70 self.enabled_tools_by_source.clear();
71 cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
72 }
73
74 pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context<Self>) {
75 self.enabled_sources.insert(source.clone());
76
77 let tools_by_source = self.tools_by_source(cx);
78 if let Some(tools) = tools_by_source.get(&source) {
79 self.enabled_tools_by_source.insert(
80 source,
81 tools
82 .into_iter()
83 .map(|tool| tool.name().into())
84 .collect::<HashSet<_>>(),
85 );
86 }
87 cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
88 }
89
90 pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context<Self>) {
91 self.enabled_sources.remove(source);
92 self.enabled_tools_by_source.remove(source);
93 cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
94 }
95
96 pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
97 let tool_id = self.next_tool_id;
98 self.next_tool_id.0 += 1;
99 self.context_server_tools_by_id
100 .insert(tool_id, tool.clone());
101 self.tools_changed();
102 tool_id
103 }
104
105 pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
106 self.enabled_tools_by_source
107 .get(source)
108 .map_or(false, |enabled_tools| enabled_tools.contains(name))
109 }
110
111 pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
112 !self.is_enabled(source, name)
113 }
114
115 pub fn enable(
116 &mut self,
117 source: ToolSource,
118 tools_to_enable: &[Arc<str>],
119 cx: &mut Context<Self>,
120 ) {
121 self.enabled_tools_by_source
122 .entry(source)
123 .or_default()
124 .extend(tools_to_enable.into_iter().cloned());
125 cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
126 }
127
128 pub fn disable(
129 &mut self,
130 source: ToolSource,
131 tools_to_disable: &[Arc<str>],
132 cx: &mut Context<Self>,
133 ) {
134 self.enabled_tools_by_source
135 .entry(source)
136 .or_default()
137 .retain(|name| !tools_to_disable.contains(name));
138 cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
139 }
140
141 pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
142 self.context_server_tools_by_id
143 .retain(|id, _| !tool_ids_to_remove.contains(id));
144 self.tools_changed();
145 }
146
147 fn tools_changed(&mut self) {
148 self.context_server_tools_by_name.clear();
149 self.context_server_tools_by_name.extend(
150 self.context_server_tools_by_id
151 .values()
152 .map(|tool| (tool.name(), tool.clone())),
153 );
154 }
155}