agent: Stream content in `StreamingEditFileTool` (#50187)

Bennet Bo Fenner created

Release Notes:

- N/A

Change summary

crates/agent/src/tools/streaming_edit_file_tool.rs | 370 +++++++++++++++
1 file changed, 356 insertions(+), 14 deletions(-)

Detailed changes

crates/agent/src/tools/streaming_edit_file_tool.rs 🔗

@@ -139,6 +139,8 @@ enum StreamingEditState {
         buffer: Entity<Buffer>,
         old_text: Arc<String>,
         diff: Entity<Diff>,
+        mode: StreamingEditFileMode,
+        last_content_len: usize,
         edit_state: IncrementalEditState,
         _finalize_diff_guard: Deferred<Box<dyn FnOnce()>>,
     },
@@ -346,21 +348,36 @@ impl StreamingEditState {
                 buffer,
                 edit_state,
                 diff,
+                mode,
+                last_content_len,
                 ..
-            } => {
-                if let Some(edits) = partial.edits {
-                    Self::process_streaming_edits(
-                        buffer,
-                        diff,
-                        edit_state,
-                        &edits,
-                        abs_path,
-                        tool,
-                        event_stream,
-                        cx,
-                    )?;
+            } => match mode {
+                StreamingEditFileMode::Create | StreamingEditFileMode::Overwrite => {
+                    if let Some(content) = &partial.content {
+                        Self::process_streaming_content(
+                            buffer,
+                            diff,
+                            last_content_len,
+                            content,
+                            cx,
+                        )?;
+                    }
                 }
-            }
+                StreamingEditFileMode::Edit => {
+                    if let Some(edits) = partial.edits {
+                        Self::process_streaming_edits(
+                            buffer,
+                            diff,
+                            edit_state,
+                            &edits,
+                            abs_path,
+                            tool,
+                            event_stream,
+                            cx,
+                        )?;
+                    }
+                }
+            },
         }
         Ok(())
     }
@@ -375,7 +392,7 @@ impl StreamingEditState {
     ) -> Result<Self, StreamingEditFileToolOutput> {
         let path = PathBuf::from(path_str);
         let project_path = cx
-            .update(|cx| resolve_path(mode, &path, &tool.project, cx))
+            .update(|cx| resolve_path(mode.clone(), &path, &tool.project, cx))
             .map_err(|e| StreamingEditFileToolOutput::Error {
                 error: e.to_string(),
             })?;
@@ -430,11 +447,47 @@ impl StreamingEditState {
             buffer,
             old_text,
             diff,
+            mode,
+            last_content_len: 0,
             edit_state: IncrementalEditState::default(),
             _finalize_diff_guard: finalize_diff_guard,
         })
     }
 
+    fn process_streaming_content(
+        buffer: &Entity<Buffer>,
+        diff: &Entity<Diff>,
+        last_content_len: &mut usize,
+        content: &str,
+        cx: &mut AsyncApp,
+    ) -> Result<(), StreamingEditFileToolOutput> {
+        let new_len = content.len();
+        if new_len > *last_content_len {
+            let new_chunk = &content[*last_content_len..];
+            cx.update(|cx| {
+                buffer.update(cx, |buffer, cx| {
+                    // On the first update, replace the entire buffer (handles Overwrite
+                    // clearing existing content). For Create the buffer is already empty
+                    // so 0..0 is a no-op range prefix.
+                    let insert_at = if *last_content_len == 0 {
+                        0..buffer.len()
+                    } else {
+                        let len = buffer.len();
+                        len..len
+                    };
+                    buffer.edit([(insert_at, new_chunk)], None, cx);
+                });
+            });
+            *last_content_len = new_len;
+
+            let anchor_range = buffer.read_with(cx, |buffer, _cx| {
+                buffer.anchor_range_between(0..buffer.len())
+            });
+            diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx));
+        }
+        Ok(())
+    }
+
     fn process_streaming_edits(
         buffer: &Entity<Buffer>,
         diff: &Entity<Diff>,
@@ -4495,6 +4548,295 @@ mod tests {
         }
     }
 
