From e2bba5526aad44206abe1f54db8a593b06ae34d3 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Mon, 6 Apr 2026 22:26:26 +0200 Subject: [PATCH] agent: Fix issue with streaming tools when model produces invalid JSON (#52891) Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - N/A --- .../agent/src/tests/edit_file_thread_test.rs | 211 ++++++ crates/agent/src/tests/mod.rs | 112 ++++ crates/agent/src/tests/test_tools.rs | 67 +- crates/agent/src/thread.rs | 312 +++++---- .../src/tools/streaming_edit_file_tool.rs | 621 +++++++++++------- crates/language_model/src/fake_provider.rs | 10 + 6 files changed, 959 insertions(+), 374 deletions(-) diff --git a/crates/agent/src/tests/edit_file_thread_test.rs b/crates/agent/src/tests/edit_file_thread_test.rs index 3beb5cb0d51abc55fbf3cf0849ced248a9d1fa5c..b5ce6441e790e0b79b2798dfe0008cc74eec69b8 100644 --- a/crates/agent/src/tests/edit_file_thread_test.rs +++ b/crates/agent/src/tests/edit_file_thread_test.rs @@ -202,3 +202,214 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) { ); }); } + +#[gpui::test] +async fn test_streaming_edit_json_parse_error_does_not_cause_unsaved_changes( + cx: &mut TestAppContext, +) { + super::init_test(cx); + super::always_allow_tools(cx); + + // Enable the streaming edit file tool feature flag. + cx.update(|cx| { + cx.update_flags(true, vec!["streaming-edit-file-tool".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "src": { + "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}\n" + } + }), + ) + .await; + + let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + model.as_fake().set_supports_streaming_tools(true); + let fake_model = model.as_fake(); + + let thread = cx.new(|cx| { + let mut thread = crate::Thread::new( + project.clone(), + project_context, + context_server_registry, + crate::Templates::new(), + Some(model.clone()), + cx, + ); + let language_registry = project.read(cx).languages().clone(); + thread.add_tool(crate::StreamingEditFileTool::new( + project.clone(), + cx.weak_entity(), + thread.action_log().clone(), + language_registry, + )); + thread + }); + + let _events = thread + .update(cx, |thread, cx| { + thread.send( + UserMessageId::new(), + ["Write new content to src/main.rs"], + cx, + ) + }) + .unwrap(); + cx.run_until_parked(); + + let tool_use_id = "edit_1"; + let partial_1 = LanguageModelToolUse { + id: tool_use_id.into(), + name: EditFileTool::NAME.into(), + raw_input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write" + }) + .to_string(), + input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write" + }), + is_input_complete: false, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_1)); + cx.run_until_parked(); + + let partial_2 = LanguageModelToolUse { + id: tool_use_id.into(), + name: EditFileTool::NAME.into(), + raw_input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write", + "content": "fn main() { /* rewritten */ }" + }) + .to_string(), + input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write", + "content": "fn main() { /* rewritten */ }" + }), + is_input_complete: false, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_2)); + cx.run_until_parked(); + + // Now send a json parse error. At this point we have started writing content to the buffer. + fake_model.send_last_completion_stream_event( + LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_use_id.into(), + tool_name: EditFileTool::NAME.into(), + raw_input: r#"{"display_description":"Rewrite main.rs","path":"project/src/main.rs","mode":"write","content":"fn main() { /* rewritten "#.into(), + json_parse_error: "EOF while parsing a string at line 1 column 95".into(), + }, + ); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // cx.executor().advance_clock(Duration::from_secs(5)); + // cx.run_until_parked(); + + assert!( + !fake_model.pending_completions().is_empty(), + "Thread should have retried after the error" + ); + + // Respond with a new, well-formed, complete edit_file tool use. + let tool_use = LanguageModelToolUse { + id: "edit_2".into(), + name: EditFileTool::NAME.into(), + raw_input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write", + "content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n" + }) + .to_string(), + input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write", + "content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n" + }), + is_input_complete: true, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use)); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let pending_completions = fake_model.pending_completions(); + assert!( + pending_completions.len() == 1, + "Expected only the follow-up completion containing the successful tool result" + ); + + let completion = pending_completions + .into_iter() + .last() + .expect("Expected a completion containing the tool result for edit_2"); + + let tool_result = completion + .messages + .iter() + .flat_map(|msg| &msg.content) + .find_map(|content| match content { + language_model::MessageContent::ToolResult(result) + if result.tool_use_id == language_model::LanguageModelToolUseId::from("edit_2") => + { + Some(result) + } + _ => None, + }) + .expect("Should have a tool result for edit_2"); + + // Ensure that the second tool call completed successfully and edits were applied. + assert!( + !tool_result.is_error, + "Tool result should succeed, got: {:?}", + tool_result + ); + let content_text = match &tool_result.content { + language_model::LanguageModelToolResultContent::Text(t) => t.to_string(), + other => panic!("Expected text content, got: {:?}", other), + }; + assert!( + !content_text.contains("file has been modified since you last read it"), + "Did not expect a stale last-read error, got: {content_text}" + ); + assert!( + !content_text.contains("This file has unsaved changes"), + "Did not expect an unsaved-changes error, got: {content_text}" + ); + + let file_content = fs + .load(path!("/project/src/main.rs").as_ref()) + .await + .expect("file should exist"); + super::assert_eq!( + file_content, + "fn main() {\n println!(\"Hello, rewritten!\");\n}\n", + "The second edit should be applied and saved gracefully" + ); + + fake_model.end_last_completion_stream(); + cx.run_until_parked(); +} diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index f7b52b2573144e4c2fd378cfb19c9ee2473a37db..ff53136a0ded4bbc283fea30598d8d30e6e29709 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -3903,6 +3903,117 @@ async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input( }); } +#[gpui::test] +async fn test_streaming_tool_json_parse_error_is_forwarded_to_running_tool( + cx: &mut TestAppContext, +) { + init_test(cx); + always_allow_tools(cx); + + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + thread.update(cx, |thread, _cx| { + thread.add_tool(StreamingJsonErrorContextTool); + }); + + let _events = thread + .update(cx, |thread, cx| { + thread.send( + UserMessageId::new(), + ["Use the streaming_json_error_context tool"], + cx, + ) + }) + .unwrap(); + cx.run_until_parked(); + + let tool_use = LanguageModelToolUse { + id: "tool_1".into(), + name: StreamingJsonErrorContextTool::NAME.into(), + raw_input: r#"{"text": "partial"#.into(), + input: json!({"text": "partial"}), + is_input_complete: false, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use)); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_event( + LanguageModelCompletionEvent::ToolUseJsonParseError { + id: "tool_1".into(), + tool_name: StreamingJsonErrorContextTool::NAME.into(), + raw_input: r#"{"text": "partial"#.into(), + json_parse_error: "EOF while parsing a string at line 1 column 17".into(), + }, + ); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + cx.executor().advance_clock(Duration::from_secs(5)); + cx.run_until_parked(); + + let completion = fake_model + .pending_completions() + .pop() + .expect("No running turn"); + + let tool_results: Vec<_> = completion + .messages + .iter() + .flat_map(|message| &message.content) + .filter_map(|content| match content { + MessageContent::ToolResult(result) + if result.tool_use_id == language_model::LanguageModelToolUseId::from("tool_1") => + { + Some(result) + } + _ => None, + }) + .collect(); + + assert_eq!( + tool_results.len(), + 1, + "Expected exactly 1 tool result for tool_1, got {}: {:#?}", + tool_results.len(), + tool_results + ); + + let result = tool_results[0]; + assert!(result.is_error); + let content_text = match &result.content { + language_model::LanguageModelToolResultContent::Text(text) => text.to_string(), + other => panic!("Expected text content, got {:?}", other), + }; + assert!( + content_text.contains("Saw partial text 'partial' before invalid JSON"), + "Expected tool-enriched partial context, got: {content_text}" + ); + assert!( + content_text + .contains("Error parsing input JSON: EOF while parsing a string at line 1 column 17"), + "Expected forwarded JSON parse error, got: {content_text}" + ); + assert!( + !content_text.contains("tool input was not fully received"), + "Should not contain orphaned sender error, got: {content_text}" + ); + + fake_model.send_last_completion_stream_text_chunk("Done"); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + thread.read_with(cx, |thread, _cx| { + assert!( + thread.is_turn_complete(), + "Thread should not be stuck; the turn should have completed", + ); + }); +} + /// Filters out the stop events for asserting against in tests fn stop_events(result_events: Vec>) -> Vec { result_events @@ -3959,6 +4070,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { InfiniteTool::NAME: true, CancellationAwareTool::NAME: true, StreamingEchoTool::NAME: true, + StreamingJsonErrorContextTool::NAME: true, StreamingFailingEchoTool::NAME: true, TerminalTool::NAME: true, UpdatePlanTool::NAME: true, diff --git a/crates/agent/src/tests/test_tools.rs b/crates/agent/src/tests/test_tools.rs index f36549a6c42f9e810c7794d8ec683613b6ae6933..4744204fae1213d49af92339b8847e9d1f470125 100644 --- a/crates/agent/src/tests/test_tools.rs +++ b/crates/agent/src/tests/test_tools.rs @@ -56,13 +56,12 @@ impl AgentTool for StreamingEchoTool { fn run( self: Arc, - mut input: ToolInput, + input: ToolInput, _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let wait_until_complete_rx = self.wait_until_complete_rx.lock().unwrap().take(); cx.spawn(async move |_cx| { - while input.recv_partial().await.is_some() {} let input = input .recv() .await @@ -75,6 +74,68 @@ impl AgentTool for StreamingEchoTool { } } +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct StreamingJsonErrorContextToolInput { + /// The text to echo. + pub text: String, +} + +pub struct StreamingJsonErrorContextTool; + +impl AgentTool for StreamingJsonErrorContextTool { + type Input = StreamingJsonErrorContextToolInput; + type Output = String; + + const NAME: &'static str = "streaming_json_error_context"; + + fn supports_input_streaming() -> bool { + true + } + + fn kind() -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { + "Streaming JSON Error Context".into() + } + + fn run( + self: Arc, + mut input: ToolInput, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + cx.spawn(async move |_cx| { + let mut last_partial_text = None; + + loop { + match input.next().await { + Ok(ToolInputPayload::Partial(partial)) => { + if let Some(text) = partial.get("text").and_then(|value| value.as_str()) { + last_partial_text = Some(text.to_string()); + } + } + Ok(ToolInputPayload::Full(input)) => return Ok(input.text), + Ok(ToolInputPayload::InvalidJson { error_message }) => { + let partial_text = last_partial_text.unwrap_or_default(); + return Err(format!( + "Saw partial text '{partial_text}' before invalid JSON: {error_message}" + )); + } + Err(error) => { + return Err(format!("Failed to receive tool input: {error}")); + } + } + } + }) + } +} + /// A streaming tool that echoes its input, used to test streaming tool /// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends /// before `is_input_complete`). @@ -119,7 +180,7 @@ impl AgentTool for StreamingFailingEchoTool { ) -> Task> { cx.spawn(async move |_cx| { for _ in 0..self.receive_chunks_until_failure { - let _ = input.recv_partial().await; + let _ = input.next().await; } Err("failed".into()) }) diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index bcb5b7b2d2f3eb8cffd5be8b70fc08fef8e9fe37..ea342e8db4e4d97d5eccc849121cd0fd2e403017 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -22,13 +22,13 @@ use client::UserStore; use cloud_api_types::Plan; use collections::{HashMap, HashSet, IndexMap}; use fs::Fs; -use futures::stream; use futures::{ FutureExt, channel::{mpsc, oneshot}, future::Shared, stream::FuturesUnordered, }; +use futures::{StreamExt, stream}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, }; @@ -47,7 +47,6 @@ use schemars::{JsonSchema, Schema}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use settings::{LanguageModelSelection, Settings, ToolPermissionMode, update_settings_file}; -use smol::stream::StreamExt; use std::{ collections::BTreeMap, marker::PhantomData, @@ -2095,7 +2094,7 @@ impl Thread { this.update(cx, |this, _cx| { this.pending_message() .tool_results - .insert(tool_result.tool_use_id.clone(), tool_result); + .insert(tool_result.tool_use_id.clone(), tool_result) })?; Ok(()) } @@ -2195,15 +2194,15 @@ impl Thread { raw_input, json_parse_error, } => { - return Ok(Some(Task::ready( - self.handle_tool_use_json_parse_error_event( - id, - tool_name, - raw_input, - json_parse_error, - event_stream, - ), - ))); + return Ok(self.handle_tool_use_json_parse_error_event( + id, + tool_name, + raw_input, + json_parse_error, + event_stream, + cancellation_rx, + cx, + )); } UsageUpdate(usage) => { telemetry::event!( @@ -2304,12 +2303,12 @@ impl Thread { if !tool_use.is_input_complete { if tool.supports_input_streaming() { let running_turn = self.running_turn.as_mut()?; - if let Some(sender) = running_turn.streaming_tool_inputs.get(&tool_use.id) { + if let Some(sender) = running_turn.streaming_tool_inputs.get_mut(&tool_use.id) { sender.send_partial(tool_use.input); return None; } - let (sender, tool_input) = ToolInputSender::channel(); + let (mut sender, tool_input) = ToolInputSender::channel(); sender.send_partial(tool_use.input); running_turn .streaming_tool_inputs @@ -2331,13 +2330,13 @@ impl Thread { } } - if let Some(sender) = self + if let Some(mut sender) = self .running_turn .as_mut()? .streaming_tool_inputs .remove(&tool_use.id) { - sender.send_final(tool_use.input); + sender.send_full(tool_use.input); return None; } @@ -2410,10 +2409,12 @@ impl Thread { raw_input: Arc, json_parse_error: String, event_stream: &ThreadEventStream, - ) -> LanguageModelToolResult { + cancellation_rx: watch::Receiver, + cx: &mut Context, + ) -> Option> { let tool_use = LanguageModelToolUse { - id: tool_use_id.clone(), - name: tool_name.clone(), + id: tool_use_id, + name: tool_name, raw_input: raw_input.to_string(), input: serde_json::json!({}), is_input_complete: true, @@ -2426,14 +2427,43 @@ impl Thread { event_stream, ); - let tool_output = format!("Error parsing input JSON: {json_parse_error}"); - LanguageModelToolResult { - tool_use_id, - tool_name, - is_error: true, - content: LanguageModelToolResultContent::Text(tool_output.into()), - output: Some(serde_json::Value::String(raw_input.to_string())), + let tool = self.tool(tool_use.name.as_ref()); + + let Some(tool) = tool else { + let content = format!("No tool named {} exists", tool_use.name); + return Some(Task::ready(LanguageModelToolResult { + content: LanguageModelToolResultContent::Text(Arc::from(content)), + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + output: None, + })); + }; + + let error_message = format!("Error parsing input JSON: {json_parse_error}"); + + if tool.supports_input_streaming() + && let Some(mut sender) = self + .running_turn + .as_mut()? + .streaming_tool_inputs + .remove(&tool_use.id) + { + sender.send_invalid_json(error_message); + return None; } + + log::debug!("Running tool {}. Received invalid JSON", tool_use.name); + let tool_input = ToolInput::invalid_json(error_message); + Some(self.run_tool( + tool, + tool_input, + tool_use.id, + tool_use.name, + event_stream, + cancellation_rx, + cx, + )) } fn send_or_update_tool_use( @@ -3114,8 +3144,7 @@ impl EventEmitter for Thread {} /// For streaming tools, partial JSON snapshots arrive via `.recv_partial()` as the LLM streams /// them, followed by the final complete input available through `.recv()`. pub struct ToolInput { - partial_rx: mpsc::UnboundedReceiver, - final_rx: oneshot::Receiver, + rx: mpsc::UnboundedReceiver>, _phantom: PhantomData, } @@ -3127,13 +3156,20 @@ impl ToolInput { } pub fn ready(value: serde_json::Value) -> Self { - let (partial_tx, partial_rx) = mpsc::unbounded(); - drop(partial_tx); - let (final_tx, final_rx) = oneshot::channel(); - final_tx.send(value).ok(); + let (tx, rx) = mpsc::unbounded(); + tx.unbounded_send(ToolInputPayload::Full(value)).ok(); Self { - partial_rx, - final_rx, + rx, + _phantom: PhantomData, + } + } + + pub fn invalid_json(error_message: String) -> Self { + let (tx, rx) = mpsc::unbounded(); + tx.unbounded_send(ToolInputPayload::InvalidJson { error_message }) + .ok(); + Self { + rx, _phantom: PhantomData, } } @@ -3147,65 +3183,89 @@ impl ToolInput { /// Wait for the final deserialized input, ignoring all partial updates. /// Non-streaming tools can use this to wait until the whole input is available. pub async fn recv(mut self) -> Result { - // Drain any remaining partials - while self.partial_rx.next().await.is_some() {} + while let Ok(value) = self.next().await { + match value { + ToolInputPayload::Full(value) => return Ok(value), + ToolInputPayload::Partial(_) => {} + ToolInputPayload::InvalidJson { error_message } => { + return Err(anyhow!(error_message)); + } + } + } + Err(anyhow!("tool input was not fully received")) + } + + pub async fn next(&mut self) -> Result> { let value = self - .final_rx + .rx + .next() .await - .map_err(|_| anyhow!("tool input was not fully received"))?; - serde_json::from_value(value).map_err(Into::into) - } + .ok_or_else(|| anyhow!("tool input was not fully received"))?; - /// Returns the next partial JSON snapshot, or `None` when input is complete. - /// Once this returns `None`, call `recv()` to get the final input. - pub async fn recv_partial(&mut self) -> Option { - self.partial_rx.next().await + Ok(match value { + ToolInputPayload::Partial(payload) => ToolInputPayload::Partial(payload), + ToolInputPayload::Full(payload) => { + ToolInputPayload::Full(serde_json::from_value(payload)?) + } + ToolInputPayload::InvalidJson { error_message } => { + ToolInputPayload::InvalidJson { error_message } + } + }) } fn cast(self) -> ToolInput { ToolInput { - partial_rx: self.partial_rx, - final_rx: self.final_rx, + rx: self.rx, _phantom: PhantomData, } } } +pub enum ToolInputPayload { + Partial(serde_json::Value), + Full(T), + InvalidJson { error_message: String }, +} + pub struct ToolInputSender { - partial_tx: mpsc::UnboundedSender, - final_tx: Option>, + has_received_final: bool, + tx: mpsc::UnboundedSender>, } impl ToolInputSender { pub(crate) fn channel() -> (Self, ToolInput) { - let (partial_tx, partial_rx) = mpsc::unbounded(); - let (final_tx, final_rx) = oneshot::channel(); + let (tx, rx) = mpsc::unbounded(); let sender = Self { - partial_tx, - final_tx: Some(final_tx), + tx, + has_received_final: false, }; let input = ToolInput { - partial_rx, - final_rx, + rx, _phantom: PhantomData, }; (sender, input) } pub(crate) fn has_received_final(&self) -> bool { - self.final_tx.is_none() + self.has_received_final } - pub(crate) fn send_partial(&self, value: serde_json::Value) { - self.partial_tx.unbounded_send(value).ok(); + pub fn send_partial(&mut self, payload: serde_json::Value) { + self.tx + .unbounded_send(ToolInputPayload::Partial(payload)) + .ok(); } - pub(crate) fn send_final(mut self, value: serde_json::Value) { - // Close the partial channel so recv_partial() returns None - self.partial_tx.close_channel(); - if let Some(final_tx) = self.final_tx.take() { - final_tx.send(value).ok(); - } + pub fn send_full(&mut self, payload: serde_json::Value) { + self.has_received_final = true; + self.tx.unbounded_send(ToolInputPayload::Full(payload)).ok(); + } + + pub fn send_invalid_json(&mut self, error_message: String) { + self.has_received_final = true; + self.tx + .unbounded_send(ToolInputPayload::InvalidJson { error_message }) + .ok(); } } @@ -4251,68 +4311,78 @@ mod tests { ) { let (thread, event_stream) = setup_thread_for_test(cx).await; - cx.update(|cx| { - thread.update(cx, |thread, _cx| { - let tool_use_id = LanguageModelToolUseId::from("test_tool_id"); - let tool_name: Arc = Arc::from("test_tool"); - let raw_input: Arc = Arc::from("{invalid json"); - let json_parse_error = "expected value at line 1 column 1".to_string(); - - // Call the function under test - let result = thread.handle_tool_use_json_parse_error_event( - tool_use_id.clone(), - tool_name.clone(), - raw_input.clone(), - json_parse_error, - &event_stream, - ); - - // Verify the result is an error - assert!(result.is_error); - assert_eq!(result.tool_use_id, tool_use_id); - assert_eq!(result.tool_name, tool_name); - assert!(matches!( - result.content, - LanguageModelToolResultContent::Text(_) - )); - - // Verify the tool use was added to the message content - { - let last_message = thread.pending_message(); - assert_eq!( - last_message.content.len(), - 1, - "Should have one tool_use in content" - ); - - match &last_message.content[0] { - AgentMessageContent::ToolUse(tool_use) => { - assert_eq!(tool_use.id, tool_use_id); - assert_eq!(tool_use.name, tool_name); - assert_eq!(tool_use.raw_input, raw_input.to_string()); - assert!(tool_use.is_input_complete); - // Should fall back to empty object for invalid JSON - assert_eq!(tool_use.input, json!({})); - } - _ => panic!("Expected ToolUse content"), - } - } - - // Insert the tool result (simulating what the caller does) - thread - .pending_message() - .tool_results - .insert(result.tool_use_id.clone(), result); + let tool_use_id = LanguageModelToolUseId::from("test_tool_id"); + let tool_name: Arc = Arc::from("test_tool"); + let raw_input: Arc = Arc::from("{invalid json"); + let json_parse_error = "expected value at line 1 column 1".to_string(); + + let (_cancellation_tx, cancellation_rx) = watch::channel(false); + + let result = cx + .update(|cx| { + thread.update(cx, |thread, cx| { + // Call the function under test + thread + .handle_tool_use_json_parse_error_event( + tool_use_id.clone(), + tool_name.clone(), + raw_input.clone(), + json_parse_error, + &event_stream, + cancellation_rx, + cx, + ) + .unwrap() + }) + }) + .await; + + // Verify the result is an error + assert!(result.is_error); + assert_eq!(result.tool_use_id, tool_use_id); + assert_eq!(result.tool_name, tool_name); + assert!(matches!( + result.content, + LanguageModelToolResultContent::Text(_) + )); - // Verify the tool result was added + thread.update(cx, |thread, _cx| { + // Verify the tool use was added to the message content + { let last_message = thread.pending_message(); assert_eq!( - last_message.tool_results.len(), + last_message.content.len(), 1, - "Should have one tool_result" + "Should have one tool_use in content" ); - assert!(last_message.tool_results.contains_key(&tool_use_id)); - }); - }); + + match &last_message.content[0] { + AgentMessageContent::ToolUse(tool_use) => { + assert_eq!(tool_use.id, tool_use_id); + assert_eq!(tool_use.name, tool_name); + assert_eq!(tool_use.raw_input, raw_input.to_string()); + assert!(tool_use.is_input_complete); + // Should fall back to empty object for invalid JSON + assert_eq!(tool_use.input, json!({})); + } + _ => panic!("Expected ToolUse content"), + } + } + + // Insert the tool result (simulating what the caller does) + thread + .pending_message() + .tool_results + .insert(result.tool_use_id.clone(), result); + + // Verify the tool result was added + let last_message = thread.pending_message(); + assert_eq!( + last_message.tool_results.len(), + 1, + "Should have one tool_result" + ); + assert!(last_message.tool_results.contains_key(&tool_use_id)); + }) } } diff --git a/crates/agent/src/tools/streaming_edit_file_tool.rs b/crates/agent/src/tools/streaming_edit_file_tool.rs index bc99515e499696e3df11101be8b813afa027c8f4..47da35bbf25ad188f3f6b98e843b2955910bb7ac 100644 --- a/crates/agent/src/tools/streaming_edit_file_tool.rs +++ b/crates/agent/src/tools/streaming_edit_file_tool.rs @@ -2,6 +2,7 @@ use super::edit_file_tool::EditFileTool; use super::restore_file_from_disk_tool::RestoreFileFromDiskTool; use super::save_file_tool::SaveFileTool; use super::tool_edit_parser::{ToolEditEvent, ToolEditParser}; +use crate::ToolInputPayload; use crate::{ AgentTool, Thread, ToolCallEventStream, ToolInput, edit_agent::{ @@ -12,7 +13,7 @@ use crate::{ use acp_thread::Diff; use action_log::ActionLog; use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields}; -use anyhow::{Context as _, Result}; +use anyhow::Result; use collections::HashSet; use futures::FutureExt as _; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; @@ -188,6 +189,10 @@ pub enum StreamingEditFileToolOutput { }, Error { error: String, + #[serde(default)] + input_path: Option, + #[serde(default)] + diff: String, }, } @@ -195,6 +200,8 @@ impl StreamingEditFileToolOutput { pub fn error(error: impl Into) -> Self { Self::Error { error: error.into(), + input_path: None, + diff: String::new(), } } } @@ -215,7 +222,24 @@ impl std::fmt::Display for StreamingEditFileToolOutput { ) } } - StreamingEditFileToolOutput::Error { error } => write!(f, "{error}"), + StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } => { + write!(f, "{error}\n")?; + if let Some(input_path) = input_path + && !diff.is_empty() + { + write!( + f, + "Edited {}:\n\n```diff\n{diff}\n```", + input_path.display() + ) + } else { + write!(f, "No edits were made.") + } + } } } } @@ -233,6 +257,14 @@ pub struct StreamingEditFileTool { language_registry: Arc, } +enum EditSessionResult { + Completed(EditSession), + Failed { + error: String, + session: Option, + }, +} + impl StreamingEditFileTool { pub fn new( project: Entity, @@ -276,6 +308,158 @@ impl StreamingEditFileTool { }); } } + + async fn ensure_buffer_saved(&self, buffer: &Entity, cx: &mut AsyncApp) { + let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| { + let settings = language_settings::LanguageSettings::for_buffer(buffer, cx); + settings.format_on_save != FormatOnSave::Off + }); + + if format_on_save_enabled { + self.project + .update(cx, |project, cx| { + project.format( + HashSet::from_iter([buffer.clone()]), + LspFormatTarget::Buffers, + false, + FormatTrigger::Save, + cx, + ) + }) + .await + .log_err(); + } + + self.project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .log_err(); + + self.action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + }); + } + + async fn process_streaming_edits( + &self, + input: &mut ToolInput, + event_stream: &ToolCallEventStream, + cx: &mut AsyncApp, + ) -> EditSessionResult { + let mut session: Option = None; + let mut last_partial: Option = None; + + loop { + futures::select! { + payload = input.next().fuse() => { + match payload { + Ok(payload) => match payload { + ToolInputPayload::Partial(partial) => { + if let Ok(parsed) = serde_json::from_value::(partial) { + let path_complete = parsed.path.is_some() + && parsed.path.as_ref() == last_partial.as_ref().and_then(|partial| partial.path.as_ref()); + + last_partial = Some(parsed.clone()); + + if session.is_none() + && path_complete + && let StreamingEditFileToolPartialInput { + path: Some(path), + display_description: Some(display_description), + mode: Some(mode), + .. + } = &parsed + { + match EditSession::new( + PathBuf::from(path), + display_description, + *mode, + self, + event_stream, + cx, + ) + .await + { + Ok(created_session) => session = Some(created_session), + Err(error) => { + log::error!("Failed to create edit session: {}", error); + return EditSessionResult::Failed { + error, + session: None, + }; + } + } + } + + if let Some(current_session) = &mut session + && let Err(error) = current_session.process(parsed, self, event_stream, cx) + { + log::error!("Failed to process edit: {}", error); + return EditSessionResult::Failed { error, session }; + } + } + } + ToolInputPayload::Full(full_input) => { + let mut session = if let Some(session) = session { + session + } else { + match EditSession::new( + full_input.path.clone(), + &full_input.display_description, + full_input.mode, + self, + event_stream, + cx, + ) + .await + { + Ok(created_session) => created_session, + Err(error) => { + log::error!("Failed to create edit session: {}", error); + return EditSessionResult::Failed { + error, + session: None, + }; + } + } + }; + + return match session.finalize(full_input, self, event_stream, cx).await { + Ok(()) => EditSessionResult::Completed(session), + Err(error) => { + log::error!("Failed to finalize edit: {}", error); + EditSessionResult::Failed { + error, + session: Some(session), + } + } + }; + } + ToolInputPayload::InvalidJson { error_message } => { + log::error!("Received invalid JSON: {error_message}"); + return EditSessionResult::Failed { + error: error_message, + session, + }; + } + }, + Err(error) => { + return EditSessionResult::Failed { + error: format!("Failed to receive tool input: {error}"), + session, + }; + } + } + } + _ = event_stream.cancelled_by_user().fuse() => { + return EditSessionResult::Failed { + error: "Edit cancelled by user".to_string(), + session, + }; + } + } + } + } } impl AgentTool for StreamingEditFileTool { @@ -348,94 +532,40 @@ impl AgentTool for StreamingEditFileTool { cx: &mut App, ) -> Task> { cx.spawn(async move |cx: &mut AsyncApp| { - let mut state: Option = None; - let mut last_partial: Option = None; - loop { - futures::select! { - partial = input.recv_partial().fuse() => { - let Some(partial_value) = partial else { break }; - if let Ok(parsed) = serde_json::from_value::(partial_value) { - let path_complete = parsed.path.is_some() - && parsed.path.as_ref() == last_partial.as_ref().and_then(|p| p.path.as_ref()); - - last_partial = Some(parsed.clone()); - - if state.is_none() - && path_complete - && let StreamingEditFileToolPartialInput { - path: Some(path), - display_description: Some(display_description), - mode: Some(mode), - .. - } = &parsed - { - match EditSession::new( - &PathBuf::from(path), - display_description, - *mode, - &self, - &event_stream, - cx, - ) - .await - { - Ok(session) => state = Some(session), - Err(e) => { - log::error!("Failed to create edit session: {}", e); - return Err(e); - } - } - } - - if let Some(state) = &mut state { - if let Err(e) = state.process(parsed, &self, &event_stream, cx) { - log::error!("Failed to process edit: {}", e); - return Err(e); - } - } - } - } - _ = event_stream.cancelled_by_user().fuse() => { - return Err(StreamingEditFileToolOutput::error("Edit cancelled by user")); - } - } - } - let full_input = - input - .recv() - .await - .map_err(|e| { - let err = StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}")); - log::error!("Failed to receive tool input: {e}"); - err - })?; - - let mut state = if let Some(state) = state { - state - } else { - match EditSession::new( - &full_input.path, - &full_input.display_description, - full_input.mode, - &self, - &event_stream, - cx, - ) + match self + .process_streaming_edits(&mut input, &event_stream, cx) .await - { - Ok(session) => session, - Err(e) => { - log::error!("Failed to create edit session: {}", e); - return Err(e); - } + { + EditSessionResult::Completed(session) => { + self.ensure_buffer_saved(&session.buffer, cx).await; + let (new_text, diff) = session.compute_new_text_and_diff(cx).await; + Ok(StreamingEditFileToolOutput::Success { + old_text: session.old_text.clone(), + new_text, + input_path: session.input_path, + diff, + }) } - }; - match state.finalize(full_input, &self, &event_stream, cx).await { - Ok(output) => Ok(output), - Err(e) => { - log::error!("Failed to finalize edit: {}", e); - Err(e) + EditSessionResult::Failed { + error, + session: Some(session), + } => { + self.ensure_buffer_saved(&session.buffer, cx).await; + let (_new_text, diff) = session.compute_new_text_and_diff(cx).await; + Err(StreamingEditFileToolOutput::Error { + error, + input_path: Some(session.input_path), + diff, + }) } + EditSessionResult::Failed { + error, + session: None, + } => Err(StreamingEditFileToolOutput::Error { + error, + input_path: None, + diff: String::new(), + }), } }) } @@ -472,6 +602,7 @@ impl AgentTool for StreamingEditFileTool { pub struct EditSession { abs_path: PathBuf, + input_path: PathBuf, buffer: Entity, old_text: Arc, diff: Entity, @@ -518,23 +649,21 @@ impl EditPipeline { impl EditSession { async fn new( - path: &PathBuf, + path: PathBuf, display_description: &str, mode: StreamingEditFileMode, tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, - ) -> Result { - let project_path = cx - .update(|cx| resolve_path(mode, &path, &tool.project, cx)) - .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; + ) -> Result { + let project_path = cx.update(|cx| resolve_path(mode, &path, &tool.project, cx))?; let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx)) else { - return Err(StreamingEditFileToolOutput::error(format!( + return Err(format!( "Worktree at '{}' does not exist", path.to_string_lossy() - ))); + )); }; event_stream.update_fields( @@ -543,13 +672,13 @@ impl EditSession { cx.update(|cx| tool.authorize(&path, &display_description, event_stream, cx)) .await - .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; + .map_err(|e| e.to_string())?; let buffer = tool .project .update(cx, |project, cx| project.open_buffer(project_path, cx)) .await - .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; + .map_err(|e| e.to_string())?; ensure_buffer_saved(&buffer, &abs_path, tool, cx)?; @@ -578,6 +707,7 @@ impl EditSession { Ok(Self { abs_path, + input_path: path, buffer, old_text, diff, @@ -594,22 +724,20 @@ impl EditSession { tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, - ) -> Result { - let old_text = self.old_text.clone(); - + ) -> Result<(), String> { match input.mode { StreamingEditFileMode::Write => { - let content = input.content.ok_or_else(|| { - StreamingEditFileToolOutput::error("'content' field is required for write mode") - })?; + let content = input + .content + .ok_or_else(|| "'content' field is required for write mode".to_string())?; let events = self.parser.finalize_content(&content); self.process_events(&events, tool, event_stream, cx)?; } StreamingEditFileMode::Edit => { - let edits = input.edits.ok_or_else(|| { - StreamingEditFileToolOutput::error("'edits' field is required for edit mode") - })?; + let edits = input + .edits + .ok_or_else(|| "'edits' field is required for edit mode".to_string())?; let events = self.parser.finalize_edits(&edits); self.process_events(&events, tool, event_stream, cx)?; @@ -625,53 +753,15 @@ impl EditSession { } } } + Ok(()) + } - let format_on_save_enabled = self.buffer.read_with(cx, |buffer, cx| { - let settings = language_settings::LanguageSettings::for_buffer(buffer, cx); - settings.format_on_save != FormatOnSave::Off - }); - - if format_on_save_enabled { - tool.action_log.update(cx, |log, cx| { - log.buffer_edited(self.buffer.clone(), cx); - }); - - let format_task = tool.project.update(cx, |project, cx| { - project.format( - HashSet::from_iter([self.buffer.clone()]), - LspFormatTarget::Buffers, - false, - FormatTrigger::Save, - cx, - ) - }); - futures::select! { - result = format_task.fuse() => { result.log_err(); }, - _ = event_stream.cancelled_by_user().fuse() => { - return Err(StreamingEditFileToolOutput::error("Edit cancelled by user")); - } - }; - } - - let save_task = tool.project.update(cx, |project, cx| { - project.save_buffer(self.buffer.clone(), cx) - }); - futures::select! { - result = save_task.fuse() => { result.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; }, - _ = event_stream.cancelled_by_user().fuse() => { - return Err(StreamingEditFileToolOutput::error("Edit cancelled by user")); - } - }; - - tool.action_log.update(cx, |log, cx| { - log.buffer_edited(self.buffer.clone(), cx); - }); - + async fn compute_new_text_and_diff(&self, cx: &mut AsyncApp) -> (String, String) { let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let (new_text, unified_diff) = cx .background_spawn({ let new_snapshot = new_snapshot.clone(); - let old_text = old_text.clone(); + let old_text = self.old_text.clone(); async move { let new_text = new_snapshot.text(); let diff = language::unified_diff(&old_text, &new_text); @@ -679,14 +769,7 @@ impl EditSession { } }) .await; - - let output = StreamingEditFileToolOutput::Success { - input_path: input.path, - new_text, - old_text: old_text.clone(), - diff: unified_diff, - }; - Ok(output) + (new_text, unified_diff) } fn process( @@ -695,7 +778,7 @@ impl EditSession { tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, - ) -> Result<(), StreamingEditFileToolOutput> { + ) -> Result<(), String> { match &self.mode { StreamingEditFileMode::Write => { if let Some(content) = &partial.content { @@ -719,7 +802,7 @@ impl EditSession { tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, - ) -> Result<(), StreamingEditFileToolOutput> { + ) -> Result<(), String> { for event in events { match event { ToolEditEvent::ContentChunk { chunk } => { @@ -969,14 +1052,14 @@ fn extract_match( buffer: &Entity, edit_index: &usize, cx: &mut AsyncApp, -) -> Result, StreamingEditFileToolOutput> { +) -> Result, String> { match matches.len() { - 0 => Err(StreamingEditFileToolOutput::error(format!( + 0 => Err(format!( "Could not find matching text for edit at index {}. \ The old_text did not match any content in the file. \ Please read the file again to get the current content.", edit_index, - ))), + )), 1 => Ok(matches.into_iter().next().unwrap()), _ => { let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); @@ -985,12 +1068,12 @@ fn extract_match( .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string()) .collect::>() .join(", "); - Err(StreamingEditFileToolOutput::error(format!( + Err(format!( "Edit {} matched multiple locations in the file at lines: {}. \ Please provide more context in old_text to uniquely \ identify the location.", edit_index, lines - ))) + )) } } } @@ -1022,7 +1105,7 @@ fn ensure_buffer_saved( abs_path: &PathBuf, tool: &StreamingEditFileTool, cx: &mut AsyncApp, -) -> Result<(), StreamingEditFileToolOutput> { +) -> Result<(), String> { let last_read_mtime = tool .action_log .read_with(cx, |log, _| log.file_read_time(abs_path)); @@ -1063,15 +1146,14 @@ fn ensure_buffer_saved( then ask them to save or revert the file manually and inform you when it's ok to proceed." } }; - return Err(StreamingEditFileToolOutput::error(message)); + return Err(message.to_string()); } if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) { if current != last_read { - return Err(StreamingEditFileToolOutput::error( - "The file has been modified since you last read it. \ - Please read the file again to get the current state before editing it.", - )); + return Err("The file has been modified since you last read it. \ + Please read the file again to get the current state before editing it." + .to_string()); } } @@ -1083,56 +1165,63 @@ fn resolve_path( path: &PathBuf, project: &Entity, cx: &mut App, -) -> Result { +) -> Result { let project = project.read(cx); match mode { StreamingEditFileMode::Edit => { let path = project .find_project_path(&path, cx) - .context("Can't edit file: path not found")?; + .ok_or_else(|| "Can't edit file: path not found".to_string())?; let entry = project .entry_for_path(&path, cx) - .context("Can't edit file: path not found")?; + .ok_or_else(|| "Can't edit file: path not found".to_string())?; - anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory"); - Ok(path) + if entry.is_file() { + Ok(path) + } else { + Err("Can't edit file: path is a directory".to_string()) + } } StreamingEditFileMode::Write => { if let Some(path) = project.find_project_path(&path, cx) && let Some(entry) = project.entry_for_path(&path, cx) { - anyhow::ensure!(entry.is_file(), "Can't write to file: path is a directory"); - return Ok(path); + if entry.is_file() { + return Ok(path); + } else { + return Err("Can't write to file: path is a directory".to_string()); + } } - let parent_path = path.parent().context("Can't create file: incorrect path")?; + let parent_path = path + .parent() + .ok_or_else(|| "Can't create file: incorrect path".to_string())?; let parent_project_path = project.find_project_path(&parent_path, cx); let parent_entry = parent_project_path .as_ref() .and_then(|path| project.entry_for_path(path, cx)) - .context("Can't create file: parent directory doesn't exist")?; + .ok_or_else(|| "Can't create file: parent directory doesn't exist")?; - anyhow::ensure!( - parent_entry.is_dir(), - "Can't create file: parent is not a directory" - ); + if !parent_entry.is_dir() { + return Err("Can't create file: parent is not a directory".to_string()); + } let file_name = path .file_name() .and_then(|file_name| file_name.to_str()) .and_then(|file_name| RelPath::unix(file_name).ok()) - .context("Can't create file: invalid filename")?; + .ok_or_else(|| "Can't create file: invalid filename".to_string())?; let new_file_path = parent_project_path.map(|parent| ProjectPath { path: parent.path.join(file_name), ..parent }); - new_file_path.context("Can't create file") + new_file_path.ok_or_else(|| "Can't create file".to_string()) } } } @@ -1382,10 +1471,17 @@ mod tests { }) .await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } = result.unwrap_err() + else { panic!("expected error"); }; assert_eq!(error, "Can't edit file: path not found"); + assert!(diff.is_empty()); + assert_eq!(input_path, None); } #[gpui::test] @@ -1411,7 +1507,7 @@ mod tests { }) .await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else { panic!("expected error"); }; assert!( @@ -1424,7 +1520,7 @@ mod tests { async fn test_streaming_early_buffer_open(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1447,7 +1543,7 @@ mod tests { cx.run_until_parked(); // Now send the final complete input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit lines", "path": "root/file.txt", "mode": "edit", @@ -1465,7 +1561,7 @@ mod tests { async fn test_streaming_path_completeness_heuristic(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1485,7 +1581,7 @@ mod tests { cx.run_until_parked(); // Send final - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "path": "root/file.txt", "mode": "write", @@ -1503,7 +1599,7 @@ mod tests { async fn test_streaming_cancellation_during_partials(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver, mut cancellation_tx) = ToolCallEventStream::test_with_cancellation(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1521,7 +1617,7 @@ mod tests { drop(sender); let result = task.await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else { panic!("expected error"); }; assert!( @@ -1537,7 +1633,7 @@ mod tests { json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1578,7 +1674,7 @@ mod tests { cx.run_until_parked(); // Send final complete input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit multiple lines", "path": "root/file.txt", "mode": "edit", @@ -1601,7 +1697,7 @@ mod tests { #[gpui::test] async fn test_streaming_create_file_with_partials(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1625,7 +1721,7 @@ mod tests { cx.run_until_parked(); // Final with full content - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Create new file", "path": "root/dir/new_file.txt", "mode": "write", @@ -1643,12 +1739,12 @@ mod tests { async fn test_streaming_no_partials_direct_final(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send final immediately with no partials (simulates non-streaming path) - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit lines", "path": "root/file.txt", "mode": "edit", @@ -1669,7 +1765,7 @@ mod tests { json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1739,7 +1835,7 @@ mod tests { ); // Send final complete input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit multiple lines", "path": "root/file.txt", "mode": "edit", @@ -1767,7 +1863,7 @@ mod tests { async fn test_streaming_incremental_three_edits(cx: &mut TestAppContext) { let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1835,7 +1931,7 @@ mod tests { assert_eq!(buffer_text.as_deref(), Some("AAA\nbbb\nCCC\nddd\nEEEeee\n")); // Send final - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit three lines", "path": "root/file.txt", "mode": "edit", @@ -1857,7 +1953,7 @@ mod tests { async fn test_streaming_edit_failure_mid_stream(cx: &mut TestAppContext) { let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1893,16 +1989,17 @@ mod tests { })); cx.run_until_parked(); - // Verify edit 1 was applied - let buffer_text = project.update(cx, |project, cx| { + let buffer = project.update(cx, |project, cx| { let pp = project .find_project_path(&PathBuf::from("root/file.txt"), cx) .unwrap(); - project.get_open_buffer(&pp, cx).map(|b| b.read(cx).text()) + project.get_open_buffer(&pp, cx).unwrap() }); + + // Verify edit 1 was applied + let buffer_text = buffer.read_with(cx, |buffer, _cx| buffer.text()); assert_eq!( - buffer_text.as_deref(), - Some("MODIFIED\nline 2\nline 3\n"), + buffer_text, "MODIFIED\nline 2\nline 3\n", "First edit should be applied even though second edit will fail" ); @@ -1925,20 +2022,32 @@ mod tests { drop(sender); let result = task.await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } = result.unwrap_err() + else { panic!("expected error"); }; + assert!( error.contains("Could not find matching text for edit at index 1"), "Expected error about edit 1 failing, got: {error}" ); + // Ensure that first edit was applied successfully and that we saved the buffer + assert_eq!(input_path, Some(PathBuf::from("root/file.txt"))); + assert_eq!( + diff, + "@@ -1,3 +1,3 @@\n-line 1\n+MODIFIED\n line 2\n line 3\n" + ); } #[gpui::test] async fn test_streaming_single_edit_no_incremental(cx: &mut TestAppContext) { let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1975,7 +2084,7 @@ mod tests { ); // Send final — the edit is applied during finalization - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Single edit", "path": "root/file.txt", "mode": "edit", @@ -1993,7 +2102,7 @@ mod tests { async fn test_streaming_input_partials_then_final(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; - let (sender, input): (ToolInputSender, ToolInput) = + let (mut sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); let (event_stream, _event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -2020,7 +2129,7 @@ mod tests { cx.run_until_parked(); // Send the final complete input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit lines", "path": "root/file.txt", "mode": "edit", @@ -2038,7 +2147,7 @@ mod tests { async fn test_streaming_input_sender_dropped_before_final(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world\n"})).await; - let (sender, input): (ToolInputSender, ToolInput) = + let (mut sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); let (event_stream, _event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -2064,7 +2173,7 @@ mod tests { // Create a channel and send multiple partials before a final, then use // ToolInput::resolved-style immediate delivery to confirm recv() works // when partials are already buffered. - let (sender, input): (ToolInputSender, ToolInput) = + let (mut sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); let (event_stream, _event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -2077,7 +2186,7 @@ mod tests { "path": "root/dir/new.txt", "mode": "write" })); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Create", "path": "root/dir/new.txt", "mode": "write", @@ -2109,13 +2218,13 @@ mod tests { let result = test_resolve_path(&mode, "root/dir/subdir", cx); assert_eq!( - result.await.unwrap_err().to_string(), + result.await.unwrap_err(), "Can't write to file: path is a directory" ); let result = test_resolve_path(&mode, "root/dir/nonexistent_dir/new.txt", cx); assert_eq!( - result.await.unwrap_err().to_string(), + result.await.unwrap_err(), "Can't create file: parent directory doesn't exist" ); } @@ -2133,14 +2242,11 @@ mod tests { assert_resolved_path_eq(result.await, rel_path(path_without_root)); let result = test_resolve_path(&mode, "root/nonexistent.txt", cx); - assert_eq!( - result.await.unwrap_err().to_string(), - "Can't edit file: path not found" - ); + assert_eq!(result.await.unwrap_err(), "Can't edit file: path not found"); let result = test_resolve_path(&mode, "root/dir", cx); assert_eq!( - result.await.unwrap_err().to_string(), + result.await.unwrap_err(), "Can't edit file: path is a directory" ); } @@ -2149,7 +2255,7 @@ mod tests { mode: &StreamingEditFileMode, path: &str, cx: &mut TestAppContext, - ) -> anyhow::Result { + ) -> Result { init_test(cx); let fs = project::FakeFs::new(cx.executor()); @@ -2170,7 +2276,7 @@ mod tests { } #[track_caller] - fn assert_resolved_path_eq(path: anyhow::Result, expected: &RelPath) { + fn assert_resolved_path_eq(path: Result, expected: &RelPath) { let actual = path.expect("Should return valid path").path; assert_eq!(actual.as_ref(), expected); } @@ -2259,7 +2365,7 @@ mod tests { }); // Use streaming pattern so executor can pump the LSP request/response - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -2271,7 +2377,7 @@ mod tests { })); cx.run_until_parked(); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Create main function", "path": "root/src/main.rs", "mode": "write", @@ -2310,7 +2416,7 @@ mod tests { }); }); - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let tool2 = Arc::new(StreamingEditFileTool::new( @@ -2329,7 +2435,7 @@ mod tests { })); cx.run_until_parked(); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Update main function", "path": "root/src/main.rs", "mode": "write", @@ -3288,14 +3394,22 @@ mod tests { }) .await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } = result.unwrap_err() + else { panic!("expected error"); }; + assert!( error.contains("has been modified since you last read it"), "Error should mention file modification, got: {}", error ); + assert!(diff.is_empty()); + assert!(input_path.is_none()); } #[gpui::test] @@ -3362,7 +3476,12 @@ mod tests { }) .await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } = result.unwrap_err() + else { panic!("expected error"); }; assert!( @@ -3380,6 +3499,8 @@ mod tests { "Error should ask user to manually save or revert when tools aren't available, got: {}", error ); + assert!(diff.is_empty()); + assert!(input_path.is_none()); } #[gpui::test] @@ -3390,7 +3511,7 @@ mod tests { // the modified buffer and succeeds. let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3420,7 +3541,7 @@ mod tests { cx.run_until_parked(); // Send the final input with all three edits. - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overlapping edits", "path": "root/file.txt", "mode": "edit", @@ -3441,7 +3562,7 @@ mod tests { #[gpui::test] async fn test_streaming_create_content_streamed(cx: &mut TestAppContext) { let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3495,7 +3616,7 @@ mod tests { ); // Send final input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Create new file", "path": "root/dir/new_file.txt", "mode": "write", @@ -3516,7 +3637,7 @@ mod tests { json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), ) .await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, mut receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3559,7 +3680,7 @@ mod tests { }); // Send final input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "path": "root/file.txt", "mode": "write", @@ -3587,7 +3708,7 @@ mod tests { json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), ) .await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3634,7 +3755,7 @@ mod tests { ); // Send final input with complete content - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "path": "root/file.txt", "mode": "write", @@ -3656,7 +3777,7 @@ mod tests { async fn test_streaming_edit_json_fixer_escape_corruption(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello\nworld\nfoo\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3690,7 +3811,7 @@ mod tests { cx.run_until_parked(); // Send final. - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit", "path": "root/file.txt", "mode": "edit", @@ -3708,7 +3829,7 @@ mod tests { async fn test_streaming_final_input_stringified_edits_succeeds(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello\nworld\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3719,7 +3840,7 @@ mod tests { })); cx.run_until_parked(); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit", "path": "root/file.txt", "mode": "edit", @@ -3823,7 +3944,7 @@ mod tests { ) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "old_content"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3849,7 +3970,7 @@ mod tests { cx.run_until_parked(); // Send final. - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "mode": "write", "content": "new_content", @@ -3869,7 +3990,7 @@ mod tests { ) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "old_content"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3902,7 +4023,7 @@ mod tests { cx.run_until_parked(); // Send final. - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "mode": "edit", "edits": [{"old_text": "old_content", "new_text": "new_content"}], @@ -3939,11 +4060,11 @@ mod tests { let old_text = "}\n\n\n\nfn render_search"; let new_text = "}\n\nfn render_search"; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Remove extra blank lines", "path": "root/file.rs", "mode": "edit", @@ -3980,11 +4101,11 @@ mod tests { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.rs": file_content})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "description", "path": "root/file.rs", "mode": "edit", diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index ae01084a2657abdc86e7510aa49663cf98aabe70..50037f31facbac446de7ecf38536d1e4a24c7867 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -125,6 +125,7 @@ pub struct FakeLanguageModel { >, forbid_requests: AtomicBool, supports_thinking: AtomicBool, + supports_streaming_tools: AtomicBool, } impl Default for FakeLanguageModel { @@ -137,6 +138,7 @@ impl Default for FakeLanguageModel { current_completion_txs: Mutex::new(Vec::new()), forbid_requests: AtomicBool::new(false), supports_thinking: AtomicBool::new(false), + supports_streaming_tools: AtomicBool::new(false), } } } @@ -169,6 +171,10 @@ impl FakeLanguageModel { self.supports_thinking.store(supports, SeqCst); } + pub fn set_supports_streaming_tools(&self, supports: bool) { + self.supports_streaming_tools.store(supports, SeqCst); + } + pub fn pending_completions(&self) -> Vec { self.current_completion_txs .lock() @@ -282,6 +288,10 @@ impl LanguageModel for FakeLanguageModel { self.supports_thinking.load(SeqCst) } + fn supports_streaming_tools(&self) -> bool { + self.supports_streaming_tools.load(SeqCst) + } + fn telemetry_id(&self) -> String { "fake".to_string() }