tool_working_set.rs

 1use std::sync::Arc;
 2
 3use collections::HashMap;
 4use gpui::App;
 5use parking_lot::Mutex;
 6
 7use crate::{Tool, ToolRegistry};
 8
 9#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
10pub struct ToolId(usize);
11
12/// A working set of tools for use in one instance of the Assistant Panel.
13#[derive(Default)]
14pub struct ToolWorkingSet {
15    state: Mutex<WorkingSetState>,
16}
17
18#[derive(Default)]
19struct WorkingSetState {
20    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
21    context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
22    next_tool_id: ToolId,
23}
24
25impl ToolWorkingSet {
26    pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
27        self.state
28            .lock()
29            .context_server_tools_by_name
30            .get(name)
31            .cloned()
32            .or_else(|| ToolRegistry::global(cx).tool(name))
33    }
34
35    pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
36        let mut tools = ToolRegistry::global(cx).tools();
37        tools.extend(
38            self.state
39                .lock()
40                .context_server_tools_by_id
41                .values()
42                .cloned(),
43        );
44
45        tools
46    }
47
48    pub fn insert(&self, command: Arc<dyn Tool>) -> ToolId {
49        let mut state = self.state.lock();
50        let command_id = state.next_tool_id;
51        state.next_tool_id.0 += 1;
52        state
53            .context_server_tools_by_id
54            .insert(command_id, command.clone());
55        state.tools_changed();
56        command_id
57    }
58
59    pub fn remove(&self, command_ids_to_remove: &[ToolId]) {
60        let mut state = self.state.lock();
61        state
62            .context_server_tools_by_id
63            .retain(|id, _| !command_ids_to_remove.contains(id));
64        state.tools_changed();
65    }
66}
67
68impl WorkingSetState {
69    fn tools_changed(&mut self) {
70        self.context_server_tools_by_name.clear();
71        self.context_server_tools_by_name.extend(
72            self.context_server_tools_by_id
73                .values()
74                .map(|command| (command.name(), command.clone())),
75        );
76    }
77}