tool_working_set.rs

  1use std::sync::Arc;
  2
  3use collections::{HashMap, HashSet, IndexMap};
  4use gpui::App;
  5use parking_lot::Mutex;
  6
  7use crate::{Tool, ToolRegistry, ToolSource};
  8
  9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
 10pub struct ToolId(usize);
 11
 12/// A working set of tools for use in one instance of the Assistant Panel.
 13#[derive(Default)]
 14pub struct ToolWorkingSet {
 15    state: Mutex<WorkingSetState>,
 16}
 17
 18struct WorkingSetState {
 19    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
 20    context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
 21    disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
 22    is_scripting_tool_disabled: bool,
 23    next_tool_id: ToolId,
 24}
 25
 26impl Default for WorkingSetState {
 27    fn default() -> Self {
 28        Self {
 29            context_server_tools_by_id: Default::default(),
 30            context_server_tools_by_name: Default::default(),
 31            disabled_tools_by_source: Default::default(),
 32            is_scripting_tool_disabled: true,
 33            next_tool_id: Default::default(),
 34        }
 35    }
 36}
 37
 38impl ToolWorkingSet {
 39    pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
 40        self.state
 41            .lock()
 42            .context_server_tools_by_name
 43            .get(name)
 44            .cloned()
 45            .or_else(|| ToolRegistry::global(cx).tool(name))
 46    }
 47
 48    pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 49        let mut tools = ToolRegistry::global(cx).tools();
 50        tools.extend(
 51            self.state
 52                .lock()
 53                .context_server_tools_by_id
 54                .values()
 55                .cloned(),
 56        );
 57
 58        tools
 59    }
 60
 61    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 62        let all_tools = self.tools(cx);
 63
 64        all_tools
 65            .into_iter()
 66            .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
 67            .collect()
 68    }
 69
 70    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
 71        let mut tools_by_source = IndexMap::default();
 72
 73        for tool in self.tools(cx) {
 74            tools_by_source
 75                .entry(tool.source())
 76                .or_insert_with(Vec::new)
 77                .push(tool);
 78        }
 79
 80        for tools in tools_by_source.values_mut() {
 81            tools.sort_by_key(|tool| tool.name());
 82        }
 83
 84        tools_by_source.sort_unstable_keys();
 85
 86        tools_by_source
 87    }
 88
 89    pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
 90        let mut state = self.state.lock();
 91        let tool_id = state.next_tool_id;
 92        state.next_tool_id.0 += 1;
 93        state
 94            .context_server_tools_by_id
 95            .insert(tool_id, tool.clone());
 96        state.tools_changed();
 97        tool_id
 98    }
 99
100    pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
101        !self.is_disabled(source, name)
102    }
103
104    pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
105        let state = self.state.lock();
106        state
107            .disabled_tools_by_source
108            .get(source)
109            .map_or(false, |disabled_tools| disabled_tools.contains(name))
110    }
111
112    pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
113        let mut state = self.state.lock();
114        state
115            .disabled_tools_by_source
116            .entry(source)
117            .or_default()
118            .retain(|name| !tools_to_enable.contains(name));
119    }
120
121    pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
122        let mut state = self.state.lock();
123        state
124            .disabled_tools_by_source
125            .entry(source)
126            .or_default()
127            .extend(tools_to_disable.into_iter().cloned());
128    }
129
130    pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
131        let mut state = self.state.lock();
132        state
133            .context_server_tools_by_id
134            .retain(|id, _| !tool_ids_to_remove.contains(id));
135        state.tools_changed();
136    }
137
138    pub fn is_scripting_tool_enabled(&self) -> bool {
139        let state = self.state.lock();
140        !state.is_scripting_tool_disabled
141    }
142
143    pub fn enable_scripting_tool(&self) {
144        let mut state = self.state.lock();
145        state.is_scripting_tool_disabled = false;
146    }
147
148    pub fn disable_scripting_tool(&self) {
149        let mut state = self.state.lock();
150        state.is_scripting_tool_disabled = true;
151    }
152}
153
154impl WorkingSetState {
155    fn tools_changed(&mut self) {
156        self.context_server_tools_by_name.clear();
157        self.context_server_tools_by_name.extend(
158            self.context_server_tools_by_id
159                .values()
160                .map(|tool| (tool.name(), tool.clone())),
161        );
162    }
163}