agent: Fail faster in case streaming tool call fails (#50834)

Bennet Bo Fenner and Ben Brandt created

If a streaming tool call (e.g. edit file) returns an error during
streaming, we would wait until we received the whole input.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>

Change summary

crates/agent/src/tests/mod.rs        | 198 +++++++++++++++++++++++++++++
crates/agent/src/tests/test_tools.rs |  73 ++++++++++
crates/agent/src/thread.rs           |  85 +++++++++---
3 files changed, 332 insertions(+), 24 deletions(-)

Detailed changes

crates/agent/src/tests/mod.rs 🔗

@@ -3616,7 +3616,7 @@ async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input(
     let fake_model = model.as_fake();
 
     thread.update(cx, |thread, _cx| {
-        thread.add_tool(StreamingEchoTool);
+        thread.add_tool(StreamingEchoTool::new());
     });
 
     let _events = thread
@@ -3768,7 +3768,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
                             InfiniteTool::NAME: true,
                             CancellationAwareTool::NAME: true,
                             StreamingEchoTool::NAME: true,
-                            (TerminalTool::NAME): true,
+                            StreamingFailingEchoTool::NAME: true,
+                            TerminalTool::NAME: true,
                         }
                     }
                 }
@@ -6335,3 +6336,196 @@ async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
         );
     });
 }
+
+#[gpui::test]
+async fn test_streaming_tool_error_breaks_stream_loop_immediately(cx: &mut TestAppContext) {
+    init_test(cx);
+    always_allow_tools(cx);
+
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    thread.update(cx, |thread, _cx| {
+        thread.add_tool(StreamingFailingEchoTool {
+            receive_chunks_until_failure: 1,
+        });
+    });
+
+    let _events = thread
+        .update(cx, |thread, cx| {
+            thread.send(
+                UserMessageId::new(),
+                ["Use the streaming_failing_echo tool"],
+                cx,
+            )
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    let tool_use = LanguageModelToolUse {
+        id: "call_1".into(),
+        name: StreamingFailingEchoTool::NAME.into(),
+        raw_input: "hello".into(),
+        input: json!({}),
+        is_input_complete: false,
+        thought_signature: None,
+    };
+
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
+
+    cx.run_until_parked();
+
+    let completions = fake_model.pending_completions();
+    let last_completion = completions.last().unwrap();
+
+    assert_eq!(
+        last_completion.messages[1..],
+        vec![
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["Use the streaming_failing_echo tool".into()],
+                cache: false,
+                reasoning_details: None,
+            },
+            LanguageModelRequestMessage {
+                role: Role::Assistant,
+                content: vec![language_model::MessageContent::ToolUse(tool_use.clone())],
+                cache: false,
+                reasoning_details: None,
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec![language_model::MessageContent::ToolResult(
+                    LanguageModelToolResult {
+                        tool_use_id: tool_use.id.clone(),
+                        tool_name: tool_use.name,
+                        is_error: true,
+                        content: "failed".into(),
+                        output: Some("failed".into()),
+                    }
+                )],
+                cache: true,
+                reasoning_details: None,
+            },
+        ]
+    );
+}
+
+#[gpui::test]
+async fn test_streaming_tool_error_waits_for_prior_tools_to_complete(cx: &mut TestAppContext) {
+    init_test(cx);
+    always_allow_tools(cx);
+
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    let (complete_streaming_echo_tool_call_tx, complete_streaming_echo_tool_call_rx) =
+        oneshot::channel();
+
+    thread.update(cx, |thread, _cx| {
+        thread.add_tool(
+            StreamingEchoTool::new().with_wait_until_complete(complete_streaming_echo_tool_call_rx),
+        );
+        thread.add_tool(StreamingFailingEchoTool {
+            receive_chunks_until_failure: 1,
+        });
+    });
+
+    let _events = thread
+        .update(cx, |thread, cx| {
+            thread.send(
+                UserMessageId::new(),
+                ["Use the streaming_echo tool and the streaming_failing_echo tool"],
+                cx,
+            )
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        LanguageModelToolUse {
+            id: "call_1".into(),
+            name: StreamingEchoTool::NAME.into(),
+            raw_input: "hello".into(),
+            input: json!({ "text": "hello" }),
+            is_input_complete: false,
+            thought_signature: None,
+        },
+    ));
+    let first_tool_use = LanguageModelToolUse {
+        id: "call_1".into(),
+        name: StreamingEchoTool::NAME.into(),
+        raw_input: "hello world".into(),
+        input: json!({ "text": "hello world" }),
+        is_input_complete: true,
+        thought_signature: None,
+    };
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        first_tool_use.clone(),
+    ));
+    let second_tool_use = LanguageModelToolUse {
+        name: StreamingFailingEchoTool::NAME.into(),
+        raw_input: "hello".into(),
+        input: json!({ "text": "hello" }),
+        is_input_complete: false,
+        thought_signature: None,
+        id: "call_2".into(),
+    };
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        second_tool_use.clone(),
+    ));
+
+    cx.run_until_parked();
+
+    complete_streaming_echo_tool_call_tx.send(()).unwrap();
+
+    cx.run_until_parked();
+
+    let completions = fake_model.pending_completions();
+    let last_completion = completions.last().unwrap();
+
+    assert_eq!(
+        last_completion.messages[1..],
+        vec![
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec![
+                    "Use the streaming_echo tool and the streaming_failing_echo tool".into()
+                ],
+                cache: false,
+                reasoning_details: None,
+            },
+            LanguageModelRequestMessage {
+                role: Role::Assistant,
+                content: vec![
+                    language_model::MessageContent::ToolUse(first_tool_use.clone()),
+                    language_model::MessageContent::ToolUse(second_tool_use.clone())
+                ],
+                cache: false,
+                reasoning_details: None,
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec![
+                    language_model::MessageContent::ToolResult(LanguageModelToolResult {
+                        tool_use_id: second_tool_use.id.clone(),
+                        tool_name: second_tool_use.name,
+                        is_error: true,
+                        content: "failed".into(),
+                        output: Some("failed".into()),
+                    }),
+                    language_model::MessageContent::ToolResult(LanguageModelToolResult {
+                        tool_use_id: first_tool_use.id.clone(),
+                        tool_name: first_tool_use.name,
+                        is_error: false,
+                        content: "hello world".into(),
+                        output: Some("hello world".into()),
+                    }),
+                ],
+                cache: true,
+                reasoning_details: None,
+            },
+        ]
+    );
+}

