kept_rate.rs

  1use crate::metrics::tokenize;
  2
  3const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
  4
  5#[cfg(test)]
  6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
  7pub enum TokenAnnotation {
  8    Context,
  9    Kept,
 10    Discarded,
 11}
 12
 13#[allow(dead_code)]
 14#[derive(Debug, Clone)]
 15pub struct KeptRateResult {
 16    /// Characters newly introduced by the candidate
 17    pub candidate_new_chars: usize,
 18    /// Characters newly introduced by the reference
 19    pub reference_new_chars: usize,
 20    /// Characters from `base` that are deleted by the candidate.
 21    pub candidate_deleted_chars: usize,
 22    /// Characters from `base` that are deleted by the reference.
 23    pub reference_deleted_chars: usize,
 24    /// Candidate new characters that are also present in the reference.
 25    pub kept_chars: usize,
 26    /// Base characters deleted by both the candidate and the reference.
 27    pub correctly_deleted_chars: usize,
 28    /// Candidate new characters that are not kept in the reference.
 29    pub discarded_chars: usize,
 30    /// Candidate characters treated as unchanged context
 31    pub context_chars: usize,
 32    /// Fraction of candidate edit characters that match the reference edit.
 33    ///
 34    /// This includes both kept newly introduced characters and correctly
 35    /// deleted base characters.
 36    pub kept_rate: f64,
 37    /// Fraction of reference edit characters covered by the candidate edit.
 38    ///
 39    /// This includes both kept newly introduced characters and correctly
 40    /// deleted base characters.
 41    pub recall_rate: f64,
 42    /// Per-token classification for candidate tokens used by tests.
 43    #[cfg(test)]
 44    pub token_annotations: Vec<TokenAnnotation>,
 45}
 46
 47fn dp_index(width: usize, row: usize, column: usize) -> usize {
 48    row * width + column
 49}
 50
 51/// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side
 52/// while sharing a single DP table construction.
 53fn fill_lcs_keep_masks(
 54    a: &[&str],
 55    b: &[&str],
 56    mut keep_a: Option<&mut [bool]>,
 57    mut keep_b: Option<&mut [bool]>,
 58) {
 59    if a.is_empty() || b.is_empty() {
 60        return;
 61    }
 62
 63    if a == b {
 64        if let Some(keep_a) = keep_a.as_mut() {
 65            keep_a.fill(true);
 66        }
 67        if let Some(keep_b) = keep_b.as_mut() {
 68            keep_b.fill(true);
 69        }
 70        return;
 71    }
 72
 73    let prefix_len = a
 74        .iter()
 75        .zip(b.iter())
 76        .take_while(|(left, right)| left == right)
 77        .count();
 78    let suffix_len = {
 79        let max_suffix = (a.len() - prefix_len).min(b.len() - prefix_len);
 80        let mut suffix_len = 0;
 81
 82        while suffix_len < max_suffix {
 83            let a_index = a.len() - 1 - suffix_len;
 84            let b_index = b.len() - 1 - suffix_len;
 85            if a[a_index] != b[b_index] {
 86                break;
 87            }
 88            suffix_len += 1;
 89        }
 90
 91        suffix_len
 92    };
 93
 94    for index in 0..prefix_len {
 95        if let Some(keep_a) = keep_a.as_mut() {
 96            keep_a[index] = true;
 97        }
 98        if let Some(keep_b) = keep_b.as_mut() {
 99            keep_b[index] = true;
100        }
101    }
102
103    for offset in 0..suffix_len {
104        let a_index = a.len() - suffix_len + offset;
105        let b_index = b.len() - suffix_len + offset;
106        if let Some(keep_a) = keep_a.as_mut() {
107            keep_a[a_index] = true;
108        }
109        if let Some(keep_b) = keep_b.as_mut() {
110            keep_b[b_index] = true;
111        }
112    }
113
114    let a_mid = &a[prefix_len..a.len() - suffix_len];
115    let b_mid = &b[prefix_len..b.len() - suffix_len];
116
117    if a_mid.is_empty() || b_mid.is_empty() {
118        return;
119    }
120
121    let row_count = a_mid.len() + 1;
122    let column_count = b_mid.len() + 1;
123    let mut dp = vec![0u32; row_count * column_count];
124
125    for i in 1..row_count {
126        let token_a = a_mid[i - 1];
127        for j in 1..column_count {
128            let index = dp_index(column_count, i, j);
129            if token_a == b_mid[j - 1] {
130                dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
131            } else {
132                let up = dp[dp_index(column_count, i - 1, j)];
133                let left = dp[dp_index(column_count, i, j - 1)];
134                dp[index] = up.max(left);
135            }
136        }
137    }
138
139    if let Some(keep_a) = keep_a.as_mut() {
140        let mut i = a_mid.len();
141        let mut j = b_mid.len();
142
143        while i > 0 && j > 0 {
144            if a_mid[i - 1] == b_mid[j - 1] {
145                keep_a[prefix_len + i - 1] = true;
146                i -= 1;
147                j -= 1;
148            } else {
149                let up = dp[dp_index(column_count, i - 1, j)];
150                let left = dp[dp_index(column_count, i, j - 1)];
151                if up >= left {
152                    i -= 1;
153                } else {
154                    j -= 1;
155                }
156            }
157        }
158    }
159
160    if let Some(keep_b) = keep_b.as_mut() {
161        let mut i = a_mid.len();
162        let mut j = b_mid.len();
163
164        while i > 0 && j > 0 {
165            if a_mid[i - 1] == b_mid[j - 1] {
166                keep_b[prefix_len + j - 1] = true;
167                i -= 1;
168                j -= 1;
169            } else {
170                let up = dp[dp_index(column_count, i - 1, j)];
171                let left = dp[dp_index(column_count, i, j - 1)];
172                if left >= up {
173                    j -= 1;
174                } else {
175                    i -= 1;
176                }
177            }
178        }
179    }
180}
181
182fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec<bool> {
183    let mut keep_a = vec![false; a.len()];
184    fill_lcs_keep_masks(a, b, Some(&mut keep_a), None);
185    keep_a
186}
187
188fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
189    let mut keep_a = vec![false; a.len()];
190    let mut keep_b = vec![false; b.len()];
191    fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b));
192    (keep_a, keep_b)
193}
194
195fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) {
196    let mut unmasked_tokens = Vec::with_capacity(tokens.len());
197    let mut unmasked_chars = 0;
198    let mut masked_chars = 0;
199
200    for (&token, &is_masked) in tokens.iter().zip(mask.iter()) {
201        if is_masked {
202            masked_chars += token.len();
203        } else {
204            unmasked_tokens.push(token);
205            unmasked_chars += token.len();
206        }
207    }
208
209    (unmasked_tokens, unmasked_chars, masked_chars)
210}
211
212fn count_unmasked_chars(tokens: &[&str], mask: &[bool]) -> usize {
213    tokens
214        .iter()
215        .zip(mask.iter())
216        .filter_map(|(&token, &is_masked)| (!is_masked).then_some(token.len()))
217        .sum()
218}
219
220fn should_bail_for_dirty_final(base: &str, candidate: &str, reference: &str) -> bool {
221    let candidate_delta_chars = candidate.len().abs_diff(base.len());
222    let reference_delta_chars = reference.len().abs_diff(base.len());
223    candidate_delta_chars.abs_diff(reference_delta_chars) > MAX_DIRTY_LENGTH_DELTA_CHARS
224}
225
226pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRateResult {
227    if base == candidate && candidate == reference {
228        let candidate_tokens = tokenize(candidate);
229        let context_chars = candidate_tokens.iter().map(|token| token.len()).sum();
230        return KeptRateResult {
231            candidate_new_chars: 0,
232            reference_new_chars: 0,
233            candidate_deleted_chars: 0,
234            reference_deleted_chars: 0,
235            kept_chars: 0,
236            correctly_deleted_chars: 0,
237            discarded_chars: 0,
238            context_chars,
239            kept_rate: 1.0,
240            recall_rate: 1.0,
241            #[cfg(test)]
242            token_annotations: vec![TokenAnnotation::Context; candidate_tokens.len()],
243        };
244    }
245
246    if should_bail_for_dirty_final(base, candidate, reference) {
247        let candidate_new_chars = candidate.len().abs_diff(base.len());
248        let reference_new_chars = reference.len().abs_diff(base.len());
249        return KeptRateResult {
250            candidate_new_chars,
251            reference_new_chars,
252            candidate_deleted_chars: 0,
253            reference_deleted_chars: 0,
254            kept_chars: 0,
255            correctly_deleted_chars: 0,
256            discarded_chars: candidate_new_chars,
257            context_chars: 0,
258            kept_rate: 0.0,
259            recall_rate: 0.0,
260            #[cfg(test)]
261            token_annotations: vec![TokenAnnotation::Discarded; tokenize(candidate).len()],
262        };
263    }
264
265    let base_tokens = tokenize(base);
266    let candidate_tokens = tokenize(candidate);
267    let reference_tokens = tokenize(reference);
268
269    let (candidate_base_mask, base_candidate_mask) =
270        lcs_keep_masks(&candidate_tokens, &base_tokens);
271    let (candidate_reference_mask, reference_candidate_mask) =
272        lcs_keep_masks(&candidate_tokens, &reference_tokens);
273    let context_mask: Vec<bool> = candidate_base_mask
274        .iter()
275        .zip(candidate_reference_mask.iter())
276        .map(|(&in_base, &in_reference)| in_base && in_reference)
277        .collect();
278
279    let (stripped_candidate, candidate_new_chars, context_chars) =
280        analyze_masked_tokens(&candidate_tokens, &context_mask);
281
282    let (reference_base_mask, base_reference_mask) =
283        lcs_keep_masks(&reference_tokens, &base_tokens);
284    let reference_context_mask: Vec<bool> = reference_base_mask
285        .iter()
286        .zip(reference_candidate_mask.iter())
287        .map(|(&in_base, &in_candidate)| in_base && in_candidate)
288        .collect();
289
290    let (stripped_reference, reference_new_chars, _) =
291        analyze_masked_tokens(&reference_tokens, &reference_context_mask);
292
293    let keep_mask = lcs_keep_mask(&stripped_candidate, &stripped_reference);
294
295    let kept_chars: usize = stripped_candidate
296        .iter()
297        .zip(keep_mask.iter())
298        .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
299        .sum();
300
301    let candidate_deleted_chars = count_unmasked_chars(&base_tokens, &base_candidate_mask);
302    let reference_deleted_chars = count_unmasked_chars(&base_tokens, &base_reference_mask);
303    let correctly_deleted_chars: usize = base_tokens
304        .iter()
305        .zip(base_candidate_mask.iter().zip(base_reference_mask.iter()))
306        .filter_map(|(&token, (&in_candidate, &in_reference))| {
307            (!in_candidate && !in_reference).then_some(token.len())
308        })
309        .sum();
310
311    let discarded_chars = candidate_new_chars - kept_chars;
312    let matched_edit_chars = kept_chars + correctly_deleted_chars;
313    let candidate_edit_chars = candidate_new_chars + candidate_deleted_chars;
314    let reference_edit_chars = reference_new_chars + reference_deleted_chars;
315
316    let kept_rate = if candidate_edit_chars == 0 {
317        if reference_edit_chars == 0 { 1.0 } else { 0.0 }
318    } else {
319        matched_edit_chars as f64 / candidate_edit_chars as f64
320    };
321
322    let recall_rate = if reference_edit_chars == 0 {
323        if candidate_edit_chars == 0 { 1.0 } else { 0.0 }
324    } else {
325        matched_edit_chars as f64 / reference_edit_chars as f64
326    };
327
328    #[cfg(test)]
329    let token_annotations = {
330        let mut token_annotations = Vec::with_capacity(candidate_tokens.len());
331        let mut new_index = 0;
332        for (token_index, _token) in candidate_tokens.iter().enumerate() {
333            if context_mask[token_index] {
334                token_annotations.push(TokenAnnotation::Context);
335            } else {
336                let annotation = if keep_mask[new_index] {
337                    TokenAnnotation::Kept
338                } else {
339                    TokenAnnotation::Discarded
340                };
341                #[cfg(test)]
342                token_annotations.push(annotation);
343                new_index += 1;
344            }
345        }
346        token_annotations
347    };
348
349    KeptRateResult {
350        candidate_new_chars,
351        reference_new_chars,
352        candidate_deleted_chars,
353        reference_deleted_chars,
354        kept_chars,
355        correctly_deleted_chars,
356        discarded_chars,
357        context_chars,
358        kept_rate,
359        recall_rate,
360        #[cfg(test)]
361        token_annotations,
362    }
363}
364
365#[cfg(test)]
366mod test_kept_rate {
367    use super::*;
368
369    #[test]
370    fn test_lcs_keep_masks() {
371        let (a_mask, b_mask) = lcs_keep_masks(&["a", "b", "c", "d", "e"], &["a", "c", "e"]);
372        assert_eq!(a_mask, vec![true, false, true, false, true]);
373        assert_eq!(b_mask, vec![true, true, true]);
374
375        let (a_mask, b_mask) = lcs_keep_masks(&[], &["x"]);
376        assert!(a_mask.is_empty());
377        assert_eq!(b_mask, vec![false]);
378    }
379
380    #[test]
381    fn test_lcs_keep_masks_matches_historical_one_sided_masks() {
382        let a = ["x", "a", "x", "b"];
383        let b = ["a", "x", "b", "x"];
384        let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
385        assert_eq!(a_mask, lcs_keep_mask(&a, &b));
386        assert_eq!(b_mask, lcs_keep_mask(&b, &a));
387    }
388
389    #[test]
390    fn test_rate_extremes() {
391        let no_change = compute_kept_rate("foo bar", "foo bar", "foo bar");
392        assert!((no_change.kept_rate - 1.0).abs() < 1e-6);
393        assert!((no_change.recall_rate - 1.0).abs() < 1e-6);
394        assert_eq!(no_change.candidate_new_chars, 0);
395        assert!(
396            no_change
397                .token_annotations
398                .iter()
399                .all(|&annotation| annotation == TokenAnnotation::Context)
400        );
401
402        let accepted = compute_kept_rate("old", "new", "new");
403        assert!((accepted.kept_rate - 1.0).abs() < 1e-6);
404        assert!((accepted.recall_rate - 1.0).abs() < 1e-6);
405
406        let discarded = compute_kept_rate("old", "old", "new");
407        assert!((discarded.kept_rate - 0.0).abs() < 1e-6);
408        assert!((discarded.recall_rate - 0.0).abs() < 1e-6);
409    }
410
411    #[test]
412    fn test_pure_addition() {
413        let kept = compute_kept_rate("", "brand new line\n", "brand new line\n");
414        assert_eq!(kept.kept_chars, kept.candidate_new_chars);
415        assert!(
416            kept.token_annotations
417                .iter()
418                .all(|&annotation| annotation == TokenAnnotation::Kept)
419        );
420
421        let discarded =
422            compute_kept_rate("", "brand new line\n", "something completely different\n");
423        assert!(discarded.kept_chars < discarded.candidate_new_chars);
424    }
425
426    #[test]
427    fn test_decoy_when_base_excluded() {
428        let base = "    decoy.when(mock_sync_hardware_api.sp()).then_return(SpeedStatus.IDLE)\n";
429        let candidate = "    decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
430        let reference = "    decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
431        let result = compute_kept_rate(base, candidate, reference);
432        let expected_new = "mock_sync_module_hardware".len() + "speed_status".len();
433        assert_eq!(result.candidate_new_chars, expected_new);
434        assert!(result.correctly_deleted_chars > 0);
435        assert!((result.kept_rate - 1.0).abs() < 1e-6);
436        assert!((result.recall_rate - 1.0).abs() < 1e-6);
437    }
438
439    #[test]
440    fn test_missing_deletion() {
441        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
442        let candidate = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\neprintln!(\"\");\n";
443        let reference = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
444        let result = compute_kept_rate(base, candidate, reference);
445        assert!(
446            result.kept_rate < 0.85,
447            "expected kept_rate < 0.85, got {}",
448            result.kept_rate
449        );
450        assert!(result.discarded_chars > 0);
451    }
452
453    #[test]
454    fn test_empty_prediction() {
455        let result = compute_kept_rate("old line\n", "", "new line\n");
456        assert_eq!(result.candidate_new_chars, 0);
457        assert!(result.candidate_deleted_chars > 0);
458        assert!(result.correctly_deleted_chars > 0);
459        assert!(result.correctly_deleted_chars < result.candidate_deleted_chars);
460        assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
461        assert!(result.recall_rate > 0.0 && result.recall_rate < 1.0);
462    }
463
464    #[test]
465    fn test_partial_kept() {
466        let result = compute_kept_rate("old\n", "alpha\nbeta\ngamma\n", "alpha\ngamma\n");
467        assert!(result.kept_chars > 0);
468        assert!(result.discarded_chars > 0);
469        assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
470    }
471
472    #[test]
473    fn test_bails_for_dirty_final() {
474        let base = "fn example() {\n    work();\n}\n";
475        let candidate = "fn example() {\n    work();\n    predicted();\n}\n";
476        let reference = format!(
477            "fn example() {{\n    work();\n    {}\n}}\n",
478            "settled();\n    ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64)
479        );
480
481        let result = compute_kept_rate(base, candidate, &reference);
482        assert_eq!(result.kept_rate, 0.0);
483        assert_eq!(result.recall_rate, 0.0);
484        assert_eq!(result.kept_chars, 0);
485        assert_eq!(result.discarded_chars, result.candidate_new_chars);
486    }
487
488    #[test]
489    fn test_eprintln_token_alignment() {
490        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
491        let candidate = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"hello world!\");\n";
492        let reference = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
493        let result = compute_kept_rate(base, candidate, reference);
494        assert!(result.discarded_chars > 0);
495        assert!(result.kept_chars > 0);
496        assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
497        assert_eq!(result.kept_chars, 14);
498        assert_eq!(result.discarded_chars, 12);
499    }
500
501    #[test]
502    fn test_annotations_rename() {
503        let base = "    foo(old_name)\n";
504        let candidate = "    foo(new_name)\n";
505        let reference = "    foo(new_name)\n";
506        let result = compute_kept_rate(base, candidate, reference);
507
508        assert_eq!(result.candidate_new_chars, "new_name".len());
509        assert_eq!(result.candidate_deleted_chars, "old_name".len());
510        assert_eq!(result.reference_deleted_chars, "old_name".len());
511        assert_eq!(result.correctly_deleted_chars, "old_name".len());
512        assert!((result.recall_rate - 1.0).abs() < 1e-6);
513        assert_eq!(result.token_annotations.len(), tokenize(candidate).len());
514
515        for (&token, &annotation) in tokenize(candidate).iter().zip(&result.token_annotations) {
516            if token == "new_name" {
517                assert_eq!(annotation, TokenAnnotation::Kept);
518            } else {
519                assert_eq!(annotation, TokenAnnotation::Context);
520            }
521        }
522    }
523
524    #[test]
525    fn test_annotations_eprintln_coloring() {
526        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
527        let candidate = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"hello world!\");\n";
528        let reference = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
529        let result = compute_kept_rate(base, candidate, reference);
530        let candidate_tokens = tokenize(candidate);
531
532        let eprintln_index = candidate_tokens
533            .iter()
534            .position(|&token| token == "eprintln")
535            .expect("eprintln token not found");
536
537        for annotation in &result.token_annotations[..eprintln_index] {
538            assert_eq!(*annotation, TokenAnnotation::Context);
539        }
540
541        assert_eq!(
542            &result.token_annotations[eprintln_index..=eprintln_index + 10],
543            &[
544                TokenAnnotation::Kept,
545                TokenAnnotation::Kept,
546                TokenAnnotation::Kept,
547                TokenAnnotation::Kept,
548                TokenAnnotation::Discarded,
549                TokenAnnotation::Discarded,
550                TokenAnnotation::Discarded,
551                TokenAnnotation::Discarded,
552                TokenAnnotation::Kept,
553                TokenAnnotation::Kept,
554                TokenAnnotation::Kept,
555            ]
556        );
557        assert_eq!(
558            result.token_annotations.last(),
559            Some(&TokenAnnotation::Context)
560        );
561    }
562
563    #[test]
564    fn test_repetitive_tokens_remain_discarded() {
565        let base = "foo + foo + foo + foo + foo\n".repeat(16);
566        let candidate = "foo + foo + prediction_token + foo + foo\n".repeat(16);
567        let reference = "foo + foo + kept_token + foo + foo\n".repeat(16);
568        let result = compute_kept_rate(&base, &candidate, &reference);
569
570        assert_eq!(result.kept_chars, 0);
571        assert_eq!(result.correctly_deleted_chars, "foo".len() * 16);
572        assert_eq!(result.discarded_chars, result.candidate_new_chars);
573        assert_eq!(result.candidate_new_chars, "prediction_token".len() * 16);
574        assert!(result.kept_rate > 0.0);
575        assert!(result.recall_rate > 0.0);
576    }
577}