agent: Fix issue with streaming tools when model produces invalid JSON (#52891)

Bennet Bo Fenner created

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Closes #ISSUE

Release Notes:

- N/A

Change summary

crates/agent/src/tests/edit_file_thread_test.rs    | 211 ++++++
crates/agent/src/tests/mod.rs                      | 112 +++
crates/agent/src/tests/test_tools.rs               |  67 +
crates/agent/src/thread.rs                         | 312 +++++---
crates/agent/src/tools/streaming_edit_file_tool.rs | 550 +++++++++------
crates/language_model/src/fake_provider.rs         |  10 
6 files changed, 916 insertions(+), 346 deletions(-)

Detailed changes

crates/agent/src/tests/edit_file_thread_test.rs 🔗

@@ -202,3 +202,214 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
         );
     });
 }
+
+#[gpui::test]
+async fn test_streaming_edit_json_parse_error_does_not_cause_unsaved_changes(
+    cx: &mut TestAppContext,
+) {
+    super::init_test(cx);
+    super::always_allow_tools(cx);
+
+    // Enable the streaming edit file tool feature flag.
+    cx.update(|cx| {
+        cx.update_flags(true, vec!["streaming-edit-file-tool".to_string()]);
+    });
+
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(
+        path!("/project"),
+        json!({
+            "src": {
+                "main.rs": "fn main() {\n    println!(\"Hello, world!\");\n}\n"
+            }
+        }),
+    )
+    .await;
+
+    let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+    let project_context = cx.new(|_cx| ProjectContext::default());
+    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+    let context_server_registry =
+        cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
+    let model = Arc::new(FakeLanguageModel::default());
+    model.as_fake().set_supports_streaming_tools(true);
+    let fake_model = model.as_fake();
+
+    let thread = cx.new(|cx| {
+        let mut thread = crate::Thread::new(
+            project.clone(),
+            project_context,
+            context_server_registry,
+            crate::Templates::new(),
+            Some(model.clone()),
+            cx,
+        );
+        let language_registry = project.read(cx).languages().clone();
+        thread.add_tool(crate::StreamingEditFileTool::new(
+            project.clone(),
+            cx.weak_entity(),
+            thread.action_log().clone(),
+            language_registry,
+        ));
+        thread
+    });
+
+    let _events = thread
+        .update(cx, |thread, cx| {
+            thread.send(
+                UserMessageId::new(),
+                ["Write new content to src/main.rs"],
+                cx,
+            )
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    let tool_use_id = "edit_1";
+    let partial_1 = LanguageModelToolUse {
+        id: tool_use_id.into(),
+        name: EditFileTool::NAME.into(),
+        raw_input: json!({
+            "display_description": "Rewrite main.rs",
+            "path": "project/src/main.rs",
+            "mode": "write"
+        })
+        .to_string(),
+        input: json!({
+            "display_description": "Rewrite main.rs",
+            "path": "project/src/main.rs",
+            "mode": "write"
+        }),
+        is_input_complete: false,
+        thought_signature: None,
+    };
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_1));
+    cx.run_until_parked();
+
+    let partial_2 = LanguageModelToolUse {
+        id: tool_use_id.into(),
+        name: EditFileTool::NAME.into(),
+        raw_input: json!({
+            "display_description": "Rewrite main.rs",
+            "path": "project/src/main.rs",
+            "mode": "write",
+            "content": "fn main() { /* rewritten */ }"
+        })
+        .to_string(),
+        input: json!({
+            "display_description": "Rewrite main.rs",
+            "path": "project/src/main.rs",
+            "mode": "write",
+            "content": "fn main() { /* rewritten */ }"
+        }),
+        is_input_complete: false,
+        thought_signature: None,
+    };
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_2));
+    cx.run_until_parked();
+
+    // Now send a json parse error. At this point we have started writing content to the buffer.
+    fake_model.send_last_completion_stream_event(
+        LanguageModelCompletionEvent::ToolUseJsonParseError {
+            id: tool_use_id.into(),
+            tool_name: EditFileTool::NAME.into(),
+            raw_input: r#"{"display_description":"Rewrite main.rs","path":"project/src/main.rs","mode":"write","content":"fn main() { /* rewritten "#.into(),
+            json_parse_error: "EOF while parsing a string at line 1 column 95".into(),
+        },
+    );
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    // cx.executor().advance_clock(Duration::from_secs(5));
+    // cx.run_until_parked();
+
+    assert!(
+        !fake_model.pending_completions().is_empty(),
+        "Thread should have retried after the error"
+    );
+
+    // Respond with a new, well-formed, complete edit_file tool use.
+    let tool_use = LanguageModelToolUse {
+        id: "edit_2".into(),
+        name: EditFileTool::NAME.into(),
+        raw_input: json!({
+            "display_description": "Rewrite main.rs",
+            "path": "project/src/main.rs",
+            "mode": "write",
+            "content": "fn main() {\n    println!(\"Hello, rewritten!\");\n}\n"
+        })
+        .to_string(),
+        input: json!({
+            "display_description": "Rewrite main.rs",
+            "path": "project/src/main.rs",
+            "mode": "write",
+            "content": "fn main() {\n    println!(\"Hello, rewritten!\");\n}\n"
+        }),
+        is_input_complete: true,
+        thought_signature: None,
+    };
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    let pending_completions = fake_model.pending_completions();
+    assert!(
+        pending_completions.len() == 1,
+        "Expected only the follow-up completion containing the successful tool result"
+    );
+
+    let completion = pending_completions
+        .into_iter()
+        .last()
+        .expect("Expected a completion containing the tool result for edit_2");
+
+    let tool_result = completion
+        .messages
+        .iter()
+        .flat_map(|msg| &msg.content)
+        .find_map(|content| match content {
+            language_model::MessageContent::ToolResult(result)
+                if result.tool_use_id == language_model::LanguageModelToolUseId::from("edit_2") =>
+            {
+                Some(result)
+            }
+            _ => None,
+        })
+        .expect("Should have a tool result for edit_2");
+
+    // Ensure that the second tool call completed successfully and edits were applied.
+    assert!(
+        !tool_result.is_error,
+        "Tool result should succeed, got: {:?}",
+        tool_result
+    );
+    let content_text = match &tool_result.content {
+        language_model::LanguageModelToolResultContent::Text(t) => t.to_string(),
+        other => panic!("Expected text content, got: {:?}", other),
+    };
+    assert!(
+        !content_text.contains("file has been modified since you last read it"),
+        "Did not expect a stale last-read error, got: {content_text}"
+    );
+    assert!(
+        !content_text.contains("This file has unsaved changes"),
+        "Did not expect an unsaved-changes error, got: {content_text}"
+    );
+
+    let file_content = fs
+        .load(path!("/project/src/main.rs").as_ref())
+        .await
+        .expect("file should exist");
+    super::assert_eq!(
+        file_content,
+        "fn main() {\n    println!(\"Hello, rewritten!\");\n}\n",
+        "The second edit should be applied and saved gracefully"
+    );
+
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+}

crates/agent/src/tests/mod.rs 🔗

@@ -3903,6 +3903,117 @@ async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input(
     });
 }
 