crates/agent/src/tests/test_tools.rs 🔗

@@ -2,6 +2,7 @@ use super::*;
 use agent_settings::AgentSettings;
 use gpui::{App, SharedString, Task};
 use std::future;
+use std::sync::Mutex;
 use std::sync::atomic::{AtomicBool, Ordering};
 use std::time::Duration;
 
@@ -14,7 +15,22 @@ pub struct StreamingEchoToolInput {
     pub text: String,
 }
 
-pub struct StreamingEchoTool;
+pub struct StreamingEchoTool {
+    wait_until_complete_rx: Mutex<Option<oneshot::Receiver<()>>>,
+}
+
+impl StreamingEchoTool {
+    pub fn new() -> Self {
+        Self {
+            wait_until_complete_rx: Mutex::new(None),
+        }
+    }
+
+    pub fn with_wait_until_complete(mut self, receiver: oneshot::Receiver<()>) -> Self {
+        self.wait_until_complete_rx = Mutex::new(Some(receiver));
+        self
+    }
+}
 
 impl AgentTool for StreamingEchoTool {
     type Input = StreamingEchoToolInput;
@@ -44,17 +60,72 @@ impl AgentTool for StreamingEchoTool {
         _event_stream: ToolCallEventStream,
         cx: &mut App,
     ) -> Task<Result<String, String>> {
+        let wait_until_complete_rx = self.wait_until_complete_rx.lock().unwrap().take();
         cx.spawn(async move |_cx| {
             while input.recv_partial().await.is_some() {}
             let input = input
                 .recv()
                 .await
                 .map_err(|e| format!("Failed to receive tool input: {e}"))?;
+            if let Some(rx) = wait_until_complete_rx {
+                rx.await.ok();
+            }
             Ok(input.text)
         })
     }
 }
 
