diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 07f53f114cf38bd2f5a2b1db6140b71dc06d0680..8cea712c19cebb3261573184349987d67492ac62 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -336,13 +336,12 @@ impl NativeAgent { ) }); - self.register_session(thread, None, cx) + self.register_session(thread, cx) } fn register_session( &mut self, thread_handle: Entity, - allowed_tool_names: Option>, cx: &mut Context, ) -> Entity { let connection = Rc::new(NativeAgentConnection(cx.entity())); @@ -374,7 +373,6 @@ impl NativeAgent { thread_handle.update(cx, |thread, cx| { thread.set_summarization_model(summarization_model, cx); thread.add_default_tools( - allowed_tool_names, Rc::new(NativeThreadEnvironment { acp_thread: acp_thread.downgrade(), agent: weak, @@ -804,9 +802,8 @@ impl NativeAgent { let task = self.load_thread(id, cx); cx.spawn(async move |this, cx| { let thread = task.await?; - let acp_thread = this.update(cx, |this, cx| { - this.register_session(thread.clone(), None, cx) - })?; + let acp_thread = + this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?; let events = thread.update(cx, |thread, cx| thread.replay(cx)); cx.update(|cx| { NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx) @@ -1601,7 +1598,6 @@ impl NativeThreadEnvironment { MAX_SUBAGENT_DEPTH )); } - let allowed_tool_names = Some(parent_thread.tools.keys().cloned().collect::>()); let subagent_thread: Entity = cx.new(|cx| { let mut thread = Thread::new_subagent(&parent_thread_entity, cx); @@ -1612,7 +1608,7 @@ impl NativeThreadEnvironment { let session_id = subagent_thread.read(cx).id().clone(); let acp_thread = agent.update(cx, |agent, cx| { - agent.register_session(subagent_thread.clone(), allowed_tool_names, cx) + agent.register_session(subagent_thread.clone(), cx) })?; parent_thread_entity.update(cx, |parent_thread, _cx| { diff --git a/crates/agent/src/tests/edit_file_thread_test.rs b/crates/agent/src/tests/edit_file_thread_test.rs index 068c0270cf7057790d3665f7f1fac59d1d3f1d07..069bf0349299e6f4952f673cbf7607e52d48d9c5 100644 --- a/crates/agent/src/tests/edit_file_thread_test.rs +++ b/crates/agent/src/tests/edit_file_thread_test.rs @@ -49,23 +49,17 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) { ); // Add just the tools we need for this test let language_registry = project.read(cx).languages().clone(); - thread.add_tool( - crate::ReadFileTool::new( - cx.weak_entity(), - project.clone(), - thread.action_log().clone(), - ), - None, - ); - thread.add_tool( - crate::EditFileTool::new( - project.clone(), - cx.weak_entity(), - language_registry, - crate::Templates::new(), - ), - None, - ); + thread.add_tool(crate::ReadFileTool::new( + cx.weak_entity(), + project.clone(), + thread.action_log().clone(), + )); + thread.add_tool(crate::EditFileTool::new( + project.clone(), + cx.weak_entity(), + language_registry, + crate::Templates::new(), + )); thread }); diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index ce6037406383147491f20df2288abb4b33a27224..6f8cd32dbcde72c6262b3c386926fced224043a7 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -464,7 +464,7 @@ async fn test_system_prompt(cx: &mut TestAppContext) { project_context.update(cx, |project_context, _cx| { project_context.shell = "test-shell".into() }); - thread.update(cx, |thread, _| thread.add_tool(EchoTool, None)); + thread.update(cx, |thread, _| thread.add_tool(EchoTool)); thread .update(cx, |thread, cx| { thread.send(UserMessageId::new(), ["abc"], cx) @@ -600,7 +600,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) { cx.run_until_parked(); // Simulate a tool call and verify that the latest tool result is cached - thread.update(cx, |thread, _| thread.add_tool(EchoTool, None)); + thread.update(cx, |thread, _| thread.add_tool(EchoTool)); thread .update(cx, |thread, cx| { thread.send(UserMessageId::new(), ["Use the echo tool"], cx) @@ -686,7 +686,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { // Test a tool call that's likely to complete *before* streaming stops. let events = thread .update(cx, |thread, cx| { - thread.add_tool(EchoTool, None); + thread.add_tool(EchoTool); thread.send( UserMessageId::new(), ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."], @@ -702,7 +702,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { let events = thread .update(cx, |thread, cx| { thread.remove_tool(&EchoTool::NAME); - thread.add_tool(DelayTool, None); + thread.add_tool(DelayTool); thread.send( UserMessageId::new(), [ @@ -746,7 +746,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { // Test a tool call that's likely to complete *before* streaming stops. let mut events = thread .update(cx, |thread, cx| { - thread.add_tool(WordListTool, None); + thread.add_tool(WordListTool); thread.send(UserMessageId::new(), ["Test the word_list tool."], cx) }) .unwrap(); @@ -797,7 +797,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { let mut events = thread .update(cx, |thread, cx| { - thread.add_tool(ToolRequiringPermission, None); + thread.add_tool(ToolRequiringPermission); thread.send(UserMessageId::new(), ["abc"], cx) }) .unwrap(); @@ -1207,7 +1207,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { // Test concurrent tool calls with different delay times let events = thread .update(cx, |thread, cx| { - thread.add_tool(DelayTool, None); + thread.add_tool(DelayTool); thread.send( UserMessageId::new(), [ @@ -1252,9 +1252,9 @@ async fn test_profiles(cx: &mut TestAppContext) { let fake_model = model.as_fake(); thread.update(cx, |thread, _cx| { - thread.add_tool(DelayTool, None); - thread.add_tool(EchoTool, None); - thread.add_tool(InfiniteTool, None); + thread.add_tool(DelayTool); + thread.add_tool(EchoTool); + thread.add_tool(InfiniteTool); }); // Override profiles and wait for settings to be loaded. @@ -1420,7 +1420,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) { // Send again after adding the echo tool, ensuring the name collision is resolved. let events = thread.update(cx, |thread, cx| { - thread.add_tool(EchoTool, None); + thread.add_tool(EchoTool); thread.send(UserMessageId::new(), ["Go"], cx).unwrap() }); cx.run_until_parked(); @@ -1711,11 +1711,11 @@ async fn test_mcp_tool_truncation(cx: &mut TestAppContext) { thread.update(cx, |thread, cx| { thread.set_profile(AgentProfileId("test".into()), cx); - thread.add_tool(EchoTool, None); - thread.add_tool(DelayTool, None); - thread.add_tool(WordListTool, None); - thread.add_tool(ToolRequiringPermission, None); - thread.add_tool(InfiniteTool, None); + thread.add_tool(EchoTool); + thread.add_tool(DelayTool); + thread.add_tool(WordListTool); + thread.add_tool(ToolRequiringPermission); + thread.add_tool(InfiniteTool); }); // Set up multiple context servers with some overlapping tool names @@ -1863,8 +1863,8 @@ async fn test_cancellation(cx: &mut TestAppContext) { let mut events = thread .update(cx, |thread, cx| { - thread.add_tool(InfiniteTool, None); - thread.add_tool(EchoTool, None); + thread.add_tool(InfiniteTool); + thread.add_tool(EchoTool); thread.send( UserMessageId::new(), ["Call the echo tool, then call the infinite tool, then explain their output"], @@ -1955,10 +1955,10 @@ async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext let mut events = thread .update(cx, |thread, cx| { - thread.add_tool( - crate::TerminalTool::new(thread.project().clone(), environment), - None, - ); + thread.add_tool(crate::TerminalTool::new( + thread.project().clone(), + environment, + )); thread.send(UserMessageId::new(), ["run a command"], cx) }) .unwrap(); @@ -2052,7 +2052,7 @@ async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppC let mut events = thread .update(cx, |thread, cx| { - thread.add_tool(tool, None); + thread.add_tool(tool); thread.send( UserMessageId::new(), ["call the cancellation aware tool"], @@ -2238,10 +2238,10 @@ async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) { let message_id = UserMessageId::new(); let mut events = thread .update(cx, |thread, cx| { - thread.add_tool( - crate::TerminalTool::new(thread.project().clone(), environment), - None, - ); + thread.add_tool(crate::TerminalTool::new( + thread.project().clone(), + environment, + )); thread.send(message_id.clone(), ["run a command"], cx) }) .unwrap(); @@ -2302,10 +2302,10 @@ async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) let mut events = thread .update(cx, |thread, cx| { - thread.add_tool( - crate::TerminalTool::new(thread.project().clone(), environment.clone()), - None, - ); + thread.add_tool(crate::TerminalTool::new( + thread.project().clone(), + environment.clone(), + )); thread.send(UserMessageId::new(), ["run multiple commands"], cx) }) .unwrap(); @@ -2415,10 +2415,10 @@ async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppCon let mut events = thread .update(cx, |thread, cx| { - thread.add_tool( - crate::TerminalTool::new(thread.project().clone(), environment), - None, - ); + thread.add_tool(crate::TerminalTool::new( + thread.project().clone(), + environment, + )); thread.send(UserMessageId::new(), ["run a command"], cx) }) .unwrap(); @@ -2509,10 +2509,10 @@ async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) { let mut events = thread .update(cx, |thread, cx| { - thread.add_tool( - crate::TerminalTool::new(thread.project().clone(), environment), - None, - ); + thread.add_tool(crate::TerminalTool::new( + thread.project().clone(), + environment, + )); thread.send(UserMessageId::new(), ["run a command with timeout"], cx) }) .unwrap(); @@ -2997,8 +2997,8 @@ async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) { let _events = thread .update(cx, |thread, cx| { - thread.add_tool(ToolRequiringPermission, None); - thread.add_tool(EchoTool, None); + thread.add_tool(ToolRequiringPermission); + thread.add_tool(EchoTool); thread.send(UserMessageId::new(), ["Hey!"], cx) }) .unwrap(); @@ -3204,7 +3204,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { #[gpui::test] async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; - thread.update(cx, |thread, _cx| thread.add_tool(EchoTool, None)); + thread.update(cx, |thread, _cx| thread.add_tool(EchoTool)); let fake_model = model.as_fake(); let mut events = thread @@ -3394,7 +3394,7 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) { let events = thread .update(cx, |thread, cx| { - thread.add_tool(EchoTool, None); + thread.add_tool(EchoTool); thread.send(UserMessageId::new(), ["Call the echo tool!"], cx) }) .unwrap(); @@ -4490,7 +4490,7 @@ async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAp Some(model), cx, ); - thread.add_default_tools(None, environment, cx); + thread.add_default_tools(environment, cx); thread }); @@ -4582,7 +4582,7 @@ async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppCont }); let deep_subagent_thread = cx.new(|cx| { let mut thread = Thread::new_subagent(&deep_parent_thread, cx); - thread.add_default_tools(None, environment, cx); + thread.add_default_tools(environment, cx); thread }); @@ -4736,78 +4736,6 @@ async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceede ); } -#[gpui::test] -async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) { - init_test(cx); - - always_allow_tools(cx); - - cx.update(|cx| { - cx.update_flags(true, vec!["subagents".to_string()]); - }); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; - let project_context = cx.new(|_cx| ProjectContext::default()); - let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx)); - cx.update(LanguageModelRegistry::test); - let model = Arc::new(FakeLanguageModel::default()); - let thread_store = cx.new(|cx| ThreadStore::new(cx)); - let native_agent = NativeAgent::new( - project.clone(), - thread_store, - Templates::new(), - None, - fs, - &mut cx.to_async(), - ) - .await - .unwrap(); - let parent_thread = cx.new(|cx| { - let mut thread = Thread::new( - project.clone(), - project_context, - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ); - thread.add_tool(ListDirectoryTool::new(project.clone()), None); - thread.add_tool(GrepTool::new(project.clone()), None); - thread - }); - - let _subagent_handle = cx - .update(|cx| { - NativeThreadEnvironment::create_subagent_thread( - native_agent.downgrade(), - parent_thread.clone(), - "some title".to_string(), - "task prompt".to_string(), - Some(Duration::from_millis(10)), - cx, - ) - }) - .expect("Failed to create subagent"); - - cx.run_until_parked(); - - let tools = model - .pending_completions() - .last() - .unwrap() - .tools - .iter() - .map(|tool| tool.name.clone()) - .collect::>(); - assert_eq!(tools.len(), 2); - assert!(tools.contains(&"grep".to_string())); - assert!(tools.contains(&"list_directory".to_string())); -} - #[gpui::test] async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) { init_test(cx); @@ -5458,7 +5386,7 @@ async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) { // Add a tool so we can simulate tool calls thread.update(cx, |thread, _cx| { - thread.add_tool(EchoTool, None); + thread.add_tool(EchoTool); }); // Start a turn by sending a message diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index f48e11d6c09da2a3ed3286b7f68b6b59f03fdb51..578f10554947567bd86a73d9abc02666daa06a84 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1326,111 +1326,52 @@ impl Thread { pub fn add_default_tools( &mut self, - allowed_tool_names: Option>, environment: Rc, cx: &mut Context, ) { let language_registry = self.project.read(cx).languages().clone(); - self.add_tool( - CopyPathTool::new(self.project.clone()), - allowed_tool_names.as_ref(), - ); - self.add_tool( - CreateDirectoryTool::new(self.project.clone()), - allowed_tool_names.as_ref(), - ); - self.add_tool( - DeletePathTool::new(self.project.clone(), self.action_log.clone()), - allowed_tool_names.as_ref(), - ); - self.add_tool( - DiagnosticsTool::new(self.project.clone()), - allowed_tool_names.as_ref(), - ); - self.add_tool( - EditFileTool::new( - self.project.clone(), - cx.weak_entity(), - language_registry.clone(), - Templates::new(), - ), - allowed_tool_names.as_ref(), - ); - self.add_tool( - StreamingEditFileTool::new( - self.project.clone(), - cx.weak_entity(), - language_registry, - Templates::new(), - ), - allowed_tool_names.as_ref(), - ); - self.add_tool( - FetchTool::new(self.project.read(cx).client().http_client()), - allowed_tool_names.as_ref(), - ); - self.add_tool( - FindPathTool::new(self.project.clone()), - allowed_tool_names.as_ref(), - ); - self.add_tool( - GrepTool::new(self.project.clone()), - allowed_tool_names.as_ref(), - ); - self.add_tool( - ListDirectoryTool::new(self.project.clone()), - allowed_tool_names.as_ref(), - ); - self.add_tool( - MovePathTool::new(self.project.clone()), - allowed_tool_names.as_ref(), - ); - self.add_tool(NowTool, allowed_tool_names.as_ref()); - self.add_tool( - OpenTool::new(self.project.clone()), - allowed_tool_names.as_ref(), - ); - self.add_tool( - ReadFileTool::new( - cx.weak_entity(), - self.project.clone(), - self.action_log.clone(), - ), - allowed_tool_names.as_ref(), - ); - self.add_tool( - SaveFileTool::new(self.project.clone()), - allowed_tool_names.as_ref(), - ); - self.add_tool( - RestoreFileFromDiskTool::new(self.project.clone()), - allowed_tool_names.as_ref(), - ); - self.add_tool( - TerminalTool::new(self.project.clone(), environment.clone()), - allowed_tool_names.as_ref(), - ); - self.add_tool(WebSearchTool, allowed_tool_names.as_ref()); + self.add_tool(CopyPathTool::new(self.project.clone())); + self.add_tool(CreateDirectoryTool::new(self.project.clone())); + self.add_tool(DeletePathTool::new( + self.project.clone(), + self.action_log.clone(), + )); + self.add_tool(DiagnosticsTool::new(self.project.clone())); + self.add_tool(EditFileTool::new( + self.project.clone(), + cx.weak_entity(), + language_registry.clone(), + Templates::new(), + )); + self.add_tool(StreamingEditFileTool::new( + self.project.clone(), + cx.weak_entity(), + language_registry, + Templates::new(), + )); + self.add_tool(FetchTool::new(self.project.read(cx).client().http_client())); + self.add_tool(FindPathTool::new(self.project.clone())); + self.add_tool(GrepTool::new(self.project.clone())); + self.add_tool(ListDirectoryTool::new(self.project.clone())); + self.add_tool(MovePathTool::new(self.project.clone())); + self.add_tool(NowTool); + self.add_tool(OpenTool::new(self.project.clone())); + self.add_tool(ReadFileTool::new( + cx.weak_entity(), + self.project.clone(), + self.action_log.clone(), + )); + self.add_tool(SaveFileTool::new(self.project.clone())); + self.add_tool(RestoreFileFromDiskTool::new(self.project.clone())); + self.add_tool(TerminalTool::new(self.project.clone(), environment.clone())); + self.add_tool(WebSearchTool); if cx.has_flag::() && self.depth() < MAX_SUBAGENT_DEPTH { - self.add_tool( - SubagentTool::new(cx.weak_entity(), environment), - allowed_tool_names.as_ref(), - ); + self.add_tool(SubagentTool::new(cx.weak_entity(), environment)); } } - pub fn add_tool( - &mut self, - tool: T, - allowed_tool_names: Option<&Vec>, - ) { - if allowed_tool_names - .is_some_and(|tool_names| !tool_names.iter().any(|x| x.as_str() == T::NAME)) - { - return; - } - + pub fn add_tool(&mut self, tool: T) { debug_assert!( !self.tools.contains_key(T::NAME), "Duplicate tool name: {}", diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index cda045cee6d0005e6cb5703ee82820d5aac38cbe..3fc8d23a69c82ebba6e42e7bc0651a16d0ec62d8 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -323,7 +323,7 @@ impl ExampleInstance { }; thread.update(cx, |thread, cx| { - thread.add_default_tools(None, Rc::new(EvalThreadEnvironment { + thread.add_default_tools(Rc::new(EvalThreadEnvironment { project: project.clone(), }), cx); thread.set_profile(meta.profile_id.clone(), cx);