From d2070975fe5249e5cf653189800b68b0f58890a2 Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Thu, 22 Jan 2026 18:18:32 +0200 Subject: [PATCH] ep: Add line-level exact match metric (#47383) Release Notes: - N/A --- crates/edit_prediction_cli/src/example.rs | 6 + crates/edit_prediction_cli/src/metrics.rs | 183 +++++++++++++++++++++- crates/edit_prediction_cli/src/score.rs | 75 ++++++--- 3 files changed, 236 insertions(+), 28 deletions(-) diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 3ef359f35da0dfa9989c0f8feb400421d2c44a83..d232ece51f88c0b03f72e05f846b858c136d2e5d 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -83,6 +83,12 @@ pub struct ExamplePrediction { pub struct ExampleScore { pub delta_chr_f: f32, pub braces_disbalance: usize, + #[serde(default)] + pub exact_lines_tp: usize, + #[serde(default)] + pub exact_lines_fp: usize, + #[serde(default)] + pub exact_lines_fn: usize, } impl Example { diff --git a/crates/edit_prediction_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs index 4382b775e237e31fff17c20dfb7a2bfb1656f2cb..6e4f4dc09b04a7c9cfdc15ec92fc9e927370c989 100644 --- a/crates/edit_prediction_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -1,20 +1,20 @@ use collections::HashMap; -type Counts = HashMap; +pub type Counts = HashMap; type CountsDelta = HashMap; /// Context characters needed on each side of a change to capture all affected n-grams const CONTEXT_CHARS: usize = CHR_F_CHAR_ORDER - 1; #[derive(Default, Debug, Clone)] -struct ClassificationMetrics { - true_positives: usize, - false_positives: usize, - false_negatives: usize, +pub struct ClassificationMetrics { + pub true_positives: usize, + pub false_positives: usize, + pub false_negatives: usize, } impl ClassificationMetrics { - fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics { + pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics { let mut true_positives = 0; let mut false_positives = 0; let mut false_negatives = 0; @@ -42,7 +42,7 @@ impl ClassificationMetrics { } } - fn precision(&self) -> f64 { + pub fn precision(&self) -> f64 { if self.true_positives + self.false_positives == 0 { 0.0 } else { @@ -50,13 +50,23 @@ impl ClassificationMetrics { } } - fn recall(&self) -> f64 { + pub fn recall(&self) -> f64 { if self.true_positives + self.false_negatives == 0 { 0.0 } else { self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64 } } + + pub fn f1(&self) -> f64 { + let precision = self.precision(); + let recall = self.recall(); + if precision + recall == 0.0 { + 0.0 + } else { + 2.0 * precision * recall / (precision + recall) + } + } } enum ChrfWhitespace { @@ -335,6 +345,43 @@ pub fn braces_disbalance(text: &str) -> usize { disbalance as usize } +/// Extracts changed lines from a unified diff string. +/// Returns a bag (multiset) of lines that were added (+) or removed (-). +/// The +/- prefix is included in the line to distinguish additions from deletions. +pub fn extract_changed_lines_from_diff(diff: &str) -> Counts { + let mut counts = Counts::default(); + + for line in diff.lines() { + // Skip file headers (--- and +++) + if line.starts_with("---") || line.starts_with("+++") { + continue; + } + // Skip hunk headers (@@) + if line.starts_with("@@") { + continue; + } + // Skip diff header lines (diff --git, index, etc.) + if line.starts_with("diff ") || line.starts_with("index ") { + continue; + } + // Include added and removed lines (with their prefix) + if line.starts_with('+') || line.starts_with('-') { + *counts.entry(line.to_string()).or_insert(0) += 1; + } + } + + counts +} + +/// Computes exact lines match metrics between expected and actual patches. +/// Treats changed lines as a bag (multiset) - order is discarded but count matters. +/// Returns ClassificationMetrics with TP/FP/FN counts. +pub fn exact_lines_match(expected_patch: &str, actual_patch: &str) -> ClassificationMetrics { + let expected_lines = extract_changed_lines_from_diff(expected_patch); + let actual_lines = extract_changed_lines_from_diff(actual_patch); + ClassificationMetrics::from_counts(&expected_lines, &actual_lines) +} + #[cfg(test)] mod test_optimization { use super::*; @@ -559,4 +606,124 @@ mod test { let text = "let x = { 1 + 2 )"; assert_eq!(braces_disbalance(text), 2); } + + #[test] + fn test_extract_changed_lines_from_diff() { + let diff = r#"--- a/file.rs ++++ b/file.rs +@@ -1,3 +1,3 @@ + fn main() { +- println!("hello"); ++ println!("world"); + }"#; + + let counts = extract_changed_lines_from_diff(diff); + assert_eq!(counts.get("- println!(\"hello\");"), Some(&1)); + assert_eq!(counts.get("+ println!(\"world\");"), Some(&1)); + assert_eq!(counts.len(), 2); + } + + #[test] + fn test_extract_changed_lines_skips_headers() { + let diff = r#"diff --git a/file.rs b/file.rs +index abc123..def456 100644 +--- a/file.rs ++++ b/file.rs +@@ -1,2 +1,2 @@ +-old line ++new line"#; + + let counts = extract_changed_lines_from_diff(diff); + assert_eq!(counts.get("-old line"), Some(&1)); + assert_eq!(counts.get("+new line"), Some(&1)); + assert_eq!(counts.len(), 2); + } + + #[test] + fn test_exact_lines_match_perfect() { + let expected = r#"--- a/file.rs ++++ b/file.rs +@@ -1,3 +1,3 @@ +-old line 1 +-old line 2 ++new line 1 ++new line 2"#; + + let actual = r#"--- a/file.rs ++++ b/file.rs +@@ -1,3 +1,3 @@ +-old line 1 +-old line 2 ++new line 1 ++new line 2"#; + + let metrics = exact_lines_match(expected, actual); + assert_eq!(metrics.true_positives, 4); + assert_eq!(metrics.false_positives, 0); + assert_eq!(metrics.false_negatives, 0); + assert!((metrics.precision() - 1.0).abs() < 1e-6); + assert!((metrics.recall() - 1.0).abs() < 1e-6); + assert!((metrics.f1() - 1.0).abs() < 1e-6); + } + + #[test] + fn test_exact_lines_match_partial() { + let expected = r#"-old line 1 +-old line 2 ++new line 1 ++new line 2"#; + + let actual = r#"-old line 1 ++new line 1 ++extra line"#; + + let metrics = exact_lines_match(expected, actual); + // TP: "-old line 1" and "+new line 1" (2) + // FP: "+extra line" (1) + // FN: "-old line 2" and "+new line 2" (2) + assert_eq!(metrics.true_positives, 2); + assert_eq!(metrics.false_positives, 1); + assert_eq!(metrics.false_negatives, 2); + } + + #[test] + fn test_exact_lines_match_no_overlap() { + let expected = r#"-line a ++line b"#; + + let actual = r#"-line x ++line y"#; + + let metrics = exact_lines_match(expected, actual); + assert_eq!(metrics.true_positives, 0); + assert_eq!(metrics.false_positives, 2); + assert_eq!(metrics.false_negatives, 2); + assert!((metrics.precision()).abs() < 1e-6); + assert!((metrics.recall()).abs() < 1e-6); + } + + #[test] + fn test_exact_lines_match_duplicate_lines() { + let expected = r#"+line a ++line a ++line a"#; + + let actual = r#"+line a ++line a"#; + + let metrics = exact_lines_match(expected, actual); + // Expected has 3 "+line a", actual has 2 + // TP: 2, FN: 1, FP: 0 + assert_eq!(metrics.true_positives, 2); + assert_eq!(metrics.false_positives, 0); + assert_eq!(metrics.false_negatives, 1); + } + + #[test] + fn test_exact_lines_match_empty_patches() { + let metrics = exact_lines_match("", ""); + assert_eq!(metrics.true_positives, 0); + assert_eq!(metrics.false_positives, 0); + assert_eq!(metrics.false_negatives, 0); + } } diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index eaa42da71883f1069a947ca827cb3f1ef27eb891..4c5a8d5d0a97c07691fcfab2b5b7b13631f9a5b8 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -42,6 +42,9 @@ pub async fn run_scoring( let zero_scores = ExampleScore { delta_chr_f: 0.0, braces_disbalance: 0, + exact_lines_tp: 0, + exact_lines_fp: 0, + exact_lines_fn: 0, }; progress.set_substatus("computing metrics"); @@ -82,9 +85,21 @@ pub async fn run_scoring( std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok(); } + // Compute exact lines match against best matching expected patch + let best_exact_lines = example + .spec + .expected_patches + .iter() + .map(|expected_patch| metrics::exact_lines_match(expected_patch, &actual_patch)) + .max_by_key(|m| m.true_positives) + .unwrap_or_default(); + scores.push(ExampleScore { delta_chr_f: best_delta_chr_f, braces_disbalance, + exact_lines_tp: best_exact_lines.true_positives, + exact_lines_fp: best_exact_lines.false_positives, + exact_lines_fn: best_exact_lines.false_negatives, }); } @@ -93,53 +108,73 @@ pub async fn run_scoring( } pub fn print_report(examples: &[Example]) { + use crate::metrics::ClassificationMetrics; + + const LINE_WIDTH: usize = 100; + let separator = "─".repeat(LINE_WIDTH); + + eprintln!("{}", separator); eprintln!( - "──────────────────────────────────────────────────────────────────────────────────────" - ); - eprintln!( - "{:<50} {:>14} {:>10}", - "Example name", "BracesDisbalance", "DeltaChrF" - ); - eprintln!( - "──────────────────────────────────────────────────────────────────────────────────────" + "{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7}", + "Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1" ); + eprintln!("{}", separator); let mut all_delta_chr_f_scores = Vec::new(); let mut braces_disbalance_sum: usize = 0; + let mut total_exact_lines = ClassificationMetrics::default(); let mut total_scores: usize = 0; for example in examples { for score in example.score.iter() { + let exact_lines = ClassificationMetrics { + true_positives: score.exact_lines_tp, + false_positives: score.exact_lines_fp, + false_negatives: score.exact_lines_fn, + }; + eprintln!( - "{:<50} {:>14} {:>9.2}", - truncate_name(&example.spec.name, 50), + "{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%", + truncate_name(&example.spec.name, 40), + score.delta_chr_f, score.braces_disbalance, - score.delta_chr_f + score.exact_lines_tp, + score.exact_lines_fp, + score.exact_lines_fn, + exact_lines.precision() * 100.0, + exact_lines.recall() * 100.0, + exact_lines.f1() * 100.0 ); all_delta_chr_f_scores.push(score.delta_chr_f); total_scores += 1; braces_disbalance_sum += score.braces_disbalance; + total_exact_lines.true_positives += score.exact_lines_tp; + total_exact_lines.false_positives += score.exact_lines_fp; + total_exact_lines.false_negatives += score.exact_lines_fn; } } - eprintln!( - "──────────────────────────────────────────────────────────────────────────────────────" - ); + eprintln!("{}", separator); if !all_delta_chr_f_scores.is_empty() { let avg_delta_chr_f: f32 = all_delta_chr_f_scores.iter().sum::() / all_delta_chr_f_scores.len() as f32; let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32; - let braces_disbalance_display = format!("{:.2}", braces_disbalance_avg); eprintln!( - "{:<50} {:>14} {:>9.2}", - "AVERAGE", braces_disbalance_display, avg_delta_chr_f - ); - eprintln!( - "──────────────────────────────────────────────────────────────────────────────────────" + "{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%", + "TOTAL / AVERAGE", + avg_delta_chr_f, + braces_disbalance_avg, + total_exact_lines.true_positives, + total_exact_lines.false_positives, + total_exact_lines.false_negatives, + total_exact_lines.precision() * 100.0, + total_exact_lines.recall() * 100.0, + total_exact_lines.f1() * 100.0 ); + eprintln!("{}", separator); } eprintln!("\n");