1use std::{borrow::Borrow, sync::Arc};
2
3use crate::{Tool, ToolRegistry, ToolSource};
4use collections::{HashMap, HashSet, IndexMap};
5use gpui::{App, SharedString};
6use util::debug_panic;
7
8#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
9pub struct ToolId(usize);
10
11/// A unique identifier for a tool within a working set.
12#[derive(Clone, PartialEq, Eq, Hash, Default)]
13pub struct UniqueToolName(SharedString);
14
15impl Borrow<str> for UniqueToolName {
16 fn borrow(&self) -> &str {
17 &self.0
18 }
19}
20
21impl From<String> for UniqueToolName {
22 fn from(value: String) -> Self {
23 UniqueToolName(SharedString::new(value))
24 }
25}
26
27impl Into<String> for UniqueToolName {
28 fn into(self) -> String {
29 self.0.into()
30 }
31}
32
33impl std::fmt::Debug for UniqueToolName {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 self.0.fmt(f)
36 }
37}
38
39impl std::fmt::Display for UniqueToolName {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "{}", self.0.as_ref())
42 }
43}
44
45/// A working set of tools for use in one instance of the Assistant Panel.
46#[derive(Default)]
47pub struct ToolWorkingSet {
48 context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
49 context_server_tools_by_name: HashMap<UniqueToolName, Arc<dyn Tool>>,
50 next_tool_id: ToolId,
51}
52
53impl ToolWorkingSet {
54 pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
55 self.context_server_tools_by_name
56 .get(name)
57 .cloned()
58 .or_else(|| ToolRegistry::global(cx).tool(name))
59 }
60
61 pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc<dyn Tool>)> {
62 let mut tools = ToolRegistry::global(cx)
63 .tools()
64 .into_iter()
65 .map(|tool| (UniqueToolName(tool.name().into()), tool))
66 .collect::<Vec<_>>();
67 tools.extend(self.context_server_tools_by_name.clone());
68 tools
69 }
70
71 pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
72 let mut tools_by_source = IndexMap::default();
73
74 for (_, tool) in self.tools(cx) {
75 tools_by_source
76 .entry(tool.source())
77 .or_insert_with(Vec::new)
78 .push(tool);
79 }
80
81 for tools in tools_by_source.values_mut() {
82 tools.sort_by_key(|tool| tool.name());
83 }
84
85 tools_by_source.sort_unstable_keys();
86
87 tools_by_source
88 }
89
90 pub fn insert(&mut self, tool: Arc<dyn Tool>, cx: &App) -> ToolId {
91 let tool_id = self.register_tool(tool);
92 self.tools_changed(cx);
93 tool_id
94 }
95
96 pub fn extend(&mut self, tools: impl Iterator<Item = Arc<dyn Tool>>, cx: &App) -> Vec<ToolId> {
97 let ids = tools.map(|tool| self.register_tool(tool)).collect();
98 self.tools_changed(cx);
99 ids
100 }
101
102 pub fn remove(&mut self, tool_ids_to_remove: &[ToolId], cx: &App) {
103 self.context_server_tools_by_id
104 .retain(|id, _| !tool_ids_to_remove.contains(id));
105 self.tools_changed(cx);
106 }
107
108 fn register_tool(&mut self, tool: Arc<dyn Tool>) -> ToolId {
109 let tool_id = self.next_tool_id;
110 self.next_tool_id.0 += 1;
111 self.context_server_tools_by_id
112 .insert(tool_id, tool.clone());
113 tool_id
114 }
115
116 fn tools_changed(&mut self, cx: &App) {
117 self.context_server_tools_by_name = resolve_context_server_tool_name_conflicts(
118 &self
119 .context_server_tools_by_id
120 .values()
121 .cloned()
122 .collect::<Vec<_>>(),
123 &ToolRegistry::global(cx).tools(),
124 );
125 }
126}
127
128fn resolve_context_server_tool_name_conflicts(
129 context_server_tools: &[Arc<dyn Tool>],
130 native_tools: &[Arc<dyn Tool>],
131) -> HashMap<UniqueToolName, Arc<dyn Tool>> {
132 fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
133 let mut tool_name = tool.name();
134 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
135 tool_name
136 }
137
138 const MAX_TOOL_NAME_LENGTH: usize = 64;
139
140 let mut duplicated_tool_names = HashSet::default();
141 let mut seen_tool_names = HashSet::default();
142 seen_tool_names.extend(native_tools.iter().map(|tool| tool.name()));
143 for tool in context_server_tools {
144 let tool_name = resolve_tool_name(tool);
145 if seen_tool_names.contains(&tool_name) {
146 debug_assert!(
147 tool.source() != ToolSource::Native,
148 "Expected MCP tool but got a native tool: {}",
149 tool_name
150 );
151 duplicated_tool_names.insert(tool_name);
152 } else {
153 seen_tool_names.insert(tool_name);
154 }
155 }
156
157 if duplicated_tool_names.is_empty() {
158 return context_server_tools
159 .iter()
160 .map(|tool| (resolve_tool_name(tool).into(), tool.clone()))
161 .collect();
162 }
163
164 context_server_tools
165 .iter()
166 .filter_map(|tool| {
167 let mut tool_name = resolve_tool_name(tool);
168 if !duplicated_tool_names.contains(&tool_name) {
169 return Some((tool_name.into(), tool.clone()));
170 }
171 match tool.source() {
172 ToolSource::Native => {
173 debug_panic!("Expected MCP tool but got a native tool: {}", tool_name);
174 // Built-in tools always keep their original name
175 Some((tool_name.into(), tool.clone()))
176 }
177 ToolSource::ContextServer { id } => {
178 // Context server tools are prefixed with the context server ID, and truncated if necessary
179 tool_name.insert(0, '_');
180 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
181 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
182 let mut id = id.to_string();
183 id.truncate(len);
184 tool_name.insert_str(0, &id);
185 } else {
186 tool_name.insert_str(0, &id);
187 }
188
189 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
190
191 if seen_tool_names.contains(&tool_name) {
192 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
193 None
194 } else {
195 Some((tool_name.into(), tool.clone()))
196 }
197 }
198 }
199 })
200 .collect()
201}
202#[cfg(test)]
203mod tests {
204 use gpui::{AnyWindowHandle, Entity, Task, TestAppContext};
205 use language_model::{LanguageModel, LanguageModelRequest};
206 use project::Project;
207
208 use crate::{ActionLog, ToolResult};
209
210 use super::*;
211
212 #[gpui::test]
213 fn test_unique_tool_names(cx: &mut TestAppContext) {
214 fn assert_tool(
215 tool_working_set: &ToolWorkingSet,
216 unique_name: &str,
217 expected_name: &str,
218 expected_source: ToolSource,
219 cx: &App,
220 ) {
221 let tool = tool_working_set.tool(unique_name, cx).unwrap();
222 assert_eq!(tool.name(), expected_name);
223 assert_eq!(tool.source(), expected_source);
224 }
225
226 let tool_registry = cx.update(ToolRegistry::default_global);
227 tool_registry.register_tool(TestTool::new("tool1", ToolSource::Native));
228 tool_registry.register_tool(TestTool::new("tool2", ToolSource::Native));
229
230 let mut tool_working_set = ToolWorkingSet::default();
231 cx.update(|cx| {
232 tool_working_set.extend(
233 vec![
234 Arc::new(TestTool::new(
235 "tool2",
236 ToolSource::ContextServer { id: "mcp-1".into() },
237 )) as Arc<dyn Tool>,
238 Arc::new(TestTool::new(
239 "tool2",
240 ToolSource::ContextServer { id: "mcp-2".into() },
241 )) as Arc<dyn Tool>,
242 ]
243 .into_iter(),
244 cx,
245 );
246 });
247
248 cx.update(|cx| {
249 assert_tool(&tool_working_set, "tool1", "tool1", ToolSource::Native, cx);
250 assert_tool(&tool_working_set, "tool2", "tool2", ToolSource::Native, cx);
251 assert_tool(
252 &tool_working_set,
253 "mcp-1_tool2",
254 "tool2",
255 ToolSource::ContextServer { id: "mcp-1".into() },
256 cx,
257 );
258 assert_tool(
259 &tool_working_set,
260 "mcp-2_tool2",
261 "tool2",
262 ToolSource::ContextServer { id: "mcp-2".into() },
263 cx,
264 );
265 })
266 }
267
268 #[gpui::test]
269 fn test_resolve_context_server_tool_name_conflicts() {
270 assert_resolve_context_server_tool_name_conflicts(
271 vec![
272 TestTool::new("tool1", ToolSource::Native),
273 TestTool::new("tool2", ToolSource::Native),
274 ],
275 vec![TestTool::new(
276 "tool3",
277 ToolSource::ContextServer { id: "mcp-1".into() },
278 )],
279 vec!["tool3"],
280 );
281
282 assert_resolve_context_server_tool_name_conflicts(
283 vec![
284 TestTool::new("tool1", ToolSource::Native),
285 TestTool::new("tool2", ToolSource::Native),
286 ],
287 vec![
288 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
289 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
290 ],
291 vec!["mcp-1_tool3", "mcp-2_tool3"],
292 );
293
294 assert_resolve_context_server_tool_name_conflicts(
295 vec![
296 TestTool::new("tool1", ToolSource::Native),
297 TestTool::new("tool2", ToolSource::Native),
298 TestTool::new("tool3", ToolSource::Native),
299 ],
300 vec![
301 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
302 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
303 ],
304 vec!["mcp-1_tool3", "mcp-2_tool3"],
305 );
306
307 // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
308 assert_resolve_context_server_tool_name_conflicts(
309 vec![TestTool::new(
310 "tool-with-very-very-very-long-name",
311 ToolSource::Native,
312 )],
313 vec![TestTool::new(
314 "tool-with-very-very-very-long-name",
315 ToolSource::ContextServer {
316 id: "mcp-with-very-very-very-long-name".into(),
317 },
318 )],
319 vec!["mcp-with-very-very-very-long-_tool-with-very-very-very-long-name"],
320 );
321
322 fn assert_resolve_context_server_tool_name_conflicts(
323 builtin_tools: Vec<TestTool>,
324 context_server_tools: Vec<TestTool>,
325 expected: Vec<&'static str>,
326 ) {
327 let context_server_tools: Vec<Arc<dyn Tool>> = context_server_tools
328 .into_iter()
329 .map(|t| Arc::new(t) as Arc<dyn Tool>)
330 .collect();
331 let builtin_tools: Vec<Arc<dyn Tool>> = builtin_tools
332 .into_iter()
333 .map(|t| Arc::new(t) as Arc<dyn Tool>)
334 .collect();
335 let tools =
336 resolve_context_server_tool_name_conflicts(&context_server_tools, &builtin_tools);
337 assert_eq!(tools.len(), expected.len());
338 for (i, (name, _)) in tools.into_iter().enumerate() {
339 assert_eq!(
340 name.0.as_ref(),
341 expected[i],
342 "Expected '{}' got '{}' at index {}",
343 expected[i],
344 name,
345 i
346 );
347 }
348 }
349 }
350
351 struct TestTool {
352 name: String,
353 source: ToolSource,
354 }
355
356 impl TestTool {
357 fn new(name: impl Into<String>, source: ToolSource) -> Self {
358 Self {
359 name: name.into(),
360 source,
361 }
362 }
363 }
364
365 impl Tool for TestTool {
366 fn name(&self) -> String {
367 self.name.clone()
368 }
369
370 fn icon(&self) -> icons::IconName {
371 icons::IconName::Ai
372 }
373
374 fn may_perform_edits(&self) -> bool {
375 false
376 }
377
378 fn needs_confirmation(
379 &self,
380 _input: &serde_json::Value,
381 _project: &Entity<Project>,
382 _cx: &App,
383 ) -> bool {
384 true
385 }
386
387 fn source(&self) -> ToolSource {
388 self.source.clone()
389 }
390
391 fn description(&self) -> String {
392 "Test tool".to_string()
393 }
394
395 fn ui_text(&self, _input: &serde_json::Value) -> String {
396 "Test tool".to_string()
397 }
398
399 fn run(
400 self: Arc<Self>,
401 _input: serde_json::Value,
402 _request: Arc<LanguageModelRequest>,
403 _project: Entity<Project>,
404 _action_log: Entity<ActionLog>,
405 _model: Arc<dyn LanguageModel>,
406 _window: Option<AnyWindowHandle>,
407 _cx: &mut App,
408 ) -> ToolResult {
409 ToolResult {
410 output: Task::ready(Err(anyhow::anyhow!("No content"))),
411 card: None,
412 }
413 }
414 }
415}