agent: Use default tool behavior for subagents (#49706)

Ben Brandt created

Looking at this more, the map of tools is just for what is available and
all of the filtering happens at runtime. So we can just rely on the
current behavior for the list of tools (it was already a matching set)
and we can simplify all of the code paths again where we were adding
this filtered list.

Release Notes:

- N/A

Change summary

crates/agent/src/agent.rs                       |  12 
crates/agent/src/tests/edit_file_thread_test.rs |  28 +-
crates/agent/src/tests/mod.rs                   | 164 +++++-------------
crates/agent/src/thread.rs                      | 133 ++++-----------
crates/eval/src/instance.rs                     |   2 
5 files changed, 99 insertions(+), 240 deletions(-)

Detailed changes

crates/agent/src/agent.rs 🔗

@@ -336,13 +336,12 @@ impl NativeAgent {
             )
         });
 
-        self.register_session(thread, None, cx)
+        self.register_session(thread, cx)
     }
 
     fn register_session(
         &mut self,
         thread_handle: Entity<Thread>,
-        allowed_tool_names: Option<Vec<SharedString>>,
         cx: &mut Context<Self>,
     ) -> Entity<AcpThread> {
         let connection = Rc::new(NativeAgentConnection(cx.entity()));
@@ -374,7 +373,6 @@ impl NativeAgent {
         thread_handle.update(cx, |thread, cx| {
             thread.set_summarization_model(summarization_model, cx);
             thread.add_default_tools(
-                allowed_tool_names,
                 Rc::new(NativeThreadEnvironment {
                     acp_thread: acp_thread.downgrade(),
                     agent: weak,
@@ -804,9 +802,8 @@ impl NativeAgent {
         let task = self.load_thread(id, cx);
         cx.spawn(async move |this, cx| {
             let thread = task.await?;
-            let acp_thread = this.update(cx, |this, cx| {
-                this.register_session(thread.clone(), None, cx)
-            })?;
+            let acp_thread =
+                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
             let events = thread.update(cx, |thread, cx| thread.replay(cx));
             cx.update(|cx| {
                 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
@@ -1601,7 +1598,6 @@ impl NativeThreadEnvironment {
                 MAX_SUBAGENT_DEPTH
             ));
         }
-        let allowed_tool_names = Some(parent_thread.tools.keys().cloned().collect::<Vec<_>>());
 
         let subagent_thread: Entity<Thread> = cx.new(|cx| {
             let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
@@ -1612,7 +1608,7 @@ impl NativeThreadEnvironment {
         let session_id = subagent_thread.read(cx).id().clone();
 
         let acp_thread = agent.update(cx, |agent, cx| {
-            agent.register_session(subagent_thread.clone(), allowed_tool_names, cx)
+            agent.register_session(subagent_thread.clone(), cx)
         })?;
 
         parent_thread_entity.update(cx, |parent_thread, _cx| {

crates/agent/src/tests/edit_file_thread_test.rs 🔗

@@ -49,23 +49,17 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
         );
         // Add just the tools we need for this test
         let language_registry = project.read(cx).languages().clone();
-        thread.add_tool(
-            crate::ReadFileTool::new(
-                cx.weak_entity(),
-                project.clone(),
-                thread.action_log().clone(),
-            ),
-            None,
-        );
-        thread.add_tool(
-            crate::EditFileTool::new(
-                project.clone(),
-                cx.weak_entity(),
-                language_registry,
-                crate::Templates::new(),
-            ),
-            None,
-        );
+        thread.add_tool(crate::ReadFileTool::new(
+            cx.weak_entity(),
+            project.clone(),
+            thread.action_log().clone(),
+        ));
+        thread.add_tool(crate::EditFileTool::new(
+            project.clone(),
+            cx.weak_entity(),
+            language_registry,
+            crate::Templates::new(),
+        ));
         thread
     });
 

crates/agent/src/tests/mod.rs 🔗

@@ -464,7 +464,7 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
     project_context.update(cx, |project_context, _cx| {
         project_context.shell = "test-shell".into()
     });
-    thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
+    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
     thread
         .update(cx, |thread, cx| {
             thread.send(UserMessageId::new(), ["abc"], cx)
@@ -600,7 +600,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
     cx.run_until_parked();
 
     // Simulate a tool call and verify that the latest tool result is cached
-    thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
+    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
     thread
         .update(cx, |thread, cx| {
             thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
@@ -686,7 +686,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
     // Test a tool call that's likely to complete *before* streaming stops.
     let events = thread
         .update(cx, |thread, cx| {
-            thread.add_tool(EchoTool, None);
+            thread.add_tool(EchoTool);
             thread.send(
                 UserMessageId::new(),
                 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
@@ -702,7 +702,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
     let events = thread
         .update(cx, |thread, cx| {
             thread.remove_tool(&EchoTool::NAME);
-            thread.add_tool(DelayTool, None);
+            thread.add_tool(DelayTool);
             thread.send(
                 UserMessageId::new(),
                 [
@@ -746,7 +746,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
     // Test a tool call that's likely to complete *before* streaming stops.
     let mut events = thread
         .update(cx, |thread, cx| {
-            thread.add_tool(WordListTool, None);
+            thread.add_tool(WordListTool);
             thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
         })
         .unwrap();
@@ -797,7 +797,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
 
     let mut events = thread
         .update(cx, |thread, cx| {
-            thread.add_tool(ToolRequiringPermission, None);
+            thread.add_tool(ToolRequiringPermission);
             thread.send(UserMessageId::new(), ["abc"], cx)
         })
         .unwrap();
@@ -1207,7 +1207,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
     // Test concurrent tool calls with different delay times
     let events = thread
         .update(cx, |thread, cx| {
-            thread.add_tool(DelayTool, None);
+            thread.add_tool(DelayTool);
             thread.send(
                 UserMessageId::new(),
                 [
@@ -1252,9 +1252,9 @@ async fn test_profiles(cx: &mut TestAppContext) {
     let fake_model = model.as_fake();
 
     thread.update(cx, |thread, _cx| {
-        thread.add_tool(DelayTool, None);
-        thread.add_tool(EchoTool, None);
-        thread.add_tool(InfiniteTool, None);
+        thread.add_tool(DelayTool);
+        thread.add_tool(EchoTool);
+        thread.add_tool(InfiniteTool);
     });
 
     // Override profiles and wait for settings to be loaded.
@@ -1420,7 +1420,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) {
 
     // Send again after adding the echo tool, ensuring the name collision is resolved.
     let events = thread.update(cx, |thread, cx| {
-        thread.add_tool(EchoTool, None);
+        thread.add_tool(EchoTool);
         thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
     });
     cx.run_until_parked();
@@ -1711,11 +1711,11 @@ async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
 
     thread.update(cx, |thread, cx| {
         thread.set_profile(AgentProfileId("test".into()), cx);
-        thread.add_tool(EchoTool, None);
-        thread.add_tool(DelayTool, None);
-        thread.add_tool(WordListTool, None);
-        thread.add_tool(ToolRequiringPermission, None);
-        thread.add_tool(InfiniteTool, None);
+        thread.add_tool(EchoTool);
+        thread.add_tool(DelayTool);
+        thread.add_tool(WordListTool);
+        thread.add_tool(ToolRequiringPermission);
+        thread.add_tool(InfiniteTool);
     });
 
     // Set up multiple context servers with some overlapping tool names
@@ -1863,8 +1863,8 @@ async fn test_cancellation(cx: &mut TestAppContext) {
 
     let mut events = thread
         .update(cx, |thread, cx| {
-            thread.add_tool(InfiniteTool, None);
-            thread.add_tool(EchoTool, None);
+            thread.add_tool(InfiniteTool);
+            thread.add_tool(EchoTool);
             thread.send(
                 UserMessageId::new(),
                 ["Call the echo tool, then call the infinite tool, then explain their output"],
@@ -1955,10 +1955,10 @@ async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext
 
     let mut events = thread
         .update(cx, |thread, cx| {
-            thread.add_tool(
-                crate::TerminalTool::new(thread.project().clone(), environment),
-                None,
-            );
+            thread.add_tool(crate::TerminalTool::new(
+                thread.project().clone(),
+                environment,
+            ));
             thread.send(UserMessageId::new(), ["run a command"], cx)
         })
         .unwrap();
@@ -2052,7 +2052,7 @@ async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppC
 
     let mut events = thread
         .update(cx, |thread, cx| {
-            thread.add_tool(tool, None);
+            thread.add_tool(tool);
             thread.send(
                 UserMessageId::new(),
                 ["call the cancellation aware tool"],
@@ -2238,10 +2238,10 @@ async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
     let message_id = UserMessageId::new();
     let mut events = thread
         .update(cx, |thread, cx| {
-            thread.add_tool(
-                crate::TerminalTool::new(thread.project().clone(), environment),
-                None,
-            );
+            thread.add_tool(crate::TerminalTool::new(
+                thread.project().clone(),
+                environment,
+            ));
             thread.send(message_id.clone(), ["run a command"], cx)
         })
         .unwrap();
@@ -2302,10 +2302,10 @@ async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext)
 
     let mut events = thread
         .update(cx, |thread, cx| {
-            thread.add_tool(
-                crate::TerminalTool::new(thread.project().clone(), environment.clone()),
-                None,
-            );
+            thread.add_tool(crate::TerminalTool::new(
+                thread.project().clone(),
+                environment.clone(),
+            ));
             thread.send(UserMessageId::new(), ["run multiple commands"], cx)
         })
         .unwrap();
@@ -2415,10 +2415,10 @@ async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppCon
 
     let mut events = thread
         .update(cx, |thread, cx| {
-            thread.add_tool(
-                crate::TerminalTool::new(thread.project().clone(), environment),
-                None,
-            );
+            thread.add_tool(crate::TerminalTool::new(
+                thread.project().clone(),
+                environment,
+            ));
             thread.send(UserMessageId::new(), ["run a command"], cx)
         })
         .unwrap();
@@ -2509,10 +2509,10 @@ async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
 
     let mut events = thread
         .update(cx, |thread, cx| {
-            thread.add_tool(
-                crate::TerminalTool::new(thread.project().clone(), environment),
-                None,
-            );
+            thread.add_tool(crate::TerminalTool::new(
+                thread.project().clone(),
+                environment,
+            ));
             thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
         })
         .unwrap();
@@ -2997,8 +2997,8 @@ async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
 
     let _events = thread
         .update(cx, |thread, cx| {
-            thread.add_tool(ToolRequiringPermission, None);
-            thread.add_tool(EchoTool, None);
+            thread.add_tool(ToolRequiringPermission);
+            thread.add_tool(EchoTool);
             thread.send(UserMessageId::new(), ["Hey!"], cx)
         })
         .unwrap();
@@ -3204,7 +3204,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
 #[gpui::test]
 async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
     let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
-    thread.update(cx, |thread, _cx| thread.add_tool(EchoTool, None));
+    thread.update(cx, |thread, _cx| thread.add_tool(EchoTool));
     let fake_model = model.as_fake();
 
     let mut events = thread
@@ -3394,7 +3394,7 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
 
     let events = thread
         .update(cx, |thread, cx| {
-            thread.add_tool(EchoTool, None);
+            thread.add_tool(EchoTool);
             thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
         })
         .unwrap();
@@ -4490,7 +4490,7 @@ async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAp
             Some(model),
             cx,
         );
-        thread.add_default_tools(None, environment, cx);
+        thread.add_default_tools(environment, cx);
         thread
     });
 
@@ -4582,7 +4582,7 @@ async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppCont
     });
     let deep_subagent_thread = cx.new(|cx| {
         let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
-        thread.add_default_tools(None, environment, cx);
+        thread.add_default_tools(environment, cx);
         thread
     });
 
@@ -4736,78 +4736,6 @@ async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceede
     );
 }
 
-#[gpui::test]
-async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) {
-    init_test(cx);
-
-    always_allow_tools(cx);
-
-    cx.update(|cx| {
-        cx.update_flags(true, vec!["subagents".to_string()]);
-    });
-
-    let fs = FakeFs::new(cx.executor());
-    fs.insert_tree(path!("/test"), json!({})).await;
-    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
-    let project_context = cx.new(|_cx| ProjectContext::default());
-    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
-    let context_server_registry =
-        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
-    cx.update(LanguageModelRegistry::test);
-    let model = Arc::new(FakeLanguageModel::default());
-    let thread_store = cx.new(|cx| ThreadStore::new(cx));
-    let native_agent = NativeAgent::new(
-        project.clone(),
-        thread_store,
-        Templates::new(),
-        None,
-        fs,
-        &mut cx.to_async(),
-    )
-    .await
-    .unwrap();
-    let parent_thread = cx.new(|cx| {
-        let mut thread = Thread::new(
-            project.clone(),
-            project_context,
-            context_server_registry,
-            Templates::new(),
-            Some(model.clone()),
-            cx,
-        );
-        thread.add_tool(ListDirectoryTool::new(project.clone()), None);
-        thread.add_tool(GrepTool::new(project.clone()), None);
-        thread
-    });
-
-    let _subagent_handle = cx
-        .update(|cx| {
-            NativeThreadEnvironment::create_subagent_thread(
-                native_agent.downgrade(),
-                parent_thread.clone(),
-                "some title".to_string(),
-                "task prompt".to_string(),
-                Some(Duration::from_millis(10)),
-                cx,
-            )
-        })
-        .expect("Failed to create subagent");
-
-    cx.run_until_parked();
-
-    let tools = model
-        .pending_completions()
-        .last()
-        .unwrap()
-        .tools
-        .iter()
-        .map(|tool| tool.name.clone())
-        .collect::<Vec<_>>();
-    assert_eq!(tools.len(), 2);
-    assert!(tools.contains(&"grep".to_string()));
-    assert!(tools.contains(&"list_directory".to_string()));
-}
-
 #[gpui::test]
 async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
     init_test(cx);
@@ -5458,7 +5386,7 @@ async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
 
     // Add a tool so we can simulate tool calls
     thread.update(cx, |thread, _cx| {
-        thread.add_tool(EchoTool, None);
+        thread.add_tool(EchoTool);
     });
 
     // Start a turn by sending a message

crates/agent/src/thread.rs 🔗

@@ -1326,111 +1326,52 @@ impl Thread {
 
     pub fn add_default_tools(
         &mut self,
-        allowed_tool_names: Option<Vec<SharedString>>,
         environment: Rc<dyn ThreadEnvironment>,
         cx: &mut Context<Self>,
     ) {
         let language_registry = self.project.read(cx).languages().clone();
-        self.add_tool(
-            CopyPathTool::new(self.project.clone()),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            CreateDirectoryTool::new(self.project.clone()),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            DeletePathTool::new(self.project.clone(), self.action_log.clone()),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            DiagnosticsTool::new(self.project.clone()),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            EditFileTool::new(
-                self.project.clone(),
-                cx.weak_entity(),
-                language_registry.clone(),
-                Templates::new(),
-            ),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            StreamingEditFileTool::new(
-                self.project.clone(),
-                cx.weak_entity(),
-                language_registry,
-                Templates::new(),
-            ),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            FetchTool::new(self.project.read(cx).client().http_client()),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            FindPathTool::new(self.project.clone()),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            GrepTool::new(self.project.clone()),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            ListDirectoryTool::new(self.project.clone()),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            MovePathTool::new(self.project.clone()),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(NowTool, allowed_tool_names.as_ref());
-        self.add_tool(
-            OpenTool::new(self.project.clone()),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            ReadFileTool::new(
-                cx.weak_entity(),
-                self.project.clone(),
-                self.action_log.clone(),
-            ),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            SaveFileTool::new(self.project.clone()),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            RestoreFileFromDiskTool::new(self.project.clone()),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(
-            TerminalTool::new(self.project.clone(), environment.clone()),
-            allowed_tool_names.as_ref(),
-        );
-        self.add_tool(WebSearchTool, allowed_tool_names.as_ref());
+        self.add_tool(CopyPathTool::new(self.project.clone()));
+        self.add_tool(CreateDirectoryTool::new(self.project.clone()));
+        self.add_tool(DeletePathTool::new(
+            self.project.clone(),
+            self.action_log.clone(),
+        ));
+        self.add_tool(DiagnosticsTool::new(self.project.clone()));
+        self.add_tool(EditFileTool::new(
+            self.project.clone(),
+            cx.weak_entity(),
+            language_registry.clone(),
+            Templates::new(),
+        ));
+        self.add_tool(StreamingEditFileTool::new(
+            self.project.clone(),
+            cx.weak_entity(),
+            language_registry,
+            Templates::new(),
+        ));
+        self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
+        self.add_tool(FindPathTool::new(self.project.clone()));
+        self.add_tool(GrepTool::new(self.project.clone()));
+        self.add_tool(ListDirectoryTool::new(self.project.clone()));
+        self.add_tool(MovePathTool::new(self.project.clone()));
+        self.add_tool(NowTool);
+        self.add_tool(OpenTool::new(self.project.clone()));
+        self.add_tool(ReadFileTool::new(
+            cx.weak_entity(),
+            self.project.clone(),
+            self.action_log.clone(),
+        ));
+        self.add_tool(SaveFileTool::new(self.project.clone()));
+        self.add_tool(RestoreFileFromDiskTool::new(self.project.clone()));
+        self.add_tool(TerminalTool::new(self.project.clone(), environment.clone()));
+        self.add_tool(WebSearchTool);
 
         if cx.has_flag::<SubagentsFeatureFlag>() && self.depth() < MAX_SUBAGENT_DEPTH {
-            self.add_tool(
-                SubagentTool::new(cx.weak_entity(), environment),
-                allowed_tool_names.as_ref(),
-            );
+            self.add_tool(SubagentTool::new(cx.weak_entity(), environment));
         }
     }
 
-    pub fn add_tool<T: AgentTool>(
-        &mut self,
-        tool: T,
-        allowed_tool_names: Option<&Vec<SharedString>>,
-    ) {
-        if allowed_tool_names
-            .is_some_and(|tool_names| !tool_names.iter().any(|x| x.as_str() == T::NAME))
-        {
-            return;
-        }
-
+    pub fn add_tool<T: AgentTool>(&mut self, tool: T) {
         debug_assert!(
             !self.tools.contains_key(T::NAME),
             "Duplicate tool name: {}",

crates/eval/src/instance.rs 🔗

@@ -323,7 +323,7 @@ impl ExampleInstance {
                 };
 
                 thread.update(cx, |thread, cx| {
-                    thread.add_default_tools(None, Rc::new(EvalThreadEnvironment {
+                    thread.add_default_tools(Rc::new(EvalThreadEnvironment {
                         project: project.clone(),
                     }), cx);
                     thread.set_profile(meta.profile_id.clone(), cx);