@@ -7,6 +7,7 @@ use client::{Client, UserStore};
use cloud_llm_client::CompletionIntent;
use collections::IndexMap;
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
+use feature_flags::FeatureFlagAppExt as _;
use fs::{FakeFs, Fs};
use futures::{
FutureExt as _, StreamExt,
@@ -343,7 +344,9 @@ async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAp
"expected tool call update to include terminal content"
);
- smol::Timer::after(Duration::from_millis(25)).await;
+ cx.background_executor
+ .timer(Duration::from_millis(25))
+ .await;
assert!(
!handle.was_killed(),
@@ -2989,10 +2992,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
acp::ToolCall::new("1", "Thinking")
.kind(acp::ToolKind::Think)
.raw_input(json!({}))
- .meta(acp::Meta::from_iter([(
- "tool_name".into(),
- "thinking".into()
- )]))
+ .meta(acp_thread::meta_with_tool_name("thinking"))
);
let update = expect_tool_call_update_fields(&mut events).await;
assert_eq!(
@@ -3927,6 +3927,1163 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
}
}
+#[gpui::test]
+async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), json!({})).await;
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+
+ let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
+ let environment = Rc::new(FakeThreadEnvironment { handle });
+
+ let thread = cx.new(|cx| {
+ let mut thread = Thread::new(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ Some(model),
+ cx,
+ );
+ thread.add_default_tools(environment, cx);
+ thread
+ });
+
+ thread.read_with(cx, |thread, _| {
+ assert!(
+ thread.has_registered_tool("subagent"),
+ "subagent tool should be present when feature flag is enabled"
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_subagent_tool_is_absent_when_feature_flag_disabled(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(false, vec![]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), json!({})).await;
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+
+ let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
+ let environment = Rc::new(FakeThreadEnvironment { handle });
+
+ let thread = cx.new(|cx| {
+ let mut thread = Thread::new(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ Some(model),
+ cx,
+ );
+ thread.add_default_tools(environment, cx);
+ thread
+ });
+
+ thread.read_with(cx, |thread, _| {
+ assert!(
+ !thread.has_registered_tool("subagent"),
+ "subagent tool should not be present when feature flag is disabled"
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), json!({})).await;
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+
+ let subagent_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
+ depth: 1,
+ summary_prompt: "Summarize".to_string(),
+ context_low_prompt: "Context low".to_string(),
+ };
+
+ let subagent = cx.new(|cx| {
+ Thread::new_subagent(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ model.clone(),
+ subagent_context,
+ cx,
+ )
+ });
+
+ subagent.read_with(cx, |thread, _| {
+ assert!(thread.is_subagent());
+ assert_eq!(thread.depth(), 1);
+ assert!(thread.model().is_some());
+ });
+}
+
+#[gpui::test]
+async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), json!({})).await;
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+
+ let subagent_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
+ depth: MAX_SUBAGENT_DEPTH,
+ summary_prompt: "Summarize".to_string(),
+ context_low_prompt: "Context low".to_string(),
+ };
+
+ let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
+ let environment = Rc::new(FakeThreadEnvironment { handle });
+
+ let deep_subagent = cx.new(|cx| {
+ let mut thread = Thread::new_subagent(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ model.clone(),
+ subagent_context,
+ cx,
+ );
+ thread.add_default_tools(environment, cx);
+ thread
+ });
+
+ deep_subagent.read_with(cx, |thread, _| {
+ assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
+ assert!(
+ !thread.has_registered_tool("subagent"),
+ "subagent tool should not be present at max depth"
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let subagent_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
+ depth: 1,
+ summary_prompt: "Summarize your work".to_string(),
+ context_low_prompt: "Context low, wrap up".to_string(),
+ };
+
+ let project = thread.read_with(cx, |t, _| t.project.clone());
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+
+ let subagent = cx.new(|cx| {
+ Thread::new_subagent(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ model.clone(),
+ subagent_context,
+ cx,
+ )
+ });
+
+ let task_prompt = "Find all TODO comments in the codebase";
+ subagent
+ .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))
+ .unwrap();
+ cx.run_until_parked();
+
+ let pending = fake_model.pending_completions();
+ assert_eq!(pending.len(), 1, "should have one pending completion");
+
+ let messages = &pending[0].messages;
+ let user_messages: Vec<_> = messages
+ .iter()
+ .filter(|m| m.role == language_model::Role::User)
+ .collect();
+ assert_eq!(user_messages.len(), 1, "should have one user message");
+
+ let content = &user_messages[0].content[0];
+ assert!(
+ content.to_str().unwrap().contains("TODO"),
+ "task prompt should be in user message"
+ );
+}
+
+#[gpui::test]
+async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let subagent_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
+ depth: 1,
+ summary_prompt: "Please summarize what you found".to_string(),
+ context_low_prompt: "Context low, wrap up".to_string(),
+ };
+
+ let project = thread.read_with(cx, |t, _| t.project.clone());
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+
+ let subagent = cx.new(|cx| {
+ Thread::new_subagent(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ model.clone(),
+ subagent_context,
+ cx,
+ )
+ });
+
+ subagent
+ .update(cx, |thread, cx| {
+ thread.submit_user_message("Do some work", cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_text_chunk("I did the work");
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ subagent
+ .update(cx, |thread, cx| thread.request_final_summary(cx))
+ .unwrap();
+ cx.run_until_parked();
+
+ let pending = fake_model.pending_completions();
+ assert!(
+ !pending.is_empty(),
+ "should have pending completion for summary"
+ );
+
+ let messages = &pending.last().unwrap().messages;
+ let user_messages: Vec<_> = messages
+ .iter()
+ .filter(|m| m.role == language_model::Role::User)
+ .collect();
+
+ let last_user = user_messages.last().unwrap();
+ assert!(
+ last_user.content[0].to_str().unwrap().contains("summarize"),
+ "summary prompt should be sent"
+ );
+}
+
+#[gpui::test]
+async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), json!({})).await;
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+
+ let subagent_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
+ depth: 1,
+ summary_prompt: "Summarize".to_string(),
+ context_low_prompt: "Context low".to_string(),
+ };
+
+ let subagent = cx.new(|cx| {
+ let mut thread = Thread::new_subagent(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ model.clone(),
+ subagent_context,
+ cx,
+ );
+ thread.add_tool(EchoTool);
+ thread.add_tool(DelayTool);
+ thread.add_tool(WordListTool);
+ thread
+ });
+
+ subagent.read_with(cx, |thread, _| {
+ assert!(thread.has_registered_tool("echo"));
+ assert!(thread.has_registered_tool("delay"));
+ assert!(thread.has_registered_tool("word_list"));
+ });
+
+ let allowed: collections::HashSet<gpui::SharedString> =
+ vec!["echo".into()].into_iter().collect();
+
+ subagent.update(cx, |thread, _cx| {
+ thread.restrict_tools(&allowed);
+ });
+
+ subagent.read_with(cx, |thread, _| {
+ assert!(
+ thread.has_registered_tool("echo"),
+ "echo should still be available"
+ );
+ assert!(
+ !thread.has_registered_tool("delay"),
+ "delay should be removed"
+ );
+ assert!(
+ !thread.has_registered_tool("word_list"),
+ "word_list should be removed"
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), json!({})).await;
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+
+ let parent = cx.new(|cx| {
+ Thread::new(
+ project.clone(),
+ project_context.clone(),
+ context_server_registry.clone(),
+ Templates::new(),
+ Some(model.clone()),
+ cx,
+ )
+ });
+
+ let subagent_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
+ depth: 1,
+ summary_prompt: "Summarize".to_string(),
+ context_low_prompt: "Context low".to_string(),
+ };
+
+ let subagent = cx.new(|cx| {
+ Thread::new_subagent(
+ project.clone(),
+ project_context.clone(),
+ context_server_registry.clone(),
+ Templates::new(),
+ model.clone(),
+ subagent_context,
+ cx,
+ )
+ });
+
+ parent.update(cx, |thread, _cx| {
+ thread.register_running_subagent(subagent.downgrade());
+ });
+
+ subagent
+ .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
+ .unwrap();
+ cx.run_until_parked();
+
+ subagent.read_with(cx, |thread, _| {
+ assert!(!thread.is_turn_complete(), "subagent should be running");
+ });
+
+ parent.update(cx, |thread, cx| {
+ thread.cancel(cx).detach();
+ });
+
+ subagent.read_with(cx, |thread, _| {
+ assert!(
+ thread.is_turn_complete(),
+ "subagent should be cancelled when parent cancels"
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let subagent_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
+ depth: 1,
+ summary_prompt: "Summarize".to_string(),
+ context_low_prompt: "Context low".to_string(),
+ };
+
+ let project = thread.read_with(cx, |t, _| t.project.clone());
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+
+ let subagent = cx.new(|cx| {
+ Thread::new_subagent(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ model.clone(),
+ subagent_context,
+ cx,
+ )
+ });
+
+ subagent
+ .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
+ .unwrap();
+ cx.run_until_parked();
+
+ subagent.read_with(cx, |thread, _| {
+ assert!(!thread.is_turn_complete(), "turn should be in progress");
+ });
+
+ fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey {
+ provider: LanguageModelProviderName::from("Fake".to_string()),
+ });
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ subagent.read_with(cx, |thread, _| {
+ assert!(
+ thread.is_turn_complete(),
+ "turn should be complete after non-retryable error"
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let subagent_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
+ depth: 1,
+ summary_prompt: "Summarize your work".to_string(),
+ context_low_prompt: "Context low, stop and summarize".to_string(),
+ };
+
+ let project = thread.read_with(cx, |t, _| t.project.clone());
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+
+ let subagent = cx.new(|cx| {
+ Thread::new_subagent(
+ project.clone(),
+ project_context.clone(),
+ context_server_registry.clone(),
+ Templates::new(),
+ model.clone(),
+ subagent_context.clone(),
+ cx,
+ )
+ });
+
+ subagent.update(cx, |thread, _| {
+ thread.add_tool(EchoTool);
+ });
+
+ subagent
+ .update(cx, |thread, cx| {
+ thread.submit_user_message("Do some work", cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_text_chunk("Working on it...");
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx));
+ assert!(
+ interrupt_result.is_ok(),
+ "interrupt_for_summary should succeed"
+ );
+
+ cx.run_until_parked();
+
+ let pending = fake_model.pending_completions();
+ assert!(
+ !pending.is_empty(),
+ "should have pending completion for interrupted summary"
+ );
+
+ let messages = &pending.last().unwrap().messages;
+ let user_messages: Vec<_> = messages
+ .iter()
+ .filter(|m| m.role == language_model::Role::User)
+ .collect();
+
+ let last_user = user_messages.last().unwrap();
+ let content_str = last_user.content[0].to_str().unwrap();
+ assert!(
+ content_str.contains("Context low") || content_str.contains("stop and summarize"),
+ "context_low_prompt should be sent when interrupting: got {:?}",
+ content_str
+ );
+}
+
+#[gpui::test]
+async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let subagent_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
+ depth: 1,
+ summary_prompt: "Summarize".to_string(),
+ context_low_prompt: "Context low".to_string(),
+ };
+
+ let project = thread.read_with(cx, |t, _| t.project.clone());
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+
+ let subagent = cx.new(|cx| {
+ Thread::new_subagent(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ model.clone(),
+ subagent_context,
+ cx,
+ )
+ });
+
+ subagent
+ .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
+ .unwrap();
+ cx.run_until_parked();
+
+ let max_tokens = model.max_token_count();
+ let high_usage = language_model::TokenUsage {
+ input_tokens: (max_tokens as f64 * 0.80) as u64,
+ output_tokens: 0,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ };
+
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage));
+ fake_model.send_last_completion_stream_text_chunk("Working...");
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage());
+ assert!(usage.is_some(), "should have token usage after completion");
+
+ let usage = usage.unwrap();
+ let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
+ assert!(
+ remaining_ratio <= 0.25,
+ "remaining ratio should be at or below 25% (got {}%), indicating context is low",
+ remaining_ratio * 100.0
+ );
+}
+
+#[gpui::test]
+async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), json!({})).await;
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+
+ let parent = cx.new(|cx| {
+ let mut thread = Thread::new(
+ project.clone(),
+ project_context.clone(),
+ context_server_registry.clone(),
+ Templates::new(),
+ Some(model.clone()),
+ cx,
+ );
+ thread.add_tool(EchoTool);
+ thread
+ });
+
+ let parent_tool_names: Vec<gpui::SharedString> = vec!["echo".into()];
+
+ let tool = Arc::new(SubagentTool::new(
+ parent.downgrade(),
+ project,
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ 0,
+ parent_tool_names,
+ ));
+
+ let result = tool.validate_allowed_tools(&Some(vec!["nonexistent_tool".to_string()]));
+ assert!(result.is_err(), "should reject unknown tool");
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("nonexistent_tool"),
+ "error should mention the invalid tool name: {}",
+ err_msg
+ );
+ assert!(
+ err_msg.contains("not available"),
+ "error should explain the tool is not available: {}",
+ err_msg
+ );
+}
+
+#[gpui::test]
+async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let subagent_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
+ depth: 1,
+ summary_prompt: "Summarize".to_string(),
+ context_low_prompt: "Context low".to_string(),
+ };
+
+ let project = thread.read_with(cx, |t, _| t.project.clone());
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+
+ let subagent = cx.new(|cx| {
+ Thread::new_subagent(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ model.clone(),
+ subagent_context,
+ cx,
+ )
+ });
+
+ subagent
+ .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
+ .unwrap();
+ cx.run_until_parked();
+
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ subagent.read_with(cx, |thread, _| {
+ assert!(
+ thread.is_turn_complete(),
+ "turn should complete even with empty response"
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), json!({})).await;
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+
+ let depth_1_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("root-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"),
+ depth: 1,
+ summary_prompt: "Summarize".to_string(),
+ context_low_prompt: "Context low".to_string(),
+ };
+
+ let depth_1_subagent = cx.new(|cx| {
+ Thread::new_subagent(
+ project.clone(),
+ project_context.clone(),
+ context_server_registry.clone(),
+ Templates::new(),
+ model.clone(),
+ depth_1_context,
+ cx,
+ )
+ });
+
+ depth_1_subagent.read_with(cx, |thread, _| {
+ assert_eq!(thread.depth(), 1);
+ assert!(thread.is_subagent());
+ });
+
+ let depth_2_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"),
+ depth: 2,
+ summary_prompt: "Summarize depth 2".to_string(),
+ context_low_prompt: "Context low depth 2".to_string(),
+ };
+
+ let depth_2_subagent = cx.new(|cx| {
+ Thread::new_subagent(
+ project.clone(),
+ project_context.clone(),
+ context_server_registry.clone(),
+ Templates::new(),
+ model.clone(),
+ depth_2_context,
+ cx,
+ )
+ });
+
+ depth_2_subagent.read_with(cx, |thread, _| {
+ assert_eq!(thread.depth(), 2);
+ assert!(thread.is_subagent());
+ });
+
+ depth_2_subagent
+ .update(cx, |thread, cx| {
+ thread.submit_user_message("Nested task", cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ let pending = model.as_fake().pending_completions();
+ assert!(
+ !pending.is_empty(),
+ "depth-2 subagent should be able to submit messages"
+ );
+}
+
+#[gpui::test]
+async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) {
+ init_test(cx);
+ always_allow_tools(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), json!({})).await;
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+ let fake_model = model.as_fake();
+
+ let subagent_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
+ depth: 1,
+ summary_prompt: "Summarize what you did".to_string(),
+ context_low_prompt: "Context low".to_string(),
+ };
+
+ let subagent = cx.new(|cx| {
+ let mut thread = Thread::new_subagent(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ model.clone(),
+ subagent_context,
+ cx,
+ );
+ thread.add_tool(EchoTool);
+ thread
+ });
+
+ subagent.read_with(cx, |thread, _| {
+ assert!(
+ thread.has_registered_tool("echo"),
+ "subagent should have echo tool"
+ );
+ });
+
+ subagent
+ .update(cx, |thread, cx| {
+ thread.submit_user_message("Use the echo tool to echo 'hello world'", cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ let tool_use = LanguageModelToolUse {
+ id: "tool_call_1".into(),
+ name: EchoTool::name().into(),
+ raw_input: json!({"text": "hello world"}).to_string(),
+ input: json!({"text": "hello world"}),
+ is_input_complete: true,
+ thought_signature: None,
+ };
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ let pending = fake_model.pending_completions();
+ assert!(
+ !pending.is_empty(),
+ "should have pending completion after tool use"
+ );
+
+ let last_completion = pending.last().unwrap();
+ let has_tool_result = last_completion.messages.iter().any(|m| {
+ m.content
+ .iter()
+ .any(|c| matches!(c, MessageContent::ToolResult(_)))
+ });
+ assert!(
+ has_tool_result,
+ "tool result should be in the messages sent back to the model"
+ );
+}
+
+#[gpui::test]
+async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), json!({})).await;
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+
+ let parent = cx.new(|cx| {
+ Thread::new(
+ project.clone(),
+ project_context.clone(),
+ context_server_registry.clone(),
+ Templates::new(),
+ Some(model.clone()),
+ cx,
+ )
+ });
+
+ let mut subagents = Vec::new();
+ for i in 0..MAX_PARALLEL_SUBAGENTS {
+ let subagent_context = SubagentContext {
+ parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
+ tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)),
+ depth: 1,
+ summary_prompt: "Summarize".to_string(),
+ context_low_prompt: "Context low".to_string(),
+ };
+
+ let subagent = cx.new(|cx| {
+ Thread::new_subagent(
+ project.clone(),
+ project_context.clone(),
+ context_server_registry.clone(),
+ Templates::new(),
+ model.clone(),
+ subagent_context,
+ cx,
+ )
+ });
+
+ parent.update(cx, |thread, _cx| {
+ thread.register_running_subagent(subagent.downgrade());
+ });
+ subagents.push(subagent);
+ }
+
+ parent.read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.running_subagent_count(),
+ MAX_PARALLEL_SUBAGENTS,
+ "should have MAX_PARALLEL_SUBAGENTS registered"
+ );
+ });
+
+ let parent_tool_names: Vec<gpui::SharedString> = vec![];
+
+ let tool = Arc::new(SubagentTool::new(
+ parent.downgrade(),
+ project.clone(),
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ 0,
+ parent_tool_names,
+ ));
+
+ let (event_stream, _rx) = crate::ToolCallEventStream::test();
+
+ let result = cx.update(|cx| {
+ tool.run(
+ SubagentToolInput {
+ label: "Test".to_string(),
+ task_prompt: "Do something".to_string(),
+ summary_prompt: "Summarize".to_string(),
+ context_low_prompt: "Context low".to_string(),
+ timeout_ms: None,
+ allowed_tools: None,
+ },
+ event_stream,
+ cx,
+ )
+ });
+
+ let err = result.await.unwrap_err();
+ assert!(
+ err.to_string().contains("Maximum parallel subagents"),
+ "should reject when max parallel subagents reached: {}",
+ err
+ );
+
+ drop(subagents);
+}
+
+#[gpui::test]
+async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) {
+ init_test(cx);
+ always_allow_tools(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["subagents".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/test"), json!({})).await;
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+ let fake_model = model.as_fake();
+
+ let parent = cx.new(|cx| {
+ let mut thread = Thread::new(
+ project.clone(),
+ project_context.clone(),
+ context_server_registry.clone(),
+ Templates::new(),
+ Some(model.clone()),
+ cx,
+ );
+ thread.add_tool(EchoTool);
+ thread
+ });
+
+ let parent_tool_names: Vec<gpui::SharedString> = vec!["echo".into()];
+
+ let tool = Arc::new(SubagentTool::new(
+ parent.downgrade(),
+ project.clone(),
+ project_context,
+ context_server_registry,
+ Templates::new(),
+ 0,
+ parent_tool_names,
+ ));
+
+ let (event_stream, _rx) = crate::ToolCallEventStream::test();
+
+ let task = cx.update(|cx| {
+ tool.run(
+ SubagentToolInput {
+ label: "Research task".to_string(),
+ task_prompt: "Find all TODOs in the codebase".to_string(),
+ summary_prompt: "Summarize what you found".to_string(),
+ context_low_prompt: "Context low, wrap up".to_string(),
+ timeout_ms: None,
+ allowed_tools: None,
+ },
+ event_stream,
+ cx,
+ )
+ });
+
+ cx.run_until_parked();
+
+ let pending = fake_model.pending_completions();
+ assert!(
+ !pending.is_empty(),
+ "subagent should have started and sent a completion request"
+ );
+
+ let first_completion = &pending[0];
+ let has_task_prompt = first_completion.messages.iter().any(|m| {
+ m.role == language_model::Role::User
+ && m.content
+ .iter()
+ .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false))
+ });
+ assert!(has_task_prompt, "task prompt should be sent to subagent");
+
+ fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase.");
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ let pending = fake_model.pending_completions();
+ assert!(
+ !pending.is_empty(),
+ "should have pending completion for summary request"
+ );
+
+ let last_completion = pending.last().unwrap();
+ let has_summary_prompt = last_completion.messages.iter().any(|m| {
+ m.role == language_model::Role::User
+ && m.content.iter().any(|c| {
+ c.to_str()
+ .map(|s| s.contains("Summarize") || s.contains("summarize"))
+ .unwrap_or(false)
+ })
+ });
+ assert!(
+ has_summary_prompt,
+ "summary prompt should be sent after task completion"
+ );
+
+ fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files.");
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ let result = task.await;
+ assert!(result.is_ok(), "subagent tool should complete successfully");
+
+ let summary = result.unwrap();
+ assert!(
+ summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"),
+ "summary should contain subagent's response: {}",
+ summary
+ );
+}
+
#[gpui::test]
async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
init_test(cx);
@@ -58,6 +58,27 @@ use uuid::Uuid;
const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
pub const MAX_TOOL_NAME_LENGTH: usize = 64;
+pub const MAX_SUBAGENT_DEPTH: u8 = 4;
+pub const MAX_PARALLEL_SUBAGENTS: usize = 8;
+
+/// Context passed to a subagent thread for lifecycle management
+#[derive(Clone)]
+pub struct SubagentContext {
+ /// ID of the parent thread
+ pub parent_thread_id: acp::SessionId,
+
+ /// ID of the tool call that spawned this subagent
+ pub tool_use_id: LanguageModelToolUseId,
+
+ /// Current depth level (0 = root agent, 1 = first-level subagent, etc.)
+ pub depth: u8,
+
+ /// Prompt to send when subagent completes successfully
+ pub summary_prompt: String,
+
+ /// Prompt to send when context is running low (≤25% remaining)
+ pub context_low_prompt: String,
+}
/// The ID of the user prompt that initiated a request.
///
@@ -626,6 +647,10 @@ pub struct Thread {
pub(crate) file_read_times: HashMap<PathBuf, fs::MTime>,
/// True if this thread was imported from a shared thread and can be synced.
imported: bool,
+ /// If this is a subagent thread, contains context about the parent
+ subagent_context: Option<SubagentContext>,
+ /// Weak references to running subagent threads for cancellation propagation
+ running_subagents: Vec<WeakEntity<Thread>>,
}
impl Thread {
@@ -683,6 +708,56 @@ impl Thread {
action_log,
file_read_times: HashMap::default(),
imported: false,
+ subagent_context: None,
+ running_subagents: Vec::new(),
+ }
+ }
+
+ pub fn new_subagent(
+ project: Entity<Project>,
+ project_context: Entity<ProjectContext>,
+ context_server_registry: Entity<ContextServerRegistry>,
+ templates: Arc<Templates>,
+ model: Arc<dyn LanguageModel>,
+ subagent_context: SubagentContext,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let profile_id = AgentSettings::get_global(cx).default_profile.clone();
+ let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
+ let (prompt_capabilities_tx, prompt_capabilities_rx) =
+ watch::channel(Self::prompt_capabilities(Some(model.as_ref())));
+ Self {
+ id: acp::SessionId::new(uuid::Uuid::new_v4().to_string()),
+ prompt_id: PromptId::new(),
+ updated_at: Utc::now(),
+ title: None,
+ pending_title_generation: None,
+ pending_summary_generation: None,
+ summary: None,
+ messages: Vec::new(),
+ user_store: project.read(cx).user_store(),
+ completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
+ running_turn: None,
+ pending_message: None,
+ tools: BTreeMap::default(),
+ tool_use_limit_reached: false,
+ request_token_usage: HashMap::default(),
+ cumulative_token_usage: TokenUsage::default(),
+ initial_project_snapshot: Task::ready(None).shared(),
+ context_server_registry,
+ profile_id,
+ project_context,
+ templates,
+ model: Some(model),
+ summarization_model: None,
+ prompt_capabilities_tx,
+ prompt_capabilities_rx,
+ project,
+ action_log,
+ file_read_times: HashMap::default(),
+ imported: false,
+ subagent_context: Some(subagent_context),
+ running_subagents: Vec::new(),
}
}
@@ -880,6 +955,8 @@ impl Thread {
prompt_capabilities_rx,
file_read_times: HashMap::default(),
imported: db_thread.imported,
+ subagent_context: None,
+ running_subagents: Vec::new(),
}
}
@@ -984,7 +1061,6 @@ impl Thread {
cx.notify()
}
- #[cfg(any(test, feature = "test-support"))]
pub fn last_message(&self) -> Option<Message> {
if let Some(message) = self.pending_message.clone() {
Some(Message::Agent(message))
@@ -1030,8 +1106,17 @@ impl Thread {
self.add_tool(ThinkingTool);
self.add_tool(WebSearchTool);
- if cx.has_flag::<SubagentsFeatureFlag>() {
- self.add_tool(SubagentTool::new());
+ if cx.has_flag::<SubagentsFeatureFlag>() && self.depth() < MAX_SUBAGENT_DEPTH {
+ let tool_names = self.registered_tool_names();
+ self.add_tool(SubagentTool::new(
+ cx.weak_entity(),
+ self.project.clone(),
+ self.project_context.clone(),
+ self.context_server_registry.clone(),
+ self.templates.clone(),
+ self.depth(),
+ tool_names,
+ ));
}
}
@@ -1043,6 +1128,10 @@ impl Thread {
self.tools.remove(name).is_some()
}
+ pub fn restrict_tools(&mut self, allowed: &collections::HashSet<SharedString>) {
+ self.tools.retain(|name, _| allowed.contains(name));
+ }
+
pub fn profile(&self) -> &AgentProfileId {
&self.profile_id
}
@@ -1061,6 +1150,12 @@ impl Thread {
}
pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
+ for subagent in self.running_subagents.drain(..) {
+ if let Some(subagent) = subagent.upgrade() {
+ subagent.update(cx, |thread, cx| thread.cancel(cx)).detach();
+ }
+ }
+
let Some(running_turn) = self.running_turn.take() else {
self.flush_pending_message(cx);
return Task::ready(());
@@ -2138,6 +2233,82 @@ impl Thread {
.is_some_and(|turn| turn.tools.contains_key(name))
}
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn has_registered_tool(&self, name: &str) -> bool {
+ self.tools.contains_key(name)
+ }
+
+ pub fn registered_tool_names(&self) -> Vec<SharedString> {
+ self.tools.keys().cloned().collect()
+ }
+
+ pub fn register_running_subagent(&mut self, subagent: WeakEntity<Thread>) {
+ self.running_subagents.push(subagent);
+ }
+
+ pub fn unregister_running_subagent(&mut self, subagent: &WeakEntity<Thread>) {
+ self.running_subagents
+ .retain(|s| s.entity_id() != subagent.entity_id());
+ }
+
+ pub fn running_subagent_count(&self) -> usize {
+ self.running_subagents
+ .iter()
+ .filter(|s| s.upgrade().is_some())
+ .count()
+ }
+
+ pub fn is_subagent(&self) -> bool {
+ self.subagent_context.is_some()
+ }
+
+ pub fn depth(&self) -> u8 {
+ self.subagent_context.as_ref().map(|c| c.depth).unwrap_or(0)
+ }
+
+ pub fn is_turn_complete(&self) -> bool {
+ self.running_turn.is_none()
+ }
+
+ pub fn submit_user_message(
+ &mut self,
+ content: impl Into<String>,
+ cx: &mut Context<Self>,
+ ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
+ let content = content.into();
+ self.messages.push(Message::User(UserMessage {
+ id: UserMessageId::new(),
+ content: vec![UserMessageContent::Text(content)],
+ }));
+ cx.notify();
+ self.send_existing(cx)
+ }
+
+ pub fn interrupt_for_summary(
+ &mut self,
+ cx: &mut Context<Self>,
+ ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
+ let context = self
+ .subagent_context
+ .as_ref()
+ .context("Not a subagent thread")?;
+ let prompt = context.context_low_prompt.clone();
+ self.cancel(cx).detach();
+ self.submit_user_message(prompt, cx)
+ }
+
+ pub fn request_final_summary(
+ &mut self,
+ cx: &mut Context<Self>,
+ ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
+ let context = self
+ .subagent_context
+ .as_ref()
+ .context("Not a subagent thread")?;
+ let prompt = context.summary_prompt.clone();
+ self.submit_user_message(prompt, cx)
+ }
+
fn build_request_messages(
&self,
available_tools: Vec<SharedString>,
@@ -2546,10 +2717,7 @@ impl ThreadEventStream {
acp::ToolCall::new(id.to_string(), title)
.kind(kind)
.raw_input(input)
- .meta(acp::Meta::from_iter([(
- "tool_name".into(),
- tool_name.into(),
- )]))
+ .meta(acp_thread::meta_with_tool_name(tool_name))
}
fn update_tool_call_fields(
@@ -2645,6 +2813,10 @@ impl ToolCallEventStream {
*self.cancellation_rx.clone().borrow()
}
+ pub fn tool_use_id(&self) -> &LanguageModelToolUseId {
+ &self.tool_use_id
+ }
+
pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
self.stream
.update_tool_call_fields(&self.tool_use_id, fields);
@@ -2663,6 +2835,19 @@ impl ToolCallEventStream {
.ok();
}
+ pub fn update_subagent_thread(&self, thread: Entity<acp_thread::AcpThread>) {
+ self.stream
+ .0
+ .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
+ acp_thread::ToolCallUpdateSubagentThread {
+ id: acp::ToolCallId::new(self.tool_use_id.to_string()),
+ thread,
+ }
+ .into(),
+ )))
+ .ok();
+ }
+
pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
return Task::ready(Ok(()));
@@ -1,11 +1,31 @@
+use acp_thread::{AcpThread, AgentConnection, UserMessageId};
+use action_log::ActionLog;
use agent_client_protocol as acp;
-use anyhow::Result;
-use gpui::{App, SharedString, Task};
+use anyhow::{Result, anyhow};
+use collections::HashSet;
+use futures::channel::mpsc;
+use gpui::{App, AppContext, AsyncApp, Entity, SharedString, Task, WeakEntity};
+use project::Project;
+use prompt_store::ProjectContext;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
+use smol::stream::StreamExt;
+use std::any::Any;
+use std::path::Path;
+use std::rc::Rc;
use std::sync::Arc;
+use std::time::Duration;
+use util::ResultExt;
+use watch;
-use crate::{AgentTool, ToolCallEventStream};
+use crate::{
+ AgentTool, ContextServerRegistry, MAX_PARALLEL_SUBAGENTS, MAX_SUBAGENT_DEPTH, SubagentContext,
+ Templates, Thread, ThreadEvent, ToolCallAuthorization, ToolCallEventStream,
+};
+
+/// When a subagent's remaining context window falls below this fraction (25%),
+/// the "context running out" prompt is sent to encourage the subagent to wrap up.
+const CONTEXT_LOW_THRESHOLD: f32 = 0.25;
/// Spawns a subagent with its own context window to perform a delegated task.
///
@@ -63,11 +83,50 @@ pub struct SubagentToolInput {
pub allowed_tools: Option<Vec<String>>,
}
-pub struct SubagentTool;
+pub struct SubagentTool {
+ parent_thread: WeakEntity<Thread>,
+ project: Entity<Project>,
+ project_context: Entity<ProjectContext>,
+ context_server_registry: Entity<ContextServerRegistry>,
+ templates: Arc<Templates>,
+ current_depth: u8,
+ parent_tool_names: HashSet<SharedString>,
+}
impl SubagentTool {
- pub fn new() -> Self {
- Self
+ pub fn new(
+ parent_thread: WeakEntity<Thread>,
+ project: Entity<Project>,
+ project_context: Entity<ProjectContext>,
+ context_server_registry: Entity<ContextServerRegistry>,
+ templates: Arc<Templates>,
+ current_depth: u8,
+ parent_tool_names: Vec<SharedString>,
+ ) -> Self {
+ Self {
+ parent_thread,
+ project,
+ project_context,
+ context_server_registry,
+ templates,
+ current_depth,
+ parent_tool_names: parent_tool_names.into_iter().collect(),
+ }
+ }
+
+ pub fn validate_allowed_tools(&self, allowed_tools: &Option<Vec<String>>) -> Result<()> {
+ if let Some(tools) = allowed_tools {
+ for tool in tools {
+ if !self.parent_tool_names.contains(tool.as_str()) {
+ return Err(anyhow!(
+ "Tool '{}' is not available to the parent agent. Available tools: {:?}",
+ tool,
+ self.parent_tool_names.iter().collect::<Vec<_>>()
+ ));
+ }
+ }
+ }
+ Ok(())
}
}
@@ -76,7 +135,7 @@ impl AgentTool for SubagentTool {
type Output = String;
fn name() -> &'static str {
- "subagent"
+ acp_thread::SUBAGENT_TOOL_NAME
}
fn kind() -> acp::ToolKind {
@@ -88,22 +147,405 @@ impl AgentTool for SubagentTool {
input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
- match input {
- Ok(input) => format!("Subagent: {}", input.label).into(),
- Err(_) => "Subagent".into(),
- }
+ input
+ .map(|i| i.label.into())
+ .unwrap_or_else(|_| "Subagent".into())
}
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
- _cx: &mut App,
+ cx: &mut App,
) -> Task<Result<String>> {
- event_stream.update_fields(
- acp::ToolCallUpdateFields::new()
- .content(vec![format!("Starting subagent: {}", input.label).into()]),
+ if self.current_depth >= MAX_SUBAGENT_DEPTH {
+ return Task::ready(Err(anyhow!(
+ "Maximum subagent depth ({}) reached",
+ MAX_SUBAGENT_DEPTH
+ )));
+ }
+
+ if let Err(e) = self.validate_allowed_tools(&input.allowed_tools) {
+ return Task::ready(Err(e));
+ }
+
+ let Some(parent_thread) = self.parent_thread.upgrade() else {
+ return Task::ready(Err(anyhow!(
+ "Parent thread no longer exists (subagent depth={})",
+ self.current_depth + 1
+ )));
+ };
+
+ let running_count = parent_thread.read(cx).running_subagent_count();
+ if running_count >= MAX_PARALLEL_SUBAGENTS {
+ return Task::ready(Err(anyhow!(
+ "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
+ MAX_PARALLEL_SUBAGENTS
+ )));
+ }
+
+ let parent_thread_id = parent_thread.read(cx).id().clone();
+ let parent_model = parent_thread.read(cx).model().cloned();
+ let tool_use_id = event_stream.tool_use_id().clone();
+
+ let Some(model) = parent_model else {
+ return Task::ready(Err(anyhow!("No model configured")));
+ };
+
+ let subagent_context = SubagentContext {
+ parent_thread_id,
+ tool_use_id,
+ depth: self.current_depth + 1,
+ summary_prompt: input.summary_prompt.clone(),
+ context_low_prompt: input.context_low_prompt.clone(),
+ };
+
+ let project = self.project.clone();
+ let project_context = self.project_context.clone();
+ let context_server_registry = self.context_server_registry.clone();
+ let templates = self.templates.clone();
+ let task_prompt = input.task_prompt;
+ let timeout_ms = input.timeout_ms;
+ let allowed_tools: Option<HashSet<SharedString>> = input
+ .allowed_tools
+ .map(|tools| tools.into_iter().map(SharedString::from).collect());
+
+ let parent_thread = self.parent_thread.clone();
+
+ cx.spawn(async move |cx| {
+ let subagent_thread: Entity<Thread> = cx.new(|cx| {
+ Thread::new_subagent(
+ project.clone(),
+ project_context.clone(),
+ context_server_registry.clone(),
+ templates.clone(),
+ model,
+ subagent_context,
+ cx,
+ )
+ });
+
+ let subagent_weak = subagent_thread.downgrade();
+
+ let acp_thread: Entity<AcpThread> = cx.new(|cx| {
+ let session_id = subagent_thread.read(cx).id().clone();
+ let action_log: Entity<ActionLog> = cx.new(|_| ActionLog::new(project.clone()));
+ let connection: Rc<dyn AgentConnection> = Rc::new(SubagentDisplayConnection);
+ AcpThread::new(
+ "Subagent",
+ connection,
+ project.clone(),
+ action_log,
+ session_id,
+ watch::Receiver::constant(acp::PromptCapabilities::new()),
+ cx,
+ )
+ });
+
+ event_stream.update_subagent_thread(acp_thread.clone());
+
+ if let Some(parent) = parent_thread.upgrade() {
+ parent.update(cx, |thread, _cx| {
+ thread.register_running_subagent(subagent_weak.clone());
+ });
+ }
+
+ let result = run_subagent(
+ &subagent_thread,
+ &acp_thread,
+ allowed_tools,
+ task_prompt,
+ timeout_ms,
+ cx,
+ )
+ .await;
+
+ if let Some(parent) = parent_thread.upgrade() {
+ let _ = parent.update(cx, |thread, _cx| {
+ thread.unregister_running_subagent(&subagent_weak);
+ });
+ }
+
+ result
+ })
+ }
+}
+
+async fn run_subagent(
+ subagent_thread: &Entity<Thread>,
+ acp_thread: &Entity<AcpThread>,
+ allowed_tools: Option<HashSet<SharedString>>,
+ task_prompt: String,
+ timeout_ms: Option<u64>,
+ cx: &mut AsyncApp,
+) -> Result<String> {
+ if let Some(ref allowed) = allowed_tools {
+ subagent_thread.update(cx, |thread, _cx| {
+ thread.restrict_tools(allowed);
+ });
+ }
+
+ let mut events_rx =
+ subagent_thread.update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))?;
+
+ let acp_thread_weak = acp_thread.downgrade();
+
+ let timed_out = if let Some(timeout) = timeout_ms {
+ forward_events_with_timeout(
+ &mut events_rx,
+ &acp_thread_weak,
+ Duration::from_millis(timeout),
+ cx,
+ )
+ .await
+ } else {
+ forward_events_until_stop(&mut events_rx, &acp_thread_weak, cx).await;
+ false
+ };
+
+ let should_interrupt =
+ timed_out || check_context_low(subagent_thread, CONTEXT_LOW_THRESHOLD, cx);
+
+ if should_interrupt {
+ let mut summary_rx =
+ subagent_thread.update(cx, |thread, cx| thread.interrupt_for_summary(cx))?;
+ forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await;
+ } else {
+ let mut summary_rx =
+ subagent_thread.update(cx, |thread, cx| thread.request_final_summary(cx))?;
+ forward_events_until_stop(&mut summary_rx, &acp_thread_weak, cx).await;
+ }
+
+ Ok(extract_last_message(subagent_thread, cx))
+}
+
+async fn forward_events_until_stop(
+ events_rx: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
+ acp_thread: &WeakEntity<AcpThread>,
+ cx: &mut AsyncApp,
+) {
+ while let Some(event) = events_rx.next().await {
+ match event {
+ Ok(ThreadEvent::Stop(_)) => break,
+ Ok(event) => {
+ forward_event_to_acp_thread(event, acp_thread, cx);
+ }
+ Err(_) => break,
+ }
+ }
+}
+
+async fn forward_events_with_timeout(
+ events_rx: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
+ acp_thread: &WeakEntity<AcpThread>,
+ timeout: Duration,
+ cx: &mut AsyncApp,
+) -> bool {
+ use futures::future::{self, Either};
+
+ let deadline = std::time::Instant::now() + timeout;
+
+ loop {
+ let remaining = deadline.saturating_duration_since(std::time::Instant::now());
+ if remaining.is_zero() {
+ return true;
+ }
+
+ let timeout_future = cx.background_executor().timer(remaining);
+ let event_future = events_rx.next();
+
+ match future::select(event_future, timeout_future).await {
+ Either::Left((event, _)) => match event {
+ Some(Ok(ThreadEvent::Stop(_))) => return false,
+ Some(Ok(event)) => {
+ forward_event_to_acp_thread(event, acp_thread, cx);
+ }
+ Some(Err(_)) => return false,
+ None => return false,
+ },
+ Either::Right((_, _)) => return true,
+ }
+ }
+}
+
+fn forward_event_to_acp_thread(
+ event: ThreadEvent,
+ acp_thread: &WeakEntity<AcpThread>,
+ cx: &mut AsyncApp,
+) {
+ match event {
+ ThreadEvent::UserMessage(message) => {
+ acp_thread
+ .update(cx, |thread, cx| {
+ for content in message.content {
+ thread.push_user_content_block(
+ Some(message.id.clone()),
+ content.into(),
+ cx,
+ );
+ }
+ })
+ .log_err();
+ }
+ ThreadEvent::AgentText(text) => {
+ acp_thread
+ .update(cx, |thread, cx| {
+ thread.push_assistant_content_block(text.into(), false, cx)
+ })
+ .log_err();
+ }
+ ThreadEvent::AgentThinking(text) => {
+ acp_thread
+ .update(cx, |thread, cx| {
+ thread.push_assistant_content_block(text.into(), true, cx)
+ })
+ .log_err();
+ }
+ ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
+ tool_call,
+ options,
+ response,
+ }) => {
+ let outcome_task = acp_thread.update(cx, |thread, cx| {
+ thread.request_tool_call_authorization(tool_call, options, true, cx)
+ });
+ if let Ok(Ok(task)) = outcome_task {
+ cx.background_spawn(async move {
+ if let acp::RequestPermissionOutcome::Selected(
+ acp::SelectedPermissionOutcome { option_id, .. },
+ ) = task.await
+ {
+ response.send(option_id).ok();
+ }
+ })
+ .detach();
+ }
+ }
+ ThreadEvent::ToolCall(tool_call) => {
+ acp_thread
+ .update(cx, |thread, cx| thread.upsert_tool_call(tool_call, cx))
+ .log_err();
+ }
+ ThreadEvent::ToolCallUpdate(update) => {
+ acp_thread
+ .update(cx, |thread, cx| thread.update_tool_call(update, cx))
+ .log_err();
+ }
+ ThreadEvent::Retry(status) => {
+ acp_thread
+ .update(cx, |thread, cx| thread.update_retry_status(status, cx))
+ .log_err();
+ }
+ ThreadEvent::Stop(_) => {}
+ }
+}
+
+fn check_context_low(thread: &Entity<Thread>, threshold: f32, cx: &mut AsyncApp) -> bool {
+ thread.read_with(cx, |thread, _| {
+ if let Some(usage) = thread.latest_token_usage() {
+ let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
+ remaining_ratio <= threshold
+ } else {
+ false
+ }
+ })
+}
+
+fn extract_last_message(thread: &Entity<Thread>, cx: &mut AsyncApp) -> String {
+ thread.read_with(cx, |thread, _| {
+ thread
+ .last_message()
+ .map(|m| m.to_markdown())
+ .unwrap_or_else(|| "No response from subagent".to_string())
+ })
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use language_model::LanguageModelToolSchemaFormat;
+
+ #[test]
+ fn test_subagent_tool_input_json_schema_is_valid() {
+ let schema = SubagentTool::input_schema(LanguageModelToolSchemaFormat::JsonSchema);
+ let schema_json = serde_json::to_value(&schema).expect("schema should serialize to JSON");
+
+ assert!(
+ schema_json.get("properties").is_some(),
+ "schema should have properties"
+ );
+ let properties = schema_json.get("properties").unwrap();
+
+ assert!(properties.get("label").is_some(), "should have label field");
+ assert!(
+ properties.get("task_prompt").is_some(),
+ "should have task_prompt field"
+ );
+ assert!(
+ properties.get("summary_prompt").is_some(),
+ "should have summary_prompt field"
+ );
+ assert!(
+ properties.get("context_low_prompt").is_some(),
+ "should have context_low_prompt field"
);
- Task::ready(Ok("Subagent tool not yet implemented.".to_string()))
+ assert!(
+ properties.get("timeout_ms").is_some(),
+ "should have timeout_ms field"
+ );
+ assert!(
+ properties.get("allowed_tools").is_some(),
+ "should have allowed_tools field"
+ );
+ }
+
+ #[test]
+ fn test_subagent_tool_name() {
+ assert_eq!(SubagentTool::name(), "subagent");
+ }
+
+ #[test]
+ fn test_subagent_tool_kind() {
+ assert_eq!(SubagentTool::kind(), acp::ToolKind::Other);
+ }
+}
+
+struct SubagentDisplayConnection;
+
+impl AgentConnection for SubagentDisplayConnection {
+ fn telemetry_id(&self) -> SharedString {
+ "subagent".into()
+ }
+
+ fn auth_methods(&self) -> &[acp::AuthMethod] {
+ &[]
+ }
+
+ fn new_thread(
+ self: Rc<Self>,
+ _project: Entity<Project>,
+ _cwd: &Path,
+ _cx: &mut App,
+ ) -> Task<Result<Entity<AcpThread>>> {
+ unimplemented!("SubagentDisplayConnection does not support new_thread")
+ }
+
+ fn authenticate(&self, _method_id: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
+ unimplemented!("SubagentDisplayConnection does not support authenticate")
+ }
+
+ fn prompt(
+ &self,
+ _id: Option<UserMessageId>,
+ _params: acp::PromptRequest,
+ _cx: &mut App,
+ ) -> Task<Result<acp::PromptResponse>> {
+ unimplemented!("SubagentDisplayConnection does not support prompt")
+ }
+
+ fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
+
+ fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+ self
}
}