kept_rate.rs

  1use crate::word_diff::tokenize;
  2
  3#[cfg(test)]
  4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
  5pub enum TokenAnnotation {
  6    Context,
  7    Kept,
  8    Discarded,
  9}
 10
 11#[allow(dead_code)]
 12#[derive(Debug, Clone)]
 13pub struct KeptRateResult {
 14    /// Characters newly introduced by the candidate
 15    pub candidate_new_chars: usize,
 16    /// Characters newly introduced by the reference
 17    pub reference_new_chars: usize,
 18    /// Characters from `base` that are deleted by the candidate.
 19    pub candidate_deleted_chars: usize,
 20    /// Characters from `base` that are deleted by the reference.
 21    pub reference_deleted_chars: usize,
 22    /// Candidate new characters that are also present in the reference.
 23    pub kept_chars: usize,
 24    /// Base characters deleted by both the candidate and the reference.
 25    pub correctly_deleted_chars: usize,
 26    /// Candidate new characters that are not kept in the reference.
 27    pub discarded_chars: usize,
 28    /// Candidate characters treated as unchanged context
 29    pub context_chars: usize,
 30    /// Fraction of candidate edit characters that match the reference edit.
 31    ///
 32    /// This includes both kept newly introduced characters and correctly
 33    /// deleted base characters.
 34    pub kept_rate: f64,
 35    /// Fraction of reference edit characters covered by the candidate edit.
 36    ///
 37    /// This includes both kept newly introduced characters and correctly
 38    /// deleted base characters.
 39    pub recall_rate: f64,
 40    /// Per-token classification for candidate tokens used by tests.
 41    #[cfg(test)]
 42    pub token_annotations: Vec<TokenAnnotation>,
 43}
 44
 45fn dp_index(width: usize, row: usize, column: usize) -> usize {
 46    row * width + column
 47}
 48
 49/// Return masks over `a` and `b` using one-sided LCS tie-breaking for each
 50/// side while sharing a single DP table construction.
 51fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
 52    if a.is_empty() || b.is_empty() {
 53        return (vec![false; a.len()], vec![false; b.len()]);
 54    }
 55
 56    if a == b {
 57        return (vec![true; a.len()], vec![true; b.len()]);
 58    }
 59
 60    let mut keep_a = vec![false; a.len()];
 61    let mut keep_b = vec![false; b.len()];
 62
 63    let prefix_len = a
 64        .iter()
 65        .zip(b.iter())
 66        .take_while(|(left, right)| left == right)
 67        .count();
 68    let suffix_len = {
 69        let max_suffix = (a.len() - prefix_len).min(b.len() - prefix_len);
 70        let mut suffix_len = 0;
 71
 72        while suffix_len < max_suffix {
 73            let a_index = a.len() - 1 - suffix_len;
 74            let b_index = b.len() - 1 - suffix_len;
 75            if a[a_index] != b[b_index] {
 76                break;
 77            }
 78            suffix_len += 1;
 79        }
 80
 81        suffix_len
 82    };
 83
 84    for index in 0..prefix_len {
 85        keep_a[index] = true;
 86        keep_b[index] = true;
 87    }
 88
 89    for offset in 0..suffix_len {
 90        let a_index = a.len() - suffix_len + offset;
 91        let b_index = b.len() - suffix_len + offset;
 92        keep_a[a_index] = true;
 93        keep_b[b_index] = true;
 94    }
 95
 96    let a_mid = &a[prefix_len..a.len() - suffix_len];
 97    let b_mid = &b[prefix_len..b.len() - suffix_len];
 98
 99    if a_mid.is_empty() || b_mid.is_empty() {
100        return (keep_a, keep_b);
101    }
102
103    let row_count = a_mid.len() + 1;
104    let column_count = b_mid.len() + 1;
105    let mut dp = vec![0u32; row_count * column_count];
106
107    for i in 1..row_count {
108        let token_a = a_mid[i - 1];
109        for j in 1..column_count {
110            let index = dp_index(column_count, i, j);
111            if token_a == b_mid[j - 1] {
112                dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
113            } else {
114                let up = dp[dp_index(column_count, i - 1, j)];
115                let left = dp[dp_index(column_count, i, j - 1)];
116                dp[index] = up.max(left);
117            }
118        }
119    }
120
121    let mut i = a_mid.len();
122    let mut j = b_mid.len();
123
124    while i > 0 && j > 0 {
125        if a_mid[i - 1] == b_mid[j - 1] {
126            keep_a[prefix_len + i - 1] = true;
127            i -= 1;
128            j -= 1;
129        } else {
130            let up = dp[dp_index(column_count, i - 1, j)];
131            let left = dp[dp_index(column_count, i, j - 1)];
132            if up >= left {
133                i -= 1;
134            } else {
135                j -= 1;
136            }
137        }
138    }
139
140    let mut i = a_mid.len();
141    let mut j = b_mid.len();
142
143    while i > 0 && j > 0 {
144        if a_mid[i - 1] == b_mid[j - 1] {
145            keep_b[prefix_len + j - 1] = true;
146            i -= 1;
147            j -= 1;
148        } else {
149            let up = dp[dp_index(column_count, i - 1, j)];
150            let left = dp[dp_index(column_count, i, j - 1)];
151            if left >= up {
152                j -= 1;
153            } else {
154                i -= 1;
155            }
156        }
157    }
158
159    (keep_a, keep_b)
160}
161
162fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) {
163    let mut unmasked_tokens = Vec::with_capacity(tokens.len());
164    let mut unmasked_chars = 0;
165    let mut masked_chars = 0;
166
167    for (&token, &is_masked) in tokens.iter().zip(mask.iter()) {
168        if is_masked {
169            masked_chars += token.len();
170        } else {
171            unmasked_tokens.push(token);
172            unmasked_chars += token.len();
173        }
174    }
175
176    (unmasked_tokens, unmasked_chars, masked_chars)
177}
178
179fn count_unmasked_chars(tokens: &[&str], mask: &[bool]) -> usize {
180    tokens
181        .iter()
182        .zip(mask.iter())
183        .filter_map(|(&token, &is_masked)| (!is_masked).then_some(token.len()))
184        .sum()
185}
186
187pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRateResult {
188    if base == candidate && candidate == reference {
189        let candidate_tokens = tokenize(candidate);
190        let context_chars = candidate_tokens.iter().map(|token| token.len()).sum();
191        return KeptRateResult {
192            candidate_new_chars: 0,
193            reference_new_chars: 0,
194            candidate_deleted_chars: 0,
195            reference_deleted_chars: 0,
196            kept_chars: 0,
197            correctly_deleted_chars: 0,
198            discarded_chars: 0,
199            context_chars,
200            kept_rate: 1.0,
201            recall_rate: 1.0,
202            #[cfg(test)]
203            token_annotations: vec![TokenAnnotation::Context; candidate_tokens.len()],
204        };
205    }
206
207    let base_tokens = tokenize(base);
208    let candidate_tokens = tokenize(candidate);
209    let reference_tokens = tokenize(reference);
210
211    let (candidate_base_mask, base_candidate_mask) =
212        lcs_keep_masks(&candidate_tokens, &base_tokens);
213    let (candidate_reference_mask, reference_candidate_mask) =
214        lcs_keep_masks(&candidate_tokens, &reference_tokens);
215    let context_mask: Vec<bool> = candidate_base_mask
216        .iter()
217        .zip(candidate_reference_mask.iter())
218        .map(|(&in_base, &in_final)| in_base && in_final)
219        .collect();
220
221    let (stripped_candidate, candidate_new_chars, context_chars) =
222        analyze_masked_tokens(&candidate_tokens, &context_mask);
223
224    let (reference_base_mask, base_reference_mask) =
225        lcs_keep_masks(&reference_tokens, &base_tokens);
226    let reference_context_mask: Vec<bool> = reference_base_mask
227        .iter()
228        .zip(reference_candidate_mask.iter())
229        .map(|(&in_base, &in_candidate)| in_base && in_candidate)
230        .collect();
231
232    let (stripped_reference, reference_new_chars, _) =
233        analyze_masked_tokens(&reference_tokens, &reference_context_mask);
234
235    let keep_mask = lcs_keep_masks(&stripped_candidate, &stripped_reference).0;
236
237    let kept_chars: usize = stripped_candidate
238        .iter()
239        .zip(keep_mask.iter())
240        .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
241        .sum();
242
243    let candidate_deleted_chars = count_unmasked_chars(&base_tokens, &base_candidate_mask);
244    let reference_deleted_chars = count_unmasked_chars(&base_tokens, &base_reference_mask);
245    let correctly_deleted_chars: usize = base_tokens
246        .iter()
247        .zip(base_candidate_mask.iter().zip(base_reference_mask.iter()))
248        .filter_map(|(&token, (&in_candidate, &in_reference))| {
249            (!in_candidate && !in_reference).then_some(token.len())
250        })
251        .sum();
252
253    let discarded_chars = candidate_new_chars - kept_chars;
254    let matched_edit_chars = kept_chars + correctly_deleted_chars;
255    let candidate_edit_chars = candidate_new_chars + candidate_deleted_chars;
256    let reference_edit_chars = reference_new_chars + reference_deleted_chars;
257
258    let kept_rate = if candidate_edit_chars == 0 {
259        if reference_edit_chars == 0 { 1.0 } else { 0.0 }
260    } else {
261        matched_edit_chars as f64 / candidate_edit_chars as f64
262    };
263
264    let recall_rate = if reference_edit_chars == 0 {
265        if candidate_edit_chars == 0 { 1.0 } else { 0.0 }
266    } else {
267        matched_edit_chars as f64 / reference_edit_chars as f64
268    };
269
270    #[cfg(test)]
271    let token_annotations = {
272        let mut token_annotations = Vec::with_capacity(candidate_tokens.len());
273        let mut new_index = 0;
274        for (token_index, _token) in candidate_tokens.iter().enumerate() {
275            if context_mask[token_index] {
276                token_annotations.push(TokenAnnotation::Context);
277            } else {
278                let annotation = if keep_mask[new_index] {
279                    TokenAnnotation::Kept
280                } else {
281                    TokenAnnotation::Discarded
282                };
283                #[cfg(test)]
284                token_annotations.push(annotation);
285                new_index += 1;
286            }
287        }
288        token_annotations
289    };
290
291    KeptRateResult {
292        candidate_new_chars,
293        reference_new_chars,
294        candidate_deleted_chars,
295        reference_deleted_chars,
296        kept_chars,
297        correctly_deleted_chars,
298        discarded_chars,
299        context_chars,
300        kept_rate,
301        recall_rate,
302        #[cfg(test)]
303        token_annotations,
304    }
305}
306
307#[cfg(test)]
308mod test_kept_rate {
309    use super::*;
310
311    #[test]
312    fn test_lcs_keep_masks() {
313        let (a_mask, b_mask) = lcs_keep_masks(&["a", "b", "c", "d", "e"], &["a", "c", "e"]);
314        assert_eq!(a_mask, vec![true, false, true, false, true]);
315        assert_eq!(b_mask, vec![true, true, true]);
316
317        let (a_mask, b_mask) = lcs_keep_masks(&[], &["x"]);
318        assert!(a_mask.is_empty());
319        assert_eq!(b_mask, vec![false]);
320    }
321
322    #[test]
323    fn test_lcs_keep_masks_matches_historical_one_sided_masks() {
324        let a = ["x", "a", "x", "b"];
325        let b = ["a", "x", "b", "x"];
326        let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
327        assert_eq!(a_mask, lcs_keep_masks(&a, &b).0);
328        assert_eq!(b_mask, lcs_keep_masks(&b, &a).0);
329    }
330
331    #[test]
332    fn test_rate_extremes() {
333        let no_change = compute_kept_rate("foo bar", "foo bar", "foo bar");
334        assert!((no_change.kept_rate - 1.0).abs() < 1e-6);
335        assert!((no_change.recall_rate - 1.0).abs() < 1e-6);
336        assert_eq!(no_change.candidate_new_chars, 0);
337        assert!(
338            no_change
339                .token_annotations
340                .iter()
341                .all(|&annotation| annotation == TokenAnnotation::Context)
342        );
343
344        let accepted = compute_kept_rate("old", "new", "new");
345        assert!((accepted.kept_rate - 1.0).abs() < 1e-6);
346        assert!((accepted.recall_rate - 1.0).abs() < 1e-6);
347
348        let discarded = compute_kept_rate("old", "old", "new");
349        assert!((discarded.kept_rate - 0.0).abs() < 1e-6);
350        assert!((discarded.recall_rate - 0.0).abs() < 1e-6);
351    }
352
353    #[test]
354    fn test_pure_addition() {
355        let kept = compute_kept_rate("", "brand new line\n", "brand new line\n");
356        assert_eq!(kept.kept_chars, kept.candidate_new_chars);
357        assert!(
358            kept.token_annotations
359                .iter()
360                .all(|&annotation| annotation == TokenAnnotation::Kept)
361        );
362
363        let discarded =
364            compute_kept_rate("", "brand new line\n", "something completely different\n");
365        assert!(discarded.kept_chars < discarded.candidate_new_chars);
366    }
367
368    #[test]
369    fn test_decoy_when_base_excluded() {
370        let base = "    decoy.when(mock_sync_hardware_api.sp()).then_return(SpeedStatus.IDLE)\n";
371        let candidate = "    decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
372        let reference = "    decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
373        let result = compute_kept_rate(base, candidate, reference);
374        let expected_new = "mock_sync_module_hardware".len() + "speed_status".len();
375        assert_eq!(result.candidate_new_chars, expected_new);
376        assert!(result.correctly_deleted_chars > 0);
377        assert!((result.kept_rate - 1.0).abs() < 1e-6);
378        assert!((result.recall_rate - 1.0).abs() < 1e-6);
379    }
380
381    #[test]
382    fn test_missing_deletion() {
383        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
384        let candidate = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\neprintln!(\"\");\n";
385        let reference = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
386        let result = compute_kept_rate(base, candidate, reference);
387        assert!(
388            result.kept_rate < 0.85,
389            "expected kept_rate < 0.85, got {}",
390            result.kept_rate
391        );
392        assert!(result.discarded_chars > 0);
393    }
394
395    #[test]
396    fn test_empty_prediction() {
397        let result = compute_kept_rate("old line\n", "", "new line\n");
398        assert_eq!(result.candidate_new_chars, 0);
399        assert!(result.candidate_deleted_chars > 0);
400        assert!(result.correctly_deleted_chars > 0);
401        assert!(result.correctly_deleted_chars < result.candidate_deleted_chars);
402        assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
403        assert!(result.recall_rate > 0.0 && result.recall_rate < 1.0);
404    }
405
406    #[test]
407    fn test_partial_kept() {
408        let result = compute_kept_rate("old\n", "alpha\nbeta\ngamma\n", "alpha\ngamma\n");
409        assert!(result.kept_chars > 0);
410        assert!(result.discarded_chars > 0);
411        assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
412    }
413
414    #[test]
415    fn test_eprintln_token_alignment() {
416        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
417        let candidate = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"hello world!\");\n";
418        let reference = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
419        let result = compute_kept_rate(base, candidate, reference);
420        assert!(result.discarded_chars > 0);
421        assert!(result.kept_chars > 0);
422        assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
423        assert_eq!(result.kept_chars, 14);
424        assert_eq!(result.discarded_chars, 12);
425    }
426
427    #[test]
428    fn test_annotations_rename() {
429        let base = "    foo(old_name)\n";
430        let candidate = "    foo(new_name)\n";
431        let reference = "    foo(new_name)\n";
432        let result = compute_kept_rate(base, candidate, reference);
433
434        assert_eq!(result.candidate_new_chars, "new_name".len());
435        assert_eq!(result.candidate_deleted_chars, "old_name".len());
436        assert_eq!(result.reference_deleted_chars, "old_name".len());
437        assert_eq!(result.correctly_deleted_chars, "old_name".len());
438        assert!((result.recall_rate - 1.0).abs() < 1e-6);
439        assert_eq!(result.token_annotations.len(), tokenize(candidate).len());
440
441        for (&token, &annotation) in tokenize(candidate).iter().zip(&result.token_annotations) {
442            if token == "new_name" {
443                assert_eq!(annotation, TokenAnnotation::Kept);
444            } else {
445                assert_eq!(annotation, TokenAnnotation::Context);
446            }
447        }
448    }
449
450    #[test]
451    fn test_annotations_eprintln_coloring() {
452        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
453        let candidate = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"hello world!\");\n";
454        let reference = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
455        let result = compute_kept_rate(base, candidate, reference);
456        let candidate_tokens = tokenize(candidate);
457
458        let eprintln_index = candidate_tokens
459            .iter()
460            .position(|&token| token == "eprintln")
461            .expect("eprintln token not found");
462
463        for annotation in &result.token_annotations[..eprintln_index] {
464            assert_eq!(*annotation, TokenAnnotation::Context);
465        }
466
467        assert_eq!(
468            &result.token_annotations[eprintln_index..=eprintln_index + 10],
469            &[
470                TokenAnnotation::Kept,
471                TokenAnnotation::Kept,
472                TokenAnnotation::Kept,
473                TokenAnnotation::Kept,
474                TokenAnnotation::Discarded,
475                TokenAnnotation::Discarded,
476                TokenAnnotation::Discarded,
477                TokenAnnotation::Discarded,
478                TokenAnnotation::Kept,
479                TokenAnnotation::Kept,
480                TokenAnnotation::Kept,
481            ]
482        );
483        assert_eq!(
484            result.token_annotations.last(),
485            Some(&TokenAnnotation::Context)
486        );
487    }
488
489    #[test]
490    fn test_repetitive_tokens_remain_discarded() {
491        let base = "foo + foo + foo + foo + foo\n".repeat(16);
492        let candidate = "foo + foo + prediction_token + foo + foo\n".repeat(16);
493        let reference = "foo + foo + kept_token + foo + foo\n".repeat(16);
494        let result = compute_kept_rate(&base, &candidate, &reference);
495
496        assert_eq!(result.kept_chars, 0);
497        assert_eq!(result.correctly_deleted_chars, "foo".len() * 16);
498        assert_eq!(result.discarded_chars, result.candidate_new_chars);
499        assert_eq!(result.candidate_new_chars, "prediction_token".len() * 16);
500        assert!(result.kept_rate > 0.0);
501        assert!(result.recall_rate > 0.0);
502    }
503}