1use crate::metrics::tokenize;
2
3const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
4
5#[cfg(test)]
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum TokenAnnotation {
8 Context,
9 Kept,
10 Discarded,
11}
12
13#[allow(dead_code)]
14#[derive(Debug, Clone)]
15pub struct KeptRateResult {
16 /// Characters newly introduced by the candidate
17 pub candidate_new_chars: usize,
18 /// Characters newly introduced by the reference
19 pub reference_new_chars: usize,
20 /// Characters from `base` that are deleted by the candidate.
21 pub candidate_deleted_chars: usize,
22 /// Characters from `base` that are deleted by the reference.
23 pub reference_deleted_chars: usize,
24 /// Candidate new characters that are also present in the reference.
25 pub kept_chars: usize,
26 /// Base characters deleted by both the candidate and the reference.
27 pub correctly_deleted_chars: usize,
28 /// Candidate new characters that are not kept in the reference.
29 pub discarded_chars: usize,
30 /// Candidate characters treated as unchanged context
31 pub context_chars: usize,
32 /// Fraction of candidate edit characters that match the reference edit.
33 ///
34 /// This includes both kept newly introduced characters and correctly
35 /// deleted base characters.
36 pub kept_rate: f64,
37 /// Fraction of reference edit characters covered by the candidate edit.
38 ///
39 /// This includes both kept newly introduced characters and correctly
40 /// deleted base characters.
41 pub recall_rate: f64,
42 /// Per-token classification for candidate tokens used by tests.
43 #[cfg(test)]
44 pub token_annotations: Vec<TokenAnnotation>,
45}
46
47fn dp_index(width: usize, row: usize, column: usize) -> usize {
48 row * width + column
49}
50
51/// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side
52/// while sharing a single DP table construction.
53fn fill_lcs_keep_masks(
54 a: &[&str],
55 b: &[&str],
56 mut keep_a: Option<&mut [bool]>,
57 mut keep_b: Option<&mut [bool]>,
58) {
59 if a.is_empty() || b.is_empty() {
60 return;
61 }
62
63 if a == b {
64 if let Some(keep_a) = keep_a.as_mut() {
65 keep_a.fill(true);
66 }
67 if let Some(keep_b) = keep_b.as_mut() {
68 keep_b.fill(true);
69 }
70 return;
71 }
72
73 let prefix_len = a
74 .iter()
75 .zip(b.iter())
76 .take_while(|(left, right)| left == right)
77 .count();
78 let suffix_len = {
79 let max_suffix = (a.len() - prefix_len).min(b.len() - prefix_len);
80 let mut suffix_len = 0;
81
82 while suffix_len < max_suffix {
83 let a_index = a.len() - 1 - suffix_len;
84 let b_index = b.len() - 1 - suffix_len;
85 if a[a_index] != b[b_index] {
86 break;
87 }
88 suffix_len += 1;
89 }
90
91 suffix_len
92 };
93
94 for index in 0..prefix_len {
95 if let Some(keep_a) = keep_a.as_mut() {
96 keep_a[index] = true;
97 }
98 if let Some(keep_b) = keep_b.as_mut() {
99 keep_b[index] = true;
100 }
101 }
102
103 for offset in 0..suffix_len {
104 let a_index = a.len() - suffix_len + offset;
105 let b_index = b.len() - suffix_len + offset;
106 if let Some(keep_a) = keep_a.as_mut() {
107 keep_a[a_index] = true;
108 }
109 if let Some(keep_b) = keep_b.as_mut() {
110 keep_b[b_index] = true;
111 }
112 }
113
114 let a_mid = &a[prefix_len..a.len() - suffix_len];
115 let b_mid = &b[prefix_len..b.len() - suffix_len];
116
117 if a_mid.is_empty() || b_mid.is_empty() {
118 return;
119 }
120
121 let row_count = a_mid.len() + 1;
122 let column_count = b_mid.len() + 1;
123 let mut dp = vec![0u32; row_count * column_count];
124
125 for i in 1..row_count {
126 let token_a = a_mid[i - 1];
127 for j in 1..column_count {
128 let index = dp_index(column_count, i, j);
129 if token_a == b_mid[j - 1] {
130 dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
131 } else {
132 let up = dp[dp_index(column_count, i - 1, j)];
133 let left = dp[dp_index(column_count, i, j - 1)];
134 dp[index] = up.max(left);
135 }
136 }
137 }
138
139 if let Some(keep_a) = keep_a.as_mut() {
140 let mut i = a_mid.len();
141 let mut j = b_mid.len();
142
143 while i > 0 && j > 0 {
144 if a_mid[i - 1] == b_mid[j - 1] {
145 keep_a[prefix_len + i - 1] = true;
146 i -= 1;
147 j -= 1;
148 } else {
149 let up = dp[dp_index(column_count, i - 1, j)];
150 let left = dp[dp_index(column_count, i, j - 1)];
151 if up >= left {
152 i -= 1;
153 } else {
154 j -= 1;
155 }
156 }
157 }
158 }
159
160 if let Some(keep_b) = keep_b.as_mut() {
161 let mut i = a_mid.len();
162 let mut j = b_mid.len();
163
164 while i > 0 && j > 0 {
165 if a_mid[i - 1] == b_mid[j - 1] {
166 keep_b[prefix_len + j - 1] = true;
167 i -= 1;
168 j -= 1;
169 } else {
170 let up = dp[dp_index(column_count, i - 1, j)];
171 let left = dp[dp_index(column_count, i, j - 1)];
172 if left >= up {
173 j -= 1;
174 } else {
175 i -= 1;
176 }
177 }
178 }
179 }
180}
181
182fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec<bool> {
183 let mut keep_a = vec![false; a.len()];
184 fill_lcs_keep_masks(a, b, Some(&mut keep_a), None);
185 keep_a
186}
187
188fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
189 let mut keep_a = vec![false; a.len()];
190 let mut keep_b = vec![false; b.len()];
191 fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b));
192 (keep_a, keep_b)
193}
194
195fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) {
196 let mut unmasked_tokens = Vec::with_capacity(tokens.len());
197 let mut unmasked_chars = 0;
198 let mut masked_chars = 0;
199
200 for (&token, &is_masked) in tokens.iter().zip(mask.iter()) {
201 if is_masked {
202 masked_chars += token.len();
203 } else {
204 unmasked_tokens.push(token);
205 unmasked_chars += token.len();
206 }
207 }
208
209 (unmasked_tokens, unmasked_chars, masked_chars)
210}
211
212fn count_unmasked_chars(tokens: &[&str], mask: &[bool]) -> usize {
213 tokens
214 .iter()
215 .zip(mask.iter())
216 .filter_map(|(&token, &is_masked)| (!is_masked).then_some(token.len()))
217 .sum()
218}
219
220fn should_bail_for_dirty_final(base: &str, candidate: &str, reference: &str) -> bool {
221 let candidate_delta_chars = candidate.len().abs_diff(base.len());
222 let reference_delta_chars = reference.len().abs_diff(base.len());
223 candidate_delta_chars.abs_diff(reference_delta_chars) > MAX_DIRTY_LENGTH_DELTA_CHARS
224}
225
226pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRateResult {
227 if base == candidate && candidate == reference {
228 let candidate_tokens = tokenize(candidate);
229 let context_chars = candidate_tokens.iter().map(|token| token.len()).sum();
230 return KeptRateResult {
231 candidate_new_chars: 0,
232 reference_new_chars: 0,
233 candidate_deleted_chars: 0,
234 reference_deleted_chars: 0,
235 kept_chars: 0,
236 correctly_deleted_chars: 0,
237 discarded_chars: 0,
238 context_chars,
239 kept_rate: 1.0,
240 recall_rate: 1.0,
241 #[cfg(test)]
242 token_annotations: vec![TokenAnnotation::Context; candidate_tokens.len()],
243 };
244 }
245
246 if should_bail_for_dirty_final(base, candidate, reference) {
247 let candidate_new_chars = candidate.len().abs_diff(base.len());
248 let reference_new_chars = reference.len().abs_diff(base.len());
249 return KeptRateResult {
250 candidate_new_chars,
251 reference_new_chars,
252 candidate_deleted_chars: 0,
253 reference_deleted_chars: 0,
254 kept_chars: 0,
255 correctly_deleted_chars: 0,
256 discarded_chars: candidate_new_chars,
257 context_chars: 0,
258 kept_rate: 0.0,
259 recall_rate: 0.0,
260 #[cfg(test)]
261 token_annotations: vec![TokenAnnotation::Discarded; tokenize(candidate).len()],
262 };
263 }
264
265 let base_tokens = tokenize(base);
266 let candidate_tokens = tokenize(candidate);
267 let reference_tokens = tokenize(reference);
268
269 let (candidate_base_mask, base_candidate_mask) =
270 lcs_keep_masks(&candidate_tokens, &base_tokens);
271 let (candidate_reference_mask, reference_candidate_mask) =
272 lcs_keep_masks(&candidate_tokens, &reference_tokens);
273 let context_mask: Vec<bool> = candidate_base_mask
274 .iter()
275 .zip(candidate_reference_mask.iter())
276 .map(|(&in_base, &in_reference)| in_base && in_reference)
277 .collect();
278
279 let (stripped_candidate, candidate_new_chars, context_chars) =
280 analyze_masked_tokens(&candidate_tokens, &context_mask);
281
282 let (reference_base_mask, base_reference_mask) =
283 lcs_keep_masks(&reference_tokens, &base_tokens);
284 let reference_context_mask: Vec<bool> = reference_base_mask
285 .iter()
286 .zip(reference_candidate_mask.iter())
287 .map(|(&in_base, &in_candidate)| in_base && in_candidate)
288 .collect();
289
290 let (stripped_reference, reference_new_chars, _) =
291 analyze_masked_tokens(&reference_tokens, &reference_context_mask);
292
293 let keep_mask = lcs_keep_mask(&stripped_candidate, &stripped_reference);
294
295 let kept_chars: usize = stripped_candidate
296 .iter()
297 .zip(keep_mask.iter())
298 .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
299 .sum();
300
301 let candidate_deleted_chars = count_unmasked_chars(&base_tokens, &base_candidate_mask);
302 let reference_deleted_chars = count_unmasked_chars(&base_tokens, &base_reference_mask);
303 let correctly_deleted_chars: usize = base_tokens
304 .iter()
305 .zip(base_candidate_mask.iter().zip(base_reference_mask.iter()))
306 .filter_map(|(&token, (&in_candidate, &in_reference))| {
307 (!in_candidate && !in_reference).then_some(token.len())
308 })
309 .sum();
310
311 let discarded_chars = candidate_new_chars - kept_chars;
312 let matched_edit_chars = kept_chars + correctly_deleted_chars;
313 let candidate_edit_chars = candidate_new_chars + candidate_deleted_chars;
314 let reference_edit_chars = reference_new_chars + reference_deleted_chars;
315
316 let kept_rate = if candidate_edit_chars == 0 {
317 if reference_edit_chars == 0 { 1.0 } else { 0.0 }
318 } else {
319 matched_edit_chars as f64 / candidate_edit_chars as f64
320 };
321
322 let recall_rate = if reference_edit_chars == 0 {
323 if candidate_edit_chars == 0 { 1.0 } else { 0.0 }
324 } else {
325 matched_edit_chars as f64 / reference_edit_chars as f64
326 };
327
328 #[cfg(test)]
329 let token_annotations = {
330 let mut token_annotations = Vec::with_capacity(candidate_tokens.len());
331 let mut new_index = 0;
332 for (token_index, _token) in candidate_tokens.iter().enumerate() {
333 if context_mask[token_index] {
334 token_annotations.push(TokenAnnotation::Context);
335 } else {
336 let annotation = if keep_mask[new_index] {
337 TokenAnnotation::Kept
338 } else {
339 TokenAnnotation::Discarded
340 };
341 #[cfg(test)]
342 token_annotations.push(annotation);
343 new_index += 1;
344 }
345 }
346 token_annotations
347 };
348
349 KeptRateResult {
350 candidate_new_chars,
351 reference_new_chars,
352 candidate_deleted_chars,
353 reference_deleted_chars,
354 kept_chars,
355 correctly_deleted_chars,
356 discarded_chars,
357 context_chars,
358 kept_rate,
359 recall_rate,
360 #[cfg(test)]
361 token_annotations,
362 }
363}
364
365#[cfg(test)]
366mod test_kept_rate {
367 use super::*;
368
369 #[test]
370 fn test_lcs_keep_masks() {
371 let (a_mask, b_mask) = lcs_keep_masks(&["a", "b", "c", "d", "e"], &["a", "c", "e"]);
372 assert_eq!(a_mask, vec![true, false, true, false, true]);
373 assert_eq!(b_mask, vec![true, true, true]);
374
375 let (a_mask, b_mask) = lcs_keep_masks(&[], &["x"]);
376 assert!(a_mask.is_empty());
377 assert_eq!(b_mask, vec![false]);
378 }
379
380 #[test]
381 fn test_lcs_keep_masks_matches_historical_one_sided_masks() {
382 let a = ["x", "a", "x", "b"];
383 let b = ["a", "x", "b", "x"];
384 let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
385 assert_eq!(a_mask, lcs_keep_mask(&a, &b));
386 assert_eq!(b_mask, lcs_keep_mask(&b, &a));
387 }
388
389 #[test]
390 fn test_rate_extremes() {
391 let no_change = compute_kept_rate("foo bar", "foo bar", "foo bar");
392 assert!((no_change.kept_rate - 1.0).abs() < 1e-6);
393 assert!((no_change.recall_rate - 1.0).abs() < 1e-6);
394 assert_eq!(no_change.candidate_new_chars, 0);
395 assert!(
396 no_change
397 .token_annotations
398 .iter()
399 .all(|&annotation| annotation == TokenAnnotation::Context)
400 );
401
402 let accepted = compute_kept_rate("old", "new", "new");
403 assert!((accepted.kept_rate - 1.0).abs() < 1e-6);
404 assert!((accepted.recall_rate - 1.0).abs() < 1e-6);
405
406 let discarded = compute_kept_rate("old", "old", "new");
407 assert!((discarded.kept_rate - 0.0).abs() < 1e-6);
408 assert!((discarded.recall_rate - 0.0).abs() < 1e-6);
409 }
410
411 #[test]
412 fn test_pure_addition() {
413 let kept = compute_kept_rate("", "brand new line\n", "brand new line\n");
414 assert_eq!(kept.kept_chars, kept.candidate_new_chars);
415 assert!(
416 kept.token_annotations
417 .iter()
418 .all(|&annotation| annotation == TokenAnnotation::Kept)
419 );
420
421 let discarded =
422 compute_kept_rate("", "brand new line\n", "something completely different\n");
423 assert!(discarded.kept_chars < discarded.candidate_new_chars);
424 }
425
426 #[test]
427 fn test_decoy_when_base_excluded() {
428 let base = " decoy.when(mock_sync_hardware_api.sp()).then_return(SpeedStatus.IDLE)\n";
429 let candidate = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
430 let reference = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
431 let result = compute_kept_rate(base, candidate, reference);
432 let expected_new = "mock_sync_module_hardware".len() + "speed_status".len();
433 assert_eq!(result.candidate_new_chars, expected_new);
434 assert!(result.correctly_deleted_chars > 0);
435 assert!((result.kept_rate - 1.0).abs() < 1e-6);
436 assert!((result.recall_rate - 1.0).abs() < 1e-6);
437 }
438
439 #[test]
440 fn test_missing_deletion() {
441 let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
442 let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\neprintln!(\"\");\n";
443 let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
444 let result = compute_kept_rate(base, candidate, reference);
445 assert!(
446 result.kept_rate < 0.85,
447 "expected kept_rate < 0.85, got {}",
448 result.kept_rate
449 );
450 assert!(result.discarded_chars > 0);
451 }
452
453 #[test]
454 fn test_empty_prediction() {
455 let result = compute_kept_rate("old line\n", "", "new line\n");
456 assert_eq!(result.candidate_new_chars, 0);
457 assert!(result.candidate_deleted_chars > 0);
458 assert!(result.correctly_deleted_chars > 0);
459 assert!(result.correctly_deleted_chars < result.candidate_deleted_chars);
460 assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
461 assert!(result.recall_rate > 0.0 && result.recall_rate < 1.0);
462 }
463
464 #[test]
465 fn test_partial_kept() {
466 let result = compute_kept_rate("old\n", "alpha\nbeta\ngamma\n", "alpha\ngamma\n");
467 assert!(result.kept_chars > 0);
468 assert!(result.discarded_chars > 0);
469 assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
470 }
471
472 #[test]
473 fn test_bails_for_dirty_final() {
474 let base = "fn example() {\n work();\n}\n";
475 let candidate = "fn example() {\n work();\n predicted();\n}\n";
476 let reference = format!(
477 "fn example() {{\n work();\n {}\n}}\n",
478 "settled();\n ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64)
479 );
480
481 let result = compute_kept_rate(base, candidate, &reference);
482 assert_eq!(result.kept_rate, 0.0);
483 assert_eq!(result.recall_rate, 0.0);
484 assert_eq!(result.kept_chars, 0);
485 assert_eq!(result.discarded_chars, result.candidate_new_chars);
486 }
487
488 #[test]
489 fn test_eprintln_token_alignment() {
490 let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
491 let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
492 let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
493 let result = compute_kept_rate(base, candidate, reference);
494 assert!(result.discarded_chars > 0);
495 assert!(result.kept_chars > 0);
496 assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
497 assert_eq!(result.kept_chars, 14);
498 assert_eq!(result.discarded_chars, 12);
499 }
500
501 #[test]
502 fn test_annotations_rename() {
503 let base = " foo(old_name)\n";
504 let candidate = " foo(new_name)\n";
505 let reference = " foo(new_name)\n";
506 let result = compute_kept_rate(base, candidate, reference);
507
508 assert_eq!(result.candidate_new_chars, "new_name".len());
509 assert_eq!(result.candidate_deleted_chars, "old_name".len());
510 assert_eq!(result.reference_deleted_chars, "old_name".len());
511 assert_eq!(result.correctly_deleted_chars, "old_name".len());
512 assert!((result.recall_rate - 1.0).abs() < 1e-6);
513 assert_eq!(result.token_annotations.len(), tokenize(candidate).len());
514
515 for (&token, &annotation) in tokenize(candidate).iter().zip(&result.token_annotations) {
516 if token == "new_name" {
517 assert_eq!(annotation, TokenAnnotation::Kept);
518 } else {
519 assert_eq!(annotation, TokenAnnotation::Context);
520 }
521 }
522 }
523
524 #[test]
525 fn test_annotations_eprintln_coloring() {
526 let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
527 let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
528 let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
529 let result = compute_kept_rate(base, candidate, reference);
530 let candidate_tokens = tokenize(candidate);
531
532 let eprintln_index = candidate_tokens
533 .iter()
534 .position(|&token| token == "eprintln")
535 .expect("eprintln token not found");
536
537 for annotation in &result.token_annotations[..eprintln_index] {
538 assert_eq!(*annotation, TokenAnnotation::Context);
539 }
540
541 assert_eq!(
542 &result.token_annotations[eprintln_index..=eprintln_index + 10],
543 &[
544 TokenAnnotation::Kept,
545 TokenAnnotation::Kept,
546 TokenAnnotation::Kept,
547 TokenAnnotation::Kept,
548 TokenAnnotation::Discarded,
549 TokenAnnotation::Discarded,
550 TokenAnnotation::Discarded,
551 TokenAnnotation::Discarded,
552 TokenAnnotation::Kept,
553 TokenAnnotation::Kept,
554 TokenAnnotation::Kept,
555 ]
556 );
557 assert_eq!(
558 result.token_annotations.last(),
559 Some(&TokenAnnotation::Context)
560 );
561 }
562
563 #[test]
564 fn test_repetitive_tokens_remain_discarded() {
565 let base = "foo + foo + foo + foo + foo\n".repeat(16);
566 let candidate = "foo + foo + prediction_token + foo + foo\n".repeat(16);
567 let reference = "foo + foo + kept_token + foo + foo\n".repeat(16);
568 let result = compute_kept_rate(&base, &candidate, &reference);
569
570 assert_eq!(result.kept_chars, 0);
571 assert_eq!(result.correctly_deleted_chars, "foo".len() * 16);
572 assert_eq!(result.discarded_chars, result.candidate_new_chars);
573 assert_eq!(result.candidate_new_chars, "prediction_token".len() * 16);
574 assert!(result.kept_rate > 0.0);
575 assert!(result.recall_rate > 0.0);
576 }
577}