Implement edit rejection in `ActionLog` (#28080)

Antonio Scandurra created

Release Notes:

- Fixed a bug that would prevent rejecting certain agent edits.

Change summary

crates/agent/src/agent_diff.rs          |  21 
crates/agent/src/thread.rs              |  17 +
crates/assistant_tool/src/action_log.rs | 369 +++++++++++++++++++++++++-
3 files changed, 368 insertions(+), 39 deletions(-)

Detailed changes

crates/agent/src/agent_diff.rs 🔗

@@ -3,7 +3,7 @@ use anyhow::Result;
 use buffer_diff::DiffHunkStatus;
 use collections::HashSet;
 use editor::{
-    AnchorRangeExt, Direction, Editor, EditorEvent, MultiBuffer, ToPoint,
+    Direction, Editor, EditorEvent, MultiBuffer, ToPoint,
     actions::{GoToHunk, GoToPreviousHunk},
     scroll::Autoscroll,
 };
@@ -350,13 +350,16 @@ impl AgentDiff {
             self.update_selection(&diff_hunks_in_ranges, window, cx);
         }
 
-        let point_ranges = ranges
-            .into_iter()
-            .map(|range| range.to_point(&snapshot))
-            .collect();
-        self.editor.update(cx, |editor, cx| {
-            editor.restore_hunks_in_ranges(point_ranges, window, cx)
-        });
+        for hunk in &diff_hunks_in_ranges {
+            let buffer = self.multibuffer.read(cx).buffer(hunk.buffer_id);
+            if let Some(buffer) = buffer {
+                self.thread
+                    .update(cx, |thread, cx| {
+                        thread.reject_edits_in_range(buffer, hunk.buffer_range.clone(), cx)
+                    })
+                    .detach_and_log_err(cx);
+            }
+        }
     }
 
     fn update_selection(
@@ -986,7 +989,7 @@ mod tests {
             Point::new(3, 0)..Point::new(3, 0)
         );
 
-        // Restoring a hunk also moves the cursor to the next hunk, possibly cycling if it's at the end.
+        // Rejecting a hunk also moves the cursor to the next hunk, possibly cycling if it's at the end.
         editor.update_in(cx, |editor, window, cx| {
             editor.change_selections(None, window, cx, |selections| {
                 selections.select_ranges([Point::new(10, 0)..Point::new(10, 0)])

crates/agent/src/thread.rs 🔗

@@ -290,7 +290,7 @@ impl Thread {
             last_restore_checkpoint: None,
             pending_checkpoint: None,
             tool_use: ToolUseState::new(tools.clone()),
-            action_log: cx.new(|_| ActionLog::new()),
+            action_log: cx.new(|_| ActionLog::new(project.clone())),
             initial_project_snapshot: {
                 let project_snapshot = Self::project_snapshot(project, cx);
                 cx.foreground_executor()
@@ -354,11 +354,11 @@ impl Thread {
             pending_completions: Vec::new(),
             last_restore_checkpoint: None,
             pending_checkpoint: None,
-            project,
+            project: project.clone(),
             prompt_builder,
             tools,
             tool_use,
-            action_log: cx.new(|_| ActionLog::new()),
+            action_log: cx.new(|_| ActionLog::new(project)),
             initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
             cumulative_token_usage: serialized.cumulative_token_usage,
             feedback: None,
@@ -1757,6 +1757,17 @@ impl Thread {
             .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
     }
 
+    pub fn reject_edits_in_range(
+        &mut self,
+        buffer: Entity<language::Buffer>,
+        buffer_range: Range<language::Anchor>,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<()>> {
+        self.action_log.update(cx, |action_log, cx| {
+            action_log.reject_edits_in_range(buffer, buffer_range, cx)
+        })
+    }
+
     pub fn action_log(&self) -> &Entity<ActionLog> {
         &self.action_log
     }

crates/assistant_tool/src/action_log.rs 🔗

@@ -4,6 +4,7 @@ use collections::BTreeMap;
 use futures::{StreamExt, channel::mpsc};
 use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
 use language::{Anchor, Buffer, BufferEvent, DiskState, Point};
+use project::{Project, ProjectItem};
 use std::{cmp, ops::Range, sync::Arc};
 use text::{Edit, Patch, Rope};
 use util::RangeExt;
@@ -14,14 +15,17 @@ pub struct ActionLog {
     tracked_buffers: BTreeMap<Entity<Buffer>, TrackedBuffer>,
     /// Has the model edited a file since it last checked diagnostics?
     edited_since_project_diagnostics_check: bool,
+    /// The project this action log is associated with
+    project: Entity<Project>,
 }
 
 impl ActionLog {
-    /// Creates a new, empty action log.
-    pub fn new() -> Self {
+    /// Creates a new, empty action log associated with the given project.
+    pub fn new(project: Entity<Project>) -> Self {
         Self {
             tracked_buffers: BTreeMap::default(),
             edited_since_project_diagnostics_check: false,
+            project,
         }
     }
 
@@ -324,14 +328,14 @@ impl ActionLog {
                     {
                         true
                     } else {
-                        let old_bytes = tracked_buffer
+                        let old_range = tracked_buffer
                             .base_text
                             .point_to_offset(Point::new(edit.old.start, 0))
                             ..tracked_buffer.base_text.point_to_offset(cmp::min(
                                 Point::new(edit.old.end, 0),
                                 tracked_buffer.base_text.max_point(),
                             ));
-                        let new_bytes = tracked_buffer
+                        let new_range = tracked_buffer
                             .snapshot
                             .point_to_offset(Point::new(edit.new.start, 0))
                             ..tracked_buffer.snapshot.point_to_offset(cmp::min(
@@ -339,10 +343,10 @@ impl ActionLog {
                                 tracked_buffer.snapshot.max_point(),
                             ));
                         tracked_buffer.base_text.replace(
-                            old_bytes,
+                            old_range,
                             &tracked_buffer
                                 .snapshot
-                                .text_for_range(new_bytes)
+                                .text_for_range(new_range)
                                 .collect::<String>(),
                         );
                         delta += edit.new_len() as i32 - edit.old_len() as i32;
@@ -354,6 +358,87 @@ impl ActionLog {
         }
     }
 
+    pub fn reject_edits_in_range(
+        &mut self,
+        buffer: Entity<Buffer>,
+        buffer_range: Range<impl language::ToPoint>,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<()>> {
+        let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else {
+            return Task::ready(Ok(()));
+        };
+
+        match tracked_buffer.status {
+            TrackedBufferStatus::Created => {
+                let delete = buffer
+                    .read(cx)
+                    .entry_id(cx)
+                    .and_then(|entry_id| {
+                        self.project
+                            .update(cx, |project, cx| project.delete_entry(entry_id, false, cx))
+                    })
+                    .unwrap_or(Task::ready(Ok(())));
+                self.tracked_buffers.remove(&buffer);
+                cx.notify();
+                delete
+            }
+            TrackedBufferStatus::Deleted => {
+                buffer.update(cx, |buffer, cx| {
+                    buffer.set_text(tracked_buffer.base_text.to_string(), cx)
+                });
+                let save = self
+                    .project
+                    .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx));
+
+                // Clear all tracked changes for this buffer and start over as if we just read it.
+                self.tracked_buffers.remove(&buffer);
+                self.track_buffer(buffer.clone(), false, cx);
+                cx.notify();
+                save
+            }
+            TrackedBufferStatus::Modified => {
+                buffer.update(cx, |buffer, cx| {
+                    let buffer_range =
+                        buffer_range.start.to_point(buffer)..buffer_range.end.to_point(buffer);
+
+                    let mut edits_to_revert = Vec::new();
+                    for edit in tracked_buffer.unreviewed_changes.edits() {
+                        if buffer_range.end.row < edit.new.start {
+                            break;
+                        } else if buffer_range.start.row > edit.new.end {
+                            continue;
+                        }
+
+                        let old_range = tracked_buffer
+                            .base_text
+                            .point_to_offset(Point::new(edit.old.start, 0))
+                            ..tracked_buffer.base_text.point_to_offset(cmp::min(
+                                Point::new(edit.old.end, 0),
+                                tracked_buffer.base_text.max_point(),
+                            ));
+                        let old_text = tracked_buffer
+                            .base_text
+                            .chunks_in_range(old_range)
+                            .collect::<String>();
+
+                        let new_range = tracked_buffer
+                            .snapshot
+                            .anchor_before(Point::new(edit.new.start, 0))
+                            ..tracked_buffer.snapshot.anchor_after(cmp::min(
+                                Point::new(edit.new.end, 0),
+                                tracked_buffer.snapshot.max_point(),
+                            ));
+                        edits_to_revert.push((new_range, old_text));
+                    }
+
+                    buffer.edit(edits_to_revert, None, cx);
+                });
+                self.project
+                    .update(cx, |project, cx| project.save_buffer(buffer, cx))
+            }
+        }
+    }
+
     pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
         self.tracked_buffers
             .retain(|_buffer, tracked_buffer| match tracked_buffer.status {
@@ -575,9 +660,22 @@ mod tests {
         }
     }
 
+    fn init_test(cx: &mut TestAppContext) {
+        cx.update(|cx| {
+            let settings_store = SettingsStore::test(cx);
+            cx.set_global(settings_store);
+            language::init(cx);
+            Project::init_settings(cx);
+        });
+    }
+
     #[gpui::test(iterations = 10)]
     async fn test_keep_edits(cx: &mut TestAppContext) {
-        let action_log = cx.new(|_| ActionLog::new());
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs.clone(), [], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
 
         cx.update(|cx| {
@@ -643,7 +741,11 @@ mod tests {
 
     #[gpui::test(iterations = 10)]
     async fn test_deletions(cx: &mut TestAppContext) {
-        let action_log = cx.new(|_| ActionLog::new());
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs.clone(), [], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno\npqr", cx));
 
         cx.update(|cx| {
@@ -713,7 +815,11 @@ mod tests {
 
     #[gpui::test(iterations = 10)]
     async fn test_overlapping_user_edits(cx: &mut TestAppContext) {
-        let action_log = cx.new(|_| ActionLog::new());
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs.clone(), [], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
 
         cx.update(|cx| {
@@ -797,15 +903,12 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_creation(cx: &mut TestAppContext) {
-        cx.update(|cx| {
-            let settings_store = SettingsStore::test(cx);
-            cx.set_global(settings_store);
-            language::init(cx);
-            Project::init_settings(cx);
-        });
+    async fn test_creating_files(cx: &mut TestAppContext) {
+        init_test(cx);
 
-        let action_log = cx.new(|_| ActionLog::new());
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 
         let fs = FakeFs::new(cx.executor());
         fs.insert_tree(path!("/dir"), json!({})).await;
@@ -864,12 +967,7 @@ mod tests {
 
     #[gpui::test(iterations = 10)]
     async fn test_deleting_files(cx: &mut TestAppContext) {
-        cx.update(|cx| {
-            let settings_store = SettingsStore::test(cx);
-            cx.set_global(settings_store);
-            language::init(cx);
-            Project::init_settings(cx);
-        });
+        init_test(cx);
 
         let fs = FakeFs::new(cx.executor());
         fs.insert_tree(
@@ -886,7 +984,7 @@ mod tests {
             .read_with(cx, |project, cx| project.find_project_path("dir/file2", cx))
             .unwrap();
 
-        let action_log = cx.new(|_| ActionLog::new());
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let buffer1 = project
             .update(cx, |project, cx| {
                 project.open_buffer(file1_path.clone(), cx)
@@ -976,15 +1074,222 @@ mod tests {
         assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
     }
 
+    #[gpui::test(iterations = 10)]
+    async fn test_reject_edits(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
+            .await;
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let file_path = project
+            .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+            .unwrap();
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .unwrap();
+
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+            buffer.update(cx, |buffer, cx| {
+                buffer
+                    .edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
+                    .unwrap()
+            });
+            buffer.update(cx, |buffer, cx| {
+                buffer
+                    .edit([(Point::new(5, 2)..Point::new(5, 3), "O")], None, cx)
+                    .unwrap()
+            });
+            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+        });
+        cx.run_until_parked();
+        assert_eq!(
+            buffer.read_with(cx, |buffer, _| buffer.text()),
+            "abc\ndE\nXYZf\nghi\njkl\nmnO"
+        );
+        assert_eq!(
+            unreviewed_hunks(&action_log, cx),
+            vec![(
+                buffer.clone(),
+                vec![
+                    HunkStatus {
+                        range: Point::new(1, 0)..Point::new(3, 0),
+                        diff_status: DiffHunkStatusKind::Modified,
+                        old_text: "def\n".into(),
+                    },
+                    HunkStatus {
+                        range: Point::new(5, 0)..Point::new(5, 3),
+                        diff_status: DiffHunkStatusKind::Modified,
+                        old_text: "mno".into(),
+                    }
+                ],
+            )]
+        );
+
+        action_log
+            .update(cx, |log, cx| {
+                log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(1, 0), cx)
+            })
+            .await
+            .unwrap();
+        cx.run_until_parked();
+        assert_eq!(
+            buffer.read_with(cx, |buffer, _| buffer.text()),
+            "abc\ndef\nghi\njkl\nmnO"
+        );
+        assert_eq!(
+            unreviewed_hunks(&action_log, cx),
+            vec![(
+                buffer.clone(),
+                vec![HunkStatus {
+                    range: Point::new(4, 0)..Point::new(4, 3),
+                    diff_status: DiffHunkStatusKind::Modified,
+                    old_text: "mno".into(),
+                }],
+            )]
+        );
+
+        action_log
+            .update(cx, |log, cx| {
+                log.reject_edits_in_range(buffer.clone(), Point::new(4, 0)..Point::new(4, 0), cx)
+            })
+            .await
+            .unwrap();
+        cx.run_until_parked();
+        assert_eq!(
+            buffer.read_with(cx, |buffer, _| buffer.text()),
+            "abc\ndef\nghi\njkl\nmno"
+        );
+        assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
+    }
+
+    #[gpui::test(iterations = 10)]
+    async fn test_reject_deleted_file(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(path!("/dir"), json!({"file": "content"}))
+            .await;
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let file_path = project
+            .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+            .unwrap();
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path.clone(), cx))
+            .await
+            .unwrap();
+
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| log.will_delete_buffer(buffer.clone(), cx));
+        });
+        project
+            .update(cx, |project, cx| {
+                project.delete_file(file_path.clone(), false, cx)
+            })
+            .unwrap()
+            .await
+            .unwrap();
+        cx.run_until_parked();
+        assert!(!fs.is_file(path!("/dir/file").as_ref()).await);
+        assert_eq!(
+            unreviewed_hunks(&action_log, cx),
+            vec![(
+                buffer.clone(),
+                vec![HunkStatus {
+                    range: Point::new(0, 0)..Point::new(0, 0),
+                    diff_status: DiffHunkStatusKind::Deleted,
+                    old_text: "content".into(),
+                }]
+            )]
+        );
+
+        action_log
+            .update(cx, |log, cx| {
+                log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 0), cx)
+            })
+            .await
+            .unwrap();
+        cx.run_until_parked();
+        assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "content");
+        assert!(fs.is_file(path!("/dir/file").as_ref()).await);
+        assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
+    }
+
+    #[gpui::test(iterations = 10)]
+    async fn test_reject_created_file(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let file_path = project
+            .read_with(cx, |project, cx| {
+                project.find_project_path("dir/new_file", cx)
+            })
+            .unwrap();
+
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .unwrap();
+        cx.update(|cx| {
+            buffer.update(cx, |buffer, cx| buffer.set_text("content", cx));
+            action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx));
+        });
+        project
+            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+            .await
+            .unwrap();
+        assert!(fs.is_file(path!("/dir/new_file").as_ref()).await);
+        cx.run_until_parked();
+        assert_eq!(
+            unreviewed_hunks(&action_log, cx),
+            vec![(
+                buffer.clone(),
+                vec![HunkStatus {
+                    range: Point::new(0, 0)..Point::new(0, 7),
+                    diff_status: DiffHunkStatusKind::Added,
+                    old_text: "".into(),
+                }],
+            )]
+        );
+
+        action_log
+            .update(cx, |log, cx| {
+                log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 11), cx)
+            })
+            .await
+            .unwrap();
+        cx.run_until_parked();
+        assert!(!fs.is_file(path!("/dir/new_file").as_ref()).await);
+        assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
+    }
+
     #[gpui::test(iterations = 100)]
     async fn test_random_diffs(mut rng: StdRng, cx: &mut TestAppContext) {
+        init_test(cx);
+
         let operations = env::var("OPERATIONS")
             .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
             .unwrap_or(20);
 
-        let action_log = cx.new(|_| ActionLog::new());
         let text = RandomCharIter::new(&mut rng).take(50).collect::<String>();
-        let buffer = cx.new(|cx| Buffer::local(text, cx));
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(path!("/dir"), json!({"file": text})).await;
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let file_path = project
+            .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+            .unwrap();
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .unwrap();
+
         action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
 
         for _ in 0..operations {
@@ -992,10 +1297,20 @@ mod tests {
                 0..25 => {
                     action_log.update(cx, |log, cx| {
                         let range = buffer.read(cx).random_byte_range(0, &mut rng);
-                        log::info!("keeping all edits in range {:?}", range);
+                        log::info!("keeping edits in range {:?}", range);
                         log.keep_edits_in_range(buffer.clone(), range, cx)
                     });
                 }
+                25..50 => {
+                    action_log
+                        .update(cx, |log, cx| {
+                            let range = buffer.read(cx).random_byte_range(0, &mut rng);
+                            log::info!("rejecting edits in range {:?}", range);
+                            log.reject_edits_in_range(buffer.clone(), range, cx)
+                        })
+                        .await
+                        .unwrap();
+                }
                 _ => {
                     let is_agent_change = rng.gen_bool(0.5);
                     if is_agent_change {