Fix rejecting multiple hunks in `AgentDiff` (#28806)

Antonio Scandurra and Max Brunsfeld created

Release Notes:

- Fixed a bug that caused `Reject All` to not always reject _all_ the
hunks.

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

crates/agent/src/agent_diff.rs          |  20 +
crates/agent/src/thread.rs              |   6 
crates/assistant_tool/src/action_log.rs | 200 +++++++++++++++++++++++---
3 files changed, 188 insertions(+), 38 deletions(-)

Detailed changes

crates/agent/src/agent_diff.rs 🔗

@@ -1,7 +1,7 @@
 use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent};
 use anyhow::Result;
 use buffer_diff::DiffHunkStatus;
-use collections::HashSet;
+use collections::{HashMap, HashSet};
 use editor::{
     Direction, Editor, EditorEvent, MultiBuffer, ToPoint,
     actions::{GoToHunk, GoToPreviousHunk},
@@ -355,16 +355,24 @@ impl AgentDiff {
             self.update_selection(&diff_hunks_in_ranges, window, cx);
         }
 
+        let mut ranges_by_buffer = HashMap::default();
         for hunk in &diff_hunks_in_ranges {
             let buffer = self.multibuffer.read(cx).buffer(hunk.buffer_id);
             if let Some(buffer) = buffer {
-                self.thread
-                    .update(cx, |thread, cx| {
-                        thread.reject_edits_in_range(buffer, hunk.buffer_range.clone(), cx)
-                    })
-                    .detach_and_log_err(cx);
+                ranges_by_buffer
+                    .entry(buffer.clone())
+                    .or_insert_with(Vec::new)
+                    .push(hunk.buffer_range.clone());
             }
         }
+
+        for (buffer, ranges) in ranges_by_buffer {
+            self.thread
+                .update(cx, |thread, cx| {
+                    thread.reject_edits_in_ranges(buffer, ranges, cx)
+                })
+                .detach_and_log_err(cx);
+        }
     }
 
     fn update_selection(

crates/agent/src/thread.rs 🔗

@@ -1801,14 +1801,14 @@ impl Thread {
             .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
     }
 
-    pub fn reject_edits_in_range(
+    pub fn reject_edits_in_ranges(
         &mut self,
         buffer: Entity<language::Buffer>,
-        buffer_range: Range<language::Anchor>,
+        buffer_ranges: Vec<Range<language::Anchor>>,
         cx: &mut Context<Self>,
     ) -> Task<Result<()>> {
         self.action_log.update(cx, |action_log, cx| {
-            action_log.reject_edits_in_range(buffer, buffer_range, cx)
+            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
         })
     }
 

crates/assistant_tool/src/action_log.rs 🔗

@@ -3,7 +3,7 @@ use buffer_diff::BufferDiff;
 use collections::BTreeMap;
 use futures::{StreamExt, channel::mpsc};
 use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
-use language::{Anchor, Buffer, BufferEvent, DiskState, Point};
+use language::{Anchor, Buffer, BufferEvent, DiskState, Point, ToPoint};
 use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle};
 use std::{cmp, ops::Range, sync::Arc};
 use text::{Edit, Patch, Rope};
@@ -363,10 +363,10 @@ impl ActionLog {
         }
     }
 
