tool_working_set.rs

  1use std::sync::Arc;
  2
  3use collections::{HashMap, HashSet, IndexMap};
  4use gpui::App;
  5use parking_lot::Mutex;
  6
  7use crate::{Tool, ToolRegistry, ToolSource};
  8
  9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
 10pub struct ToolId(usize);
 11
 12/// A working set of tools for use in one instance of the Assistant Panel.
 13#[derive(Default)]
 14pub struct ToolWorkingSet {
 15    state: Mutex<WorkingSetState>,
 16}
 17
 18struct WorkingSetState {
 19    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
 20    context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
 21    disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
 22    is_scripting_tool_disabled: bool,
 23    next_tool_id: ToolId,
 24}
 25
 26impl Default for WorkingSetState {
 27    fn default() -> Self {
 28        Self {
 29            context_server_tools_by_id: HashMap::default(),
 30            context_server_tools_by_name: HashMap::default(),
 31            disabled_tools_by_source: HashMap::default(),
 32            is_scripting_tool_disabled: true,
 33            next_tool_id: ToolId::default(),
 34        }
 35    }
 36}
 37
 38impl ToolWorkingSet {
 39    pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
 40        self.state
 41            .lock()
 42            .context_server_tools_by_name
 43            .get(name)
 44            .cloned()
 45            .or_else(|| ToolRegistry::global(cx).tool(name))
 46    }
 47
 48    pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 49        self.state.lock().tools(cx)
 50    }
 51
 52    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
 53        self.state.lock().tools_by_source(cx)
 54    }
 55
 56    pub fn are_all_tools_enabled(&self) -> bool {
 57        let state = self.state.lock();
 58        state.disabled_tools_by_source.is_empty() && !state.is_scripting_tool_disabled
 59    }
 60
 61    pub fn are_all_tools_from_source_enabled(&self, source: &ToolSource) -> bool {
 62        let state = self.state.lock();
 63        !state.disabled_tools_by_source.contains_key(source)
 64    }
 65
 66    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 67        self.state.lock().enabled_tools(cx)
 68    }
 69
 70    pub fn enable_all_tools(&self) {
 71        let mut state = self.state.lock();
 72        state.disabled_tools_by_source.clear();
 73        state.enable_scripting_tool();
 74    }
 75
 76    pub fn disable_all_tools(&self, cx: &App) {
 77        let mut state = self.state.lock();
 78        state.disable_all_tools(cx);
 79    }
 80
 81    pub fn enable_source(&self, source: &ToolSource) {
 82        let mut state = self.state.lock();
 83        state.enable_source(source);
 84    }
 85
 86    pub fn disable_source(&self, source: ToolSource, cx: &App) {
 87        let mut state = self.state.lock();
 88        state.disable_source(source, cx);
 89    }
 90
 91    pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
 92        let mut state = self.state.lock();
 93        let tool_id = state.next_tool_id;
 94        state.next_tool_id.0 += 1;
 95        state
 96            .context_server_tools_by_id
 97            .insert(tool_id, tool.clone());
 98        state.tools_changed();
 99        tool_id
100    }
101
102    pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
103        self.state.lock().is_enabled(source, name)
104    }
105
106    pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
107        self.state.lock().is_disabled(source, name)
108    }
109
110    pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
111        let mut state = self.state.lock();
112        state.enable(source, tools_to_enable);
113    }
114
115    pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
116        let mut state = self.state.lock();
117        state.disable(source, tools_to_disable);
118    }
119
120    pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
121        let mut state = self.state.lock();
122        state
123            .context_server_tools_by_id
124            .retain(|id, _| !tool_ids_to_remove.contains(id));
125        state.tools_changed();
126    }
127
128    pub fn is_scripting_tool_enabled(&self) -> bool {
129        let state = self.state.lock();
130        !state.is_scripting_tool_disabled
131    }
132
133    pub fn enable_scripting_tool(&self) {
134        let mut state = self.state.lock();
135        state.enable_scripting_tool();
136    }
137
138    pub fn disable_scripting_tool(&self) {
139        let mut state = self.state.lock();
140        state.disable_scripting_tool();
141    }
142}
143
144impl WorkingSetState {
145    fn tools_changed(&mut self) {
146        self.context_server_tools_by_name.clear();
147        self.context_server_tools_by_name.extend(
148            self.context_server_tools_by_id
149                .values()
150                .map(|tool| (tool.name(), tool.clone())),
151        );
152    }
153
154    fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
155        let mut tools = ToolRegistry::global(cx).tools();
156        tools.extend(self.context_server_tools_by_id.values().cloned());
157
158        tools
159    }
160
161    fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
162        let mut tools_by_source = IndexMap::default();
163
164        for tool in self.tools(cx) {
165            tools_by_source
166                .entry(tool.source())
167                .or_insert_with(Vec::new)
168                .push(tool);
169        }
170
171        for tools in tools_by_source.values_mut() {
172            tools.sort_by_key(|tool| tool.name());
173        }
174
175        tools_by_source.sort_unstable_keys();
176
177        tools_by_source
178    }
179
180    fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
181        let all_tools = self.tools(cx);
182
183        all_tools
184            .into_iter()
185            .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
186            .collect()
187    }
188
189    fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
190        !self.is_disabled(source, name)
191    }
192
193    fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
194        self.disabled_tools_by_source
195            .get(source)
196            .map_or(false, |disabled_tools| disabled_tools.contains(name))
197    }
198
199    fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
200        self.disabled_tools_by_source
201            .entry(source)
202            .or_default()
203            .retain(|name| !tools_to_enable.contains(name));
204    }
205
206    fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
207        self.disabled_tools_by_source
208            .entry(source)
209            .or_default()
210            .extend(tools_to_disable.into_iter().cloned());
211    }
212
213    fn enable_source(&mut self, source: &ToolSource) {
214        self.disabled_tools_by_source.remove(source);
215    }
216
217    fn disable_source(&mut self, source: ToolSource, cx: &App) {
218        let tools_by_source = self.tools_by_source(cx);
219        let Some(tools) = tools_by_source.get(&source) else {
220            return;
221        };
222
223        self.disabled_tools_by_source.insert(
224            source,
225            tools
226                .into_iter()
227                .map(|tool| tool.name().into())
228                .collect::<HashSet<_>>(),
229        );
230    }
231
232    fn disable_all_tools(&mut self, cx: &App) {
233        let tools = self.tools_by_source(cx);
234
235        for (source, tools) in tools {
236            let tool_names = tools
237                .into_iter()
238                .map(|tool| tool.name().into())
239                .collect::<Vec<_>>();
240
241            self.disable(source, &tool_names);
242        }
243
244        self.disable_scripting_tool();
245    }
246
247    fn enable_scripting_tool(&mut self) {
248        self.is_scripting_tool_disabled = false;
249    }
250
251    fn disable_scripting_tool(&mut self) {
252        self.is_scripting_tool_disabled = true;
253    }
254}