1use crate::metrics::tokenize;
2
3const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
4
5#[cfg(test)]
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum TokenAnnotation {
8 Context,
9 Kept,
10 Discarded,
11}
12
13#[allow(dead_code)]
14#[derive(Debug, Clone)]
15pub struct KeptRateResult {
16 pub predicted_new_chars: usize,
17 pub final_new_chars: usize,
18 pub kept_chars: usize,
19 pub discarded_chars: usize,
20 pub context_chars: usize,
21 pub kept_rate: f64,
22 #[cfg(test)]
23 pub token_annotations: Vec<TokenAnnotation>,
24}
25
26fn dp_index(width: usize, row: usize, column: usize) -> usize {
27 row * width + column
28}
29
30/// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side
31/// while sharing a single DP table construction.
32fn fill_lcs_keep_masks(
33 a: &[&str],
34 b: &[&str],
35 mut keep_a: Option<&mut [bool]>,
36 mut keep_b: Option<&mut [bool]>,
37) {
38 if a.is_empty() || b.is_empty() {
39 return;
40 }
41
42 if a == b {
43 if let Some(keep_a) = keep_a.as_mut() {
44 keep_a.fill(true);
45 }
46 if let Some(keep_b) = keep_b.as_mut() {
47 keep_b.fill(true);
48 }
49 return;
50 }
51
52 let prefix_len = a
53 .iter()
54 .zip(b.iter())
55 .take_while(|(left, right)| left == right)
56 .count();
57 let suffix_len = {
58 let max_suffix = (a.len() - prefix_len).min(b.len() - prefix_len);
59 let mut suffix_len = 0;
60
61 while suffix_len < max_suffix {
62 let a_index = a.len() - 1 - suffix_len;
63 let b_index = b.len() - 1 - suffix_len;
64 if a[a_index] != b[b_index] {
65 break;
66 }
67 suffix_len += 1;
68 }
69
70 suffix_len
71 };
72
73 for index in 0..prefix_len {
74 if let Some(keep_a) = keep_a.as_mut() {
75 keep_a[index] = true;
76 }
77 if let Some(keep_b) = keep_b.as_mut() {
78 keep_b[index] = true;
79 }
80 }
81
82 for offset in 0..suffix_len {
83 let a_index = a.len() - suffix_len + offset;
84 let b_index = b.len() - suffix_len + offset;
85 if let Some(keep_a) = keep_a.as_mut() {
86 keep_a[a_index] = true;
87 }
88 if let Some(keep_b) = keep_b.as_mut() {
89 keep_b[b_index] = true;
90 }
91 }
92
93 let a_mid = &a[prefix_len..a.len() - suffix_len];
94 let b_mid = &b[prefix_len..b.len() - suffix_len];
95
96 if a_mid.is_empty() || b_mid.is_empty() {
97 return;
98 }
99
100 let row_count = a_mid.len() + 1;
101 let column_count = b_mid.len() + 1;
102 let mut dp = vec![0u32; row_count * column_count];
103
104 for i in 1..row_count {
105 let token_a = a_mid[i - 1];
106 for j in 1..column_count {
107 let index = dp_index(column_count, i, j);
108 if token_a == b_mid[j - 1] {
109 dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
110 } else {
111 let up = dp[dp_index(column_count, i - 1, j)];
112 let left = dp[dp_index(column_count, i, j - 1)];
113 dp[index] = up.max(left);
114 }
115 }
116 }
117
118 if let Some(keep_a) = keep_a.as_mut() {
119 let mut i = a_mid.len();
120 let mut j = b_mid.len();
121
122 while i > 0 && j > 0 {
123 if a_mid[i - 1] == b_mid[j - 1] {
124 keep_a[prefix_len + i - 1] = true;
125 i -= 1;
126 j -= 1;
127 } else {
128 let up = dp[dp_index(column_count, i - 1, j)];
129 let left = dp[dp_index(column_count, i, j - 1)];
130 if up >= left {
131 i -= 1;
132 } else {
133 j -= 1;
134 }
135 }
136 }
137 }
138
139 if let Some(keep_b) = keep_b.as_mut() {
140 let mut i = a_mid.len();
141 let mut j = b_mid.len();
142
143 while i > 0 && j > 0 {
144 if a_mid[i - 1] == b_mid[j - 1] {
145 keep_b[prefix_len + j - 1] = true;
146 i -= 1;
147 j -= 1;
148 } else {
149 let up = dp[dp_index(column_count, i - 1, j)];
150 let left = dp[dp_index(column_count, i, j - 1)];
151 if left >= up {
152 j -= 1;
153 } else {
154 i -= 1;
155 }
156 }
157 }
158 }
159}
160
161fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec<bool> {
162 let mut keep_a = vec![false; a.len()];
163 fill_lcs_keep_masks(a, b, Some(&mut keep_a), None);
164 keep_a
165}
166
167fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
168 let mut keep_a = vec![false; a.len()];
169 let mut keep_b = vec![false; b.len()];
170 fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b));
171 (keep_a, keep_b)
172}
173
174fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) {
175 let mut unmasked_tokens = Vec::with_capacity(tokens.len());
176 let mut unmasked_chars = 0;
177 let mut masked_chars = 0;
178
179 for (&token, &is_masked) in tokens.iter().zip(mask.iter()) {
180 if is_masked {
181 masked_chars += token.len();
182 } else {
183 unmasked_tokens.push(token);
184 unmasked_chars += token.len();
185 }
186 }
187
188 (unmasked_tokens, unmasked_chars, masked_chars)
189}
190
191fn should_bail_for_dirty_final(base: &str, predicted: &str, final_text: &str) -> bool {
192 let predicted_delta_chars = predicted.len().abs_diff(base.len());
193 let final_delta_chars = final_text.len().abs_diff(base.len());
194 predicted_delta_chars.abs_diff(final_delta_chars) > MAX_DIRTY_LENGTH_DELTA_CHARS
195}
196
197pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptRateResult {
198 if base == predicted && predicted == final_text {
199 let predicted_tokens = tokenize(predicted);
200 let context_chars = predicted_tokens.iter().map(|token| token.len()).sum();
201 return KeptRateResult {
202 predicted_new_chars: 0,
203 final_new_chars: 0,
204 kept_chars: 0,
205 discarded_chars: 0,
206 context_chars,
207 kept_rate: 1.0,
208 #[cfg(test)]
209 token_annotations: vec![TokenAnnotation::Context; predicted_tokens.len()],
210 };
211 }
212
213 if should_bail_for_dirty_final(base, predicted, final_text) {
214 let predicted_new_chars = predicted.len().abs_diff(base.len());
215 let final_new_chars = final_text.len().abs_diff(base.len());
216 return KeptRateResult {
217 predicted_new_chars,
218 final_new_chars,
219 kept_chars: 0,
220 discarded_chars: predicted_new_chars,
221 context_chars: 0,
222 kept_rate: 0.0,
223 #[cfg(test)]
224 token_annotations: vec![TokenAnnotation::Discarded; tokenize(predicted).len()],
225 };
226 }
227
228 let base_tokens = tokenize(base);
229 let predicted_tokens = tokenize(predicted);
230 let final_tokens = tokenize(final_text);
231
232 let pred_base_mask = lcs_keep_mask(&predicted_tokens, &base_tokens);
233 let (pred_final_mask, final_pred_mask) = lcs_keep_masks(&predicted_tokens, &final_tokens);
234 let context_mask: Vec<bool> = pred_base_mask
235 .iter()
236 .zip(pred_final_mask.iter())
237 .map(|(&in_base, &in_final)| in_base && in_final)
238 .collect();
239
240 let (stripped_predicted, predicted_new_chars, context_chars) =
241 analyze_masked_tokens(&predicted_tokens, &context_mask);
242
243 let final_base_mask = lcs_keep_mask(&final_tokens, &base_tokens);
244 let final_context_mask: Vec<bool> = final_base_mask
245 .iter()
246 .zip(final_pred_mask.iter())
247 .map(|(&in_base, &in_predicted)| in_base && in_predicted)
248 .collect();
249
250 let (stripped_final, final_new_chars, _) =
251 analyze_masked_tokens(&final_tokens, &final_context_mask);
252
253 let keep_mask = lcs_keep_mask(&stripped_predicted, &stripped_final);
254
255 let kept_chars: usize = stripped_predicted
256 .iter()
257 .zip(keep_mask.iter())
258 .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
259 .sum();
260
261 let discarded_chars = predicted_new_chars - kept_chars;
262
263 let kept_rate = if predicted_new_chars == 0 {
264 if final_new_chars == 0 { 1.0 } else { 0.0 }
265 } else {
266 kept_chars as f64 / predicted_new_chars as f64
267 };
268
269 #[cfg(test)]
270 let token_annotations = {
271 let mut token_annotations = Vec::with_capacity(predicted_tokens.len());
272 let mut new_index = 0;
273 for (token_index, _token) in predicted_tokens.iter().enumerate() {
274 if context_mask[token_index] {
275 token_annotations.push(TokenAnnotation::Context);
276 } else {
277 let annotation = if keep_mask[new_index] {
278 TokenAnnotation::Kept
279 } else {
280 TokenAnnotation::Discarded
281 };
282 #[cfg(test)]
283 token_annotations.push(annotation);
284 new_index += 1;
285 }
286 }
287 token_annotations
288 };
289
290 KeptRateResult {
291 predicted_new_chars,
292 final_new_chars,
293 kept_chars,
294 discarded_chars,
295 context_chars,
296 kept_rate,
297 #[cfg(test)]
298 token_annotations,
299 }
300}
301
302#[cfg(test)]
303mod test_kept_rate {
304 use super::*;
305
306 #[test]
307 fn test_lcs_keep_masks() {
308 let (a_mask, b_mask) = lcs_keep_masks(&["a", "b", "c", "d", "e"], &["a", "c", "e"]);
309 assert_eq!(a_mask, vec![true, false, true, false, true]);
310 assert_eq!(b_mask, vec![true, true, true]);
311
312 let (a_mask, b_mask) = lcs_keep_masks(&[], &["x"]);
313 assert!(a_mask.is_empty());
314 assert_eq!(b_mask, vec![false]);
315 }
316
317 #[test]
318 fn test_lcs_keep_masks_matches_historical_one_sided_masks() {
319 let a = ["x", "a", "x", "b"];
320 let b = ["a", "x", "b", "x"];
321 let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
322 assert_eq!(a_mask, lcs_keep_mask(&a, &b));
323 assert_eq!(b_mask, lcs_keep_mask(&b, &a));
324 }
325
326 #[test]
327 fn test_rate_extremes() {
328 let no_change = compute_kept_rate("foo bar", "foo bar", "foo bar");
329 assert!((no_change.kept_rate - 1.0).abs() < 1e-6);
330 assert_eq!(no_change.predicted_new_chars, 0);
331 assert!(
332 no_change
333 .token_annotations
334 .iter()
335 .all(|&annotation| annotation == TokenAnnotation::Context)
336 );
337
338 let accepted = compute_kept_rate("old", "new", "new");
339 assert!((accepted.kept_rate - 1.0).abs() < 1e-6);
340
341 let discarded = compute_kept_rate("old", "old", "new");
342 assert!((discarded.kept_rate - 0.0).abs() < 1e-6);
343 }
344
345 #[test]
346 fn test_pure_addition() {
347 let kept = compute_kept_rate("", "brand new line\n", "brand new line\n");
348 assert_eq!(kept.kept_chars, kept.predicted_new_chars);
349 assert!(
350 kept.token_annotations
351 .iter()
352 .all(|&annotation| annotation == TokenAnnotation::Kept)
353 );
354
355 let discarded =
356 compute_kept_rate("", "brand new line\n", "something completely different\n");
357 assert!(discarded.kept_chars < discarded.predicted_new_chars);
358 }
359
360 #[test]
361 fn test_decoy_when_base_excluded() {
362 let base = " decoy.when(mock_sync_hardware_api.sp()).then_return(SpeedStatus.IDLE)\n";
363 let predicted = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
364 let final_text = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
365 let result = compute_kept_rate(base, predicted, final_text);
366 let expected_new = "mock_sync_module_hardware".len() + "speed_status".len();
367 assert_eq!(result.predicted_new_chars, expected_new);
368 assert!((result.kept_rate - 1.0).abs() < 1e-6);
369 }
370
371 #[test]
372 fn test_missing_deletion() {
373 let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
374 let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\neprintln!(\"\");\n";
375 let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
376 let result = compute_kept_rate(base, predicted, final_text);
377 assert!(
378 result.kept_rate < 0.85,
379 "expected kept_rate < 0.85, got {}",
380 result.kept_rate
381 );
382 assert!(result.discarded_chars > 0);
383 }
384
385 #[test]
386 fn test_empty_prediction() {
387 let result = compute_kept_rate("old line\n", "", "new line\n");
388 assert!((result.kept_rate - 0.0).abs() < 1e-6);
389 }
390
391 #[test]
392 fn test_partial_kept() {
393 let result = compute_kept_rate("old\n", "alpha\nbeta\ngamma\n", "alpha\ngamma\n");
394 assert!(result.kept_chars > 0);
395 assert!(result.discarded_chars > 0);
396 assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
397 }
398
399 #[test]
400 fn test_bails_for_dirty_final() {
401 let base = "fn example() {\n work();\n}\n";
402 let predicted = "fn example() {\n work();\n predicted();\n}\n";
403 let final_text = format!(
404 "fn example() {{\n work();\n {}\n}}\n",
405 "settled();\n ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64)
406 );
407
408 let result = compute_kept_rate(base, predicted, &final_text);
409 assert_eq!(result.kept_rate, 0.0);
410 assert_eq!(result.kept_chars, 0);
411 assert_eq!(result.discarded_chars, result.predicted_new_chars);
412 }
413
414 #[test]
415 fn test_eprintln_token_alignment() {
416 let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
417 let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
418 let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
419 let result = compute_kept_rate(base, predicted, final_text);
420 assert!(result.discarded_chars > 0);
421 assert!(result.kept_chars > 0);
422 assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
423 assert_eq!(result.kept_chars, 14);
424 assert_eq!(result.discarded_chars, 12);
425 }
426
427 #[test]
428 fn test_annotations_rename() {
429 let base = " foo(old_name)\n";
430 let predicted = " foo(new_name)\n";
431 let final_text = " foo(new_name)\n";
432 let result = compute_kept_rate(base, predicted, final_text);
433
434 assert_eq!(result.predicted_new_chars, "new_name".len());
435 assert_eq!(result.token_annotations.len(), tokenize(predicted).len());
436
437 for (&token, &annotation) in tokenize(predicted).iter().zip(&result.token_annotations) {
438 if token == "new_name" {
439 assert_eq!(annotation, TokenAnnotation::Kept);
440 } else {
441 assert_eq!(annotation, TokenAnnotation::Context);
442 }
443 }
444 }
445
446 #[test]
447 fn test_annotations_eprintln_coloring() {
448 let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
449 let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
450 let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
451 let result = compute_kept_rate(base, predicted, final_text);
452 let predicted_tokens = tokenize(predicted);
453
454 let eprintln_index = predicted_tokens
455 .iter()
456 .position(|&token| token == "eprintln")
457 .expect("eprintln token not found");
458
459 for annotation in &result.token_annotations[..eprintln_index] {
460 assert_eq!(*annotation, TokenAnnotation::Context);
461 }
462
463 assert_eq!(
464 &result.token_annotations[eprintln_index..=eprintln_index + 10],
465 &[
466 TokenAnnotation::Kept,
467 TokenAnnotation::Kept,
468 TokenAnnotation::Kept,
469 TokenAnnotation::Kept,
470 TokenAnnotation::Discarded,
471 TokenAnnotation::Discarded,
472 TokenAnnotation::Discarded,
473 TokenAnnotation::Discarded,
474 TokenAnnotation::Kept,
475 TokenAnnotation::Kept,
476 TokenAnnotation::Kept,
477 ]
478 );
479 assert_eq!(
480 result.token_annotations.last(),
481 Some(&TokenAnnotation::Context)
482 );
483 }
484
485 #[test]
486 fn test_repetitive_tokens_remain_discarded() {
487 let base = "foo + foo + foo + foo + foo\n".repeat(16);
488 let predicted = "foo + foo + prediction_token + foo + foo\n".repeat(16);
489 let final_text = "foo + foo + kept_token + foo + foo\n".repeat(16);
490 let result = compute_kept_rate(&base, &predicted, &final_text);
491
492 assert_eq!(result.kept_chars, 0);
493 assert_eq!(result.discarded_chars, result.predicted_new_chars);
494 assert_eq!(result.predicted_new_chars, "prediction_token".len() * 16);
495 }
496}