tool_working_set.rs

  1use std::sync::Arc;
  2
  3use collections::{HashMap, HashSet, IndexMap};
  4use gpui::App;
  5use parking_lot::Mutex;
  6
  7use crate::{Tool, ToolRegistry, ToolSource};
  8
  9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
 10pub struct ToolId(usize);
 11
 12/// A working set of tools for use in one instance of the Assistant Panel.
 13#[derive(Default)]
 14pub struct ToolWorkingSet {
 15    state: Mutex<WorkingSetState>,
 16}
 17
 18struct WorkingSetState {
 19    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
 20    context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
 21    disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
 22    is_scripting_tool_disabled: bool,
 23    next_tool_id: ToolId,
 24}
 25
 26impl Default for WorkingSetState {
 27    fn default() -> Self {
 28        Self {
 29            context_server_tools_by_id: HashMap::default(),
 30            context_server_tools_by_name: HashMap::default(),
 31            disabled_tools_by_source: HashMap::default(),
 32            is_scripting_tool_disabled: true,
 33            next_tool_id: ToolId::default(),
 34        }
 35    }
 36}
 37
 38impl ToolWorkingSet {
 39    pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
 40        self.state
 41            .lock()
 42            .context_server_tools_by_name
 43            .get(name)
 44            .cloned()
 45            .or_else(|| ToolRegistry::global(cx).tool(name))
 46    }
 47
 48    pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 49        self.state.lock().tools(cx)
 50    }
 51
 52    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
 53        self.state.lock().tools_by_source(cx)
 54    }
 55
 56    pub fn are_all_tools_enabled(&self) -> bool {
 57        let state = self.state.lock();
 58        state.disabled_tools_by_source.is_empty() && !state.is_scripting_tool_disabled
 59    }
 60
 61    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 62        self.state.lock().enabled_tools(cx)
 63    }
 64
 65    pub fn enable_all_tools(&self) {
 66        let mut state = self.state.lock();
 67        state.disabled_tools_by_source.clear();
 68        state.enable_scripting_tool();
 69    }
 70
 71    pub fn disable_all_tools(&self, cx: &App) {
 72        let mut state = self.state.lock();
 73        state.disable_all_tools(cx);
 74    }
 75
 76    pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
 77        let mut state = self.state.lock();
 78        let tool_id = state.next_tool_id;
 79        state.next_tool_id.0 += 1;
 80        state
 81            .context_server_tools_by_id
 82            .insert(tool_id, tool.clone());
 83        state.tools_changed();
 84        tool_id
 85    }
 86
 87    pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
 88        self.state.lock().is_enabled(source, name)
 89    }
 90
 91    pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
 92        self.state.lock().is_disabled(source, name)
 93    }
 94
 95    pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
 96        let mut state = self.state.lock();
 97        state.enable(source, tools_to_enable);
 98    }
 99
100    pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
101        let mut state = self.state.lock();
102        state.disable(source, tools_to_disable);
103    }
104
105    pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
106        let mut state = self.state.lock();
107        state
108            .context_server_tools_by_id
109            .retain(|id, _| !tool_ids_to_remove.contains(id));
110        state.tools_changed();
111    }
112
113    pub fn is_scripting_tool_enabled(&self) -> bool {
114        let state = self.state.lock();
115        !state.is_scripting_tool_disabled
116    }
117
118    pub fn enable_scripting_tool(&self) {
119        let mut state = self.state.lock();
120        state.enable_scripting_tool();
121    }
122
123    pub fn disable_scripting_tool(&self) {
124        let mut state = self.state.lock();
125        state.disable_scripting_tool();
126    }
127}
128
129impl WorkingSetState {
130    fn tools_changed(&mut self) {
131        self.context_server_tools_by_name.clear();
132        self.context_server_tools_by_name.extend(
133            self.context_server_tools_by_id
134                .values()
135                .map(|tool| (tool.name(), tool.clone())),
136        );
137    }
138
139    fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
140        let mut tools = ToolRegistry::global(cx).tools();
141        tools.extend(self.context_server_tools_by_id.values().cloned());
142
143        tools
144    }
145
146    fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
147        let mut tools_by_source = IndexMap::default();
148
149        for tool in self.tools(cx) {
150            tools_by_source
151                .entry(tool.source())
152                .or_insert_with(Vec::new)
153                .push(tool);
154        }
155
156        for tools in tools_by_source.values_mut() {
157            tools.sort_by_key(|tool| tool.name());
158        }
159
160        tools_by_source.sort_unstable_keys();
161
162        tools_by_source
163    }
164
165    fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
166        let all_tools = self.tools(cx);
167
168        all_tools
169            .into_iter()
170            .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
171            .collect()
172    }
173
174    fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
175        !self.is_disabled(source, name)
176    }
177
178    fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
179        self.disabled_tools_by_source
180            .get(source)
181            .map_or(false, |disabled_tools| disabled_tools.contains(name))
182    }
183
184    fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
185        self.disabled_tools_by_source
186            .entry(source)
187            .or_default()
188            .retain(|name| !tools_to_enable.contains(name));
189    }
190
191    fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
192        self.disabled_tools_by_source
193            .entry(source)
194            .or_default()
195            .extend(tools_to_disable.into_iter().cloned());
196    }
197
198    fn disable_all_tools(&mut self, cx: &App) {
199        let tools = self.tools_by_source(cx);
200
201        for (source, tools) in tools {
202            let tool_names = tools
203                .into_iter()
204                .map(|tool| tool.name().into())
205                .collect::<Vec<_>>();
206
207            self.disable(source, &tool_names);
208        }
209
210        self.disable_scripting_tool();
211    }
212
213    fn enable_scripting_tool(&mut self) {
214        self.is_scripting_tool_disabled = false;
215    }
216
217    fn disable_scripting_tool(&mut self) {
218        self.is_scripting_tool_disabled = true;
219    }
220}