kept_rate.rs

  1use crate::tokenize::tokenize;
  2use serde::Serialize;
  3
  4const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
  5
  6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
  7#[serde(rename_all = "snake_case")]
  8pub enum TokenAnnotation {
  9    Context,
 10    Kept,
 11    Discarded,
 12}
 13
 14#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
 15pub struct AnnotatedToken {
 16    pub token: String,
 17    pub annotation: TokenAnnotation,
 18}
 19
 20#[allow(dead_code)]
 21#[derive(Debug, Clone, Serialize)]
 22pub struct KeptRateResult {
 23    /// Characters newly introduced by the candidate
 24    pub candidate_new_chars: usize,
 25    /// Characters newly introduced by the reference
 26    pub reference_new_chars: usize,
 27    /// Characters from `base` that are deleted by the candidate.
 28    pub candidate_deleted_chars: usize,
 29    /// Characters from `base` that are deleted by the reference.
 30    pub reference_deleted_chars: usize,
 31    /// Candidate new characters that are also present in the reference.
 32    pub kept_chars: usize,
 33    /// Base characters deleted by both the candidate and the reference.
 34    pub correctly_deleted_chars: usize,
 35    /// Candidate new characters that are not kept in the reference.
 36    pub discarded_chars: usize,
 37    /// Candidate characters treated as unchanged context
 38    pub context_chars: usize,
 39    /// Fraction of candidate edit characters that match the reference edit.
 40    ///
 41    /// This includes both kept newly introduced characters and correctly
 42    /// deleted base characters.
 43    pub kept_rate: f64,
 44    /// Fraction of reference edit characters covered by the candidate edit.
 45    ///
 46    /// This includes both kept newly introduced characters and correctly
 47    /// deleted base characters.
 48    pub recall_rate: f64,
 49    /// Per-token classification for candidate tokens.
 50    pub token_annotations: Vec<TokenAnnotation>,
 51}
 52
 53fn dp_index(width: usize, row: usize, column: usize) -> usize {
 54    row * width + column
 55}
 56
 57/// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side
 58/// while sharing a single DP table construction.
 59fn fill_lcs_keep_masks<T: Eq>(
 60    a: &[T],
 61    b: &[T],
 62    mut keep_a: Option<&mut [bool]>,
 63    mut keep_b: Option<&mut [bool]>,
 64) {
 65    if a.is_empty() || b.is_empty() {
 66        return;
 67    }
 68
 69    if a == b {
 70        if let Some(keep_a) = keep_a.as_mut() {
 71            keep_a.fill(true);
 72        }
 73        if let Some(keep_b) = keep_b.as_mut() {
 74            keep_b.fill(true);
 75        }
 76        return;
 77    }
 78
 79    let prefix_len = a
 80        .iter()
 81        .zip(b.iter())
 82        .take_while(|(left, right)| left == right)
 83        .count();
 84    let suffix_len = {
 85        let max_suffix = (a.len() - prefix_len).min(b.len() - prefix_len);
 86        let mut suffix_len = 0;
 87
 88        while suffix_len < max_suffix {
 89            let a_index = a.len() - 1 - suffix_len;
 90            let b_index = b.len() - 1 - suffix_len;
 91            if a[a_index] != b[b_index] {
 92                break;
 93            }
 94            suffix_len += 1;
 95        }
 96
 97        suffix_len
 98    };
 99
100    for index in 0..prefix_len {
101        if let Some(keep_a) = keep_a.as_mut() {
102            keep_a[index] = true;
103        }
104        if let Some(keep_b) = keep_b.as_mut() {
105            keep_b[index] = true;
106        }
107    }
108
109    for offset in 0..suffix_len {
110        let a_index = a.len() - suffix_len + offset;
111        let b_index = b.len() - suffix_len + offset;
112        if let Some(keep_a) = keep_a.as_mut() {
113            keep_a[a_index] = true;
114        }
115        if let Some(keep_b) = keep_b.as_mut() {
116            keep_b[b_index] = true;
117        }
118    }
119
120    let a_mid = &a[prefix_len..a.len() - suffix_len];
121    let b_mid = &b[prefix_len..b.len() - suffix_len];
122
123    if a_mid.is_empty() || b_mid.is_empty() {
124        return;
125    }
126
127    let row_count = a_mid.len() + 1;
128    let column_count = b_mid.len() + 1;
129    let mut dp = vec![0u32; row_count * column_count];
130
131    for i in 1..row_count {
132        let token_a = &a_mid[i - 1];
133        for j in 1..column_count {
134            let index = dp_index(column_count, i, j);
135            if token_a == &b_mid[j - 1] {
136                dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
137            } else {
138                let up = dp[dp_index(column_count, i - 1, j)];
139                let left = dp[dp_index(column_count, i, j - 1)];
140                dp[index] = up.max(left);
141            }
142        }
143    }
144
145    if let Some(keep_a) = keep_a.as_mut() {
146        let mut i = a_mid.len();
147        let mut j = b_mid.len();
148
149        while i > 0 && j > 0 {
150            if a_mid[i - 1] == b_mid[j - 1] {
151                keep_a[prefix_len + i - 1] = true;
152                i -= 1;
153                j -= 1;
154            } else {
155                let up = dp[dp_index(column_count, i - 1, j)];
156                let left = dp[dp_index(column_count, i, j - 1)];
157                if up >= left {
158                    i -= 1;
159                } else {
160                    j -= 1;
161                }
162            }
163        }
164    }
165
166    if let Some(keep_b) = keep_b.as_mut() {
167        let mut i = a_mid.len();
168        let mut j = b_mid.len();
169
170        while i > 0 && j > 0 {
171            if a_mid[i - 1] == b_mid[j - 1] {
172                keep_b[prefix_len + j - 1] = true;
173                i -= 1;
174                j -= 1;
175            } else {
176                let up = dp[dp_index(column_count, i - 1, j)];
177                let left = dp[dp_index(column_count, i, j - 1)];
178                if left >= up {
179                    j -= 1;
180                } else {
181                    i -= 1;
182                }
183            }
184        }
185    }
186}
187
188fn lcs_keep_mask<T: Eq>(a: &[T], b: &[T]) -> Vec<bool> {
189    let mut keep_a = vec![false; a.len()];
190    fill_lcs_keep_masks(a, b, Some(&mut keep_a), None);
191    keep_a
192}
193
194fn lcs_keep_masks<T: Eq>(a: &[T], b: &[T]) -> (Vec<bool>, Vec<bool>) {
195    let mut keep_a = vec![false; a.len()];
196    let mut keep_b = vec![false; b.len()];
197    fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b));
198    (keep_a, keep_b)
199}
200
201#[derive(Debug, Clone)]
202struct ComparisonUnit {
203    text: String,
204    token_start: usize,
205    token_end: usize,
206}
207
208fn is_identifier_token(token: &str) -> bool {
209    !token.is_empty()
210        && token
211            .chars()
212            .all(|character| character.is_alphanumeric() || character == '_')
213}
214
215fn build_comparison_units(tokens: &[&str]) -> Vec<ComparisonUnit> {
216    let mut units = Vec::new();
217    let mut index = 0;
218
219    while index < tokens.len() {
220        let token_start = index;
221
222        if is_identifier_token(tokens[index]) {
223            let mut text = String::new();
224
225            while index < tokens.len() && is_identifier_token(tokens[index]) {
226                text.push_str(tokens[index]);
227                index += 1;
228            }
229
230            units.push(ComparisonUnit {
231                text,
232                token_start,
233                token_end: index,
234            });
235        } else {
236            units.push(ComparisonUnit {
237                text: tokens[index].to_string(),
238                token_start,
239                token_end: index + 1,
240            });
241            index += 1;
242        }
243    }
244
245    units
246}
247
248fn analyze_masked_units<'a>(
249    units: &'a [ComparisonUnit],
250    mask: &[bool],
251) -> (Vec<&'a str>, usize, usize) {
252    let mut unmasked_units = Vec::with_capacity(units.len());
253    let mut unmasked_chars = 0;
254    let mut masked_chars = 0;
255
256    for (unit, &is_masked) in units.iter().zip(mask.iter()) {
257        if is_masked {
258            masked_chars += unit.text.len();
259        } else {
260            unmasked_units.push(unit.text.as_str());
261            unmasked_chars += unit.text.len();
262        }
263    }
264
265    (unmasked_units, unmasked_chars, masked_chars)
266}
267
268fn count_unmasked_unit_chars(units: &[ComparisonUnit], mask: &[bool]) -> usize {
269    units
270        .iter()
271        .zip(mask.iter())
272        .filter_map(|(unit, &is_masked)| (!is_masked).then_some(unit.text.len()))
273        .sum()
274}
275
276fn should_bail_for_dirty_final(base: &str, candidate: &str, reference: &str) -> bool {
277    let candidate_delta_chars = candidate.len().abs_diff(base.len());
278    let reference_delta_chars = reference.len().abs_diff(base.len());
279    candidate_delta_chars.abs_diff(reference_delta_chars) > MAX_DIRTY_LENGTH_DELTA_CHARS
280}
281
282pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRateResult {
283    if base == candidate && candidate == reference {
284        let candidate_tokens = tokenize(candidate);
285        let context_chars = candidate_tokens.iter().map(|token| token.len()).sum();
286        return KeptRateResult {
287            candidate_new_chars: 0,
288            reference_new_chars: 0,
289            candidate_deleted_chars: 0,
290            reference_deleted_chars: 0,
291            kept_chars: 0,
292            correctly_deleted_chars: 0,
293            discarded_chars: 0,
294            context_chars,
295            kept_rate: 1.0,
296            recall_rate: 1.0,
297            token_annotations: vec![TokenAnnotation::Context; candidate_tokens.len()],
298        };
299    }
300
301    if should_bail_for_dirty_final(base, candidate, reference) {
302        let candidate_new_chars = candidate.len().abs_diff(base.len());
303        let reference_new_chars = reference.len().abs_diff(base.len());
304        return KeptRateResult {
305            candidate_new_chars,
306            reference_new_chars,
307            candidate_deleted_chars: 0,
308            reference_deleted_chars: 0,
309            kept_chars: 0,
310            correctly_deleted_chars: 0,
311            discarded_chars: candidate_new_chars,
312            context_chars: 0,
313            kept_rate: 0.0,
314            recall_rate: 0.0,
315            token_annotations: vec![TokenAnnotation::Discarded; tokenize(candidate).len()],
316        };
317    }
318
319    let base_tokens = tokenize(base);
320    let candidate_tokens = tokenize(candidate);
321    let reference_tokens = tokenize(reference);
322
323    let candidate_units = build_comparison_units(&candidate_tokens);
324    let base_units = build_comparison_units(&base_tokens);
325    let reference_units = build_comparison_units(&reference_tokens);
326
327    let candidate_unit_texts: Vec<&str> = candidate_units
328        .iter()
329        .map(|unit| unit.text.as_str())
330        .collect();
331    let base_unit_texts: Vec<&str> = base_units.iter().map(|unit| unit.text.as_str()).collect();
332    let reference_unit_texts: Vec<&str> = reference_units
333        .iter()
334        .map(|unit| unit.text.as_str())
335        .collect();
336
337    let (candidate_base_mask, base_candidate_mask) =
338        lcs_keep_masks(&candidate_unit_texts, &base_unit_texts);
339    let (stripped_candidate, candidate_new_chars, context_chars) =
340        analyze_masked_units(&candidate_units, &candidate_base_mask);
341
342    let (reference_base_mask, base_reference_mask) =
343        lcs_keep_masks(&reference_unit_texts, &base_unit_texts);
344    let (stripped_reference, reference_new_chars, _) =
345        analyze_masked_units(&reference_units, &reference_base_mask);
346
347    let keep_mask = lcs_keep_mask(&stripped_candidate, &stripped_reference);
348
349    let kept_chars: usize = stripped_candidate
350        .iter()
351        .zip(keep_mask.iter())
352        .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
353        .sum();
354
355    let candidate_deleted_chars = count_unmasked_unit_chars(&base_units, &base_candidate_mask);
356    let reference_deleted_chars = count_unmasked_unit_chars(&base_units, &base_reference_mask);
357    let correctly_deleted_chars: usize = base_units
358        .iter()
359        .zip(base_candidate_mask.iter().zip(base_reference_mask.iter()))
360        .filter_map(|(unit, (&in_candidate, &in_reference))| {
361            (!in_candidate && !in_reference).then_some(unit.text.len())
362        })
363        .sum();
364
365    let discarded_chars = candidate_new_chars - kept_chars;
366    let matched_edit_chars = kept_chars + correctly_deleted_chars;
367    let candidate_edit_chars = candidate_new_chars + candidate_deleted_chars;
368    let reference_edit_chars = reference_new_chars + reference_deleted_chars;
369
370    let kept_rate = if candidate_edit_chars == 0 {
371        if reference_edit_chars == 0 { 1.0 } else { 0.0 }
372    } else {
373        matched_edit_chars as f64 / candidate_edit_chars as f64
374    };
375
376    let recall_rate = if reference_edit_chars == 0 {
377        if candidate_edit_chars == 0 { 1.0 } else { 0.0 }
378    } else {
379        matched_edit_chars as f64 / reference_edit_chars as f64
380    };
381
382    let token_annotations = {
383        let mut token_annotations = vec![TokenAnnotation::Context; candidate_tokens.len()];
384        let mut new_index = 0;
385
386        for (unit_index, unit) in candidate_units.iter().enumerate() {
387            let annotation = if candidate_base_mask[unit_index] {
388                TokenAnnotation::Context
389            } else {
390                let annotation = if keep_mask[new_index] {
391                    TokenAnnotation::Kept
392                } else {
393                    TokenAnnotation::Discarded
394                };
395                new_index += 1;
396                annotation
397            };
398
399            for token_index in unit.token_start..unit.token_end {
400                token_annotations[token_index] = annotation;
401            }
402        }
403
404        token_annotations
405    };
406
407    KeptRateResult {
408        candidate_new_chars,
409        reference_new_chars,
410        candidate_deleted_chars,
411        reference_deleted_chars,
412        kept_chars,
413        correctly_deleted_chars,
414        discarded_chars,
415        context_chars,
416        kept_rate,
417        recall_rate,
418        token_annotations,
419    }
420}
421
422pub fn annotate_kept_rate_tokens(
423    base: &str,
424    candidate: &str,
425    reference: &str,
426) -> Vec<AnnotatedToken> {
427    let result = compute_kept_rate(base, candidate, reference);
428    tokenize(candidate)
429        .into_iter()
430        .zip(result.token_annotations)
431        .map(|(token, annotation)| AnnotatedToken {
432            token: token.to_string(),
433            annotation,
434        })
435        .collect()
436}
437
438#[cfg(test)]
439mod test_kept_rate {
440    use super::*;
441    use indoc::indoc;
442
443    #[test]
444    fn test_lcs_keep_masks() {
445        let (a_mask, b_mask) = lcs_keep_masks(&["a", "b", "c", "d", "e"], &["a", "c", "e"]);
446        assert_eq!(a_mask, vec![true, false, true, false, true]);
447        assert_eq!(b_mask, vec![true, true, true]);
448
449        let (a_mask, b_mask) = lcs_keep_masks(&[], &["x"]);
450        assert!(a_mask.is_empty());
451        assert_eq!(b_mask, vec![false]);
452    }
453
454    #[test]
455    fn test_lcs_keep_masks_matches_historical_one_sided_masks() {
456        let a = ["x", "a", "x", "b"];
457        let b = ["a", "x", "b", "x"];
458        let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
459        assert_eq!(a_mask, lcs_keep_mask(&a, &b));
460        assert_eq!(b_mask, lcs_keep_mask(&b, &a));
461    }
462
463    #[test]
464    fn test_rate_extremes() {
465        let no_change = compute_kept_rate("foo bar", "foo bar", "foo bar");
466        assert!((no_change.kept_rate - 1.0).abs() < 1e-6);
467        assert!((no_change.recall_rate - 1.0).abs() < 1e-6);
468        assert_eq!(no_change.candidate_new_chars, 0);
469        assert!(
470            no_change
471                .token_annotations
472                .iter()
473                .all(|&annotation| annotation == TokenAnnotation::Context)
474        );
475
476        let accepted = compute_kept_rate("old", "new", "new");
477        assert!((accepted.kept_rate - 1.0).abs() < 1e-6);
478        assert!((accepted.recall_rate - 1.0).abs() < 1e-6);
479
480        let discarded = compute_kept_rate("old", "old", "new");
481        assert!((discarded.kept_rate - 0.0).abs() < 1e-6);
482        assert!((discarded.recall_rate - 0.0).abs() < 1e-6);
483    }
484
485    #[test]
486    fn test_pure_addition() {
487        let kept = compute_kept_rate("", "brand new line\n", "brand new line\n");
488        assert_eq!(kept.kept_chars, kept.candidate_new_chars);
489        assert!(
490            kept.token_annotations
491                .iter()
492                .all(|&annotation| annotation == TokenAnnotation::Kept)
493        );
494
495        let discarded =
496            compute_kept_rate("", "brand new line\n", "something completely different\n");
497        assert!(discarded.kept_chars < discarded.candidate_new_chars);
498    }
499
500    #[test]
501    fn test_decoy_when_base_excluded() {
502        let base = "    decoy.when(mock_sync_hardware_api.sp()).then_return(SpeedStatus.IDLE)\n";
503        let candidate = "    decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
504        let reference = "    decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
505        let result = compute_kept_rate(base, candidate, reference);
506        let expected_new = "mock_sync_module_hardware".len() + "speed_status".len();
507        assert_eq!(result.candidate_new_chars, expected_new);
508        assert!(result.correctly_deleted_chars > 0);
509        assert!((result.kept_rate - 1.0).abs() < 1e-6);
510        assert!((result.recall_rate - 1.0).abs() < 1e-6);
511    }
512
513    #[test]
514    fn test_missing_deletion() {
515        let base = indoc! {"
516            fn example() {
517                epr
518        "};
519        let candidate = indoc! {r#"
520            fn example() {
521                epr
522            eprintln!("");
523        "#};
524        let reference = indoc! {r#"
525            fn example() {
526            eprintln!("");
527        "#};
528
529        let result = compute_kept_rate(base, candidate, reference);
530        assert!((result.kept_rate - (14.0 / 15.0)).abs() < 1e-6);
531        assert_eq!(result.kept_chars, 14);
532        assert_eq!(result.discarded_chars, 1);
533    }
534
535    #[test]
536    fn test_empty_prediction() {
537        let result = compute_kept_rate("old line\n", "", "new line\n");
538        assert_eq!(result.candidate_new_chars, 0);
539        assert!(result.candidate_deleted_chars > 0);
540        assert!(result.correctly_deleted_chars > 0);
541        assert!(result.correctly_deleted_chars < result.candidate_deleted_chars);
542        assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
543        assert!(result.recall_rate > 0.0 && result.recall_rate < 1.0);
544    }
545
546    #[test]
547    fn test_partial_kept() {
548        let result = compute_kept_rate("old\n", "alpha\nbeta\ngamma\n", "alpha\ngamma\n");
549        assert!(result.kept_chars > 0);
550        assert!(result.discarded_chars > 0);
551        assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
552    }
553
554    #[test]
555    fn test_bails_for_dirty_final() {
556        let base = indoc! {"
557            fn example() {
558                work();
559            }
560        "};
561        let candidate = indoc! {"
562            fn example() {
563                work();
564                predicted();
565            }
566        "};
567        let reference = format!(
568            "fn example() {{\n    work();\n    {}\n}}\n",
569            "settled();\n    ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64)
570        );
571
572        let result = compute_kept_rate(base, candidate, &reference);
573        assert_eq!(result.kept_rate, 0.0);
574        assert_eq!(result.recall_rate, 0.0);
575        assert_eq!(result.kept_chars, 0);
576        assert_eq!(result.discarded_chars, result.candidate_new_chars);
577    }
578
579    #[test]
580    fn test_eprintln_token_alignment() {
581        let base = indoc! {"
582            fn example() {
583                epr
584        "};
585        let candidate = indoc! {r#"
586            fn example() {
587                eprintln!("hello world!");
588        "#};
589        let reference = indoc! {r#"
590            fn example() {
591                eprintln!("");
592        "#};
593
594        let result = compute_kept_rate(base, candidate, reference);
595        assert!(result.discarded_chars > 0);
596        assert!(result.kept_chars > 0);
597        assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
598        assert_eq!(result.kept_chars, 14);
599        assert_eq!(result.discarded_chars, 12);
600    }
601
602    #[test]
603    fn test_kept_rate_treats_unchanged_stale_text_as_context() {
604        let base = indoc! {"
605            a=fomr
606            b=old
607        "};
608        let candidate = indoc! {"
609            a=formula;
610            b=old
611        "};
612        let reference = indoc! {"
613            a=formula;
614            b=new
615        "};
616
617        let result = compute_kept_rate(base, candidate, reference);
618        let candidate_tokens = tokenize(candidate);
619
620        assert_eq!(result.candidate_new_chars, "formula".len() + ";".len());
621        assert_eq!(result.kept_chars, "formula".len() + ";".len());
622        assert_eq!(result.discarded_chars, 0);
623        assert_eq!(result.candidate_deleted_chars, "fomr".len());
624        assert_eq!(result.correctly_deleted_chars, "fomr".len());
625        assert!((result.kept_rate - 1.0).abs() < 1e-6);
626        assert!((result.recall_rate - (2.0 / 3.0)).abs() < 1e-6);
627
628        let old_index = candidate_tokens
629            .iter()
630            .position(|&token| token == "old")
631            .expect("old token not found");
632        assert_eq!(
633            result.token_annotations[old_index],
634            TokenAnnotation::Context
635        );
636    }
637
638    #[test]
639    fn test_annotations_rename() {
640        let base = "    foo(old_name)\n";
641        let candidate = "    foo(new_name)\n";
642        let reference = "    foo(new_name)\n";
643        let result = compute_kept_rate(base, candidate, reference);
644
645        assert_eq!(result.candidate_new_chars, "new_name".len());
646        assert_eq!(result.candidate_deleted_chars, "old_name".len());
647        assert_eq!(result.reference_deleted_chars, "old_name".len());
648        assert_eq!(result.correctly_deleted_chars, "old_name".len());
649        assert!((result.recall_rate - 1.0).abs() < 1e-6);
650        assert_eq!(result.token_annotations.len(), tokenize(candidate).len());
651
652        for (&token, &annotation) in tokenize(candidate).iter().zip(&result.token_annotations) {
653            if matches!(token, "new" | "_" | "name") {
654                assert_eq!(annotation, TokenAnnotation::Kept);
655            } else {
656                assert_eq!(annotation, TokenAnnotation::Context);
657            }
658        }
659    }
660
661    #[test]
662    fn test_annotations_eprintln_coloring() {
663        let base = indoc! {"
664            fn example() {
665                epr
666        "};
667        let candidate = indoc! {r#"
668            fn example() {
669                eprintln!("hello world!");
670        "#};
671        let reference = indoc! {r#"
672            fn example() {
673                eprintln!("");
674        "#};
675        let result = compute_kept_rate(base, candidate, reference);
676        let candidate_tokens = tokenize(candidate);
677
678        let eprintln_index = candidate_tokens
679            .iter()
680            .position(|&token| token == "eprintln")
681            .expect("eprintln token not found");
682
683        for annotation in &result.token_annotations[..eprintln_index] {
684            assert_eq!(*annotation, TokenAnnotation::Context);
685        }
686
687        assert_eq!(
688            &result.token_annotations[eprintln_index..=eprintln_index + 10],
689            &[
690                TokenAnnotation::Kept,
691                TokenAnnotation::Kept,
692                TokenAnnotation::Kept,
693                TokenAnnotation::Kept,
694                TokenAnnotation::Discarded,
695                TokenAnnotation::Discarded,
696                TokenAnnotation::Discarded,
697                TokenAnnotation::Discarded,
698                TokenAnnotation::Kept,
699                TokenAnnotation::Kept,
700                TokenAnnotation::Kept,
701            ]
702        );
703        assert_eq!(
704            result.token_annotations.last(),
705            Some(&TokenAnnotation::Context)
706        );
707    }
708
709    #[test]
710    fn test_repetitive_tokens_remain_discarded() {
711        let base = "foo + foo + foo + foo + foo\n".repeat(16);
712        let candidate = "foo + foo + prediction_token + foo + foo\n".repeat(16);
713        let reference = "foo + foo + kept_token + foo + foo\n".repeat(16);
714        let result = compute_kept_rate(&base, &candidate, &reference);
715
716        assert_eq!(result.kept_chars, 0);
717        assert_eq!(result.correctly_deleted_chars, "foo".len() * 16);
718        assert_eq!(result.discarded_chars, result.candidate_new_chars);
719        assert_eq!(result.candidate_new_chars, "prediction_token".len() * 16);
720        assert!(result.kept_rate > 0.0);
721        assert!(result.recall_rate > 0.0);
722    }
723}