1use std::ops::Range;
2use std::path::Path;
3use std::sync::Arc;
4
5use edit_prediction::udiff::apply_diff_to_string;
6use language::text_diff;
7
8use crate::example::ExamplePromptInputs;
9
10pub fn reverse_diff(diff: &str) -> String {
11 let mut result: String = diff
12 .lines()
13 .map(|line| {
14 if line.starts_with("--- ") {
15 line.replacen("--- ", "+++ ", 1)
16 } else if line.starts_with("+++ ") {
17 line.replacen("+++ ", "--- ", 1)
18 } else if line.starts_with('+') && !line.starts_with("+++") {
19 format!("-{}", &line[1..])
20 } else if line.starts_with('-') && !line.starts_with("---") {
21 format!("+{}", &line[1..])
22 } else {
23 line.to_string()
24 }
25 })
26 .collect::<Vec<_>>()
27 .join("\n");
28 if diff.ends_with('\n') {
29 result.push('\n');
30 }
31 result
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35struct GranularEdit {
36 range: Range<usize>,
37 old_text: String,
38 new_text: String,
39}
40
41fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
42 text_diff(old_text, new_text)
43 .into_iter()
44 .map(|(range, new_text)| GranularEdit {
45 old_text: old_text[range.clone()].to_string(),
46 range,
47 new_text: new_text.to_string(),
48 })
49 .collect()
50}
51
52#[derive(Debug, Clone)]
53struct HistoryAdditionRange {
54 range_in_current: Range<usize>,
55}
56
57#[derive(Debug, Clone)]
58struct HistoryDeletionRange {
59 deleted_text: String,
60}
61
62fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryAdditionRange> {
63 let mut result = Vec::new();
64 let mut offset_delta: isize = 0;
65
66 for edit in history_edits {
67 if !edit.new_text.is_empty() {
68 let new_start = (edit.range.start as isize + offset_delta) as usize;
69 let new_end = new_start + edit.new_text.len();
70 result.push(HistoryAdditionRange {
71 range_in_current: new_start..new_end,
72 });
73 }
74
75 offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
76 }
77
78 result
79}
80
81fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryDeletionRange> {
82 history_edits
83 .iter()
84 .filter(|edit| !edit.old_text.is_empty())
85 .map(|edit| HistoryDeletionRange {
86 deleted_text: edit.old_text.clone(),
87 })
88 .collect()
89}
90
91#[derive(Debug, Clone, Default, PartialEq, Eq)]
92struct ReversalOverlap {
93 chars_reversing_user_edits: usize,
94 total_chars_in_prediction: usize,
95}
96
97impl ReversalOverlap {
98 fn ratio(&self) -> f32 {
99 if self.total_chars_in_prediction == 0 {
100 0.0
101 } else {
102 self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32
103 }
104 }
105}
106
107/// Check if `needle` is a subsequence of `haystack` (characters appear in order, not necessarily contiguous).
108fn is_subsequence(needle: &str, haystack: &str) -> bool {
109 let mut needle_chars = needle.chars().peekable();
110 for c in haystack.chars() {
111 if needle_chars.peek() == Some(&c) {
112 needle_chars.next();
113 }
114 }
115 needle_chars.peek().is_none()
116}
117
118/// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension),
119/// or where `new_text` appears as a subsequence within `old_text` (reduction).
120///
121/// For extensions: when the user's text is preserved (in order) within the prediction,
122/// we only count the newly inserted characters, not the preserved ones.
123/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()")
124/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars
125///
126/// For reductions: when the prediction's text is preserved (in order) within the original,
127/// we only count the deleted characters, not the preserved ones.
128/// E.g., "ifrom" → "from" becomes 1 deleted char ("i")
129fn normalize_extension_edits(edits: Vec<GranularEdit>) -> Vec<GranularEdit> {
130 edits
131 .into_iter()
132 .map(|edit| {
133 if edit.old_text.is_empty() || edit.new_text.is_empty() {
134 return edit;
135 }
136
137 if is_subsequence(&edit.old_text, &edit.new_text) {
138 let inserted_len = edit.new_text.len() - edit.old_text.len();
139 GranularEdit {
140 range: edit.range.start..edit.range.start,
141 old_text: String::new(),
142 new_text: edit.new_text.chars().take(inserted_len).collect(),
143 }
144 } else if is_subsequence(&edit.new_text, &edit.old_text) {
145 let deleted_len = edit.old_text.len() - edit.new_text.len();
146 GranularEdit {
147 range: edit.range.start..edit.range.start + deleted_len,
148 old_text: edit.old_text.chars().take(deleted_len).collect(),
149 new_text: String::new(),
150 }
151 } else {
152 edit
153 }
154 })
155 .collect()
156}
157
158fn compute_reversal_overlap(
159 original_content: &str,
160 current_content: &str,
161 predicted_content: &str,
162) -> ReversalOverlap {
163 let history_edits =
164 normalize_extension_edits(compute_granular_edits(original_content, current_content));
165 let prediction_edits =
166 normalize_extension_edits(compute_granular_edits(current_content, predicted_content));
167
168 let history_addition_ranges = compute_history_addition_ranges(&history_edits);
169 let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
170
171 let reversed_additions =
172 compute_reversed_additions(&history_addition_ranges, &prediction_edits);
173 let restored_deletions =
174 compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
175
176 let total_chars_in_prediction: usize = prediction_edits
177 .iter()
178 .map(|e| e.new_text.len() + e.old_text.len())
179 .sum();
180
181 ReversalOverlap {
182 chars_reversing_user_edits: reversed_additions + restored_deletions,
183 total_chars_in_prediction,
184 }
185}
186
187fn compute_reversed_additions(
188 history_addition_ranges: &[HistoryAdditionRange],
189 prediction_edits: &[GranularEdit],
190) -> usize {
191 let mut reversed_chars = 0;
192
193 for pred_edit in prediction_edits {
194 for history_addition in history_addition_ranges {
195 let overlap_start = pred_edit
196 .range
197 .start
198 .max(history_addition.range_in_current.start);
199 let overlap_end = pred_edit
200 .range
201 .end
202 .min(history_addition.range_in_current.end);
203
204 if overlap_start < overlap_end {
205 reversed_chars += overlap_end - overlap_start;
206 }
207 }
208 }
209
210 reversed_chars
211}
212
213fn compute_restored_deletions(
214 history_deletion_ranges: &[HistoryDeletionRange],
215 prediction_edits: &[GranularEdit],
216) -> usize {
217 let history_deleted_text: String = history_deletion_ranges
218 .iter()
219 .map(|r| r.deleted_text.as_str())
220 .collect();
221
222 let prediction_added_text: String = prediction_edits
223 .iter()
224 .map(|e| e.new_text.as_str())
225 .collect();
226
227 compute_lcs_length(&history_deleted_text, &prediction_added_text)
228}
229
230fn compute_lcs_length(a: &str, b: &str) -> usize {
231 let a_chars: Vec<char> = a.chars().collect();
232 let b_chars: Vec<char> = b.chars().collect();
233 let m = a_chars.len();
234 let n = b_chars.len();
235
236 if m == 0 || n == 0 {
237 return 0;
238 }
239
240 let mut prev = vec![0; n + 1];
241 let mut curr = vec![0; n + 1];
242
243 for i in 1..=m {
244 for j in 1..=n {
245 if a_chars[i - 1] == b_chars[j - 1] {
246 curr[j] = prev[j - 1] + 1;
247 } else {
248 curr[j] = prev[j].max(curr[j - 1]);
249 }
250 }
251 std::mem::swap(&mut prev, &mut curr);
252 curr.fill(0);
253 }
254
255 prev[n]
256}
257
258pub fn filter_edit_history_by_path<'a>(
259 edit_history: &'a [Arc<zeta_prompt::Event>],
260 cursor_path: &std::path::Path,
261) -> Vec<&'a zeta_prompt::Event> {
262 edit_history
263 .iter()
264 .filter(|event| match event.as_ref() {
265 zeta_prompt::Event::BufferChange { path, .. } => {
266 let event_path = path.as_ref();
267 if event_path == cursor_path {
268 return true;
269 }
270 let stripped = event_path
271 .components()
272 .skip(1)
273 .collect::<std::path::PathBuf>();
274 stripped == cursor_path
275 }
276 })
277 .map(|arc| arc.as_ref())
278 .collect()
279}
280
281pub fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str {
282 match event {
283 zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(),
284 }
285}
286
287pub fn compute_prediction_reversal_ratio(
288 prompt_inputs: &ExamplePromptInputs,
289 predicted_content: &str,
290 cursor_path: &Path,
291) -> f32 {
292 let current_content = &prompt_inputs.content;
293
294 let edit_history: &[Arc<zeta_prompt::Event>] = &prompt_inputs.edit_history;
295 let relevant_events = filter_edit_history_by_path(edit_history, cursor_path);
296
297 let mut original_content = current_content.to_string();
298 for event in relevant_events.into_iter().rev() {
299 let diff = extract_diff_from_event(event);
300 if diff.is_empty() {
301 continue;
302 }
303 let reversed = reverse_diff(diff);
304 let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed);
305 match apply_diff_to_string(&with_headers, &original_content) {
306 Ok(updated_content) => original_content = updated_content,
307 Err(err) => {
308 log::warn!(
309 "Failed to reconstruct original content for reversal tracking: Failed to apply reversed diff: {:#}",
310 err
311 );
312 return 0.0;
313 }
314 }
315 }
316
317 let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content);
318 overlap.ratio()
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use edit_prediction::udiff::apply_diff_to_string;
325
326 #[test]
327 fn test_reversal_overlap() {
328 struct Case {
329 name: &'static str,
330 original: &'static str,
331 current: &'static str,
332 predicted: &'static str,
333 expected_reversal_chars: usize,
334 expected_total_chars: usize,
335 }
336
337 let cases = [
338 Case {
339 name: "user_adds_line_prediction_removes_it",
340 original: "a\nb\nc",
341 current: "a\nnew line\nb\nc",
342 predicted: "a\nb\nc",
343 expected_reversal_chars: 9,
344 expected_total_chars: 9,
345 },
346 Case {
347 name: "user_deletes_line_prediction_restores_it",
348 original: "a\ndeleted\nb",
349 current: "a\nb",
350 predicted: "a\ndeleted\nb",
351 expected_reversal_chars: 8,
352 expected_total_chars: 8,
353 },
354 Case {
355 name: "user_deletes_text_prediction_restores_partial",
356 original: "hello beautiful world",
357 current: "hello world",
358 predicted: "hello beautiful world",
359 expected_reversal_chars: 10,
360 expected_total_chars: 10,
361 },
362 Case {
363 name: "user_deletes_foo_prediction_adds_bar",
364 original: "foo",
365 current: "",
366 predicted: "bar",
367 expected_reversal_chars: 0,
368 expected_total_chars: 3,
369 },
370 Case {
371 name: "independent_edits_different_locations",
372 original: "line1\nline2\nline3",
373 current: "LINE1\nline2\nline3",
374 predicted: "LINE1\nline2\nLINE3",
375 expected_reversal_chars: 0,
376 expected_total_chars: 10,
377 },
378 Case {
379 name: "no_history_edits",
380 original: "same",
381 current: "same",
382 predicted: "different",
383 expected_reversal_chars: 0,
384 expected_total_chars: 13,
385 },
386 Case {
387 name: "user_replaces_text_prediction_reverses",
388 original: "keep\ndelete_me\nkeep2",
389 current: "keep\nadded\nkeep2",
390 predicted: "keep\ndelete_me\nkeep2",
391 expected_reversal_chars: 14,
392 expected_total_chars: 14,
393 },
394 Case {
395 name: "user_modifies_word_prediction_modifies_differently",
396 original: "the quick brown fox",
397 current: "the slow brown fox",
398 predicted: "the fast brown fox",
399 expected_reversal_chars: 4,
400 expected_total_chars: 8,
401 },
402 Case {
403 name: "user finishes function name (suffix)",
404 original: "",
405 current: "epr",
406 predicted: "eprintln!()",
407 expected_reversal_chars: 0,
408 expected_total_chars: 8,
409 },
410 Case {
411 name: "user starts function name (prefix)",
412 original: "",
413 current: "my_function()",
414 predicted: "test_my_function()",
415 expected_reversal_chars: 0,
416 expected_total_chars: 5,
417 },
418 Case {
419 name: "user types partial, prediction extends in multiple places",
420 original: "",
421 current: "test_my_function",
422 predicted: "a_test_for_my_special_function_plz",
423 expected_reversal_chars: 0,
424 expected_total_chars: 18,
425 },
426 // Edge cases for subsequence matching
427 Case {
428 name: "subsequence with interleaved underscores",
429 original: "",
430 current: "a_b_c",
431 predicted: "_a__b__c__",
432 expected_reversal_chars: 0,
433 expected_total_chars: 5,
434 },
435 Case {
436 name: "not a subsequence - different characters",
437 original: "",
438 current: "abc",
439 predicted: "xyz",
440 expected_reversal_chars: 3,
441 expected_total_chars: 6,
442 },
443 Case {
444 name: "not a subsequence - wrong order",
445 original: "",
446 current: "abc",
447 predicted: "cba",
448 expected_reversal_chars: 3,
449 expected_total_chars: 6,
450 },
451 Case {
452 name: "partial subsequence - only some chars match",
453 original: "",
454 current: "abcd",
455 predicted: "axbx",
456 expected_reversal_chars: 4,
457 expected_total_chars: 8,
458 },
459 // Common completion patterns
460 Case {
461 name: "completing a method call",
462 original: "",
463 current: "vec.pu",
464 predicted: "vec.push(item)",
465 expected_reversal_chars: 0,
466 expected_total_chars: 8,
467 },
468 Case {
469 name: "completing an import statement",
470 original: "",
471 current: "use std::col",
472 predicted: "use std::collections::HashMap",
473 expected_reversal_chars: 0,
474 expected_total_chars: 17,
475 },
476 Case {
477 name: "completing a struct field",
478 original: "",
479 current: "name: St",
480 predicted: "name: String",
481 expected_reversal_chars: 0,
482 expected_total_chars: 4,
483 },
484 Case {
485 name: "prediction replaces with completely different text",
486 original: "",
487 current: "hello",
488 predicted: "world",
489 expected_reversal_chars: 5,
490 expected_total_chars: 10,
491 },
492 Case {
493 name: "empty prediction removes user text",
494 original: "",
495 current: "mistake",
496 predicted: "",
497 expected_reversal_chars: 7,
498 expected_total_chars: 7,
499 },
500 Case {
501 name: "fixing typo is not reversal",
502 original: "",
503 current: "<dv",
504 predicted: "<div>",
505 expected_reversal_chars: 0,
506 expected_total_chars: 2,
507 },
508 Case {
509 name: "infix insertion not reversal",
510 original: "from my_project import Foo\n",
511 current: "ifrom my_project import Foo\n",
512 predicted: indoc::indoc! {"
513 import
514 from my_project import Foo
515 "},
516 expected_reversal_chars: 0,
517 expected_total_chars: 6,
518 },
519 Case {
520 name: "non-word based reversal",
521 original: "from",
522 current: "ifrom",
523 predicted: "from",
524 expected_reversal_chars: 1,
525 expected_total_chars: 1,
526 },
527 ];
528
529 for case in &cases {
530 let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
531 assert_eq!(
532 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
533 "Test '{}': expected {} reversal chars, got {}",
534 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
535 );
536 assert_eq!(
537 overlap.total_chars_in_prediction, case.expected_total_chars,
538 "Test '{}': expected {} total chars, got {}",
539 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
540 );
541 }
542 }
543
544 #[test]
545 fn test_reverse_diff() {
546 let forward_diff = "\
547--- a/file.rs
548+++ b/file.rs
549@@ -1,3 +1,4 @@
550 fn main() {
551+ let x = 42;
552 println!(\"hello\");
553}";
554
555 let reversed = reverse_diff(forward_diff);
556
557 assert!(
558 reversed.contains("+++ a/file.rs"),
559 "Should have +++ for old path"
560 );
561 assert!(
562 reversed.contains("--- b/file.rs"),
563 "Should have --- for new path"
564 );
565 assert!(
566 reversed.contains("- let x = 42;"),
567 "Added line should become deletion"
568 );
569 assert!(
570 reversed.contains(" fn main()"),
571 "Context lines should be unchanged"
572 );
573 }
574
575 #[test]
576 fn test_reverse_diff_roundtrip() {
577 // Applying a diff and then its reverse should get back to original
578 let original = "first line\nhello world\nlast line\n";
579 let modified = "first line\nhello beautiful world\nlast line\n";
580
581 // unified_diff doesn't include file headers, but apply_diff_to_string needs them
582 let diff_body = language::unified_diff(original, modified);
583 let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body);
584 let reversed_diff = reverse_diff(&forward_diff);
585
586 // Apply forward diff to original
587 let after_forward = apply_diff_to_string(&forward_diff, original).unwrap();
588 assert_eq!(after_forward, modified);
589
590 // Apply reversed diff to modified
591 let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap();
592 assert_eq!(after_reverse, original);
593 }
594
595 #[test]
596 fn test_filter_edit_history_by_path() {
597 // Test that filter_edit_history_by_path correctly matches paths when
598 // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs")
599 // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs")
600 let events = vec![
601 Arc::new(zeta_prompt::Event::BufferChange {
602 path: Arc::from(Path::new("myrepo/src/file.rs")),
603 old_path: Arc::from(Path::new("myrepo/src/file.rs")),
604 diff: "@@ -1 +1 @@\n-old\n+new".into(),
605 predicted: false,
606 in_open_source_repo: true,
607 }),
608 Arc::new(zeta_prompt::Event::BufferChange {
609 path: Arc::from(Path::new("myrepo/other.rs")),
610 old_path: Arc::from(Path::new("myrepo/other.rs")),
611 diff: "@@ -1 +1 @@\n-a\n+b".into(),
612 predicted: false,
613 in_open_source_repo: true,
614 }),
615 Arc::new(zeta_prompt::Event::BufferChange {
616 path: Arc::from(Path::new("src/file.rs")),
617 old_path: Arc::from(Path::new("src/file.rs")),
618 diff: "@@ -1 +1 @@\n-x\n+y".into(),
619 predicted: false,
620 in_open_source_repo: true,
621 }),
622 ];
623
624 // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path
625 // "src/file.rs" exact match
626 let cursor_path = Path::new("src/file.rs");
627 let filtered = filter_edit_history_by_path(&events, cursor_path);
628 assert_eq!(
629 filtered.len(),
630 2,
631 "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)"
632 );
633
634 // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs"
635 // "src/file.rs" stripped -> "file.rs" == "file.rs"
636 let cursor_path = Path::new("file.rs");
637 let filtered = filter_edit_history_by_path(&events, cursor_path);
638 assert_eq!(
639 filtered.len(),
640 1,
641 "Should only match src/file.rs (stripped to file.rs)"
642 );
643
644 // "myrepo/other.rs" stripped -> "other.rs" == "other.rs"
645 let cursor_path = Path::new("other.rs");
646 let filtered = filter_edit_history_by_path(&events, cursor_path);
647 assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs");
648 }
649
650 #[test]
651 fn test_reverse_diff_preserves_trailing_newline() {
652 let diff_with_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new\n";
653 let reversed = reverse_diff(diff_with_trailing_newline);
654 assert!(
655 reversed.ends_with('\n'),
656 "Reversed diff should preserve trailing newline"
657 );
658
659 let diff_without_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new";
660 let reversed = reverse_diff(diff_without_trailing_newline);
661 assert!(
662 !reversed.ends_with('\n'),
663 "Reversed diff should not add trailing newline if original didn't have one"
664 );
665 }
666}