ep: Heuristic for detecting wrong editable region (#48343)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

Cargo.lock                                |  8 +-
crates/edit_prediction_cli/src/example.rs |  1 
crates/edit_prediction_cli/src/metrics.rs | 48 ++++++++++++++++
crates/edit_prediction_cli/src/score.rs   | 75 +++++++++++++++++++-----
4 files changed, 111 insertions(+), 21 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -2090,7 +2090,7 @@ dependencies = [
  "bitflags 2.9.4",
  "cexpr",
  "clang-sys",
- "itertools 0.10.5",
+ "itertools 0.12.1",
  "log",
  "prettyplease",
  "proc-macro2",
@@ -2110,7 +2110,7 @@ dependencies = [
  "bitflags 2.9.4",
  "cexpr",
  "clang-sys",
- "itertools 0.10.5",
+ "itertools 0.12.1",
  "proc-macro2",
  "quote",
  "regex",
@@ -12965,7 +12965,7 @@ checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
 dependencies = [
  "bytes 1.11.1",
  "heck 0.5.0",
- "itertools 0.10.5",
+ "itertools 0.12.1",
  "log",
  "multimap 0.10.1",
  "once_cell",
@@ -12998,7 +12998,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1"
 dependencies = [
  "anyhow",
- "itertools 0.10.5",
+ "itertools 0.12.1",
  "proc-macro2",
  "quote",
  "syn 2.0.106",

crates/edit_prediction_cli/src/example.rs 🔗

@@ -116,6 +116,7 @@ pub struct ExampleScore {
     pub cursor_distance: Option<usize>,
     #[serde(default, skip_serializing_if = "Option::is_none")]
     pub cursor_exact_match: Option<bool>,
+    pub wrong_editable_region: Option<bool>,
 }
 
 impl Example {

crates/edit_prediction_cli/src/metrics.rs 🔗

@@ -1,5 +1,7 @@
 use collections::HashMap;
 
+use crate::reorder_patch::{Patch, PatchLine};
+
 pub type Counts = HashMap<String, usize>;
 type CountsDelta = HashMap<String, isize>;
 
@@ -382,6 +384,32 @@ pub fn exact_lines_match(expected_patch: &str, actual_patch: &str) -> Classifica
     ClassificationMetrics::from_counts(&expected_lines, &actual_lines)
 }
 
+/// A simple proxy for whether the prediction respects editable region.
+pub fn is_editable_region_correct(actual_patch: &str) -> bool {
+    // A typical sign of a wrong editable region: a bunch of lines deletion
+    // at the beginning or end of the patch.
+    let patch = Patch::parse_unified_diff(actual_patch);
+    if patch.hunks.is_empty() {
+        return true;
+    }
+
+    let hunk = &patch.hunks[0];
+    let mut deletions_at_start = 0;
+
+    for line in hunk.lines.iter() {
+        match line {
+            PatchLine::Deletion(_) => deletions_at_start += 1,
+            _ => break,
+        }
+    }
+
+    if deletions_at_start >= 3 {
+        return false;
+    }
+
+    true
+}
+
 #[cfg(test)]
 mod test_optimization {
     use super::*;
@@ -530,6 +558,7 @@ mod test_optimization {
 #[cfg(test)]
 mod test {
     use super::*;
+    use indoc::indoc;
 
     #[test]
     fn test_delta_chr_f_perfect_match() {
@@ -726,4 +755,23 @@ index abc123..def456 100644
         assert_eq!(metrics.false_positives, 0);
         assert_eq!(metrics.false_negatives, 0);
     }
+
+    #[test]
+    fn test_is_editable_region_correct() {
+        let patch = indoc! {"
+            @@ -1,1 +1,1 @@
+            -context
+            -removed
+            -from the beginning of the file
+            import sys
+            +sys.exit(0)
+
+            "};
+        assert!(!is_editable_region_correct(patch));
+
+        let patch = indoc! {"
+            @@ -1,1 +1,1 @@
+            "};
+        assert!(is_editable_region_correct(patch));
+    }
 }

crates/edit_prediction_cli/src/score.rs 🔗

@@ -74,6 +74,7 @@ pub async fn run_scoring(
         reversal_ratio: 0.0,
         cursor_distance: None,
         cursor_exact_match: None,
+        wrong_editable_region: None,
     };
 
     let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
@@ -140,16 +141,6 @@ pub async fn run_scoring(
         let disbalance_before = metrics::braces_disbalance(&original_text);
         let disbalance_after = metrics::braces_disbalance(&actual_text);
         let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
-        if braces_disbalance > 0 {
-            std::fs::write(
-                "/tmp/unbalanced-count.before",
-                disbalance_before.to_string(),
-            )
-            .ok();
-            std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok();
-            std::fs::write("/tmp/unbalanced-text.before", &original_text).ok();
-            std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
-        }
 
         // Compute exact lines match against best matching expected patch
         let best_exact_lines = expected_patches_with_cursors
@@ -169,6 +160,9 @@ pub async fn run_scoring(
         let (cursor_distance, cursor_exact_match) =
             compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor_offset);
 
+        // Compute approximation of editable region correctness
+        let wrong_editable_region = Some(!metrics::is_editable_region_correct(&actual_patch));
+
         scores.push(ExampleScore {
             delta_chr_f: best_delta_chr_f,
             braces_disbalance,
@@ -178,6 +172,7 @@ pub async fn run_scoring(
             reversal_ratio,
             cursor_distance,
             cursor_exact_match,
+            wrong_editable_region,
         });
     }
 
@@ -209,13 +204,13 @@ fn compute_cursor_metrics(
 pub fn print_report(examples: &[Example]) {
     use crate::metrics::ClassificationMetrics;
 
-    const LINE_WIDTH: usize = 94;
+    const LINE_WIDTH: usize = 101;
     let separator = "─".repeat(LINE_WIDTH);
 
     println!("{}", separator);
     println!(
-        "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6}",
-        "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor"
+        "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6} {:>5}",
+        "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor", "WrgER"
     );
     println!("{}", separator);
 
@@ -232,6 +227,8 @@ pub fn print_report(examples: &[Example]) {
     let mut cursor_total: usize = 0;
     let mut cursor_distance_sum: usize = 0;
     let mut cursor_distance_count: usize = 0;
+    let mut wrong_editable_region_count: usize = 0;
+    let mut wrong_editable_region_total: usize = 0;
 
     for example in examples {
         for (score_idx, score) in example.score.iter().enumerate() {
@@ -252,6 +249,13 @@ pub fn print_report(examples: &[Example]) {
                 .map(|v| format!("{}", v))
                 .unwrap_or("-".to_string());
 
+            // Format wrong editable region metric
+            let wrong_er_str = match score.wrong_editable_region {
+                Some(true) => "✗",
+                Some(false) => "",
+                None => "",
+            };
+
             // Format cursor metric
             let cursor_str = match (score.cursor_exact_match, score.cursor_distance) {
                 (Some(true), _) => "✓".to_string(),
@@ -261,7 +265,7 @@ pub fn print_report(examples: &[Example]) {
             };
 
             println!(
-                "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6}",
+                "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
                 truncate_name(&example.spec.name, 40),
                 score.delta_chr_f,
                 score.braces_disbalance,
@@ -269,7 +273,8 @@ pub fn print_report(examples: &[Example]) {
                 score.reversal_ratio * 100.0,
                 qa_reverts_str,
                 qa_conf_str,
-                cursor_str
+                cursor_str,
+                wrong_er_str
             );
 
             all_delta_chr_f_scores.push(score.delta_chr_f);
@@ -294,6 +299,14 @@ pub fn print_report(examples: &[Example]) {
                 }
             }
 
+            // Accumulate wrong editable region metrics
+            if let Some(wrong) = score.wrong_editable_region {
+                wrong_editable_region_total += 1;
+                if wrong {
+                    wrong_editable_region_count += 1;
+                }
+            }
+
             // Accumulate cursor metrics
             if let Some(exact_match) = score.cursor_exact_match {
                 cursor_total += 1;
@@ -341,6 +354,14 @@ pub fn print_report(examples: &[Example]) {
         } else {
             "-".to_string()
         };
+        let wrong_er_str = if wrong_editable_region_total > 0 {
+            format!(
+                "{:.2}%",
+                wrong_editable_region_count as f32 / wrong_editable_region_total as f32 * 100.0
+            )
+        } else {
+            "-".to_string()
+        };
         let avg_cursor_distance = if cursor_distance_count > 0 {
             Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
         } else {
@@ -348,7 +369,7 @@ pub fn print_report(examples: &[Example]) {
         };
 
         println!(
-            "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6}",
+            "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
             "TOTAL / AVERAGE",
             avg_delta_chr_f,
             braces_disbalance_avg,
@@ -356,7 +377,8 @@ pub fn print_report(examples: &[Example]) {
             avg_reversal_ratio * 100.0,
             qa_reverts_str,
             qa_conf_str,
-            cursor_str
+            cursor_str,
+            wrong_er_str
         );
         println!("{}", separator);
 
@@ -405,6 +427,8 @@ pub struct SummaryJson {
     pub cursor_avg_distance: Option<f32>,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub cursor_total_evaluated: Option<usize>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub wrong_editable_region_rate: Option<f32>,
 }
 
 pub fn compute_summary(examples: &[Example]) -> SummaryJson {
@@ -423,6 +447,8 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
     let mut cursor_total: usize = 0;
     let mut cursor_distance_sum: usize = 0;
     let mut cursor_distance_count: usize = 0;
+    let mut wrong_editable_region_count: usize = 0;
+    let mut wrong_editable_region_total: usize = 0;
 
     for example in examples {
         for (score_idx, score) in example.score.iter().enumerate() {
@@ -448,6 +474,14 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
                 }
             }
 
+            // Accumulate wrong editable region metrics
+            if let Some(wrong) = score.wrong_editable_region {
+                wrong_editable_region_total += 1;
+                if wrong {
+                    wrong_editable_region_count += 1;
+                }
+            }
+
             // Accumulate cursor metrics
             if let Some(exact_match) = score.cursor_exact_match {
                 cursor_total += 1;
@@ -510,6 +544,12 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
         None
     };
 
+    let wrong_editable_region_rate = if wrong_editable_region_total > 0 {
+        Some(wrong_editable_region_count as f32 / wrong_editable_region_total as f32)
+    } else {
+        None
+    };
+
     SummaryJson {
         total_examples: total_scores,
         avg_delta_chr_f,
@@ -526,6 +566,7 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
         cursor_exact_match_rate,
         cursor_avg_distance,
         cursor_total_evaluated,
+        wrong_editable_region_rate,
     }
 }