Add a 'rejected patch' field to example specs, for DPO examples (#47043)

Max Brunsfeld created

The `capture example` action now populates the markdown file with a noop
"Rejected Patch", so that you can easily specify the good and bad
output.

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/capture_example.rs  | 69 +++++++++++++------
crates/edit_prediction/src/example_spec.rs     | 31 ++++++--
crates/edit_prediction_cli/src/git.rs          | 19 ++++
crates/edit_prediction_cli/src/load_project.rs | 52 ++++++++++----
crates/edit_prediction_cli/src/split_commit.rs |  2 
crates/edit_prediction_cli/src/synthesize.rs   |  1 
6 files changed, 128 insertions(+), 46 deletions(-)

Detailed changes

crates/edit_prediction/src/capture_example.rs 🔗

@@ -78,6 +78,7 @@ pub fn capture_example(
         // Initialize an empty patch with context lines, to make it easy
         // to write the expected patch by hand.
         let mut expected_patches = Vec::new();
+        let mut rejected_patch = None;
         if populate_expected_patch {
             let mut empty_patch = String::new();
             let start_row = cursor_excerpt_range.start.row + 1;
@@ -93,7 +94,9 @@ pub fn capture_example(
             for line in cursor_excerpt.lines() {
                 writeln!(&mut empty_patch, " {}", line).ok();
             }
-            expected_patches.push(empty_patch);
+
+            expected_patches.push(empty_patch.clone());
+            rejected_patch = Some(empty_patch);
         }
 
         let mut spec = ExampleSpec {
@@ -107,6 +110,7 @@ pub fn capture_example(
             cursor_position: String::new(),
             edit_history,
             expected_patches,
+            rejected_patch,
         };
         spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
         Ok(spec)
@@ -483,27 +487,50 @@ mod tests {
                 .to_string(),
                 expected_patches: vec![
                     indoc! {"
-                    --- a/src/main.rs
-                    +++ b/src/main.rs
-                    @@ -1,16 +1,16 @@
-                     fn main() {
-                         // comment 1
-                         one();
-                         two();
-                         // comment 4
-                         three();
-                         four();
-                         // comment 3
-                         five();
-                         six();
-                         seven();
-                         eight();
-                         // comment 2
-                         nine();
-                     }
-                "}
+                        --- a/src/main.rs
+                        +++ b/src/main.rs
+                        @@ -1,16 +1,16 @@
+                         fn main() {
+                             // comment 1
+                             one();
+                             two();
+                             // comment 4
+                             three();
+                             four();
+                             // comment 3
+                             five();
+                             six();
+                             seven();
+                             eight();
+                             // comment 2
+                             nine();
+                         }
+                    "}
+                    .to_string()
+                ],
+                rejected_patch: Some(
+                    indoc! {"
+                        --- a/src/main.rs
+                        +++ b/src/main.rs
+                        @@ -1,16 +1,16 @@
+                         fn main() {
+                             // comment 1
+                             one();
+                             two();
+                             // comment 4
+                             three();
+                             four();
+                             // comment 3
+                             five();
+                             six();
+                             seven();
+                             eight();
+                             // comment 2
+                             nine();
+                         }
+                    "}
                     .to_string()
-                ]
+                )
             }
         );
     }

crates/edit_prediction/src/example_spec.rs 🔗

@@ -21,6 +21,8 @@ pub struct ExampleSpec {
     pub cursor_position: String,
     pub edit_history: String,
     pub expected_patches: Vec<String>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub rejected_patch: Option<String>,
 }
 
 const REASONING_HEADING: &str = "Reasoning";
@@ -28,7 +30,7 @@ const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
 const EDIT_HISTORY_HEADING: &str = "Edit History";
 const CURSOR_POSITION_HEADING: &str = "Cursor Position";
 const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
-const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
+const REJECTED_PATCH_HEADING: &str = "Rejected Patch";
 
 #[derive(Serialize, Deserialize)]
 struct FrontMatter<'a> {
@@ -136,6 +138,18 @@ impl ExampleSpec {
             markdown.push('\n');
         }
 
+        if let Some(rejected_patch) = &self.rejected_patch {
+            _ = writeln!(markdown, "## {}", REJECTED_PATCH_HEADING);
+            markdown.push('\n');
+            _ = writeln!(markdown, "```diff");
+            markdown.push_str(rejected_patch);
+            if !markdown.ends_with('\n') {
+                markdown.push('\n');
+            }
+            _ = writeln!(markdown, "```");
+            markdown.push('\n');
+        }
+
         markdown
     }
 
@@ -154,6 +168,7 @@ impl ExampleSpec {
             cursor_position: String::new(),
             edit_history: String::new(),
             expected_patches: Vec::new(),
+            rejected_patch: None,
         };
 
         if let Some(rest) = input.strip_prefix("+++\n")
@@ -177,8 +192,8 @@ impl ExampleSpec {
             UncommittedDiff,
             EditHistory,
             CursorPosition,
-            ExpectedExcerpts,
             ExpectedPatch,
+            RejectedPatch,
             Other,
         }
 
@@ -202,8 +217,8 @@ impl ExampleSpec {
                         Section::CursorPosition
                     } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
                         Section::ExpectedPatch
-                    } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
-                        Section::ExpectedExcerpts
+                    } else if title.eq_ignore_ascii_case(REJECTED_PATCH_HEADING) {
+                        Section::RejectedPatch
                     } else {
                         Section::Other
                     };
@@ -244,12 +259,12 @@ impl ExampleSpec {
                             spec.cursor_path = Path::new(block_info).into();
                             spec.cursor_position = mem::take(&mut text);
                         }
-                        Section::ExpectedExcerpts => {
-                            mem::take(&mut text);
-                        }
                         Section::ExpectedPatch => {
                             spec.expected_patches.push(mem::take(&mut text));
                         }
