From 8294fbb75beff01336e9561517c392f1b94ed72f Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Mon, 23 Feb 2026 17:03:32 +0100 Subject: [PATCH] agent: Subagent low context warnings (#49902) Allow the parent agent to handle cases where the subagent is running on of context window. Also communicates if it has completely out of context. Release Notes: - N/A --- crates/acp_thread/src/acp_thread.rs | 6 +- crates/agent/src/agent.rs | 130 +++++++-- crates/agent/src/tests/mod.rs | 431 +++++++++++++++++++++++++++- 3 files changed, 533 insertions(+), 34 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 8da004b2cb967b19ec27d03ce573778bf301fcd9..83645226e5eb9cba3d19b37b587d15d1d80087c1 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -895,15 +895,17 @@ pub struct TokenUsage { pub max_output_tokens: Option, } +pub const TOKEN_USAGE_WARNING_THRESHOLD: f32 = 0.8; + impl TokenUsage { pub fn ratio(&self) -> TokenUsageRatio { #[cfg(debug_assertions)] let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") - .unwrap_or("0.8".to_string()) + .unwrap_or(TOKEN_USAGE_WARNING_THRESHOLD.to_string()) .parse() .unwrap(); #[cfg(not(debug_assertions))] - let warning_threshold: f32 = 0.8; + let warning_threshold: f32 = TOKEN_USAGE_WARNING_THRESHOLD; // When the maximum is unknown because there is no selected model, // avoid showing the token limit warning. diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 3906da36056709bdded6ffca92a85a390ec81f44..759c6e3b9c8c228a6ae6bea5330819b97200b603 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -25,7 +25,7 @@ pub use tools::*; use acp_thread::{ AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest, - AgentSessionListResponse, UserMessageId, + AgentSessionListResponse, TokenUsageRatio, UserMessageId, }; use agent_client_protocol as acp; use anyhow::{Context as _, Result, anyhow}; @@ -1652,33 +1652,14 @@ impl NativeThreadEnvironment { prompt: String, cx: &mut App, ) -> Result> { - parent_thread_entity.update(cx, |parent_thread, _cx| { - parent_thread.register_running_subagent(subagent_thread.downgrade()) - }); - - let task = acp_thread.update(cx, |acp_thread, cx| { - acp_thread.send(vec![prompt.into()], cx) - }); - - let wait_for_prompt_to_complete = cx - .background_spawn(async move { - let response = task.await.log_err().flatten(); - if response - .is_some_and(|response| response.stop_reason == acp::StopReason::Cancelled) - { - SubagentInitialPromptResult::Cancelled - } else { - SubagentInitialPromptResult::Completed - } - }) - .shared(); - - Ok(Rc::new(NativeSubagentHandle { + Ok(Rc::new(NativeSubagentHandle::new( session_id, subagent_thread, - parent_thread: parent_thread_entity.downgrade(), - wait_for_prompt_to_complete, - }) as _) + acp_thread, + parent_thread_entity, + prompt, + cx, + )) as _) } } @@ -1749,17 +1730,95 @@ impl ThreadEnvironment for NativeThreadEnvironment { } } -#[derive(Debug, Clone, Copy)] -enum SubagentInitialPromptResult { +#[derive(Debug, Clone)] +enum SubagentPromptResult { Completed, Cancelled, + ContextWindowWarning, + Error(String), } pub struct NativeSubagentHandle { session_id: acp::SessionId, parent_thread: WeakEntity, subagent_thread: Entity, - wait_for_prompt_to_complete: Shared>, + wait_for_prompt_to_complete: Shared>, + _subscription: Subscription, +} + +impl NativeSubagentHandle { + fn new( + session_id: acp::SessionId, + subagent_thread: Entity, + acp_thread: Entity, + parent_thread_entity: Entity, + prompt: String, + cx: &mut App, + ) -> Self { + let ratio_before_prompt = subagent_thread + .read(cx) + .latest_token_usage() + .map(|usage| usage.ratio()); + + parent_thread_entity.update(cx, |parent_thread, _cx| { + parent_thread.register_running_subagent(subagent_thread.downgrade()) + }); + + let task = acp_thread.update(cx, |acp_thread, cx| { + acp_thread.send(vec![prompt.into()], cx) + }); + + let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>(); + let mut token_limit_tx = Some(token_limit_tx); + + let subscription = cx.subscribe( + &subagent_thread, + move |_thread, event: &TokenUsageUpdated, _cx| { + if let Some(usage) = &event.0 { + let old_ratio = ratio_before_prompt + .clone() + .unwrap_or(TokenUsageRatio::Normal); + let new_ratio = usage.ratio(); + if old_ratio == TokenUsageRatio::Normal && new_ratio == TokenUsageRatio::Warning + { + if let Some(tx) = token_limit_tx.take() { + tx.send(()).ok(); + } + } + } + }, + ); + + let wait_for_prompt_to_complete = cx + .background_spawn(async move { + futures::select! { + response = task.fuse() => match response { + Ok(Some(response)) =>{ + match response.stop_reason { + acp::StopReason::Cancelled => SubagentPromptResult::Cancelled, + acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()), + acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()), + acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()), + acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed, + } + + } + Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()), + Err(error) => SubagentPromptResult::Error(error.to_string()), + }, + _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning, + } + }) + .shared(); + + NativeSubagentHandle { + session_id, + subagent_thread, + parent_thread: parent_thread_entity.downgrade(), + wait_for_prompt_to_complete, + _subscription: subscription, + } + } } impl SubagentHandle for NativeSubagentHandle { @@ -1776,13 +1835,22 @@ impl SubagentHandle for NativeSubagentHandle { cx.spawn(async move |cx| { let result = match wait_for_prompt.await { - SubagentInitialPromptResult::Completed => thread.read_with(cx, |thread, _cx| { + SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| { thread .last_message() .map(|m| m.to_markdown()) .context("No response from subagent") }), - SubagentInitialPromptResult::Cancelled => Err(anyhow!("User cancelled")), + SubagentPromptResult::Cancelled => Err(anyhow!("User cancelled")), + SubagentPromptResult::Error(message) => Err(anyhow!("{message}")), + SubagentPromptResult::ContextWindowWarning => { + thread.update(cx, |thread, cx| thread.cancel(cx)).await; + Err(anyhow!( + "The agent is nearing the end of its context window and has been \ + stopped. You can prompt the thread again to have the agent wrap up \ + or hand off its work." + )) + } }; parent_thread diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 2c412561bae29b87d96e2a4a016283015dfc4e15..139242fdee9da968986b3fc9537bf9e5292b7dc5 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -29,7 +29,8 @@ use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat, - LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel, + LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage, + fake_provider::FakeLanguageModel, }; use pretty_assertions::assert_eq; use project::{ @@ -4830,6 +4831,434 @@ async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_subagent_context_window_warning(cx: &mut TestAppContext) { + init_test(cx); + cx.update(|cx| { + LanguageModelRegistry::test(cx); + }); + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "b.md": "Lorem" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let agent = NativeAgent::new( + project.clone(), + thread_store.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_session(project.clone(), Path::new(""), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + let model = Arc::new(FakeLanguageModel::default()); + + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + }); + cx.run_until_parked(); + + // Start the parent turn + let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx)); + cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("spawning subagent"); + let subagent_tool_input = SpawnAgentToolInput { + label: "label".to_string(), + message: "subagent task prompt".to_string(), + session_id: None, + }; + let subagent_tool_use = LanguageModelToolUse { + id: "subagent_1".into(), + name: SpawnAgentTool::NAME.into(), + raw_input: serde_json::to_string(&subagent_tool_input).unwrap(), + input: serde_json::to_value(&subagent_tool_input).unwrap(), + is_input_complete: true, + thought_signature: None, + }; + model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + subagent_tool_use, + )); + model.end_last_completion_stream(); + + cx.run_until_parked(); + + // Verify subagent is running + let subagent_session_id = thread.read_with(cx, |thread, cx| { + thread + .running_subagent_ids(cx) + .get(0) + .expect("subagent thread should be running") + .clone() + }); + + // Send a usage update that crosses the warning threshold (80% of 1,000,000) + model.send_last_completion_stream_text_chunk("partial work"); + model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + TokenUsage { + input_tokens: 850_000, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); + + cx.run_until_parked(); + + // The subagent should no longer be running + thread.read_with(cx, |thread, cx| { + assert!( + thread.running_subagent_ids(cx).is_empty(), + "subagent should be stopped after context window warning" + ); + }); + + // The parent model should get a new completion request to respond to the tool error + model.send_last_completion_stream_text_chunk("Response after warning"); + model.end_last_completion_stream(); + + send.await.unwrap(); + + // Verify the parent thread shows the warning error in the tool call + let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)); + assert!( + markdown.contains("nearing the end of its context window"), + "tool output should contain context window warning message, got:\n{markdown}" + ); + assert!( + markdown.contains("Status: Failed"), + "tool call should have Failed status, got:\n{markdown}" + ); + + // Verify the subagent session still exists (can be resumed) + agent.read_with(cx, |agent, _cx| { + assert!( + agent.sessions.contains_key(&subagent_session_id), + "subagent session should still exist for potential resume" + ); + }); +} + +#[gpui::test] +async fn test_subagent_no_context_window_warning_when_already_at_warning(cx: &mut TestAppContext) { + init_test(cx); + cx.update(|cx| { + LanguageModelRegistry::test(cx); + }); + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "b.md": "Lorem" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let agent = NativeAgent::new( + project.clone(), + thread_store.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_session(project.clone(), Path::new(""), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + let model = Arc::new(FakeLanguageModel::default()); + + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + }); + cx.run_until_parked(); + + // === First turn: create subagent, trigger context window warning === + let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx)); + cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("spawning subagent"); + let subagent_tool_input = SpawnAgentToolInput { + label: "initial task".to_string(), + message: "do the first task".to_string(), + session_id: None, + }; + let subagent_tool_use = LanguageModelToolUse { + id: "subagent_1".into(), + name: SpawnAgentTool::NAME.into(), + raw_input: serde_json::to_string(&subagent_tool_input).unwrap(), + input: serde_json::to_value(&subagent_tool_input).unwrap(), + is_input_complete: true, + thought_signature: None, + }; + model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + subagent_tool_use, + )); + model.end_last_completion_stream(); + + cx.run_until_parked(); + + let subagent_session_id = thread.read_with(cx, |thread, cx| { + thread + .running_subagent_ids(cx) + .get(0) + .expect("subagent thread should be running") + .clone() + }); + + // Subagent sends a usage update that crosses the warning threshold. + // This triggers Normal→Warning, stopping the subagent. + model.send_last_completion_stream_text_chunk("partial work"); + model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + TokenUsage { + input_tokens: 850_000, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); + + cx.run_until_parked(); + + // Verify the first turn was stopped with a context window warning + thread.read_with(cx, |thread, cx| { + assert!( + thread.running_subagent_ids(cx).is_empty(), + "subagent should be stopped after context window warning" + ); + }); + + // Parent model responds to complete first turn + model.send_last_completion_stream_text_chunk("First response"); + model.end_last_completion_stream(); + + send.await.unwrap(); + + let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)); + assert!( + markdown.contains("nearing the end of its context window"), + "first turn should have context window warning, got:\n{markdown}" + ); + + // === Second turn: resume the same subagent (now at Warning level) === + let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx)); + cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("resuming subagent"); + let resume_tool_input = SpawnAgentToolInput { + label: "follow-up task".to_string(), + message: "do the follow-up task".to_string(), + session_id: Some(subagent_session_id.clone()), + }; + let resume_tool_use = LanguageModelToolUse { + id: "subagent_2".into(), + name: SpawnAgentTool::NAME.into(), + raw_input: serde_json::to_string(&resume_tool_input).unwrap(), + input: serde_json::to_value(&resume_tool_input).unwrap(), + is_input_complete: true, + thought_signature: None, + }; + model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use)); + model.end_last_completion_stream(); + + cx.run_until_parked(); + + // Subagent responds with tokens still at warning level (no worse). + // Since ratio_before_prompt was already Warning, this should NOT + // trigger the context window warning again. + model.send_last_completion_stream_text_chunk("follow-up task response"); + model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + TokenUsage { + input_tokens: 870_000, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); + model.end_last_completion_stream(); + + cx.run_until_parked(); + + // Parent model responds to complete second turn + model.send_last_completion_stream_text_chunk("Second response"); + model.end_last_completion_stream(); + + send2.await.unwrap(); + + // The resumed subagent should have completed normally since the ratio + // didn't transition (it was Warning before and stayed at Warning) + let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)); + assert!( + markdown.contains("follow-up task response"), + "resumed subagent should complete normally when already at warning, got:\n{markdown}" + ); + // The second tool call should NOT have a context window warning + let second_tool_pos = markdown + .find("follow-up task") + .expect("should find follow-up tool call"); + let after_second_tool = &markdown[second_tool_pos..]; + assert!( + !after_second_tool.contains("nearing the end of its context window"), + "should NOT contain context window warning for resumed subagent at same level, got:\n{after_second_tool}" + ); +} + +#[gpui::test] +async fn test_subagent_error_propagation(cx: &mut TestAppContext) { + init_test(cx); + cx.update(|cx| { + LanguageModelRegistry::test(cx); + }); + cx.update(|cx| { + cx.update_flags(true, vec!["subagents".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "b.md": "Lorem" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let agent = NativeAgent::new( + project.clone(), + thread_store.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_session(project.clone(), Path::new(""), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + let model = Arc::new(FakeLanguageModel::default()); + + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + }); + cx.run_until_parked(); + + // Start the parent turn + let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx)); + cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("spawning subagent"); + let subagent_tool_input = SpawnAgentToolInput { + label: "label".to_string(), + message: "subagent task prompt".to_string(), + session_id: None, + }; + let subagent_tool_use = LanguageModelToolUse { + id: "subagent_1".into(), + name: SpawnAgentTool::NAME.into(), + raw_input: serde_json::to_string(&subagent_tool_input).unwrap(), + input: serde_json::to_value(&subagent_tool_input).unwrap(), + is_input_complete: true, + thought_signature: None, + }; + model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + subagent_tool_use, + )); + model.end_last_completion_stream(); + + cx.run_until_parked(); + + // Verify subagent is running + thread.read_with(cx, |thread, cx| { + assert!( + !thread.running_subagent_ids(cx).is_empty(), + "subagent should be running" + ); + }); + + // The subagent's model returns a non-retryable error + model.send_last_completion_stream_error(LanguageModelCompletionError::PromptTooLarge { + tokens: None, + }); + + cx.run_until_parked(); + + // The subagent should no longer be running + thread.read_with(cx, |thread, cx| { + assert!( + thread.running_subagent_ids(cx).is_empty(), + "subagent should not be running after error" + ); + }); + + // The parent model should get a new completion request to respond to the tool error + model.send_last_completion_stream_text_chunk("Response after error"); + model.end_last_completion_stream(); + + send.await.unwrap(); + + // Verify the parent thread shows the error in the tool call + let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)); + assert!( + markdown.contains("Status: Failed"), + "tool call should have Failed status after model error, got:\n{markdown}" + ); +} + #[gpui::test] async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) { init_test(cx);