@@ -3616,7 +3616,7 @@ async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input(
let fake_model = model.as_fake();
thread.update(cx, |thread, _cx| {
- thread.add_tool(StreamingEchoTool);
+ thread.add_tool(StreamingEchoTool::new());
});
let _events = thread
@@ -3768,7 +3768,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
InfiniteTool::NAME: true,
CancellationAwareTool::NAME: true,
StreamingEchoTool::NAME: true,
- (TerminalTool::NAME): true,
+ StreamingFailingEchoTool::NAME: true,
+ TerminalTool::NAME: true,
}
}
}
@@ -6335,3 +6336,196 @@ async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
);
});
}
+
+#[gpui::test]
+async fn test_streaming_tool_error_breaks_stream_loop_immediately(cx: &mut TestAppContext) {
+ init_test(cx);
+ always_allow_tools(cx);
+
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ thread.update(cx, |thread, _cx| {
+ thread.add_tool(StreamingFailingEchoTool {
+ receive_chunks_until_failure: 1,
+ });
+ });
+
+ let _events = thread
+ .update(cx, |thread, cx| {
+ thread.send(
+ UserMessageId::new(),
+ ["Use the streaming_failing_echo tool"],
+ cx,
+ )
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ let tool_use = LanguageModelToolUse {
+ id: "call_1".into(),
+ name: StreamingFailingEchoTool::NAME.into(),
+ raw_input: "hello".into(),
+ input: json!({}),
+ is_input_complete: false,
+ thought_signature: None,
+ };
+
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
+
+ cx.run_until_parked();
+
+ let completions = fake_model.pending_completions();
+ let last_completion = completions.last().unwrap();
+
+ assert_eq!(
+ last_completion.messages[1..],
+ vec![
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Use the streaming_failing_echo tool".into()],
+ cache: false,
+ reasoning_details: None,
+ },
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![language_model::MessageContent::ToolUse(tool_use.clone())],
+ cache: false,
+ reasoning_details: None,
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![language_model::MessageContent::ToolResult(
+ LanguageModelToolResult {
+ tool_use_id: tool_use.id.clone(),
+ tool_name: tool_use.name,
+ is_error: true,
+ content: "failed".into(),
+ output: Some("failed".into()),
+ }
+ )],
+ cache: true,
+ reasoning_details: None,
+ },
+ ]
+ );
+}
+
+#[gpui::test]
+async fn test_streaming_tool_error_waits_for_prior_tools_to_complete(cx: &mut TestAppContext) {
+ init_test(cx);
+ always_allow_tools(cx);
+
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let (complete_streaming_echo_tool_call_tx, complete_streaming_echo_tool_call_rx) =
+ oneshot::channel();
+
+ thread.update(cx, |thread, _cx| {
+ thread.add_tool(
+ StreamingEchoTool::new().with_wait_until_complete(complete_streaming_echo_tool_call_rx),
+ );
+ thread.add_tool(StreamingFailingEchoTool {
+ receive_chunks_until_failure: 1,
+ });
+ });
+
+ let _events = thread
+ .update(cx, |thread, cx| {
+ thread.send(
+ UserMessageId::new(),
+ ["Use the streaming_echo tool and the streaming_failing_echo tool"],
+ cx,
+ )
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "call_1".into(),
+ name: StreamingEchoTool::NAME.into(),
+ raw_input: "hello".into(),
+ input: json!({ "text": "hello" }),
+ is_input_complete: false,
+ thought_signature: None,
+ },
+ ));
+ let first_tool_use = LanguageModelToolUse {
+ id: "call_1".into(),
+ name: StreamingEchoTool::NAME.into(),
+ raw_input: "hello world".into(),
+ input: json!({ "text": "hello world" }),
+ is_input_complete: true,
+ thought_signature: None,
+ };
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ first_tool_use.clone(),
+ ));
+ let second_tool_use = LanguageModelToolUse {
+ name: StreamingFailingEchoTool::NAME.into(),
+ raw_input: "hello".into(),
+ input: json!({ "text": "hello" }),
+ is_input_complete: false,
+ thought_signature: None,
+ id: "call_2".into(),
+ };
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ second_tool_use.clone(),
+ ));
+
+ cx.run_until_parked();
+
+ complete_streaming_echo_tool_call_tx.send(()).unwrap();
+
+ cx.run_until_parked();
+
+ let completions = fake_model.pending_completions();
+ let last_completion = completions.last().unwrap();
+
+ assert_eq!(
+ last_completion.messages[1..],
+ vec![
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![
+ "Use the streaming_echo tool and the streaming_failing_echo tool".into()
+ ],
+ cache: false,
+ reasoning_details: None,
+ },
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![
+ language_model::MessageContent::ToolUse(first_tool_use.clone()),
+ language_model::MessageContent::ToolUse(second_tool_use.clone())
+ ],
+ cache: false,
+ reasoning_details: None,
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![
+ language_model::MessageContent::ToolResult(LanguageModelToolResult {
+ tool_use_id: second_tool_use.id.clone(),
+ tool_name: second_tool_use.name,
+ is_error: true,
+ content: "failed".into(),
+ output: Some("failed".into()),
+ }),
+ language_model::MessageContent::ToolResult(LanguageModelToolResult {
+ tool_use_id: first_tool_use.id.clone(),
+ tool_name: first_tool_use.name,
+ is_error: false,
+ content: "hello world".into(),
+ output: Some("hello world".into()),
+ }),
+ ],
+ cache: true,
+ reasoning_details: None,
+ },
+ ]
+ );
+}
@@ -2,6 +2,7 @@ use super::*;
use agent_settings::AgentSettings;
use gpui::{App, SharedString, Task};
use std::future;
+use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
@@ -14,7 +15,22 @@ pub struct StreamingEchoToolInput {
pub text: String,
}
-pub struct StreamingEchoTool;
+pub struct StreamingEchoTool {
+ wait_until_complete_rx: Mutex<Option<oneshot::Receiver<()>>>,
+}
+
+impl StreamingEchoTool {
+ pub fn new() -> Self {
+ Self {
+ wait_until_complete_rx: Mutex::new(None),
+ }
+ }
+
+ pub fn with_wait_until_complete(mut self, receiver: oneshot::Receiver<()>) -> Self {
+ self.wait_until_complete_rx = Mutex::new(Some(receiver));
+ self
+ }
+}
impl AgentTool for StreamingEchoTool {
type Input = StreamingEchoToolInput;
@@ -44,17 +60,72 @@ impl AgentTool for StreamingEchoTool {
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String, String>> {
+ let wait_until_complete_rx = self.wait_until_complete_rx.lock().unwrap().take();
cx.spawn(async move |_cx| {
while input.recv_partial().await.is_some() {}
let input = input
.recv()
.await
.map_err(|e| format!("Failed to receive tool input: {e}"))?;
+ if let Some(rx) = wait_until_complete_rx {
+ rx.await.ok();
+ }
Ok(input.text)
})
}
}
+/// A streaming tool that echoes its input, used to test streaming tool
+/// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
+/// before `is_input_complete`).
+#[derive(JsonSchema, Serialize, Deserialize)]
+pub struct StreamingFailingEchoToolInput {
+ /// The text to echo.
+ pub text: String,
+}
+
+pub struct StreamingFailingEchoTool {
+ pub receive_chunks_until_failure: usize,
+}
+
+impl AgentTool for StreamingFailingEchoTool {
+ type Input = StreamingFailingEchoToolInput;
+
+ type Output = String;
+
+ const NAME: &'static str = "streaming_failing_echo";
+
+ fn kind() -> acp::ToolKind {
+ acp::ToolKind::Other
+ }
+
+ fn supports_input_streaming() -> bool {
+ true
+ }
+
+ fn initial_title(
+ &self,
+ _input: Result<Self::Input, serde_json::Value>,
+ _cx: &mut App,
+ ) -> SharedString {
+ "echo".into()
+ }
+
+ fn run(
+ self: Arc<Self>,
+ mut input: ToolInput<Self::Input>,
+ _event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Task<Result<Self::Output, Self::Output>> {
+ cx.spawn(async move |_cx| {
+ for _ in 0..self.receive_chunks_until_failure {
+ let _ = input.recv_partial().await;
+ }
+ Err("failed".into())
+ })
+ }
+}
+
/// A tool that echoes its input
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct EchoToolInput {
@@ -1846,12 +1846,37 @@ impl Thread {
Ok(events) => (events.fuse(), None),
Err(err) => (stream::empty().boxed().fuse(), Some(err)),
};
- let mut tool_results = FuturesUnordered::new();
+ let mut tool_results: FuturesUnordered<Task<LanguageModelToolResult>> =
+ FuturesUnordered::new();
+ let mut early_tool_results: Vec<LanguageModelToolResult> = Vec::new();
let mut cancelled = false;
loop {
- // Race between getting the first event and cancellation
+ // Race between getting the first event, tool completion, and cancellation.
let first_event = futures::select! {
event = events.next().fuse() => event,
+ tool_result = futures::StreamExt::select_next_some(&mut tool_results) => {
+ let is_error = tool_result.is_error;
+ let is_still_streaming = this
+ .read_with(cx, |this, _cx| {
+ this.running_turn
+ .as_ref()
+ .and_then(|turn| turn.streaming_tool_inputs.get(&tool_result.tool_use_id))
+ .map_or(false, |inputs| !inputs.has_received_final())
+ })
+ .unwrap_or(false);
+
+ early_tool_results.push(tool_result);
+
+ // Only break if the tool errored and we are still
+ // streaming the input of the tool. If the tool errored
+ // but we are no longer streaming its input (i.e. there
+ // are parallel tool calls) we want to continue
+ // processing those tool inputs.
+ if is_error && is_still_streaming {
+ break;
+ }
+ continue;
+ }
_ = cancellation_rx.changed().fuse() => {
if *cancellation_rx.borrow() {
cancelled = true;
@@ -1931,26 +1956,13 @@ impl Thread {
}
})?;
- let end_turn = tool_results.is_empty();
- while let Some(tool_result) = tool_results.next().await {
- log::debug!("Tool finished {:?}", tool_result);
+ let end_turn = tool_results.is_empty() && early_tool_results.is_empty();
- event_stream.update_tool_call_fields(
- &tool_result.tool_use_id,
- acp::ToolCallUpdateFields::new()
- .status(if tool_result.is_error {
- acp::ToolCallStatus::Failed
- } else {
- acp::ToolCallStatus::Completed
- })
- .raw_output(tool_result.output.clone()),
- None,
- );
- this.update(cx, |this, _cx| {
- this.pending_message()
- .tool_results
- .insert(tool_result.tool_use_id.clone(), tool_result);
- })?;
+ for tool_result in early_tool_results {
+ Self::process_tool_result(this, event_stream, cx, tool_result)?;
+ }
+ while let Some(tool_result) = tool_results.next().await {
+ Self::process_tool_result(this, event_stream, cx, tool_result)?;
}
this.update(cx, |this, cx| {
@@ -2004,6 +2016,33 @@ impl Thread {
}
}
+ fn process_tool_result(
+ this: &WeakEntity<Thread>,
+ event_stream: &ThreadEventStream,
+ cx: &mut AsyncApp,
+ tool_result: LanguageModelToolResult,
+ ) -> Result<(), anyhow::Error> {
+ log::debug!("Tool finished {:?}", tool_result);
+
+ event_stream.update_tool_call_fields(
+ &tool_result.tool_use_id,
+ acp::ToolCallUpdateFields::new()
+ .status(if tool_result.is_error {
+ acp::ToolCallStatus::Failed
+ } else {
+ acp::ToolCallStatus::Completed
+ })
+ .raw_output(tool_result.output.clone()),
+ None,
+ );
+ this.update(cx, |this, _cx| {
+ this.pending_message()
+ .tool_results
+ .insert(tool_result.tool_use_id.clone(), tool_result);
+ })?;
+ Ok(())
+ }
+
fn handle_completion_error(
&mut self,
error: LanguageModelCompletionError,
@@ -3072,6 +3111,10 @@ impl ToolInputSender {
(sender, input)
}
+ pub(crate) fn has_received_final(&self) -> bool {
+ self.final_tx.is_none()
+ }
+
pub(crate) fn send_partial(&self, value: serde_json::Value) {
self.partial_tx.unbounded_send(value).ok();
}