@@ -3605,6 +3605,113 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
));
}
+#[gpui::test]
+async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input(
+ cx: &mut TestAppContext,
+) {
+ init_test(cx);
+ always_allow_tools(cx);
+
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ thread.update(cx, |thread, _cx| {
+ thread.add_tool(StreamingEchoTool);
+ });
+
+ let _events = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Use the streaming_echo tool"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ // Send a partial tool use (is_input_complete = false), simulating the LLM
+ // streaming input for a tool.
+ let tool_use = LanguageModelToolUse {
+ id: "tool_1".into(),
+ name: "streaming_echo".into(),
+ raw_input: r#"{"text": "partial"}"#.into(),
+ input: json!({"text": "partial"}),
+ is_input_complete: false,
+ thought_signature: None,
+ };
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
+ cx.run_until_parked();
+
+ // Send a stream error WITHOUT ever sending is_input_complete = true.
+ // Before the fix, this would deadlock: the tool waits for more partials
+ // (or cancellation), run_turn_internal waits for the tool, and the sender
+ // keeping the channel open lives inside RunningTurn.
+ fake_model.send_last_completion_stream_error(
+ LanguageModelCompletionError::UpstreamProviderError {
+ message: "Internal server error".to_string(),
+ status: http_client::StatusCode::INTERNAL_SERVER_ERROR,
+ retry_after: None,
+ },
+ );
+ fake_model.end_last_completion_stream();
+
+ // Advance past the retry delay so run_turn_internal retries.
+ cx.executor().advance_clock(Duration::from_secs(5));
+ cx.run_until_parked();
+
+ // The retry request should contain the streaming tool's error result,
+ // proving the tool terminated and its result was forwarded.
+ let completion = fake_model
+ .pending_completions()
+ .pop()
+ .expect("No running turn");
+ assert_eq!(
+ completion.messages[1..],
+ vec![
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Use the streaming_echo tool".into()],
+ cache: false,
+ reasoning_details: None,
+ },
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![language_model::MessageContent::ToolUse(tool_use.clone())],
+ cache: false,
+ reasoning_details: None,
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![language_model::MessageContent::ToolResult(
+ LanguageModelToolResult {
+ tool_use_id: tool_use.id.clone(),
+ tool_name: tool_use.name,
+ is_error: true,
+ content: "Failed to receive tool input: tool input was not fully received"
+ .into(),
+ output: Some(
+ "Failed to receive tool input: tool input was not fully received"
+ .into()
+ ),
+ }
+ )],
+ cache: true,
+ reasoning_details: None,
+ },
+ ]
+ );
+
+ // Finish the retry round so the turn completes cleanly.
+ fake_model.send_last_completion_stream_text_chunk("Done");
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ thread.read_with(cx, |thread, _cx| {
+ assert!(
+ thread.is_turn_complete(),
+ "Thread should not be stuck; the turn should have completed",
+ );
+ });
+}
+
/// Filters out the stop events for asserting against in tests
fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
result_events
@@ -3660,6 +3767,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
ToolRequiringPermission::NAME: true,
InfiniteTool::NAME: true,
CancellationAwareTool::NAME: true,
+ StreamingEchoTool::NAME: true,
(TerminalTool::NAME): true,
}
}
@@ -5,6 +5,56 @@ use std::future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
+/// A streaming tool that echoes its input, used to test streaming tool
+/// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
+/// before `is_input_complete`).
+#[derive(JsonSchema, Serialize, Deserialize)]
+pub struct StreamingEchoToolInput {
+ /// The text to echo.
+ pub text: String,
+}
+
+pub struct StreamingEchoTool;
+
+impl AgentTool for StreamingEchoTool {
+ type Input = StreamingEchoToolInput;
+ type Output = String;
+
+ const NAME: &'static str = "streaming_echo";
+
+ fn supports_input_streaming() -> bool {
+ true
+ }
+
+ fn kind() -> acp::ToolKind {
+ acp::ToolKind::Other
+ }
+
+ fn initial_title(
+ &self,
+ _input: Result<Self::Input, serde_json::Value>,
+ _cx: &mut App,
+ ) -> SharedString {
+ "Streaming Echo".into()
+ }
+
+ fn run(
+ self: Arc<Self>,
+ mut input: ToolInput<Self::Input>,
+ _event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Task<Result<String, String>> {
+ cx.spawn(async move |_cx| {
+ while input.recv_partial().await.is_some() {}
+ let input = input
+ .recv()
+ .await
+ .map_err(|e| format!("Failed to receive tool input: {e}"))?;
+ Ok(input.text)
+ })
+ }
+}
+
/// A tool that echoes its input
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct EchoToolInput {
@@ -1918,6 +1918,19 @@ impl Thread {
// that need their own permits.
drop(events);
+ // Drop streaming tool input senders that never received their final input.
+ // This prevents deadlock when the LLM stream ends (e.g. because of an error)
+ // before sending a tool use with `is_input_complete: true`.
+ this.update(cx, |this, _cx| {
+ if let Some(running_turn) = this.running_turn.as_mut() {
+ if running_turn.streaming_tool_inputs.is_empty() {
+ return;
+ }
+ log::warn!("Dropping partial tool inputs because the stream ended");
+ running_turn.streaming_tool_inputs.drain();
+ }
+ })?;
+
let end_turn = tool_results.is_empty();
while let Some(tool_result) = tool_results.next().await {
log::debug!("Tool finished {:?}", tool_result);
@@ -3019,7 +3032,7 @@ impl<T: DeserializeOwned> ToolInput<T> {
let value = self
.final_rx
.await
- .map_err(|_| anyhow!("tool input sender was dropped before sending final input"))?;
+ .map_err(|_| anyhow!("tool input was not fully received"))?;
serde_json::from_value(value).map_err(Into::into)
}