ep: Add line-level exact match metric (#47383)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/example.rs |   6 
crates/edit_prediction_cli/src/metrics.rs | 183 +++++++++++++++++++++++-
crates/edit_prediction_cli/src/score.rs   |  75 +++++++--
3 files changed, 236 insertions(+), 28 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/example.rs 🔗

@@ -83,6 +83,12 @@ pub struct ExamplePrediction {
 pub struct ExampleScore {
     pub delta_chr_f: f32,
     pub braces_disbalance: usize,
+    #[serde(default)]
+    pub exact_lines_tp: usize,
+    #[serde(default)]
+    pub exact_lines_fp: usize,
+    #[serde(default)]
+    pub exact_lines_fn: usize,
 }
 
 impl Example {

crates/edit_prediction_cli/src/metrics.rs 🔗

@@ -1,20 +1,20 @@
 use collections::HashMap;
 
-type Counts = HashMap<String, usize>;
+pub type Counts = HashMap<String, usize>;
 type CountsDelta = HashMap<String, isize>;
 
 /// Context characters needed on each side of a change to capture all affected n-grams
 const CONTEXT_CHARS: usize = CHR_F_CHAR_ORDER - 1;
 
 #[derive(Default, Debug, Clone)]
-struct ClassificationMetrics {
-    true_positives: usize,
-    false_positives: usize,
-    false_negatives: usize,
+pub struct ClassificationMetrics {
+    pub true_positives: usize,
+    pub false_positives: usize,
+    pub false_negatives: usize,
 }
 
 impl ClassificationMetrics {
-    fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
+    pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
         let mut true_positives = 0;
         let mut false_positives = 0;
         let mut false_negatives = 0;
@@ -42,7 +42,7 @@ impl ClassificationMetrics {
         }
     }
 
-    fn precision(&self) -> f64 {
+    pub fn precision(&self) -> f64 {
         if self.true_positives + self.false_positives == 0 {
             0.0
         } else {
@@ -50,13 +50,23 @@ impl ClassificationMetrics {
         }
     }
 
-    fn recall(&self) -> f64 {
+    pub fn recall(&self) -> f64 {
         if self.true_positives + self.false_negatives == 0 {
             0.0
         } else {
             self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
         }
     }
+
+    pub fn f1(&self) -> f64 {
+        let precision = self.precision();
+        let recall = self.recall();
+        if precision + recall == 0.0 {
+            0.0
+        } else {
+            2.0 * precision * recall / (precision + recall)
+        }
+    }
 }
 
 enum ChrfWhitespace {
@@ -335,6 +345,43 @@ pub fn braces_disbalance(text: &str) -> usize {
     disbalance as usize
 }
 
+/// Extracts changed lines from a unified diff string.
+/// Returns a bag (multiset) of lines that were added (+) or removed (-).
+/// The +/- prefix is included in the line to distinguish additions from deletions.
+pub fn extract_changed_lines_from_diff(diff: &str) -> Counts {
+    let mut counts = Counts::default();
+
+    for line in diff.lines() {
+        // Skip file headers (--- and +++)
+        if line.starts_with("---") || line.starts_with("+++") {
+            continue;
+        }
+        // Skip hunk headers (@@)
+        if line.starts_with("@@") {
+            continue;
+        }
+        // Skip diff header lines (diff --git, index, etc.)
+        if line.starts_with("diff ") || line.starts_with("index ") {
+            continue;
+        }
+        // Include added and removed lines (with their prefix)
+        if line.starts_with('+') || line.starts_with('-') {
+            *counts.entry(line.to_string()).or_insert(0) += 1;
+        }
+    }
+
+    counts
+}
+
+/// Computes exact lines match metrics between expected and actual patches.
+/// Treats changed lines as a bag (multiset) - order is discarded but count matters.
+/// Returns ClassificationMetrics with TP/FP/FN counts.
+pub fn exact_lines_match(expected_patch: &str, actual_patch: &str) -> ClassificationMetrics {
+    let expected_lines = extract_changed_lines_from_diff(expected_patch);
+    let actual_lines = extract_changed_lines_from_diff(actual_patch);
+    ClassificationMetrics::from_counts(&expected_lines, &actual_lines)
+}
+
 #[cfg(test)]
 mod test_optimization {
     use super::*;
@@ -559,4 +606,124 @@ mod test {
         let text = "let x = { 1 + 2 )";
         assert_eq!(braces_disbalance(text), 2);
     }
+
+    #[test]
+    fn test_extract_changed_lines_from_diff() {
+        let diff = r#"--- a/file.rs
++++ b/file.rs
+@@ -1,3 +1,3 @@
+ fn main() {
+-    println!("hello");
++    println!("world");
+ }"#;
+
+        let counts = extract_changed_lines_from_diff(diff);
+        assert_eq!(counts.get("-    println!(\"hello\");"), Some(&1));
+        assert_eq!(counts.get("+    println!(\"world\");"), Some(&1));
+        assert_eq!(counts.len(), 2);
+    }
+
+    #[test]
+    fn test_extract_changed_lines_skips_headers() {
+        let diff = r#"diff --git a/file.rs b/file.rs
+index abc123..def456 100644
+--- a/file.rs
++++ b/file.rs
+@@ -1,2 +1,2 @@
+-old line
++new line"#;
+
+        let counts = extract_changed_lines_from_diff(diff);
+        assert_eq!(counts.get("-old line"), Some(&1));
+        assert_eq!(counts.get("+new line"), Some(&1));
+        assert_eq!(counts.len(), 2);
+    }
+
+    #[test]
+    fn test_exact_lines_match_perfect() {
+        let expected = r#"--- a/file.rs
++++ b/file.rs
+@@ -1,3 +1,3 @@
+-old line 1
+-old line 2
++new line 1
++new line 2"#;
+
+        let actual = r#"--- a/file.rs
++++ b/file.rs
+@@ -1,3 +1,3 @@
+-old line 1
+-old line 2
++new line 1
++new line 2"#;
+
+        let metrics = exact_lines_match(expected, actual);
+        assert_eq!(metrics.true_positives, 4);
+        assert_eq!(metrics.false_positives, 0);
+        assert_eq!(metrics.false_negatives, 0);
+        assert!((metrics.precision() - 1.0).abs() < 1e-6);
+        assert!((metrics.recall() - 1.0).abs() < 1e-6);
+        assert!((metrics.f1() - 1.0).abs() < 1e-6);
+    }
+
+    #[test]
+    fn test_exact_lines_match_partial() {
+        let expected = r#"-old line 1
+-old line 2
++new line 1
++new line 2"#;
+
+        let actual = r#"-old line 1
++new line 1
++extra line"#;
+
+        let metrics = exact_lines_match(expected, actual);
+        // TP: "-old line 1" and "+new line 1" (2)
+        // FP: "+extra line" (1)
+        // FN: "-old line 2" and "+new line 2" (2)
+        assert_eq!(metrics.true_positives, 2);
+        assert_eq!(metrics.false_positives, 1);
+        assert_eq!(metrics.false_negatives, 2);
+    }
+
+    #[test]
+    fn test_exact_lines_match_no_overlap() {
+        let expected = r#"-line a
++line b"#;
+
+        let actual = r#"-line x
++line y"#;
+
+        let metrics = exact_lines_match(expected, actual);
+        assert_eq!(metrics.true_positives, 0);
+        assert_eq!(metrics.false_positives, 2);
+        assert_eq!(metrics.false_negatives, 2);
+        assert!((metrics.precision()).abs() < 1e-6);
+        assert!((metrics.recall()).abs() < 1e-6);
+    }
+
+    #[test]
+    fn test_exact_lines_match_duplicate_lines() {
+        let expected = r#"+line a
++line a
++line a"#;
+
+        let actual = r#"+line a
++line a"#;
+
+        let metrics = exact_lines_match(expected, actual);
+        // Expected has 3 "+line a", actual has 2
+        // TP: 2, FN: 1, FP: 0
+        assert_eq!(metrics.true_positives, 2);
+        assert_eq!(metrics.false_positives, 0);
+        assert_eq!(metrics.false_negatives, 1);
+    }
+
+    #[test]
+    fn test_exact_lines_match_empty_patches() {
+        let metrics = exact_lines_match("", "");
+        assert_eq!(metrics.true_positives, 0);
+        assert_eq!(metrics.false_positives, 0);
+        assert_eq!(metrics.false_negatives, 0);
+    }
 }

crates/edit_prediction_cli/src/score.rs 🔗

@@ -42,6 +42,9 @@ pub async fn run_scoring(
     let zero_scores = ExampleScore {
         delta_chr_f: 0.0,
         braces_disbalance: 0,
+        exact_lines_tp: 0,
+        exact_lines_fp: 0,
+        exact_lines_fn: 0,
     };
 
     progress.set_substatus("computing metrics");
@@ -82,9 +85,21 @@ pub async fn run_scoring(
             std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
         }
 
+        // Compute exact lines match against best matching expected patch
+        let best_exact_lines = example
+            .spec
+            .expected_patches
+            .iter()
+            .map(|expected_patch| metrics::exact_lines_match(expected_patch, &actual_patch))
+            .max_by_key(|m| m.true_positives)
+            .unwrap_or_default();
+
         scores.push(ExampleScore {
             delta_chr_f: best_delta_chr_f,
             braces_disbalance,
+            exact_lines_tp: best_exact_lines.true_positives,
+            exact_lines_fp: best_exact_lines.false_positives,
+            exact_lines_fn: best_exact_lines.false_negatives,
         });
     }
 
@@ -93,53 +108,73 @@ pub async fn run_scoring(
 }
 
 pub fn print_report(examples: &[Example]) {
+    use crate::metrics::ClassificationMetrics;
+
+    const LINE_WIDTH: usize = 100;
+    let separator = "─".repeat(LINE_WIDTH);
+
+    eprintln!("{}", separator);
     eprintln!(
-        "──────────────────────────────────────────────────────────────────────────────────────"
-    );
-    eprintln!(
-        "{:<50} {:>14} {:>10}",
-        "Example name", "BracesDisbalance", "DeltaChrF"
-    );
-    eprintln!(
-        "──────────────────────────────────────────────────────────────────────────────────────"
+        "{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7}",
+        "Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1"
     );
+    eprintln!("{}", separator);
 
     let mut all_delta_chr_f_scores = Vec::new();
     let mut braces_disbalance_sum: usize = 0;
+    let mut total_exact_lines = ClassificationMetrics::default();
     let mut total_scores: usize = 0;
 
     for example in examples {
         for score in example.score.iter() {
+            let exact_lines = ClassificationMetrics {
+                true_positives: score.exact_lines_tp,
+                false_positives: score.exact_lines_fp,
+                false_negatives: score.exact_lines_fn,
+            };
+
             eprintln!(
-                "{:<50} {:>14} {:>9.2}",
-                truncate_name(&example.spec.name, 50),
+                "{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
+                truncate_name(&example.spec.name, 40),
+                score.delta_chr_f,
                 score.braces_disbalance,
-                score.delta_chr_f
+                score.exact_lines_tp,
+                score.exact_lines_fp,
+                score.exact_lines_fn,
+                exact_lines.precision() * 100.0,
+                exact_lines.recall() * 100.0,
+                exact_lines.f1() * 100.0
             );
 
             all_delta_chr_f_scores.push(score.delta_chr_f);
             total_scores += 1;
             braces_disbalance_sum += score.braces_disbalance;
+            total_exact_lines.true_positives += score.exact_lines_tp;
+            total_exact_lines.false_positives += score.exact_lines_fp;
+            total_exact_lines.false_negatives += score.exact_lines_fn;
         }
     }
 
-    eprintln!(
-        "──────────────────────────────────────────────────────────────────────────────────────"
-    );
+    eprintln!("{}", separator);
 
     if !all_delta_chr_f_scores.is_empty() {
         let avg_delta_chr_f: f32 =
             all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
         let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
-        let braces_disbalance_display = format!("{:.2}", braces_disbalance_avg);
 
         eprintln!(
-            "{:<50} {:>14} {:>9.2}",
-            "AVERAGE", braces_disbalance_display, avg_delta_chr_f
-        );
-        eprintln!(
-            "──────────────────────────────────────────────────────────────────────────────────────"
+            "{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
+            "TOTAL / AVERAGE",
+            avg_delta_chr_f,
+            braces_disbalance_avg,
+            total_exact_lines.true_positives,
+            total_exact_lines.false_positives,
+            total_exact_lines.false_negatives,
+            total_exact_lines.precision() * 100.0,
+            total_exact_lines.recall() * 100.0,
+            total_exact_lines.f1() * 100.0
         );
+        eprintln!("{}", separator);
     }
 
     eprintln!("\n");