tool_working_set.rs

  1use std::sync::Arc;
  2
  3use collections::{HashMap, HashSet, IndexMap};
  4use gpui::App;
  5use parking_lot::Mutex;
  6
  7use crate::{Tool, ToolRegistry, ToolSource};
  8
  9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
 10pub struct ToolId(usize);
 11
 12/// A working set of tools for use in one instance of the Assistant Panel.
 13#[derive(Default)]
 14pub struct ToolWorkingSet {
 15    state: Mutex<WorkingSetState>,
 16}
 17
 18struct WorkingSetState {
 19    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
 20    context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
 21    disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
 22    is_scripting_tool_disabled: bool,
 23    next_tool_id: ToolId,
 24}
 25
 26impl Default for WorkingSetState {
 27    fn default() -> Self {
 28        Self {
 29            context_server_tools_by_id: HashMap::default(),
 30            context_server_tools_by_name: HashMap::default(),
 31            disabled_tools_by_source: HashMap::default(),
 32            is_scripting_tool_disabled: true,
 33            next_tool_id: ToolId::default(),
 34        }
 35    }
 36}
 37
 38impl ToolWorkingSet {
 39    pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
 40        self.state
 41            .lock()
 42            .context_server_tools_by_name
 43            .get(name)
 44            .cloned()
 45            .or_else(|| ToolRegistry::global(cx).tool(name))
 46    }
 47
 48    pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 49        let mut tools = ToolRegistry::global(cx).tools();
 50        tools.extend(
 51            self.state
 52                .lock()
 53                .context_server_tools_by_id
 54                .values()
 55                .cloned(),
 56        );
 57
 58        tools
 59    }
 60
 61    pub fn are_all_tools_enabled(&self) -> bool {
 62        let state = self.state.lock();
 63
 64        state.disabled_tools_by_source.is_empty() && !state.is_scripting_tool_disabled
 65    }
 66
 67    pub fn enable_all_tools(&self) {
 68        let mut state = self.state.lock();
 69
 70        state.disabled_tools_by_source.clear();
 71        state.is_scripting_tool_disabled = false;
 72    }
 73
 74    pub fn disable_all_tools(&self, cx: &App) {
 75        let tools = self.tools_by_source(cx);
 76
 77        for (source, tools) in tools {
 78            let tool_names = tools
 79                .into_iter()
 80                .map(|tool| tool.name().into())
 81                .collect::<Vec<_>>();
 82
 83            self.disable(source, &tool_names);
 84        }
 85
 86        self.disable_scripting_tool();
 87    }
 88
 89    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 90        let all_tools = self.tools(cx);
 91
 92        all_tools
 93            .into_iter()
 94            .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
 95            .collect()
 96    }
 97
 98    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
 99        let mut tools_by_source = IndexMap::default();
100
101        for tool in self.tools(cx) {
102            tools_by_source
103                .entry(tool.source())
104                .or_insert_with(Vec::new)
105                .push(tool);
106        }
107
108        for tools in tools_by_source.values_mut() {
109            tools.sort_by_key(|tool| tool.name());
110        }
111
112        tools_by_source.sort_unstable_keys();
113
114        tools_by_source
115    }
116
117    pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
118        let mut state = self.state.lock();
119        let tool_id = state.next_tool_id;
120        state.next_tool_id.0 += 1;
121        state
122            .context_server_tools_by_id
123            .insert(tool_id, tool.clone());
124        state.tools_changed();
125        tool_id
126    }
127
128    pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
129        !self.is_disabled(source, name)
130    }
131
132    pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
133        let state = self.state.lock();
134        state
135            .disabled_tools_by_source
136            .get(source)
137            .map_or(false, |disabled_tools| disabled_tools.contains(name))
138    }
139
140    pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
141        let mut state = self.state.lock();
142        state
143            .disabled_tools_by_source
144            .entry(source)
145            .or_default()
146            .retain(|name| !tools_to_enable.contains(name));
147    }
148
149    pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
150        let mut state = self.state.lock();
151        state
152            .disabled_tools_by_source
153            .entry(source)
154            .or_default()
155            .extend(tools_to_disable.into_iter().cloned());
156    }
157
158    pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
159        let mut state = self.state.lock();
160        state
161            .context_server_tools_by_id
162            .retain(|id, _| !tool_ids_to_remove.contains(id));
163        state.tools_changed();
164    }
165
166    pub fn is_scripting_tool_enabled(&self) -> bool {
167        let state = self.state.lock();
168        !state.is_scripting_tool_disabled
169    }
170
171    pub fn enable_scripting_tool(&self) {
172        let mut state = self.state.lock();
173        state.is_scripting_tool_disabled = false;
174    }
175
176    pub fn disable_scripting_tool(&self) {
177        let mut state = self.state.lock();
178        state.is_scripting_tool_disabled = true;
179    }
180}
181
182impl WorkingSetState {
183    fn tools_changed(&mut self) {
184        self.context_server_tools_by_name.clear();
185        self.context_server_tools_by_name.extend(
186            self.context_server_tools_by_id
187                .values()
188                .map(|tool| (tool.name(), tool.clone())),
189        );
190    }
191}