tool_working_set.rs

  1use std::sync::Arc;
  2
  3use collections::{HashMap, HashSet, IndexMap};
  4use gpui::{App, Context, EventEmitter};
  5
  6use crate::{Tool, ToolRegistry, ToolSource};
  7
  8#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
  9pub struct ToolId(usize);
 10
 11/// A working set of tools for use in one instance of the Assistant Panel.
 12#[derive(Default)]
 13pub struct ToolWorkingSet {
 14    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
 15    context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
 16    enabled_sources: HashSet<ToolSource>,
 17    enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
 18    next_tool_id: ToolId,
 19}
 20
 21pub enum ToolWorkingSetEvent {
 22    EnabledToolsChanged,
 23}
 24
 25impl EventEmitter<ToolWorkingSetEvent> for ToolWorkingSet {}
 26
 27impl ToolWorkingSet {
 28    pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
 29        self.context_server_tools_by_name
 30            .get(name)
 31            .cloned()
 32            .or_else(|| ToolRegistry::global(cx).tool(name))
 33    }
 34
 35    pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 36        let mut tools = ToolRegistry::global(cx).tools();
 37        tools.extend(self.context_server_tools_by_id.values().cloned());
 38        tools
 39    }
 40
 41    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
 42        let mut tools_by_source = IndexMap::default();
 43
 44        for tool in self.tools(cx) {
 45            tools_by_source
 46                .entry(tool.source())
 47                .or_insert_with(Vec::new)
 48                .push(tool);
 49        }
 50
 51        for tools in tools_by_source.values_mut() {
 52            tools.sort_by_key(|tool| tool.name());
 53        }
 54
 55        tools_by_source.sort_unstable_keys();
 56
 57        tools_by_source
 58    }
 59
 60    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
 61        let all_tools = self.tools(cx);
 62
 63        all_tools
 64            .into_iter()
 65            .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
 66            .collect()
 67    }
 68
 69    pub fn disable_all_tools(&mut self, cx: &mut Context<Self>) {
 70        self.enabled_tools_by_source.clear();
 71        cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
 72    }
 73
 74    pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context<Self>) {
 75        self.enabled_sources.insert(source.clone());
 76
 77        let tools_by_source = self.tools_by_source(cx);
 78        if let Some(tools) = tools_by_source.get(&source) {
 79            self.enabled_tools_by_source.insert(
 80                source,
 81                tools
 82                    .into_iter()
 83                    .map(|tool| tool.name().into())
 84                    .collect::<HashSet<_>>(),
 85            );
 86        }
 87        cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
 88    }
 89
 90    pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context<Self>) {
 91        self.enabled_sources.remove(source);
 92        self.enabled_tools_by_source.remove(source);
 93        cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
 94    }
 95
 96    pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
 97        let tool_id = self.next_tool_id;
 98        self.next_tool_id.0 += 1;
 99        self.context_server_tools_by_id
100            .insert(tool_id, tool.clone());
101        self.tools_changed();
102        tool_id
103    }
104
105    pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
106        self.enabled_tools_by_source
107            .get(source)
108            .map_or(false, |enabled_tools| enabled_tools.contains(name))
109    }
110
111    pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
112        !self.is_enabled(source, name)
113    }
114
115    pub fn enable(
116        &mut self,
117        source: ToolSource,
118        tools_to_enable: &[Arc<str>],
119        cx: &mut Context<Self>,
120    ) {
121        self.enabled_tools_by_source
122            .entry(source)
123            .or_default()
124            .extend(tools_to_enable.into_iter().cloned());
125        cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
126    }
127
128    pub fn disable(
129        &mut self,
130        source: ToolSource,
131        tools_to_disable: &[Arc<str>],
132        cx: &mut Context<Self>,
133    ) {
134        self.enabled_tools_by_source
135            .entry(source)
136            .or_default()
137            .retain(|name| !tools_to_disable.contains(name));
138        cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
139    }
140
141    pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
142        self.context_server_tools_by_id
143            .retain(|id, _| !tool_ids_to_remove.contains(id));
144        self.tools_changed();
145    }
146
147    fn tools_changed(&mut self) {
148        self.context_server_tools_by_name.clear();
149        self.context_server_tools_by_name.extend(
150            self.context_server_tools_by_id
151                .values()
152                .map(|tool| (tool.name(), tool.clone())),
153        );
154    }
155}