tool_working_set.rs

  1use std::sync::Arc;
  2
  3use collections::{HashMap, HashSet, IndexMap};
  4use gpui::App;
  5use parking_lot::Mutex;
  6
  7use crate::{Tool, ToolRegistry, ToolSource};
  8
  9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
 10pub struct ToolId(usize);
 11
 12/// A working set of tools for use in one instance of the Assistant Panel.
 13#[derive(Default)]
 14pub struct ToolWorkingSet {
 15    state: Mutex<WorkingSetState>,
 16}
 17
 18#[derive(Default)]
 19struct WorkingSetState {
 20    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
 21    context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
 22    disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
 23    next_tool_id: ToolId,
 24}
 25
 26impl ToolWorkingSet {
 27    pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
 28        self.state
 29            .lock()
 30            .context_server_tools_by_name
 31            .get(name)
 32            .cloned()
 33            .or_else(|| ToolRegistry::global(cx).tool(name))
 34    }
 35
 36    pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 37        self.state.lock().tools(cx)
 38    }
 39
 40    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
 41        self.state.lock().tools_by_source(cx)
 42    }
 43
 44    pub fn are_all_tools_enabled(&self) -> bool {
 45        let state = self.state.lock();
 46        state.disabled_tools_by_source.is_empty()
 47    }
 48
 49    pub fn are_all_tools_from_source_enabled(&self, source: &ToolSource) -> bool {
 50        let state = self.state.lock();
 51        !state.disabled_tools_by_source.contains_key(source)
 52    }
 53
 54    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 55        self.state.lock().enabled_tools(cx)
 56    }
 57
 58    pub fn enable_all_tools(&self) {
 59        let mut state = self.state.lock();
 60        state.disabled_tools_by_source.clear();
 61    }
 62
 63    pub fn disable_all_tools(&self, cx: &App) {
 64        let mut state = self.state.lock();
 65        state.disable_all_tools(cx);
 66    }
 67
 68    pub fn enable_source(&self, source: &ToolSource) {
 69        let mut state = self.state.lock();
 70        state.enable_source(source);
 71    }
 72
 73    pub fn disable_source(&self, source: ToolSource, cx: &App) {
 74        let mut state = self.state.lock();
 75        state.disable_source(source, cx);
 76    }
 77
 78    pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
 79        let mut state = self.state.lock();
 80        let tool_id = state.next_tool_id;
 81        state.next_tool_id.0 += 1;
 82        state
 83            .context_server_tools_by_id
 84            .insert(tool_id, tool.clone());
 85        state.tools_changed();
 86        tool_id
 87    }
 88
 89    pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
 90        self.state.lock().is_enabled(source, name)
 91    }
 92
 93    pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
 94        self.state.lock().is_disabled(source, name)
 95    }
 96
 97    pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
 98        let mut state = self.state.lock();
 99        state.enable(source, tools_to_enable);
100    }
101
102    pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
103        let mut state = self.state.lock();
104        state.disable(source, tools_to_disable);
105    }
106
107    pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
108        let mut state = self.state.lock();
109        state
110            .context_server_tools_by_id
111            .retain(|id, _| !tool_ids_to_remove.contains(id));
112        state.tools_changed();
113    }
114}
115
116impl WorkingSetState {
117    fn tools_changed(&mut self) {
118        self.context_server_tools_by_name.clear();
119        self.context_server_tools_by_name.extend(
120            self.context_server_tools_by_id
121                .values()
122                .map(|tool| (tool.name(), tool.clone())),
123        );
124    }
125
126    fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
127        let mut tools = ToolRegistry::global(cx).tools();
128        tools.extend(self.context_server_tools_by_id.values().cloned());
129
130        tools
131    }
132
133    fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
134        let mut tools_by_source = IndexMap::default();
135
136        for tool in self.tools(cx) {
137            tools_by_source
138                .entry(tool.source())
139                .or_insert_with(Vec::new)
140                .push(tool);
141        }
142
143        for tools in tools_by_source.values_mut() {
144            tools.sort_by_key(|tool| tool.name());
145        }
146
147        tools_by_source.sort_unstable_keys();
148
149        tools_by_source
150    }
151
152    fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
153        let all_tools = self.tools(cx);
154
155        all_tools
156            .into_iter()
157            .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
158            .collect()
159    }
160
161    fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
162        !self.is_disabled(source, name)
163    }
164
165    fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
166        self.disabled_tools_by_source
167            .get(source)
168            .map_or(false, |disabled_tools| disabled_tools.contains(name))
169    }
170
171    fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
172        self.disabled_tools_by_source
173            .entry(source)
174            .or_default()
175            .retain(|name| !tools_to_enable.contains(name));
176    }
177
178    fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
179        self.disabled_tools_by_source
180            .entry(source)
181            .or_default()
182            .extend(tools_to_disable.into_iter().cloned());
183    }
184
185    fn enable_source(&mut self, source: &ToolSource) {
186        self.disabled_tools_by_source.remove(source);
187    }
188
189    fn disable_source(&mut self, source: ToolSource, cx: &App) {
190        let tools_by_source = self.tools_by_source(cx);
191        let Some(tools) = tools_by_source.get(&source) else {
192            return;
193        };
194
195        self.disabled_tools_by_source.insert(
196            source,
197            tools
198                .into_iter()
199                .map(|tool| tool.name().into())
200                .collect::<HashSet<_>>(),
201        );
202    }
203
204    fn disable_all_tools(&mut self, cx: &App) {
205        let tools = self.tools_by_source(cx);
206
207        for (source, tools) in tools {
208            let tool_names = tools
209                .into_iter()
210                .map(|tool| tool.name().into())
211                .collect::<Vec<_>>();
212
213            self.disable(source, &tool_names);
214        }
215    }
216}