tool_working_set.rs

  1use std::{borrow::Borrow, sync::Arc};
  2
  3use crate::{Tool, ToolRegistry, ToolSource};
  4use collections::{HashMap, HashSet, IndexMap};
  5use gpui::{App, SharedString};
  6use util::debug_panic;
  7
  8#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
  9pub struct ToolId(usize);
 10
 11/// A unique identifier for a tool within a working set.
 12#[derive(Clone, PartialEq, Eq, Hash, Default)]
 13pub struct UniqueToolName(SharedString);
 14
 15impl Borrow<str> for UniqueToolName {
 16    fn borrow(&self) -> &str {
 17        &self.0
 18    }
 19}
 20
 21impl From<String> for UniqueToolName {
 22    fn from(value: String) -> Self {
 23        UniqueToolName(SharedString::new(value))
 24    }
 25}
 26
 27impl Into<String> for UniqueToolName {
 28    fn into(self) -> String {
 29        self.0.into()
 30    }
 31}
 32
 33impl std::fmt::Debug for UniqueToolName {
 34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 35        self.0.fmt(f)
 36    }
 37}
 38
 39impl std::fmt::Display for UniqueToolName {
 40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 41        write!(f, "{}", self.0.as_ref())
 42    }
 43}
 44
 45/// A working set of tools for use in one instance of the Assistant Panel.
 46#[derive(Default)]
 47pub struct ToolWorkingSet {
 48    context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
 49    context_server_tools_by_name: HashMap<UniqueToolName, Arc<dyn Tool>>,
 50    next_tool_id: ToolId,
 51}
 52
 53impl ToolWorkingSet {
 54    pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
 55        self.context_server_tools_by_name
 56            .get(name)
 57            .cloned()
 58            .or_else(|| ToolRegistry::global(cx).tool(name))
 59    }
 60
 61    pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc<dyn Tool>)> {
 62        let mut tools = ToolRegistry::global(cx)
 63            .tools()
 64            .into_iter()
 65            .map(|tool| (UniqueToolName(tool.name().into()), tool))
 66            .collect::<Vec<_>>();
 67        tools.extend(self.context_server_tools_by_name.clone());
 68        tools
 69    }
 70
 71    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
 72        let mut tools_by_source = IndexMap::default();
 73
 74        for (_, tool) in self.tools(cx) {
 75            tools_by_source
 76                .entry(tool.source())
 77                .or_insert_with(Vec::new)
 78                .push(tool);
 79        }
 80
 81        for tools in tools_by_source.values_mut() {
 82            tools.sort_by_key(|tool| tool.name());
 83        }
 84
 85        tools_by_source.sort_unstable_keys();
 86
 87        tools_by_source
 88    }
 89
 90    pub fn insert(&mut self, tool: Arc<dyn Tool>, cx: &App) -> ToolId {
 91        let tool_id = self.register_tool(tool);
 92        self.tools_changed(cx);
 93        tool_id
 94    }
 95
 96    pub fn extend(&mut self, tools: impl Iterator<Item = Arc<dyn Tool>>, cx: &App) -> Vec<ToolId> {
 97        let ids = tools.map(|tool| self.register_tool(tool)).collect();
 98        self.tools_changed(cx);
 99        ids
100    }
101
102    pub fn remove(&mut self, tool_ids_to_remove: &[ToolId], cx: &App) {
103        self.context_server_tools_by_id
104            .retain(|id, _| !tool_ids_to_remove.contains(id));
105        self.tools_changed(cx);
106    }
107
108    fn register_tool(&mut self, tool: Arc<dyn Tool>) -> ToolId {
109        let tool_id = self.next_tool_id;
110        self.next_tool_id.0 += 1;
111        self.context_server_tools_by_id
112            .insert(tool_id, tool.clone());
113        tool_id
114    }
115
116    fn tools_changed(&mut self, cx: &App) {
117        self.context_server_tools_by_name = resolve_context_server_tool_name_conflicts(
118            &self
119                .context_server_tools_by_id
120                .values()
121                .cloned()
122                .collect::<Vec<_>>(),
123            &ToolRegistry::global(cx).tools(),
124        );
125    }
126}
127
128fn resolve_context_server_tool_name_conflicts(
129    context_server_tools: &[Arc<dyn Tool>],
130    native_tools: &[Arc<dyn Tool>],
131) -> HashMap<UniqueToolName, Arc<dyn Tool>> {
132    fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
133        let mut tool_name = tool.name();
134        tool_name.truncate(MAX_TOOL_NAME_LENGTH);
135        tool_name
136    }
137
138    const MAX_TOOL_NAME_LENGTH: usize = 64;
139
140    let mut duplicated_tool_names = HashSet::default();
141    let mut seen_tool_names = HashSet::default();
142    seen_tool_names.extend(native_tools.iter().map(|tool| tool.name()));
143    for tool in context_server_tools {
144        let tool_name = resolve_tool_name(tool);
145        if seen_tool_names.contains(&tool_name) {
146            debug_assert!(
147                tool.source() != ToolSource::Native,
148                "Expected MCP tool but got a native tool: {}",
149                tool_name
150            );
151            duplicated_tool_names.insert(tool_name);
152        } else {
153            seen_tool_names.insert(tool_name);
154        }
155    }
156
157    if duplicated_tool_names.is_empty() {
158        return context_server_tools
159            .iter()
160            .map(|tool| (resolve_tool_name(tool).into(), tool.clone()))
161            .collect();
162    }
163
164    context_server_tools
165        .iter()
166        .filter_map(|tool| {
167            let mut tool_name = resolve_tool_name(tool);
168            if !duplicated_tool_names.contains(&tool_name) {
169                return Some((tool_name.into(), tool.clone()));
170            }
171            match tool.source() {
172                ToolSource::Native => {
173                    debug_panic!("Expected MCP tool but got a native tool: {}", tool_name);
174                    // Built-in tools always keep their original name
175                    Some((tool_name.into(), tool.clone()))
176                }
177                ToolSource::ContextServer { id } => {
178                    // Context server tools are prefixed with the context server ID, and truncated if necessary
179                    tool_name.insert(0, '_');
180                    if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
181                        let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
182                        let mut id = id.to_string();
183                        id.truncate(len);
184                        tool_name.insert_str(0, &id);
185                    } else {
186                        tool_name.insert_str(0, &id);
187                    }
188
189                    tool_name.truncate(MAX_TOOL_NAME_LENGTH);
190
191                    if seen_tool_names.contains(&tool_name) {
192                        log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
193                        None
194                    } else {
195                        Some((tool_name.into(), tool.clone()))
196                    }
197                }
198            }
199        })
200        .collect()
201}
202#[cfg(test)]
203mod tests {
204    use gpui::{AnyWindowHandle, Entity, Task, TestAppContext};
205    use language_model::{LanguageModel, LanguageModelRequest};
206    use project::Project;
207
208    use crate::{ActionLog, ToolResult};
209
210    use super::*;
211
212    #[gpui::test]
213    fn test_unique_tool_names(cx: &mut TestAppContext) {
214        fn assert_tool(
215            tool_working_set: &ToolWorkingSet,
216            unique_name: &str,
217            expected_name: &str,
218            expected_source: ToolSource,
219            cx: &App,
220        ) {
221            let tool = tool_working_set.tool(unique_name, cx).unwrap();
222            assert_eq!(tool.name(), expected_name);
223            assert_eq!(tool.source(), expected_source);
224        }
225
226        let tool_registry = cx.update(ToolRegistry::default_global);
227        tool_registry.register_tool(TestTool::new("tool1", ToolSource::Native));
228        tool_registry.register_tool(TestTool::new("tool2", ToolSource::Native));
229
230        let mut tool_working_set = ToolWorkingSet::default();
231        cx.update(|cx| {
232            tool_working_set.extend(
233                vec![
234                    Arc::new(TestTool::new(
235                        "tool2",
236                        ToolSource::ContextServer { id: "mcp-1".into() },
237                    )) as Arc<dyn Tool>,
238                    Arc::new(TestTool::new(
239                        "tool2",
240                        ToolSource::ContextServer { id: "mcp-2".into() },
241                    )) as Arc<dyn Tool>,
242                ]
243                .into_iter(),
244                cx,
245            );
246        });
247
248        cx.update(|cx| {
249            assert_tool(&tool_working_set, "tool1", "tool1", ToolSource::Native, cx);
250            assert_tool(&tool_working_set, "tool2", "tool2", ToolSource::Native, cx);
251            assert_tool(
252                &tool_working_set,
253                "mcp-1_tool2",
254                "tool2",
255                ToolSource::ContextServer { id: "mcp-1".into() },
256                cx,
257            );
258            assert_tool(
259                &tool_working_set,
260                "mcp-2_tool2",
261                "tool2",
262                ToolSource::ContextServer { id: "mcp-2".into() },
263                cx,
264            );
265        })
266    }
267
268    #[gpui::test]
269    fn test_resolve_context_server_tool_name_conflicts() {
270        assert_resolve_context_server_tool_name_conflicts(
271            vec![
272                TestTool::new("tool1", ToolSource::Native),
273                TestTool::new("tool2", ToolSource::Native),
274            ],
275            vec![TestTool::new(
276                "tool3",
277                ToolSource::ContextServer { id: "mcp-1".into() },
278            )],
279            vec!["tool3"],
280        );
281
282        assert_resolve_context_server_tool_name_conflicts(
283            vec![
284                TestTool::new("tool1", ToolSource::Native),
285                TestTool::new("tool2", ToolSource::Native),
286            ],
287            vec![
288                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
289                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
290            ],
291            vec!["mcp-1_tool3", "mcp-2_tool3"],
292        );
293
294        assert_resolve_context_server_tool_name_conflicts(
295            vec![
296                TestTool::new("tool1", ToolSource::Native),
297                TestTool::new("tool2", ToolSource::Native),
298                TestTool::new("tool3", ToolSource::Native),
299            ],
300            vec![
301                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
302                TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
303            ],
304            vec!["mcp-1_tool3", "mcp-2_tool3"],
305        );
306
307        // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
308        assert_resolve_context_server_tool_name_conflicts(
309            vec![TestTool::new(
310                "tool-with-very-very-very-long-name",
311                ToolSource::Native,
312            )],
313            vec![TestTool::new(
314                "tool-with-very-very-very-long-name",
315                ToolSource::ContextServer {
316                    id: "mcp-with-very-very-very-long-name".into(),
317                },
318            )],
319            vec!["mcp-with-very-very-very-long-_tool-with-very-very-very-long-name"],
320        );
321
322        fn assert_resolve_context_server_tool_name_conflicts(
323            builtin_tools: Vec<TestTool>,
324            context_server_tools: Vec<TestTool>,
325            expected: Vec<&'static str>,
326        ) {
327            let context_server_tools: Vec<Arc<dyn Tool>> = context_server_tools
328                .into_iter()
329                .map(|t| Arc::new(t) as Arc<dyn Tool>)
330                .collect();
331            let builtin_tools: Vec<Arc<dyn Tool>> = builtin_tools
332                .into_iter()
333                .map(|t| Arc::new(t) as Arc<dyn Tool>)
334                .collect();
335            let tools =
336                resolve_context_server_tool_name_conflicts(&context_server_tools, &builtin_tools);
337            assert_eq!(tools.len(), expected.len());
338            for (i, (name, _)) in tools.into_iter().enumerate() {
339                assert_eq!(
340                    name.0.as_ref(),
341                    expected[i],
342                    "Expected '{}' got '{}' at index {}",
343                    expected[i],
344                    name,
345                    i
346                );
347            }
348        }
349    }
350
351    struct TestTool {
352        name: String,
353        source: ToolSource,
354    }
355
356    impl TestTool {
357        fn new(name: impl Into<String>, source: ToolSource) -> Self {
358            Self {
359                name: name.into(),
360                source,
361            }
362        }
363    }
364
365    impl Tool for TestTool {
366        fn name(&self) -> String {
367            self.name.clone()
368        }
369
370        fn icon(&self) -> icons::IconName {
371            icons::IconName::Ai
372        }
373
374        fn may_perform_edits(&self) -> bool {
375            false
376        }
377
378        fn needs_confirmation(
379            &self,
380            _input: &serde_json::Value,
381            _project: &Entity<Project>,
382            _cx: &App,
383        ) -> bool {
384            true
385        }
386
387        fn source(&self) -> ToolSource {
388            self.source.clone()
389        }
390
391        fn description(&self) -> String {
392            "Test tool".to_string()
393        }
394
395        fn ui_text(&self, _input: &serde_json::Value) -> String {
396            "Test tool".to_string()
397        }
398
399        fn run(
400            self: Arc<Self>,
401            _input: serde_json::Value,
402            _request: Arc<LanguageModelRequest>,
403            _project: Entity<Project>,
404            _action_log: Entity<ActionLog>,
405            _model: Arc<dyn LanguageModel>,
406            _window: Option<AnyWindowHandle>,
407            _cx: &mut App,
408        ) -> ToolResult {
409            ToolResult {
410                output: Task::ready(Err(anyhow::anyhow!("No content"))),
411                card: None,
412            }
413        }
414    }
415}