From d340670fd495507438878c5180687923abdf5796 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 2 Apr 2025 16:06:36 +0200 Subject: [PATCH] Ensure rejecting a hunk dismisses the diff (#27919) Release Notes: - N/A --- crates/assistant_tool/src/action_log.rs | 317 ++++++++---------- crates/assistant_tools/src/edit_files_tool.rs | 2 + .../src/find_replace_file_tool.rs | 3 + 3 files changed, 154 insertions(+), 168 deletions(-) diff --git a/crates/assistant_tool/src/action_log.rs b/crates/assistant_tool/src/action_log.rs index 535ef4c930e7f24a9e1622fd0d5be170a0dbcf31..3fb6370eaef545986127bfbdc72cb2cdd5f7018e 100644 --- a/crates/assistant_tool/src/action_log.rs +++ b/crates/assistant_tool/src/action_log.rs @@ -3,7 +3,7 @@ use buffer_diff::BufferDiff; use collections::{BTreeMap, HashSet}; use futures::{StreamExt, channel::mpsc}; use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity}; -use language::{Buffer, BufferEvent, DiskState, Point}; +use language::{Anchor, Buffer, BufferEvent, DiskState, Point}; use std::{cmp, ops::Range, sync::Arc}; use text::{Edit, Patch, Rope}; use util::RangeExt; @@ -169,20 +169,15 @@ impl ActionLog { let unreviewed_changes = tracked_buffer.unreviewed_changes.clone(); async move { let edits = diff_snapshots(&old_snapshot, &new_snapshot); - let unreviewed_changes = match author { - ChangeAuthor::User => rebase_patch( + if let ChangeAuthor::User = author { + apply_non_conflicting_edits( &unreviewed_changes, edits, &mut base_text, new_snapshot.as_rope(), - ), - ChangeAuthor::Agent => unreviewed_changes.compose(edits), - }; - ( - Arc::new(base_text.to_string()), - base_text, - unreviewed_changes, - ) + ); + } + (Arc::new(base_text.to_string()), base_text) } }); @@ -194,7 +189,7 @@ impl ActionLog { )) })??; - let (new_base_text, new_base_text_rope, unreviewed_changes) = rebase.await; + let (new_base_text, new_base_text_rope) = rebase.await; let diff_snapshot = BufferDiff::update_diff( diff.clone(), buffer_snapshot.clone(), @@ -206,7 +201,39 @@ impl ActionLog { cx, ) .await; + + let mut unreviewed_changes = Patch::default(); if let Ok(diff_snapshot) = diff_snapshot { + unreviewed_changes = cx + .background_spawn({ + let diff_snapshot = diff_snapshot.clone(); + let buffer_snapshot = buffer_snapshot.clone(); + let new_base_text_rope = new_base_text_rope.clone(); + async move { + let mut unreviewed_changes = Patch::default(); + for hunk in diff_snapshot.hunks_intersecting_range( + Anchor::MIN..Anchor::MAX, + &buffer_snapshot, + ) { + let old_range = new_base_text_rope + .offset_to_point(hunk.diff_base_byte_range.start) + ..new_base_text_rope + .offset_to_point(hunk.diff_base_byte_range.end); + let new_range = hunk.range.start..hunk.range.end; + unreviewed_changes.push(point_to_row_edit( + Edit { + old: old_range, + new: new_range, + }, + &new_base_text_rope, + &buffer_snapshot.as_rope(), + )); + } + unreviewed_changes + } + }) + .await; + diff.update(cx, |diff, cx| { diff.set_snapshot(diff_snapshot, &buffer_snapshot, None, cx) })?; @@ -267,21 +294,12 @@ impl ActionLog { cx.notify(); } - pub fn keep_edits_in_range( + pub fn keep_edits_in_range( &mut self, buffer: Entity, - buffer_range: Range, + buffer_range: Range, cx: &mut Context, - ) where - T: 'static + language::ToPoint, // + Clone - // + Copy - // + Ord - // + Sub - // + Add - // + AddAssign - // + Default - // + PartialEq, - { + ) { let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else { return; }; @@ -377,15 +395,12 @@ impl ActionLog { } } -fn rebase_patch( +fn apply_non_conflicting_edits( patch: &Patch, edits: Vec>, old_text: &mut Rope, new_text: &Rope, -) -> Patch { - let mut translated_unreviewed_edits = Patch::default(); - let mut conflicting_edits = Vec::new(); - +) { let mut old_edits = patch.edits().iter().cloned().peekable(); let mut new_edits = edits.into_iter().peekable(); let mut applied_delta = 0i32; @@ -396,43 +411,30 @@ fn rebase_patch( // Push all the old edits that are before this new edit or that intersect with it. while let Some(old_edit) = old_edits.peek() { - if new_edit.old.end <= old_edit.new.start { + if new_edit.old.end < old_edit.new.start + || (!old_edit.new.is_empty() && new_edit.old.end == old_edit.new.start) + { break; - } else if new_edit.old.start >= old_edit.new.end { - let mut old_edit = old_edits.next().unwrap(); - old_edit.old.start = (old_edit.old.start as i32 + applied_delta) as u32; - old_edit.old.end = (old_edit.old.end as i32 + applied_delta) as u32; - old_edit.new.start = (old_edit.new.start as i32 + applied_delta) as u32; - old_edit.new.end = (old_edit.new.end as i32 + applied_delta) as u32; + } else if new_edit.old.start > old_edit.new.end + || (!old_edit.new.is_empty() && new_edit.old.start == old_edit.new.end) + { + let old_edit = old_edits.next().unwrap(); rebased_delta += old_edit.new_len() as i32 - old_edit.old_len() as i32; - translated_unreviewed_edits.push(old_edit); } else { conflict = true; if new_edits .peek() .map_or(false, |next_edit| next_edit.old.overlaps(&old_edit.new)) { - new_edit.old.start = (new_edit.old.start as i32 + applied_delta) as u32; - new_edit.old.end = (new_edit.old.end as i32 + applied_delta) as u32; - conflicting_edits.push(new_edit); new_edit = new_edits.next().unwrap(); } else { - let mut old_edit = old_edits.next().unwrap(); - old_edit.old.start = (old_edit.old.start as i32 + applied_delta) as u32; - old_edit.old.end = (old_edit.old.end as i32 + applied_delta) as u32; - old_edit.new.start = (old_edit.new.start as i32 + applied_delta) as u32; - old_edit.new.end = (old_edit.new.end as i32 + applied_delta) as u32; + let old_edit = old_edits.next().unwrap(); rebased_delta += old_edit.new_len() as i32 - old_edit.old_len() as i32; - translated_unreviewed_edits.push(old_edit); } } } - if conflict { - new_edit.old.start = (new_edit.old.start as i32 + applied_delta) as u32; - new_edit.old.end = (new_edit.old.end as i32 + applied_delta) as u32; - conflicting_edits.push(new_edit); - } else { + if !conflict { // This edit doesn't intersect with any old edit, so we can apply it to the old text. new_edit.old.start = (new_edit.old.start as i32 + applied_delta - rebased_delta) as u32; new_edit.old.end = (new_edit.old.end as i32 + applied_delta - rebased_delta) as u32; @@ -454,17 +456,6 @@ fn rebase_patch( applied_delta += new_edit.new_len() as i32 - new_edit.old_len() as i32; } } - - // Push all the outstanding old edits. - for mut old_edit in old_edits { - old_edit.old.start = (old_edit.old.start as i32 + applied_delta) as u32; - old_edit.old.end = (old_edit.old.end as i32 + applied_delta) as u32; - old_edit.new.start = (old_edit.new.start as i32 + applied_delta) as u32; - old_edit.new.end = (old_edit.new.end as i32 + applied_delta) as u32; - translated_unreviewed_edits.push(old_edit); - } - - translated_unreviewed_edits.compose(conflicting_edits) } fn diff_snapshots( @@ -473,31 +464,7 @@ fn diff_snapshots( ) -> Vec> { let mut edits = new_snapshot .edits_since::(&old_snapshot.version) - .map(|edit| { - if edit.old.start.column == old_snapshot.line_len(edit.old.start.row) - && new_snapshot.chars_at(edit.new.start).next() == Some('\n') - && edit.old.start != old_snapshot.max_point() - { - Edit { - old: edit.old.start.row + 1..edit.old.end.row + 1, - new: edit.new.start.row + 1..edit.new.end.row + 1, - } - } else if edit.old.start.column == 0 - && edit.old.end.column == 0 - && edit.new.end.column == 0 - && edit.old.end != old_snapshot.max_point() - { - Edit { - old: edit.old.start.row..edit.old.end.row, - new: edit.new.start.row..edit.new.end.row, - } - } else { - Edit { - old: edit.old.start.row..edit.old.end.row + 1, - new: edit.new.start.row..edit.new.end.row + 1, - } - } - }) + .map(|edit| point_to_row_edit(edit, old_snapshot.as_rope(), new_snapshot.as_rope())) .peekable(); let mut row_edits = Vec::new(); while let Some(mut edit) = edits.next() { @@ -515,6 +482,35 @@ fn diff_snapshots( row_edits } +fn point_to_row_edit(edit: Edit, old_text: &Rope, new_text: &Rope) -> Edit { + if edit.old.start.column == old_text.line_len(edit.old.start.row) + && new_text + .chars_at(new_text.point_to_offset(edit.new.start)) + .next() + == Some('\n') + && edit.old.start != old_text.max_point() + { + Edit { + old: edit.old.start.row + 1..edit.old.end.row + 1, + new: edit.new.start.row + 1..edit.new.end.row + 1, + } + } else if edit.old.start.column == 0 + && edit.old.end.column == 0 + && edit.new.end.column == 0 + && edit.old.end != old_text.max_point() + { + Edit { + old: edit.old.start.row..edit.old.end.row, + new: edit.new.start.row..edit.new.end.row, + } + } else { + Edit { + old: edit.old.start.row..edit.old.end.row + 1, + new: edit.new.start.row..edit.new.end.row + 1, + } + } +} + enum ChangeAuthor { User, Agent, @@ -572,7 +568,7 @@ mod tests { use rand::prelude::*; use serde_json::json; use settings::SettingsStore; - use util::{RandomCharIter, path, post_inc}; + use util::{RandomCharIter, path}; #[ctor::ctor] fn init_logger() { @@ -582,7 +578,7 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_edit_review(cx: &mut TestAppContext) { + async fn test_keep_edits(cx: &mut TestAppContext) { let action_log = cx.new(|_| ActionLog::new()); let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx)); @@ -647,6 +643,70 @@ mod tests { assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); } + #[gpui::test(iterations = 10)] + async fn test_undoing_edits(cx: &mut TestAppContext) { + let action_log = cx.new(|_| ActionLog::new()); + let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno\npqr", cx)); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| { + buffer + .edit([(Point::new(1, 1)..Point::new(1, 2), "E")], None, cx) + .unwrap(); + buffer.finalize_last_transaction(); + }); + buffer.update(cx, |buffer, cx| { + buffer + .edit([(Point::new(4, 0)..Point::new(5, 0), "")], None, cx) + .unwrap(); + buffer.finalize_last_transaction(); + }); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "abc\ndEf\nghi\njkl\npqr" + ); + assert_eq!( + unreviewed_hunks(&action_log, cx), + vec![( + buffer.clone(), + vec![ + HunkStatus { + range: Point::new(1, 0)..Point::new(2, 0), + diff_status: DiffHunkStatusKind::Modified, + old_text: "def\n".into(), + }, + HunkStatus { + range: Point::new(4, 0)..Point::new(4, 0), + diff_status: DiffHunkStatusKind::Deleted, + old_text: "mno\n".into(), + } + ], + )] + ); + + buffer.update(cx, |buffer, cx| buffer.undo(cx)); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "abc\ndEf\nghi\njkl\nmno\npqr" + ); + assert_eq!( + unreviewed_hunks(&action_log, cx), + vec![( + buffer.clone(), + vec![HunkStatus { + range: Point::new(1, 0)..Point::new(2, 0), + diff_status: DiffHunkStatusKind::Modified, + old_text: "def\n".into(), + }], + )] + ); + } + #[gpui::test(iterations = 10)] async fn test_overlapping_user_edits(cx: &mut TestAppContext) { let action_log = cx.new(|_| ActionLog::new()); @@ -982,85 +1042,6 @@ mod tests { } } - #[gpui::test(iterations = 100)] - fn test_rebase_random(mut rng: StdRng) { - let operations = env::var("OPERATIONS") - .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) - .unwrap_or(20); - - let mut next_line_id = 0; - let base_lines = (0..rng.gen_range(1..=20)) - .map(|_| post_inc(&mut next_line_id).to_string()) - .collect::>(); - log::info!("base lines: {:?}", base_lines); - - let (new_lines, patch_1) = - build_edits(&base_lines, operations, &mut rng, &mut next_line_id); - log::info!("agent edits: {:#?}", patch_1); - let (new_lines, patch_2) = build_edits(&new_lines, operations, &mut rng, &mut next_line_id); - log::info!("user edits: {:#?}", patch_2); - - let mut old_text = Rope::from(base_lines.join("\n")); - let new_text = Rope::from(new_lines.join("\n")); - let patch = rebase_patch(&patch_1, patch_2.into_inner(), &mut old_text, &new_text); - log::info!("rebased edits: {:#?}", patch.edits()); - - for edit in patch.edits() { - let old_start = old_text.point_to_offset(Point::new(edit.new.start, 0)); - let old_end = old_text.point_to_offset(cmp::min( - Point::new(edit.new.start + edit.old_len(), 0), - old_text.max_point(), - )); - old_text.replace( - old_start..old_end, - &new_text.slice_rows(edit.new.clone()).to_string(), - ); - } - pretty_assertions::assert_eq!(old_text.to_string(), new_text.to_string()); - } - - fn build_edits( - lines: &Vec, - count: usize, - rng: &mut StdRng, - next_line_id: &mut usize, - ) -> (Vec, Patch) { - let mut delta = 0i32; - let mut last_edit_end = 0; - let mut edits = Patch::default(); - let mut edited_lines = lines.clone(); - for _ in 0..count { - if last_edit_end >= lines.len() { - break; - } - - let end = rng.gen_range(last_edit_end..lines.len()); - let start = rng.gen_range(last_edit_end..=end); - let old_len = end - start; - - let mut new_len: usize = rng.gen_range(0..=3); - if start == end && new_len == 0 { - new_len += 1; - } - - last_edit_end = end + 1; - - let new_lines = (0..new_len) - .map(|_| post_inc(next_line_id).to_string()) - .collect::>(); - log::info!(" editing {:?}: {:?}", start..end, new_lines); - let old = start as u32..end as u32; - let new = (start as i32 + delta) as u32..(start as i32 + delta + new_len as i32) as u32; - edited_lines.splice( - new.start as usize..new.start as usize + old.len(), - new_lines, - ); - edits.push(Edit { old, new }); - delta += new_len as i32 - old_len as i32; - } - (edited_lines, edits) - } - #[derive(Debug, Clone, PartialEq, Eq)] struct HunkStatus { range: Range, diff --git a/crates/assistant_tools/src/edit_files_tool.rs b/crates/assistant_tools/src/edit_files_tool.rs index 67a2ac901d8ceb96376889242c25f844d3c661ba..3cf3a2c6a65f61ff7a8093591106f0860e4c0246 100644 --- a/crates/assistant_tools/src/edit_files_tool.rs +++ b/crates/assistant_tools/src/edit_files_tool.rs @@ -339,6 +339,8 @@ impl EditToolRequest { } DiffResult::Diff(diff) => { cx.update(|cx| { + self.action_log + .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); buffer.update(cx, |buffer, cx| { buffer.finalize_last_transaction(); buffer.apply_diff(diff, cx); diff --git a/crates/assistant_tools/src/find_replace_file_tool.rs b/crates/assistant_tools/src/find_replace_file_tool.rs index 93b6d2fef76f5ad436d1eea1ad092c4cc06bdca9..db51fe9891302827f3f1b2026092813e96a35509 100644 --- a/crates/assistant_tools/src/find_replace_file_tool.rs +++ b/crates/assistant_tools/src/find_replace_file_tool.rs @@ -226,6 +226,9 @@ impl Tool for FindReplaceFileTool { }; let snapshot = cx.update(|cx| { + action_log.update(cx, |log, cx| { + log.buffer_read(buffer.clone(), cx) + }); let snapshot = buffer.update(cx, |buffer, cx| { buffer.finalize_last_transaction(); buffer.apply_diff(diff, cx);