Fix rejecting overwritten files if the agent previously edited them (#30744)

Antonio Scandurra created

Release Notes:

- Fixed rejecting overwritten files if the agent had previously edited them.

Change summary

crates/assistant_tool/src/action_log.rs | 174 ++++++++++++++++++++------
1 file changed, 134 insertions(+), 40 deletions(-)

Detailed changes

crates/assistant_tool/src/action_log.rs 🔗

@@ -49,6 +49,37 @@ impl ActionLog {
         is_created: bool,
         cx: &mut Context<Self>,
     ) -> &mut TrackedBuffer {
+        let status = if is_created {
+            if let Some(tracked) = self.tracked_buffers.remove(&buffer) {
+                match tracked.status {
+                    TrackedBufferStatus::Created {
+                        existing_file_content,
+                    } => TrackedBufferStatus::Created {
+                        existing_file_content,
+                    },
+                    TrackedBufferStatus::Modified | TrackedBufferStatus::Deleted => {
+                        TrackedBufferStatus::Created {
+                            existing_file_content: Some(tracked.diff_base),
+                        }
+                    }
+                }
+            } else if buffer
+                .read(cx)
+                .file()
+                .map_or(false, |file| file.disk_state().exists())
+            {
+                TrackedBufferStatus::Created {
+                    existing_file_content: Some(buffer.read(cx).as_rope().clone()),
+                }
+            } else {
+                TrackedBufferStatus::Created {
+                    existing_file_content: None,
+                }
+            }
+        } else {
+            TrackedBufferStatus::Modified
+        };
+
         let tracked_buffer = self
             .tracked_buffers
             .entry(buffer.clone())
@@ -60,36 +91,21 @@ impl ActionLog {
                 let text_snapshot = buffer.read(cx).text_snapshot();
                 let diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
                 let (diff_update_tx, diff_update_rx) = mpsc::unbounded();
-                let base_text;
-                let status;
+                let diff_base;
                 let unreviewed_changes;
                 if is_created {
-                    let existing_file_content = if buffer
-                        .read(cx)
-                        .file()
-                        .map_or(false, |file| file.disk_state().exists())
-                    {
-                        Some(text_snapshot.as_rope().clone())
-                    } else {
-                        None
-                    };
-
-                    base_text = Rope::default();
-                    status = TrackedBufferStatus::Created {
-                        existing_file_content,
-                    };
+                    diff_base = Rope::default();
                     unreviewed_changes = Patch::new(vec![Edit {
                         old: 0..1,
                         new: 0..text_snapshot.max_point().row + 1,
                     }])
                 } else {
-                    base_text = buffer.read(cx).as_rope().clone();
-                    status = TrackedBufferStatus::Modified;
+                    diff_base = buffer.read(cx).as_rope().clone();
                     unreviewed_changes = Patch::default();
                 }
                 TrackedBuffer {
                     buffer: buffer.clone(),
-                    base_text,
+                    diff_base,
                     unreviewed_changes,
                     snapshot: text_snapshot.clone(),
                     status,
@@ -184,7 +200,7 @@ impl ActionLog {
                         .context("buffer not tracked")?;
 
                     let rebase = cx.background_spawn({
-                        let mut base_text = tracked_buffer.base_text.clone();
+                        let mut base_text = tracked_buffer.diff_base.clone();
                         let old_snapshot = tracked_buffer.snapshot.clone();
                         let new_snapshot = buffer_snapshot.clone();
                         let unreviewed_changes = tracked_buffer.unreviewed_changes.clone();
@@ -210,7 +226,7 @@ impl ActionLog {
                     ))
                 })??;
 
-            let (new_base_text, new_base_text_rope) = rebase.await;
+            let (new_base_text, new_diff_base) = rebase.await;
             let diff_snapshot = BufferDiff::update_diff(
                 diff.clone(),
                 buffer_snapshot.clone(),
@@ -229,24 +245,23 @@ impl ActionLog {
                     .background_spawn({
                         let diff_snapshot = diff_snapshot.clone();
                         let buffer_snapshot = buffer_snapshot.clone();
-                        let new_base_text_rope = new_base_text_rope.clone();
+                        let new_diff_base = new_diff_base.clone();
                         async move {
                             let mut unreviewed_changes = Patch::default();
                             for hunk in diff_snapshot.hunks_intersecting_range(
                                 Anchor::MIN..Anchor::MAX,
                                 &buffer_snapshot,
                             ) {
-                                let old_range = new_base_text_rope
+                                let old_range = new_diff_base
                                     .offset_to_point(hunk.diff_base_byte_range.start)
-                                    ..new_base_text_rope
-                                        .offset_to_point(hunk.diff_base_byte_range.end);
+                                    ..new_diff_base.offset_to_point(hunk.diff_base_byte_range.end);
                                 let new_range = hunk.range.start..hunk.range.end;
                                 unreviewed_changes.push(point_to_row_edit(
                                     Edit {
                                         old: old_range,
                                         new: new_range,
                                     },
-                                    &new_base_text_rope,
+                                    &new_diff_base,
                                     &buffer_snapshot.as_rope(),
                                 ));
                             }
@@ -264,7 +279,7 @@ impl ActionLog {
                     .tracked_buffers
                     .get_mut(&buffer)
                     .context("buffer not tracked")?;
-                tracked_buffer.base_text = new_base_text_rope;
+                tracked_buffer.diff_base = new_diff_base;
                 tracked_buffer.snapshot = buffer_snapshot;
                 tracked_buffer.unreviewed_changes = unreviewed_changes;
                 cx.notify();
@@ -283,7 +298,6 @@ impl ActionLog {
     /// Mark a buffer as edited, so we can refresh it in the context
     pub fn buffer_created(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
         self.edited_since_project_diagnostics_check = true;
-        self.tracked_buffers.remove(&buffer);
         self.track_buffer_internal(buffer.clone(), true, cx);
     }
 
@@ -346,11 +360,11 @@ impl ActionLog {
                         true
                     } else {
                         let old_range = tracked_buffer
-                            .base_text
+                            .diff_base
                             .point_to_offset(Point::new(edit.old.start, 0))
-                            ..tracked_buffer.base_text.point_to_offset(cmp::min(
+                            ..tracked_buffer.diff_base.point_to_offset(cmp::min(
                                 Point::new(edit.old.end, 0),
-                                tracked_buffer.base_text.max_point(),
+                                tracked_buffer.diff_base.max_point(),
                             ));
                         let new_range = tracked_buffer
                             .snapshot
@@ -359,7 +373,7 @@ impl ActionLog {
                                 Point::new(edit.new.end, 0),
                                 tracked_buffer.snapshot.max_point(),
                             ));
-                        tracked_buffer.base_text.replace(
+                        tracked_buffer.diff_base.replace(
                             old_range,
                             &tracked_buffer
                                 .snapshot
@@ -417,7 +431,7 @@ impl ActionLog {
             }
             TrackedBufferStatus::Deleted => {
                 buffer.update(cx, |buffer, cx| {
-                    buffer.set_text(tracked_buffer.base_text.to_string(), cx)
+                    buffer.set_text(tracked_buffer.diff_base.to_string(), cx)
                 });
                 let save = self
                     .project
@@ -464,14 +478,14 @@ impl ActionLog {
 
                         if revert {
                             let old_range = tracked_buffer
-                                .base_text
+                                .diff_base
                                 .point_to_offset(Point::new(edit.old.start, 0))
-                                ..tracked_buffer.base_text.point_to_offset(cmp::min(
+                                ..tracked_buffer.diff_base.point_to_offset(cmp::min(
                                     Point::new(edit.old.end, 0),
-                                    tracked_buffer.base_text.max_point(),
+                                    tracked_buffer.diff_base.max_point(),
                                 ));
                             let old_text = tracked_buffer
-                                .base_text
+                                .diff_base
                                 .chunks_in_range(old_range)
                                 .collect::<String>();
                             edits_to_revert.push((new_range, old_text));
@@ -492,7 +506,7 @@ impl ActionLog {
                 TrackedBufferStatus::Deleted => false,
                 _ => {
                     tracked_buffer.unreviewed_changes.clear();
-                    tracked_buffer.base_text = tracked_buffer.snapshot.as_rope().clone();
+                    tracked_buffer.diff_base = tracked_buffer.snapshot.as_rope().clone();
                     tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx);
                     true
                 }
@@ -655,7 +669,7 @@ enum TrackedBufferStatus {
 
 struct TrackedBuffer {
     buffer: Entity<Buffer>,
-    base_text: Rope,
+    diff_base: Rope,
     unreviewed_changes: Patch<u32>,
     status: TrackedBufferStatus,
     version: clock::Global,
@@ -1094,6 +1108,86 @@ mod tests {
         );
     }
 
+    #[gpui::test(iterations = 10)]
+    async fn test_overwriting_previously_edited_files(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            path!("/dir"),
+            json!({
+                "file1": "Lorem ipsum dolor"
+            }),
+        )
+        .await;
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let file_path = project
+            .read_with(cx, |project, cx| project.find_project_path("dir/file1", cx))
+            .unwrap();
+
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .unwrap();
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+            buffer.update(cx, |buffer, cx| buffer.append(" sit amet consecteur", cx));
+            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+        });
+        project
+            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+            .await
+            .unwrap();
+        cx.run_until_parked();
+        assert_eq!(
+            unreviewed_hunks(&action_log, cx),
+            vec![(
+                buffer.clone(),
+                vec![HunkStatus {
+                    range: Point::new(0, 0)..Point::new(0, 37),
+                    diff_status: DiffHunkStatusKind::Modified,
+                    old_text: "Lorem ipsum dolor".into(),
+                }],
+            )]
+        );
+
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
+            buffer.update(cx, |buffer, cx| buffer.set_text("rewritten", cx));
+            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+        });
+        project
+            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+            .await
+            .unwrap();
+        cx.run_until_parked();
+        assert_eq!(
+            unreviewed_hunks(&action_log, cx),
+            vec![(
+                buffer.clone(),
+                vec![HunkStatus {
+                    range: Point::new(0, 0)..Point::new(0, 9),
+                    diff_status: DiffHunkStatusKind::Added,
+                    old_text: "".into(),
+                }],
+            )]
+        );
+
+        action_log
+            .update(cx, |log, cx| {
+                log.reject_edits_in_ranges(buffer.clone(), vec![2..5], cx)
+            })
+            .await
+            .unwrap();
+        cx.run_until_parked();
+        assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
+        assert_eq!(
+            buffer.read_with(cx, |buffer, _cx| buffer.text()),
+            "Lorem ipsum dolor"
+        );
+    }
+
     #[gpui::test(iterations = 10)]
     async fn test_deleting_files(cx: &mut TestAppContext) {
         init_test(cx);
@@ -1601,7 +1695,7 @@ mod tests {
             cx.run_until_parked();
             action_log.update(cx, |log, cx| {
                 let tracked_buffer = log.tracked_buffers.get(&buffer).unwrap();
-                let mut old_text = tracked_buffer.base_text.clone();
+                let mut old_text = tracked_buffer.diff_base.clone();
                 let new_text = buffer.read(cx).as_rope();
                 for edit in tracked_buffer.unreviewed_changes.edits() {
                     let old_start = old_text.point_to_offset(Point::new(edit.new.start, 0));