+                        Section::RejectedPatch => {
+                            spec.rejected_patch = Some(mem::take(&mut text));
+                        }
                         Section::Start | Section::Other => {}
                     }
                 }
@@ -399,6 +414,7 @@ mod tests {
             cursor_position: String::new(),
             edit_history: String::new(),
             expected_patches: Vec::new(),
+            rejected_patch: None,
         };
 
         // Cursor before `42`
@@ -531,6 +547,7 @@ mod tests {
             cursor_position: String::new(),
             edit_history: String::new(),
             expected_patches: Vec::new(),
+            rejected_patch: None,
         };
 
         // Cursor before `42` using inline marker

crates/edit_prediction_cli/src/git.rs 🔗

@@ -68,11 +68,26 @@ pub async fn ensure_repo_cloned(repo_url: &str) -> Result<PathBuf> {
     let repo_path = repo_path_for_url(repo_url)?;
     let _lock = lock_repo(&repo_path).await;
 
-    if !repo_path.is_dir() {
+    // Validate existing repo has correct origin, otherwise remove and re-init.
+    let mut git_repo_exists = false;
+    if repo_path.is_dir() {
+        if run_git(&repo_path, &["remote", "get-url", "origin"])
+            .await
+            .map_or(false, |origin| origin.trim() == repo_url)
+        {
+            git_repo_exists = true;
+        } else {
+            std::fs::remove_dir_all(&repo_path).ok();
+        }
+    }
+
+    if !git_repo_exists {
         log::info!("Cloning {} into {:?}", repo_url, repo_path);
         std::fs::create_dir_all(&repo_path)?;
         run_git(&repo_path, &["init"]).await?;
-        run_git(&repo_path, &["remote", "add", "origin", repo_url]).await?;
+        run_git(&repo_path, &["remote", "add", "origin", repo_url])
+            .await
+            .ok();
     }
 
     // Always fetch to get latest commits

crates/edit_prediction_cli/src/load_project.rs 🔗

@@ -254,14 +254,23 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
     let worktree_git_dir = repo_dir
         .join(".git/worktrees")
         .join(repo_name.name.as_ref());
-    let index_lock = worktree_git_dir.join("index.lock");
-    if index_lock.exists() {
-        fs::remove_file(&index_lock).ok();
+    for lock_file in &["index.lock", "HEAD.lock", "config.lock"] {
+        let worktree_lock_path = worktree_git_dir.join(lock_file);
+        let repo_lock_path = repo_dir.join(".git").join(lock_file);
+        if worktree_lock_path.exists() {
+            fs::remove_file(&worktree_lock_path).ok();
+        }
+        if repo_lock_path.exists() {
+            fs::remove_file(&repo_lock_path).ok();
+        }
     }
 
     let mut git_repo_exists = false;
     if repo_dir.is_dir() {
-        if git::run_git(&repo_dir, &["status"]).await.is_ok() {
+        if git::run_git(&repo_dir, &["remote", "get-url", "origin"])
+            .await
+            .map_or(false, |origin| origin.trim() == example.spec.repository_url)
+        {
             git_repo_exists = true;
         } else {
             fs::remove_dir_all(&repo_dir).ok();
@@ -283,25 +292,36 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
     step_progress.set_substatus("fetching");
     let revision = git::fetch_if_needed(&repo_dir, &example.spec.revision).await?;
 
+    // Clean up any stale worktree registrations from previous crashed runs.
+    git::run_git(&repo_dir, &["worktree", "prune"]).await.ok();
+
     // Create the worktree for this example if needed.
     step_progress.set_substatus("preparing worktree");
 
-    let mut worktree_exists = false;
-    if worktree_path.is_dir() {
-        if git::run_git(&worktree_path, &["clean", "--force", "-d"])
+    // Check if worktree exists and is valid (not just a directory from a crashed run).
+    let worktree_valid = worktree_path.is_dir()
+        && git::run_git(&worktree_path, &["rev-parse", "--git-dir"])
             .await
-            .is_ok()
-        {
-            git::run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
-            git::run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
-            worktree_exists = true;
-        } else {
+            .is_ok();
+
+    if worktree_valid {
+        git::run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
+        git::run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
+        git::run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
+    } else {
+        let worktree_path_string = worktree_path.to_string_lossy();
+
+        // Clean up invalid worktree directory and registration if they exist.
+        if worktree_path.exists() {
             fs::remove_dir_all(&worktree_path).ok();
         }
-    }
+        git::run_git(
+            &repo_dir,
+            &["worktree", "remove", "--force", &worktree_path_string],
+        )
+        .await
+        .ok();
 
-    if !worktree_exists {
-        let worktree_path_string = worktree_path.to_string_lossy();
         let branch_name = example.spec.filename();
         git::run_git(
             &repo_dir,

crates/edit_prediction_cli/src/split_commit.rs 🔗

@@ -344,6 +344,7 @@ pub fn generate_evaluation_example_from_ordered_commit(
         tags: vec![],
         reasoning: None,
         uncommitted_diff: String::new(),
+        rejected_patch: None,
     })
 }
 
@@ -1374,6 +1375,7 @@ Date: Mon Jan 1 00:00:00 2024
             tags: vec![],
             reasoning: None,
             uncommitted_diff: String::new(),
+            rejected_patch: None,
         };
 
         let json = serde_json::to_string(&case).unwrap();

crates/edit_prediction_cli/src/synthesize.rs 🔗

@@ -791,6 +791,7 @@ async fn build_example(
         cursor_position: String::new(),
         edit_history,
         expected_patches: vec![expected_patch_with_header],
+        rejected_patch: None,
     };
     spec.set_cursor_excerpt(&excerpt, cursor_offset, comment_prefix);