1use crate::word_diff::tokenize;
2
3#[cfg(test)]
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum TokenAnnotation {
6 Context,
7 Kept,
8 Discarded,
9}
10
11#[allow(dead_code)]
12#[derive(Debug, Clone)]
13pub struct KeptRateResult {
14 pub predicted_new_chars: usize,
15 pub final_new_chars: usize,
16 pub kept_chars: usize,
17 pub discarded_chars: usize,
18 pub context_chars: usize,
19 pub kept_rate: f64,
20 #[cfg(test)]
21 pub token_annotations: Vec<TokenAnnotation>,
22}
23
24fn dp_index(width: usize, row: usize, column: usize) -> usize {
25 row * width + column
26}
27
28/// Return masks over `a` and `b` using one-sided LCS tie-breaking for each
29/// side while sharing a single DP table construction.
30fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
31 if a.is_empty() || b.is_empty() {
32 return (vec![false; a.len()], vec![false; b.len()]);
33 }
34
35 if a == b {
36 return (vec![true; a.len()], vec![true; b.len()]);
37 }
38
39 let mut keep_a = vec![false; a.len()];
40 let mut keep_b = vec![false; b.len()];
41
42 let prefix_len = a
43 .iter()
44 .zip(b.iter())
45 .take_while(|(left, right)| left == right)
46 .count();
47 let suffix_len = {
48 let max_suffix = (a.len() - prefix_len).min(b.len() - prefix_len);
49 let mut suffix_len = 0;
50
51 while suffix_len < max_suffix {
52 let a_index = a.len() - 1 - suffix_len;
53 let b_index = b.len() - 1 - suffix_len;
54 if a[a_index] != b[b_index] {
55 break;
56 }
57 suffix_len += 1;
58 }
59
60 suffix_len
61 };
62
63 for index in 0..prefix_len {
64 keep_a[index] = true;
65 keep_b[index] = true;
66 }
67
68 for offset in 0..suffix_len {
69 let a_index = a.len() - suffix_len + offset;
70 let b_index = b.len() - suffix_len + offset;
71 keep_a[a_index] = true;
72 keep_b[b_index] = true;
73 }
74
75 let a_mid = &a[prefix_len..a.len() - suffix_len];
76 let b_mid = &b[prefix_len..b.len() - suffix_len];
77
78 if a_mid.is_empty() || b_mid.is_empty() {
79 return (keep_a, keep_b);
80 }
81
82 let row_count = a_mid.len() + 1;
83 let column_count = b_mid.len() + 1;
84 let mut dp = vec![0u32; row_count * column_count];
85
86 for i in 1..row_count {
87 let token_a = a_mid[i - 1];
88 for j in 1..column_count {
89 let index = dp_index(column_count, i, j);
90 if token_a == b_mid[j - 1] {
91 dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
92 } else {
93 let up = dp[dp_index(column_count, i - 1, j)];
94 let left = dp[dp_index(column_count, i, j - 1)];
95 dp[index] = up.max(left);
96 }
97 }
98 }
99
100 let mut i = a_mid.len();
101 let mut j = b_mid.len();
102
103 while i > 0 && j > 0 {
104 if a_mid[i - 1] == b_mid[j - 1] {
105 keep_a[prefix_len + i - 1] = true;
106 i -= 1;
107 j -= 1;
108 } else {
109 let up = dp[dp_index(column_count, i - 1, j)];
110 let left = dp[dp_index(column_count, i, j - 1)];
111 if up >= left {
112 i -= 1;
113 } else {
114 j -= 1;
115 }
116 }
117 }
118
119 let mut i = a_mid.len();
120 let mut j = b_mid.len();
121
122 while i > 0 && j > 0 {
123 if a_mid[i - 1] == b_mid[j - 1] {
124 keep_b[prefix_len + j - 1] = true;
125 i -= 1;
126 j -= 1;
127 } else {
128 let up = dp[dp_index(column_count, i - 1, j)];
129 let left = dp[dp_index(column_count, i, j - 1)];
130 if left >= up {
131 j -= 1;
132 } else {
133 i -= 1;
134 }
135 }
136 }
137
138 (keep_a, keep_b)
139}
140
141fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) {
142 let mut unmasked_tokens = Vec::with_capacity(tokens.len());
143 let mut unmasked_chars = 0;
144 let mut masked_chars = 0;
145
146 for (&token, &is_masked) in tokens.iter().zip(mask.iter()) {
147 if is_masked {
148 masked_chars += token.len();
149 } else {
150 unmasked_tokens.push(token);
151 unmasked_chars += token.len();
152 }
153 }
154
155 (unmasked_tokens, unmasked_chars, masked_chars)
156}
157
158pub fn compute_kept_rate(base: &str, predicted: &str, final_text: &str) -> KeptRateResult {
159 if base == predicted && predicted == final_text {
160 let predicted_tokens = tokenize(predicted);
161 let context_chars = predicted_tokens.iter().map(|token| token.len()).sum();
162 return KeptRateResult {
163 predicted_new_chars: 0,
164 final_new_chars: 0,
165 kept_chars: 0,
166 discarded_chars: 0,
167 context_chars,
168 kept_rate: 1.0,
169 #[cfg(test)]
170 token_annotations: vec![TokenAnnotation::Context; predicted_tokens.len()],
171 };
172 }
173
174 let base_tokens = tokenize(base);
175 let predicted_tokens = tokenize(predicted);
176 let final_tokens = tokenize(final_text);
177
178 let (pred_base_mask, _) = lcs_keep_masks(&predicted_tokens, &base_tokens);
179 let (pred_final_mask, final_pred_mask) = lcs_keep_masks(&predicted_tokens, &final_tokens);
180 let context_mask: Vec<bool> = pred_base_mask
181 .iter()
182 .zip(pred_final_mask.iter())
183 .map(|(&in_base, &in_final)| in_base && in_final)
184 .collect();
185
186 let (stripped_predicted, predicted_new_chars, context_chars) =
187 analyze_masked_tokens(&predicted_tokens, &context_mask);
188
189 let (final_base_mask, _) = lcs_keep_masks(&final_tokens, &base_tokens);
190 let final_context_mask: Vec<bool> = final_base_mask
191 .iter()
192 .zip(final_pred_mask.iter())
193 .map(|(&in_base, &in_predicted)| in_base && in_predicted)
194 .collect();
195
196 let (stripped_final, final_new_chars, _) =
197 analyze_masked_tokens(&final_tokens, &final_context_mask);
198
199 let keep_mask = lcs_keep_masks(&stripped_predicted, &stripped_final).0;
200
201 let kept_chars: usize = stripped_predicted
202 .iter()
203 .zip(keep_mask.iter())
204 .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
205 .sum();
206
207 let discarded_chars = predicted_new_chars - kept_chars;
208
209 let kept_rate = if predicted_new_chars == 0 {
210 if final_new_chars == 0 { 1.0 } else { 0.0 }
211 } else {
212 kept_chars as f64 / predicted_new_chars as f64
213 };
214
215 #[cfg(test)]
216 let token_annotations = {
217 let mut token_annotations = Vec::with_capacity(predicted_tokens.len());
218 let mut new_index = 0;
219 for (token_index, _token) in predicted_tokens.iter().enumerate() {
220 if context_mask[token_index] {
221 token_annotations.push(TokenAnnotation::Context);
222 } else {
223 let annotation = if keep_mask[new_index] {
224 TokenAnnotation::Kept
225 } else {
226 TokenAnnotation::Discarded
227 };
228 #[cfg(test)]
229 token_annotations.push(annotation);
230 new_index += 1;
231 }
232 }
233 token_annotations
234 };
235
236 KeptRateResult {
237 predicted_new_chars,
238 final_new_chars,
239 kept_chars,
240 discarded_chars,
241 context_chars,
242 kept_rate,
243 #[cfg(test)]
244 token_annotations,
245 }
246}
247
248#[cfg(test)]
249mod test_kept_rate {
250 use super::*;
251
252 #[test]
253 fn test_lcs_keep_masks() {
254 let (a_mask, b_mask) = lcs_keep_masks(&["a", "b", "c", "d", "e"], &["a", "c", "e"]);
255 assert_eq!(a_mask, vec![true, false, true, false, true]);
256 assert_eq!(b_mask, vec![true, true, true]);
257
258 let (a_mask, b_mask) = lcs_keep_masks(&[], &["x"]);
259 assert!(a_mask.is_empty());
260 assert_eq!(b_mask, vec![false]);
261 }
262
263 #[test]
264 fn test_lcs_keep_masks_matches_historical_one_sided_masks() {
265 let a = ["x", "a", "x", "b"];
266 let b = ["a", "x", "b", "x"];
267 let (a_mask, b_mask) = lcs_keep_masks(&a, &b);
268 assert_eq!(a_mask, lcs_keep_masks(&a, &b).0);
269 assert_eq!(b_mask, lcs_keep_masks(&b, &a).0);
270 }
271
272 #[test]
273 fn test_rate_extremes() {
274 let no_change = compute_kept_rate("foo bar", "foo bar", "foo bar");
275 assert!((no_change.kept_rate - 1.0).abs() < 1e-6);
276 assert_eq!(no_change.predicted_new_chars, 0);
277 assert!(
278 no_change
279 .token_annotations
280 .iter()
281 .all(|&annotation| annotation == TokenAnnotation::Context)
282 );
283
284 let accepted = compute_kept_rate("old", "new", "new");
285 assert!((accepted.kept_rate - 1.0).abs() < 1e-6);
286
287 let discarded = compute_kept_rate("old", "old", "new");
288 assert!((discarded.kept_rate - 0.0).abs() < 1e-6);
289 }
290
291 #[test]
292 fn test_pure_addition() {
293 let kept = compute_kept_rate("", "brand new line\n", "brand new line\n");
294 assert_eq!(kept.kept_chars, kept.predicted_new_chars);
295 assert!(
296 kept.token_annotations
297 .iter()
298 .all(|&annotation| annotation == TokenAnnotation::Kept)
299 );
300
301 let discarded =
302 compute_kept_rate("", "brand new line\n", "something completely different\n");
303 assert!(discarded.kept_chars < discarded.predicted_new_chars);
304 }
305
306 #[test]
307 fn test_decoy_when_base_excluded() {
308 let base = " decoy.when(mock_sync_hardware_api.sp()).then_return(SpeedStatus.IDLE)\n";
309 let predicted = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
310 let final_text = " decoy.when(mock_sync_module_hardware.speed_status).then_return(SpeedStatus.IDLE)\n";
311 let result = compute_kept_rate(base, predicted, final_text);
312 let expected_new = "mock_sync_module_hardware".len() + "speed_status".len();
313 assert_eq!(result.predicted_new_chars, expected_new);
314 assert!((result.kept_rate - 1.0).abs() < 1e-6);
315 }
316
317 #[test]
318 fn test_missing_deletion() {
319 let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
320 let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\neprintln!(\"\");\n";
321 let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
322 let result = compute_kept_rate(base, predicted, final_text);
323 assert!(
324 result.kept_rate < 0.85,
325 "expected kept_rate < 0.85, got {}",
326 result.kept_rate
327 );
328 assert!(result.discarded_chars > 0);
329 }
330
331 #[test]
332 fn test_empty_prediction() {
333 let result = compute_kept_rate("old line\n", "", "new line\n");
334 assert!((result.kept_rate - 0.0).abs() < 1e-6);
335 }
336
337 #[test]
338 fn test_partial_kept() {
339 let result = compute_kept_rate("old\n", "alpha\nbeta\ngamma\n", "alpha\ngamma\n");
340 assert!(result.kept_chars > 0);
341 assert!(result.discarded_chars > 0);
342 assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
343 }
344
345 #[test]
346 fn test_eprintln_token_alignment() {
347 let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
348 let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
349 let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
350 let result = compute_kept_rate(base, predicted, final_text);
351 assert!(result.discarded_chars > 0);
352 assert!(result.kept_chars > 0);
353 assert!(result.kept_rate > 0.0 && result.kept_rate < 1.0);
354 assert_eq!(result.kept_chars, 14);
355 assert_eq!(result.discarded_chars, 12);
356 }
357
358 #[test]
359 fn test_annotations_rename() {
360 let base = " foo(old_name)\n";
361 let predicted = " foo(new_name)\n";
362 let final_text = " foo(new_name)\n";
363 let result = compute_kept_rate(base, predicted, final_text);
364
365 assert_eq!(result.predicted_new_chars, "new_name".len());
366 assert_eq!(result.token_annotations.len(), tokenize(predicted).len());
367
368 for (&token, &annotation) in tokenize(predicted).iter().zip(&result.token_annotations) {
369 if token == "new_name" {
370 assert_eq!(annotation, TokenAnnotation::Kept);
371 } else {
372 assert_eq!(annotation, TokenAnnotation::Context);
373 }
374 }
375 }
376
377 #[test]
378 fn test_annotations_eprintln_coloring() {
379 let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
380 let predicted = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
381 let final_text = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
382 let result = compute_kept_rate(base, predicted, final_text);
383 let predicted_tokens = tokenize(predicted);
384
385 let eprintln_index = predicted_tokens
386 .iter()
387 .position(|&token| token == "eprintln")
388 .expect("eprintln token not found");
389
390 for annotation in &result.token_annotations[..eprintln_index] {
391 assert_eq!(*annotation, TokenAnnotation::Context);
392 }
393
394 assert_eq!(
395 &result.token_annotations[eprintln_index..=eprintln_index + 10],
396 &[
397 TokenAnnotation::Kept,
398 TokenAnnotation::Kept,
399 TokenAnnotation::Kept,
400 TokenAnnotation::Kept,
401 TokenAnnotation::Discarded,
402 TokenAnnotation::Discarded,
403 TokenAnnotation::Discarded,
404 TokenAnnotation::Discarded,
405 TokenAnnotation::Kept,
406 TokenAnnotation::Kept,
407 TokenAnnotation::Kept,
408 ]
409 );
410 assert_eq!(
411 result.token_annotations.last(),
412 Some(&TokenAnnotation::Context)
413 );
414 }
415
416 #[test]
417 fn test_repetitive_tokens_remain_discarded() {
418 let base = "foo + foo + foo + foo + foo\n".repeat(16);
419 let predicted = "foo + foo + prediction_token + foo + foo\n".repeat(16);
420 let final_text = "foo + foo + kept_token + foo + foo\n".repeat(16);
421 let result = compute_kept_rate(&base, &predicted, &final_text);
422
423 assert_eq!(result.kept_chars, 0);
424 assert_eq!(result.discarded_chars, result.predicted_new_chars);
425 assert_eq!(result.predicted_new_chars, "prediction_token".len() * 16);
426 }
427}