-    pub fn reject_edits_in_range(
+    pub fn reject_edits_in_ranges(
         &mut self,
         buffer: Entity<Buffer>,
-        buffer_range: Range<impl language::ToPoint>,
+        buffer_ranges: Vec<Range<impl language::ToPoint>>,
         cx: &mut Context<Self>,
     ) -> Task<Result<()>> {
         let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else {
@@ -403,29 +403,15 @@ impl ActionLog {
             }
             TrackedBufferStatus::Modified => {
                 buffer.update(cx, |buffer, cx| {
-                    let buffer_range =
-                        buffer_range.start.to_point(buffer)..buffer_range.end.to_point(buffer);
+                    let mut buffer_row_ranges = buffer_ranges
+                        .into_iter()
+                        .map(|range| {
+                            range.start.to_point(buffer).row..range.end.to_point(buffer).row
+                        })
+                        .peekable();
 
                     let mut edits_to_revert = Vec::new();
                     for edit in tracked_buffer.unreviewed_changes.edits() {
-                        if buffer_range.end.row < edit.new.start {
-                            break;
-                        } else if buffer_range.start.row > edit.new.end {
-                            continue;
-                        }
-
-                        let old_range = tracked_buffer
-                            .base_text
-                            .point_to_offset(Point::new(edit.old.start, 0))
-                            ..tracked_buffer.base_text.point_to_offset(cmp::min(
-                                Point::new(edit.old.end, 0),
-                                tracked_buffer.base_text.max_point(),
-                            ));
-                        let old_text = tracked_buffer
-                            .base_text
-                            .chunks_in_range(old_range)
-                            .collect::<String>();
-
                         let new_range = tracked_buffer
                             .snapshot
                             .anchor_before(Point::new(edit.new.start, 0))
@@ -433,7 +419,35 @@ impl ActionLog {
                                 Point::new(edit.new.end, 0),
                                 tracked_buffer.snapshot.max_point(),
                             ));
-                        edits_to_revert.push((new_range, old_text));
+                        let new_row_range = new_range.start.to_point(buffer).row
+                            ..new_range.end.to_point(buffer).row;
+
+                        let mut revert = false;
+                        while let Some(buffer_row_range) = buffer_row_ranges.peek() {
+                            if buffer_row_range.end < new_row_range.start {
+                                buffer_row_ranges.next();
+                            } else if buffer_row_range.start > new_row_range.end {
+                                break;
+                            } else {
+                                revert = true;
+                                break;
+                            }
+                        }
+
+                        if revert {
+                            let old_range = tracked_buffer
+                                .base_text
+                                .point_to_offset(Point::new(edit.old.start, 0))
+                                ..tracked_buffer.base_text.point_to_offset(cmp::min(
+                                    Point::new(edit.old.end, 0),
+                                    tracked_buffer.base_text.max_point(),
+                                ));
+                            let old_text = tracked_buffer
+                                .base_text
+                                .chunks_in_range(old_range)
+                                .collect::<String>();
+                            edits_to_revert.push((new_range, old_text));
+                        }
                     }
 
                     buffer.edit(edits_to_revert, None, cx);
@@ -599,6 +613,7 @@ fn point_to_row_edit(edit: Edit<Point>, old_text: &Rope, new_text: &Rope) -> Edi
     }
 }
 
+#[derive(Copy, Clone, Debug)]
 enum ChangeAuthor {
     User,
     Agent,
@@ -1135,9 +1150,48 @@ mod tests {
             )]
         );
 
+        // If the rejected range doesn't overlap with any hunk, we ignore it.
         action_log
             .update(cx, |log, cx| {
-                log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(1, 0), cx)
+                log.reject_edits_in_ranges(
+                    buffer.clone(),
+                    vec![Point::new(4, 0)..Point::new(4, 0)],
+                    cx,
+                )
+            })
+            .await
+            .unwrap();
+        cx.run_until_parked();
+        assert_eq!(
+            buffer.read_with(cx, |buffer, _| buffer.text()),
+            "abc\ndE\nXYZf\nghi\njkl\nmnO"
+        );
+        assert_eq!(
+            unreviewed_hunks(&action_log, cx),
+            vec![(
+                buffer.clone(),
+                vec![
+                    HunkStatus {
+                        range: Point::new(1, 0)..Point::new(3, 0),
+                        diff_status: DiffHunkStatusKind::Modified,
+                        old_text: "def\n".into(),
+                    },
+                    HunkStatus {
+                        range: Point::new(5, 0)..Point::new(5, 3),
+                        diff_status: DiffHunkStatusKind::Modified,
+                        old_text: "mno".into(),
+                    }
+                ],
+            )]
+        );
+
+        action_log
+            .update(cx, |log, cx| {
+                log.reject_edits_in_ranges(
+                    buffer.clone(),
+                    vec![Point::new(0, 0)..Point::new(1, 0)],
+                    cx,
+                )
             })
             .await
             .unwrap();
@@ -1160,7 +1214,11 @@ mod tests {
 
         action_log
             .update(cx, |log, cx| {
-                log.reject_edits_in_range(buffer.clone(), Point::new(4, 0)..Point::new(4, 0), cx)
+                log.reject_edits_in_ranges(
+                    buffer.clone(),
+                    vec![Point::new(4, 0)..Point::new(4, 0)],
+                    cx,
+                )
             })
             .await
             .unwrap();
@@ -1172,6 +1230,82 @@ mod tests {
         assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
     }
 
+    #[gpui::test(iterations = 10)]
+    async fn test_reject_multiple_edits(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
+            .await;
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let file_path = project
+            .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+            .unwrap();
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .unwrap();
+
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+            buffer.update(cx, |buffer, cx| {
+                buffer
+                    .edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
+                    .unwrap()
+            });
+            buffer.update(cx, |buffer, cx| {
+                buffer
+                    .edit([(Point::new(5, 2)..Point::new(5, 3), "O")], None, cx)
+                    .unwrap()
+            });
+            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+        });
+        cx.run_until_parked();
+        assert_eq!(
+            buffer.read_with(cx, |buffer, _| buffer.text()),
+            "abc\ndE\nXYZf\nghi\njkl\nmnO"
+        );
+        assert_eq!(
+            unreviewed_hunks(&action_log, cx),
+            vec![(
+                buffer.clone(),
+                vec![
+                    HunkStatus {
+                        range: Point::new(1, 0)..Point::new(3, 0),
+                        diff_status: DiffHunkStatusKind::Modified,
+                        old_text: "def\n".into(),
+                    },
+                    HunkStatus {
+                        range: Point::new(5, 0)..Point::new(5, 3),
+                        diff_status: DiffHunkStatusKind::Modified,
+                        old_text: "mno".into(),
+                    }
+                ],
+            )]
+        );
+
+        action_log.update(cx, |log, cx| {
+            let range_1 = buffer.read(cx).anchor_before(Point::new(0, 0))
+                ..buffer.read(cx).anchor_before(Point::new(1, 0));
+            let range_2 = buffer.read(cx).anchor_before(Point::new(5, 0))
+                ..buffer.read(cx).anchor_before(Point::new(5, 3));
+
+            log.reject_edits_in_ranges(buffer.clone(), vec![range_1, range_2], cx)
+                .detach();
+            assert_eq!(
+                buffer.read_with(cx, |buffer, _| buffer.text()),
+                "abc\ndef\nghi\njkl\nmno"
+            );
+        });
+        cx.run_until_parked();
+        assert_eq!(
+            buffer.read_with(cx, |buffer, _| buffer.text()),
+            "abc\ndef\nghi\njkl\nmno"
+        );
+        assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
+    }
+
     #[gpui::test(iterations = 10)]
     async fn test_reject_deleted_file(cx: &mut TestAppContext) {
         init_test(cx);
@@ -1215,7 +1349,11 @@ mod tests {
 
         action_log
             .update(cx, |log, cx| {
-                log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 0), cx)
+                log.reject_edits_in_ranges(
+                    buffer.clone(),
+                    vec![Point::new(0, 0)..Point::new(0, 0)],
+                    cx,
+                )
             })
             .await
             .unwrap();
@@ -1266,7 +1404,11 @@ mod tests {
 
         action_log
             .update(cx, |log, cx| {
-                log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 11), cx)
+                log.reject_edits_in_ranges(
+                    buffer.clone(),
+                    vec![Point::new(0, 0)..Point::new(0, 11)],
+                    cx,
+                )
             })
             .await
             .unwrap();
@@ -1312,7 +1454,7 @@ mod tests {
                         .update(cx, |log, cx| {
                             let range = buffer.read(cx).random_byte_range(0, &mut rng);
                             log::info!("rejecting edits in range {:?}", range);
-                            log.reject_edits_in_range(buffer.clone(), range, cx)
+                            log.reject_edits_in_ranges(buffer.clone(), vec![range], cx)
                         })
                         .await
                         .unwrap();