tool_working_set.rs

  1use std::sync::Arc;
  2
  3use collections::{HashMap, HashSet, IndexMap};
  4use gpui::App;
  5use parking_lot::Mutex;
  6
  7use crate::{Tool, ToolRegistry, ToolSource};
  8
  9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
 10pub struct ToolId(usize);
 11
 12/// A working set of tools for use in one instance of the Assistant Panel.
 13#[derive(Default)]
 14pub struct ToolWorkingSet {
 15    state: Mutex<WorkingSetState>,
 16}
 17
 18#[derive(Default)]
 19struct WorkingSetState {
 20    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
 21    context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
 22    enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
 23    next_tool_id: ToolId,
 24}
 25
 26impl ToolWorkingSet {
 27    pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
 28        self.state
 29            .lock()
 30            .context_server_tools_by_name
 31            .get(name)
 32            .cloned()
 33            .or_else(|| ToolRegistry::global(cx).tool(name))
 34    }
 35
 36    pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 37        self.state.lock().tools(cx)
 38    }
 39
 40    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
 41        self.state.lock().tools_by_source(cx)
 42    }
 43
 44    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 45        self.state.lock().enabled_tools(cx)
 46    }
 47
 48    pub fn disable_all_tools(&self) {
 49        let mut state = self.state.lock();
 50        state.disable_all_tools();
 51    }
 52
 53    pub fn enable_source(&self, source: ToolSource, cx: &App) {
 54        let mut state = self.state.lock();
 55        state.enable_source(source, cx);
 56    }
 57
 58    pub fn disable_source(&self, source: &ToolSource) {
 59        let mut state = self.state.lock();
 60        state.disable_source(source);
 61    }
 62
 63    pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
 64        let mut state = self.state.lock();
 65        let tool_id = state.next_tool_id;
 66        state.next_tool_id.0 += 1;
 67        state
 68            .context_server_tools_by_id
 69            .insert(tool_id, tool.clone());
 70        state.tools_changed();
 71        tool_id
 72    }
 73
 74    pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
 75        self.state.lock().is_enabled(source, name)
 76    }
 77
 78    pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
 79        self.state.lock().is_disabled(source, name)
 80    }
 81
 82    pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
 83        let mut state = self.state.lock();
 84        state.enable(source, tools_to_enable);
 85    }
 86
 87    pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
 88        let mut state = self.state.lock();
 89        state.disable(source, tools_to_disable);
 90    }
 91
 92    pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
 93        let mut state = self.state.lock();
 94        state
 95            .context_server_tools_by_id
 96            .retain(|id, _| !tool_ids_to_remove.contains(id));
 97        state.tools_changed();
 98    }
 99}
100
101impl WorkingSetState {
102    fn tools_changed(&mut self) {
103        self.context_server_tools_by_name.clear();
104        self.context_server_tools_by_name.extend(
105            self.context_server_tools_by_id
106                .values()
107                .map(|tool| (tool.name(), tool.clone())),
108        );
109    }
110
111    fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
112        let mut tools = ToolRegistry::global(cx).tools();
113        tools.extend(self.context_server_tools_by_id.values().cloned());
114
115        tools
116    }
117
118    fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
119        let mut tools_by_source = IndexMap::default();
120
121        for tool in self.tools(cx) {
122            tools_by_source
123                .entry(tool.source())
124                .or_insert_with(Vec::new)
125                .push(tool);
126        }
127
128        for tools in tools_by_source.values_mut() {
129            tools.sort_by_key(|tool| tool.name());
130        }
131
132        tools_by_source.sort_unstable_keys();
133
134        tools_by_source
135    }
136
137    fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
138        let all_tools = self.tools(cx);
139
140        all_tools
141            .into_iter()
142            .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
143            .collect()
144    }
145
146    fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
147        self.enabled_tools_by_source
148            .get(source)
149            .map_or(false, |enabled_tools| enabled_tools.contains(name))
150    }
151
152    fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
153        !self.is_enabled(source, name)
154    }
155
156    fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
157        self.enabled_tools_by_source
158            .entry(source)
159            .or_default()
160            .extend(tools_to_enable.into_iter().cloned());
161    }
162
163    fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
164        self.enabled_tools_by_source
165            .entry(source)
166            .or_default()
167            .retain(|name| !tools_to_disable.contains(name));
168    }
169
170    fn enable_source(&mut self, source: ToolSource, cx: &App) {
171        let tools_by_source = self.tools_by_source(cx);
172        let Some(tools) = tools_by_source.get(&source) else {
173            return;
174        };
175
176        self.enabled_tools_by_source.insert(
177            source,
178            tools
179                .into_iter()
180                .map(|tool| tool.name().into())
181                .collect::<HashSet<_>>(),
182        );
183    }
184
185    fn disable_source(&mut self, source: &ToolSource) {
186        self.enabled_tools_by_source.remove(source);
187    }
188
189    fn disable_all_tools(&mut self) {
190        self.enabled_tools_by_source.clear();
191    }
192}