kept_rate.rs

  1use crate::metrics::tokenize;
  2
  3const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
  4
  5#[cfg(test)]
  6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
  7pub enum TokenAnnotation {
  8    Context,
  9    Kept,
 10    Discarded,
 11}
 12
 13#[allow(dead_code)]
 14#[derive(Debug, Clone)]
 15pub struct KeptRateResult {
 16    pub predicted_new_chars: usize,
 17    pub final_new_chars: usize,
 18    pub kept_chars: usize,
 19    pub discarded_chars: usize,
 20    pub context_chars: usize,
 21    pub kept_rate: f64,
 22    #[cfg(test)]
 23    pub token_annotations: Vec<TokenAnnotation>,
 24}
 25
 26fn dp_index(width: usize, row: usize, column: usize) -> usize {
 27    row * width + column
 28}
 29
 30/// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side
 31/// while sharing a single DP table construction.
 32fn fill_lcs_keep_masks(
 33    a: &[&str],
 34    b: &[&str],
 35    mut keep_a: Option<&mut [bool]>,
 36    mut keep_b: Option<&mut [bool]>,
 37) {
 38    if a.is_empty() || b.is_empty() {
 39        return;
 40    }
 41
 42    if a == b {
 43        if let Some(keep_a) = keep_a.as_mut() {
 44            keep_a.fill(true);
 45        }
 46        if let Some(keep_b) = keep_b.as_mut() {
 47            keep_b.fill(true);
 48        }
 49        return;
 50    }
 51
 52    let prefix_len = a
 53        .iter()
 54        .zip(b.iter())
 55        .take_while(|(left, right)| left == right)
 56        .count();
 57    let suffix_len = {
 58        let max_suffix = (a.len() - prefix_len).min(b.len() - prefix_len);
 59        let mut suffix_len = 0;
 60
 61        while suffix_len < max_suffix {
 62            let a_index = a.len() - 1 - suffix_len;
 63            let b_index = b.len() - 1 - suffix_len;
 64            if a[a_index] != b[b_index] {
 65                break;
 66            }
 67            suffix_len += 1;
 68        }
 69
 70        suffix_len
 71    };
 72
 73    for index in 0..prefix_len {
 74        if let Some(keep_a) = keep_a.as_mut() {
 75            keep_a[index] = true;
 76        }
 77        if let Some(keep_b) = keep_b.as_mut() {
 78            keep_b[index] = true;
 79        }
 80    }
 81
 82    for offset in 0..suffix_len {
 83        let a_index = a.len() - suffix_len + offset;
 84        let b_index = b.len() - suffix_len + offset;
 85        if let Some(keep_a) = keep_a.as_mut() {
 86            keep_a[a_index] = true;
 87        }
 88        if let Some(keep_b) = keep_b.as_mut() {
 89            keep_b[b_index] = true;
 90        }
 91    }
 92
 93    let a_mid = &a[prefix_len..a.len() - suffix_len];
 94    let b_mid = &b[prefix_len..b.len() - suffix_len];
 95
 96    if a_mid.is_empty() || b_mid.is_empty() {
 97        return;
 98    }
 99
100    let row_count = a_mid.len() + 1;
101    let column_count = b_mid.len() + 1;
102    let mut dp = vec![0u32; row_count * column_count];
103
104    for i in 1..row_count {
105        let token_a = a_mid[i - 1];
106        for j in 1..column_count {
107            let index = dp_index(column_count, i, j);
108            if token_a == b_mid[j - 1] {
109                dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
110            } else {
111                let up = dp[dp_index(column_count, i - 1, j)];
112                let left = dp[dp_index(column_count, i, j - 1)];
113                dp[index] = up.max(left);
114            }
115        }
116    }
117
118    if let Some(keep_a) = keep_a.as_mut() {
119        let mut i = a_mid.len();
120        let mut j = b_mid.len();
121
122        while i > 0 && j > 0 {
123            if a_mid[i - 1] == b_mid[j - 1] {
124                keep_a[prefix_len + i - 1] = true;
125                i -= 1;
126                j -= 1;
127            } else {
128                let up = dp[dp_index(column_count, i - 1, j)];
129                let left = dp[dp_index(column_count, i, j - 1)];
130                if up >= left {
131                    i -= 1;
132                } else {
133                    j -= 1;
134                }
135            }
136        }
137    }
138
139    if let Some(keep_b) = keep_b.as_mut() {
140        let mut i = a_mid.len();
141        let mut j = b_mid.len();
142
143        while i > 0 && j > 0 {
144            if a_mid[i - 1] == b_mid[j - 1] {
145                keep_b[prefix_len + j - 1] = true;
146                i -= 1;
147                j -= 1;
148            } else {
149                let up = dp[dp_index(column_count, i - 1, j)];
150                let left = dp[dp_index(column_count, i, j - 1)];
151                if left >= up {
152                    j -= 1;
153                } else {
154                    i -= 1;
155                }
156            }
157        }
158    }
159}
160
161fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec<bool> {
162    let mut keep_a = vec![false; a.len()];
163    fill_lcs_keep_masks(a, b, Some(&mut keep_a), None);
164    keep_a
165}
166
167fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
168    let mut keep_a = vec![false; a.len()];
169    let mut keep_b = vec![false; b.len()];
170    fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b));
171    (keep_a, keep_b)
172}
173
174fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) {
175    let mut unmasked_tokens = Vec::with_capacity(tokens.len());
176    let mut unmasked_chars = 0;
177    let mut masked_chars = 0;
178
179    for (&token, &is_masked) in tokens.iter().zip(mask.iter()) {
180        if is_masked {
181            masked_chars += token.len();
182        } else {
183            unmasked_tokens.push(token);
184            unmasked_chars += token.len();
185        }
186    }
187
188    (unmasked_tokens, unmasked_chars, masked_chars)
189}
190
191fn should_bail_for_dirty_final(base: &str, predicted: &str, final_text: &str) -> bool {
192    let predicted_delta_chars = predicted.len().abs_diff(base.len());
193    let final_delta_chars = final_text.len().abs_diff(base.len());
194    predicted_delta_chars.abs_diff(final_delta_chars) > MAX_DIRTY_LENGTH_DELTA_CHARS
195}
196
197pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptRateResult {
198    if base == predicted && predicted == final_text {
199        let predicted_tokens = tokenize(predicted);
200        let context_chars = predicted_tokens.iter().map(|token| token.len()).sum();
201        return KeptRateResult {
202            predicted_new_chars: 0,
203            final_new_chars: 0,
204            kept_chars: 0,
205            discarded_chars: 0,
206            context_chars,
207            kept_rate: 1.0,
208            #[cfg(test)]
209            token_annotations: vec![TokenAnnotation::Context; predicted_tokens.len()],
210        };
211    }
212
213    if should_bail_for_dirty_final(base, predicted, final_text) {
214        let predicted_new_chars = predicted.len().abs_diff(base.len());
215        let final_new_chars = final_text.len().abs_diff(base.len());
216        return KeptRateResult {
217            predicted_new_chars,
218            final_new_chars,
219            kept_chars: 0,
220            discarded_chars: predicted_new_chars,
221            context_chars: 0,
222            kept_rate: 0.0,
223            #[cfg(test)]
224            token_annotations: vec![TokenAnnotation::Discarded; tokenize(predicted).len()],
225        };
226    }
227
228    let base_tokens = tokenize(base);
229    let predicted_tokens = tokenize(predicted);
230    let final_tokens = tokenize(final_text);
231
232    let pred_base_mask = lcs_keep_mask(&predicted_tokens, &base_tokens);
233    let (pred_final_mask, final_pred_mask) = lcs_keep_masks(&predicted_tokens, &final_tokens);
234    let context_mask: Vec<bool> = pred_base_mask
235        .iter()
236        .zip(pred_final_mask.iter())
237        .map(|(&in_base, &in_final)| in_base && in_final)
238        .collect();
239
240    let (stripped_predicted, predicted_new_chars, context_chars) =
241        analyze_masked_tokens(&predicted_tokens, &context_mask);
242
243    let final_base_mask = lcs_keep_mask(&final_tokens, &base_tokens);
244    let final_context_mask: Vec<bool> = final_base_mask
245        .iter()
246        .zip(final_pred_mask.iter())
247        .map(|(&in_base, &in_predicted)| in_base && in_predicted)
248        .collect();
249
250    let (stripped_final, final_new_chars, _) =
251        analyze_masked_tokens(&final_tokens, &final_context_mask);
252
253    let keep_mask = lcs_keep_mask(&stripped_predicted, &stripped_final);
254
255    let kept_chars: usize = stripped_predicted
256        .iter()
257        .zip(keep_mask.iter())
258        .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
259        .sum();
260
261    let discarded_chars = predicted_new_chars - kept_chars;
262
263    let kept_rate = if predicted_new_chars == 0 {
264        if final_new_chars == 0 { 1.0 } else { 0.0 }
265    } else {
266        kept_chars as f64 / predicted_new_chars as f64
267    };
268
269    #[cfg(test)]
270    let token_annotations = {
271        let mut token_annotations = Vec::with_capacity(predicted_tokens.len());
272        let mut new_index = 0;
273        for (token_index, _token) in predicted_tokens.iter().enumerate() {
274            if context_mask[token_index] {
275                token_annotations.push(TokenAnnotation::Context);
276            } else {
277                let annotation = if keep_mask[new_index] {
278                    TokenAnnotation::Kept
279                } else {
280                    TokenAnnotation::Discarded
281                };
282                #[cfg(test)]
283                token_annotations.push(annotation);
284                new_index += 1;
285            }
286        }
287        token_annotations
288    };
289
290    KeptRateResult {
291        predicted_new_chars,
292        final_new_chars,
293        kept_chars,
294        discarded_chars,
295        context_chars,
296        kept_rate,
297        #[cfg(test)]
298        token_annotations,
299    }
300}
301
302#[cfg(test)]
303mod test_kept_rate {
304    use super::*;
305
306    #[test]
307    fn test_lcs_keep_masks() {
308        let (a_mask, b_mask) = lcs_keep_masks(&["a", "b", "c", "d", "e"], &["a", "c", "e"]);
309        assert_eq!(a_mask, vec![true, false, true, false, true]);
310        assert_eq!(b_mask, vec![true, true, true]);
311
312        let (a_mask, b_mask) = lcs_keep_masks(&[], &["x"]);
313        assert!(a_mask.is_empty());
314        assert_eq!(b_mask, vec![false]);
315    }
316
317    #[test]
318    fn test_lcs_keep_masks_matches_historical_one_sided_masks() {
319        let a = ["x", "a", "x", "b"];
320        let b = ["a", "x", "b", "x"];
321        let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
322        assert_eq!(a_mask, lcs_keep_mask(&a, &b));
323        assert_eq!(b_mask, lcs_keep_mask(&b, &a));
324    }
325
326    #[test]
327    fn test_rate_extremes() {
328        let no_change = compute_kept_rate("foo bar", "foo bar", "foo bar");
329        assert!((no_change.kept_rate - 1.0).abs() < 1e-6);
330        assert_eq!(no_change.predicted_new_chars, 0);
331        assert!(
332            no_change
333                .token_annotations
334                .iter()
335                .all(|&annotation| annotation == TokenAnnotation::Context)
336        );
337
338        let accepted = compute_kept_rate("old", "new", "new");
339        assert!((accepted.kept_rate - 1.0).abs() < 1e-6);
340
341        let discarded = compute_kept_rate("old", "old", "new");
342        assert!((discarded.kept_rate - 0.0).abs() < 1e-6);
343    }
344
345    #[test]
346    fn test_pure_addition() {
347        let kept = compute_kept_rate("", "brand new line\n", "brand new line\n");
348        assert_eq!(kept.kept_chars, kept.predicted_new_chars);
349        assert!(
350            kept.token_annotations
351                .iter()
352                .all(|&annotation| annotation == TokenAnnotation::Kept)
353        );
354
355        let discarded =
356            compute_kept_rate("", "brand new line\n", "something completely different\n");
357        assert!(discarded.kept_chars < discarded.predicted_new_chars);
358    }
359
360    #[test]
361    fn test_decoy_when_base_excluded() {
362        let base = "    decoy.when(mock_sync_hardware_api.sp()).then_return(SpeedStatus.IDLE)\n";
363        let predicted = "    decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
364        let final_text = "    decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
365        let result = compute_kept_rate(base, predicted, final_text);
366        let expected_new = "mock_sync_module_hardware".len() + "speed_status".len();
367        assert_eq!(result.predicted_new_chars, expected_new);
368        assert!((result.kept_rate - 1.0).abs() < 1e-6);
369    }
370
371    #[test]
372    fn test_missing_deletion() {
373        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
374        let predicted = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\neprintln!(\"\");\n";
375        let final_text = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
376        let result = compute_kept_rate(base, predicted, final_text);
377        assert!(
378            result.kept_rate < 0.85,
379            "expected kept_rate < 0.85, got {}",
380            result.kept_rate
381        );
382        assert!(result.discarded_chars > 0);
383    }
384
385    #[test]
386    fn test_empty_prediction() {
387        let result = compute_kept_rate("old line\n", "", "new line\n");
388        assert!((result.kept_rate - 0.0).abs() < 1e-6);
389    }
390
391    #[test]
392    fn test_partial_kept() {
393        let result = compute_kept_rate("old\n", "alpha\nbeta\ngamma\n", "alpha\ngamma\n");
394        assert!(result.kept_chars > 0);
395        assert!(result.discarded_chars > 0);
396        assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
397    }
398
399    #[test]
400    fn test_bails_for_dirty_final() {
401        let base = "fn example() {\n    work();\n}\n";
402        let predicted = "fn example() {\n    work();\n    predicted();\n}\n";
403        let final_text = format!(
404            "fn example() {{\n    work();\n    {}\n}}\n",
405            "settled();\n    ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64)
406        );
407
408        let result = compute_kept_rate(base, predicted, &final_text);
409        assert_eq!(result.kept_rate, 0.0);
410        assert_eq!(result.kept_chars, 0);
411        assert_eq!(result.discarded_chars, result.predicted_new_chars);
412    }
413
414    #[test]
415    fn test_eprintln_token_alignment() {
416        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
417        let predicted = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"hello world!\");\n";
418        let final_text = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
419        let result = compute_kept_rate(base, predicted, final_text);
420        assert!(result.discarded_chars > 0);
421        assert!(result.kept_chars > 0);
422        assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
423        assert_eq!(result.kept_chars, 14);
424        assert_eq!(result.discarded_chars, 12);
425    }
426
427    #[test]
428    fn test_annotations_rename() {
429        let base = "    foo(old_name)\n";
430        let predicted = "    foo(new_name)\n";
431        let final_text = "    foo(new_name)\n";
432        let result = compute_kept_rate(base, predicted, final_text);
433
434        assert_eq!(result.predicted_new_chars, "new_name".len());
435        assert_eq!(result.token_annotations.len(), tokenize(predicted).len());
436
437        for (&token, &annotation) in tokenize(predicted).iter().zip(&result.token_annotations) {
438            if token == "new_name" {
439                assert_eq!(annotation, TokenAnnotation::Kept);
440            } else {
441                assert_eq!(annotation, TokenAnnotation::Context);
442            }
443        }
444    }
445
446    #[test]
447    fn test_annotations_eprintln_coloring() {
448        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
449        let predicted = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"hello world!\");\n";
450        let final_text = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
451        let result = compute_kept_rate(base, predicted, final_text);
452        let predicted_tokens = tokenize(predicted);
453
454        let eprintln_index = predicted_tokens
455            .iter()
456            .position(|&token| token == "eprintln")
457            .expect("eprintln token not found");
458
459        for annotation in &result.token_annotations[..eprintln_index] {
460            assert_eq!(*annotation, TokenAnnotation::Context);
461        }
462
463        assert_eq!(
464            &result.token_annotations[eprintln_index..=eprintln_index + 10],
465            &[
466                TokenAnnotation::Kept,
467                TokenAnnotation::Kept,
468                TokenAnnotation::Kept,
469                TokenAnnotation::Kept,
470                TokenAnnotation::Discarded,
471                TokenAnnotation::Discarded,
472                TokenAnnotation::Discarded,
473                TokenAnnotation::Discarded,
474                TokenAnnotation::Kept,
475                TokenAnnotation::Kept,
476                TokenAnnotation::Kept,
477            ]
478        );
479        assert_eq!(
480            result.token_annotations.last(),
481            Some(&TokenAnnotation::Context)
482        );
483    }
484
485    #[test]
486    fn test_repetitive_tokens_remain_discarded() {
487        let base = "foo + foo + foo + foo + foo\n".repeat(16);
488        let predicted = "foo + foo + prediction_token + foo + foo\n".repeat(16);
489        let final_text = "foo + foo + kept_token + foo + foo\n".repeat(16);
490        let result = compute_kept_rate(&base, &predicted, &final_text);
491
492        assert_eq!(result.kept_chars, 0);
493        assert_eq!(result.discarded_chars, result.predicted_new_chars);
494        assert_eq!(result.predicted_new_chars, "prediction_token".len() * 16);
495    }
496}