agent: Subagent low context warnings (#49902)

Ben Brandt created

Allow the parent agent to handle cases where the subagent is running on
of context window. Also communicates if it has completely out of
context.

Release Notes:

- N/A

Change summary

crates/acp_thread/src/acp_thread.rs |   6 
crates/agent/src/agent.rs           | 130 +++++++--
crates/agent/src/tests/mod.rs       | 431 ++++++++++++++++++++++++++++++
3 files changed, 533 insertions(+), 34 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -895,15 +895,17 @@ pub struct TokenUsage {
     pub max_output_tokens: Option<u64>,
 }
 
+pub const TOKEN_USAGE_WARNING_THRESHOLD: f32 = 0.8;
+
 impl TokenUsage {
     pub fn ratio(&self) -> TokenUsageRatio {
         #[cfg(debug_assertions)]
         let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
-            .unwrap_or("0.8".to_string())
+            .unwrap_or(TOKEN_USAGE_WARNING_THRESHOLD.to_string())
             .parse()
             .unwrap();
         #[cfg(not(debug_assertions))]
-        let warning_threshold: f32 = 0.8;
+        let warning_threshold: f32 = TOKEN_USAGE_WARNING_THRESHOLD;
 
         // When the maximum is unknown because there is no selected model,
         // avoid showing the token limit warning.

crates/agent/src/agent.rs 🔗

@@ -25,7 +25,7 @@ pub use tools::*;
 
 use acp_thread::{
     AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
-    AgentSessionListResponse, UserMessageId,
+    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
 };
 use agent_client_protocol as acp;
 use anyhow::{Context as _, Result, anyhow};
@@ -1652,33 +1652,14 @@ impl NativeThreadEnvironment {
         prompt: String,
         cx: &mut App,
     ) -> Result<Rc<dyn SubagentHandle>> {
-        parent_thread_entity.update(cx, |parent_thread, _cx| {
-            parent_thread.register_running_subagent(subagent_thread.downgrade())
-        });
-
-        let task = acp_thread.update(cx, |acp_thread, cx| {
-            acp_thread.send(vec![prompt.into()], cx)
-        });
-
-        let wait_for_prompt_to_complete = cx
-            .background_spawn(async move {
-                let response = task.await.log_err().flatten();
-                if response
-                    .is_some_and(|response| response.stop_reason == acp::StopReason::Cancelled)
-                {
-                    SubagentInitialPromptResult::Cancelled
-                } else {
-                    SubagentInitialPromptResult::Completed
-                }
-            })
-            .shared();
-
-        Ok(Rc::new(NativeSubagentHandle {
+        Ok(Rc::new(NativeSubagentHandle::new(
             session_id,
             subagent_thread,
-            parent_thread: parent_thread_entity.downgrade(),
-            wait_for_prompt_to_complete,
-        }) as _)
+            acp_thread,
+            parent_thread_entity,
+            prompt,
+            cx,
+        )) as _)
     }
 }
 
@@ -1749,17 +1730,95 @@ impl ThreadEnvironment for NativeThreadEnvironment {
     }
 }
 
-#[derive(Debug, Clone, Copy)]
-enum SubagentInitialPromptResult {
+#[derive(Debug, Clone)]
+enum SubagentPromptResult {
     Completed,
     Cancelled,
+    ContextWindowWarning,
+    Error(String),
 }
 
 pub struct NativeSubagentHandle {
     session_id: acp::SessionId,
     parent_thread: WeakEntity<Thread>,
     subagent_thread: Entity<Thread>,
-    wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
+    wait_for_prompt_to_complete: Shared<Task<SubagentPromptResult>>,
+    _subscription: Subscription,
+}
+
+impl NativeSubagentHandle {
+    fn new(
+        session_id: acp::SessionId,
+        subagent_thread: Entity<Thread>,
+        acp_thread: Entity<acp_thread::AcpThread>,
+        parent_thread_entity: Entity<Thread>,
+        prompt: String,
+        cx: &mut App,
+    ) -> Self {
+        let ratio_before_prompt = subagent_thread
+            .read(cx)
+            .latest_token_usage()
+            .map(|usage| usage.ratio());
+
+        parent_thread_entity.update(cx, |parent_thread, _cx| {
+            parent_thread.register_running_subagent(subagent_thread.downgrade())
+        });
+
+        let task = acp_thread.update(cx, |acp_thread, cx| {
+            acp_thread.send(vec![prompt.into()], cx)
+        });
+
+        let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
+        let mut token_limit_tx = Some(token_limit_tx);
+
+        let subscription = cx.subscribe(
+            &subagent_thread,
+            move |_thread, event: &TokenUsageUpdated, _cx| {
+                if let Some(usage) = &event.0 {
+                    let old_ratio = ratio_before_prompt
+                        .clone()
+                        .unwrap_or(TokenUsageRatio::Normal);
+                    let new_ratio = usage.ratio();
+                    if old_ratio == TokenUsageRatio::Normal && new_ratio == TokenUsageRatio::Warning
+                    {
+                        if let Some(tx) = token_limit_tx.take() {
+                            tx.send(()).ok();
+                        }
+                    }
+                }
+            },
+        );
+
+        let wait_for_prompt_to_complete = cx
+            .background_spawn(async move {
+                futures::select! {
+                    response = task.fuse() => match response {
+                        Ok(Some(response)) =>{
+                            match response.stop_reason {
+                                acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
+                                acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
+                                acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
+                                acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
+                                acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
+                            }
+
+                        }
+                        Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
+                        Err(error) => SubagentPromptResult::Error(error.to_string()),
+                    },
+                    _ = token_limit_rx.fuse() =>  SubagentPromptResult::ContextWindowWarning,
+                }
+            })
+            .shared();
+
+        NativeSubagentHandle {
+            session_id,
+            subagent_thread,
+            parent_thread: parent_thread_entity.downgrade(),
+            wait_for_prompt_to_complete,
+            _subscription: subscription,
+        }
+    }
 }
 
 impl SubagentHandle for NativeSubagentHandle {
@@ -1776,13 +1835,22 @@ impl SubagentHandle for NativeSubagentHandle {
 
         cx.spawn(async move |cx| {
             let result = match wait_for_prompt.await {
-                SubagentInitialPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
+                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
                     thread
                         .last_message()
                         .map(|m| m.to_markdown())
                         .context("No response from subagent")
                 }),
-                SubagentInitialPromptResult::Cancelled => Err(anyhow!("User cancelled")),
+                SubagentPromptResult::Cancelled => Err(anyhow!("User cancelled")),
+                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
+                SubagentPromptResult::ContextWindowWarning => {
+                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
+                    Err(anyhow!(
+                        "The agent is nearing the end of its context window and has been \
+                         stopped. You can prompt the thread again to have the agent wrap up \
+                         or hand off its work."
+                    ))
+                }
             };
 
             parent_thread

crates/agent/src/tests/mod.rs 🔗

@@ -29,7 +29,8 @@ use language_model::{
     LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
     LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
     LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
-    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
+    LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage,
+    fake_provider::FakeLanguageModel,
 };
 use pretty_assertions::assert_eq;
 use project::{
@@ -4830,6 +4831,434 @@ async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
     });
 }
 
+#[gpui::test]
+async fn test_subagent_context_window_warning(cx: &mut TestAppContext) {
+    init_test(cx);
+    cx.update(|cx| {
+        LanguageModelRegistry::test(cx);
+    });
+    cx.update(|cx| {
+        cx.update_flags(true, vec!["subagents".to_string()]);
+    });
+
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(
+        "/",
+        json!({
+            "a": {
+                "b.md": "Lorem"
+            }
+        }),
+    )
+    .await;
+    let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+    let thread_store = cx.new(|cx| ThreadStore::new(cx));
+    let agent = NativeAgent::new(
+        project.clone(),
+        thread_store.clone(),
+        Templates::new(),
+        None,
+        fs.clone(),
+        &mut cx.to_async(),
+    )
+    .await
+    .unwrap();
+    let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+    let acp_thread = cx
+        .update(|cx| {
+            connection
+                .clone()
+                .new_session(project.clone(), Path::new(""), cx)
+        })
+        .await
+        .unwrap();
+    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+    let thread = agent.read_with(cx, |agent, _| {
+        agent.sessions.get(&session_id).unwrap().thread.clone()
+    });
+    let model = Arc::new(FakeLanguageModel::default());
+
+    thread.update(cx, |thread, cx| {
+        thread.set_model(model.clone(), cx);
+    });
+    cx.run_until_parked();
+
+    // Start the parent turn
+    let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
+    cx.run_until_parked();
+    model.send_last_completion_stream_text_chunk("spawning subagent");
+    let subagent_tool_input = SpawnAgentToolInput {
+        label: "label".to_string(),
+        message: "subagent task prompt".to_string(),
+        session_id: None,
+    };
+    let subagent_tool_use = LanguageModelToolUse {
+        id: "subagent_1".into(),
+        name: SpawnAgentTool::NAME.into(),
+        raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
+        input: serde_json::to_value(&subagent_tool_input).unwrap(),
+        is_input_complete: true,
+        thought_signature: None,
+    };
+    model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        subagent_tool_use,
+    ));
+    model.end_last_completion_stream();
+
+    cx.run_until_parked();
+
+    // Verify subagent is running
+    let subagent_session_id = thread.read_with(cx, |thread, cx| {
+        thread
+            .running_subagent_ids(cx)
+            .get(0)
+            .expect("subagent thread should be running")
+            .clone()
+    });
+
+    // Send a usage update that crosses the warning threshold (80% of 1,000,000)
+    model.send_last_completion_stream_text_chunk("partial work");
+    model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+        TokenUsage {
+            input_tokens: 850_000,
+            output_tokens: 0,
+            cache_creation_input_tokens: 0,
+            cache_read_input_tokens: 0,
+        },
+    ));
+
+    cx.run_until_parked();
+
+    // The subagent should no longer be running
+    thread.read_with(cx, |thread, cx| {
+        assert!(
+            thread.running_subagent_ids(cx).is_empty(),
+            "subagent should be stopped after context window warning"
+        );
+    });
+
+    // The parent model should get a new completion request to respond to the tool error
+    model.send_last_completion_stream_text_chunk("Response after warning");
+    model.end_last_completion_stream();
+
+    send.await.unwrap();
+
+    // Verify the parent thread shows the warning error in the tool call
+    let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
+    assert!(
+        markdown.contains("nearing the end of its context window"),
+        "tool output should contain context window warning message, got:\n{markdown}"
+    );
+    assert!(
+        markdown.contains("Status: Failed"),
+        "tool call should have Failed status, got:\n{markdown}"
+    );
+
+    // Verify the subagent session still exists (can be resumed)
+    agent.read_with(cx, |agent, _cx| {
+        assert!(
+            agent.sessions.contains_key(&subagent_session_id),
+            "subagent session should still exist for potential resume"
+        );
+    });
+}
+
+#[gpui::test]
+async fn test_subagent_no_context_window_warning_when_already_at_warning(cx: &mut TestAppContext) {
+    init_test(cx);
+    cx.update(|cx| {
+        LanguageModelRegistry::test(cx);
+    });
+    cx.update(|cx| {
+        cx.update_flags(true, vec!["subagents".to_string()]);
+    });
+
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(
+        "/",
+        json!({
+            "a": {
+                "b.md": "Lorem"
+            }
+        }),
+    )
+    .await;
+    let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+    let thread_store = cx.new(|cx| ThreadStore::new(cx));
+    let agent = NativeAgent::new(
+        project.clone(),
+        thread_store.clone(),
+        Templates::new(),
+        None,
+        fs.clone(),
+        &mut cx.to_async(),
+    )
+    .await
+    .unwrap();
+    let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+    let acp_thread = cx
+        .update(|cx| {
+            connection
+                .clone()
+                .new_session(project.clone(), Path::new(""), cx)
+        })
+        .await
+        .unwrap();
+    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+    let thread = agent.read_with(cx, |agent, _| {
+        agent.sessions.get(&session_id).unwrap().thread.clone()
+    });
+    let model = Arc::new(FakeLanguageModel::default());
+
+    thread.update(cx, |thread, cx| {
+        thread.set_model(model.clone(), cx);
+    });
+    cx.run_until_parked();
+
+    // === First turn: create subagent, trigger context window warning ===
+    let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
+    cx.run_until_parked();
+    model.send_last_completion_stream_text_chunk("spawning subagent");
+    let subagent_tool_input = SpawnAgentToolInput {
+        label: "initial task".to_string(),
+        message: "do the first task".to_string(),
+        session_id: None,
+    };
+    let subagent_tool_use = LanguageModelToolUse {
+        id: "subagent_1".into(),
+        name: SpawnAgentTool::NAME.into(),
+        raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
+        input: serde_json::to_value(&subagent_tool_input).unwrap(),
+        is_input_complete: true,
+        thought_signature: None,
+    };
+    model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        subagent_tool_use,
+    ));
+    model.end_last_completion_stream();
+
+    cx.run_until_parked();
+
+    let subagent_session_id = thread.read_with(cx, |thread, cx| {
+        thread
+            .running_subagent_ids(cx)
+            .get(0)
+            .expect("subagent thread should be running")
+            .clone()
+    });
+
+    // Subagent sends a usage update that crosses the warning threshold.
+    // This triggers Normal→Warning, stopping the subagent.
+    model.send_last_completion_stream_text_chunk("partial work");
+    model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+        TokenUsage {
+            input_tokens: 850_000,
+            output_tokens: 0,
+            cache_creation_input_tokens: 0,
+            cache_read_input_tokens: 0,
+        },
+    ));
+
+    cx.run_until_parked();
+
+    // Verify the first turn was stopped with a context window warning
+    thread.read_with(cx, |thread, cx| {
+        assert!(
+            thread.running_subagent_ids(cx).is_empty(),
+            "subagent should be stopped after context window warning"
+        );
+    });
+
+    // Parent model responds to complete first turn
+    model.send_last_completion_stream_text_chunk("First response");
+    model.end_last_completion_stream();
+
+    send.await.unwrap();
+
+    let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
+    assert!(
+        markdown.contains("nearing the end of its context window"),
+        "first turn should have context window warning, got:\n{markdown}"
+    );
+
+    // === Second turn: resume the same subagent (now at Warning level) ===
+    let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
+    cx.run_until_parked();
+    model.send_last_completion_stream_text_chunk("resuming subagent");
+    let resume_tool_input = SpawnAgentToolInput {
+        label: "follow-up task".to_string(),
+        message: "do the follow-up task".to_string(),
+        session_id: Some(subagent_session_id.clone()),
+    };
+    let resume_tool_use = LanguageModelToolUse {
+        id: "subagent_2".into(),
+        name: SpawnAgentTool::NAME.into(),
+        raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
+        input: serde_json::to_value(&resume_tool_input).unwrap(),
+        is_input_complete: true,
+        thought_signature: None,
+    };
+    model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
+    model.end_last_completion_stream();
+
+    cx.run_until_parked();
+
+    // Subagent responds with tokens still at warning level (no worse).
+    // Since ratio_before_prompt was already Warning, this should NOT
+    // trigger the context window warning again.
+    model.send_last_completion_stream_text_chunk("follow-up task response");
+    model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+        TokenUsage {
+            input_tokens: 870_000,
+            output_tokens: 0,
+            cache_creation_input_tokens: 0,
+            cache_read_input_tokens: 0,
+        },
+    ));
+    model.end_last_completion_stream();
+
+    cx.run_until_parked();
+
+    // Parent model responds to complete second turn
+    model.send_last_completion_stream_text_chunk("Second response");
+    model.end_last_completion_stream();
+
+    send2.await.unwrap();
+
+    // The resumed subagent should have completed normally since the ratio
+    // didn't transition (it was Warning before and stayed at Warning)
+    let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
+    assert!(
+        markdown.contains("follow-up task response"),
+        "resumed subagent should complete normally when already at warning, got:\n{markdown}"
+    );
+    // The second tool call should NOT have a context window warning
+    let second_tool_pos = markdown
+        .find("follow-up task")
+        .expect("should find follow-up tool call");
+    let after_second_tool = &markdown[second_tool_pos..];
+    assert!(
+        !after_second_tool.contains("nearing the end of its context window"),
+        "should NOT contain context window warning for resumed subagent at same level, got:\n{after_second_tool}"
+    );
+}
+
+#[gpui::test]
+async fn test_subagent_error_propagation(cx: &mut TestAppContext) {
+    init_test(cx);
+    cx.update(|cx| {
+        LanguageModelRegistry::test(cx);
+    });
+    cx.update(|cx| {
+        cx.update_flags(true, vec!["subagents".to_string()]);
+    });
+
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(
+        "/",
+        json!({
+            "a": {
+                "b.md": "Lorem"
+            }
+        }),
+    )
+    .await;
+    let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+    let thread_store = cx.new(|cx| ThreadStore::new(cx));
+    let agent = NativeAgent::new(
+        project.clone(),
+        thread_store.clone(),
+        Templates::new(),
+        None,
+        fs.clone(),
+        &mut cx.to_async(),
+    )
+    .await
+    .unwrap();
+    let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+    let acp_thread = cx
+        .update(|cx| {
+            connection
+                .clone()
+                .new_session(project.clone(), Path::new(""), cx)
+        })
+        .await
+        .unwrap();
+    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+    let thread = agent.read_with(cx, |agent, _| {
+        agent.sessions.get(&session_id).unwrap().thread.clone()
+    });
+    let model = Arc::new(FakeLanguageModel::default());
+
+    thread.update(cx, |thread, cx| {
+        thread.set_model(model.clone(), cx);
+    });
+    cx.run_until_parked();
+
+    // Start the parent turn
+    let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
+    cx.run_until_parked();
+    model.send_last_completion_stream_text_chunk("spawning subagent");
+    let subagent_tool_input = SpawnAgentToolInput {
+        label: "label".to_string(),
+        message: "subagent task prompt".to_string(),
+        session_id: None,
+    };
+    let subagent_tool_use = LanguageModelToolUse {
+        id: "subagent_1".into(),
+        name: SpawnAgentTool::NAME.into(),
+        raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
+        input: serde_json::to_value(&subagent_tool_input).unwrap(),
+        is_input_complete: true,
+        thought_signature: None,
+    };
+    model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        subagent_tool_use,
+    ));
+    model.end_last_completion_stream();
+
+    cx.run_until_parked();
+
+    // Verify subagent is running
+    thread.read_with(cx, |thread, cx| {
+        assert!(
+            !thread.running_subagent_ids(cx).is_empty(),
+            "subagent should be running"
+        );
+    });
+
+    // The subagent's model returns a non-retryable error
+    model.send_last_completion_stream_error(LanguageModelCompletionError::PromptTooLarge {
+        tokens: None,
+    });
+
+    cx.run_until_parked();
+
+    // The subagent should no longer be running
+    thread.read_with(cx, |thread, cx| {
+        assert!(
+            thread.running_subagent_ids(cx).is_empty(),
+            "subagent should not be running after error"
+        );
+    });
+
+    // The parent model should get a new completion request to respond to the tool error
+    model.send_last_completion_stream_text_chunk("Response after error");
+    model.end_last_completion_stream();
+
+    send.await.unwrap();
+
+    // Verify the parent thread shows the error in the tool call
+    let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
+    assert!(
+        markdown.contains("Status: Failed"),
+        "tool call should have Failed status after model error, got:\n{markdown}"
+    );
+}
+
 #[gpui::test]
 async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
     init_test(cx);