Restore original file content when rejecting an overwritten file (#29974)

Antonio Scandurra created

Release Notes:

- Fixed a bug that would cause rejecting a hunk from the agent to delete
the file if the agent had decided to rewrite that file from scratch.

Change summary

crates/assistant_tool/src/action_log.rs        | 123 ++++++++++++++++---
crates/assistant_tools/src/edit_agent/evals.rs |   2 
2 files changed, 105 insertions(+), 20 deletions(-)

Detailed changes

crates/assistant_tool/src/action_log.rs 🔗

@@ -64,8 +64,20 @@ impl ActionLog {
                 let status;
                 let unreviewed_changes;
                 if is_created {
+                    let existing_file_content = if buffer
+                        .read(cx)
+                        .file()
+                        .map_or(false, |file| file.disk_state().exists())
+                    {
+                        Some(text_snapshot.as_rope().clone())
+                    } else {
+                        None
+                    };
+
                     base_text = Rope::default();
-                    status = TrackedBufferStatus::Created;
+                    status = TrackedBufferStatus::Created {
+                        existing_file_content,
+                    };
                     unreviewed_changes = Patch::new(vec![Edit {
                         old: 0..1,
                         new: 0..text_snapshot.max_point().row + 1,
@@ -128,7 +140,7 @@ impl ActionLog {
         };
 
         match tracked_buffer.status {
-            TrackedBufferStatus::Created | TrackedBufferStatus::Modified => {
+            TrackedBufferStatus::Created { .. } | TrackedBufferStatus::Modified => {
                 if buffer
                     .read(cx)
                     .file()
@@ -289,7 +301,7 @@ impl ActionLog {
     pub fn will_delete_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
         let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx);
         match tracked_buffer.status {
-            TrackedBufferStatus::Created => {
+            TrackedBufferStatus::Created { .. } => {
                 self.tracked_buffers.remove(&buffer);
                 cx.notify();
             }
@@ -373,19 +385,35 @@ impl ActionLog {
             return Task::ready(Ok(()));
         };
 
-        match tracked_buffer.status {
-            TrackedBufferStatus::Created => {
-                let delete = buffer
-                    .read(cx)
-                    .entry_id(cx)
-                    .and_then(|entry_id| {
-                        self.project
-                            .update(cx, |project, cx| project.delete_entry(entry_id, false, cx))
-                    })
-                    .unwrap_or(Task::ready(Ok(())));
+        match &tracked_buffer.status {
+            TrackedBufferStatus::Created {
+                existing_file_content,
+            } => {
+                let task = if let Some(existing_file_content) = existing_file_content {
+                    buffer.update(cx, |buffer, cx| {
+                        buffer.start_transaction();
+                        buffer.set_text("", cx);
+                        for chunk in existing_file_content.chunks() {
+                            buffer.append(chunk, cx);
+                        }
+                        buffer.end_transaction(cx);
+                    });
+                    self.project
+                        .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+                } else {
+                    buffer
+                        .read(cx)
+                        .entry_id(cx)
+                        .and_then(|entry_id| {
+                            self.project
+                                .update(cx, |project, cx| project.delete_entry(entry_id, false, cx))
+                        })
+                        .unwrap_or(Task::ready(Ok(())))
+                };
+
                 self.tracked_buffers.remove(&buffer);
                 cx.notify();
-                delete
+                task
             }
             TrackedBufferStatus::Deleted => {
                 buffer.update(cx, |buffer, cx| {
@@ -619,9 +647,8 @@ enum ChangeAuthor {
     Agent,
 }
 
-#[derive(Copy, Clone, Eq, PartialEq)]
 enum TrackedBufferStatus {
-    Created,
+    Created { existing_file_content: Option<Rope> },
     Modified,
     Deleted,
 }
@@ -1009,6 +1036,64 @@ mod tests {
         assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
     }
 
+    #[gpui::test(iterations = 10)]
+    async fn test_overwriting_files(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            path!("/dir"),
+            json!({
+                "file1": "Lorem ipsum dolor"
+            }),
+        )
+        .await;
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let file_path = project
+            .read_with(cx, |project, cx| project.find_project_path("dir/file1", cx))
+            .unwrap();
+
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .unwrap();
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
+            buffer.update(cx, |buffer, cx| buffer.set_text("sit amet consecteur", cx));
+            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+        });
+        project
+            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+            .await
+            .unwrap();
+        cx.run_until_parked();
+        assert_eq!(
+            unreviewed_hunks(&action_log, cx),
+            vec![(
+                buffer.clone(),
+                vec![HunkStatus {
+                    range: Point::new(0, 0)..Point::new(0, 19),
+                    diff_status: DiffHunkStatusKind::Added,
+                    old_text: "".into(),
+                }],
+            )]
+        );
+
+        action_log
+            .update(cx, |log, cx| {
+                log.reject_edits_in_ranges(buffer.clone(), vec![2..5], cx)
+            })
+            .await
+            .unwrap();
+        cx.run_until_parked();
+        assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
+        assert_eq!(
+            buffer.read_with(cx, |buffer, _cx| buffer.text()),
+            "Lorem ipsum dolor"
+        );
+    }
+
     #[gpui::test(iterations = 10)]
     async fn test_deleting_files(cx: &mut TestAppContext) {
         init_test(cx);
@@ -1090,7 +1175,7 @@ mod tests {
             .update(cx, |project, cx| project.open_buffer(file2_path, cx))
             .await
             .unwrap();
-        action_log.update(cx, |log, cx| log.buffer_read(buffer2.clone(), cx));
+        action_log.update(cx, |log, cx| log.buffer_created(buffer2.clone(), cx));
         buffer2.update(cx, |buffer, cx| buffer.set_text("IPSUM", cx));
         action_log.update(cx, |log, cx| log.buffer_edited(buffer2.clone(), cx));
         project
@@ -1105,8 +1190,8 @@ mod tests {
                 buffer2.clone(),
                 vec![HunkStatus {
                     range: Point::new(0, 0)..Point::new(0, 5),
-                    diff_status: DiffHunkStatusKind::Modified,
-                    old_text: "ipsum\n".into(),
+                    diff_status: DiffHunkStatusKind::Added,
+                    old_text: "".into(),
                 }],
             )]
         );

crates/assistant_tools/src/edit_agent/evals.rs 🔗

@@ -957,7 +957,7 @@ impl EditAgentTest {
 
                 cx.spawn(async move |cx| {
                     let agent_model =
-                        Self::load_model("google", "gemini-2.5-pro-preview-03-25", cx).await;
+                        Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
                     let judge_model =
                         Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
                     (agent_model.unwrap(), judge_model.unwrap())