+    #[gpui::test]
+    async fn test_streaming_create_content_streamed(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = project::FakeFs::new(cx.executor());
+        fs.insert_tree("/root", json!({"dir": {}})).await;
+        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+        let model = Arc::new(FakeLanguageModel::default());
+        let thread = cx.new(|cx| {
+            crate::Thread::new(
+                project.clone(),
+                cx.new(|_cx| ProjectContext::default()),
+                context_server_registry,
+                Templates::new(),
+                Some(model),
+                cx,
+            )
+        });
+
+        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (event_stream, _receiver) = ToolCallEventStream::test();
+
+        let tool = Arc::new(StreamingEditFileTool::new(
+            project.clone(),
+            thread.downgrade(),
+            language_registry,
+        ));
+
+        let task = cx.update(|cx| tool.run(input, event_stream, cx));
+
+        // Transition to BufferResolved
+        sender.send_partial(json!({
+            "display_description": "Create new file",
+            "path": "root/dir/new_file.txt",
+            "mode": "create"
+        }));
+        cx.run_until_parked();
+
+        // Stream content incrementally
+        sender.send_partial(json!({
+            "display_description": "Create new file",
+            "path": "root/dir/new_file.txt",
+            "mode": "create",
+            "content": "line 1\n"
+        }));
+        cx.run_until_parked();
+
+        // Verify buffer has partial content
+        let buffer = project.update(cx, |project, cx| {
+            let path = project
+                .find_project_path("root/dir/new_file.txt", cx)
+                .unwrap();
+            project.get_open_buffer(&path, cx).unwrap()
+        });
+        assert_eq!(buffer.read_with(cx, |b, _| b.text()), "line 1\n");
+
+        // Stream more content
+        sender.send_partial(json!({
+            "display_description": "Create new file",
+            "path": "root/dir/new_file.txt",
+            "mode": "create",
+            "content": "line 1\nline 2\n"
+        }));
+        cx.run_until_parked();
+        assert_eq!(buffer.read_with(cx, |b, _| b.text()), "line 1\nline 2\n");
+
+        // Stream final chunk
+        sender.send_partial(json!({
+            "display_description": "Create new file",
+            "path": "root/dir/new_file.txt",
+            "mode": "create",
+            "content": "line 1\nline 2\nline 3\n"
+        }));
+        cx.run_until_parked();
+        assert_eq!(
+            buffer.read_with(cx, |b, _| b.text()),
+            "line 1\nline 2\nline 3\n"
+        );
+
+        // Send final input
+        sender.send_final(json!({
+            "display_description": "Create new file",
+            "path": "root/dir/new_file.txt",
+            "mode": "create",
+            "content": "line 1\nline 2\nline 3\n"
+        }));
+
+        let result = task.await;
+        let StreamingEditFileToolOutput::Success { new_text, .. } = result.unwrap() else {
+            panic!("expected success");
+        };
+        assert_eq!(new_text, "line 1\nline 2\nline 3\n");
+    }
+
+    #[gpui::test]
+    async fn test_streaming_overwrite_diff_revealed_during_streaming(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = project::FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/root",
+            json!({
+                "file.txt": "old line 1\nold line 2\nold line 3\n"
+            }),
+        )
+        .await;
+        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+        let model = Arc::new(FakeLanguageModel::default());
+        let thread = cx.new(|cx| {
+            crate::Thread::new(
+                project.clone(),
+                cx.new(|_cx| ProjectContext::default()),
+                context_server_registry,
+                Templates::new(),
+                Some(model),
+                cx,
+            )
+        });
+
+        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (event_stream, mut receiver) = ToolCallEventStream::test();
+
+        let tool = Arc::new(StreamingEditFileTool::new(
+            project.clone(),
+            thread.downgrade(),
+            language_registry,
+        ));
+
+        let task = cx.update(|cx| tool.run(input, event_stream, cx));
+
+        // Transition to BufferResolved
+        sender.send_partial(json!({
+            "display_description": "Overwrite file",
+            "path": "root/file.txt",
+            "mode": "overwrite"
+        }));
+        cx.run_until_parked();
+
+        // Get the diff entity from the event stream
+        receiver.expect_update_fields().await;
+        let diff = receiver.expect_diff().await;
+
+        // Diff starts pending with no revealed ranges
+        diff.read_with(cx, |diff, cx| {
+            assert!(matches!(diff, Diff::Pending(_)));
+            assert!(!diff.has_revealed_range(cx));
+        });
+
+        // Stream first content chunk
+        sender.send_partial(json!({
+            "display_description": "Overwrite file",
+            "path": "root/file.txt",
+            "mode": "overwrite",
+            "content": "new line 1\n"
+        }));
+        cx.run_until_parked();
+
+        // Diff should now have revealed ranges showing the new content
+        diff.read_with(cx, |diff, cx| {
+            assert!(diff.has_revealed_range(cx));
+        });
+
+        // Send final input
+        sender.send_final(json!({
+            "display_description": "Overwrite file",
+            "path": "root/file.txt",
+            "mode": "overwrite",
+            "content": "new line 1\nnew line 2\n"
+        }));
+
+        let result = task.await;
+        let StreamingEditFileToolOutput::Success {
+            new_text, old_text, ..
+        } = result.unwrap()
+        else {
+            panic!("expected success");
+        };
+        assert_eq!(new_text, "new line 1\nnew line 2\n");
+        assert_eq!(*old_text, "old line 1\nold line 2\nold line 3\n");
+
+        // Diff is finalized after completion
+        diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
+    }
+
+    #[gpui::test]
+    async fn test_streaming_overwrite_content_streamed(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = project::FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/root",
+            json!({
+                "file.txt": "old line 1\nold line 2\nold line 3\n"
+            }),
+        )
+        .await;
+        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+        let model = Arc::new(FakeLanguageModel::default());
+        let thread = cx.new(|cx| {
+            crate::Thread::new(
+                project.clone(),
+                cx.new(|_cx| ProjectContext::default()),
+                context_server_registry,
+                Templates::new(),
+                Some(model),
+                cx,
+            )
+        });
+
+        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (event_stream, _receiver) = ToolCallEventStream::test();
+
+        let tool = Arc::new(StreamingEditFileTool::new(
+            project.clone(),
+            thread.downgrade(),
+            language_registry,
+        ));
+
+        let task = cx.update(|cx| tool.run(input, event_stream, cx));
+
+        // Transition to BufferResolved
+        sender.send_partial(json!({
+            "display_description": "Overwrite file",
+            "path": "root/file.txt",
+            "mode": "overwrite"
+        }));
+        cx.run_until_parked();
+
+        // Verify buffer still has old content (no content partial yet)
+        let buffer = project.update(cx, |project, cx| {
+            let path = project.find_project_path("root/file.txt", cx).unwrap();
+            project.get_open_buffer(&path, cx).unwrap()
+        });
+        assert_eq!(
+            buffer.read_with(cx, |b, _| b.text()),
+            "old line 1\nold line 2\nold line 3\n"
+        );
+
+        // First content partial replaces old content
+        sender.send_partial(json!({
+            "display_description": "Overwrite file",
+            "path": "root/file.txt",
+            "mode": "overwrite",
+            "content": "new line 1\n"
+        }));
+        cx.run_until_parked();
+        assert_eq!(buffer.read_with(cx, |b, _| b.text()), "new line 1\n");
+
+        // Subsequent content partials append
+        sender.send_partial(json!({
+            "display_description": "Overwrite file",
+            "path": "root/file.txt",
+            "mode": "overwrite",
+            "content": "new line 1\nnew line 2\n"
+        }));
+        cx.run_until_parked();
+        assert_eq!(
+            buffer.read_with(cx, |b, _| b.text()),
+            "new line 1\nnew line 2\n"
+        );
+
+        // Send final input with complete content
+        sender.send_final(json!({
+            "display_description": "Overwrite file",
+            "path": "root/file.txt",
+            "mode": "overwrite",
+            "content": "new line 1\nnew line 2\nnew line 3\n"
+        }));
+
+        let result = task.await;
+        let StreamingEditFileToolOutput::Success {
+            new_text, old_text, ..
+        } = result.unwrap()
+        else {
+            panic!("expected success");
+        };
+        assert_eq!(new_text, "new line 1\nnew line 2\nnew line 3\n");
+        assert_eq!(*old_text, "old line 1\nold line 2\nold line 3\n");
+    }
+
     fn init_test(cx: &mut TestAppContext) {
         cx.update(|cx| {
             let settings_store = SettingsStore::test(cx);