1use std::borrow::Borrow;
2
3use crate::{AnyTool, ToolRegistry, ToolSource};
4use collections::{HashMap, HashSet, IndexMap};
5use gpui::{App, SharedString};
6use util::debug_panic;
7
8#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
9pub struct ToolId(usize);
10
11/// A unique identifier for a tool within a working set.
12#[derive(Clone, PartialEq, Eq, Hash, Default)]
13pub struct UniqueToolName(SharedString);
14
15impl Borrow<str> for UniqueToolName {
16 fn borrow(&self) -> &str {
17 &self.0
18 }
19}
20
21impl From<String> for UniqueToolName {
22 fn from(value: String) -> Self {
23 UniqueToolName(SharedString::new(value))
24 }
25}
26
27impl Into<String> for UniqueToolName {
28 fn into(self) -> String {
29 self.0.into()
30 }
31}
32
33impl std::fmt::Debug for UniqueToolName {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 self.0.fmt(f)
36 }
37}
38
39impl std::fmt::Display for UniqueToolName {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "{}", self.0.as_ref())
42 }
43}
44
45/// A working set of tools for use in one instance of the Assistant Panel.
46#[derive(Default)]
47pub struct ToolWorkingSet {
48 context_server_tools_by_id: HashMap<ToolId, AnyTool>,
49 context_server_tools_by_name: HashMap<UniqueToolName, AnyTool>,
50 next_tool_id: ToolId,
51}
52
53impl ToolWorkingSet {
54 pub fn tool(&self, name: &str, cx: &App) -> Option<AnyTool> {
55 self.context_server_tools_by_name
56 .get(name)
57 .cloned()
58 .or_else(|| ToolRegistry::global(cx).tool(name))
59 }
60
61 pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, AnyTool)> {
62 let mut tools = ToolRegistry::global(cx)
63 .tools()
64 .into_iter()
65 .map(|tool| (UniqueToolName(tool.name().into()), tool))
66 .collect::<Vec<_>>();
67 tools.extend(self.context_server_tools_by_name.clone());
68 tools
69 }
70
71 pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<AnyTool>> {
72 let mut tools_by_source = IndexMap::default();
73
74 for (_, tool) in self.tools(cx) {
75 tools_by_source
76 .entry(tool.source())
77 .or_insert_with(Vec::new)
78 .push(tool);
79 }
80
81 for tools in tools_by_source.values_mut() {
82 tools.sort_by_key(|tool| tool.name());
83 }
84
85 tools_by_source.sort_unstable_keys();
86
87 tools_by_source
88 }
89
90 pub fn insert(&mut self, tool: AnyTool, cx: &App) -> ToolId {
91 let tool_id = self.register_tool(tool);
92 self.tools_changed(cx);
93 tool_id
94 }
95
96 pub fn extend(&mut self, tools: impl Iterator<Item = AnyTool>, cx: &App) -> Vec<ToolId> {
97 let ids = tools.map(|tool| self.register_tool(tool)).collect();
98 self.tools_changed(cx);
99 ids
100 }
101
102 pub fn remove(&mut self, tool_ids_to_remove: &[ToolId], cx: &App) {
103 self.context_server_tools_by_id
104 .retain(|id, _| !tool_ids_to_remove.contains(id));
105 self.tools_changed(cx);
106 }
107
108 fn register_tool(&mut self, tool: AnyTool) -> ToolId {
109 let tool_id = self.next_tool_id;
110 self.next_tool_id.0 += 1;
111 self.context_server_tools_by_id
112 .insert(tool_id, tool.clone());
113 tool_id
114 }
115
116 fn tools_changed(&mut self, cx: &App) {
117 self.context_server_tools_by_name = resolve_context_server_tool_name_conflicts(
118 &self
119 .context_server_tools_by_id
120 .values()
121 .cloned()
122 .collect::<Vec<_>>(),
123 &ToolRegistry::global(cx).tools(),
124 );
125 }
126}
127
128fn resolve_context_server_tool_name_conflicts(
129 context_server_tools: &[AnyTool],
130 native_tools: &[AnyTool],
131) -> HashMap<UniqueToolName, AnyTool> {
132 fn resolve_tool_name(tool: &AnyTool) -> String {
133 let mut tool_name = tool.name();
134 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
135 tool_name
136 }
137
138 const MAX_TOOL_NAME_LENGTH: usize = 64;
139
140 let mut duplicated_tool_names = HashSet::default();
141 let mut seen_tool_names = HashSet::default();
142 seen_tool_names.extend(native_tools.iter().map(|tool| tool.name()));
143 for tool in context_server_tools {
144 let tool_name = resolve_tool_name(tool);
145 if seen_tool_names.contains(&tool_name) {
146 debug_assert!(
147 tool.source() != ToolSource::Native,
148 "Expected MCP tool but got a native tool: {}",
149 tool_name
150 );
151 duplicated_tool_names.insert(tool_name);
152 } else {
153 seen_tool_names.insert(tool_name);
154 }
155 }
156
157 if duplicated_tool_names.is_empty() {
158 return context_server_tools
159 .into_iter()
160 .map(|tool| (resolve_tool_name(tool).into(), tool.clone()))
161 .collect();
162 }
163
164 context_server_tools
165 .into_iter()
166 .filter_map(|tool| {
167 let mut tool_name = resolve_tool_name(tool);
168 if !duplicated_tool_names.contains(&tool_name) {
169 return Some((tool_name.into(), tool.clone()));
170 }
171 match tool.source() {
172 ToolSource::Native => {
173 debug_panic!("Expected MCP tool but got a native tool: {}", tool_name);
174 // Built-in tools always keep their original name
175 Some((tool_name.into(), tool.clone()))
176 }
177 ToolSource::ContextServer { id } => {
178 // Context server tools are prefixed with the context server ID, and truncated if necessary
179 tool_name.insert(0, '_');
180 if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
181 let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
182 let mut id = id.to_string();
183 id.truncate(len);
184 tool_name.insert_str(0, &id);
185 } else {
186 tool_name.insert_str(0, &id);
187 }
188
189 tool_name.truncate(MAX_TOOL_NAME_LENGTH);
190
191 if seen_tool_names.contains(&tool_name) {
192 log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
193 None
194 } else {
195 Some((tool_name.into(), tool.clone()))
196 }
197 }
198 }
199 })
200 .collect()
201}
202#[cfg(test)]
203mod tests {
204 use std::sync::Arc;
205
206 use gpui::{AnyWindowHandle, Entity, Task, TestAppContext};
207 use language_model::{LanguageModel, LanguageModelRequest};
208 use project::Project;
209
210 use crate::{ActionLog, Tool, ToolResult};
211
212 use super::*;
213
214 #[gpui::test]
215 fn test_unique_tool_names(cx: &mut TestAppContext) {
216 fn assert_tool(
217 tool_working_set: &ToolWorkingSet,
218 unique_name: &str,
219 expected_name: &str,
220 expected_source: ToolSource,
221 cx: &App,
222 ) {
223 let tool = tool_working_set.tool(unique_name, cx).unwrap();
224 assert_eq!(tool.name(), expected_name);
225 assert_eq!(tool.source(), expected_source);
226 }
227
228 let tool_registry = cx.update(ToolRegistry::default_global);
229 tool_registry.register_tool(TestTool::new("tool1", ToolSource::Native));
230 tool_registry.register_tool(TestTool::new("tool2", ToolSource::Native));
231
232 let mut tool_working_set = ToolWorkingSet::default();
233 cx.update(|cx| {
234 tool_working_set.extend(
235 vec![
236 Arc::new(TestTool::new(
237 "tool2",
238 ToolSource::ContextServer { id: "mcp-1".into() },
239 ))
240 .into(),
241 Arc::new(TestTool::new(
242 "tool2",
243 ToolSource::ContextServer { id: "mcp-2".into() },
244 ))
245 .into(),
246 ]
247 .into_iter(),
248 cx,
249 );
250 });
251
252 cx.update(|cx| {
253 assert_tool(&tool_working_set, "tool1", "tool1", ToolSource::Native, cx);
254 assert_tool(&tool_working_set, "tool2", "tool2", ToolSource::Native, cx);
255 assert_tool(
256 &tool_working_set,
257 "mcp-1_tool2",
258 "tool2",
259 ToolSource::ContextServer { id: "mcp-1".into() },
260 cx,
261 );
262 assert_tool(
263 &tool_working_set,
264 "mcp-2_tool2",
265 "tool2",
266 ToolSource::ContextServer { id: "mcp-2".into() },
267 cx,
268 );
269 })
270 }
271
272 #[gpui::test]
273 fn test_resolve_context_server_tool_name_conflicts() {
274 assert_resolve_context_server_tool_name_conflicts(
275 vec![
276 TestTool::new("tool1", ToolSource::Native),
277 TestTool::new("tool2", ToolSource::Native),
278 ],
279 vec![TestTool::new(
280 "tool3",
281 ToolSource::ContextServer { id: "mcp-1".into() },
282 )],
283 vec!["tool3"],
284 );
285
286 assert_resolve_context_server_tool_name_conflicts(
287 vec![
288 TestTool::new("tool1", ToolSource::Native),
289 TestTool::new("tool2", ToolSource::Native),
290 ],
291 vec![
292 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
293 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
294 ],
295 vec!["mcp-1_tool3", "mcp-2_tool3"],
296 );
297
298 assert_resolve_context_server_tool_name_conflicts(
299 vec![
300 TestTool::new("tool1", ToolSource::Native),
301 TestTool::new("tool2", ToolSource::Native),
302 TestTool::new("tool3", ToolSource::Native),
303 ],
304 vec![
305 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
306 TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
307 ],
308 vec!["mcp-1_tool3", "mcp-2_tool3"],
309 );
310
311 // Test deduplication of tools with very long names, in this case the mcp server name should be truncated
312 assert_resolve_context_server_tool_name_conflicts(
313 vec![TestTool::new(
314 "tool-with-very-very-very-long-name",
315 ToolSource::Native,
316 )],
317 vec![TestTool::new(
318 "tool-with-very-very-very-long-name",
319 ToolSource::ContextServer {
320 id: "mcp-with-very-very-very-long-name".into(),
321 },
322 )],
323 vec!["mcp-with-very-very-very-long-_tool-with-very-very-very-long-name"],
324 );
325
326 fn assert_resolve_context_server_tool_name_conflicts(
327 builtin_tools: Vec<TestTool>,
328 context_server_tools: Vec<TestTool>,
329 expected: Vec<&'static str>,
330 ) {
331 let context_server_tools: Vec<AnyTool> = context_server_tools
332 .into_iter()
333 .map(|t| Arc::new(t).into())
334 .collect();
335 let builtin_tools: Vec<AnyTool> = builtin_tools
336 .into_iter()
337 .map(|t| Arc::new(t).into())
338 .collect();
339 let tools =
340 resolve_context_server_tool_name_conflicts(&context_server_tools, &builtin_tools);
341 assert_eq!(tools.len(), expected.len());
342 for (i, (name, _)) in tools.into_iter().enumerate() {
343 assert_eq!(
344 name.0.as_ref(),
345 expected[i],
346 "Expected '{}' got '{}' at index {}",
347 expected[i],
348 name,
349 i
350 );
351 }
352 }
353 }
354
355 struct TestTool {
356 name: String,
357 source: ToolSource,
358 }
359
360 impl TestTool {
361 fn new(name: impl Into<String>, source: ToolSource) -> Self {
362 Self {
363 name: name.into(),
364 source,
365 }
366 }
367 }
368
369 impl Tool for TestTool {
370 type Input = ();
371
372 fn name(&self) -> String {
373 self.name.clone()
374 }
375
376 fn icon(&self) -> icons::IconName {
377 icons::IconName::Ai
378 }
379
380 fn may_perform_edits(&self) -> bool {
381 false
382 }
383
384 fn needs_confirmation(&self, _input: &Self::Input, _cx: &App) -> bool {
385 true
386 }
387
388 fn source(&self) -> ToolSource {
389 self.source.clone()
390 }
391
392 fn description(&self) -> String {
393 "Test tool".to_string()
394 }
395
396 fn ui_text(&self, _input: &Self::Input) -> String {
397 "Test tool".to_string()
398 }
399
400 fn run(
401 self: Arc<Self>,
402 _input: Self::Input,
403 _request: Arc<LanguageModelRequest>,
404 _project: Entity<Project>,
405 _action_log: Entity<ActionLog>,
406 _model: Arc<dyn LanguageModel>,
407 _window: Option<AnyWindowHandle>,
408 _cx: &mut App,
409 ) -> ToolResult {
410 ToolResult {
411 output: Task::ready(Err(anyhow::anyhow!("No content"))),
412 card: None,
413 }
414 }
415 }
416}