tool_working_set.rs

  1use std::sync::Arc;
  2
  3use collections::{HashMap, HashSet, IndexMap};
  4use gpui::App;
  5use parking_lot::Mutex;
  6
  7use crate::{Tool, ToolRegistry, ToolSource};
  8
  9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
 10pub struct ToolId(usize);
 11
 12/// A working set of tools for use in one instance of the Assistant Panel.
 13#[derive(Default)]
 14pub struct ToolWorkingSet {
 15    state: Mutex<WorkingSetState>,
 16}
 17
 18#[derive(Default)]
 19struct WorkingSetState {
 20    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
 21    context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
 22    disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
 23    is_scripting_tool_disabled: bool,
 24    next_tool_id: ToolId,
 25}
 26
 27impl ToolWorkingSet {
 28    pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
 29        self.state
 30            .lock()
 31            .context_server_tools_by_name
 32            .get(name)
 33            .cloned()
 34            .or_else(|| ToolRegistry::global(cx).tool(name))
 35    }
 36
 37    pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 38        let mut tools = ToolRegistry::global(cx).tools();
 39        tools.extend(
 40            self.state
 41                .lock()
 42                .context_server_tools_by_id
 43                .values()
 44                .cloned(),
 45        );
 46
 47        tools
 48    }
 49
 50    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 51        let all_tools = self.tools(cx);
 52
 53        all_tools
 54            .into_iter()
 55            .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
 56            .collect()
 57    }
 58
 59    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
 60        let mut tools_by_source = IndexMap::default();
 61
 62        for tool in self.tools(cx) {
 63            tools_by_source
 64                .entry(tool.source())
 65                .or_insert_with(Vec::new)
 66                .push(tool);
 67        }
 68
 69        for tools in tools_by_source.values_mut() {
 70            tools.sort_by_key(|tool| tool.name());
 71        }
 72
 73        tools_by_source.sort_unstable_keys();
 74
 75        tools_by_source
 76    }
 77
 78    pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
 79        let mut state = self.state.lock();
 80        let tool_id = state.next_tool_id;
 81        state.next_tool_id.0 += 1;
 82        state
 83            .context_server_tools_by_id
 84            .insert(tool_id, tool.clone());
 85        state.tools_changed();
 86        tool_id
 87    }
 88
 89    pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
 90        !self.is_disabled(source, name)
 91    }
 92
 93    pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
 94        let state = self.state.lock();
 95        state
 96            .disabled_tools_by_source
 97            .get(source)
 98            .map_or(false, |disabled_tools| disabled_tools.contains(name))
 99    }
100
101    pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
102        let mut state = self.state.lock();
103        state
104            .disabled_tools_by_source
105            .entry(source)
106            .or_default()
107            .retain(|name| !tools_to_enable.contains(name));
108    }
109
110    pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
111        let mut state = self.state.lock();
112        state
113            .disabled_tools_by_source
114            .entry(source)
115            .or_default()
116            .extend(tools_to_disable.into_iter().cloned());
117    }
118
119    pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
120        let mut state = self.state.lock();
121        state
122            .context_server_tools_by_id
123            .retain(|id, _| !tool_ids_to_remove.contains(id));
124        state.tools_changed();
125    }
126
127    pub fn is_scripting_tool_enabled(&self) -> bool {
128        let state = self.state.lock();
129        !state.is_scripting_tool_disabled
130    }
131
132    pub fn enable_scripting_tool(&self) {
133        let mut state = self.state.lock();
134        state.is_scripting_tool_disabled = false;
135    }
136
137    pub fn disable_scripting_tool(&self) {
138        let mut state = self.state.lock();
139        state.is_scripting_tool_disabled = true;
140    }
141}
142
143impl WorkingSetState {
144    fn tools_changed(&mut self) {
145        self.context_server_tools_by_name.clear();
146        self.context_server_tools_by_name.extend(
147            self.context_server_tools_by_id
148                .values()
149                .map(|tool| (tool.name(), tool.clone())),
150        );
151    }
152}