+/// A streaming tool that echoes its input, used to test streaming tool
+/// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
+/// before `is_input_complete`).
+#[derive(JsonSchema, Serialize, Deserialize)]
+pub struct StreamingFailingEchoToolInput {
+    /// The text to echo.
+    pub text: String,
+}
+
+pub struct StreamingFailingEchoTool {
+    pub receive_chunks_until_failure: usize,
+}
+
+impl AgentTool for StreamingFailingEchoTool {
+    type Input = StreamingFailingEchoToolInput;
+
+    type Output = String;
+
+    const NAME: &'static str = "streaming_failing_echo";
+
+    fn kind() -> acp::ToolKind {
+        acp::ToolKind::Other
+    }
+
+    fn supports_input_streaming() -> bool {
+        true
+    }
+
+    fn initial_title(
+        &self,
+        _input: Result<Self::Input, serde_json::Value>,
+        _cx: &mut App,
+    ) -> SharedString {
+        "echo".into()
+    }
+
+    fn run(
+        self: Arc<Self>,
+        mut input: ToolInput<Self::Input>,
+        _event_stream: ToolCallEventStream,
+        cx: &mut App,
+    ) -> Task<Result<Self::Output, Self::Output>> {
+        cx.spawn(async move |_cx| {
+            for _ in 0..self.receive_chunks_until_failure {
+                let _ = input.recv_partial().await;
+            }
+            Err("failed".into())
+        })
+    }
+}
+
 /// A tool that echoes its input
 #[derive(JsonSchema, Serialize, Deserialize)]
 pub struct EchoToolInput {

crates/agent/src/thread.rs 🔗

@@ -1846,12 +1846,37 @@ impl Thread {
                 Ok(events) => (events.fuse(), None),
                 Err(err) => (stream::empty().boxed().fuse(), Some(err)),
             };
-            let mut tool_results = FuturesUnordered::new();
+            let mut tool_results: FuturesUnordered<Task<LanguageModelToolResult>> =
+                FuturesUnordered::new();
+            let mut early_tool_results: Vec<LanguageModelToolResult> = Vec::new();
             let mut cancelled = false;
             loop {
-                // Race between getting the first event and cancellation
+                // Race between getting the first event, tool completion, and cancellation.
                 let first_event = futures::select! {
                     event = events.next().fuse() => event,
+                    tool_result = futures::StreamExt::select_next_some(&mut tool_results) => {
+                        let is_error = tool_result.is_error;
+                        let is_still_streaming = this
+                            .read_with(cx, |this, _cx| {
+                                this.running_turn
+                                    .as_ref()
+                                    .and_then(|turn| turn.streaming_tool_inputs.get(&tool_result.tool_use_id))
+                                    .map_or(false, |inputs| !inputs.has_received_final())
+                            })
+                            .unwrap_or(false);
+
+                        early_tool_results.push(tool_result);
+
+                        // Only break if the tool errored and we are still
+                        // streaming the input of the tool. If the tool errored
+                        // but we are no longer streaming its input (i.e. there
+                        // are parallel tool calls) we want to continue
+                        // processing those tool inputs.
+                        if is_error && is_still_streaming {
+                            break;
+                        }
+                        continue;
+                    }
                     _ = cancellation_rx.changed().fuse() => {
                         if *cancellation_rx.borrow() {
                             cancelled = true;
@@ -1931,26 +1956,13 @@ impl Thread {
                 }
             })?;
 
-            let end_turn = tool_results.is_empty();
-            while let Some(tool_result) = tool_results.next().await {
-                log::debug!("Tool finished {:?}", tool_result);
+            let end_turn = tool_results.is_empty() && early_tool_results.is_empty();
 
-                event_stream.update_tool_call_fields(
-                    &tool_result.tool_use_id,
-                    acp::ToolCallUpdateFields::new()
-                        .status(if tool_result.is_error {
-                            acp::ToolCallStatus::Failed
-                        } else {
-                            acp::ToolCallStatus::Completed
-                        })
-                        .raw_output(tool_result.output.clone()),
-                    None,
-                );
-                this.update(cx, |this, _cx| {
-                    this.pending_message()
-                        .tool_results
-                        .insert(tool_result.tool_use_id.clone(), tool_result);
-                })?;
+            for tool_result in early_tool_results {
+                Self::process_tool_result(this, event_stream, cx, tool_result)?;
+            }
+            while let Some(tool_result) = tool_results.next().await {
+                Self::process_tool_result(this, event_stream, cx, tool_result)?;
             }
 
             this.update(cx, |this, cx| {
@@ -2004,6 +2016,33 @@ impl Thread {
         }
     }
 
+    fn process_tool_result(
+        this: &WeakEntity<Thread>,
+        event_stream: &ThreadEventStream,
+        cx: &mut AsyncApp,
+        tool_result: LanguageModelToolResult,
+    ) -> Result<(), anyhow::Error> {
+        log::debug!("Tool finished {:?}", tool_result);
+
+        event_stream.update_tool_call_fields(
+            &tool_result.tool_use_id,
+            acp::ToolCallUpdateFields::new()
+                .status(if tool_result.is_error {
+                    acp::ToolCallStatus::Failed
+                } else {
+                    acp::ToolCallStatus::Completed
+                })
+                .raw_output(tool_result.output.clone()),
+            None,
+        );
+        this.update(cx, |this, _cx| {
+            this.pending_message()
+                .tool_results
+                .insert(tool_result.tool_use_id.clone(), tool_result);
+        })?;
+        Ok(())
+    }
+
     fn handle_completion_error(
         &mut self,
         error: LanguageModelCompletionError,
@@ -3072,6 +3111,10 @@ impl ToolInputSender {
         (sender, input)
     }
 
+    pub(crate) fn has_received_final(&self) -> bool {
+        self.final_tx.is_none()
+    }
+
     pub(crate) fn send_partial(&self, value: serde_json::Value) {
         self.partial_tx.unbounded_send(value).ok();
     }