agent: Fix deadlock when LLM API returns an error while streaming tool input (#50813)

Bennet Bo Fenner created

Nightly only since no tool is using streaming on Preview.

Tools that used streaming could deadlock when the upstream LLM API
stream exited early without ever sending a ToolUse with `input_complete`
set to `true`. This is now fixes by manually dropping the tool input
channels.

Release Notes:

- N/A

Change summary

crates/agent/src/tests/mod.rs        | 108 ++++++++++++++++++++++++++++++
crates/agent/src/tests/test_tools.rs |  50 +++++++++++++
crates/agent/src/thread.rs           |  15 +++
3 files changed, 172 insertions(+), 1 deletion(-)

Detailed changes

crates/agent/src/tests/mod.rs 🔗

@@ -3605,6 +3605,113 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
     ));
 }
 
+#[gpui::test]
+async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input(
+    cx: &mut TestAppContext,
+) {
+    init_test(cx);
+    always_allow_tools(cx);
+
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    thread.update(cx, |thread, _cx| {
+        thread.add_tool(StreamingEchoTool);
+    });
+
+    let _events = thread
+        .update(cx, |thread, cx| {
+            thread.send(UserMessageId::new(), ["Use the streaming_echo tool"], cx)
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    // Send a partial tool use (is_input_complete = false), simulating the LLM
+    // streaming input for a tool.
+    let tool_use = LanguageModelToolUse {
+        id: "tool_1".into(),
+        name: "streaming_echo".into(),
+        raw_input: r#"{"text": "partial"}"#.into(),
+        input: json!({"text": "partial"}),
+        is_input_complete: false,
+        thought_signature: None,
+    };
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
+    cx.run_until_parked();
+
+    // Send a stream error WITHOUT ever sending is_input_complete = true.
+    // Before the fix, this would deadlock: the tool waits for more partials
+    // (or cancellation), run_turn_internal waits for the tool, and the sender
+    // keeping the channel open lives inside RunningTurn.
+    fake_model.send_last_completion_stream_error(
+        LanguageModelCompletionError::UpstreamProviderError {
+            message: "Internal server error".to_string(),
+            status: http_client::StatusCode::INTERNAL_SERVER_ERROR,
+            retry_after: None,
+        },
+    );
+    fake_model.end_last_completion_stream();
+
+    // Advance past the retry delay so run_turn_internal retries.
+    cx.executor().advance_clock(Duration::from_secs(5));
+    cx.run_until_parked();
+
+    // The retry request should contain the streaming tool's error result,
+    // proving the tool terminated and its result was forwarded.
+    let completion = fake_model
+        .pending_completions()
+        .pop()
+        .expect("No running turn");
+    assert_eq!(
+        completion.messages[1..],
+        vec![
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["Use the streaming_echo tool".into()],
+                cache: false,
+                reasoning_details: None,
+            },
+            LanguageModelRequestMessage {
+                role: Role::Assistant,
+                content: vec![language_model::MessageContent::ToolUse(tool_use.clone())],
+                cache: false,
+                reasoning_details: None,
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec![language_model::MessageContent::ToolResult(
+                    LanguageModelToolResult {
+                        tool_use_id: tool_use.id.clone(),
+                        tool_name: tool_use.name,
+                        is_error: true,
+                        content: "Failed to receive tool input: tool input was not fully received"
+                            .into(),
+                        output: Some(
+                            "Failed to receive tool input: tool input was not fully received"
+                                .into()
+                        ),
+                    }
+                )],
+                cache: true,
+                reasoning_details: None,
+            },
+        ]
+    );
+
+    // Finish the retry round so the turn completes cleanly.
+    fake_model.send_last_completion_stream_text_chunk("Done");
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    thread.read_with(cx, |thread, _cx| {
+        assert!(
+            thread.is_turn_complete(),
+            "Thread should not be stuck; the turn should have completed",
+        );
+    });
+}
+
 /// Filters out the stop events for asserting against in tests
 fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
     result_events
@@ -3660,6 +3767,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
                             ToolRequiringPermission::NAME: true,
                             InfiniteTool::NAME: true,
                             CancellationAwareTool::NAME: true,
+                            StreamingEchoTool::NAME: true,
                             (TerminalTool::NAME): true,
                         }
                     }

crates/agent/src/tests/test_tools.rs 🔗

@@ -5,6 +5,56 @@ use std::future;
 use std::sync::atomic::{AtomicBool, Ordering};
 use std::time::Duration;
 
+/// A streaming tool that echoes its input, used to test streaming tool
+/// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
+/// before `is_input_complete`).
+#[derive(JsonSchema, Serialize, Deserialize)]
+pub struct StreamingEchoToolInput {
+    /// The text to echo.
+    pub text: String,
+}
+
+pub struct StreamingEchoTool;
+
+impl AgentTool for StreamingEchoTool {
+    type Input = StreamingEchoToolInput;
+    type Output = String;
+
+    const NAME: &'static str = "streaming_echo";
+
+    fn supports_input_streaming() -> bool {
+        true
+    }
+
+    fn kind() -> acp::ToolKind {
+        acp::ToolKind::Other
+    }
+
+    fn initial_title(
+        &self,
+        _input: Result<Self::Input, serde_json::Value>,
+        _cx: &mut App,
+    ) -> SharedString {
+        "Streaming Echo".into()
+    }
+
+    fn run(
+        self: Arc<Self>,
+        mut input: ToolInput<Self::Input>,
+        _event_stream: ToolCallEventStream,
+        cx: &mut App,
+    ) -> Task<Result<String, String>> {
+        cx.spawn(async move |_cx| {
+            while input.recv_partial().await.is_some() {}
+            let input = input
+                .recv()
+                .await
+                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
+            Ok(input.text)
+        })
+    }
+}
+
 /// A tool that echoes its input
 #[derive(JsonSchema, Serialize, Deserialize)]
 pub struct EchoToolInput {

crates/agent/src/thread.rs 🔗

@@ -1918,6 +1918,19 @@ impl Thread {
             // that need their own permits.
             drop(events);
 
+            // Drop streaming tool input senders that never received their final input.
+            // This prevents deadlock when the LLM stream ends (e.g. because of an error)
+            // before sending a tool use with `is_input_complete: true`.
+            this.update(cx, |this, _cx| {
+                if let Some(running_turn) = this.running_turn.as_mut() {
+                    if running_turn.streaming_tool_inputs.is_empty() {
+                        return;
+                    }
+                    log::warn!("Dropping partial tool inputs because the stream ended");
+                    running_turn.streaming_tool_inputs.drain();
+                }
+            })?;
+
             let end_turn = tool_results.is_empty();
             while let Some(tool_result) = tool_results.next().await {
                 log::debug!("Tool finished {:?}", tool_result);
@@ -3019,7 +3032,7 @@ impl<T: DeserializeOwned> ToolInput<T> {
         let value = self
             .final_rx
             .await
-            .map_err(|_| anyhow!("tool input sender was dropped before sending final input"))?;
+            .map_err(|_| anyhow!("tool input was not fully received"))?;
         serde_json::from_value(value).map_err(Into::into)
     }