tool_working_set.rs

 1use assistant_tool::{Tool, ToolRegistry};
 2use collections::HashMap;
 3use gpui::AppContext;
 4use parking_lot::Mutex;
 5use std::sync::Arc;
 6
 7#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
 8pub struct ToolId(usize);
 9
10/// A working set of tools for use in one instance of the Assistant Panel.
11#[derive(Default)]
12pub struct ToolWorkingSet {
13    state: Mutex<WorkingSetState>,
14}
15
16#[derive(Default)]
17struct WorkingSetState {
18    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
19    context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
20    next_tool_id: ToolId,
21}
22
23impl ToolWorkingSet {
24    pub fn tool(&self, name: &str, cx: &AppContext) -> Option<Arc<dyn Tool>> {
25        self.state
26            .lock()
27            .context_server_tools_by_name
28            .get(name)
29            .cloned()
30            .or_else(|| ToolRegistry::global(cx).tool(name))
31    }
32
33    pub fn tools(&self, cx: &AppContext) -> Vec<Arc<dyn Tool>> {
34        let mut tools = ToolRegistry::global(cx).tools();
35        tools.extend(
36            self.state
37                .lock()
38                .context_server_tools_by_id
39                .values()
40                .cloned(),
41        );
42
43        tools
44    }
45
46    pub fn insert(&self, command: Arc<dyn Tool>) -> ToolId {
47        let mut state = self.state.lock();
48        let command_id = state.next_tool_id;
49        state.next_tool_id.0 += 1;
50        state
51            .context_server_tools_by_id
52            .insert(command_id, command.clone());
53        state.tools_changed();
54        command_id
55    }
56
57    pub fn remove(&self, command_ids_to_remove: &[ToolId]) {
58        let mut state = self.state.lock();
59        state
60            .context_server_tools_by_id
61            .retain(|id, _| !command_ids_to_remove.contains(id));
62        state.tools_changed();
63    }
64}
65
66impl WorkingSetState {
67    fn tools_changed(&mut self) {
68        self.context_server_tools_by_name.clear();
69        self.context_server_tools_by_name.extend(
70            self.context_server_tools_by_id
71                .values()
72                .map(|command| (command.name(), command.clone())),
73        );
74    }
75}