ep: Change beta in deltaChrF to favor precision over recall (#52422)

Oleksiy Syvokon created

Also store and print more detailed information related to this metric
(precision, recall, tp/fp/fn counts)

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/example.rs |  31 ++++
crates/edit_prediction_cli/src/metrics.rs | 157 ++++++++++++++++++------
crates/edit_prediction_cli/src/score.rs   |  81 ++++++++++--
3 files changed, 214 insertions(+), 55 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/example.rs đź”—

@@ -1,4 +1,5 @@
 use crate::PredictionProvider;
+use crate::metrics::ClassificationMetrics;
 use crate::paths::WORKTREES_DIR;
 use crate::qa::QaResult;
 use anyhow::{Context as _, Result};
@@ -150,6 +151,18 @@ where
 #[derive(Clone, Debug, Serialize, Deserialize)]
 pub struct ExampleScore {
     pub delta_chr_f: f32,
+    #[serde(default)]
+    pub delta_chr_f_true_positives: usize,
+    #[serde(default)]
+    pub delta_chr_f_false_positives: usize,
+    #[serde(default)]
+    pub delta_chr_f_false_negatives: usize,
+    #[serde(default)]
+    pub delta_chr_f_precision: f64,
+    #[serde(default)]
+    pub delta_chr_f_recall: f64,
+    #[serde(default)]
+    pub delta_chr_f_beta: f64,
     pub braces_disbalance: usize,
     #[serde(default)]
     pub exact_lines_tp: usize,
@@ -176,6 +189,24 @@ pub struct ExampleScore {
     pub avg_logprob: Option<f64>,
 }
 
+impl ExampleScore {
+    pub fn delta_chr_f_counts(&self) -> ClassificationMetrics {
+        ClassificationMetrics {
+            true_positives: self.delta_chr_f_true_positives,
+            false_positives: self.delta_chr_f_false_positives,
+            false_negatives: self.delta_chr_f_false_negatives,
+        }
+    }
+
+    pub fn exact_lines_counts(&self) -> ClassificationMetrics {
+        ClassificationMetrics {
+            true_positives: self.exact_lines_tp,
+            false_positives: self.exact_lines_fp,
+            false_negatives: self.exact_lines_fn,
+        }
+    }
+}
+
 impl Example {
     pub fn repo_name(&self) -> Result<RepoName<'_>> {
         // git@github.com:owner/repo.git

crates/edit_prediction_cli/src/metrics.rs đź”—

@@ -48,6 +48,12 @@ impl ClassificationMetrics {
         }
     }
 
+    pub fn accumulate(&mut self, other: &ClassificationMetrics) {
+        self.true_positives += other.true_positives;
+        self.false_positives += other.false_positives;
+        self.false_negatives += other.false_negatives;
+    }
+
     pub fn precision(&self) -> f64 {
         if self.true_positives + self.false_positives == 0 {
             0.0
@@ -89,10 +95,23 @@ enum ChrfWhitespace {
 }
 
 const CHR_F_CHAR_ORDER: usize = 6;
-const CHR_F_BETA: f64 = 2.0;
+const CHR_F_BETA: f64 = 0.5;
 const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Collapse;
 
-/// Computes a delta-chrF score that compares two sets of edits.
+pub fn delta_chr_f_beta() -> f64 {
+    CHR_F_BETA
+}
+
+#[derive(Default, Debug, Clone)]
+pub struct DeltaChrFMetrics {
+    pub score: f64,
+    pub beta: f64,
+    pub counts: ClassificationMetrics,
+    pub precision: f64,
+    pub recall: f64,
+}
+
+/// Computes delta-chrF metrics that compare two sets of edits.
 ///
 /// This metric works by:
 /// 1. Computing n-gram count differences (deltas) between original→expected and original→actual
@@ -100,13 +119,17 @@ const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Collapse;
 ///
 /// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match
 /// the expected edits.
-pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 {
-    // Edge case: if all texts are identical, the edits match perfectly
+pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics {
     if original == expected && expected == actual {
-        return 100.0;
+        return DeltaChrFMetrics {
+            score: 100.0,
+            beta: CHR_F_BETA,
+            precision: 1.0,
+            recall: 1.0,
+            ..DeltaChrFMetrics::default()
+        };
     }
 
-    // Pre-filter whitespace once for all texts
     let orig_chars: Vec<char> = filter_whitespace_chars(original);
     let exp_chars: Vec<char> = filter_whitespace_chars(expected);
     let act_chars: Vec<char> = filter_whitespace_chars(actual);
@@ -118,9 +141,9 @@ pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 {
 
     let mut total_precision = 0.0;
     let mut total_recall = 0.0;
+    let mut total_counts = ClassificationMetrics::default();
 
     for order in 1..=CHR_F_CHAR_ORDER {
-        // Compute n-grams only on the affected regions
         let orig_ngrams_for_exp = count_ngrams_from_chars(&orig_for_exp, order);
         let exp_ngrams = count_ngrams_from_chars(&exp_region, order);
         let expected_delta = compute_ngram_delta(&exp_ngrams, &orig_ngrams_for_exp);
@@ -138,28 +161,43 @@ pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 {
         let expected_counts = ngram_delta_to_counts(&expected_delta);
         let actual_counts = ngram_delta_to_counts(&actual_delta);
 
-        let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
-        total_precision += score.precision();
-        total_recall += score.recall();
+        let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
+        total_precision += counts.precision();
+        total_recall += counts.recall();
+        total_counts.accumulate(&counts);
     }
 
-    let prec = total_precision / CHR_F_CHAR_ORDER as f64;
-    let recall = total_recall / CHR_F_CHAR_ORDER as f64;
-    let f_score = if prec + recall == 0.0 {
+    let average_precision = total_precision / CHR_F_CHAR_ORDER as f64;
+    let average_recall = total_recall / CHR_F_CHAR_ORDER as f64;
+    let score = if average_precision + average_recall == 0.0 {
         0.0
     } else {
-        (1.0 + CHR_F_BETA * CHR_F_BETA) * prec * recall / (CHR_F_BETA * CHR_F_BETA * prec + recall)
+        (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall
+            / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall)
+            * 100.0
     };
 
-    f_score * 100.0
+    DeltaChrFMetrics {
+        score,
+        beta: CHR_F_BETA,
+        counts: total_counts,
+        precision: average_precision,
+        recall: average_recall,
+    }
 }
 
-/// Reference implementation of delta_chr_f (original, non-optimized version).
+/// Reference implementation of delta-chrF metrics (original, non-optimized version).
 /// Used for testing that the optimized version produces identical results.
 #[cfg(test)]
-fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> f64 {
+fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> DeltaChrFMetrics {
     if original == expected && expected == actual {
-        return 100.0;
+        return DeltaChrFMetrics {
+            score: 100.0,
+            beta: CHR_F_BETA,
+            precision: 1.0,
+            recall: 1.0,
+            ..DeltaChrFMetrics::default()
+        };
     }
 
     let original_ngrams = chr_f_ngram_counts(original);
@@ -168,6 +206,7 @@ fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> f64 {
 
     let mut total_precision = 0.0;
     let mut total_recall = 0.0;
+    let mut total_counts = ClassificationMetrics::default();
 
     for order in 0..CHR_F_CHAR_ORDER {
         let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]);
@@ -182,20 +221,29 @@ fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> f64 {
         let expected_counts = ngram_delta_to_counts(&expected_delta);
         let actual_counts = ngram_delta_to_counts(&actual_delta);
 
-        let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
-        total_precision += score.precision();
-        total_recall += score.recall();
+        let counts = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
+        total_precision += counts.precision();
+        total_recall += counts.recall();
+        total_counts.accumulate(&counts);
     }
 
-    let prec = total_precision / CHR_F_CHAR_ORDER as f64;
-    let recall = total_recall / CHR_F_CHAR_ORDER as f64;
-    let f_score = if prec + recall == 0.0 {
+    let average_precision = total_precision / CHR_F_CHAR_ORDER as f64;
+    let average_recall = total_recall / CHR_F_CHAR_ORDER as f64;
+    let score = if average_precision + average_recall == 0.0 {
         0.0
     } else {
-        (1.0 + CHR_F_BETA * CHR_F_BETA) * prec * recall / (CHR_F_BETA * CHR_F_BETA * prec + recall)
+        (1.0 + CHR_F_BETA * CHR_F_BETA) * average_precision * average_recall
+            / (CHR_F_BETA * CHR_F_BETA * average_precision + average_recall)
+            * 100.0
     };
 
-    f_score * 100.0
+    DeltaChrFMetrics {
+        score,
+        beta: CHR_F_BETA,
+        counts: total_counts,
+        precision: average_precision,
+        recall: average_recall,
+    }
 }
 
 /// Filter whitespace from a string and return as Vec<char>
@@ -664,7 +712,7 @@ mod test_optimization {
         ];
 
         for (original, expected, actual) in test_cases {
-            let score = delta_chr_f(original, expected, actual);
+            let score = delta_chr_f(original, expected, actual).score;
             // Just verify it produces a reasonable score (0-100)
             assert!(
                 score >= 0.0 && score <= 100.0,
@@ -733,20 +781,51 @@ mod test_optimization {
         ];
 
         for (original, expected, actual) in test_cases {
-            let optimized_score = delta_chr_f(original, expected, actual);
-            let reference_score = delta_chr_f_reference(original, expected, actual);
+            let optimized_metrics = delta_chr_f(original, expected, actual);
+            let reference_metrics = delta_chr_f_reference(original, expected, actual);
 
             assert!(
-                (optimized_score - reference_score).abs() < 1e-10,
-                "Mismatch for ({:?}, {:?}, {:?}):\n  optimized: {}\n  reference: {}",
+                (optimized_metrics.score - reference_metrics.score).abs() < 1e-10,
+                "Score mismatch for ({:?}, {:?}, {:?}):\n  optimized: {}\n  reference: {}",
                 original,
                 expected,
                 actual,
-                optimized_score,
-                reference_score
+                optimized_metrics.score,
+                reference_metrics.score
+            );
+            assert_eq!(
+                optimized_metrics.counts.true_positives,
+                reference_metrics.counts.true_positives
+            );
+            assert_eq!(
+                optimized_metrics.counts.false_positives,
+                reference_metrics.counts.false_positives
             );
+            assert_eq!(
+                optimized_metrics.counts.false_negatives,
+                reference_metrics.counts.false_negatives
+            );
+            assert!((optimized_metrics.precision - reference_metrics.precision).abs() < 1e-10);
+            assert!((optimized_metrics.recall - reference_metrics.recall).abs() < 1e-10);
         }
     }
+
+    #[test]
+    fn test_delta_chr_f_metrics_include_counts_and_rates() {
+        let original = "one two three";
+        let expected = "one three";
+        let actual = "one two four";
+
+        let metrics = delta_chr_f(original, expected, actual);
+
+        assert!(metrics.score > 20.0 && metrics.score < 40.0);
+        assert!(metrics.counts.true_positives > 0);
+        assert!(metrics.counts.false_positives > 0);
+        assert!(metrics.counts.false_negatives > 0);
+        assert!(metrics.precision > 0.0 && metrics.precision < 1.0);
+        assert!(metrics.recall > 0.0 && metrics.recall < 1.0);
+        assert_eq!(metrics.beta, CHR_F_BETA);
+    }
 }
 
 #[cfg(test)]
@@ -770,7 +849,7 @@ mod test {
         let original = "fn main() {    println!(\"Hello\");}";
         let expected = "fn main() {    println!(\"Hello, World!\");}";
 
-        let score = delta_chr_f(original, expected, expected);
+        let score = delta_chr_f(original, expected, expected).score;
         assert!((score - 100.0).abs() < 1e-2);
     }
 
@@ -782,7 +861,7 @@ mod test {
         let actual = "one two four"; // deleted "three", added "four"
 
         // Then the score should be low
-        let score = delta_chr_f(original, expected, actual);
+        let score = delta_chr_f(original, expected, actual).score;
         assert!(score > 20.0 && score < 40.0);
     }
 
@@ -794,7 +873,7 @@ mod test {
 
         // We got the edit location right, but the replacement text is wrong.
         // Deleted ngrams will match, bringing the score somewhere in the middle.
-        let score = delta_chr_f(original, expected, actual);
+        let score = delta_chr_f(original, expected, actual).score;
         assert!(score > 40.0 && score < 60.0);
     }
 
@@ -806,7 +885,7 @@ mod test {
         let actual = "prefix old suffix"; // no change
 
         // Then the score should be low (all expected changes are false negatives)
-        let score = delta_chr_f(original, expected, actual);
+        let score = delta_chr_f(original, expected, actual).score;
         assert!(score < 20.0);
     }
 
@@ -818,14 +897,14 @@ mod test {
         let actual = "helloextraworld"; // added "extra"
 
         // Then the score should be low (all actual changes are false positives)
-        let score = delta_chr_f(original, expected, actual);
+        let score = delta_chr_f(original, expected, actual).score;
         assert!(score < 20.0);
     }
 
     #[test]
     fn test_delta_chr_f_no_changes() {
         let text = "unchanged text";
-        let score = delta_chr_f(text, text, text);
+        let score = delta_chr_f(text, text, text).score;
         assert!((score - 100.0).abs() < 1e-2);
     }
 

crates/edit_prediction_cli/src/score.rs đź”—

@@ -67,6 +67,12 @@ pub async fn run_scoring(
 
     let zero_scores = ExampleScore {
         delta_chr_f: 0.0,
+        delta_chr_f_true_positives: 0,
+        delta_chr_f_false_positives: 0,
+        delta_chr_f_false_negatives: 0,
+        delta_chr_f_precision: 0.0,
+        delta_chr_f_recall: 0.0,
+        delta_chr_f_beta: metrics::delta_chr_f_beta(),
         braces_disbalance: 0,
         exact_lines_tp: 0,
         exact_lines_fp: 0,
@@ -111,14 +117,14 @@ pub async fn run_scoring(
             }
         };
 
-        let mut best_delta_chr_f = 0.0f32;
+        let mut best_delta_chr_f_metrics = metrics::DeltaChrFMetrics::default();
         let mut best_expected_cursor: Option<usize> = None;
         let mut best_patch_idx: Option<usize> = None;
 
         for (idx, expected) in expected_texts.iter().enumerate() {
-            let delta_chr_f = metrics::delta_chr_f(original_text, expected, &actual_text) as f32;
-            if delta_chr_f > best_delta_chr_f {
-                best_delta_chr_f = delta_chr_f;
+            let delta_chr_f_metrics = metrics::delta_chr_f(original_text, expected, &actual_text);
+            if delta_chr_f_metrics.score > best_delta_chr_f_metrics.score {
+                best_delta_chr_f_metrics = delta_chr_f_metrics;
                 best_patch_idx = Some(idx);
             }
         }
@@ -179,7 +185,13 @@ pub async fn run_scoring(
         );
 
         scores.push(ExampleScore {
-            delta_chr_f: best_delta_chr_f,
+            delta_chr_f: best_delta_chr_f_metrics.score as f32,
+            delta_chr_f_true_positives: best_delta_chr_f_metrics.counts.true_positives,
+            delta_chr_f_false_positives: best_delta_chr_f_metrics.counts.false_positives,
+            delta_chr_f_false_negatives: best_delta_chr_f_metrics.counts.false_negatives,
+            delta_chr_f_precision: best_delta_chr_f_metrics.precision,
+            delta_chr_f_recall: best_delta_chr_f_metrics.recall,
+            delta_chr_f_beta: best_delta_chr_f_metrics.beta,
             braces_disbalance,
             exact_lines_tp: best_exact_lines.true_positives,
             exact_lines_fp: best_exact_lines.false_positives,
@@ -238,6 +250,10 @@ pub fn print_report(examples: &[Example], verbose: bool) {
     let mut all_delta_chr_f_scores = Vec::new();
     let mut all_reversal_ratios = Vec::new();
     let mut braces_disbalance_sum: usize = 0;
+    let mut total_delta_chr_f = ClassificationMetrics::default();
+    let mut total_delta_chr_f_precision = 0.0;
+    let mut total_delta_chr_f_recall = 0.0;
+    let mut delta_chr_f_beta = 0.0;
     let mut total_exact_lines = ClassificationMetrics::default();
     let mut total_scores: usize = 0;
     let mut qa_reverts_count: usize = 0;
@@ -260,11 +276,7 @@ pub fn print_report(examples: &[Example], verbose: bool) {
 
     for example in examples {
         for (score_idx, score) in example.score.iter().enumerate() {
-            let exact_lines = ClassificationMetrics {
-                true_positives: score.exact_lines_tp,
-                false_positives: score.exact_lines_fp,
-                false_negatives: score.exact_lines_fn,
-            };
+            let exact_lines = score.exact_lines_counts();
 
             // Get QA results for this prediction if available
             let qa_result = example.qa.get(score_idx).and_then(|q| q.as_ref());
@@ -314,9 +326,11 @@ pub fn print_report(examples: &[Example], verbose: bool) {
             all_reversal_ratios.push(score.reversal_ratio);
             total_scores += 1;
             braces_disbalance_sum += score.braces_disbalance;
-            total_exact_lines.true_positives += score.exact_lines_tp;
-            total_exact_lines.false_positives += score.exact_lines_fp;
-            total_exact_lines.false_negatives += score.exact_lines_fn;
+            total_delta_chr_f.accumulate(&score.delta_chr_f_counts());
+            total_delta_chr_f_precision += score.delta_chr_f_precision;
+            total_delta_chr_f_recall += score.delta_chr_f_recall;
+            delta_chr_f_beta = score.delta_chr_f_beta;
+            total_exact_lines.accumulate(&score.exact_lines_counts());
 
             // Accumulate QA metrics
             if let Some(qa) = qa_result {
@@ -448,6 +462,15 @@ pub fn print_report(examples: &[Example], verbose: bool) {
             wrong_er_str
         );
         println!("{}", separator);
+        println!(
+            "Delta chrF (β={:.1}): TP={}, FP={}, FN={}, P={:.1}%, R={:.1}%",
+            delta_chr_f_beta,
+            total_delta_chr_f.true_positives,
+            total_delta_chr_f.false_positives,
+            total_delta_chr_f.false_negatives,
+            total_delta_chr_f_precision / total_scores as f64 * 100.0,
+            total_delta_chr_f_recall / total_scores as f64 * 100.0
+        );
 
         // Print additional cursor metrics if available
         if let Some(avg_dist) = avg_cursor_distance {
@@ -540,6 +563,12 @@ fn truncate_name(name: &str, max_len: usize) -> String {
 pub struct SummaryJson {
     pub total_examples: usize,
     pub avg_delta_chr_f: f32,
+    pub delta_chr_f_beta: f64,
+    pub delta_chr_f_true_positives: usize,
+    pub delta_chr_f_false_positives: usize,
+    pub delta_chr_f_false_negatives: usize,
+    pub delta_chr_f_precision: f64,
+    pub delta_chr_f_recall: f64,
     pub avg_braces_disbalance: f32,
     pub exact_lines_true_positives: usize,
     pub exact_lines_false_positives: usize,
@@ -569,6 +598,10 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
     let mut all_delta_chr_f_scores = Vec::new();
     let mut all_reversal_ratios = Vec::new();
     let mut braces_disbalance_sum: usize = 0;
+    let mut total_delta_chr_f = ClassificationMetrics::default();
+    let mut total_delta_chr_f_precision = 0.0;
+    let mut total_delta_chr_f_recall = 0.0;
+    let mut delta_chr_f_beta = 0.0;
     let mut total_exact_lines = ClassificationMetrics::default();
     let mut total_scores: usize = 0;
     let mut qa_reverts_count: usize = 0;
@@ -589,9 +622,11 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
             all_reversal_ratios.push(score.reversal_ratio);
             total_scores += 1;
             braces_disbalance_sum += score.braces_disbalance;
-            total_exact_lines.true_positives += score.exact_lines_tp;
-            total_exact_lines.false_positives += score.exact_lines_fp;
-            total_exact_lines.false_negatives += score.exact_lines_fn;
+            total_delta_chr_f.accumulate(&score.delta_chr_f_counts());
+            total_delta_chr_f_precision += score.delta_chr_f_precision;
+            total_delta_chr_f_recall += score.delta_chr_f_recall;
+            delta_chr_f_beta = score.delta_chr_f_beta;
+            total_exact_lines.accumulate(&score.exact_lines_counts());
 
             // Accumulate QA metrics
             if let Some(Some(qa)) = example.qa.get(score_idx) {
@@ -697,6 +732,20 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
     SummaryJson {
         total_examples: total_scores,
         avg_delta_chr_f,
+        delta_chr_f_beta,
+        delta_chr_f_true_positives: total_delta_chr_f.true_positives,
+        delta_chr_f_false_positives: total_delta_chr_f.false_positives,
+        delta_chr_f_false_negatives: total_delta_chr_f.false_negatives,
+        delta_chr_f_precision: if total_scores == 0 {
+            0.0
+        } else {
+            total_delta_chr_f_precision / total_scores as f64
+        },
+        delta_chr_f_recall: if total_scores == 0 {
+            0.0
+        } else {
+            total_delta_chr_f_recall / total_scores as f64
+        },
         avg_braces_disbalance,
         exact_lines_true_positives: total_exact_lines.true_positives,
         exact_lines_false_positives: total_exact_lines.false_positives,