Ensure rejecting a hunk dismisses the diff (#27919)

Antonio Scandurra created

Release Notes:

- N/A

Change summary

crates/assistant_tool/src/action_log.rs              | 317 ++++++-------
crates/assistant_tools/src/edit_files_tool.rs        |   2 
crates/assistant_tools/src/find_replace_file_tool.rs |   3 
3 files changed, 154 insertions(+), 168 deletions(-)

Detailed changes

crates/assistant_tool/src/action_log.rs 🔗

@@ -3,7 +3,7 @@ use buffer_diff::BufferDiff;
 use collections::{BTreeMap, HashSet};
 use futures::{StreamExt, channel::mpsc};
 use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
-use language::{Buffer, BufferEvent, DiskState, Point};
+use language::{Anchor, Buffer, BufferEvent, DiskState, Point};
 use std::{cmp, ops::Range, sync::Arc};
 use text::{Edit, Patch, Rope};
 use util::RangeExt;
@@ -169,20 +169,15 @@ impl ActionLog {
                         let unreviewed_changes = tracked_buffer.unreviewed_changes.clone();
                         async move {
                             let edits = diff_snapshots(&old_snapshot, &new_snapshot);
-                            let unreviewed_changes = match author {
-                                ChangeAuthor::User => rebase_patch(
+                            if let ChangeAuthor::User = author {
+                                apply_non_conflicting_edits(
                                     &unreviewed_changes,
                                     edits,
                                     &mut base_text,
                                     new_snapshot.as_rope(),
-                                ),
-                                ChangeAuthor::Agent => unreviewed_changes.compose(edits),
-                            };
-                            (
-                                Arc::new(base_text.to_string()),
-                                base_text,
-                                unreviewed_changes,
-                            )
+                                );
+                            }
+                            (Arc::new(base_text.to_string()), base_text)
                         }
                     });
 
@@ -194,7 +189,7 @@ impl ActionLog {
                     ))
                 })??;
 
