kept_rate.rs

  1use crate::word_diff::tokenize;
  2
  3#[cfg(test)]
  4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
  5pub enum TokenAnnotation {
  6    Context,
  7    Kept,
  8    Discarded,
  9}
 10
 11#[allow(dead_code)]
 12#[derive(Debug, Clone)]
 13pub struct KeptRateResult {
 14    pub predicted_new_chars: usize,
 15    pub final_new_chars: usize,
 16    pub kept_chars: usize,
 17    pub discarded_chars: usize,
 18    pub context_chars: usize,
 19    pub kept_rate: f64,
 20    #[cfg(test)]
 21    pub token_annotations: Vec<TokenAnnotation>,
 22}
 23
 24fn dp_index(width: usize, row: usize, column: usize) -> usize {
 25    row * width + column
 26}
 27
 28/// Return masks over `a` and `b` using one-sided LCS tie-breaking for each
 29/// side while sharing a single DP table construction.
 30fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
 31    if a.is_empty() || b.is_empty() {
 32        return (vec![false; a.len()], vec![false; b.len()]);
 33    }
 34
 35    if a == b {
 36        return (vec![true; a.len()], vec![true; b.len()]);
 37    }
 38
 39    let mut keep_a = vec![false; a.len()];
 40    let mut keep_b = vec![false; b.len()];
 41
 42    let prefix_len = a
 43        .iter()
 44        .zip(b.iter())
 45        .take_while(|(left, right)| left == right)
 46        .count();
 47    let suffix_len = {
 48        let max_suffix = (a.len() - prefix_len).min(b.len() - prefix_len);
 49        let mut suffix_len = 0;
 50
 51        while suffix_len < max_suffix {
 52            let a_index = a.len() - 1 - suffix_len;
 53            let b_index = b.len() - 1 - suffix_len;
 54            if a[a_index] != b[b_index] {
 55                break;
 56            }
 57            suffix_len += 1;
 58        }
 59
 60        suffix_len
 61    };
 62
 63    for index in 0..prefix_len {
 64        keep_a[index] = true;
 65        keep_b[index] = true;
 66    }
 67
 68    for offset in 0..suffix_len {
 69        let a_index = a.len() - suffix_len + offset;
 70        let b_index = b.len() - suffix_len + offset;
 71        keep_a[a_index] = true;
 72        keep_b[b_index] = true;
 73    }
 74
 75    let a_mid = &a[prefix_len..a.len() - suffix_len];
 76    let b_mid = &b[prefix_len..b.len() - suffix_len];
 77
 78    if a_mid.is_empty() || b_mid.is_empty() {
 79        return (keep_a, keep_b);
 80    }
 81
 82    let row_count = a_mid.len() + 1;
 83    let column_count = b_mid.len() + 1;
 84    let mut dp = vec![0u32; row_count * column_count];
 85
 86    for i in 1..row_count {
 87        let token_a = a_mid[i - 1];
 88        for j in 1..column_count {
 89            let index = dp_index(column_count, i, j);
 90            if token_a == b_mid[j - 1] {
 91                dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
 92            } else {
 93                let up = dp[dp_index(column_count, i - 1, j)];
 94                let left = dp[dp_index(column_count, i, j - 1)];
 95                dp[index] = up.max(left);
 96            }
 97        }
 98    }
 99
