1use crate::word_diff::tokenize;
2
3#[cfg(test)]
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum TokenAnnotation {
6 Context,
7 Kept,
8 Discarded,
9}
10
11#[allow(dead_code)]
12#[derive(Debug, Clone)]
13pub struct KeptRateResult {
14 /// Characters newly introduced by the candidate
15 pub candidate_new_chars: usize,
16 /// Characters newly introduced by the reference
17 pub reference_new_chars: usize,
18 /// Characters from `base` that are deleted by the candidate.
19 pub candidate_deleted_chars: usize,
20 /// Characters from `base` that are deleted by the reference.
21 pub reference_deleted_chars: usize,
22 /// Candidate new characters that are also present in the reference.
23 pub kept_chars: usize,
24 /// Base characters deleted by both the candidate and the reference.
25 pub correctly_deleted_chars: usize,
26 /// Candidate new characters that are not kept in the reference.
27 pub discarded_chars: usize,
28 /// Candidate characters treated as unchanged context
29 pub context_chars: usize,
30 /// Fraction of candidate edit characters that match the reference edit.
31 ///
32 /// This includes both kept newly introduced characters and correctly
33 /// deleted base characters.
34 pub kept_rate: f64,
35 /// Fraction of reference edit characters covered by the candidate edit.
36 ///
37 /// This includes both kept newly introduced characters and correctly
38 /// deleted base characters.
39 pub recall_rate: f64,
40 /// Per-token classification for candidate tokens used by tests.
41 #[cfg(test)]
42 pub token_annotations: Vec<TokenAnnotation>,
43}
44
45fn dp_index(width: usize, row: usize, column: usize) -> usize {
46 row * width + column
47}
48
49/// Return masks over `a` and `b` using one-sided LCS tie-breaking for each
50/// side while sharing a single DP table construction.
51fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
52 if a.is_empty() || b.is_empty() {
53 return (vec![false; a.len()], vec![false; b.len()]);
54 }
55
56 if a == b {
57 return (vec![true; a.len()], vec![true; b.len()]);
58 }
59
60 let mut keep_a = vec![false; a.len()];
61 let mut keep_b = vec![false; b.len()];
62
63 let prefix_len = a
64 .iter()
65 .zip(b.iter())
66 .take_while(|(left, right)| left == right)
67 .count();
68 let suffix_len = {
69 let max_suffix = (a.len() - prefix_len).min(b.len() - prefix_len);
70 let mut suffix_len = 0;
71
72 while suffix_len < max_suffix {
73 let a_index = a.len() - 1 - suffix_len;
74 let b_index = b.len() - 1 - suffix_len;
75 if a[a_index] != b[b_index] {
76 break;
77 }
78 suffix_len += 1;
79 }
80
81 suffix_len
82 };
83
84 for index in 0..prefix_len {
85 keep_a[index] = true;
86 keep_b[index] = true;
87 }
88
89 for offset in 0..suffix_len {
90 let a_index = a.len() - suffix_len + offset;
91 let b_index = b.len() - suffix_len + offset;
92 keep_a[a_index] = true;
93 keep_b[b_index] = true;
94 }
95
96 let a_mid = &a[prefix_len..a.len() - suffix_len];
97 let b_mid = &b[prefix_len..b.len() - suffix_len];
98
99 if a_mid.is_empty() || b_mid.is_empty() {
100 return (keep_a, keep_b);
101 }
102
103 let row_count = a_mid.len() + 1;
104 let column_count = b_mid.len() + 1;
105 let mut dp = vec![0u32; row_count * column_count];
106
107 for i in 1..row_count {
108 let token_a = a_mid[i - 1];
109 for j in 1..column_count {
110 let index = dp_index(column_count, i, j);
111 if token_a == b_mid[j - 1] {
112 dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
113 } else {
114 let up = dp[dp_index(column_count, i - 1, j)];
115 let left = dp[dp_index(column_count, i, j - 1)];
116 dp[index] = up.max(left);
117 }
118 }
119 }
120
121 let mut i = a_mid.len();
122 let mut j = b_mid.len();
123
124 while i > 0 && j > 0 {
125 if a_mid[i - 1] == b_mid[j - 1] {
126 keep_a[prefix_len + i - 1] = true;
127 i -= 1;
128 j -= 1;
129 } else {
130 let up = dp[dp_index(column_count, i - 1, j)];
131 let left = dp[dp_index(column_count, i, j - 1)];
132 if up >= left {
133 i -= 1;
134 } else {
135 j -= 1;
136 }
137 }
138 }
139
140 let mut i = a_mid.len();
141 let mut j = b_mid.len();
142
143 while i > 0 && j > 0 {
144 if a_mid[i - 1] == b_mid[j - 1] {
145 keep_b[prefix_len + j - 1] = true;
146 i -= 1;
147 j -= 1;
148 } else {
149 let up = dp[dp_index(column_count, i - 1, j)];
150 let left = dp[dp_index(column_count, i, j - 1)];
151 if left >= up {
152 j -= 1;
153 } else {
154 i -= 1;
155 }
156 }
157 }
158
159 (keep_a, keep_b)
160}
161
162fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) {
163 let mut unmasked_tokens = Vec::with_capacity(tokens.len());
164 let mut unmasked_chars = 0;
165 let mut masked_chars = 0;
166
167 for (&token, &is_masked) in tokens.iter().zip(mask.iter()) {
168 if is_masked {
169 masked_chars += token.len();
170 } else {
171 unmasked_tokens.push(token);
172 unmasked_chars += token.len();
173 }
174 }
175
176 (unmasked_tokens, unmasked_chars, masked_chars)
177}
178
179fn count_unmasked_chars(tokens: &[&str], mask: &[bool]) -> usize {
180 tokens
181 .iter()
182 .zip(mask.iter())
183 .filter_map(|(&token, &is_masked)| (!is_masked).then_some(token.len()))
184 .sum()
185}
186
187pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRateResult {
188 if base == candidate && candidate == reference {
189 let candidate_tokens = tokenize(candidate);
190 let context_chars = candidate_tokens.iter().map(|token| token.len()).sum();
191 return KeptRateResult {
192 candidate_new_chars: 0,
193 reference_new_chars: 0,
194 candidate_deleted_chars: 0,
195 reference_deleted_chars: 0,
196 kept_chars: 0,
197 correctly_deleted_chars: 0,
198 discarded_chars: 0,
199 context_chars,
200 kept_rate: 1.0,
201 recall_rate: 1.0,
202 #[cfg(test)]
203 token_annotations: vec![TokenAnnotation::Context; candidate_tokens.len()],
204 };
205 }
206
207 let base_tokens = tokenize(base);
208 let candidate_tokens = tokenize(candidate);
209 let reference_tokens = tokenize(reference);
210
211 let (candidate_base_mask, base_candidate_mask) =
212 lcs_keep_masks(&candidate_tokens, &base_tokens);
213 let (candidate_reference_mask, reference_candidate_mask) =
214 lcs_keep_masks(&candidate_tokens, &reference_tokens);
215 let context_mask: Vec<bool> = candidate_base_mask
216 .iter()
217 .zip(candidate_reference_mask.iter())
218 .map(|(&in_base, &in_final)| in_base && in_final)
219 .collect();
220
221 let (stripped_candidate, candidate_new_chars, context_chars) =
222 analyze_masked_tokens(&candidate_tokens, &context_mask);
223
224 let (reference_base_mask, base_reference_mask) =
225 lcs_keep_masks(&reference_tokens, &base_tokens);
226 let reference_context_mask: Vec<bool> = reference_base_mask
227 .iter()
228 .zip(reference_candidate_mask.iter())
229 .map(|(&in_base, &in_candidate)| in_base && in_candidate)
230 .collect();
231
232 let (stripped_reference, reference_new_chars, _) =
233 analyze_masked_tokens(&reference_tokens, &reference_context_mask);
234
235 let keep_mask = lcs_keep_masks(&stripped_candidate, &stripped_reference).0;
236
237 let kept_chars: usize = stripped_candidate
238 .iter()
239 .zip(keep_mask.iter())
240 .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
241 .sum();
242
243 let candidate_deleted_chars = count_unmasked_chars(&base_tokens, &base_candidate_mask);
244 let reference_deleted_chars = count_unmasked_chars(&base_tokens, &base_reference_mask);
245 let correctly_deleted_chars: usize = base_tokens
246 .iter()
247 .zip(base_candidate_mask.iter().zip(base_reference_mask.iter()))
248 .filter_map(|(&token, (&in_candidate, &in_reference))| {
249 (!in_candidate && !in_reference).then_some(token.len())
250 })
251 .sum();
252
253 let discarded_chars = candidate_new_chars - kept_chars;
254 let matched_edit_chars = kept_chars + correctly_deleted_chars;
255 let candidate_edit_chars = candidate_new_chars + candidate_deleted_chars;
256 let reference_edit_chars = reference_new_chars + reference_deleted_chars;
257
258 let kept_rate = if candidate_edit_chars == 0 {
259 if reference_edit_chars == 0 { 1.0 } else { 0.0 }
260 } else {
261 matched_edit_chars as f64 / candidate_edit_chars as f64
262 };
263
264 let recall_rate = if reference_edit_chars == 0 {
265 if candidate_edit_chars == 0 { 1.0 } else { 0.0 }
266 } else {
267 matched_edit_chars as f64 / reference_edit_chars as f64
268 };
269
270 #[cfg(test)]
271 let token_annotations = {
272 let mut token_annotations = Vec::with_capacity(candidate_tokens.len());
273 let mut new_index = 0;
274 for (token_index, _token) in candidate_tokens.iter().enumerate() {
275 if context_mask[token_index] {
276 token_annotations.push(TokenAnnotation::Context);
277 } else {
278 let annotation = if keep_mask[new_index] {
279 TokenAnnotation::Kept
280 } else {
281 TokenAnnotation::Discarded
282 };
283 #[cfg(test)]
284 token_annotations.push(annotation);
285 new_index += 1;
286 }
287 }
288 token_annotations
289 };
290
291 KeptRateResult {
292 candidate_new_chars,
293 reference_new_chars,
294 candidate_deleted_chars,
295 reference_deleted_chars,
296 kept_chars,
297 correctly_deleted_chars,
298 discarded_chars,
299 context_chars,
300 kept_rate,
301 recall_rate,
302 #[cfg(test)]
303 token_annotations,
304 }
305}
306
307#[cfg(test)]
308mod test_kept_rate {
309 use super::*;
310
311 #[test]
312 fn test_lcs_keep_masks() {
313 let (a_mask, b_mask) = lcs_keep_masks(&["a", "b", "c", "d", "e"], &["a", "c", "e"]);
314 assert_eq!(a_mask, vec![true, false, true, false, true]);
315 assert_eq!(b_mask, vec![true, true, true]);
316
317 let (a_mask, b_mask) = lcs_keep_masks(&[], &["x"]);
318 assert!(a_mask.is_empty());
319 assert_eq!(b_mask, vec![false]);
320 }
321
322 #[test]
323 fn test_lcs_keep_masks_matches_historical_one_sided_masks() {
324 let a = ["x", "a", "x", "b"];
325 let b = ["a", "x", "b", "x"];
326 let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
327 assert_eq!(a_mask, lcs_keep_masks(&a, &b).0);
328 assert_eq!(b_mask, lcs_keep_masks(&b, &a).0);
329 }
330
331 #[test]
332 fn test_rate_extremes() {
333 let no_change = compute_kept_rate("foo bar", "foo bar", "foo bar");
334 assert!((no_change.kept_rate - 1.0).abs() < 1e-6);
335 assert!((no_change.recall_rate - 1.0).abs() < 1e-6);
336 assert_eq!(no_change.candidate_new_chars, 0);
337 assert!(
338 no_change
339 .token_annotations
340 .iter()
341 .all(|&annotation| annotation == TokenAnnotation::Context)
342 );
343
344 let accepted = compute_kept_rate("old", "new", "new");
345 assert!((accepted.kept_rate - 1.0).abs() < 1e-6);
346 assert!((accepted.recall_rate - 1.0).abs() < 1e-6);
347
348 let discarded = compute_kept_rate("old", "old", "new");
349 assert!((discarded.kept_rate - 0.0).abs() < 1e-6);
350 assert!((discarded.recall_rate - 0.0).abs() < 1e-6);
351 }
352
353 #[test]
354 fn test_pure_addition() {
355 let kept = compute_kept_rate("", "brand new line\n", "brand new line\n");
356 assert_eq!(kept.kept_chars, kept.candidate_new_chars);
357 assert!(
358 kept.token_annotations
359 .iter()
360 .all(|&annotation| annotation == TokenAnnotation::Kept)
361 );
362
363 let discarded =
364 compute_kept_rate("", "brand new line\n", "something completely different\n");
365 assert!(discarded.kept_chars < discarded.candidate_new_chars);
366 }
367
368 #[test]
369 fn test_decoy_when_base_excluded() {
370 let base = " decoy.when(mock_sync_hardware_api.sp()).then_return(SpeedStatus.IDLE)\n";
371 let candidate = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
372 let reference = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
373 let result = compute_kept_rate(base, candidate, reference);
374 let expected_new = "mock_sync_module_hardware".len() + "speed_status".len();
375 assert_eq!(result.candidate_new_chars, expected_new);
376 assert!(result.correctly_deleted_chars > 0);
377 assert!((result.kept_rate - 1.0).abs() < 1e-6);
378 assert!((result.recall_rate - 1.0).abs() < 1e-6);
379 }
380
381 #[test]
382 fn test_missing_deletion() {
383 let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
384 let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\neprintln!(\"\");\n";
385 let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
386 let result = compute_kept_rate(base, candidate, reference);
387 assert!(
388 result.kept_rate < 0.85,
389 "expected kept_rate < 0.85, got {}",
390 result.kept_rate
391 );
392 assert!(result.discarded_chars > 0);
393 }
394
395 #[test]
396 fn test_empty_prediction() {
397 let result = compute_kept_rate("old line\n", "", "new line\n");
398 assert_eq!(result.candidate_new_chars, 0);
399 assert!(result.candidate_deleted_chars > 0);
400 assert!(result.correctly_deleted_chars > 0);
401 assert!(result.correctly_deleted_chars < result.candidate_deleted_chars);
402 assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
403 assert!(result.recall_rate > 0.0 && result.recall_rate < 1.0);
404 }
405
406 #[test]
407 fn test_partial_kept() {
408 let result = compute_kept_rate("old\n", "alpha\nbeta\ngamma\n", "alpha\ngamma\n");
409 assert!(result.kept_chars > 0);
410 assert!(result.discarded_chars > 0);
411 assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
412 }
413
414 #[test]
415 fn test_eprintln_token_alignment() {
416 let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
417 let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
418 let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
419 let result = compute_kept_rate(base, candidate, reference);
420 assert!(result.discarded_chars > 0);
421 assert!(result.kept_chars > 0);
422 assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
423 assert_eq!(result.kept_chars, 14);
424 assert_eq!(result.discarded_chars, 12);
425 }
426
427 #[test]
428 fn test_annotations_rename() {
429 let base = " foo(old_name)\n";
430 let candidate = " foo(new_name)\n";
431 let reference = " foo(new_name)\n";
432 let result = compute_kept_rate(base, candidate, reference);
433
434 assert_eq!(result.candidate_new_chars, "new_name".len());
435 assert_eq!(result.candidate_deleted_chars, "old_name".len());
436 assert_eq!(result.reference_deleted_chars, "old_name".len());
437 assert_eq!(result.correctly_deleted_chars, "old_name".len());
438 assert!((result.recall_rate - 1.0).abs() < 1e-6);
439 assert_eq!(result.token_annotations.len(), tokenize(candidate).len());
440
441 for (&token, &annotation) in tokenize(candidate).iter().zip(&result.token_annotations) {
442 if token == "new_name" {
443 assert_eq!(annotation, TokenAnnotation::Kept);
444 } else {
445 assert_eq!(annotation, TokenAnnotation::Context);
446 }
447 }
448 }
449
450 #[test]
451 fn test_annotations_eprintln_coloring() {
452 let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
453 let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
454 let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
455 let result = compute_kept_rate(base, candidate, reference);
456 let candidate_tokens = tokenize(candidate);
457
458 let eprintln_index = candidate_tokens
459 .iter()
460 .position(|&token| token == "eprintln")
461 .expect("eprintln token not found");
462
463 for annotation in &result.token_annotations[..eprintln_index] {
464 assert_eq!(*annotation, TokenAnnotation::Context);
465 }
466
467 assert_eq!(
468 &result.token_annotations[eprintln_index..=eprintln_index + 10],
469 &[
470 TokenAnnotation::Kept,
471 TokenAnnotation::Kept,
472 TokenAnnotation::Kept,
473 TokenAnnotation::Kept,
474 TokenAnnotation::Discarded,
475 TokenAnnotation::Discarded,
476 TokenAnnotation::Discarded,
477 TokenAnnotation::Discarded,
478 TokenAnnotation::Kept,
479 TokenAnnotation::Kept,
480 TokenAnnotation::Kept,
481 ]
482 );
483 assert_eq!(
484 result.token_annotations.last(),
485 Some(&TokenAnnotation::Context)
486 );
487 }
488
489 #[test]
490 fn test_repetitive_tokens_remain_discarded() {
491 let base = "foo + foo + foo + foo + foo\n".repeat(16);
492 let candidate = "foo + foo + prediction_token + foo + foo\n".repeat(16);
493 let reference = "foo + foo + kept_token + foo + foo\n".repeat(16);
494 let result = compute_kept_rate(&base, &candidate, &reference);
495
496 assert_eq!(result.kept_chars, 0);
497 assert_eq!(result.correctly_deleted_chars, "foo".len() * 16);
498 assert_eq!(result.discarded_chars, result.candidate_new_chars);
499 assert_eq!(result.candidate_new_chars, "prediction_token".len() * 16);
500 assert!(result.kept_rate > 0.0);
501 assert!(result.recall_rate > 0.0);
502 }
503}