metrics.rs

  1use collections::{HashMap, HashSet};
  2use zeta::udiff::DiffLine;
  3
  4type Counts = HashMap<String, usize>;
  5type CountsDelta = HashMap<String, isize>;
  6
  7#[derive(Default, Debug, Clone)]
  8pub struct Scores {
  9    pub true_positives: usize,
 10    pub false_positives: usize,
 11    pub false_negatives: usize,
 12}
 13
 14impl Scores {
 15    pub fn from_sets(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
 16        let true_positives = expected.intersection(actual).count();
 17        let false_positives = actual.difference(expected).count();
 18        let false_negatives = expected.difference(actual).count();
 19
 20        Scores {
 21            true_positives,
 22            false_positives,
 23            false_negatives,
 24        }
 25    }
 26
 27    pub fn from_counts(expected: &Counts, actual: &Counts) -> Scores {
 28        let mut true_positives = 0;
 29        let mut false_positives = 0;
 30        let mut false_negatives = 0;
 31
 32        for (ngram, &expected_count) in expected {
 33            let actual_count = *actual.get(ngram).unwrap_or(&0);
 34            if actual_count > expected_count {
 35                false_positives += actual_count - expected_count;
 36            } else {
 37                false_negatives += expected_count - actual_count;
 38            }
 39            true_positives += expected_count.min(actual_count);
 40        }
 41
 42        for (ngram, &actual_count) in actual {
 43            if !expected.contains_key(ngram) {
 44                false_positives += actual_count;
 45            }
 46        }
 47
 48        Scores {
 49            true_positives,
 50            false_positives,
 51            false_negatives,
 52        }
 53    }
 54
 55    pub fn to_markdown(&self) -> String {
 56        format!(
 57            "
 58Precision       : {:.4}
 59Recall          : {:.4}
 60F1 Score        : {:.4}
 61True Positives  : {}
 62False Positives : {}
 63False Negatives : {}",
 64            self.precision(),
 65            self.recall(),
 66            self.f1_score(),
 67            self.true_positives,
 68            self.false_positives,
 69            self.false_negatives
 70        )
 71    }
 72
 73    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
 74        let mut true_positives = 0;
 75        let mut false_positives = 0;
 76        let mut false_negatives = 0;
 77
 78        for score in scores {
 79            true_positives += score.true_positives;
 80            false_positives += score.false_positives;
 81            false_negatives += score.false_negatives;
 82        }
 83
 84        Scores {
 85            true_positives,
 86            false_positives,
 87            false_negatives,
 88        }
 89    }
 90
 91    pub fn precision(&self) -> f64 {
 92        if self.true_positives + self.false_positives == 0 {
 93            0.0
 94        } else {
 95            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
 96        }
 97    }
 98
 99    pub fn recall(&self) -> f64 {
100        if self.true_positives + self.false_negatives == 0 {
101            0.0
102        } else {
103            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
104        }
105    }
106
107    pub fn f1_score(&self) -> f64 {
108        let recall = self.recall();
109        let precision = self.precision();
110        if precision + recall == 0.0 {
111            0.0
112        } else {
113            2.0 * precision * recall / (precision + recall)
114        }
115    }
116}
117
118pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine]) -> Scores {
119    let expected_change_lines = expected_patch
120        .iter()
121        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
122        .map(|line| line.to_string())
123        .collect();
124
125    let actual_change_lines = actual_patch
126        .iter()
127        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
128        .map(|line| line.to_string())
129        .collect();
130
131    Scores::from_sets(&expected_change_lines, &actual_change_lines)
132}
133
134enum ChrfWhitespace {
135    #[allow(unused)]
136    Unchanged,
137    Ignore,
138}
139
140const CHR_F_CHAR_ORDER: usize = 6;
141const CHR_F_BETA: f64 = 2.0;
142const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Ignore;
143
144/// Computes a delta-chrF score that compares two sets of edits.
145///
146/// This metric works by:
147/// 1. Reconstructing original, golden (expected result), and actual texts from diffs
148/// 2. Computing n-gram count differences (deltas) between original→golden and original→actual
149/// 3. Comparing these deltas to measure how well actual edits match expected edits
150pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
151    // Reconstruct texts from diffs
152    let mut original_text = String::new(); // state of the text before any edits
153    let mut golden_text = String::new(); // text after applying golden edits
154    let mut actual_text = String::new(); // text after applying actual edits
155
156    for line in expected {
157        match line {
158            DiffLine::Context(s) => {
159                original_text.push_str(s);
160                golden_text.push_str(s);
161            }
162            DiffLine::Deletion(s) => {
163                original_text.push_str(s);
164            }
165            DiffLine::Addition(s) => {
166                golden_text.push_str(s);
167            }
168            _ => {}
169        }
170    }
171
172    for line in actual {
173        match line {
174            DiffLine::Context(s) | DiffLine::Addition(s) => {
175                actual_text.push_str(s);
176            }
177            _ => {}
178        }
179    }
180
181    // Edge case
182    if original_text == golden_text && golden_text == actual_text {
183        return 100.0;
184    }
185
186    // Compute the metric
187    let original_ngrams = chr_f_ngram_counts(&original_text);
188    let golden_ngrams = chr_f_ngram_counts(&golden_text);
189    let actual_ngrams = chr_f_ngram_counts(&actual_text);
190
191    let mut total_precision = 0.0;
192    let mut total_recall = 0.0;
193
194    for order in 0..CHR_F_CHAR_ORDER {
195        let expected_delta = compute_ngram_delta(&golden_ngrams[order], &original_ngrams[order]);
196        let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]);
197
198        if expected_delta.is_empty() && actual_delta.is_empty() {
199            total_precision += 1.0;
200            total_recall += 1.0;
201            continue;
202        }
203
204        let expected_counts = ngram_delta_to_counts(&expected_delta);
205        let actual_counts = ngram_delta_to_counts(&actual_delta);
206
207        let score = Scores::from_counts(&expected_counts, &actual_counts);
208        total_precision += score.precision();
209        total_recall += score.recall();
210    }
211
212    let prec = total_precision / CHR_F_CHAR_ORDER as f64;
213    let recall = total_recall / CHR_F_CHAR_ORDER as f64;
214    let f_score = if prec + recall == 0.0 {
215        0.0
216    } else {
217        (1.0 + CHR_F_BETA * CHR_F_BETA) * prec * recall / (CHR_F_BETA * CHR_F_BETA * prec + recall)
218    };
219
220    f_score * 100.0
221}
222
223fn chr_f_ngram_counts(text: &str) -> Vec<Counts> {
224    // Ignore whitespace. The original chrF implementation skips all
225    // whitespace. We should consider compressing multiple consecutive
226    // spaces into one -- this may reflect our task more closely.
227    let text = match CHR_F_WHITESPACE {
228        ChrfWhitespace::Unchanged => text.to_string(),
229        ChrfWhitespace::Ignore => text
230            .chars()
231            .filter(|c| !c.is_whitespace())
232            .collect::<String>(),
233    };
234
235    (1..=CHR_F_CHAR_ORDER)
236        .map(|order| count_ngrams(&text, order))
237        .collect()
238}
239
240fn compute_ngram_delta(after: &Counts, before: &Counts) -> CountsDelta {
241    let mut delta = CountsDelta::default();
242
243    for (ngram, &before_count) in before {
244        let after_count = *after.get(ngram).unwrap_or(&0);
245        delta.insert(ngram.clone(), after_count as isize - before_count as isize);
246    }
247
248    for (ngram, &after_count) in after {
249        if !before.contains_key(ngram) {
250            delta.insert(ngram.clone(), after_count as isize);
251        }
252    }
253
254    delta
255}
256
257/// Convert negative counts to special deletion tokens.
258/// For example, if expected delta is {"foo": -1} and actual delta is {"bar": -1},
259/// we convert it to {"¬foo": +1} and {"¬bar": +1}. This way _not_ deleting "foo"
260/// will result in a false negative, and mistakenly deleting "bar" will result in a false positive.
261fn ngram_delta_to_counts(delta: &CountsDelta) -> Counts {
262    let mut counts = Counts::default();
263
264    for (ngram, &delta) in delta {
265        if delta > 0 {
266            counts.insert(ngram.clone(), delta as usize);
267        } else {
268            counts.insert(format!("¬{ngram}"), delta.unsigned_abs());
269        }
270    }
271
272    counts
273}
274
275fn count_ngrams(text: &str, n: usize) -> Counts {
276    let chars: Vec<char> = text.chars().collect();
277    let mut counts = Counts::default();
278
279    for window in chars.windows(n) {
280        let ngram: String = window.iter().collect();
281        *counts.entry(ngram).or_insert(0) += 1;
282    }
283
284    counts
285}
286
287#[cfg(test)]
288mod test {
289    use super::*;
290    use zeta::udiff::DiffLine;
291
292    #[test]
293    fn test_delta_chr_f_perfect_match() {
294        let diff = vec![
295            DiffLine::Context("fn main() {"),
296            DiffLine::Deletion("    println!(\"Hello\");"),
297            DiffLine::Addition("    println!(\"Hello, World!\");"),
298            DiffLine::Context("}"),
299        ];
300
301        let score = delta_chr_f(&diff, &diff);
302        assert!((score - 100.0).abs() < 1e-2);
303    }
304
305    #[test]
306    fn test_delta_chr_f_wrong_edit() {
307        // When the edit is wrong
308        let expected = vec![
309            DiffLine::Context("one "),
310            DiffLine::Deletion("two "),
311            DiffLine::Context("three"),
312        ];
313
314        let actual = vec![
315            DiffLine::Context("one "),
316            DiffLine::Context("two "),
317            DiffLine::Deletion("three"),
318            DiffLine::Addition("four"),
319        ];
320
321        // Then the score should be low
322        let score = delta_chr_f(&expected, &actual);
323        assert!(score > 20.0 && score < 40.0);
324    }
325
326    #[test]
327    fn test_delta_chr_f_partial_match() {
328        let expected = vec![
329            DiffLine::Deletion("let x = 42;"),
330            DiffLine::Addition("let x = 100;"),
331        ];
332
333        let actual = vec![
334            DiffLine::Deletion("let x = 42;"),
335            DiffLine::Addition("let x = 99;"),
336        ];
337
338        // We got the edit location right, but the replacement text is wrong.
339        // Deleted ngrams will match, bringing the score somewhere in the middle.
340        let score = delta_chr_f(&expected, &actual);
341        assert!(score > 40.0 && score < 60.0);
342    }
343
344    #[test]
345    fn test_delta_chr_f_missed_edit() {
346        // When predictions makes no changes
347        let expected = vec![
348            DiffLine::Context("prefix "),
349            DiffLine::Deletion("old"),
350            DiffLine::Addition("new"),
351            DiffLine::Context(" suffix"),
352        ];
353
354        let actual = vec![
355            DiffLine::Context("prefix "),
356            DiffLine::Context("old"),
357            DiffLine::Context(" suffix"),
358        ];
359
360        // Then the score should be low (all expected changes are false negatives)
361        let score = delta_chr_f(&expected, &actual);
362        assert!(score < 20.0);
363    }
364
365    #[test]
366    fn test_delta_chr_f_extra_edit() {
367        // When adding unexpected content
368        let expected = vec![DiffLine::Context("hello"), DiffLine::Context("world")];
369
370        let actual = vec![
371            DiffLine::Context("hello"),
372            DiffLine::Addition("extra"),
373            DiffLine::Context("world"),
374        ];
375
376        // Then the score should be low (all actual changes are false positives)
377        let score = delta_chr_f(&expected, &actual);
378        assert!(score < 20.0);
379    }
380}