Restore file to original content when rejecting file recreated by agent (#29264)

Antonio Scandurra created

Release Notes:

- Fixed a bug that could sometimes cause a file to be deleted when
rejecting an agent change.

Change summary

crates/agent/src/agent_diff.rs                  |   2 
crates/agent/src/thread.rs                      |  14 -
crates/assistant_tool/src/action_log.rs         | 101 +++++++++++-------
crates/assistant_tools/src/code_action_tool.rs  |   2 
crates/assistant_tools/src/code_symbols_tool.rs |   2 
crates/assistant_tools/src/contents_tool.rs     |   4 
crates/assistant_tools/src/create_file_tool.rs  |   5 
crates/assistant_tools/src/edit_file_tool.rs    |   2 
crates/assistant_tools/src/read_file_tool.rs    |   4 
crates/assistant_tools/src/rename_tool.rs       |   2 
crates/assistant_tools/src/symbol_info_tool.rs  |   2 
11 files changed, 77 insertions(+), 63 deletions(-)

Detailed changes

crates/agent/src/agent_diff.rs 🔗

@@ -988,7 +988,7 @@ mod tests {
             .await
             .unwrap();
         cx.update(|_, cx| {
-            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+            action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
             buffer.update(cx, |buffer, cx| {
                 buffer
                     .edit(

crates/agent/src/thread.rs 🔗

@@ -770,24 +770,18 @@ impl Thread {
                 for ctx in &new_context {
                     match ctx {
                         AssistantContext::File(file_ctx) => {
-                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
+                            log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
                         }
                         AssistantContext::Directory(dir_ctx) => {
                             for context_buffer in &dir_ctx.context_buffers {
-                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
+                                log.track_buffer(context_buffer.buffer.clone(), cx);
                             }
                         }
                         AssistantContext::Symbol(symbol_ctx) => {
-                            log.buffer_added_as_context(
-                                symbol_ctx.context_symbol.buffer.clone(),
-                                cx,
-                            );
+                            log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
                         }
                         AssistantContext::Selection(selection_context) => {
-                            log.buffer_added_as_context(
-                                selection_context.context_buffer.buffer.clone(),
-                                cx,
-                            );
+                            log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
                         }
                         AssistantContext::FetchedUrl(_)
                         | AssistantContext::Thread(_)

crates/assistant_tool/src/action_log.rs 🔗

@@ -39,10 +39,9 @@ impl ActionLog {
         self.edited_since_project_diagnostics_check
     }
 
-    fn track_buffer(
+    fn track_buffer_internal(
         &mut self,
         buffer: Entity<Buffer>,
-        created: bool,
         cx: &mut Context<Self>,
     ) -> &mut TrackedBuffer {
         let tracked_buffer = self
@@ -59,7 +58,11 @@ impl ActionLog {
                 let base_text;
                 let status;
                 let unreviewed_changes;
-                if created {
+                if buffer
+                    .read(cx)
+                    .file()
+                    .map_or(true, |file| !file.disk_state().exists())
+                {
                     base_text = Rope::default();
                     status = TrackedBufferStatus::Created;
                     unreviewed_changes = Patch::new(vec![Edit {
@@ -146,7 +149,7 @@ impl ActionLog {
                     // resurrected externally, we want to clear the changes we
                     // were tracking and reset the buffer's state.
                     self.tracked_buffers.remove(&buffer);
-                    self.track_buffer(buffer, false, cx);
+                    self.track_buffer_internal(buffer, cx);
                 }
                 cx.notify();
             }
@@ -260,26 +263,15 @@ impl ActionLog {
     }
 
     /// Track a buffer as read, so we can notify the model about user edits.
-    pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
-        self.track_buffer(buffer, false, cx);
-    }
-
-    /// Track a buffer that was added as context, so we can notify the model about user edits.
-    pub fn buffer_added_as_context(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
-        self.track_buffer(buffer, false, cx);
-    }
-
-    /// Track a buffer as read, so we can notify the model about user edits.
-    pub fn will_create_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
-        self.track_buffer(buffer.clone(), true, cx);
-        self.buffer_edited(buffer, cx)
+    pub fn track_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
+        self.track_buffer_internal(buffer, cx);
     }
 
     /// Mark a buffer as edited, so we can refresh it in the context
     pub fn buffer_edited(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
         self.edited_since_project_diagnostics_check = true;
 
-        let tracked_buffer = self.track_buffer(buffer.clone(), false, cx);
+        let tracked_buffer = self.track_buffer_internal(buffer.clone(), cx);
         if let TrackedBufferStatus::Deleted = tracked_buffer.status {
             tracked_buffer.status = TrackedBufferStatus::Modified;
         }
@@ -287,7 +279,7 @@ impl ActionLog {
     }
 
     pub fn will_delete_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
-        let tracked_buffer = self.track_buffer(buffer.clone(), false, cx);
+        let tracked_buffer = self.track_buffer_internal(buffer.clone(), cx);
         match tracked_buffer.status {
             TrackedBufferStatus::Created => {
                 self.tracked_buffers.remove(&buffer);
@@ -397,7 +389,7 @@ impl ActionLog {
 
                 // Clear all tracked changes for this buffer and start over as if we just read it.
                 self.tracked_buffers.remove(&buffer);
-                self.track_buffer(buffer.clone(), false, cx);
+                self.track_buffer_internal(buffer.clone(), cx);
                 cx.notify();
                 save
             }
@@ -695,12 +687,20 @@ mod tests {
         init_test(cx);
 
         let fs = FakeFs::new(cx.executor());
-        let project = Project::test(fs.clone(), [], cx).await;
+        fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
+            .await;
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
-        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
+        let file_path = project
+            .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+            .unwrap();
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .unwrap();
 
         cx.update(|cx| {
-            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+            action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
             buffer.update(cx, |buffer, cx| {
                 buffer
                     .edit([(Point::new(1, 1)..Point::new(1, 2), "E")], None, cx)
@@ -765,12 +765,23 @@ mod tests {
         init_test(cx);
 
         let fs = FakeFs::new(cx.executor());
-        let project = Project::test(fs.clone(), [], cx).await;
+        fs.insert_tree(
+            path!("/dir"),
+            json!({"file": "abc\ndef\nghi\njkl\nmno\npqr"}),
+        )
+        .await;
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
-        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno\npqr", cx));
+        let file_path = project
+            .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+            .unwrap();
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .unwrap();
 
         cx.update(|cx| {
-            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+            action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
             buffer.update(cx, |buffer, cx| {
                 buffer
                     .edit([(Point::new(1, 0)..Point::new(2, 0), "")], None, cx)
@@ -839,12 +850,20 @@ mod tests {
         init_test(cx);
 
         let fs = FakeFs::new(cx.executor());
-        let project = Project::test(fs.clone(), [], cx).await;
+        fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
+            .await;
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
-        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
+        let file_path = project
+            .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+            .unwrap();
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .unwrap();
 
         cx.update(|cx| {
-            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+            action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
             buffer.update(cx, |buffer, cx| {
                 buffer
                     .edit([(Point::new(1, 2)..Point::new(2, 3), "F\nGHI")], None, cx)
@@ -927,26 +946,22 @@ mod tests {
     async fn test_creating_files(cx: &mut TestAppContext) {
         init_test(cx);
 
-        let fs = FakeFs::new(cx.executor());
-        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
-        let action_log = cx.new(|_| ActionLog::new(project.clone()));
-
         let fs = FakeFs::new(cx.executor());
         fs.insert_tree(path!("/dir"), json!({})).await;
-
         let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let file_path = project
             .read_with(cx, |project, cx| project.find_project_path("dir/file1", cx))
             .unwrap();
 
-        // Simulate file2 being recreated by a tool.
         let buffer = project
             .update(cx, |project, cx| project.open_buffer(file_path, cx))
             .await
             .unwrap();
         cx.update(|cx| {
+            action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
             buffer.update(cx, |buffer, cx| buffer.set_text("lorem", cx));
-            action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx));
+            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
         });
         project
             .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
@@ -1067,8 +1082,9 @@ mod tests {
             .update(cx, |project, cx| project.open_buffer(file2_path, cx))
             .await
             .unwrap();
+        action_log.update(cx, |log, cx| log.track_buffer(buffer2.clone(), cx));
         buffer2.update(cx, |buffer, cx| buffer.set_text("IPSUM", cx));
-        action_log.update(cx, |log, cx| log.will_create_buffer(buffer2.clone(), cx));
+        action_log.update(cx, |log, cx| log.buffer_edited(buffer2.clone(), cx));
         project
             .update(cx, |project, cx| project.save_buffer(buffer2.clone(), cx))
             .await
@@ -1113,7 +1129,7 @@ mod tests {
             .unwrap();
 
         cx.update(|cx| {
-            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+            action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
             buffer.update(cx, |buffer, cx| {
                 buffer
                     .edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
@@ -1248,7 +1264,7 @@ mod tests {
             .unwrap();
 
         cx.update(|cx| {
-            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+            action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
             buffer.update(cx, |buffer, cx| {
                 buffer
                     .edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
@@ -1381,8 +1397,9 @@ mod tests {
             .await
             .unwrap();
         cx.update(|cx| {
+            action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
             buffer.update(cx, |buffer, cx| buffer.set_text("content", cx));
-            action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx));
+            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
         });
         project
             .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
@@ -1438,7 +1455,7 @@ mod tests {
             .await
             .unwrap();
 
-        action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+        action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
 
         for _ in 0..operations {
             match rng.gen_range(0..100) {
@@ -1490,7 +1507,7 @@ mod tests {
             log::info!("quiescing...");
             cx.run_until_parked();
             action_log.update(cx, |log, cx| {
-                let tracked_buffer = log.track_buffer(buffer.clone(), false, cx);
+                let tracked_buffer = log.track_buffer_internal(buffer.clone(), cx);
                 let mut old_text = tracked_buffer.base_text.clone();
                 let new_text = buffer.read(cx).as_rope();
                 for edit in tracked_buffer.unreviewed_changes.edits() {

crates/assistant_tools/src/code_action_tool.rs 🔗

@@ -159,7 +159,7 @@ impl Tool for CodeActionTool {
             };
 
             action_log.update(cx, |action_log, cx| {
-                action_log.buffer_read(buffer.clone(), cx);
+                action_log.track_buffer(buffer.clone(), cx);
             })?;
 
             let range = {

crates/assistant_tools/src/code_symbols_tool.rs 🔗

@@ -174,7 +174,7 @@ pub async fn file_outline(
     };
 
     action_log.update(cx, |action_log, cx| {
-        action_log.buffer_read(buffer.clone(), cx);
+        action_log.track_buffer(buffer.clone(), cx);
     })?;
 
     // Wait until the buffer has been fully parsed, so that we can read its outline.

crates/assistant_tools/src/contents_tool.rs 🔗

@@ -209,7 +209,7 @@ impl Tool for ContentsTool {
                     })?;
 
                     action_log.update(cx, |log, cx| {
-                        log.buffer_read(buffer, cx);
+                        log.track_buffer(buffer, cx);
                     })?;
 
                     Ok(result)
@@ -221,7 +221,7 @@ impl Tool for ContentsTool {
                         let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
 
                         action_log.update(cx, |log, cx| {
-                            log.buffer_read(buffer, cx);
+                            log.track_buffer(buffer, cx);
                         })?;
 
                         Ok(result)

crates/assistant_tools/src/create_file_tool.rs 🔗

@@ -112,9 +112,12 @@ impl Tool for CreateFileTool {
                 .await
                 .map_err(|err| anyhow!("Unable to open buffer for {destination_path}: {err}"))?;
             cx.update(|cx| {
+                action_log.update(cx, |action_log, cx| {
+                    action_log.track_buffer(buffer.clone(), cx)
+                });
                 buffer.update(cx, |buffer, cx| buffer.set_text(contents, cx));
                 action_log.update(cx, |action_log, cx| {
-                    action_log.will_create_buffer(buffer.clone(), cx)
+                    action_log.buffer_edited(buffer.clone(), cx)
                 });
             })?;
 

crates/assistant_tools/src/edit_file_tool.rs 🔗

@@ -182,7 +182,7 @@ impl Tool for EditFileTool {
 
             let snapshot = cx.update(|cx| {
                 action_log.update(cx, |log, cx| {
-                    log.buffer_read(buffer.clone(), cx)
+                    log.track_buffer(buffer.clone(), cx)
                 });
                 let snapshot = buffer.update(cx, |buffer, cx| {
                     buffer.finalize_last_transaction();

crates/assistant_tools/src/read_file_tool.rs 🔗

@@ -134,7 +134,7 @@ impl Tool for ReadFileTool {
                 })?;
 
                 action_log.update(cx, |log, cx| {
-                    log.buffer_read(buffer, cx);
+                    log.track_buffer(buffer, cx);
                 })?;
 
                 Ok(result)
@@ -147,7 +147,7 @@ impl Tool for ReadFileTool {
                     let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
 
                     action_log.update(cx, |log, cx| {
-                        log.buffer_read(buffer, cx);
+                        log.track_buffer(buffer, cx);
                     })?;
 
                     Ok(result)

crates/assistant_tools/src/rename_tool.rs 🔗

@@ -106,7 +106,7 @@ impl Tool for RenameTool {
             };
 
             action_log.update(cx, |action_log, cx| {
-                action_log.buffer_read(buffer.clone(), cx);
+                action_log.track_buffer(buffer.clone(), cx);
             })?;
 
             let position = {

crates/assistant_tools/src/symbol_info_tool.rs 🔗

@@ -140,7 +140,7 @@ impl Tool for SymbolInfoTool {
             };
 
             action_log.update(cx, |action_log, cx| {
-                action_log.buffer_read(buffer.clone(), cx);
+                action_log.track_buffer(buffer.clone(), cx);
             })?;
 
             let position = {