From f69ab1d04e53577b4fc5e873ebb0ec0bbd3283e2 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Fri, 6 Mar 2026 15:39:08 +0100 Subject: [PATCH] agent: Fail faster in case streaming tool call fails (#50834) If a streaming tool call (e.g. edit file) returns an error during streaming, we would wait until we received the whole input. Release Notes: - N/A --------- Co-authored-by: Ben Brandt --- crates/agent/src/tests/mod.rs | 198 ++++++++++++++++++++++++++- crates/agent/src/tests/test_tools.rs | 73 +++++++++- crates/agent/src/thread.rs | 85 +++++++++--- 3 files changed, 332 insertions(+), 24 deletions(-) diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 0993b43a13ced62000692bf2b0b35d3ab7fb68e7..79e8a5e24592d746675de670ca3288771e5eb5f4 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -3616,7 +3616,7 @@ async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input( let fake_model = model.as_fake(); thread.update(cx, |thread, _cx| { - thread.add_tool(StreamingEchoTool); + thread.add_tool(StreamingEchoTool::new()); }); let _events = thread @@ -3768,7 +3768,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { InfiniteTool::NAME: true, CancellationAwareTool::NAME: true, StreamingEchoTool::NAME: true, - (TerminalTool::NAME): true, + StreamingFailingEchoTool::NAME: true, + TerminalTool::NAME: true, } } } @@ -6335,3 +6336,196 @@ async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) { ); }); } + +#[gpui::test] +async fn test_streaming_tool_error_breaks_stream_loop_immediately(cx: &mut TestAppContext) { + init_test(cx); + always_allow_tools(cx); + + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + thread.update(cx, |thread, _cx| { + thread.add_tool(StreamingFailingEchoTool { + receive_chunks_until_failure: 1, + }); + }); + + let _events = thread + .update(cx, |thread, cx| { + thread.send( + UserMessageId::new(), + ["Use the streaming_failing_echo tool"], + cx, + ) + }) + .unwrap(); + cx.run_until_parked(); + + let tool_use = LanguageModelToolUse { + id: "call_1".into(), + name: StreamingFailingEchoTool::NAME.into(), + raw_input: "hello".into(), + input: json!({}), + is_input_complete: false, + thought_signature: None, + }; + + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); + + cx.run_until_parked(); + + let completions = fake_model.pending_completions(); + let last_completion = completions.last().unwrap(); + + assert_eq!( + last_completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec!["Use the streaming_failing_echo tool".into()], + cache: false, + reasoning_details: None, + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![language_model::MessageContent::ToolUse(tool_use.clone())], + cache: false, + reasoning_details: None, + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![language_model::MessageContent::ToolResult( + LanguageModelToolResult { + tool_use_id: tool_use.id.clone(), + tool_name: tool_use.name, + is_error: true, + content: "failed".into(), + output: Some("failed".into()), + } + )], + cache: true, + reasoning_details: None, + }, + ] + ); +} + +#[gpui::test] +async fn test_streaming_tool_error_waits_for_prior_tools_to_complete(cx: &mut TestAppContext) { + init_test(cx); + always_allow_tools(cx); + + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let (complete_streaming_echo_tool_call_tx, complete_streaming_echo_tool_call_rx) = + oneshot::channel(); + + thread.update(cx, |thread, _cx| { + thread.add_tool( + StreamingEchoTool::new().with_wait_until_complete(complete_streaming_echo_tool_call_rx), + ); + thread.add_tool(StreamingFailingEchoTool { + receive_chunks_until_failure: 1, + }); + }); + + let _events = thread + .update(cx, |thread, cx| { + thread.send( + UserMessageId::new(), + ["Use the streaming_echo tool and the streaming_failing_echo tool"], + cx, + ) + }) + .unwrap(); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "call_1".into(), + name: StreamingEchoTool::NAME.into(), + raw_input: "hello".into(), + input: json!({ "text": "hello" }), + is_input_complete: false, + thought_signature: None, + }, + )); + let first_tool_use = LanguageModelToolUse { + id: "call_1".into(), + name: StreamingEchoTool::NAME.into(), + raw_input: "hello world".into(), + input: json!({ "text": "hello world" }), + is_input_complete: true, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + first_tool_use.clone(), + )); + let second_tool_use = LanguageModelToolUse { + name: StreamingFailingEchoTool::NAME.into(), + raw_input: "hello".into(), + input: json!({ "text": "hello" }), + is_input_complete: false, + thought_signature: None, + id: "call_2".into(), + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + second_tool_use.clone(), + )); + + cx.run_until_parked(); + + complete_streaming_echo_tool_call_tx.send(()).unwrap(); + + cx.run_until_parked(); + + let completions = fake_model.pending_completions(); + let last_completion = completions.last().unwrap(); + + assert_eq!( + last_completion.messages[1..], + vec![ + LanguageModelRequestMessage { + role: Role::User, + content: vec![ + "Use the streaming_echo tool and the streaming_failing_echo tool".into() + ], + cache: false, + reasoning_details: None, + }, + LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![ + language_model::MessageContent::ToolUse(first_tool_use.clone()), + language_model::MessageContent::ToolUse(second_tool_use.clone()) + ], + cache: false, + reasoning_details: None, + }, + LanguageModelRequestMessage { + role: Role::User, + content: vec![ + language_model::MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: second_tool_use.id.clone(), + tool_name: second_tool_use.name, + is_error: true, + content: "failed".into(), + output: Some("failed".into()), + }), + language_model::MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: first_tool_use.id.clone(), + tool_name: first_tool_use.name, + is_error: false, + content: "hello world".into(), + output: Some("hello world".into()), + }), + ], + cache: true, + reasoning_details: None, + }, + ] + ); +} diff --git a/crates/agent/src/tests/test_tools.rs b/crates/agent/src/tests/test_tools.rs index ac179c590a93824813afa338d9deed16b4d00ebd..f36549a6c42f9e810c7794d8ec683613b6ae6933 100644 --- a/crates/agent/src/tests/test_tools.rs +++ b/crates/agent/src/tests/test_tools.rs @@ -2,6 +2,7 @@ use super::*; use agent_settings::AgentSettings; use gpui::{App, SharedString, Task}; use std::future; +use std::sync::Mutex; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; @@ -14,7 +15,22 @@ pub struct StreamingEchoToolInput { pub text: String, } -pub struct StreamingEchoTool; +pub struct StreamingEchoTool { + wait_until_complete_rx: Mutex>>, +} + +impl StreamingEchoTool { + pub fn new() -> Self { + Self { + wait_until_complete_rx: Mutex::new(None), + } + } + + pub fn with_wait_until_complete(mut self, receiver: oneshot::Receiver<()>) -> Self { + self.wait_until_complete_rx = Mutex::new(Some(receiver)); + self + } +} impl AgentTool for StreamingEchoTool { type Input = StreamingEchoToolInput; @@ -44,17 +60,72 @@ impl AgentTool for StreamingEchoTool { _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { + let wait_until_complete_rx = self.wait_until_complete_rx.lock().unwrap().take(); cx.spawn(async move |_cx| { while input.recv_partial().await.is_some() {} let input = input .recv() .await .map_err(|e| format!("Failed to receive tool input: {e}"))?; + if let Some(rx) = wait_until_complete_rx { + rx.await.ok(); + } Ok(input.text) }) } } +/// A streaming tool that echoes its input, used to test streaming tool +/// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends +/// before `is_input_complete`). +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct StreamingFailingEchoToolInput { + /// The text to echo. + pub text: String, +} + +pub struct StreamingFailingEchoTool { + pub receive_chunks_until_failure: usize, +} + +impl AgentTool for StreamingFailingEchoTool { + type Input = StreamingFailingEchoToolInput; + + type Output = String; + + const NAME: &'static str = "streaming_failing_echo"; + + fn kind() -> acp::ToolKind { + acp::ToolKind::Other + } + + fn supports_input_streaming() -> bool { + true + } + + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { + "echo".into() + } + + fn run( + self: Arc, + mut input: ToolInput, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + cx.spawn(async move |_cx| { + for _ in 0..self.receive_chunks_until_failure { + let _ = input.recv_partial().await; + } + Err("failed".into()) + }) + } +} + /// A tool that echoes its input #[derive(JsonSchema, Serialize, Deserialize)] pub struct EchoToolInput { diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 73102929ac58caaf96b06e6ab74ded698cbe86e3..e61a395e71f93d49d63d378355c89e44359db835 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1846,12 +1846,37 @@ impl Thread { Ok(events) => (events.fuse(), None), Err(err) => (stream::empty().boxed().fuse(), Some(err)), }; - let mut tool_results = FuturesUnordered::new(); + let mut tool_results: FuturesUnordered> = + FuturesUnordered::new(); + let mut early_tool_results: Vec = Vec::new(); let mut cancelled = false; loop { - // Race between getting the first event and cancellation + // Race between getting the first event, tool completion, and cancellation. let first_event = futures::select! { event = events.next().fuse() => event, + tool_result = futures::StreamExt::select_next_some(&mut tool_results) => { + let is_error = tool_result.is_error; + let is_still_streaming = this + .read_with(cx, |this, _cx| { + this.running_turn + .as_ref() + .and_then(|turn| turn.streaming_tool_inputs.get(&tool_result.tool_use_id)) + .map_or(false, |inputs| !inputs.has_received_final()) + }) + .unwrap_or(false); + + early_tool_results.push(tool_result); + + // Only break if the tool errored and we are still + // streaming the input of the tool. If the tool errored + // but we are no longer streaming its input (i.e. there + // are parallel tool calls) we want to continue + // processing those tool inputs. + if is_error && is_still_streaming { + break; + } + continue; + } _ = cancellation_rx.changed().fuse() => { if *cancellation_rx.borrow() { cancelled = true; @@ -1931,26 +1956,13 @@ impl Thread { } })?; - let end_turn = tool_results.is_empty(); - while let Some(tool_result) = tool_results.next().await { - log::debug!("Tool finished {:?}", tool_result); + let end_turn = tool_results.is_empty() && early_tool_results.is_empty(); - event_stream.update_tool_call_fields( - &tool_result.tool_use_id, - acp::ToolCallUpdateFields::new() - .status(if tool_result.is_error { - acp::ToolCallStatus::Failed - } else { - acp::ToolCallStatus::Completed - }) - .raw_output(tool_result.output.clone()), - None, - ); - this.update(cx, |this, _cx| { - this.pending_message() - .tool_results - .insert(tool_result.tool_use_id.clone(), tool_result); - })?; + for tool_result in early_tool_results { + Self::process_tool_result(this, event_stream, cx, tool_result)?; + } + while let Some(tool_result) = tool_results.next().await { + Self::process_tool_result(this, event_stream, cx, tool_result)?; } this.update(cx, |this, cx| { @@ -2004,6 +2016,33 @@ impl Thread { } } + fn process_tool_result( + this: &WeakEntity, + event_stream: &ThreadEventStream, + cx: &mut AsyncApp, + tool_result: LanguageModelToolResult, + ) -> Result<(), anyhow::Error> { + log::debug!("Tool finished {:?}", tool_result); + + event_stream.update_tool_call_fields( + &tool_result.tool_use_id, + acp::ToolCallUpdateFields::new() + .status(if tool_result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + }) + .raw_output(tool_result.output.clone()), + None, + ); + this.update(cx, |this, _cx| { + this.pending_message() + .tool_results + .insert(tool_result.tool_use_id.clone(), tool_result); + })?; + Ok(()) + } + fn handle_completion_error( &mut self, error: LanguageModelCompletionError, @@ -3072,6 +3111,10 @@ impl ToolInputSender { (sender, input) } + pub(crate) fn has_received_final(&self) -> bool { + self.final_tx.is_none() + } + pub(crate) fn send_partial(&self, value: serde_json::Value) { self.partial_tx.unbounded_send(value).ok(); }