+#[gpui::test]
+async fn test_streaming_tool_json_parse_error_is_forwarded_to_running_tool(
+    cx: &mut TestAppContext,
+) {
+    init_test(cx);
+    always_allow_tools(cx);
+
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    thread.update(cx, |thread, _cx| {
+        thread.add_tool(StreamingJsonErrorContextTool);
+    });
+
+    let _events = thread
+        .update(cx, |thread, cx| {
+            thread.send(
+                UserMessageId::new(),
+                ["Use the streaming_json_error_context tool"],
+                cx,
+            )
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    let tool_use = LanguageModelToolUse {
+        id: "tool_1".into(),
+        name: StreamingJsonErrorContextTool::NAME.into(),
+        raw_input: r#"{"text": "partial"#.into(),
+        input: json!({"text": "partial"}),
+        is_input_complete: false,
+        thought_signature: None,
+    };
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
+    cx.run_until_parked();
+
+    fake_model.send_last_completion_stream_event(
+        LanguageModelCompletionEvent::ToolUseJsonParseError {
+            id: "tool_1".into(),
+            tool_name: StreamingJsonErrorContextTool::NAME.into(),
+            raw_input: r#"{"text": "partial"#.into(),
+            json_parse_error: "EOF while parsing a string at line 1 column 17".into(),
+        },
+    );
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    cx.executor().advance_clock(Duration::from_secs(5));
+    cx.run_until_parked();
+
+    let completion = fake_model
+        .pending_completions()
+        .pop()
+        .expect("No running turn");
+
+    let tool_results: Vec<_> = completion
+        .messages
+        .iter()
+        .flat_map(|message| &message.content)
+        .filter_map(|content| match content {
+            MessageContent::ToolResult(result)
+                if result.tool_use_id == language_model::LanguageModelToolUseId::from("tool_1") =>
+            {
+                Some(result)
+            }
+            _ => None,
+        })
+        .collect();
+
+    assert_eq!(
+        tool_results.len(),
+        1,
+        "Expected exactly 1 tool result for tool_1, got {}: {:#?}",
+        tool_results.len(),
+        tool_results
+    );
+
+    let result = tool_results[0];
+    assert!(result.is_error);
+    let content_text = match &result.content {
+        language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
+        other => panic!("Expected text content, got {:?}", other),
+    };
+    assert!(
+        content_text.contains("Saw partial text 'partial' before invalid JSON"),
+        "Expected tool-enriched partial context, got: {content_text}"
+    );
+    assert!(
+        content_text
+            .contains("Error parsing input JSON: EOF while parsing a string at line 1 column 17"),
+        "Expected forwarded JSON parse error, got: {content_text}"
+    );
+    assert!(
+        !content_text.contains("tool input was not fully received"),
+        "Should not contain orphaned sender error, got: {content_text}"
+    );
+
+    fake_model.send_last_completion_stream_text_chunk("Done");
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    thread.read_with(cx, |thread, _cx| {
+        assert!(
+            thread.is_turn_complete(),
+            "Thread should not be stuck; the turn should have completed",
+        );
+    });
+}
+
 /// Filters out the stop events for asserting against in tests
 fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
     result_events
@@ -3959,6 +4070,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
                             InfiniteTool::NAME: true,
                             CancellationAwareTool::NAME: true,
                             StreamingEchoTool::NAME: true,
+                            StreamingJsonErrorContextTool::NAME: true,
                             StreamingFailingEchoTool::NAME: true,
                             TerminalTool::NAME: true,
                             UpdatePlanTool::NAME: true,

crates/agent/src/tests/test_tools.rs 🔗

@@ -56,13 +56,12 @@ impl AgentTool for StreamingEchoTool {
 
     fn run(
         self: Arc<Self>,
-        mut input: ToolInput<Self::Input>,
+        input: ToolInput<Self::Input>,
         _event_stream: ToolCallEventStream,
         cx: &mut App,
     ) -> Task<Result<String, String>> {
         let wait_until_complete_rx = self.wait_until_complete_rx.lock().unwrap().take();
         cx.spawn(async move |_cx| {
-            while input.recv_partial().await.is_some() {}
             let input = input
                 .recv()
                 .await
@@ -75,6 +74,68 @@ impl AgentTool for StreamingEchoTool {
     }
 }
 
+#[derive(JsonSchema, Serialize, Deserialize)]
+pub struct StreamingJsonErrorContextToolInput {
+    /// The text to echo.
+    pub text: String,
+}
+
+pub struct StreamingJsonErrorContextTool;
+
+impl AgentTool for StreamingJsonErrorContextTool {
+    type Input = StreamingJsonErrorContextToolInput;
+    type Output = String;
+
+    const NAME: &'static str = "streaming_json_error_context";
+
+    fn supports_input_streaming() -> bool {
+        true
+    }
+
+    fn kind() -> acp::ToolKind {
+        acp::ToolKind::Other
+    }
+
+    fn initial_title(
+        &self,
+        _input: Result<Self::Input, serde_json::Value>,
+        _cx: &mut App,
+    ) -> SharedString {
+        "Streaming JSON Error Context".into()
+    }
+
+    fn run(
+        self: Arc<Self>,
+        mut input: ToolInput<Self::Input>,
+        _event_stream: ToolCallEventStream,
+        cx: &mut App,
+    ) -> Task<Result<String, String>> {
+        cx.spawn(async move |_cx| {
+            let mut last_partial_text = None;
+
+            loop {
+                match input.next().await {
+                    Ok(ToolInputPayload::Partial(partial)) => {
+                        if let Some(text) = partial.get("text").and_then(|value| value.as_str()) {
+                            last_partial_text = Some(text.to_string());
+                        }
+                    }
+                    Ok(ToolInputPayload::Full(input)) => return Ok(input.text),
+                    Ok(ToolInputPayload::InvalidJson { error_message }) => {
+                        let partial_text = last_partial_text.unwrap_or_default();
+                        return Err(format!(
+                            "Saw partial text '{partial_text}' before invalid JSON: {error_message}"
+                        ));
+                    }
+                    Err(error) => {
+                        return Err(format!("Failed to receive tool input: {error}"));
+                    }
+                }
+            }
+        })
+    }
+}
+
 /// A streaming tool that echoes its input, used to test streaming tool
 /// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
 /// before `is_input_complete`).
@@ -119,7 +180,7 @@ impl AgentTool for StreamingFailingEchoTool {
     ) -> Task<Result<Self::Output, Self::Output>> {
         cx.spawn(async move |_cx| {
             for _ in 0..self.receive_chunks_until_failure {
-                let _ = input.recv_partial().await;
+                let _ = input.next().await;
             }
             Err("failed".into())
         })

crates/agent/src/thread.rs 🔗

@@ -22,13 +22,13 @@ use client::UserStore;
 use cloud_api_types::Plan;
 use collections::{HashMap, HashSet, IndexMap};
 use fs::Fs;
-use futures::stream;
 use futures::{
     FutureExt,
     channel::{mpsc, oneshot},
     future::Shared,
     stream::FuturesUnordered,
 };
+use futures::{StreamExt, stream};
 use gpui::{
     App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
 };
@@ -47,7 +47,6 @@ use schemars::{JsonSchema, Schema};
 use serde::de::DeserializeOwned;
 use serde::{Deserialize, Serialize};
 use settings::{LanguageModelSelection, Settings, ToolPermissionMode, update_settings_file};
-use smol::stream::StreamExt;
 use std::{
     collections::BTreeMap,
     marker::PhantomData,
@@ -2095,7 +2094,7 @@ impl Thread {
         this.update(cx, |this, _cx| {
             this.pending_message()
                 .tool_results
-                .insert(tool_result.tool_use_id.clone(), tool_result);
+                .insert(tool_result.tool_use_id.clone(), tool_result)
         })?;
         Ok(())
     }
@@ -2195,15 +2194,15 @@ impl Thread {
                 raw_input,
                 json_parse_error,
             } => {
-                return Ok(Some(Task::ready(
-                    self.handle_tool_use_json_parse_error_event(
-                        id,
-                        tool_name,
-                        raw_input,
-                        json_parse_error,
-                        event_stream,
-                    ),
-                )));
+                return Ok(self.handle_tool_use_json_parse_error_event(
+                    id,
+                    tool_name,
+                    raw_input,
+                    json_parse_error,
+                    event_stream,
+                    cancellation_rx,
+                    cx,
+                ));
             }
             UsageUpdate(usage) => {
                 telemetry::event!(
@@ -2304,12 +2303,12 @@ impl Thread {
         if !tool_use.is_input_complete {
             if tool.supports_input_streaming() {
                 let running_turn = self.running_turn.as_mut()?;
-                if let Some(sender) = running_turn.streaming_tool_inputs.get(&tool_use.id) {
+                if let Some(sender) = running_turn.streaming_tool_inputs.get_mut(&tool_use.id) {
                     sender.send_partial(tool_use.input);
                     return None;
                 }
 
-                let (sender, tool_input) = ToolInputSender::channel();
+                let (mut sender, tool_input) = ToolInputSender::channel();
                 sender.send_partial(tool_use.input);
                 running_turn
                     .streaming_tool_inputs
@@ -2331,13 +2330,13 @@ impl Thread {
             }
         }
 
-        if let Some(sender) = self
+        if let Some(mut sender) = self
             .running_turn
             .as_mut()?
             .streaming_tool_inputs
             .remove(&tool_use.id)
         {
-            sender.send_final(tool_use.input);
+            sender.send_full(tool_use.input);
             return None;
         }
 
@@ -2410,10 +2409,12 @@ impl Thread {
         raw_input: Arc<str>,
         json_parse_error: String,
         event_stream: &ThreadEventStream,
-    ) -> LanguageModelToolResult {
+        cancellation_rx: watch::Receiver<bool>,
+        cx: &mut Context<Self>,
+    ) -> Option<Task<LanguageModelToolResult>> {
         let tool_use = LanguageModelToolUse {
-            id: tool_use_id.clone(),
-            name: tool_name.clone(),
+            id: tool_use_id,
+            name: tool_name,
             raw_input: raw_input.to_string(),
             input: serde_json::json!({}),
             is_input_complete: true,
@@ -2426,14 +2427,43 @@ impl Thread {
             event_stream,
         );
 
-        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
-        LanguageModelToolResult {
-            tool_use_id,
-            tool_name,
-            is_error: true,
-            content: LanguageModelToolResultContent::Text(tool_output.into()),
-            output: Some(serde_json::Value::String(raw_input.to_string())),
+        let tool = self.tool(tool_use.name.as_ref());
+
+        let Some(tool) = tool else {
+            let content = format!("No tool named {} exists", tool_use.name);
+            return Some(Task::ready(LanguageModelToolResult {
+                content: LanguageModelToolResultContent::Text(Arc::from(content)),
+                tool_use_id: tool_use.id,
+                tool_name: tool_use.name,
+                is_error: true,
+                output: None,
+            }));
+        };
+
+        let error_message = format!("Error parsing input JSON: {json_parse_error}");
+
+        if tool.supports_input_streaming()
+            && let Some(mut sender) = self
+                .running_turn
+                .as_mut()?
+                .streaming_tool_inputs
+                .remove(&tool_use.id)
+        {
+            sender.send_invalid_json(error_message);
+            return None;
         }
+
+        log::debug!("Running tool {}. Received invalid JSON", tool_use.name);
+        let tool_input = ToolInput::invalid_json(error_message);
+        Some(self.run_tool(
+            tool,
+            tool_input,
+            tool_use.id,
+            tool_use.name,
+            event_stream,
+            cancellation_rx,
+            cx,
+        ))
     }
 
     fn send_or_update_tool_use(
@@ -3114,8 +3144,7 @@ impl EventEmitter<TitleUpdated> for Thread {}
 /// For streaming tools, partial JSON snapshots arrive via `.recv_partial()` as the LLM streams
 /// them, followed by the final complete input available through `.recv()`.
 pub struct ToolInput<T> {
-    partial_rx: mpsc::UnboundedReceiver<serde_json::Value>,
-    final_rx: oneshot::Receiver<serde_json::Value>,
+    rx: mpsc::UnboundedReceiver<ToolInputPayload<serde_json::Value>>,
     _phantom: PhantomData<T>,
 }
 
@@ -3127,13 +3156,20 @@ impl<T: DeserializeOwned> ToolInput<T> {
     }
 
     pub fn ready(value: serde_json::Value) -> Self {
-        let (partial_tx, partial_rx) = mpsc::unbounded();
-        drop(partial_tx);
-        let (final_tx, final_rx) = oneshot::channel();
-        final_tx.send(value).ok();
+        let (tx, rx) = mpsc::unbounded();
+        tx.unbounded_send(ToolInputPayload::Full(value)).ok();
         Self {
-            partial_rx,
-            final_rx,
+            rx,
+            _phantom: PhantomData,
+        }
+    }
+
+    pub fn invalid_json(error_message: String) -> Self {
+        let (tx, rx) = mpsc::unbounded();
+        tx.unbounded_send(ToolInputPayload::InvalidJson { error_message })
+            .ok();
+        Self {
+            rx,
             _phantom: PhantomData,
         }
     }
@@ -3147,65 +3183,89 @@ impl<T: DeserializeOwned> ToolInput<T> {
     /// Wait for the final deserialized input, ignoring all partial updates.
     /// Non-streaming tools can use this to wait until the whole input is available.
     pub async fn recv(mut self) -> Result<T> {
-        // Drain any remaining partials
-        while self.partial_rx.next().await.is_some() {}
+        while let Ok(value) = self.next().await {
+            match value {
+                ToolInputPayload::Full(value) => return Ok(value),
+                ToolInputPayload::Partial(_) => {}
+                ToolInputPayload::InvalidJson { error_message } => {
+                    return Err(anyhow!(error_message));
+                }
+            }
+        }
+        Err(anyhow!("tool input was not fully received"))
+    }
+
+    pub async fn next(&mut self) -> Result<ToolInputPayload<T>> {
         let value = self
-            .final_rx
+            .rx
+            .next()
             .await
-            .map_err(|_| anyhow!("tool input was not fully received"))?;
-        serde_json::from_value(value).map_err(Into::into)
-    }
+            .ok_or_else(|| anyhow!("tool input was not fully received"))?;
 
-    /// Returns the next partial JSON snapshot, or `None` when input is complete.
-    /// Once this returns `None`, call `recv()` to get the final input.
-    pub async fn recv_partial(&mut self) -> Option<serde_json::Value> {
-        self.partial_rx.next().await
+        Ok(match value {
+            ToolInputPayload::Partial(payload) => ToolInputPayload::Partial(payload),
+            ToolInputPayload::Full(payload) => {
+                ToolInputPayload::Full(serde_json::from_value(payload)?)
+            }
+            ToolInputPayload::InvalidJson { error_message } => {
+                ToolInputPayload::InvalidJson { error_message }
+            }
+        })
     }
 
     fn cast<U: DeserializeOwned>(self) -> ToolInput<U> {
         ToolInput {
-            partial_rx: self.partial_rx,
-            final_rx: self.final_rx,
+            rx: self.rx,
             _phantom: PhantomData,
         }
     }
 }
 
+pub enum ToolInputPayload<T> {
+    Partial(serde_json::Value),
+    Full(T),
+    InvalidJson { error_message: String },
+}
+
 pub struct ToolInputSender {
-    partial_tx: mpsc::UnboundedSender<serde_json::Value>,
-    final_tx: Option<oneshot::Sender<serde_json::Value>>,
+    has_received_final: bool,
+    tx: mpsc::UnboundedSender<ToolInputPayload<serde_json::Value>>,
 }
 
 impl ToolInputSender {
     pub(crate) fn channel() -> (Self, ToolInput<serde_json::Value>) {
-        let (partial_tx, partial_rx) = mpsc::unbounded();
-        let (final_tx, final_rx) = oneshot::channel();
+        let (tx, rx) = mpsc::unbounded();
         let sender = Self {
-            partial_tx,
-            final_tx: Some(final_tx),
+            tx,
+            has_received_final: false,
         };
         let input = ToolInput {
-            partial_rx,
-            final_rx,
+            rx,
             _phantom: PhantomData,
         };
         (sender, input)
     }
 
     pub(crate) fn has_received_final(&self) -> bool {
-        self.final_tx.is_none()
+        self.has_received_final
     }
 
-    pub(crate) fn send_partial(&self, value: serde_json::Value) {
-        self.partial_tx.unbounded_send(value).ok();
+    pub fn send_partial(&mut self, payload: serde_json::Value) {
+        self.tx
+            .unbounded_send(ToolInputPayload::Partial(payload))
+            .ok();
     }
 
-    pub(crate) fn send_final(mut self, value: serde_json::Value) {
-        // Close the partial channel so recv_partial() returns None
-        self.partial_tx.close_channel();
-        if let Some(final_tx) = self.final_tx.take() {
-            final_tx.send(value).ok();
-        }
+    pub fn send_full(&mut self, payload: serde_json::Value) {
+        self.has_received_final = true;
+        self.tx.unbounded_send(ToolInputPayload::Full(payload)).ok();
+    }
+
+    pub fn send_invalid_json(&mut self, error_message: String) {
+        self.has_received_final = true;
+        self.tx
+            .unbounded_send(ToolInputPayload::InvalidJson { error_message })
+            .ok();
     }
 }
 
@@ -4251,68 +4311,78 @@ mod tests {
     ) {
         let (thread, event_stream) = setup_thread_for_test(cx).await;
 
-        cx.update(|cx| {
-            thread.update(cx, |thread, _cx| {
-                let tool_use_id = LanguageModelToolUseId::from("test_tool_id");
-                let tool_name: Arc<str> = Arc::from("test_tool");
-                let raw_input: Arc<str> = Arc::from("{invalid json");
-                let json_parse_error = "expected value at line 1 column 1".to_string();
-
-                // Call the function under test
-                let result = thread.handle_tool_use_json_parse_error_event(
-                    tool_use_id.clone(),
-                    tool_name.clone(),
-                    raw_input.clone(),
-                    json_parse_error,
-                    &event_stream,
-                );
-
-                // Verify the result is an error
-                assert!(result.is_error);
-                assert_eq!(result.tool_use_id, tool_use_id);
-                assert_eq!(result.tool_name, tool_name);
-                assert!(matches!(
-                    result.content,
-                    LanguageModelToolResultContent::Text(_)
-                ));
-
-                // Verify the tool use was added to the message content
-                {
-                    let last_message = thread.pending_message();
-                    assert_eq!(
-                        last_message.content.len(),
-                        1,
-                        "Should have one tool_use in content"
-                    );
-
-                    match &last_message.content[0] {
-                        AgentMessageContent::ToolUse(tool_use) => {
-                            assert_eq!(tool_use.id, tool_use_id);
-                            assert_eq!(tool_use.name, tool_name);
-                            assert_eq!(tool_use.raw_input, raw_input.to_string());
-                            assert!(tool_use.is_input_complete);
-                            // Should fall back to empty object for invalid JSON
-                            assert_eq!(tool_use.input, json!({}));
-                        }
-                        _ => panic!("Expected ToolUse content"),
-                    }
-                }
-
-                // Insert the tool result (simulating what the caller does)
-                thread
-                    .pending_message()
-                    .tool_results
-                    .insert(result.tool_use_id.clone(), result);
+        let tool_use_id = LanguageModelToolUseId::from("test_tool_id");
+        let tool_name: Arc<str> = Arc::from("test_tool");
+        let raw_input: Arc<str> = Arc::from("{invalid json");
+        let json_parse_error = "expected value at line 1 column 1".to_string();
+
+        let (_cancellation_tx, cancellation_rx) = watch::channel(false);
+
+        let result = cx
+            .update(|cx| {
+                thread.update(cx, |thread, cx| {
+                    // Call the function under test
+                    thread
+                        .handle_tool_use_json_parse_error_event(
+                            tool_use_id.clone(),
+                            tool_name.clone(),
+                            raw_input.clone(),
+                            json_parse_error,
+                            &event_stream,
+                            cancellation_rx,
+                            cx,
+                        )
+                        .unwrap()
+                })
+            })
+            .await;
+
+        // Verify the result is an error
+        assert!(result.is_error);
+        assert_eq!(result.tool_use_id, tool_use_id);
+        assert_eq!(result.tool_name, tool_name);
+        assert!(matches!(
+            result.content,
+            LanguageModelToolResultContent::Text(_)
+        ));
 
-                // Verify the tool result was added
+        thread.update(cx, |thread, _cx| {
+            // Verify the tool use was added to the message content
+            {
                 let last_message = thread.pending_message();
                 assert_eq!(
-                    last_message.tool_results.len(),
+                    last_message.content.len(),
                     1,
-                    "Should have one tool_result"
+                    "Should have one tool_use in content"
                 );
-                assert!(last_message.tool_results.contains_key(&tool_use_id));
-            });
-        });
+
+                match &last_message.content[0] {
+                    AgentMessageContent::ToolUse(tool_use) => {
+                        assert_eq!(tool_use.id, tool_use_id);
+                        assert_eq!(tool_use.name, tool_name);
+                        assert_eq!(tool_use.raw_input, raw_input.to_string());
+                        assert!(tool_use.is_input_complete);
+                        // Should fall back to empty object for invalid JSON
+                        assert_eq!(tool_use.input, json!({}));
+                    }
+                    _ => panic!("Expected ToolUse content"),
+                }
+            }
+
+            // Insert the tool result (simulating what the caller does)
+            thread
+                .pending_message()
+                .tool_results
+                .insert(result.tool_use_id.clone(), result);
+
+            // Verify the tool result was added
+            let last_message = thread.pending_message();
+            assert_eq!(
+                last_message.tool_results.len(),
+                1,
+                "Should have one tool_result"
+            );
+            assert!(last_message.tool_results.contains_key(&tool_use_id));
+        })
     }
 }

crates/agent/src/tools/streaming_edit_file_tool.rs 🔗

@@ -2,6 +2,7 @@ use super::edit_file_tool::EditFileTool;
 use super::restore_file_from_disk_tool::RestoreFileFromDiskTool;
 use super::save_file_tool::SaveFileTool;
 use super::tool_edit_parser::{ToolEditEvent, ToolEditParser};
+use crate::ToolInputPayload;
 use crate::{
     AgentTool, Thread, ToolCallEventStream, ToolInput,
     edit_agent::{
@@ -12,7 +13,7 @@ use crate::{
 use acp_thread::Diff;
 use action_log::ActionLog;
 use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
-use anyhow::{Context as _, Result};
+use anyhow::Result;
 use collections::HashSet;
 use futures::FutureExt as _;
 use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
@@ -188,6 +189,10 @@ pub enum StreamingEditFileToolOutput {
     },
     Error {
         error: String,
+        #[serde(default)]
+        input_path: Option<PathBuf>,
+        #[serde(default)]
+        diff: String,
     },
 }
 
@@ -195,6 +200,8 @@ impl StreamingEditFileToolOutput {
     pub fn error(error: impl Into<String>) -> Self {
         Self::Error {
             error: error.into(),
+            input_path: None,
+            diff: String::new(),
         }
     }
 }
@@ -215,7 +222,24 @@ impl std::fmt::Display for StreamingEditFileToolOutput {
                     )
                 }
             }
-            StreamingEditFileToolOutput::Error { error } => write!(f, "{error}"),
+            StreamingEditFileToolOutput::Error {
+                error,
+                diff,
+                input_path,
+            } => {
+                write!(f, "{error}\n")?;
+                if let Some(input_path) = input_path
+                    && !diff.is_empty()
+                {
+                    write!(
+                        f,
+                        "Edited {}:\n\n```diff\n{diff}\n```",
+                        input_path.display()
+                    )
+                } else {
+                    write!(f, "No edits were made.")
+                }
+            }
         }
     }
 }
@@ -233,6 +257,14 @@ pub struct StreamingEditFileTool {
     language_registry: Arc<LanguageRegistry>,
 }
 
+enum EditSessionResult {
+    Completed(EditSession),
+    Failed {
+        error: String,
+        session: Option<EditSession>,
+    },
+}
+
 impl StreamingEditFileTool {
     pub fn new(
         project: Entity<Project>,
@@ -276,6 +308,158 @@ impl StreamingEditFileTool {
             });
         }
     }
+
+    async fn ensure_buffer_saved(&self, buffer: &Entity<Buffer>, cx: &mut AsyncApp) {
+        let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| {
+            let settings = language_settings::LanguageSettings::for_buffer(buffer, cx);
+            settings.format_on_save != FormatOnSave::Off
+        });
+
+        if format_on_save_enabled {
+            self.project
+                .update(cx, |project, cx| {
+                    project.format(
+                        HashSet::from_iter([buffer.clone()]),
+                        LspFormatTarget::Buffers,
+                        false,
+                        FormatTrigger::Save,
+                        cx,
+                    )
+                })
+                .await
+                .log_err();
+        }
+
+        self.project
+            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+            .await
+            .log_err();
+
+        self.action_log.update(cx, |log, cx| {
+            log.buffer_edited(buffer.clone(), cx);
+        });
+    }
+
+    async fn process_streaming_edits(
+        &self,
+        input: &mut ToolInput<StreamingEditFileToolInput>,
+        event_stream: &ToolCallEventStream,
+        cx: &mut AsyncApp,
+    ) -> EditSessionResult {
+        let mut session: Option<EditSession> = None;
+        let mut last_partial: Option<StreamingEditFileToolPartialInput> = None;
+
+        loop {
+            futures::select! {
+                payload = input.next().fuse() => {
+                    match payload {
+                        Ok(payload) => match payload {
+                            ToolInputPayload::Partial(partial) => {
+                                if let Ok(parsed) = serde_json::from_value::<StreamingEditFileToolPartialInput>(partial) {
+                                    let path_complete = parsed.path.is_some()
+                                        && parsed.path.as_ref() == last_partial.as_ref().and_then(|partial| partial.path.as_ref());
+
+                                    last_partial = Some(parsed.clone());
+
+                                    if session.is_none()
+                                        && path_complete
+                                        && let StreamingEditFileToolPartialInput {
+                                            path: Some(path),
+                                            display_description: Some(display_description),
+                                            mode: Some(mode),
+                                            ..
+                                        } = &parsed
+                                    {
+                                        match EditSession::new(
+                                            PathBuf::from(path),
+                                            display_description,
+                                            *mode,
+                                            self,
+                                            event_stream,
+                                            cx,
+                                        )
+                                        .await
+                                        {
+                                            Ok(created_session) => session = Some(created_session),
+                                            Err(error) => {
+                                                log::error!("Failed to create edit session: {}", error);
+                                                return EditSessionResult::Failed {
+                                                    error,
+                                                    session: None,
+                                                };
+                                            }
+                                        }
+                                    }
+
+                                    if let Some(current_session) = &mut session
+                                        && let Err(error) = current_session.process(parsed, self, event_stream, cx)
+                                    {
+                                        log::error!("Failed to process edit: {}", error);
+                                        return EditSessionResult::Failed { error, session };
+                                    }
+                                }
+                            }
+                            ToolInputPayload::Full(full_input) => {
+                                let mut session = if let Some(session) = session {
+                                    session
+                                } else {
+                                    match EditSession::new(
+                                        full_input.path.clone(),
+                                        &full_input.display_description,
+                                        full_input.mode,
+                                        self,
+                                        event_stream,
+                                        cx,
+                                    )
+                                    .await
+                                    {
+                                        Ok(created_session) => created_session,
+                                        Err(error) => {
+                                            log::error!("Failed to create edit session: {}", error);
+                                            return EditSessionResult::Failed {
+                                                error,
+                                                session: None,
+                                            };
+                                        }
+                                    }
+                                };
+
+                                return match session.finalize(full_input, self, event_stream, cx).await {
+                                    Ok(()) => EditSessionResult::Completed(session),
+                                    Err(error) => {
+                                        log::error!("Failed to finalize edit: {}", error);
+                                        EditSessionResult::Failed {
+                                            error,
+                                            session: Some(session),
+                                        }
+                                    }
+                                };
+                            }
+                            ToolInputPayload::InvalidJson { error_message } => {
+                                log::error!("Received invalid JSON: {error_message}");
+                                return EditSessionResult::Failed {
+                                    error: error_message,
+                                    session,
+                                };
+                            }
+                        },
+                        Err(error) => {
+                            return EditSessionResult::Failed {
+                                error: format!("Failed to receive tool input: {error}"),
+                                session,
+                            };
+                        }
+                    }
+                }
+                _ = event_stream.cancelled_by_user().fuse() => {
+                    return EditSessionResult::Failed {
+                        error: "Edit cancelled by user".to_string(),
+                        session,
+                    };
+                }
+            }
+        }
+    }
 }
 
 impl AgentTool for StreamingEditFileTool {
@@ -348,94 +532,40 @@ impl AgentTool for StreamingEditFileTool {
         cx: &mut App,
     ) -> Task<Result<Self::Output, Self::Output>> {
         cx.spawn(async move |cx: &mut AsyncApp| {
-            let mut state: Option<EditSession> = None;
-            let mut last_partial: Option<StreamingEditFileToolPartialInput> = None;
-            loop {
-                futures::select! {
-                    partial = input.recv_partial().fuse() => {
-                        let Some(partial_value) = partial else { break };
-                        if let Ok(parsed) = serde_json::from_value::<StreamingEditFileToolPartialInput>(partial_value) {
-                            let path_complete = parsed.path.is_some()
-                                && parsed.path.as_ref() == last_partial.as_ref().and_then(|p| p.path.as_ref());
-
-                            last_partial = Some(parsed.clone());
-
-                            if state.is_none()
-                                && path_complete
-                                && let StreamingEditFileToolPartialInput {
-                                    path: Some(path),
-                                    display_description: Some(display_description),
-                                    mode: Some(mode),
-                                    ..
-                                } = &parsed
-                            {
-                                match EditSession::new(
-                                    &PathBuf::from(path),
-                                    display_description,
-                                    *mode,
-                                    &self,
-                                    &event_stream,
-                                    cx,
-                                )
-                                .await
-                                {
-                                    Ok(session) => state = Some(session),
-                                    Err(e) => {
-                                        log::error!("Failed to create edit session: {}", e);
-                                        return Err(e);
-                                    }
-                                }
-                            }
-
-                            if let Some(state) = &mut state {
-                                if let Err(e) = state.process(parsed, &self, &event_stream, cx) {
-                                    log::error!("Failed to process edit: {}", e);
-                                    return Err(e);
-                                }
-                            }
-                        }
-                    }
-                    _ = event_stream.cancelled_by_user().fuse() => {
-                        return Err(StreamingEditFileToolOutput::error("Edit cancelled by user"));
-                    }
-                }
-            }
-            let full_input =
-                input
-                    .recv()
-                    .await
-                    .map_err(|e| {
-                        let err = StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}"));
-                        log::error!("Failed to receive tool input: {e}");
-                        err
-                    })?;
-
-            let mut state = if let Some(state) = state {
-                state
-            } else {
-                match EditSession::new(
-                    &full_input.path,
-                    &full_input.display_description,
-                    full_input.mode,
-                    &self,
-                    &event_stream,
-                    cx,
-                )
+            match self
+                .process_streaming_edits(&mut input, &event_stream, cx)
                 .await
-                {
-                    Ok(session) => session,
-                    Err(e) => {
-                        log::error!("Failed to create edit session: {}", e);
-                        return Err(e);
-                    }
+            {
+                EditSessionResult::Completed(session) => {
+                    self.ensure_buffer_saved(&session.buffer, cx).await;
+                    let (new_text, diff) = session.compute_new_text_and_diff(cx).await;
+                    Ok(StreamingEditFileToolOutput::Success {
+                        old_text: session.old_text.clone(),
+                        new_text,
+                        input_path: session.input_path,
+                        diff,
+                    })
                 }
-            };
-            match state.finalize(full_input, &self, &event_stream, cx).await {
-                Ok(output) => Ok(output),
-                Err(e) => {
-                    log::error!("Failed to finalize edit: {}", e);
-                    Err(e)
+                EditSessionResult::Failed {
+                    error,
+                    session: Some(session),
+                } => {
+                    self.ensure_buffer_saved(&session.buffer, cx).await;
+                    let (_new_text, diff) = session.compute_new_text_and_diff(cx).await;
+                    Err(StreamingEditFileToolOutput::Error {
+                        error,
+                        input_path: Some(session.input_path),
+                        diff,
+                    })
                 }
+                EditSessionResult::Failed {
+                    error,
+                    session: None,
+                } => Err(StreamingEditFileToolOutput::Error {
+                    error,
+                    input_path: None,
+                    diff: String::new(),
+                }),
             }
         })
     }
@@ -472,6 +602,7 @@ impl AgentTool for StreamingEditFileTool {
 
 pub struct EditSession {
     abs_path: PathBuf,
+    input_path: PathBuf,
     buffer: Entity<Buffer>,
     old_text: Arc<String>,
     diff: Entity<Diff>,
@@ -518,23 +649,21 @@ impl EditPipeline {
 
 impl EditSession {
     async fn new(
-        path: &PathBuf,
+        path: PathBuf,
         display_description: &str,
         mode: StreamingEditFileMode,
         tool: &StreamingEditFileTool,
         event_stream: &ToolCallEventStream,
         cx: &mut AsyncApp,
-    ) -> Result<Self, StreamingEditFileToolOutput> {
-        let project_path = cx
-            .update(|cx| resolve_path(mode, &path, &tool.project, cx))
-            .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+    ) -> Result<Self, String> {
+        let project_path = cx.update(|cx| resolve_path(mode, &path, &tool.project, cx))?;
 
         let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx))
         else {
-            return Err(StreamingEditFileToolOutput::error(format!(
+            return Err(format!(
                 "Worktree at '{}' does not exist",
                 path.to_string_lossy()
-            )));
+            ));
         };
 
         event_stream.update_fields(
@@ -543,13 +672,13 @@ impl EditSession {
 
         cx.update(|cx| tool.authorize(&path, &display_description, event_stream, cx))
             .await
-            .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+            .map_err(|e| e.to_string())?;
 
         let buffer = tool
             .project
             .update(cx, |project, cx| project.open_buffer(project_path, cx))
             .await
-            .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+            .map_err(|e| e.to_string())?;
 
         ensure_buffer_saved(&buffer, &abs_path, tool, cx)?;
 
@@ -578,6 +707,7 @@ impl EditSession {
 
         Ok(Self {
             abs_path,
+            input_path: path,
             buffer,
             old_text,
             diff,
@@ -594,22 +724,20 @@ impl EditSession {
         tool: &StreamingEditFileTool,
         event_stream: &ToolCallEventStream,
         cx: &mut AsyncApp,
-    ) -> Result<StreamingEditFileToolOutput, StreamingEditFileToolOutput> {
-        let old_text = self.old_text.clone();
-
+    ) -> Result<(), String> {
         match input.mode {
             StreamingEditFileMode::Write => {
-                let content = input.content.ok_or_else(|| {
-                    StreamingEditFileToolOutput::error("'content' field is required for write mode")
-                })?;
+                let content = input
+                    .content
+                    .ok_or_else(|| "'content' field is required for write mode".to_string())?;
 
                 let events = self.parser.finalize_content(&content);
                 self.process_events(&events, tool, event_stream, cx)?;
             }
             StreamingEditFileMode::Edit => {
-                let edits = input.edits.ok_or_else(|| {
-                    StreamingEditFileToolOutput::error("'edits' field is required for edit mode")
-                })?;
+                let edits = input
+                    .edits
+                    .ok_or_else(|| "'edits' field is required for edit mode".to_string())?;
                 let events = self.parser.finalize_edits(&edits);
                 self.process_events(&events, tool, event_stream, cx)?;
 
@@ -625,53 +753,15 @@ impl EditSession {
                 }
             }
         }
+        Ok(())
+    }
 
-        let format_on_save_enabled = self.buffer.read_with(cx, |buffer, cx| {
-            let settings = language_settings::LanguageSettings::for_buffer(buffer, cx);
-            settings.format_on_save != FormatOnSave::Off
-        });
-
-        if format_on_save_enabled {
-            tool.action_log.update(cx, |log, cx| {
-                log.buffer_edited(self.buffer.clone(), cx);
-            });
-
-            let format_task = tool.project.update(cx, |project, cx| {
-                project.format(
-                    HashSet::from_iter([self.buffer.clone()]),
-                    LspFormatTarget::Buffers,
-                    false,
-                    FormatTrigger::Save,
-                    cx,
-                )
-            });
-            futures::select! {
-                result = format_task.fuse() => { result.log_err(); },
-                _ = event_stream.cancelled_by_user().fuse() => {
-                    return Err(StreamingEditFileToolOutput::error("Edit cancelled by user"));
-                }
-            };
-        }
-
-        let save_task = tool.project.update(cx, |project, cx| {
-            project.save_buffer(self.buffer.clone(), cx)
-        });
-        futures::select! {
-            result = save_task.fuse() => { result.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; },
-            _ = event_stream.cancelled_by_user().fuse() => {
-                return Err(StreamingEditFileToolOutput::error("Edit cancelled by user"));
-            }
-        };
-
-        tool.action_log.update(cx, |log, cx| {
-            log.buffer_edited(self.buffer.clone(), cx);
-        });
-
+    async fn compute_new_text_and_diff(&self, cx: &mut AsyncApp) -> (String, String) {
         let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
         let (new_text, unified_diff) = cx
             .background_spawn({
                 let new_snapshot = new_snapshot.clone();
-                let old_text = old_text.clone();
+                let old_text = self.old_text.clone();
                 async move {
                     let new_text = new_snapshot.text();
                     let diff = language::unified_diff(&old_text, &new_text);
@@ -679,14 +769,7 @@ impl EditSession {
                 }
             })
             .await;
-
-        let output = StreamingEditFileToolOutput::Success {
-            input_path: input.path,
-            new_text,
-            old_text: old_text.clone(),
-            diff: unified_diff,
-        };
-        Ok(output)
+        (new_text, unified_diff)
     }
 
     fn process(
@@ -695,7 +778,7 @@ impl EditSession {
         tool: &StreamingEditFileTool,
         event_stream: &ToolCallEventStream,
         cx: &mut AsyncApp,
-    ) -> Result<(), StreamingEditFileToolOutput> {
+    ) -> Result<(), String> {
         match &self.mode {
             StreamingEditFileMode::Write => {
                 if let Some(content) = &partial.content {
@@ -719,7 +802,7 @@ impl EditSession {
         tool: &StreamingEditFileTool,
         event_stream: &ToolCallEventStream,
         cx: &mut AsyncApp,
-    ) -> Result<(), StreamingEditFileToolOutput> {
+    ) -> Result<(), String> {
         for event in events {
             match event {
                 ToolEditEvent::ContentChunk { chunk } => {
@@ -969,14 +1052,14 @@ fn extract_match(
     buffer: &Entity<Buffer>,
     edit_index: &usize,
     cx: &mut AsyncApp,
-) -> Result<Range<usize>, StreamingEditFileToolOutput> {
+) -> Result<Range<usize>, String> {
     match matches.len() {
-        0 => Err(StreamingEditFileToolOutput::error(format!(
+        0 => Err(format!(
             "Could not find matching text for edit at index {}. \
                 The old_text did not match any content in the file. \
                 Please read the file again to get the current content.",
             edit_index,
-        ))),
+        )),
         1 => Ok(matches.into_iter().next().unwrap()),
         _ => {
             let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
@@ -985,12 +1068,12 @@ fn extract_match(
                 .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string())
                 .collect::<Vec<_>>()
                 .join(", ");
-            Err(StreamingEditFileToolOutput::error(format!(
+            Err(format!(
                 "Edit {} matched multiple locations in the file at lines: {}. \
                     Please provide more context in old_text to uniquely \
                     identify the location.",
                 edit_index, lines
-            )))
+            ))
         }
     }
 }
@@ -1022,7 +1105,7 @@ fn ensure_buffer_saved(
     abs_path: &PathBuf,
     tool: &StreamingEditFileTool,
     cx: &mut AsyncApp,
-) -> Result<(), StreamingEditFileToolOutput> {
+) -> Result<(), String> {
     let last_read_mtime = tool
         .action_log
         .read_with(cx, |log, _| log.file_read_time(abs_path));
@@ -1063,15 +1146,14 @@ fn ensure_buffer_saved(
                          then ask them to save or revert the file manually and inform you when it's ok to proceed."
             }
         };
-        return Err(StreamingEditFileToolOutput::error(message));
+        return Err(message.to_string());
     }
 
     if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
         if current != last_read {
-            return Err(StreamingEditFileToolOutput::error(
-                "The file has been modified since you last read it. \
-                             Please read the file again to get the current state before editing it.",
-            ));
+            return Err("The file has been modified since you last read it. \
+                    Please read the file again to get the current state before editing it."
+                .to_string());
         }
     }
 
@@ -1083,56 +1165,63 @@ fn resolve_path(
     path: &PathBuf,
     project: &Entity<Project>,
     cx: &mut App,
-) -> Result<ProjectPath> {
+) -> Result<ProjectPath, String> {
     let project = project.read(cx);
 
     match mode {
         StreamingEditFileMode::Edit => {
             let path = project
                 .find_project_path(&path, cx)
-                .context("Can't edit file: path not found")?;
+                .ok_or_else(|| "Can't edit file: path not found".to_string())?;
 
             let entry = project
                 .entry_for_path(&path, cx)
-                .context("Can't edit file: path not found")?;
+                .ok_or_else(|| "Can't edit file: path not found".to_string())?;
 
-            anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
-            Ok(path)
+            if entry.is_file() {
+                Ok(path)
+            } else {
+                Err("Can't edit file: path is a directory".to_string())
+            }
         }
         StreamingEditFileMode::Write => {
             if let Some(path) = project.find_project_path(&path, cx)
                 && let Some(entry) = project.entry_for_path(&path, cx)
             {
-                anyhow::ensure!(entry.is_file(), "Can't write to file: path is a directory");
-                return Ok(path);
+                if entry.is_file() {
+                    return Ok(path);
+                } else {
+                    return Err("Can't write to file: path is a directory".to_string());
+                }
             }
 
-            let parent_path = path.parent().context("Can't create file: incorrect path")?;
+            let parent_path = path
+                .parent()
+                .ok_or_else(|| "Can't create file: incorrect path".to_string())?;
 
             let parent_project_path = project.find_project_path(&parent_path, cx);
 
             let parent_entry = parent_project_path
                 .as_ref()
                 .and_then(|path| project.entry_for_path(path, cx))
-                .context("Can't create file: parent directory doesn't exist")?;
+                .ok_or_else(|| "Can't create file: parent directory doesn't exist")?;
 
-            anyhow::ensure!(
-                parent_entry.is_dir(),
-                "Can't create file: parent is not a directory"
-            );
+            if !parent_entry.is_dir() {
+                return Err("Can't create file: parent is not a directory".to_string());
+            }
 
             let file_name = path
                 .file_name()
                 .and_then(|file_name| file_name.to_str())
                 .and_then(|file_name| RelPath::unix(file_name).ok())
-                .context("Can't create file: invalid filename")?;
+                .ok_or_else(|| "Can't create file: invalid filename".to_string())?;
 
             let new_file_path = parent_project_path.map(|parent| ProjectPath {
                 path: parent.path.join(file_name),
                 ..parent
             });
 
-            new_file_path.context("Can't create file")
+            new_file_path.ok_or_else(|| "Can't create file".to_string())
         }
     }
 }
@@ -1382,10 +1471,17 @@ mod tests {
             })
             .await;
 
-        let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+        let StreamingEditFileToolOutput::Error {
+            error,
+            diff,
+            input_path,
+        } = result.unwrap_err()
+        else {
             panic!("expected error");
         };
         assert_eq!(error, "Can't edit file: path not found");
+        assert!(diff.is_empty());
+        assert_eq!(input_path, None);
     }
 
     #[gpui::test]
@@ -1411,7 +1507,7 @@ mod tests {
             })
             .await;
 
-        let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+        let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else {
             panic!("expected error");
         };
         assert!(
@@ -1424,7 +1520,7 @@ mod tests {
     async fn test_streaming_early_buffer_open(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1447,7 +1543,7 @@ mod tests {
         cx.run_until_parked();
 
         // Now send the final complete input
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Edit lines",
             "path": "root/file.txt",
             "mode": "edit",
@@ -1465,7 +1561,7 @@ mod tests {
     async fn test_streaming_path_completeness_heuristic(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "hello world"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1485,7 +1581,7 @@ mod tests {
         cx.run_until_parked();
 
         // Send final
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Overwrite file",
             "path": "root/file.txt",
             "mode": "write",
@@ -1503,7 +1599,7 @@ mod tests {
     async fn test_streaming_cancellation_during_partials(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "hello world"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver, mut cancellation_tx) =
             ToolCallEventStream::test_with_cancellation();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1521,7 +1617,7 @@ mod tests {
         drop(sender);
 
         let result = task.await;
-        let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+        let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else {
             panic!("expected error");
         };
         assert!(
@@ -1537,7 +1633,7 @@ mod tests {
             json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}),
         )
         .await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1578,7 +1674,7 @@ mod tests {
         cx.run_until_parked();
 
         // Send final complete input
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Edit multiple lines",
             "path": "root/file.txt",
             "mode": "edit",
@@ -1601,7 +1697,7 @@ mod tests {
     #[gpui::test]
     async fn test_streaming_create_file_with_partials(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1625,7 +1721,7 @@ mod tests {
         cx.run_until_parked();
 
         // Final with full content
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Create new file",
             "path": "root/dir/new_file.txt",
             "mode": "write",
@@ -1643,12 +1739,12 @@ mod tests {
     async fn test_streaming_no_partials_direct_final(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
         // Send final immediately with no partials (simulates non-streaming path)
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Edit lines",
             "path": "root/file.txt",
             "mode": "edit",
@@ -1669,7 +1765,7 @@ mod tests {
             json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}),
         )
         .await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1739,7 +1835,7 @@ mod tests {
         );
 
         // Send final complete input
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Edit multiple lines",
             "path": "root/file.txt",
             "mode": "edit",
@@ -1767,7 +1863,7 @@ mod tests {
     async fn test_streaming_incremental_three_edits(cx: &mut TestAppContext) {
         let (tool, project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1835,7 +1931,7 @@ mod tests {
         assert_eq!(buffer_text.as_deref(), Some("AAA\nbbb\nCCC\nddd\nEEEeee\n"));
 
         // Send final
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Edit three lines",
             "path": "root/file.txt",
             "mode": "edit",
@@ -1857,7 +1953,7 @@ mod tests {
     async fn test_streaming_edit_failure_mid_stream(cx: &mut TestAppContext) {
         let (tool, project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1893,16 +1989,17 @@ mod tests {
         }));
         cx.run_until_parked();
 
-        // Verify edit 1 was applied
-        let buffer_text = project.update(cx, |project, cx| {
+        let buffer = project.update(cx, |project, cx| {
             let pp = project
                 .find_project_path(&PathBuf::from("root/file.txt"), cx)
                 .unwrap();
-            project.get_open_buffer(&pp, cx).map(|b| b.read(cx).text())
+            project.get_open_buffer(&pp, cx).unwrap()
         });
+
+        // Verify edit 1 was applied
+        let buffer_text = buffer.read_with(cx, |buffer, _cx| buffer.text());
         assert_eq!(
-            buffer_text.as_deref(),
-            Some("MODIFIED\nline 2\nline 3\n"),
+            buffer_text, "MODIFIED\nline 2\nline 3\n",
             "First edit should be applied even though second edit will fail"
         );
 
@@ -1925,20 +2022,32 @@ mod tests {
         drop(sender);
 
         let result = task.await;
-        let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+        let StreamingEditFileToolOutput::Error {
+            error,
+            diff,
+            input_path,
+        } = result.unwrap_err()
+        else {
             panic!("expected error");
         };
+
         assert!(
             error.contains("Could not find matching text for edit at index 1"),
             "Expected error about edit 1 failing, got: {error}"
         );
+        // Ensure that first edit was applied successfully and that we saved the buffer
+        assert_eq!(input_path, Some(PathBuf::from("root/file.txt")));
+        assert_eq!(
+            diff,
+            "@@ -1,3 +1,3 @@\n-line 1\n+MODIFIED\n line 2\n line 3\n"
+        );
     }
 
     #[gpui::test]
     async fn test_streaming_single_edit_no_incremental(cx: &mut TestAppContext) {
         let (tool, project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "hello world\n"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1975,7 +2084,7 @@ mod tests {
         );
 
         // Send final — the edit is applied during finalization
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Single edit",
             "path": "root/file.txt",
             "mode": "edit",
@@ -1993,7 +2102,7 @@ mod tests {
     async fn test_streaming_input_partials_then_final(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
-        let (sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
+        let (mut sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
             ToolInput::test();
         let (event_stream, _event_rx) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -2020,7 +2129,7 @@ mod tests {
         cx.run_until_parked();
 
         // Send the final complete input
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Edit lines",
             "path": "root/file.txt",
             "mode": "edit",
@@ -2038,7 +2147,7 @@ mod tests {
     async fn test_streaming_input_sender_dropped_before_final(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "hello world\n"})).await;
-        let (sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
+        let (mut sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
             ToolInput::test();
         let (event_stream, _event_rx) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -2064,7 +2173,7 @@ mod tests {
         // Create a channel and send multiple partials before a final, then use
         // ToolInput::resolved-style immediate delivery to confirm recv() works
         // when partials are already buffered.
-        let (sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
+        let (mut sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
             ToolInput::test();
         let (event_stream, _event_rx) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -2077,7 +2186,7 @@ mod tests {
             "path": "root/dir/new.txt",
             "mode": "write"
         }));
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Create",
             "path": "root/dir/new.txt",
             "mode": "write",
@@ -2109,13 +2218,13 @@ mod tests {
 
         let result = test_resolve_path(&mode, "root/dir/subdir", cx);
         assert_eq!(
-            result.await.unwrap_err().to_string(),
+            result.await.unwrap_err(),
             "Can't write to file: path is a directory"
         );
 
         let result = test_resolve_path(&mode, "root/dir/nonexistent_dir/new.txt", cx);
         assert_eq!(
-            result.await.unwrap_err().to_string(),
+            result.await.unwrap_err(),
             "Can't create file: parent directory doesn't exist"
         );
     }
@@ -2133,14 +2242,11 @@ mod tests {
         assert_resolved_path_eq(result.await, rel_path(path_without_root));
 
         let result = test_resolve_path(&mode, "root/nonexistent.txt", cx);
-        assert_eq!(
-            result.await.unwrap_err().to_string(),
-            "Can't edit file: path not found"
-        );
+        assert_eq!(result.await.unwrap_err(), "Can't edit file: path not found");
 
         let result = test_resolve_path(&mode, "root/dir", cx);
         assert_eq!(
-            result.await.unwrap_err().to_string(),
+            result.await.unwrap_err(),
             "Can't edit file: path is a directory"
         );
     }

crates/language_model/src/fake_provider.rs 🔗

@@ -125,6 +125,7 @@ pub struct FakeLanguageModel {
     >,
     forbid_requests: AtomicBool,
     supports_thinking: AtomicBool,
+    supports_streaming_tools: AtomicBool,
 }
 
 impl Default for FakeLanguageModel {
@@ -137,6 +138,7 @@ impl Default for FakeLanguageModel {
             current_completion_txs: Mutex::new(Vec::new()),
             forbid_requests: AtomicBool::new(false),
             supports_thinking: AtomicBool::new(false),
+            supports_streaming_tools: AtomicBool::new(false),
         }
     }
 }
@@ -169,6 +171,10 @@ impl FakeLanguageModel {
         self.supports_thinking.store(supports, SeqCst);
     }
 
+    pub fn set_supports_streaming_tools(&self, supports: bool) {
+        self.supports_streaming_tools.store(supports, SeqCst);
+    }
+
     pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
         self.current_completion_txs
             .lock()
@@ -282,6 +288,10 @@ impl LanguageModel for FakeLanguageModel {
         self.supports_thinking.load(SeqCst)
     }
 
+    fn supports_streaming_tools(&self) -> bool {
+        self.supports_streaming_tools.load(SeqCst)
+    }
+
     fn telemetry_id(&self) -> String {
         "fake".to_string()
     }