-            let (new_base_text, new_base_text_rope, unreviewed_changes) = rebase.await;
+            let (new_base_text, new_base_text_rope) = rebase.await;
             let diff_snapshot = BufferDiff::update_diff(
                 diff.clone(),
                 buffer_snapshot.clone(),
@@ -206,7 +201,39 @@ impl ActionLog {
                 cx,
             )
             .await;
+
+            let mut unreviewed_changes = Patch::default();
             if let Ok(diff_snapshot) = diff_snapshot {
+                unreviewed_changes = cx
+                    .background_spawn({
+                        let diff_snapshot = diff_snapshot.clone();
+                        let buffer_snapshot = buffer_snapshot.clone();
+                        let new_base_text_rope = new_base_text_rope.clone();
+                        async move {
+                            let mut unreviewed_changes = Patch::default();
+                            for hunk in diff_snapshot.hunks_intersecting_range(
+                                Anchor::MIN..Anchor::MAX,
+                                &buffer_snapshot,
+                            ) {
+                                let old_range = new_base_text_rope
+                                    .offset_to_point(hunk.diff_base_byte_range.start)
+                                    ..new_base_text_rope
+                                        .offset_to_point(hunk.diff_base_byte_range.end);
+                                let new_range = hunk.range.start..hunk.range.end;
+                                unreviewed_changes.push(point_to_row_edit(
+                                    Edit {
+                                        old: old_range,
+                                        new: new_range,
+                                    },
+                                    &new_base_text_rope,
+                                    &buffer_snapshot.as_rope(),
+                                ));
+                            }
+                            unreviewed_changes
+                        }
+                    })
+                    .await;
+
                 diff.update(cx, |diff, cx| {
                     diff.set_snapshot(diff_snapshot, &buffer_snapshot, None, cx)
                 })?;
@@ -267,21 +294,12 @@ impl ActionLog {
         cx.notify();
     }
 
-    pub fn keep_edits_in_range<T>(
+    pub fn keep_edits_in_range(
         &mut self,
         buffer: Entity<Buffer>,
-        buffer_range: Range<T>,
+        buffer_range: Range<impl language::ToPoint>,
         cx: &mut Context<Self>,
-    ) where
-        T: 'static + language::ToPoint, // + Clone
-                                        // + Copy
-                                        // + Ord
-                                        // + Sub<T, Output = T>
-                                        // + Add<T, Output = T>
-                                        // + AddAssign
-                                        // + Default
-                                        // + PartialEq,
-    {
+    ) {
         let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else {
             return;
         };
@@ -377,15 +395,12 @@ impl ActionLog {
     }
 }
 
-fn rebase_patch(
+fn apply_non_conflicting_edits(
     patch: &Patch<u32>,
     edits: Vec<Edit<u32>>,
     old_text: &mut Rope,
     new_text: &Rope,
-) -> Patch<u32> {
-    let mut translated_unreviewed_edits = Patch::default();
-    let mut conflicting_edits = Vec::new();
-
+) {
     let mut old_edits = patch.edits().iter().cloned().peekable();
     let mut new_edits = edits.into_iter().peekable();
     let mut applied_delta = 0i32;
@@ -396,43 +411,30 @@ fn rebase_patch(
 
         // Push all the old edits that are before this new edit or that intersect with it.
         while let Some(old_edit) = old_edits.peek() {
-            if new_edit.old.end <= old_edit.new.start {
+            if new_edit.old.end < old_edit.new.start
+                || (!old_edit.new.is_empty() && new_edit.old.end == old_edit.new.start)
+            {
                 break;
-            } else if new_edit.old.start >= old_edit.new.end {
-                let mut old_edit = old_edits.next().unwrap();
-                old_edit.old.start = (old_edit.old.start as i32 + applied_delta) as u32;
-                old_edit.old.end = (old_edit.old.end as i32 + applied_delta) as u32;
-                old_edit.new.start = (old_edit.new.start as i32 + applied_delta) as u32;
-                old_edit.new.end = (old_edit.new.end as i32 + applied_delta) as u32;
+            } else if new_edit.old.start > old_edit.new.end
+                || (!old_edit.new.is_empty() && new_edit.old.start == old_edit.new.end)
+            {
+                let old_edit = old_edits.next().unwrap();
                 rebased_delta += old_edit.new_len() as i32 - old_edit.old_len() as i32;
-                translated_unreviewed_edits.push(old_edit);
             } else {
                 conflict = true;
                 if new_edits
                     .peek()
                     .map_or(false, |next_edit| next_edit.old.overlaps(&old_edit.new))
                 {
-                    new_edit.old.start = (new_edit.old.start as i32 + applied_delta) as u32;
-                    new_edit.old.end = (new_edit.old.end as i32 + applied_delta) as u32;
-                    conflicting_edits.push(new_edit);
                     new_edit = new_edits.next().unwrap();
                 } else {
-                    let mut old_edit = old_edits.next().unwrap();
-                    old_edit.old.start = (old_edit.old.start as i32 + applied_delta) as u32;
-                    old_edit.old.end = (old_edit.old.end as i32 + applied_delta) as u32;
-                    old_edit.new.start = (old_edit.new.start as i32 + applied_delta) as u32;
-                    old_edit.new.end = (old_edit.new.end as i32 + applied_delta) as u32;
+                    let old_edit = old_edits.next().unwrap();
                     rebased_delta += old_edit.new_len() as i32 - old_edit.old_len() as i32;
-                    translated_unreviewed_edits.push(old_edit);
                 }
             }
         }
 
-        if conflict {
-            new_edit.old.start = (new_edit.old.start as i32 + applied_delta) as u32;
-            new_edit.old.end = (new_edit.old.end as i32 + applied_delta) as u32;
-            conflicting_edits.push(new_edit);
-        } else {
+        if !conflict {
             // This edit doesn't intersect with any old edit, so we can apply it to the old text.
             new_edit.old.start = (new_edit.old.start as i32 + applied_delta - rebased_delta) as u32;
             new_edit.old.end = (new_edit.old.end as i32 + applied_delta - rebased_delta) as u32;
@@ -454,17 +456,6 @@ fn rebase_patch(
             applied_delta += new_edit.new_len() as i32 - new_edit.old_len() as i32;
         }
     }
-
-    // Push all the outstanding old edits.
-    for mut old_edit in old_edits {
-        old_edit.old.start = (old_edit.old.start as i32 + applied_delta) as u32;
-        old_edit.old.end = (old_edit.old.end as i32 + applied_delta) as u32;
-        old_edit.new.start = (old_edit.new.start as i32 + applied_delta) as u32;
-        old_edit.new.end = (old_edit.new.end as i32 + applied_delta) as u32;
-        translated_unreviewed_edits.push(old_edit);
-    }
-
-    translated_unreviewed_edits.compose(conflicting_edits)
 }
 
 fn diff_snapshots(
@@ -473,31 +464,7 @@ fn diff_snapshots(
 ) -> Vec<Edit<u32>> {
     let mut edits = new_snapshot
         .edits_since::<Point>(&old_snapshot.version)
-        .map(|edit| {
-            if edit.old.start.column == old_snapshot.line_len(edit.old.start.row)
-                && new_snapshot.chars_at(edit.new.start).next() == Some('\n')
-                && edit.old.start != old_snapshot.max_point()
-            {
-                Edit {
-                    old: edit.old.start.row + 1..edit.old.end.row + 1,
-                    new: edit.new.start.row + 1..edit.new.end.row + 1,
-                }
-            } else if edit.old.start.column == 0
-                && edit.old.end.column == 0
-                && edit.new.end.column == 0
-                && edit.old.end != old_snapshot.max_point()
-            {
-                Edit {
-                    old: edit.old.start.row..edit.old.end.row,
-                    new: edit.new.start.row..edit.new.end.row,
-                }
-            } else {
-                Edit {
-                    old: edit.old.start.row..edit.old.end.row + 1,
-                    new: edit.new.start.row..edit.new.end.row + 1,
-                }
-            }
-        })
+        .map(|edit| point_to_row_edit(edit, old_snapshot.as_rope(), new_snapshot.as_rope()))
         .peekable();
     let mut row_edits = Vec::new();
     while let Some(mut edit) = edits.next() {
@@ -515,6 +482,35 @@ fn diff_snapshots(
     row_edits
 }
 
+fn point_to_row_edit(edit: Edit<Point>, old_text: &Rope, new_text: &Rope) -> Edit<u32> {
+    if edit.old.start.column == old_text.line_len(edit.old.start.row)
+        && new_text
+            .chars_at(new_text.point_to_offset(edit.new.start))
+            .next()
+            == Some('\n')
+        && edit.old.start != old_text.max_point()
+    {
+        Edit {
+            old: edit.old.start.row + 1..edit.old.end.row + 1,
+            new: edit.new.start.row + 1..edit.new.end.row + 1,
+        }
+    } else if edit.old.start.column == 0
+        && edit.old.end.column == 0
+        && edit.new.end.column == 0
+        && edit.old.end != old_text.max_point()
+    {
+        Edit {
+            old: edit.old.start.row..edit.old.end.row,
+            new: edit.new.start.row..edit.new.end.row,
+        }
+    } else {
+        Edit {
+            old: edit.old.start.row..edit.old.end.row + 1,
+            new: edit.new.start.row..edit.new.end.row + 1,
+        }
+    }
+}
+
 enum ChangeAuthor {
     User,
     Agent,
@@ -572,7 +568,7 @@ mod tests {
     use rand::prelude::*;
     use serde_json::json;
     use settings::SettingsStore;
-    use util::{RandomCharIter, path, post_inc};
+    use util::{RandomCharIter, path};
 
     #[ctor::ctor]
     fn init_logger() {
@@ -582,7 +578,7 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_edit_review(cx: &mut TestAppContext) {
+    async fn test_keep_edits(cx: &mut TestAppContext) {
         let action_log = cx.new(|_| ActionLog::new());
         let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
 
@@ -647,6 +643,70 @@ mod tests {
         assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
     }
 
+    #[gpui::test(iterations = 10)]
+    async fn test_undoing_edits(cx: &mut TestAppContext) {
+        let action_log = cx.new(|_| ActionLog::new());
+        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno\npqr", cx));
+
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+            buffer.update(cx, |buffer, cx| {
+                buffer
+                    .edit([(Point::new(1, 1)..Point::new(1, 2), "E")], None, cx)
+                    .unwrap();
+                buffer.finalize_last_transaction();
+            });
+            buffer.update(cx, |buffer, cx| {
+                buffer
+                    .edit([(Point::new(4, 0)..Point::new(5, 0), "")], None, cx)
+                    .unwrap();
+                buffer.finalize_last_transaction();
+            });
+            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+        });
+        cx.run_until_parked();
+        assert_eq!(
+            buffer.read_with(cx, |buffer, _| buffer.text()),
+            "abc\ndEf\nghi\njkl\npqr"
+        );
+        assert_eq!(
+            unreviewed_hunks(&action_log, cx),
+            vec![(
+                buffer.clone(),
+                vec![
+                    HunkStatus {
+                        range: Point::new(1, 0)..Point::new(2, 0),
+                        diff_status: DiffHunkStatusKind::Modified,
+                        old_text: "def\n".into(),
+                    },
+                    HunkStatus {
+                        range: Point::new(4, 0)..Point::new(4, 0),
+                        diff_status: DiffHunkStatusKind::Deleted,
+                        old_text: "mno\n".into(),
+                    }
+                ],
+            )]
+        );
+
+        buffer.update(cx, |buffer, cx| buffer.undo(cx));
+        cx.run_until_parked();
+        assert_eq!(
+            buffer.read_with(cx, |buffer, _| buffer.text()),
+            "abc\ndEf\nghi\njkl\nmno\npqr"
+        );
+        assert_eq!(
+            unreviewed_hunks(&action_log, cx),
+            vec![(
+                buffer.clone(),
+                vec![HunkStatus {
+                    range: Point::new(1, 0)..Point::new(2, 0),
+                    diff_status: DiffHunkStatusKind::Modified,
+                    old_text: "def\n".into(),
+                }],
+            )]
+        );
+    }
+
     #[gpui::test(iterations = 10)]
     async fn test_overlapping_user_edits(cx: &mut TestAppContext) {
         let action_log = cx.new(|_| ActionLog::new());
@@ -982,85 +1042,6 @@ mod tests {
         }
     }
 
-    #[gpui::test(iterations = 100)]
-    fn test_rebase_random(mut rng: StdRng) {
-        let operations = env::var("OPERATIONS")
-            .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
-            .unwrap_or(20);
-
-        let mut next_line_id = 0;
-        let base_lines = (0..rng.gen_range(1..=20))
-            .map(|_| post_inc(&mut next_line_id).to_string())
-            .collect::<Vec<_>>();
-        log::info!("base lines: {:?}", base_lines);
-
-        let (new_lines, patch_1) =
-            build_edits(&base_lines, operations, &mut rng, &mut next_line_id);
-        log::info!("agent edits: {:#?}", patch_1);
-        let (new_lines, patch_2) = build_edits(&new_lines, operations, &mut rng, &mut next_line_id);
-        log::info!("user edits: {:#?}", patch_2);
-
-        let mut old_text = Rope::from(base_lines.join("\n"));
-        let new_text = Rope::from(new_lines.join("\n"));
-        let patch = rebase_patch(&patch_1, patch_2.into_inner(), &mut old_text, &new_text);
-        log::info!("rebased edits: {:#?}", patch.edits());
-
-        for edit in patch.edits() {
-            let old_start = old_text.point_to_offset(Point::new(edit.new.start, 0));
-            let old_end = old_text.point_to_offset(cmp::min(
-                Point::new(edit.new.start + edit.old_len(), 0),
-                old_text.max_point(),
-            ));
-            old_text.replace(
-                old_start..old_end,
-                &new_text.slice_rows(edit.new.clone()).to_string(),
-            );
-        }
-        pretty_assertions::assert_eq!(old_text.to_string(), new_text.to_string());
-    }
-
-    fn build_edits(
-        lines: &Vec<String>,
-        count: usize,
-        rng: &mut StdRng,
-        next_line_id: &mut usize,
-    ) -> (Vec<String>, Patch<u32>) {
-        let mut delta = 0i32;
-        let mut last_edit_end = 0;
-        let mut edits = Patch::default();
-        let mut edited_lines = lines.clone();
-        for _ in 0..count {
-            if last_edit_end >= lines.len() {
-                break;
-            }
-
-            let end = rng.gen_range(last_edit_end..lines.len());
-            let start = rng.gen_range(last_edit_end..=end);
-            let old_len = end - start;
-
-            let mut new_len: usize = rng.gen_range(0..=3);
-            if start == end && new_len == 0 {
-                new_len += 1;
-            }
-
-            last_edit_end = end + 1;
-
-            let new_lines = (0..new_len)
-                .map(|_| post_inc(next_line_id).to_string())
-                .collect::<Vec<_>>();
-            log::info!("  editing {:?}: {:?}", start..end, new_lines);
-            let old = start as u32..end as u32;
-            let new = (start as i32 + delta) as u32..(start as i32 + delta + new_len as i32) as u32;
-            edited_lines.splice(
-                new.start as usize..new.start as usize + old.len(),
-                new_lines,
-            );
-            edits.push(Edit { old, new });
-            delta += new_len as i32 - old_len as i32;
-        }
-        (edited_lines, edits)
-    }
-
     #[derive(Debug, Clone, PartialEq, Eq)]
     struct HunkStatus {
         range: Range<Point>,

crates/assistant_tools/src/edit_files_tool.rs 🔗

@@ -339,6 +339,8 @@ impl EditToolRequest {
             }
             DiffResult::Diff(diff) => {
                 cx.update(|cx| {
+                    self.action_log
+                        .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
                     buffer.update(cx, |buffer, cx| {
                         buffer.finalize_last_transaction();
                         buffer.apply_diff(diff, cx);

crates/assistant_tools/src/find_replace_file_tool.rs 🔗

@@ -226,6 +226,9 @@ impl Tool for FindReplaceFileTool {
             };
 
             let snapshot = cx.update(|cx| {
+                action_log.update(cx, |log, cx| {
+                    log.buffer_read(buffer.clone(), cx)
+                });
                 let snapshot = buffer.update(cx, |buffer, cx| {
                     buffer.finalize_last_transaction();
                     buffer.apply_diff(diff, cx);