diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 196f4f96d99b64aed2ff3ae2d7a9897295a60b29..4827337d37a211056d04cf9ca13f8d49fb91c392 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -1,4 +1,5 @@ use crate::PredictionProvider; +use crate::metrics::ClassificationMetrics; use crate::paths::WORKTREES_DIR; use crate::qa::QaResult; use anyhow::{Context as _, Result}; @@ -150,6 +151,18 @@ where #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ExampleScore { pub delta_chr_f: f32, + #[serde(default)] + pub delta_chr_f_true_positives: usize, + #[serde(default)] + pub delta_chr_f_false_positives: usize, + #[serde(default)] + pub delta_chr_f_false_negatives: usize, + #[serde(default)] + pub delta_chr_f_precision: f64, + #[serde(default)] + pub delta_chr_f_recall: f64, + #[serde(default)] + pub delta_chr_f_beta: f64, pub braces_disbalance: usize, #[serde(default)] pub exact_lines_tp: usize, @@ -176,6 +189,24 @@ pub struct ExampleScore { pub avg_logprob: Option, } +impl ExampleScore { + pub fn delta_chr_f_counts(&self) -> ClassificationMetrics { + ClassificationMetrics { + true_positives: self.delta_chr_f_true_positives, + false_positives: self.delta_chr_f_false_positives, + false_negatives: self.delta_chr_f_false_negatives, + } + } + + pub fn exact_lines_counts(&self) -> ClassificationMetrics { + ClassificationMetrics { + true_positives: self.exact_lines_tp, + false_positives: self.exact_lines_fp, + false_negatives: self.exact_lines_fn, + } + } +} + impl Example { pub fn repo_name(&self) -> Result> { // git@github.com:owner/repo.git diff --git a/crates/edit_prediction_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs index 1bfd8e542fa3d74b55f091d2ac13aa22883f6a2f..8037699f4bb6f851fdadb05b435b090b911b010a 100644 --- a/crates/edit_prediction_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -48,6 +48,12 @@ impl ClassificationMetrics { } } + pub fn accumulate(&mut self, other: &ClassificationMetrics) { + self.true_positives += other.true_positives; + self.false_positives += other.false_positives; + self.false_negatives += other.false_negatives; + } + pub fn precision(&self) -> f64 { if self.true_positives + self.false_positives == 0 { 0.0 @@ -89,10 +95,23 @@ enum ChrfWhitespace { } const CHR_F_CHAR_ORDER: usize = 6; -const CHR_F_BETA: f64 = 2.0; +const CHR_F_BETA: f64 = 0.5; const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Collapse; -/// Computes a delta-chrF score that compares two sets of edits. +pub fn delta_chr_f_beta() -> f64 { + CHR_F_BETA +} + +#[derive(Default, Debug, Clone)] +pub struct DeltaChrFMetrics { + pub score: f64, + pub beta: f64, + pub counts: ClassificationMetrics, + pub precision: f64, + pub recall: f64, +} + +/// Computes delta-chrF metrics that compare two sets of edits. /// /// This metric works by: /// 1. Computing n-gram count differences (deltas) between original→expected and original→actual @@ -100,13 +119,17 @@ const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Collapse; /// /// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match /// the expected edits. -pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 { - // Edge case: if all texts are identical, the edits match perfectly +pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics { if original == expected && expected == actual { - return 100.0; + return DeltaChrFMetrics { + score: 100.0, + beta: CHR_F_BETA, + precision: 1.0, + recall: 1.0, + ..DeltaChrFMetrics::default() + }; } - // Pre-filter whitespace once for all texts let orig_chars: Vec = filter_whitespace_chars(original); let exp_chars: Vec = filter_whitespace_chars(expected); let act_chars: Vec = filter_whitespace_chars(actual); @@ -118,9 +141,9 @@ pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 { let mut total_precision = 0.0; let mut total_recall = 0.0; + let mut total_counts = ClassificationMetrics::default(); for order in 1..=CHR_F_CHAR_ORDER { - // Compute n-grams only on the affected regions let orig_ngrams_for_exp = count_ngrams_from_chars(&orig_for_exp, order); let exp_ngrams = count_ngrams_from_chars(&exp_region, order); let expected_delta = compute_ngram_delta(&exp_ngrams, &orig_ngrams_for_exp); @@ -138,28 +161,43 @@ pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 { let expected_counts = ngram_delta_to_counts(&expected_delta); let actual_counts = ngram_delta_to_counts(&actual_delta); - let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts); - total_precision += score.precision(); - total_recall += score.recall(); + let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts); + total_precision += counts.precision(); + total_recall += counts.recall(); + total_counts.accumulate(&counts); } - let prec = total_precision / CHR_F_CHAR_ORDER as f64; - let recall = total_recall / CHR_F_CHAR_ORDER as f64; - let f_score = if prec + recall == 0.0 { + let average_precision = total_precision / CHR_F_CHAR_ORDER as f64; + let average_recall = total_recall / CHR_F_CHAR_ORDER as f64; + let score = if average_precision + average_recall == 0.0 { 0.0 } else { - (1.0 + CHR_F_BETA * CHR_F_BETA) * prec * recall / (CHR_F_BETA * CHR_F_BETA * prec + recall) + (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall + / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall) + * 100.0 }; - f_score * 100.0 + DeltaChrFMetrics { + score, + beta: CHR_F_BETA, + counts: total_counts, + precision: average_precision, + recall: average_recall, + } } -/// Reference implementation of delta_chr_f (original, non-optimized version). +/// Reference implementation of delta-chrF metrics (original, non-optimized version). /// Used for testing that the optimized version produces identical results. #[cfg(test)] -fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> f64 { +fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics { if original == expected && expected == actual { - return 100.0; + return DeltaChrFMetrics { + score: 100.0, + beta: CHR_F_BETA, + precision: 1.0, + recall: 1.0, + ..DeltaChrFMetrics::default() + }; } let original_ngrams = chr_f_ngram_counts(original); @@ -168,6 +206,7 @@ fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> f64 { let mut total_precision = 0.0; let mut total_recall = 0.0; + let mut total_counts = ClassificationMetrics::default(); for order in 0..CHR_F_CHAR_ORDER { let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]); @@ -182,20 +221,29 @@ fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> f64 { let expected_counts = ngram_delta_to_counts(&expected_delta); let actual_counts = ngram_delta_to_counts(&actual_delta); - let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts); - total_precision += score.precision(); - total_recall += score.recall(); + let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts); + total_precision += counts.precision(); + total_recall += counts.recall(); + total_counts.accumulate(&counts); } - let prec = total_precision / CHR_F_CHAR_ORDER as f64; - let recall = total_recall / CHR_F_CHAR_ORDER as f64; - let f_score = if prec + recall == 0.0 { + let average_precision = total_precision / CHR_F_CHAR_ORDER as f64; + let average_recall = total_recall / CHR_F_CHAR_ORDER as f64; + let score = if average_precision + average_recall == 0.0 { 0.0 } else { - (1.0 + CHR_F_BETA * CHR_F_BETA) * prec * recall / (CHR_F_BETA * CHR_F_BETA * prec + recall) + (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall + / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall) + * 100.0 }; - f_score * 100.0 + DeltaChrFMetrics { + score, + beta: CHR_F_BETA, + counts: total_counts, + precision: average_precision, + recall: average_recall, + } } /// Filter whitespace from a string and return as Vec @@ -664,7 +712,7 @@ mod test_optimization { ]; for (original, expected, actual) in test_cases { - let score = delta_chr_f(original, expected, actual); + let score = delta_chr_f(original, expected, actual).score; // Just verify it produces a reasonable score (0-100) assert!( score >= 0.0 && score <= 100.0, @@ -733,20 +781,51 @@ mod test_optimization { ]; for (original, expected, actual) in test_cases { - let optimized_score = delta_chr_f(original, expected, actual); - let reference_score = delta_chr_f_reference(original, expected, actual); + let optimized_metrics = delta_chr_f(original, expected, actual); + let reference_metrics = delta_chr_f_reference(original, expected, actual); assert!( - (optimized_score - reference_score).abs() < 1e-10, - "Mismatch for ({:?}, {:?}, {:?}):\n optimized: {}\n reference: {}", + (optimized_metrics.score - reference_metrics.score).abs() < 1e-10, + "Score mismatch for ({:?}, {:?}, {:?}):\n optimized: {}\n reference: {}", original, expected, actual, - optimized_score, - reference_score + optimized_metrics.score, + reference_metrics.score + ); + assert_eq!( + optimized_metrics.counts.true_positives, + reference_metrics.counts.true_positives + ); + assert_eq!( + optimized_metrics.counts.false_positives, + reference_metrics.counts.false_positives ); + assert_eq!( + optimized_metrics.counts.false_negatives, + reference_metrics.counts.false_negatives + ); + assert!((optimized_metrics.precision - reference_metrics.precision).abs() < 1e-10); + assert!((optimized_metrics.recall - reference_metrics.recall).abs() < 1e-10); } } + + #[test] + fn test_delta_chr_f_metrics_include_counts_and_rates() { + let original = "one two three"; + let expected = "one three"; + let actual = "one two four"; + + let metrics = delta_chr_f(original, expected, actual); + + assert!(metrics.score > 20.0 && metrics.score < 40.0); + assert!(metrics.counts.true_positives > 0); + assert!(metrics.counts.false_positives > 0); + assert!(metrics.counts.false_negatives > 0); + assert!(metrics.precision > 0.0 && metrics.precision < 1.0); + assert!(metrics.recall > 0.0 && metrics.recall < 1.0); + assert_eq!(metrics.beta, CHR_F_BETA); + } } #[cfg(test)] @@ -770,7 +849,7 @@ mod test { let original = "fn main() { println!(\"Hello\");}"; let expected = "fn main() { println!(\"Hello, World!\");}"; - let score = delta_chr_f(original, expected, expected); + let score = delta_chr_f(original, expected, expected).score; assert!((score - 100.0).abs() < 1e-2); } @@ -782,7 +861,7 @@ mod test { let actual = "one two four"; // deleted "three", added "four" // Then the score should be low - let score = delta_chr_f(original, expected, actual); + let score = delta_chr_f(original, expected, actual).score; assert!(score > 20.0 && score < 40.0); } @@ -794,7 +873,7 @@ mod test { // We got the edit location right, but the replacement text is wrong. // Deleted ngrams will match, bringing the score somewhere in the middle. - let score = delta_chr_f(original, expected, actual); + let score = delta_chr_f(original, expected, actual).score; assert!(score > 40.0 && score < 60.0); } @@ -806,7 +885,7 @@ mod test { let actual = "prefix old suffix"; // no change // Then the score should be low (all expected changes are false negatives) - let score = delta_chr_f(original, expected, actual); + let score = delta_chr_f(original, expected, actual).score; assert!(score < 20.0); } @@ -818,14 +897,14 @@ mod test { let actual = "helloextraworld"; // added "extra" // Then the score should be low (all actual changes are false positives) - let score = delta_chr_f(original, expected, actual); + let score = delta_chr_f(original, expected, actual).score; assert!(score < 20.0); } #[test] fn test_delta_chr_f_no_changes() { let text = "unchanged text"; - let score = delta_chr_f(text, text, text); + let score = delta_chr_f(text, text, text).score; assert!((score - 100.0).abs() < 1e-2); } diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index d75cf55e85b198bc28469e83d8f9209a8a59a83f..be9b185809e6e0cd49e0befbeecec0f317339342 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -67,6 +67,12 @@ pub async fn run_scoring( let zero_scores = ExampleScore { delta_chr_f: 0.0, + delta_chr_f_true_positives: 0, + delta_chr_f_false_positives: 0, + delta_chr_f_false_negatives: 0, + delta_chr_f_precision: 0.0, + delta_chr_f_recall: 0.0, + delta_chr_f_beta: metrics::delta_chr_f_beta(), braces_disbalance: 0, exact_lines_tp: 0, exact_lines_fp: 0, @@ -111,14 +117,14 @@ pub async fn run_scoring( } }; - let mut best_delta_chr_f = 0.0f32; + let mut best_delta_chr_f_metrics = metrics::DeltaChrFMetrics::default(); let mut best_expected_cursor: Option = None; let mut best_patch_idx: Option = None; for (idx, expected) in expected_texts.iter().enumerate() { - let delta_chr_f = metrics::delta_chr_f(original_text, expected, &actual_text) as f32; - if delta_chr_f > best_delta_chr_f { - best_delta_chr_f = delta_chr_f; + let delta_chr_f_metrics = metrics::delta_chr_f(original_text, expected, &actual_text); + if delta_chr_f_metrics.score > best_delta_chr_f_metrics.score { + best_delta_chr_f_metrics = delta_chr_f_metrics; best_patch_idx = Some(idx); } } @@ -179,7 +185,13 @@ pub async fn run_scoring( ); scores.push(ExampleScore { - delta_chr_f: best_delta_chr_f, + delta_chr_f: best_delta_chr_f_metrics.score as f32, + delta_chr_f_true_positives: best_delta_chr_f_metrics.counts.true_positives, + delta_chr_f_false_positives: best_delta_chr_f_metrics.counts.false_positives, + delta_chr_f_false_negatives: best_delta_chr_f_metrics.counts.false_negatives, + delta_chr_f_precision: best_delta_chr_f_metrics.precision, + delta_chr_f_recall: best_delta_chr_f_metrics.recall, + delta_chr_f_beta: best_delta_chr_f_metrics.beta, braces_disbalance, exact_lines_tp: best_exact_lines.true_positives, exact_lines_fp: best_exact_lines.false_positives, @@ -238,6 +250,10 @@ pub fn print_report(examples: &[Example], verbose: bool) { let mut all_delta_chr_f_scores = Vec::new(); let mut all_reversal_ratios = Vec::new(); let mut braces_disbalance_sum: usize = 0; + let mut total_delta_chr_f = ClassificationMetrics::default(); + let mut total_delta_chr_f_precision = 0.0; + let mut total_delta_chr_f_recall = 0.0; + let mut delta_chr_f_beta = 0.0; let mut total_exact_lines = ClassificationMetrics::default(); let mut total_scores: usize = 0; let mut qa_reverts_count: usize = 0; @@ -260,11 +276,7 @@ pub fn print_report(examples: &[Example], verbose: bool) { for example in examples { for (score_idx, score) in example.score.iter().enumerate() { - let exact_lines = ClassificationMetrics { - true_positives: score.exact_lines_tp, - false_positives: score.exact_lines_fp, - false_negatives: score.exact_lines_fn, - }; + let exact_lines = score.exact_lines_counts(); // Get QA results for this prediction if available let qa_result = example.qa.get(score_idx).and_then(|q| q.as_ref()); @@ -314,9 +326,11 @@ pub fn print_report(examples: &[Example], verbose: bool) { all_reversal_ratios.push(score.reversal_ratio); total_scores += 1; braces_disbalance_sum += score.braces_disbalance; - total_exact_lines.true_positives += score.exact_lines_tp; - total_exact_lines.false_positives += score.exact_lines_fp; - total_exact_lines.false_negatives += score.exact_lines_fn; + total_delta_chr_f.accumulate(&score.delta_chr_f_counts()); + total_delta_chr_f_precision += score.delta_chr_f_precision; + total_delta_chr_f_recall += score.delta_chr_f_recall; + delta_chr_f_beta = score.delta_chr_f_beta; + total_exact_lines.accumulate(&score.exact_lines_counts()); // Accumulate QA metrics if let Some(qa) = qa_result { @@ -448,6 +462,15 @@ pub fn print_report(examples: &[Example], verbose: bool) { wrong_er_str ); println!("{}", separator); + println!( + "Delta chrF (β={:.1}): TP={}, FP={}, FN={}, P={:.1}%, R={:.1}%", + delta_chr_f_beta, + total_delta_chr_f.true_positives, + total_delta_chr_f.false_positives, + total_delta_chr_f.false_negatives, + total_delta_chr_f_precision / total_scores as f64 * 100.0, + total_delta_chr_f_recall / total_scores as f64 * 100.0 + ); // Print additional cursor metrics if available if let Some(avg_dist) = avg_cursor_distance { @@ -540,6 +563,12 @@ fn truncate_name(name: &str, max_len: usize) -> String { pub struct SummaryJson { pub total_examples: usize, pub avg_delta_chr_f: f32, + pub delta_chr_f_beta: f64, + pub delta_chr_f_true_positives: usize, + pub delta_chr_f_false_positives: usize, + pub delta_chr_f_false_negatives: usize, + pub delta_chr_f_precision: f64, + pub delta_chr_f_recall: f64, pub avg_braces_disbalance: f32, pub exact_lines_true_positives: usize, pub exact_lines_false_positives: usize, @@ -569,6 +598,10 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { let mut all_delta_chr_f_scores = Vec::new(); let mut all_reversal_ratios = Vec::new(); let mut braces_disbalance_sum: usize = 0; + let mut total_delta_chr_f = ClassificationMetrics::default(); + let mut total_delta_chr_f_precision = 0.0; + let mut total_delta_chr_f_recall = 0.0; + let mut delta_chr_f_beta = 0.0; let mut total_exact_lines = ClassificationMetrics::default(); let mut total_scores: usize = 0; let mut qa_reverts_count: usize = 0; @@ -589,9 +622,11 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { all_reversal_ratios.push(score.reversal_ratio); total_scores += 1; braces_disbalance_sum += score.braces_disbalance; - total_exact_lines.true_positives += score.exact_lines_tp; - total_exact_lines.false_positives += score.exact_lines_fp; - total_exact_lines.false_negatives += score.exact_lines_fn; + total_delta_chr_f.accumulate(&score.delta_chr_f_counts()); + total_delta_chr_f_precision += score.delta_chr_f_precision; + total_delta_chr_f_recall += score.delta_chr_f_recall; + delta_chr_f_beta = score.delta_chr_f_beta; + total_exact_lines.accumulate(&score.exact_lines_counts()); // Accumulate QA metrics if let Some(Some(qa)) = example.qa.get(score_idx) { @@ -697,6 +732,20 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { SummaryJson { total_examples: total_scores, avg_delta_chr_f, + delta_chr_f_beta, + delta_chr_f_true_positives: total_delta_chr_f.true_positives, + delta_chr_f_false_positives: total_delta_chr_f.false_positives, + delta_chr_f_false_negatives: total_delta_chr_f.false_negatives, + delta_chr_f_precision: if total_scores == 0 { + 0.0 + } else { + total_delta_chr_f_precision / total_scores as f64 + }, + delta_chr_f_recall: if total_scores == 0 { + 0.0 + } else { + total_delta_chr_f_recall / total_scores as f64 + }, avg_braces_disbalance, exact_lines_true_positives: total_exact_lines.true_positives, exact_lines_false_positives: total_exact_lines.false_positives,