metrics.rs

  1use collections::{HashMap, HashSet};
  2use edit_prediction::udiff::DiffLine;
  3use serde::{Deserialize, Serialize};
  4
  5type Counts = HashMap<String, usize>;
  6type CountsDelta = HashMap<String, isize>;
  7
  8#[derive(Default, Debug, Clone, Serialize, Deserialize)]
  9pub struct ClassificationMetrics {
 10    pub true_positives: usize,
 11    pub false_positives: usize,
 12    pub false_negatives: usize,
 13}
 14
 15impl ClassificationMetrics {
 16    pub fn from_sets(
 17        expected: &HashSet<String>,
 18        actual: &HashSet<String>,
 19    ) -> ClassificationMetrics {
 20        let true_positives = expected.intersection(actual).count();
 21        let false_positives = actual.difference(expected).count();
 22        let false_negatives = expected.difference(actual).count();
 23
 24        ClassificationMetrics {
 25            true_positives,
 26            false_positives,
 27            false_negatives,
 28        }
 29    }
 30
 31    pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
 32        let mut true_positives = 0;
 33        let mut false_positives = 0;
 34        let mut false_negatives = 0;
 35
 36        for (ngram, &expected_count) in expected {
 37            let actual_count = *actual.get(ngram).unwrap_or(&0);
 38            if actual_count > expected_count {
 39                false_positives += actual_count - expected_count;
 40            } else {
 41                false_negatives += expected_count - actual_count;
 42            }
 43            true_positives += expected_count.min(actual_count);
 44        }
 45
 46        for (ngram, &actual_count) in actual {
 47            if !expected.contains_key(ngram) {
 48                false_positives += actual_count;
 49            }
 50        }
 51
 52        ClassificationMetrics {
 53            true_positives,
 54            false_positives,
 55            false_negatives,
 56        }
 57    }
 58
 59    pub fn aggregate<'a>(
 60        scores: impl Iterator<Item = &'a ClassificationMetrics>,
 61    ) -> ClassificationMetrics {
 62        let mut true_positives = 0;
 63        let mut false_positives = 0;
 64        let mut false_negatives = 0;
 65
 66        for score in scores {
 67            true_positives += score.true_positives;
 68            false_positives += score.false_positives;
 69            false_negatives += score.false_negatives;
 70        }
 71
 72        ClassificationMetrics {
 73            true_positives,
 74            false_positives,
 75            false_negatives,
 76        }
 77    }
 78
 79    pub fn precision(&self) -> f64 {
 80        if self.true_positives + self.false_positives == 0 {
 81            0.0
 82        } else {
 83            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
 84        }
 85    }
 86
 87    pub fn recall(&self) -> f64 {
 88        if self.true_positives + self.false_negatives == 0 {
 89            0.0
 90        } else {
 91            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
 92        }
 93    }
 94
 95    pub fn f1_score(&self) -> f64 {
 96        let recall = self.recall();
 97        let precision = self.precision();
 98        if precision + recall == 0.0 {
 99            0.0
100        } else {
101            2.0 * precision * recall / (precision + recall)
102        }
103    }
104}
105
106pub fn line_match_score(
107    expected_patch: &[DiffLine],
108    actual_patch: &[DiffLine],
109) -> ClassificationMetrics {
110    let expected_change_lines = expected_patch
111        .iter()
112        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
113        .map(|line| line.to_string())
114        .collect();
115
116    let actual_change_lines = actual_patch
117        .iter()
118        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
119        .map(|line| line.to_string())
120        .collect();
121
122    ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
123}
124
125enum ChrfWhitespace {
126    #[allow(unused)]
127    Unchanged,
128    Ignore,
129}
130
131const CHR_F_CHAR_ORDER: usize = 6;
132const CHR_F_BETA: f64 = 2.0;
133const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Ignore;
134
135/// Computes a delta-chrF score that compares two sets of edits.
136///
137/// This metric works by:
138/// 1. Reconstructing original, golden (expected result), and actual texts from diffs
139/// 2. Computing n-gram count differences (deltas) between original→golden and original→actual
140/// 3. Comparing these deltas to measure how well actual edits match expected edits
141pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
142    // Reconstruct texts from diffs
143    let mut original_text = String::new(); // state of the text before any edits
144    let mut golden_text = String::new(); // text after applying golden edits
145    let mut actual_text = String::new(); // text after applying actual edits
146
147    for line in expected {
148        match line {
149            DiffLine::Context(s) => {
150                original_text.push_str(s);
151                golden_text.push_str(s);
152            }
153            DiffLine::Deletion(s) => {
154                original_text.push_str(s);
155            }
156            DiffLine::Addition(s) => {
157                golden_text.push_str(s);
158            }
159            _ => {}
160        }
161    }
162
163    for line in actual {
164        match line {
165            DiffLine::Context(s) | DiffLine::Addition(s) => {
166                actual_text.push_str(s);
167            }
168            _ => {}
169        }
170    }
171
172    // Edge case
173    if original_text == golden_text && golden_text == actual_text {
174        return 100.0;
175    }
176
177    // Compute the metric
178    let original_ngrams = chr_f_ngram_counts(&original_text);
179    let golden_ngrams = chr_f_ngram_counts(&golden_text);
180    let actual_ngrams = chr_f_ngram_counts(&actual_text);
181
182    let mut total_precision = 0.0;
183    let mut total_recall = 0.0;
184
185    for order in 0..CHR_F_CHAR_ORDER {
186        let expected_delta = compute_ngram_delta(&golden_ngrams[order], &original_ngrams[order]);
187        let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]);
188
189        if expected_delta.is_empty() && actual_delta.is_empty() {
190            total_precision += 1.0;
191            total_recall += 1.0;
192            continue;
193        }
194
195        let expected_counts = ngram_delta_to_counts(&expected_delta);
196        let actual_counts = ngram_delta_to_counts(&actual_delta);
197
198        let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
199        total_precision += score.precision();
200        total_recall += score.recall();
201    }
202
203    let prec = total_precision / CHR_F_CHAR_ORDER as f64;
204    let recall = total_recall / CHR_F_CHAR_ORDER as f64;
205    let f_score = if prec + recall == 0.0 {
206        0.0
207    } else {
208        (1.0 + CHR_F_BETA * CHR_F_BETA) * prec * recall / (CHR_F_BETA * CHR_F_BETA * prec + recall)
209    };
210
211    f_score * 100.0
212}
213
214fn chr_f_ngram_counts(text: &str) -> Vec<Counts> {
215    // Ignore whitespace. The original chrF implementation skips all
216    // whitespace. We should consider compressing multiple consecutive
217    // spaces into one -- this may reflect our task more closely.
218    let text = match CHR_F_WHITESPACE {
219        ChrfWhitespace::Unchanged => text.to_string(),
220        ChrfWhitespace::Ignore => text
221            .chars()
222            .filter(|c| !c.is_whitespace())
223            .collect::<String>(),
224    };
225
226    (1..=CHR_F_CHAR_ORDER)
227        .map(|order| count_ngrams(&text, order))
228        .collect()
229}
230
231fn compute_ngram_delta(after: &Counts, before: &Counts) -> CountsDelta {
232    let mut delta = CountsDelta::default();
233
234    for (ngram, &before_count) in before {
235        let after_count = *after.get(ngram).unwrap_or(&0);
236        delta.insert(ngram.clone(), after_count as isize - before_count as isize);
237    }
238
239    for (ngram, &after_count) in after {
240        if !before.contains_key(ngram) {
241            delta.insert(ngram.clone(), after_count as isize);
242        }
243    }
244
245    delta
246}
247
248/// Convert negative counts to special deletion tokens.
249/// For example, if expected delta is {"foo": -1} and actual delta is {"bar": -1},
250/// we convert it to {"¬foo": +1} and {"¬bar": +1}. This way _not_ deleting "foo"
251/// will result in a false negative, and mistakenly deleting "bar" will result in a false positive.
252fn ngram_delta_to_counts(delta: &CountsDelta) -> Counts {
253    let mut counts = Counts::default();
254
255    for (ngram, &delta) in delta {
256        if delta > 0 {
257            counts.insert(ngram.clone(), delta as usize);
258        } else {
259            counts.insert(format!("¬{ngram}"), delta.unsigned_abs());
260        }
261    }
262
263    counts
264}
265
266fn count_ngrams(text: &str, n: usize) -> Counts {
267    let chars: Vec<char> = text.chars().collect();
268    let mut counts = Counts::default();
269
270    for window in chars.windows(n) {
271        let ngram: String = window.iter().collect();
272        *counts.entry(ngram).or_insert(0) += 1;
273    }
274
275    counts
276}
277
278#[cfg(test)]
279mod test {
280    use super::*;
281    use edit_prediction::udiff::DiffLine;
282
283    #[test]
284    fn test_delta_chr_f_perfect_match() {
285        let diff = vec![
286            DiffLine::Context("fn main() {"),
287            DiffLine::Deletion("    println!(\"Hello\");"),
288            DiffLine::Addition("    println!(\"Hello, World!\");"),
289            DiffLine::Context("}"),
290        ];
291
292        let score = delta_chr_f(&diff, &diff);
293        assert!((score - 100.0).abs() < 1e-2);
294    }
295
296    #[test]
297    fn test_delta_chr_f_wrong_edit() {
298        // When the edit is wrong
299        let expected = vec![
300            DiffLine::Context("one "),
301            DiffLine::Deletion("two "),
302            DiffLine::Context("three"),
303        ];
304
305        let actual = vec![
306            DiffLine::Context("one "),
307            DiffLine::Context("two "),
308            DiffLine::Deletion("three"),
309            DiffLine::Addition("four"),
310        ];
311
312        // Then the score should be low
313        let score = delta_chr_f(&expected, &actual);
314        assert!(score > 20.0 && score < 40.0);
315    }
316
317    #[test]
318    fn test_delta_chr_f_partial_match() {
319        let expected = vec![
320            DiffLine::Deletion("let x = 42;"),
321            DiffLine::Addition("let x = 100;"),
322        ];
323
324        let actual = vec![
325            DiffLine::Deletion("let x = 42;"),
326            DiffLine::Addition("let x = 99;"),
327        ];
328
329        // We got the edit location right, but the replacement text is wrong.
330        // Deleted ngrams will match, bringing the score somewhere in the middle.
331        let score = delta_chr_f(&expected, &actual);
332        assert!(score > 40.0 && score < 60.0);
333    }
334
335    #[test]
336    fn test_delta_chr_f_missed_edit() {
337        // When predictions makes no changes
338        let expected = vec![
339            DiffLine::Context("prefix "),
340            DiffLine::Deletion("old"),
341            DiffLine::Addition("new"),
342            DiffLine::Context(" suffix"),
343        ];
344
345        let actual = vec![
346            DiffLine::Context("prefix "),
347            DiffLine::Context("old"),
348            DiffLine::Context(" suffix"),
349        ];
350
351        // Then the score should be low (all expected changes are false negatives)
352        let score = delta_chr_f(&expected, &actual);
353        assert!(score < 20.0);
354    }
355
356    #[test]
357    fn test_delta_chr_f_extra_edit() {
358        // When adding unexpected content
359        let expected = vec![DiffLine::Context("hello"), DiffLine::Context("world")];
360
361        let actual = vec![
362            DiffLine::Context("hello"),
363            DiffLine::Addition("extra"),
364            DiffLine::Context("world"),
365        ];
366
367        // Then the score should be low (all actual changes are false positives)
368        let score = delta_chr_f(&expected, &actual);
369        assert!(score < 20.0);
370    }
371}