100    let mut i = a_mid.len();
101    let mut j = b_mid.len();
102
103    while i > 0 && j > 0 {
104        if a_mid[i - 1] == b_mid[j - 1] {
105            keep_a[prefix_len + i - 1] = true;
106            i -= 1;
107            j -= 1;
108        } else {
109            let up = dp[dp_index(column_count, i - 1, j)];
110            let left = dp[dp_index(column_count, i, j - 1)];
111            if up >= left {
112                i -= 1;
113            } else {
114                j -= 1;
115            }
116        }
117    }
118
119    let mut i = a_mid.len();
120    let mut j = b_mid.len();
121
122    while i > 0 && j > 0 {
123        if a_mid[i - 1] == b_mid[j - 1] {
124            keep_b[prefix_len + j - 1] = true;
125            i -= 1;
126            j -= 1;
127        } else {
128            let up = dp[dp_index(column_count, i - 1, j)];
129            let left = dp[dp_index(column_count, i, j - 1)];
130            if left >= up {
131                j -= 1;
132            } else {
133                i -= 1;
134            }
135        }
136    }
137
138    (keep_a, keep_b)
139}
140
141fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) {
142    let mut unmasked_tokens = Vec::with_capacity(tokens.len());
143    let mut unmasked_chars = 0;
144    let mut masked_chars = 0;
145
146    for (&token, &is_masked) in tokens.iter().zip(mask.iter()) {
147        if is_masked {
148            masked_chars += token.len();
149        } else {
150            unmasked_tokens.push(token);
151            unmasked_chars += token.len();
152        }
153    }
154
155    (unmasked_tokens, unmasked_chars, masked_chars)
156}
157
158pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptRateResult {
159    if base == predicted && predicted == final_text {
160        let predicted_tokens = tokenize(predicted);
161        let context_chars = predicted_tokens.iter().map(|token| token.len()).sum();
162        return KeptRateResult {
163            predicted_new_chars: 0,
164            final_new_chars: 0,
165            kept_chars: 0,
166            discarded_chars: 0,
167            context_chars,
168            kept_rate: 1.0,
169            #[cfg(test)]
170            token_annotations: vec![TokenAnnotation::Context; predicted_tokens.len()],
171        };
172    }
173
174    let base_tokens = tokenize(base);
175    let predicted_tokens = tokenize(predicted);
176    let final_tokens = tokenize(final_text);
177
178    let (pred_base_mask, _) = lcs_keep_masks(&predicted_tokens, &base_tokens);
179    let (pred_final_mask, final_pred_mask) = lcs_keep_masks(&predicted_tokens, &final_tokens);
180    let context_mask: Vec<bool> = pred_base_mask
181        .iter()
182        .zip(pred_final_mask.iter())
183        .map(|(&in_base, &in_final)| in_base && in_final)
184        .collect();
185
186    let (stripped_predicted, predicted_new_chars, context_chars) =
187        analyze_masked_tokens(&predicted_tokens, &context_mask);
188
189    let (final_base_mask, _) = lcs_keep_masks(&final_tokens, &base_tokens);
190    let final_context_mask: Vec<bool> = final_base_mask
191        .iter()
192        .zip(final_pred_mask.iter())
193        .map(|(&in_base, &in_predicted)| in_base && in_predicted)
194        .collect();
195
196    let (stripped_final, final_new_chars, _) =
197        analyze_masked_tokens(&final_tokens, &final_context_mask);
198
199    let keep_mask = lcs_keep_masks(&stripped_predicted, &stripped_final).0;
200
201    let kept_chars: usize = stripped_predicted
202        .iter()
203        .zip(keep_mask.iter())
204        .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
205        .sum();
206
207    let discarded_chars = predicted_new_chars - kept_chars;
208
209    let kept_rate = if predicted_new_chars == 0 {
210        if final_new_chars == 0 { 1.0 } else { 0.0 }
211    } else {
212        kept_chars as f64 / predicted_new_chars as f64
213    };
214
215    #[cfg(test)]
216    let token_annotations = {
217        let mut token_annotations = Vec::with_capacity(predicted_tokens.len());
218        let mut new_index = 0;
219        for (token_index, _token) in predicted_tokens.iter().enumerate() {
220            if context_mask[token_index] {
221                token_annotations.push(TokenAnnotation::Context);
222            } else {
223                let annotation = if keep_mask[new_index] {
224                    TokenAnnotation::Kept
225                } else {
226                    TokenAnnotation::Discarded
227                };
228                #[cfg(test)]
229                token_annotations.push(annotation);
230                new_index += 1;
231            }
232        }
233        token_annotations
234    };
235
236    KeptRateResult {
237        predicted_new_chars,
238        final_new_chars,
239        kept_chars,
240        discarded_chars,
241        context_chars,
242        kept_rate,
243        #[cfg(test)]
244        token_annotations,
245    }
246}
247
248#[cfg(test)]
249mod test_kept_rate {
250    use super::*;
251
252    #[test]
253    fn test_lcs_keep_masks() {
254        let (a_mask, b_mask) = lcs_keep_masks(&["a", "b", "c", "d", "e"], &["a", "c", "e"]);
255        assert_eq!(a_mask, vec![true, false, true, false, true]);
256        assert_eq!(b_mask, vec![true, true, true]);
257
258        let (a_mask, b_mask) = lcs_keep_masks(&[], &["x"]);
259        assert!(a_mask.is_empty());
260        assert_eq!(b_mask, vec![false]);
261    }
262
263    #[test]
264    fn test_lcs_keep_masks_matches_historical_one_sided_masks() {
265        let a = ["x", "a", "x", "b"];
266        let b = ["a", "x", "b", "x"];
267        let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
268        assert_eq!(a_mask, lcs_keep_masks(&a, &b).0);
269        assert_eq!(b_mask, lcs_keep_masks(&b, &a).0);
270    }
271
272    #[test]
273    fn test_rate_extremes() {
274        let no_change = compute_kept_rate("foo bar", "foo bar", "foo bar");
275        assert!((no_change.kept_rate - 1.0).abs() < 1e-6);
276        assert_eq!(no_change.predicted_new_chars, 0);
277        assert!(
278            no_change
279                .token_annotations
280                .iter()
281                .all(|&annotation| annotation == TokenAnnotation::Context)
282        );
283
284        let accepted = compute_kept_rate("old", "new", "new");
285        assert!((accepted.kept_rate - 1.0).abs() < 1e-6);
286
287        let discarded = compute_kept_rate("old", "old", "new");
288        assert!((discarded.kept_rate - 0.0).abs() < 1e-6);
289    }
290
291    #[test]
292    fn test_pure_addition() {
293        let kept = compute_kept_rate("", "brand new line\n", "brand new line\n");
294        assert_eq!(kept.kept_chars, kept.predicted_new_chars);
295        assert!(
296            kept.token_annotations
297                .iter()
298                .all(|&annotation| annotation == TokenAnnotation::Kept)
299        );
300
301        let discarded =
302            compute_kept_rate("", "brand new line\n", "something completely different\n");
303        assert!(discarded.kept_chars < discarded.predicted_new_chars);
304    }
305
306    #[test]
307    fn test_decoy_when_base_excluded() {
308        let base = "    decoy.when(mock_sync_hardware_api.sp()).then_return(SpeedStatus.IDLE)\n";
309        let predicted = "    decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
310        let final_text = "    decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
311        let result = compute_kept_rate(base, predicted, final_text);
312        let expected_new = "mock_sync_module_hardware".len() + "speed_status".len();
313        assert_eq!(result.predicted_new_chars, expected_new);
314        assert!((result.kept_rate - 1.0).abs() < 1e-6);
315    }
316
317    #[test]
318    fn test_missing_deletion() {
319        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
320        let predicted = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\neprintln!(\"\");\n";
321        let final_text = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
322        let result = compute_kept_rate(base, predicted, final_text);
323        assert!(
324            result.kept_rate < 0.85,
325            "expected kept_rate < 0.85, got {}",
326            result.kept_rate
327        );
328        assert!(result.discarded_chars > 0);
329    }
330
331    #[test]
332    fn test_empty_prediction() {
333        let result = compute_kept_rate("old line\n", "", "new line\n");
334        assert!((result.kept_rate - 0.0).abs() < 1e-6);
335    }
336
337    #[test]
338    fn test_partial_kept() {
339        let result = compute_kept_rate("old\n", "alpha\nbeta\ngamma\n", "alpha\ngamma\n");
340        assert!(result.kept_chars > 0);
341        assert!(result.discarded_chars > 0);
342        assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
343    }
344
345    #[test]
346    fn test_eprintln_token_alignment() {
347        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
348        let predicted = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"hello world!\");\n";
349        let final_text = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
350        let result = compute_kept_rate(base, predicted, final_text);
351        assert!(result.discarded_chars > 0);
352        assert!(result.kept_chars > 0);
353        assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
354        assert_eq!(result.kept_chars, 14);
355        assert_eq!(result.discarded_chars, 12);
356    }
357
358    #[test]
359    fn test_annotations_rename() {
360        let base = "    foo(old_name)\n";
361        let predicted = "    foo(new_name)\n";
362        let final_text = "    foo(new_name)\n";
363        let result = compute_kept_rate(base, predicted, final_text);
364
365        assert_eq!(result.predicted_new_chars, "new_name".len());
366        assert_eq!(result.token_annotations.len(), tokenize(predicted).len());
367
368        for (&token, &annotation) in tokenize(predicted).iter().zip(&result.token_annotations) {
369            if token == "new_name" {
370                assert_eq!(annotation, TokenAnnotation::Kept);
371            } else {
372                assert_eq!(annotation, TokenAnnotation::Context);
373            }
374        }
375    }
376
377    #[test]
378    fn test_annotations_eprintln_coloring() {
379        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
380        let predicted = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"hello world!\");\n";
381        let final_text = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
382        let result = compute_kept_rate(base, predicted, final_text);
383        let predicted_tokens = tokenize(predicted);
384
385        let eprintln_index = predicted_tokens
386            .iter()
387            .position(|&token| token == "eprintln")
388            .expect("eprintln token not found");
389
390        for annotation in &result.token_annotations[..eprintln_index] {
391            assert_eq!(*annotation, TokenAnnotation::Context);
392        }
393
394        assert_eq!(
395            &result.token_annotations[eprintln_index..=eprintln_index + 10],
396            &[
397                TokenAnnotation::Kept,
398                TokenAnnotation::Kept,
399                TokenAnnotation::Kept,
400                TokenAnnotation::Kept,
401                TokenAnnotation::Discarded,
402                TokenAnnotation::Discarded,
403                TokenAnnotation::Discarded,
404                TokenAnnotation::Discarded,
405                TokenAnnotation::Kept,
406                TokenAnnotation::Kept,
407                TokenAnnotation::Kept,
408            ]
409        );
410        assert_eq!(
411            result.token_annotations.last(),
412            Some(&TokenAnnotation::Context)
413        );
414    }
415
416    #[test]
417    fn test_repetitive_tokens_remain_discarded() {
418        let base = "foo + foo + foo + foo + foo\n".repeat(16);
419        let predicted = "foo + foo + prediction_token + foo + foo\n".repeat(16);
420        let final_text = "foo + foo + kept_token + foo + foo\n".repeat(16);
421        let result = compute_kept_rate(&base, &predicted, &final_text);
422
423        assert_eq!(result.kept_chars, 0);
424        assert_eq!(result.discarded_chars, result.predicted_new_chars);
425        assert_eq!(result.predicted_new_chars, "prediction_token".len() * 16);
426    }
427}