tool_working_set.rs

 1use std::sync::Arc;
 2
 3use collections::{HashMap, IndexMap};
 4use gpui::App;
 5
 6use crate::{Tool, ToolRegistry, ToolSource};
 7
 8#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
 9pub struct ToolId(usize);
10
11/// A working set of tools for use in one instance of the Assistant Panel.
12#[derive(Default)]
13pub struct ToolWorkingSet {
14    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
15    context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
16    next_tool_id: ToolId,
17}
18
19impl ToolWorkingSet {
20    pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
21        self.context_server_tools_by_name
22            .get(name)
23            .cloned()
24            .or_else(|| ToolRegistry::global(cx).tool(name))
25    }
26
27    pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
28        let mut tools = ToolRegistry::global(cx).tools();
29        tools.extend(self.context_server_tools_by_id.values().cloned());
30        tools
31    }
32
33    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
34        let mut tools_by_source = IndexMap::default();
35
36        for tool in self.tools(cx) {
37            tools_by_source
38                .entry(tool.source())
39                .or_insert_with(Vec::new)
40                .push(tool);
41        }
42
43        for tools in tools_by_source.values_mut() {
44            tools.sort_by_key(|tool| tool.name());
45        }
46
47        tools_by_source.sort_unstable_keys();
48
49        tools_by_source
50    }
51
52    pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
53        let tool_id = self.next_tool_id;
54        self.next_tool_id.0 += 1;
55        self.context_server_tools_by_id
56            .insert(tool_id, tool.clone());
57        self.tools_changed();
58        tool_id
59    }
60
61    pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
62        self.context_server_tools_by_id
63            .retain(|id, _| !tool_ids_to_remove.contains(id));
64        self.tools_changed();
65    }
66
67    fn tools_changed(&mut self) {
68        self.context_server_tools_by_name.clear();
69        self.context_server_tools_by_name.extend(
70            self.context_server_tools_by_id
71                .values()
72                .map(|tool| (tool.name(), tool.clone())),
73        );
74    }
75}