tool_working_set.rs

  1use std::sync::Arc;
  2
  3use collections::{HashMap, HashSet, IndexMap};
  4use gpui::App;
  5use parking_lot::Mutex;
  6
  7use crate::{Tool, ToolRegistry, ToolSource};
  8
  9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
 10pub struct ToolId(usize);
 11
 12/// A working set of tools for use in one instance of the Assistant Panel.
 13#[derive(Default)]
 14pub struct ToolWorkingSet {
 15    state: Mutex<WorkingSetState>,
 16}
 17
 18#[derive(Default)]
 19struct WorkingSetState {
 20    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
 21    context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
 22    enabled_sources: HashSet<ToolSource>,
 23    enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
 24    next_tool_id: ToolId,
 25}
 26
 27impl ToolWorkingSet {
 28    pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
 29        self.state
 30            .lock()
 31            .context_server_tools_by_name
 32            .get(name)
 33            .cloned()
 34            .or_else(|| ToolRegistry::global(cx).tool(name))
 35    }
 36
 37    pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 38        self.state.lock().tools(cx)
 39    }
 40
 41    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
 42        self.state.lock().tools_by_source(cx)
 43    }
 44
 45    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 46        self.state.lock().enabled_tools(cx)
 47    }
 48
 49    pub fn disable_all_tools(&self) {
 50        let mut state = self.state.lock();
 51        state.disable_all_tools();
 52    }
 53
 54    pub fn enable_source(&self, source: ToolSource, cx: &App) {
 55        let mut state = self.state.lock();
 56        state.enable_source(source, cx);
 57    }
 58
 59    pub fn disable_source(&self, source: &ToolSource) {
 60        let mut state = self.state.lock();
 61        state.disable_source(source);
 62    }
 63
 64    pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
 65        let mut state = self.state.lock();
 66        let tool_id = state.next_tool_id;
 67        state.next_tool_id.0 += 1;
 68        state
 69            .context_server_tools_by_id
 70            .insert(tool_id, tool.clone());
 71        state.tools_changed();
 72        tool_id
 73    }
 74
 75    pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
 76        self.state.lock().is_enabled(source, name)
 77    }
 78
 79    pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
 80        self.state.lock().is_disabled(source, name)
 81    }
 82
 83    pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
 84        let mut state = self.state.lock();
 85        state.enable(source, tools_to_enable);
 86    }
 87
 88    pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
 89        let mut state = self.state.lock();
 90        state.disable(source, tools_to_disable);
 91    }
 92
 93    pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
 94        let mut state = self.state.lock();
 95        state
 96            .context_server_tools_by_id
 97            .retain(|id, _| !tool_ids_to_remove.contains(id));
 98        state.tools_changed();
 99    }
100}
101
102impl WorkingSetState {
103    fn tools_changed(&mut self) {
104        self.context_server_tools_by_name.clear();
105        self.context_server_tools_by_name.extend(
106            self.context_server_tools_by_id
107                .values()
108                .map(|tool| (tool.name(), tool.clone())),
109        );
110    }
111
112    fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
113        let mut tools = ToolRegistry::global(cx).tools();
114        tools.extend(self.context_server_tools_by_id.values().cloned());
115
116        tools
117    }
118
119    fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
120        let mut tools_by_source = IndexMap::default();
121
122        for tool in self.tools(cx) {
123            tools_by_source
124                .entry(tool.source())
125                .or_insert_with(Vec::new)
126                .push(tool);
127        }
128
129        for tools in tools_by_source.values_mut() {
130            tools.sort_by_key(|tool| tool.name());
131        }
132
133        tools_by_source.sort_unstable_keys();
134
135        tools_by_source
136    }
137
138    fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
139        let all_tools = self.tools(cx);
140
141        all_tools
142            .into_iter()
143            .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
144            .collect()
145    }
146
147    fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
148        self.enabled_tools_by_source
149            .get(source)
150            .map_or(false, |enabled_tools| enabled_tools.contains(name))
151    }
152
153    fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
154        !self.is_enabled(source, name)
155    }
156
157    fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
158        self.enabled_tools_by_source
159            .entry(source)
160            .or_default()
161            .extend(tools_to_enable.into_iter().cloned());
162    }
163
164    fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
165        self.enabled_tools_by_source
166            .entry(source)
167            .or_default()
168            .retain(|name| !tools_to_disable.contains(name));
169    }
170
171    fn enable_source(&mut self, source: ToolSource, cx: &App) {
172        self.enabled_sources.insert(source.clone());
173
174        let tools_by_source = self.tools_by_source(cx);
175        if let Some(tools) = tools_by_source.get(&source) {
176            self.enabled_tools_by_source.insert(
177                source,
178                tools
179                    .into_iter()
180                    .map(|tool| tool.name().into())
181                    .collect::<HashSet<_>>(),
182            );
183        }
184    }
185
186    fn disable_source(&mut self, source: &ToolSource) {
187        self.enabled_sources.remove(source);
188        self.enabled_tools_by_source.remove(source);
189    }
190
191    fn disable_all_tools(&mut self) {
192        self.enabled_tools_by_source.clear();
193    }
194}