tool_working_set.rs

  1use std::sync::Arc;
  2
  3use collections::{HashMap, HashSet, IndexMap};
  4use gpui::App;
  5use parking_lot::Mutex;
  6
  7use crate::{Tool, ToolRegistry, ToolSource};
  8
  9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
 10pub struct ToolId(usize);
 11
 12/// A working set of tools for use in one instance of the Assistant Panel.
 13#[derive(Default)]
 14pub struct ToolWorkingSet {
 15    state: Mutex<WorkingSetState>,
 16}
 17
 18#[derive(Default)]
 19struct WorkingSetState {
 20    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
 21    context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
 22    disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
 23    next_tool_id: ToolId,
 24}
 25
 26impl ToolWorkingSet {
 27    pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
 28        self.state
 29            .lock()
 30            .context_server_tools_by_name
 31            .get(name)
 32            .cloned()
 33            .or_else(|| ToolRegistry::global(cx).tool(name))
 34    }
 35
 36    pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 37        let mut tools = ToolRegistry::global(cx).tools();
 38        tools.extend(
 39            self.state
 40                .lock()
 41                .context_server_tools_by_id
 42                .values()
 43                .cloned(),
 44        );
 45
 46        tools
 47    }
 48
 49    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 50        let all_tools = self.tools(cx);
 51
 52        all_tools
 53            .into_iter()
 54            .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
 55            .collect()
 56    }
 57
 58    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
 59        let mut tools_by_source = IndexMap::default();
 60
 61        for tool in self.tools(cx) {
 62            tools_by_source
 63                .entry(tool.source())
 64                .or_insert_with(Vec::new)
 65                .push(tool);
 66        }
 67
 68        for tools in tools_by_source.values_mut() {
 69            tools.sort_by_key(|tool| tool.name());
 70        }
 71
 72        tools_by_source.sort_unstable_keys();
 73
 74        tools_by_source
 75    }
 76
 77    pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
 78        let mut state = self.state.lock();
 79        let tool_id = state.next_tool_id;
 80        state.next_tool_id.0 += 1;
 81        state
 82            .context_server_tools_by_id
 83            .insert(tool_id, tool.clone());
 84        state.tools_changed();
 85        tool_id
 86    }
 87
 88    pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
 89        !self.is_disabled(source, name)
 90    }
 91
 92    pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
 93        let state = self.state.lock();
 94        state
 95            .disabled_tools_by_source
 96            .get(source)
 97            .map_or(false, |disabled_tools| disabled_tools.contains(name))
 98    }
 99
100    pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
101        let mut state = self.state.lock();
102        state
103            .disabled_tools_by_source
104            .entry(source)
105            .or_default()
106            .retain(|name| !tools_to_enable.contains(name));
107    }
108
109    pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
110        let mut state = self.state.lock();
111        state
112            .disabled_tools_by_source
113            .entry(source)
114            .or_default()
115            .extend(tools_to_disable.into_iter().cloned());
116    }
117
118    pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
119        let mut state = self.state.lock();
120        state
121            .context_server_tools_by_id
122            .retain(|id, _| !tool_ids_to_remove.contains(id));
123        state.tools_changed();
124    }
125}
126
127impl WorkingSetState {
128    fn tools_changed(&mut self) {
129        self.context_server_tools_by_name.clear();
130        self.context_server_tools_by_name.extend(
131            self.context_server_tools_by_id
132                .values()
133                .map(|tool| (tool.name(), tool.clone())),
134        );
135    }
136}