metrics.rs

  1use collections::HashMap;
  2
  3type Counts = HashMap<String, usize>;
  4type CountsDelta = HashMap<String, isize>;
  5
  6/// Context characters needed on each side of a change to capture all affected n-grams
  7const CONTEXT_CHARS: usize = CHR_F_CHAR_ORDER - 1;
  8
  9#[derive(Default, Debug, Clone)]
 10struct ClassificationMetrics {
 11    true_positives: usize,
 12    false_positives: usize,
 13    false_negatives: usize,
 14}
 15
 16impl ClassificationMetrics {
 17    fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
 18        let mut true_positives = 0;
 19        let mut false_positives = 0;
 20        let mut false_negatives = 0;
 21
 22        for (ngram, &expected_count) in expected {
 23            let actual_count = *actual.get(ngram).unwrap_or(&0);
 24            if actual_count > expected_count {
 25                false_positives += actual_count - expected_count;
 26            } else {
 27                false_negatives += expected_count - actual_count;
 28            }
 29            true_positives += expected_count.min(actual_count);
 30        }
 31
 32        for (ngram, &actual_count) in actual {
 33            if !expected.contains_key(ngram) {
 34                false_positives += actual_count;
 35            }
 36        }
 37
 38        ClassificationMetrics {
 39            true_positives,
 40            false_positives,
 41            false_negatives,
 42        }
 43    }
 44
 45    fn precision(&self) -> f64 {
 46        if self.true_positives + self.false_positives == 0 {
 47            0.0
 48        } else {
 49            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
 50        }
 51    }
 52
 53    fn recall(&self) -> f64 {
 54        if self.true_positives + self.false_negatives == 0 {
 55            0.0
 56        } else {
 57            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
 58        }
 59    }
 60}
 61
 62enum ChrfWhitespace {
 63    #[allow(unused)]
 64    Unchanged,
 65    Ignore,
 66}
 67
 68const CHR_F_CHAR_ORDER: usize = 6;
 69const CHR_F_BETA: f64 = 2.0;
 70const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Ignore;
 71
 72/// Computes a delta-chrF score that compares two sets of edits.
 73///
 74/// This metric works by:
 75/// 1. Computing n-gram count differences (deltas) between original→expected and original→actual
 76/// 2. Comparing these deltas to measure how well actual edits match expected edits
 77///
 78/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match
 79/// the expected edits.
 80pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 {
 81    // Edge case: if all texts are identical, the edits match perfectly
 82    if original == expected && expected == actual {
 83        return 100.0;
 84    }
 85
 86    // Pre-filter whitespace once for all texts
 87    let orig_chars: Vec<char> = filter_whitespace_chars(original);
 88    let exp_chars: Vec<char> = filter_whitespace_chars(expected);
 89    let act_chars: Vec<char> = filter_whitespace_chars(actual);
 90
 91    // Find the changed regions between original→expected and original→actual
 92    // We only need to compute n-grams on these regions (plus context for boundary n-grams)
 93    let (orig_for_exp, exp_region) = extract_changed_regions(&orig_chars, &exp_chars);
 94    let (orig_for_act, act_region) = extract_changed_regions(&orig_chars, &act_chars);
 95
 96    let mut total_precision = 0.0;
 97    let mut total_recall = 0.0;
 98
 99    for order in 1..=CHR_F_CHAR_ORDER {
100        // Compute n-grams only on the affected regions
101        let orig_ngrams_for_exp = count_ngrams_from_chars(&orig_for_exp, order);
102        let exp_ngrams = count_ngrams_from_chars(&exp_region, order);
103        let expected_delta = compute_ngram_delta(&exp_ngrams, &orig_ngrams_for_exp);
104
105        let orig_ngrams_for_act = count_ngrams_from_chars(&orig_for_act, order);
106        let act_ngrams = count_ngrams_from_chars(&act_region, order);
107        let actual_delta = compute_ngram_delta(&act_ngrams, &orig_ngrams_for_act);
108
109        if expected_delta.is_empty() && actual_delta.is_empty() {
110            total_precision += 1.0;
111            total_recall += 1.0;
112            continue;
113        }
114
115        let expected_counts = ngram_delta_to_counts(&expected_delta);
116        let actual_counts = ngram_delta_to_counts(&actual_delta);
117
118        let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
119        total_precision += score.precision();
120        total_recall += score.recall();
121    }
122
123    let prec = total_precision / CHR_F_CHAR_ORDER as f64;
124    let recall = total_recall / CHR_F_CHAR_ORDER as f64;
125    let f_score = if prec + recall == 0.0 {
126        0.0
127    } else {
128        (1.0 + CHR_F_BETA * CHR_F_BETA) * prec * recall / (CHR_F_BETA * CHR_F_BETA * prec + recall)
129    };
130
131    f_score * 100.0
132}
133
134/// Reference implementation of delta_chr_f (original, non-optimized version).
135/// Used for testing that the optimized version produces identical results.
136#[cfg(test)]
137fn delta_chr_f_reference(original: &str, expected: &str, actual: &str) -> f64 {
138    if original == expected && expected == actual {
139        return 100.0;
140    }
141
142    let original_ngrams = chr_f_ngram_counts(original);
143    let expected_ngrams = chr_f_ngram_counts(expected);
144    let actual_ngrams = chr_f_ngram_counts(actual);
145
146    let mut total_precision = 0.0;
147    let mut total_recall = 0.0;
148
149    for order in 0..CHR_F_CHAR_ORDER {
150        let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]);
151        let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]);
152
153        if expected_delta.is_empty() && actual_delta.is_empty() {
154            total_precision += 1.0;
155            total_recall += 1.0;
156            continue;
157        }
158
159        let expected_counts = ngram_delta_to_counts(&expected_delta);
160        let actual_counts = ngram_delta_to_counts(&actual_delta);
161
162        let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
163        total_precision += score.precision();
164        total_recall += score.recall();
165    }
166
167    let prec = total_precision / CHR_F_CHAR_ORDER as f64;
168    let recall = total_recall / CHR_F_CHAR_ORDER as f64;
169    let f_score = if prec + recall == 0.0 {
170        0.0
171    } else {
172        (1.0 + CHR_F_BETA * CHR_F_BETA) * prec * recall / (CHR_F_BETA * CHR_F_BETA * prec + recall)
173    };
174
175    f_score * 100.0
176}
177
178/// Filter whitespace from a string and return as Vec<char>
179fn filter_whitespace_chars(text: &str) -> Vec<char> {
180    match CHR_F_WHITESPACE {
181        ChrfWhitespace::Unchanged => text.chars().collect(),
182        ChrfWhitespace::Ignore => text.chars().filter(|c| !c.is_whitespace()).collect(),
183    }
184}
185
186/// Extract only the changed regions between two texts, with context for n-gram boundaries.
187///
188/// Returns (original_affected_region, modified_affected_region) as Vec<char>.
189///
190/// The key insight: when computing n-gram delta between two nearly-identical texts,
191/// n-grams from unchanged regions cancel out. We only need to process:
192/// 1. The changed content itself
193/// 2. CONTEXT_CHARS (n-1) characters before and after, to capture boundary-crossing n-grams
194fn extract_changed_regions(original: &[char], modified: &[char]) -> (Vec<char>, Vec<char>) {
195    // Find longest common prefix
196    let prefix_len = original
197        .iter()
198        .zip(modified.iter())
199        .take_while(|(a, b)| a == b)
200        .count();
201
202    // Find longest common suffix (that doesn't overlap with prefix)
203    let orig_remaining = original.len().saturating_sub(prefix_len);
204    let mod_remaining = modified.len().saturating_sub(prefix_len);
205    let max_suffix = orig_remaining.min(mod_remaining);
206
207    let suffix_len = original
208        .iter()
209        .rev()
210        .zip(modified.iter().rev())
211        .take(max_suffix)
212        .take_while(|(a, b)| a == b)
213        .count();
214
215    // Calculate the changed region boundaries
216    let orig_change_start = prefix_len;
217    let orig_change_end = original.len().saturating_sub(suffix_len);
218    let mod_change_start = prefix_len;
219    let mod_change_end = modified.len().saturating_sub(suffix_len);
220
221    // If there's no actual change, return empty regions
222    if orig_change_start >= orig_change_end && mod_change_start >= mod_change_end {
223        return (Vec::new(), Vec::new());
224    }
225
226    // Expand to include context for n-gram boundaries
227    let orig_context_start = orig_change_start.saturating_sub(CONTEXT_CHARS);
228    let orig_context_end = (orig_change_end + CONTEXT_CHARS).min(original.len());
229    let mod_context_start = mod_change_start.saturating_sub(CONTEXT_CHARS);
230    let mod_context_end = (mod_change_end + CONTEXT_CHARS).min(modified.len());
231
232    let orig_region: Vec<char> = original[orig_context_start..orig_context_end].to_vec();
233    let mod_region: Vec<char> = modified[mod_context_start..mod_context_end].to_vec();
234
235    (orig_region, mod_region)
236}
237
238/// Count n-grams directly from a char slice (avoids String allocation for the full text)
239fn count_ngrams_from_chars(chars: &[char], n: usize) -> Counts {
240    let mut counts = Counts::default();
241
242    if chars.len() < n {
243        return counts;
244    }
245
246    for window in chars.windows(n) {
247        let ngram: String = window.iter().collect();
248        *counts.entry(ngram).or_insert(0) += 1;
249    }
250
251    counts
252}
253
254#[allow(dead_code)]
255fn chr_f_ngram_counts(text: &str) -> Vec<Counts> {
256    // Ignore whitespace. The original chrF implementation skips all
257    // whitespace. We should consider compressing multiple consecutive
258    // spaces into one -- this may reflect our task more closely.
259    let text = match CHR_F_WHITESPACE {
260        ChrfWhitespace::Unchanged => text.to_string(),
261        ChrfWhitespace::Ignore => text
262            .chars()
263            .filter(|c| !c.is_whitespace())
264            .collect::<String>(),
265    };
266
267    (1..=CHR_F_CHAR_ORDER)
268        .map(|order| count_ngrams(&text, order))
269        .collect()
270}
271
272fn compute_ngram_delta(after: &Counts, before: &Counts) -> CountsDelta {
273    let mut delta = CountsDelta::default();
274
275    for (ngram, &before_count) in before {
276        let after_count = *after.get(ngram).unwrap_or(&0);
277        delta.insert(ngram.clone(), after_count as isize - before_count as isize);
278    }
279
280    for (ngram, &after_count) in after {
281        if !before.contains_key(ngram) {
282            delta.insert(ngram.clone(), after_count as isize);
283        }
284    }
285
286    delta
287}
288
289/// Convert negative counts to special deletion tokens.
290/// For example, if expected delta is {"foo": -1} and actual delta is {"bar": -1},
291/// we convert it to {"¬foo": +1} and {"¬bar": +1}. This way _not_ deleting "foo"
292/// will result in a false negative, and mistakenly deleting "bar" will result in a false positive.
293fn ngram_delta_to_counts(delta: &CountsDelta) -> Counts {
294    let mut counts = Counts::default();
295
296    for (ngram, &delta) in delta {
297        if delta > 0 {
298            counts.insert(ngram.clone(), delta as usize);
299        } else if delta < 0 {
300            counts.insert(format!("¬{ngram}"), delta.unsigned_abs());
301        }
302    }
303
304    counts
305}
306
307#[allow(dead_code)]
308fn count_ngrams(text: &str, n: usize) -> Counts {
309    let chars: Vec<char> = text.chars().collect();
310    let mut counts = Counts::default();
311
312    for window in chars.windows(n) {
313        let ngram: String = window.iter().collect();
314        *counts.entry(ngram).or_insert(0) += 1;
315    }
316
317    counts
318}
319
320pub fn braces_disbalance(text: &str) -> usize {
321    let mut disbalance = 0isize;
322
323    let a = text.chars().filter(|&c| c == '{').count() as isize;
324    let b = text.chars().filter(|&c| c == '}').count() as isize;
325    disbalance += (a - b).abs();
326
327    let a = text.chars().filter(|&c| c == '(').count() as isize;
328    let b = text.chars().filter(|&c| c == ')').count() as isize;
329    disbalance += (a - b).abs();
330
331    let a = text.chars().filter(|&c| c == '[').count() as isize;
332    let b = text.chars().filter(|&c| c == ']').count() as isize;
333    disbalance += (a - b).abs();
334
335    disbalance as usize
336}
337
338#[cfg(test)]
339mod test_optimization {
340    use super::*;
341
342    #[test]
343    fn test_extract_changed_regions_simple() {
344        let original: Vec<char> = "hello world".chars().collect();
345        let modified: Vec<char> = "hello there".chars().collect();
346
347        let (orig_region, mod_region) = extract_changed_regions(&original, &modified);
348
349        // "world" vs "there" - with 5 chars context, we get "ello world" vs "ello there"
350        // (or less if not enough chars available)
351        assert!(orig_region.len() < original.len());
352        assert!(mod_region.len() < modified.len());
353    }
354
355    #[test]
356    fn test_extract_changed_regions_insertion() {
357        let original: Vec<char> = "abcdef".chars().collect();
358        let modified: Vec<char> = "abcXYZdef".chars().collect();
359
360        let (orig_region, mod_region) = extract_changed_regions(&original, &modified);
361
362        // The insertion is between c and d, so we need context around that point
363        assert!(orig_region.len() <= original.len());
364        assert!(mod_region.iter().collect::<String>().contains("XYZ"));
365    }
366
367    #[test]
368    fn test_extract_changed_regions_identical() {
369        let text: Vec<char> = "identical text".chars().collect();
370
371        let (orig_region, mod_region) = extract_changed_regions(&text, &text);
372
373        // When texts are identical, regions should be empty
374        assert!(orig_region.is_empty());
375        assert!(mod_region.is_empty());
376    }
377
378    #[test]
379    fn test_optimized_matches_original_score() {
380        // Test that our optimized version produces the same results
381        let test_cases = vec![
382            ("hello world", "hello there", "hello world"),
383            (
384                "fn main() {}",
385                "fn main() { println!(); }",
386                "fn main() { print!(); }",
387            ),
388            ("abcdefghij", "abcXXXghij", "abcYYghij"),
389            ("unchanged", "unchanged", "unchanged"),
390            (
391                "prefix middle suffix",
392                "prefix CHANGED suffix",
393                "prefix middle suffix",
394            ),
395        ];
396
397        for (original, expected, actual) in test_cases {
398            let score = delta_chr_f(original, expected, actual);
399            // Just verify it produces a reasonable score (0-100)
400            assert!(
401                score >= 0.0 && score <= 100.0,
402                "Score {} out of range for ({}, {}, {})",
403                score,
404                original,
405                expected,
406                actual
407            );
408        }
409    }
410
411    #[test]
412    fn test_optimized_equals_reference() {
413        // Comprehensive test that optimized version matches reference implementation exactly
414        let test_cases = vec![
415            // Basic cases
416            ("hello world", "hello there", "hello world"),
417            ("hello world", "hello there", "hello there"),
418            ("unchanged", "unchanged", "unchanged"),
419            // Code-like cases
420            (
421                "fn main() { println!(\"Hello\"); }",
422                "fn main() { println!(\"Hello, World!\"); }",
423                "fn main() { println!(\"Hello, World!\"); }",
424            ),
425            (
426                "fn main() { println!(\"Hello\"); }",
427                "fn main() { println!(\"Hello, World!\"); }",
428                "fn main() { println!(\"Goodbye\"); }",
429            ),
430            // Insertion
431            ("abcdef", "abcXYZdef", "abcdef"),
432            ("abcdef", "abcXYZdef", "abcXYZdef"),
433            ("abcdef", "abcXYZdef", "abcABCdef"),
434            // Deletion
435            ("abcXYZdef", "abcdef", "abcXYZdef"),
436            ("abcXYZdef", "abcdef", "abcdef"),
437            // Multiple changes (simulated by different expected/actual)
438            ("one two three four", "one THREE four", "one two FOUR"),
439            // Edge cases
440            ("a", "b", "c"),
441            ("", "abc", ""),
442            ("abc", "", "abc"),
443            // Longer text with small change
444            (
445                "This is a longer piece of text that contains many words and characters to process",
446                "This is a longer piece of TEXT that contains many words and characters to process",
447                "This is a longer piece of text that contains many words and characters to process",
448            ),
449            // Change at the beginning
450            (
451                "ORIGINAL start of text",
452                "NEW start of text",
453                "DIFFERENT start of text",
454            ),
455            // Change at the end
456            (
457                "text ending ORIGINAL",
458                "text ending NEW",
459                "text ending DIFFERENT",
460            ),
461            // Whitespace (should be ignored)
462            ("hello   world", "hello   there", "hello   world"),
463            ("a b c d", "a X c d", "a Y c d"),
464        ];
465
466        for (original, expected, actual) in test_cases {
467            let optimized_score = delta_chr_f(original, expected, actual);
468            let reference_score = delta_chr_f_reference(original, expected, actual);
469
470            assert!(
471                (optimized_score - reference_score).abs() < 1e-10,
472                "Mismatch for ({:?}, {:?}, {:?}):\n  optimized: {}\n  reference: {}",
473                original,
474                expected,
475                actual,
476                optimized_score,
477                reference_score
478            );
479        }
480    }
481}
482
483#[cfg(test)]
484mod test {
485    use super::*;
486
487    #[test]
488    fn test_delta_chr_f_perfect_match() {
489        let original = "fn main() {    println!(\"Hello\");}";
490        let expected = "fn main() {    println!(\"Hello, World!\");}";
491
492        let score = delta_chr_f(original, expected, expected);
493        assert!((score - 100.0).abs() < 1e-2);
494    }
495
496    #[test]
497    fn test_delta_chr_f_wrong_edit() {
498        // When the edit is wrong
499        let original = "one two three";
500        let expected = "one three"; // deleted "two "
501        let actual = "one two four"; // deleted "three", added "four"
502
503        // Then the score should be low
504        let score = delta_chr_f(original, expected, actual);
505        assert!(score > 20.0 && score < 40.0);
506    }
507
508    #[test]
509    fn test_delta_chr_f_partial_match() {
510        let original = "let x = 42;";
511        let expected = "let x = 100;";
512        let actual = "let x = 99;";
513
514        // We got the edit location right, but the replacement text is wrong.
515        // Deleted ngrams will match, bringing the score somewhere in the middle.
516        let score = delta_chr_f(original, expected, actual);
517        assert!(score > 40.0 && score < 60.0);
518    }
519
520    #[test]
521    fn test_delta_chr_f_missed_edit() {
522        // When predictions makes no changes
523        let original = "prefix old suffix";
524        let expected = "prefix new suffix";
525        let actual = "prefix old suffix"; // no change
526
527        // Then the score should be low (all expected changes are false negatives)
528        let score = delta_chr_f(original, expected, actual);
529        assert!(score < 20.0);
530    }
531
532    #[test]
533    fn test_delta_chr_f_extra_edit() {
534        // When adding unexpected content
535        let original = "helloworld";
536        let expected = "helloworld"; // no change expected
537        let actual = "helloextraworld"; // added "extra"
538
539        // Then the score should be low (all actual changes are false positives)
540        let score = delta_chr_f(original, expected, actual);
541        assert!(score < 20.0);
542    }
543
544    #[test]
545    fn test_delta_chr_f_no_changes() {
546        let text = "unchanged text";
547        let score = delta_chr_f(text, text, text);
548        assert!((score - 100.0).abs() < 1e-2);
549    }
550
551    #[test]
552    fn test_braces_disbalance() {
553        let text = "let x = { 1 + 2 };";
554        assert_eq!(braces_disbalance(text), 0);
555
556        let text = "let x = { 1 + 2";
557        assert_eq!(braces_disbalance(text), 1);
558
559        let text = "let x = { 1 + 2 )";
560        assert_eq!(braces_disbalance(text), 2);
561    }
562}