1use std::ops::Range;
2use std::path::Path;
3use std::sync::Arc;
4
5use edit_prediction::udiff::apply_diff_to_string;
6use language::{char_diff, text_diff};
7
8use crate::example::ExamplePromptInputs;
9
10fn apply_diff_to_string_lenient(diff_str: &str, text: &str) -> String {
11 let hunks = parse_diff_hunks(diff_str);
12 let mut result = text.to_string();
13
14 for hunk in hunks {
15 let hunk_diff = format!("--- a/file\n+++ b/file\n{}", format_hunk(&hunk));
16 if let Ok(updated) = apply_diff_to_string(&hunk_diff, &result) {
17 result = updated;
18 }
19 }
20
21 result
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25struct ParsedHunk {
26 old_start: u32,
27 old_count: u32,
28 new_start: u32,
29 new_count: u32,
30 lines: Vec<HunkLine>,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34enum HunkLine {
35 Context(String),
36 Addition(String),
37 Deletion(String),
38}
39
40fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> {
41 let line = line.strip_prefix("@@ -")?;
42 let (old_part, rest) = line.split_once(' ')?;
43 let rest = rest.strip_prefix('+')?;
44 let (new_part, _) = rest.split_once(" @@")?;
45
46 let (old_start, old_count) = if let Some((start, count)) = old_part.split_once(',') {
47 (start.parse().ok()?, count.parse().ok()?)
48 } else {
49 (old_part.parse().ok()?, 1)
50 };
51
52 let (new_start, new_count) = if let Some((start, count)) = new_part.split_once(',') {
53 (start.parse().ok()?, count.parse().ok()?)
54 } else {
55 (new_part.parse().ok()?, 1)
56 };
57
58 Some((old_start, old_count, new_start, new_count))
59}
60
61fn parse_diff_hunks(diff: &str) -> Vec<ParsedHunk> {
62 let mut hunks = Vec::new();
63 let mut current_hunk: Option<ParsedHunk> = None;
64
65 for line in diff.lines() {
66 if let Some((old_start, old_count, new_start, new_count)) = parse_hunk_header(line) {
67 if let Some(hunk) = current_hunk.take() {
68 hunks.push(hunk);
69 }
70 current_hunk = Some(ParsedHunk {
71 old_start,
72 old_count,
73 new_start,
74 new_count,
75 lines: Vec::new(),
76 });
77 } else if let Some(ref mut hunk) = current_hunk {
78 if let Some(stripped) = line.strip_prefix('+') {
79 hunk.lines.push(HunkLine::Addition(stripped.to_string()));
80 } else if let Some(stripped) = line.strip_prefix('-') {
81 hunk.lines.push(HunkLine::Deletion(stripped.to_string()));
82 } else if let Some(stripped) = line.strip_prefix(' ') {
83 hunk.lines.push(HunkLine::Context(stripped.to_string()));
84 } else if line.is_empty() {
85 hunk.lines.push(HunkLine::Context(String::new()));
86 }
87 }
88 }
89
90 if let Some(hunk) = current_hunk {
91 hunks.push(hunk);
92 }
93
94 hunks
95}
96
97fn format_hunk(hunk: &ParsedHunk) -> String {
98 let mut result = format!(
99 "@@ -{},{} +{},{} @@\n",
100 hunk.old_start, hunk.old_count, hunk.new_start, hunk.new_count
101 );
102 for line in &hunk.lines {
103 match line {
104 HunkLine::Context(text) => {
105 result.push(' ');
106 result.push_str(text);
107 result.push('\n');
108 }
109 HunkLine::Addition(text) => {
110 result.push('+');
111 result.push_str(text);
112 result.push('\n');
113 }
114 HunkLine::Deletion(text) => {
115 result.push('-');
116 result.push_str(text);
117 result.push('\n');
118 }
119 }
120 }
121 result
122}
123
124fn filter_diff_hunks_by_excerpt(
125 diff: &str,
126 excerpt_start_row: u32,
127 excerpt_row_count: u32,
128) -> (String, i32) {
129 let hunks = parse_diff_hunks(diff);
130 let excerpt_start_0based = excerpt_start_row;
131 let excerpt_end_0based = excerpt_start_row + excerpt_row_count;
132
133 let mut filtered_hunks = Vec::new();
134 let mut cumulative_line_offset: i32 = 0;
135
136 for hunk in hunks {
137 let hunk_start_0based = hunk.new_start.saturating_sub(1);
138 let hunk_end_0based = hunk_start_0based + hunk.new_count;
139
140 let additions: i32 = hunk
141 .lines
142 .iter()
143 .filter(|l| matches!(l, HunkLine::Addition(_)))
144 .count() as i32;
145 let deletions: i32 = hunk
146 .lines
147 .iter()
148 .filter(|l| matches!(l, HunkLine::Deletion(_)))
149 .count() as i32;
150 let hunk_line_delta = additions - deletions;
151
152 if hunk_end_0based <= excerpt_start_0based {
153 cumulative_line_offset += hunk_line_delta;
154 continue;
155 }
156
157 if hunk_start_0based >= excerpt_end_0based {
158 continue;
159 }
160
161 let mut filtered_lines = Vec::new();
162 let mut current_row_0based = hunk_start_0based;
163 let mut filtered_old_count = 0u32;
164 let mut filtered_new_count = 0u32;
165 let mut first_included_row: Option<u32> = None;
166
167 for line in &hunk.lines {
168 match line {
169 HunkLine::Context(text) => {
170 if current_row_0based >= excerpt_start_0based
171 && current_row_0based < excerpt_end_0based
172 {
173 if first_included_row.is_none() {
174 first_included_row = Some(current_row_0based);
175 }
176 filtered_lines.push(HunkLine::Context(text.clone()));
177 filtered_old_count += 1;
178 filtered_new_count += 1;
179 }
180 current_row_0based += 1;
181 }
182 HunkLine::Addition(text) => {
183 if current_row_0based >= excerpt_start_0based
184 && current_row_0based < excerpt_end_0based
185 {
186 if first_included_row.is_none() {
187 first_included_row = Some(current_row_0based);
188 }
189 filtered_lines.push(HunkLine::Addition(text.clone()));
190 filtered_new_count += 1;
191 }
192 current_row_0based += 1;
193 }
194 HunkLine::Deletion(text) => {
195 if current_row_0based >= excerpt_start_0based
196 && current_row_0based < excerpt_end_0based
197 {
198 if first_included_row.is_none() {
199 first_included_row = Some(current_row_0based);
200 }
201 filtered_lines.push(HunkLine::Deletion(text.clone()));
202 filtered_old_count += 1;
203 }
204 }
205 }
206 }
207
208 if !filtered_lines.is_empty() {
209 let first_row = first_included_row.unwrap_or(excerpt_start_0based);
210 let new_start_1based = (first_row - excerpt_start_0based) + 1;
211
212 filtered_hunks.push(ParsedHunk {
213 old_start: new_start_1based,
214 old_count: filtered_old_count,
215 new_start: new_start_1based,
216 new_count: filtered_new_count,
217 lines: filtered_lines,
218 });
219 }
220
221 cumulative_line_offset += hunk_line_delta;
222 }
223
224 let mut result = String::new();
225 for hunk in &filtered_hunks {
226 result.push_str(&format_hunk(hunk));
227 }
228
229 (result, cumulative_line_offset)
230}
231
232fn compute_excerpt_aware_reversal_overlap(
233 edit_history_diffs: &[&str],
234 excerpt_content: &str,
235 excerpt_start_row: u32,
236 predicted_content: &str,
237) -> ReversalOverlap {
238 let mut current_content = excerpt_content.to_string();
239 let mut current_excerpt_start_row = excerpt_start_row;
240
241 for diff in edit_history_diffs.iter().rev() {
242 if diff.is_empty() {
243 continue;
244 }
245
246 let current_row_count = current_content.lines().count() as u32;
247 let (filtered_diff, _line_offset) =
248 filter_diff_hunks_by_excerpt(diff, current_excerpt_start_row, current_row_count.max(1));
249
250 if filtered_diff.is_empty() {
251 let hunks = parse_diff_hunks(diff);
252 for hunk in hunks {
253 let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
254 if hunk_end <= current_excerpt_start_row {
255 let additions: u32 = hunk
256 .lines
257 .iter()
258 .filter(|l| matches!(l, HunkLine::Addition(_)))
259 .count() as u32;
260 let deletions: u32 = hunk
261 .lines
262 .iter()
263 .filter(|l| matches!(l, HunkLine::Deletion(_)))
264 .count() as u32;
265 if additions >= deletions {
266 current_excerpt_start_row =
267 current_excerpt_start_row.saturating_sub(additions - deletions);
268 } else {
269 current_excerpt_start_row += deletions - additions;
270 }
271 }
272 }
273 continue;
274 }
275
276 let reversed = reverse_diff(&format!("--- a/file\n+++ b/file\n{}", filtered_diff));
277 match apply_diff_to_string(&reversed, ¤t_content) {
278 Ok(updated) => {
279 current_content = updated;
280 }
281 Err(_) => {
282 continue;
283 }
284 }
285
286 let hunks = parse_diff_hunks(diff);
287 for hunk in hunks {
288 let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
289 if hunk_end <= current_excerpt_start_row {
290 let additions: u32 = hunk
291 .lines
292 .iter()
293 .filter(|l| matches!(l, HunkLine::Addition(_)))
294 .count() as u32;
295 let deletions: u32 = hunk
296 .lines
297 .iter()
298 .filter(|l| matches!(l, HunkLine::Deletion(_)))
299 .count() as u32;
300 if additions >= deletions {
301 current_excerpt_start_row =
302 current_excerpt_start_row.saturating_sub(additions - deletions);
303 } else {
304 current_excerpt_start_row += deletions - additions;
305 }
306 }
307 }
308 }
309
310 compute_reversal_overlap(¤t_content, excerpt_content, predicted_content)
311}
312
313fn reverse_diff(diff: &str) -> String {
314 let mut result: String = diff
315 .lines()
316 .map(|line| {
317 if line.starts_with("--- ") {
318 line.replacen("--- ", "+++ ", 1)
319 } else if line.starts_with("+++ ") {
320 line.replacen("+++ ", "--- ", 1)
321 } else if line.starts_with('+') && !line.starts_with("+++") {
322 format!("-{}", &line[1..])
323 } else if line.starts_with('-') && !line.starts_with("---") {
324 format!("+{}", &line[1..])
325 } else {
326 line.to_string()
327 }
328 })
329 .collect::<Vec<_>>()
330 .join("\n");
331 if diff.ends_with('\n') {
332 result.push('\n');
333 }
334 result
335}
336
337#[derive(Debug, Clone, PartialEq, Eq)]
338struct GranularEdit {
339 range: Range<usize>,
340 old_text: String,
341 new_text: String,
342}
343
344fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
345 text_diff(old_text, new_text)
346 .into_iter()
347 .map(|(range, new_text)| GranularEdit {
348 old_text: old_text[range.clone()].to_string(),
349 range,
350 new_text: new_text.to_string(),
351 })
352 .collect()
353}
354
355#[derive(Debug, Clone)]
356struct HistoryAdditionRange {
357 range_in_current: Range<usize>,
358}
359
360#[derive(Debug, Clone)]
361struct HistoryDeletionRange {
362 deleted_text: String,
363 position_in_current: usize,
364}
365
366fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryAdditionRange> {
367 let mut result = Vec::new();
368 let mut offset_delta: isize = 0;
369
370 for edit in history_edits {
371 if !edit.new_text.is_empty() {
372 let new_start = (edit.range.start as isize + offset_delta) as usize;
373 let new_end = new_start + edit.new_text.len();
374 result.push(HistoryAdditionRange {
375 range_in_current: new_start..new_end,
376 });
377 }
378
379 offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
380 }
381
382 result
383}
384
385fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryDeletionRange> {
386 let mut result = Vec::new();
387 let mut offset_delta: isize = 0;
388
389 for edit in history_edits {
390 if !edit.old_text.is_empty() {
391 let position_in_current = (edit.range.start as isize + offset_delta) as usize;
392 result.push(HistoryDeletionRange {
393 deleted_text: edit.old_text.clone(),
394 position_in_current,
395 });
396 }
397
398 offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
399 }
400
401 result
402}
403
404#[derive(Debug, Clone, Default, PartialEq, Eq)]
405struct ReversalOverlap {
406 chars_reversing_user_edits: usize,
407 total_chars_in_prediction: usize,
408}
409
410impl ReversalOverlap {
411 fn ratio(&self) -> f32 {
412 if self.total_chars_in_prediction == 0 {
413 0.0
414 } else {
415 self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32
416 }
417 }
418}
419
420/// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension),
421/// or where `new_text` appears as a subsequence within `old_text` (reduction).
422///
423/// For extensions: when the user's text is preserved (in order) within the prediction,
424/// we only count the newly inserted characters, not the preserved ones.
425/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()")
426/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars
427///
428/// For reductions: when the prediction's text is preserved (in order) within the original,
429/// we only count the deleted characters, not the preserved ones.
430/// E.g., "ifrom" → "from" becomes 1 deleted char ("i")
431fn normalize_extension_edits(edits: Vec<GranularEdit>) -> Vec<GranularEdit> {
432 edits
433 .into_iter()
434 .flat_map(|edit| {
435 if edit.old_text.is_empty() || edit.new_text.is_empty() {
436 return vec![edit];
437 }
438
439 // Use character-wise diff to find exact byte ranges of changes
440 let char_edits = char_diff(&edit.old_text, &edit.new_text);
441
442 let all_deletions = !char_edits.is_empty()
443 && char_edits
444 .iter()
445 .all(|(range, replacement)| !range.is_empty() && replacement.is_empty());
446 let all_insertions = !char_edits.is_empty()
447 && char_edits
448 .iter()
449 .all(|(range, replacement)| range.is_empty() && !replacement.is_empty());
450 if all_deletions || all_insertions {
451 return char_edits
452 .into_iter()
453 .map(|(range, replacement)| GranularEdit {
454 range: edit.range.start + range.start..edit.range.start + range.end,
455 old_text: edit.old_text[range].to_string(),
456 new_text: replacement.to_string(),
457 })
458 .collect();
459 }
460
461 // Otherwise, keep the original edit (mixed changes)
462 vec![edit]
463 })
464 .collect()
465}
466
467fn compute_reversal_overlap(
468 original_content: &str,
469 current_content: &str,
470 predicted_content: &str,
471) -> ReversalOverlap {
472 let history_edits =
473 normalize_extension_edits(compute_granular_edits(original_content, current_content));
474 let prediction_edits =
475 normalize_extension_edits(compute_granular_edits(current_content, predicted_content));
476
477 let history_addition_ranges = compute_history_addition_ranges(&history_edits);
478 let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
479
480 let reversed_additions =
481 compute_reversed_additions(&history_addition_ranges, &prediction_edits);
482 let restored_deletions =
483 compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
484
485 let total_chars_in_prediction: usize = prediction_edits
486 .iter()
487 .map(|e| e.new_text.chars().count() + e.old_text.chars().count())
488 .sum();
489
490 ReversalOverlap {
491 chars_reversing_user_edits: reversed_additions + restored_deletions,
492 total_chars_in_prediction,
493 }
494}
495
496fn compute_reversed_additions(
497 history_addition_ranges: &[HistoryAdditionRange],
498 prediction_edits: &[GranularEdit],
499) -> usize {
500 let mut reversed_chars = 0;
501
502 for pred_edit in prediction_edits {
503 for history_addition in history_addition_ranges {
504 let overlap_start = pred_edit
505 .range
506 .start
507 .max(history_addition.range_in_current.start);
508 let overlap_end = pred_edit
509 .range
510 .end
511 .min(history_addition.range_in_current.end);
512
513 if overlap_start < overlap_end {
514 let relative_start = overlap_start - pred_edit.range.start;
515 let relative_end = overlap_end - pred_edit.range.start;
516 let overlap_text = &pred_edit.old_text[relative_start..relative_end];
517 reversed_chars += overlap_text.chars().count();
518 }
519 }
520 }
521
522 reversed_chars
523}
524
525fn compute_restored_deletions(
526 history_deletion_ranges: &[HistoryDeletionRange],
527 prediction_edits: &[GranularEdit],
528) -> usize {
529 let mut restored = 0;
530
531 for pred_edit in prediction_edits {
532 if pred_edit.new_text.is_empty() {
533 continue;
534 }
535
536 for deletion in history_deletion_ranges {
537 if pred_edit.range.contains(&deletion.position_in_current)
538 || deletion.position_in_current == pred_edit.range.start
539 {
540 restored += compute_lcs_length(&deletion.deleted_text, &pred_edit.new_text);
541 }
542 }
543 }
544
545 restored
546}
547
548fn compute_lcs_length(a: &str, b: &str) -> usize {
549 let a_chars: Vec<char> = a.chars().collect();
550 let b_chars: Vec<char> = b.chars().collect();
551 let m = a_chars.len();
552 let n = b_chars.len();
553
554 if m == 0 || n == 0 {
555 return 0;
556 }
557
558 let mut prev = vec![0; n + 1];
559 let mut curr = vec![0; n + 1];
560
561 for i in 1..=m {
562 for j in 1..=n {
563 if a_chars[i - 1] == b_chars[j - 1] {
564 curr[j] = prev[j - 1] + 1;
565 } else {
566 curr[j] = prev[j].max(curr[j - 1]);
567 }
568 }
569 std::mem::swap(&mut prev, &mut curr);
570 curr.fill(0);
571 }
572
573 prev[n]
574}
575
576fn filter_edit_history_by_path<'a>(
577 edit_history: &'a [Arc<zeta_prompt::Event>],
578 cursor_path: &std::path::Path,
579) -> Vec<&'a zeta_prompt::Event> {
580 edit_history
581 .iter()
582 .filter(|event| match event.as_ref() {
583 zeta_prompt::Event::BufferChange { path, .. } => {
584 let event_path = path.as_ref();
585 if event_path == cursor_path {
586 return true;
587 }
588 let stripped = event_path
589 .components()
590 .skip(1)
591 .collect::<std::path::PathBuf>();
592 stripped == cursor_path
593 }
594 })
595 .map(|arc| arc.as_ref())
596 .collect()
597}
598
599fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str {
600 match event {
601 zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(),
602 }
603}
604
605pub fn compute_prediction_reversal_ratio(
606 prompt_inputs: &ExamplePromptInputs,
607 predicted_content: &str,
608 cursor_path: &Path,
609) -> f32 {
610 let current_content = &prompt_inputs.content;
611
612 let edit_history: &[Arc<zeta_prompt::Event>] = &prompt_inputs.edit_history;
613 let relevant_events = filter_edit_history_by_path(edit_history, cursor_path);
614
615 if let Some(excerpt_start_row) = prompt_inputs.excerpt_start_row {
616 let diffs: Vec<&str> = relevant_events
617 .iter()
618 .map(|e| extract_diff_from_event(e))
619 .collect();
620 let overlap = compute_excerpt_aware_reversal_overlap(
621 &diffs,
622 current_content,
623 excerpt_start_row,
624 predicted_content,
625 );
626 return overlap.ratio();
627 }
628
629 let mut original_content = current_content.to_string();
630 for event in relevant_events.into_iter().rev() {
631 let diff = extract_diff_from_event(event);
632 if diff.is_empty() {
633 continue;
634 }
635 let reversed = reverse_diff(diff);
636 let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed);
637 match apply_diff_to_string(&with_headers, &original_content) {
638 Ok(updated_content) => original_content = updated_content,
639 Err(_) => {
640 original_content = apply_diff_to_string_lenient(&reversed, &original_content);
641 }
642 }
643 }
644
645 let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content);
646 overlap.ratio()
647}
648
649#[cfg(test)]
650mod tests {
651 use super::*;
652 use edit_prediction::udiff::apply_diff_to_string;
653 use indoc::indoc;
654
655 #[test]
656 fn test_reversal_overlap() {
657 struct Case {
658 name: &'static str,
659 original: &'static str,
660 current: &'static str,
661 predicted: &'static str,
662 expected_reversal_chars: usize,
663 expected_total_chars: usize,
664 }
665
666 let cases = [
667 Case {
668 name: "user_adds_line_prediction_removes_it",
669 original: indoc! {"
670 a
671 b
672 c"},
673 current: indoc! {"
674 a
675 new line
676 b
677 c"},
678 predicted: indoc! {"
679 a
680 b
681 c"},
682 expected_reversal_chars: 9,
683 expected_total_chars: 9,
684 },
685 Case {
686 name: "user_deletes_line_prediction_restores_it",
687 original: indoc! {"
688 a
689 deleted
690 b"},
691 current: indoc! {"
692 a
693 b"},
694 predicted: indoc! {"
695 a
696 deleted
697 b"},
698 expected_reversal_chars: 8,
699 expected_total_chars: 8,
700 },
701 Case {
702 name: "user_deletes_text_prediction_restores_partial",
703 original: "hello beautiful world",
704 current: "hello world",
705 predicted: "hello beautiful world",
706 expected_reversal_chars: 10,
707 expected_total_chars: 10,
708 },
709 Case {
710 name: "user_deletes_foo_prediction_adds_bar",
711 original: "foo",
712 current: "",
713 predicted: "bar",
714 expected_reversal_chars: 0,
715 expected_total_chars: 3,
716 },
717 Case {
718 name: "independent_edits_different_locations",
719 original: indoc! {"
720 line1
721 line2
722 line3"},
723 current: indoc! {"
724 LINE1
725 line2
726 line3"},
727 predicted: indoc! {"
728 LINE1
729 line2
730 LINE3"},
731 expected_reversal_chars: 0,
732 expected_total_chars: 10,
733 },
734 Case {
735 name: "no_history_edits",
736 original: "same",
737 current: "same",
738 predicted: "different",
739 expected_reversal_chars: 0,
740 expected_total_chars: 13,
741 },
742 Case {
743 name: "user_replaces_text_prediction_reverses",
744 original: indoc! {"
745 keep
746 delete_me
747 keep2"},
748 current: indoc! {"
749 keep
750 added
751 keep2"},
752 predicted: indoc! {"
753 keep
754 delete_me
755 keep2"},
756 expected_reversal_chars: 14,
757 expected_total_chars: 14,
758 },
759 Case {
760 name: "user_modifies_word_prediction_modifies_differently",
761 original: "the quick brown fox",
762 current: "the slow brown fox",
763 predicted: "the fast brown fox",
764 expected_reversal_chars: 4,
765 expected_total_chars: 8,
766 },
767 Case {
768 name: "user finishes function name (suffix)",
769 original: "",
770 current: "epr",
771 predicted: "eprintln!()",
772 expected_reversal_chars: 0,
773 expected_total_chars: 8,
774 },
775 Case {
776 name: "user starts function name (prefix)",
777 original: "",
778 current: "my_function()",
779 predicted: "test_my_function()",
780 expected_reversal_chars: 0,
781 expected_total_chars: 5,
782 },
783 Case {
784 name: "user types partial, prediction extends in multiple places",
785 original: "",
786 current: "test_my_function",
787 predicted: "a_test_for_my_special_function_plz",
788 expected_reversal_chars: 0,
789 expected_total_chars: 18,
790 },
791 // Edge cases for subsequence matching
792 Case {
793 name: "subsequence with interleaved underscores",
794 original: "",
795 current: "a_b_c",
796 predicted: "_a__b__c__",
797 expected_reversal_chars: 0,
798 expected_total_chars: 5,
799 },
800 Case {
801 name: "not a subsequence - different characters",
802 original: "",
803 current: "abc",
804 predicted: "xyz",
805 expected_reversal_chars: 3,
806 expected_total_chars: 6,
807 },
808 Case {
809 name: "not a subsequence - wrong order",
810 original: "",
811 current: "abc",
812 predicted: "cba",
813 expected_reversal_chars: 3,
814 expected_total_chars: 6,
815 },
816 Case {
817 name: "partial subsequence - only some chars match",
818 original: "",
819 current: "abcd",
820 predicted: "axbx",
821 expected_reversal_chars: 4,
822 expected_total_chars: 8,
823 },
824 // Common completion patterns
825 Case {
826 name: "completing a method call",
827 original: "",
828 current: "vec.pu",
829 predicted: "vec.push(item)",
830 expected_reversal_chars: 0,
831 expected_total_chars: 8,
832 },
833 Case {
834 name: "completing an import statement",
835 original: "",
836 current: "use std::col",
837 predicted: "use std::collections::HashMap",
838 expected_reversal_chars: 0,
839 expected_total_chars: 17,
840 },
841 Case {
842 name: "completing a struct field",
843 original: "",
844 current: "name: St",
845 predicted: "name: String",
846 expected_reversal_chars: 0,
847 expected_total_chars: 4,
848 },
849 Case {
850 name: "prediction replaces with completely different text",
851 original: "",
852 current: "hello",
853 predicted: "world",
854 expected_reversal_chars: 5,
855 expected_total_chars: 10,
856 },
857 Case {
858 name: "empty prediction removes user text",
859 original: "",
860 current: "mistake",
861 predicted: "",
862 expected_reversal_chars: 7,
863 expected_total_chars: 7,
864 },
865 Case {
866 name: "fixing typo is not reversal",
867 original: "",
868 current: "<dv",
869 predicted: "<div>",
870 expected_reversal_chars: 0,
871 expected_total_chars: 2,
872 },
873 Case {
874 name: "infix insertion not reversal",
875 original: indoc! {"
876 from my_project import Foo
877 "},
878 current: indoc! {"
879 ifrom my_project import Foo
880 "},
881 predicted: indoc! {"
882 import
883 from my_project import Foo
884 "},
885 expected_reversal_chars: 0,
886 expected_total_chars: 6,
887 },
888 Case {
889 name: "non-word based reversal",
890 original: "from",
891 current: "ifrom",
892 predicted: "from",
893 expected_reversal_chars: 1,
894 expected_total_chars: 1,
895 },
896 Case {
897 name: "multiple insertions no reversal",
898 original: "print(\"Hello, World!\")",
899 current: "sys.(\"Hello, World!\")",
900 predicted: "sys.stdout.write(\"Hello, World!\\n\")",
901 expected_reversal_chars: 0,
902 expected_total_chars: 14,
903 },
904 ];
905
906 for case in &cases {
907 let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
908 assert_eq!(
909 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
910 "Test '{}': expected {} reversal chars, got {}",
911 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
912 );
913 assert_eq!(
914 overlap.total_chars_in_prediction, case.expected_total_chars,
915 "Test '{}': expected {} total chars, got {}",
916 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
917 );
918 }
919 }
920
921 #[test]
922 fn test_reverse_diff() {
923 let forward_diff = indoc! {"
924 --- a/file.rs
925 +++ b/file.rs
926 @@ -1,3 +1,4 @@
927 fn main() {
928 + let x = 42;
929 println!(\"hello\");
930 }"};
931
932 let reversed = reverse_diff(forward_diff);
933
934 assert!(
935 reversed.contains("+++ a/file.rs"),
936 "Should have +++ for old path"
937 );
938 assert!(
939 reversed.contains("--- b/file.rs"),
940 "Should have --- for new path"
941 );
942 assert!(
943 reversed.contains("- let x = 42;"),
944 "Added line should become deletion"
945 );
946 assert!(
947 reversed.contains(" fn main()"),
948 "Context lines should be unchanged"
949 );
950 }
951
952 #[test]
953 fn test_reverse_diff_roundtrip() {
954 // Applying a diff and then its reverse should get back to original
955 let original = indoc! {"
956 first line
957 hello world
958 last line
959 "};
960 let modified = indoc! {"
961 first line
962 hello beautiful world
963 last line
964 "};
965
966 // unified_diff doesn't include file headers, but apply_diff_to_string needs them
967 let diff_body = language::unified_diff(original, modified);
968 let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body);
969 let reversed_diff = reverse_diff(&forward_diff);
970
971 // Apply forward diff to original
972 let after_forward = apply_diff_to_string(&forward_diff, original).unwrap();
973 assert_eq!(after_forward, modified);
974
975 // Apply reversed diff to modified
976 let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap();
977 assert_eq!(after_reverse, original);
978 }
979
980 #[test]
981 fn test_filter_edit_history_by_path() {
982 // Test that filter_edit_history_by_path correctly matches paths when
983 // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs")
984 // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs")
985 let events = vec![
986 Arc::new(zeta_prompt::Event::BufferChange {
987 path: Arc::from(Path::new("myrepo/src/file.rs")),
988 old_path: Arc::from(Path::new("myrepo/src/file.rs")),
989 diff: indoc! {"
990 @@ -1 +1 @@
991 -old
992 +new"}
993 .into(),
994 predicted: false,
995 in_open_source_repo: true,
996 }),
997 Arc::new(zeta_prompt::Event::BufferChange {
998 path: Arc::from(Path::new("myrepo/other.rs")),
999 old_path: Arc::from(Path::new("myrepo/other.rs")),
1000 diff: indoc! {"
1001 @@ -1 +1 @@
1002 -a
1003 +b"}
1004 .into(),
1005 predicted: false,
1006 in_open_source_repo: true,
1007 }),
1008 Arc::new(zeta_prompt::Event::BufferChange {
1009 path: Arc::from(Path::new("src/file.rs")),
1010 old_path: Arc::from(Path::new("src/file.rs")),
1011 diff: indoc! {"
1012 @@ -1 +1 @@
1013 -x
1014 +y"}
1015 .into(),
1016 predicted: false,
1017 in_open_source_repo: true,
1018 }),
1019 ];
1020
1021 // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path
1022 // "src/file.rs" exact match
1023 let cursor_path = Path::new("src/file.rs");
1024 let filtered = filter_edit_history_by_path(&events, cursor_path);
1025 assert_eq!(
1026 filtered.len(),
1027 2,
1028 "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)"
1029 );
1030
1031 // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs"
1032 // "src/file.rs" stripped -> "file.rs" == "file.rs"
1033 let cursor_path = Path::new("file.rs");
1034 let filtered = filter_edit_history_by_path(&events, cursor_path);
1035 assert_eq!(
1036 filtered.len(),
1037 1,
1038 "Should only match src/file.rs (stripped to file.rs)"
1039 );
1040
1041 // "myrepo/other.rs" stripped -> "other.rs" == "other.rs"
1042 let cursor_path = Path::new("other.rs");
1043 let filtered = filter_edit_history_by_path(&events, cursor_path);
1044 assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs");
1045 }
1046
1047 #[test]
1048 fn test_reverse_diff_preserves_trailing_newline() {
1049 let diff_with_trailing_newline = indoc! {"
1050 --- a/file
1051 +++ b/file
1052 @@ -1 +1 @@
1053 -old
1054 +new
1055 "};
1056 let reversed = reverse_diff(diff_with_trailing_newline);
1057 assert!(
1058 reversed.ends_with('\n'),
1059 "Reversed diff should preserve trailing newline"
1060 );
1061
1062 let diff_without_trailing_newline = indoc! {"
1063 --- a/file
1064 +++ b/file
1065 @@ -1 +1 @@
1066 -old
1067 +new"};
1068 let reversed = reverse_diff(diff_without_trailing_newline);
1069 assert!(
1070 !reversed.ends_with('\n'),
1071 "Reversed diff should not add trailing newline if original didn't have one"
1072 );
1073 }
1074
1075 #[test]
1076 fn test_filter_hunks_by_excerpt_region() {
1077 struct Case {
1078 name: &'static str,
1079 diff: &'static str,
1080 excerpt_start_row: u32,
1081 excerpt_row_count: u32,
1082 expected_filtered_diff: &'static str,
1083 expected_line_offset: i32,
1084 }
1085
1086 let cases = [
1087 Case {
1088 name: "hunk_entirely_before_excerpt",
1089 diff: indoc! {"
1090 @@ -1,3 +1,4 @@
1091 line1
1092 +inserted
1093 line2
1094 line3
1095 "},
1096 excerpt_start_row: 10,
1097 excerpt_row_count: 5,
1098 expected_filtered_diff: "",
1099 expected_line_offset: 1,
1100 },
1101 Case {
1102 name: "hunk_entirely_inside_excerpt",
1103 diff: indoc! {"
1104 @@ -12,3 +12,4 @@
1105 line12
1106 +inserted
1107 line13
1108 line14
1109 "},
1110 excerpt_start_row: 10,
1111 excerpt_row_count: 10,
1112 expected_filtered_diff: indoc! {"
1113 @@ -2,3 +2,4 @@
1114 line12
1115 +inserted
1116 line13
1117 line14
1118 "},
1119 expected_line_offset: 1,
1120 },
1121 Case {
1122 name: "hunk_entirely_after_excerpt",
1123 diff: indoc! {"
1124 @@ -50,3 +50,4 @@
1125 line50
1126 +inserted
1127 line51
1128 line52
1129 "},
1130 excerpt_start_row: 10,
1131 excerpt_row_count: 5,
1132 expected_filtered_diff: "",
1133 expected_line_offset: 0,
1134 },
1135 Case {
1136 name: "hunk_straddles_excerpt_start",
1137 diff: indoc! {"
1138 @@ -8,5 +8,6 @@
1139 line8
1140 line9
1141 +inserted
1142 line10
1143 line11
1144 line12
1145 "},
1146 excerpt_start_row: 10,
1147 excerpt_row_count: 10,
1148 expected_filtered_diff: indoc! {"
1149 @@ -1,3 +1,3 @@
1150 line10
1151 line11
1152 line12
1153 "},
1154 expected_line_offset: 1,
1155 },
1156 Case {
1157 name: "hunk_straddles_excerpt_end",
1158 diff: indoc! {"
1159 @@ -18,5 +18,6 @@
1160 line18
1161 line19
1162 +inserted
1163 line20
1164 line21
1165 line22
1166 "},
1167 excerpt_start_row: 10,
1168 excerpt_row_count: 10,
1169 expected_filtered_diff: indoc! {"
1170 @@ -8,2 +8,3 @@
1171 line18
1172 line19
1173 +inserted
1174 "},
1175 expected_line_offset: 1,
1176 },
1177 Case {
1178 name: "multiple_hunks_mixed",
1179 diff: indoc! {"
1180 @@ -1,2 +1,3 @@
1181 line1
1182 +before_excerpt
1183 line2
1184 @@ -12,2 +13,3 @@
1185 line12
1186 +inside_excerpt
1187 line13
1188 @@ -50,2 +52,3 @@
1189 line50
1190 +after_excerpt
1191 line51
1192 "},
1193 excerpt_start_row: 10,
1194 excerpt_row_count: 10,
1195 expected_filtered_diff: indoc! {"
1196 @@ -3,2 +3,3 @@
1197 line12
1198 +inside_excerpt
1199 line13
1200 "},
1201 expected_line_offset: 2,
1202 },
1203 Case {
1204 name: "deletion_before_excerpt",
1205 diff: indoc! {"
1206 @@ -1,4 +1,3 @@
1207 line1
1208 -deleted
1209 line2
1210 line3
1211 "},
1212 excerpt_start_row: 10,
1213 excerpt_row_count: 5,
1214 expected_filtered_diff: "",
1215 expected_line_offset: -1,
1216 },
1217 Case {
1218 name: "deletion_inside_excerpt",
1219 diff: indoc! {"
1220 @@ -12,4 +12,3 @@
1221 line12
1222 -deleted
1223 line13
1224 line14
1225 "},
1226 excerpt_start_row: 10,
1227 excerpt_row_count: 10,
1228 expected_filtered_diff: indoc! {"
1229 @@ -2,4 +2,3 @@
1230 line12
1231 -deleted
1232 line13
1233 line14
1234 "},
1235 expected_line_offset: -1,
1236 },
1237 Case {
1238 name: "empty_diff",
1239 diff: "",
1240 excerpt_start_row: 10,
1241 excerpt_row_count: 5,
1242 expected_filtered_diff: "",
1243 expected_line_offset: 0,
1244 },
1245 Case {
1246 name: "hunk_spans_entire_excerpt",
1247 diff: indoc! {"
1248 @@ -8,10 +8,12 @@
1249 line8
1250 line9
1251 line10
1252 line11
1253 +inserted1
1254 line12
1255 line13
1256 +inserted2
1257 line14
1258 line15
1259 line16
1260 line17
1261 "},
1262 excerpt_start_row: 10,
1263 excerpt_row_count: 5,
1264 expected_filtered_diff: indoc! {"
1265 @@ -1,3 +1,5 @@
1266 line11
1267 +inserted1
1268 line12
1269 line13
1270 +inserted2
1271 "},
1272 expected_line_offset: 2,
1273 },
1274 Case {
1275 name: "replacement_inside_excerpt",
1276 diff: indoc! {"
1277 @@ -12,3 +12,3 @@
1278 line12
1279 -old_text
1280 +new_text
1281 line14
1282 "},
1283 excerpt_start_row: 10,
1284 excerpt_row_count: 10,
1285 expected_filtered_diff: indoc! {"
1286 @@ -2,3 +2,3 @@
1287 line12
1288 -old_text
1289 +new_text
1290 line14
1291 "},
1292 expected_line_offset: 0,
1293 },
1294 ];
1295
1296 for case in &cases {
1297 let (filtered, line_offset) = filter_diff_hunks_by_excerpt(
1298 case.diff,
1299 case.excerpt_start_row,
1300 case.excerpt_row_count,
1301 );
1302 assert_eq!(
1303 filtered, case.expected_filtered_diff,
1304 "Test '{}': filtered diff mismatch.\nExpected:\n{}\nGot:\n{}",
1305 case.name, case.expected_filtered_diff, filtered
1306 );
1307 assert_eq!(
1308 line_offset, case.expected_line_offset,
1309 "Test '{}': line offset mismatch. Expected {}, got {}",
1310 case.name, case.expected_line_offset, line_offset
1311 );
1312 }
1313 }
1314
1315 #[test]
1316 fn test_excerpt_aware_reversal_tracking() {
1317 struct Case {
1318 name: &'static str,
1319 edit_history_diffs: Vec<&'static str>,
1320 excerpt_content: &'static str,
1321 excerpt_start_row: u32,
1322 predicted_content: &'static str,
1323 expected_reversal_chars: usize,
1324 expected_total_chars: usize,
1325 }
1326
1327 let cases = [
1328 Case {
1329 name: "edit_outside_excerpt_no_reversal",
1330 edit_history_diffs: vec![indoc! {"
1331 @@ -1,2 +1,3 @@
1332 line1
1333 +added_outside
1334 line2
1335 "}],
1336 excerpt_content: indoc! {"
1337 line10
1338 line11
1339 line12
1340 "},
1341 excerpt_start_row: 10,
1342 predicted_content: indoc! {"
1343 line10
1344 modified
1345 line12
1346 "},
1347 expected_reversal_chars: 0,
1348 expected_total_chars: 14,
1349 },
1350 Case {
1351 name: "edit_inside_excerpt_with_reversal",
1352 edit_history_diffs: vec![indoc! {"
1353 @@ -10,3 +10,4 @@
1354 line10
1355 +user_added
1356 line11
1357 line12
1358 "}],
1359 excerpt_content: indoc! {"
1360 line10
1361 user_added
1362 line11
1363 line12
1364 "},
1365 excerpt_start_row: 10,
1366 predicted_content: indoc! {"
1367 line10
1368 line11
1369 line12
1370 "},
1371 expected_reversal_chars: 11,
1372 expected_total_chars: 11,
1373 },
1374 Case {
1375 name: "straddling_edit_partial_reversal",
1376 edit_history_diffs: vec![indoc! {"
1377 @@ -8,6 +8,8 @@
1378 line8
1379 line9
1380 +before_excerpt
1381 line10
1382 +inside_excerpt
1383 line11
1384 line12
1385 line13
1386 "}],
1387 excerpt_content: indoc! {"
1388 line10
1389 inside_excerpt
1390 line11
1391 line12
1392 line13
1393 "},
1394 excerpt_start_row: 10,
1395 predicted_content: indoc! {"
1396 line10
1397 line11
1398 line12
1399 line13
1400 "},
1401 expected_reversal_chars: 15,
1402 expected_total_chars: 15,
1403 },
1404 Case {
1405 name: "multiple_edits_mixed_locations",
1406 edit_history_diffs: vec![
1407 indoc! {"
1408 @@ -1,2 +1,3 @@
1409 line1
1410 +outside1
1411 line2
1412 "},
1413 indoc! {"
1414 @@ -11,2 +12,3 @@
1415 line11
1416 +inside1
1417 line12
1418 "},
1419 ],
1420 excerpt_content: indoc! {"
1421 line10
1422 line11
1423 inside1
1424 line12
1425 line13
1426 "},
1427 excerpt_start_row: 10,
1428 predicted_content: indoc! {"
1429 line10
1430 line11
1431 line12
1432 line13
1433 "},
1434 expected_reversal_chars: 8,
1435 expected_total_chars: 8,
1436 },
1437 Case {
1438 name: "no_edit_history",
1439 edit_history_diffs: vec![],
1440 excerpt_content: indoc! {"
1441 line10
1442 line11
1443 line12
1444 "},
1445 excerpt_start_row: 10,
1446 predicted_content: indoc! {"
1447 line10
1448 modified
1449 line12
1450 "},
1451 expected_reversal_chars: 0,
1452 expected_total_chars: 14,
1453 },
1454 Case {
1455 name: "edit_after_excerpt_no_effect",
1456 edit_history_diffs: vec![indoc! {"
1457 @@ -50,2 +50,3 @@
1458 line50
1459 +added_after
1460 line51
1461 "}],
1462 excerpt_content: indoc! {"
1463 line10
1464 line11
1465 line12
1466 "},
1467 excerpt_start_row: 10,
1468 predicted_content: indoc! {"
1469 line10
1470 changed
1471 line12
1472 "},
1473 expected_reversal_chars: 0,
1474 expected_total_chars: 13,
1475 },
1476 Case {
1477 name: "line_offset_tracking_across_hunks",
1478 edit_history_diffs: vec![
1479 indoc! {"
1480 @@ -1,2 +1,4 @@
1481 line1
1482 +added1
1483 +added2
1484 line2
1485 "},
1486 indoc! {"
1487 @@ -12,2 +14,3 @@
1488 line12
1489 +inside_after_offset
1490 line13
1491 "},
1492 ],
1493 excerpt_content: indoc! {"
1494 line10
1495 line11
1496 line12
1497 inside_after_offset
1498 line13
1499 "},
1500 excerpt_start_row: 10,
1501 predicted_content: indoc! {"
1502 line10
1503 line11
1504 line12
1505 line13
1506 "},
1507 expected_reversal_chars: 20,
1508 expected_total_chars: 20,
1509 },
1510 ];
1511
1512 for case in &cases {
1513 let overlap = compute_excerpt_aware_reversal_overlap(
1514 &case.edit_history_diffs,
1515 case.excerpt_content,
1516 case.excerpt_start_row,
1517 case.predicted_content,
1518 );
1519 assert_eq!(
1520 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
1521 "Test '{}': expected {} reversal chars, got {}",
1522 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
1523 );
1524 assert_eq!(
1525 overlap.total_chars_in_prediction, case.expected_total_chars,
1526 "Test '{}': expected {} total chars, got {}",
1527 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
1528 );
1529 }
1530 }
1531
1532 #[test]
1533 fn test_lenient_diff_application() {
1534 struct Case {
1535 name: &'static str,
1536 diff: &'static str,
1537 content: &'static str,
1538 expected_result: &'static str,
1539 }
1540
1541 let cases = [
1542 Case {
1543 name: "hunk_context_not_found_skipped",
1544 diff: indoc! {"
1545 @@ -1,3 +1,4 @@
1546 context_not_in_content
1547 +added_line
1548 more_context
1549 final_context
1550 "},
1551 content: indoc! {"
1552 completely
1553 different
1554 content
1555 "},
1556 expected_result: indoc! {"
1557 completely
1558 different
1559 content
1560 "},
1561 },
1562 Case {
1563 name: "hunk_context_found_applied",
1564 diff: indoc! {"
1565 @@ -1,3 +1,4 @@
1566 line1
1567 +inserted
1568 line2
1569 line3
1570 "},
1571 content: indoc! {"
1572 line1
1573 line2
1574 line3
1575 "},
1576 expected_result: indoc! {"
1577 line1
1578 inserted
1579 line2
1580 line3
1581 "},
1582 },
1583 Case {
1584 name: "multiple_hunks_partial_match",
1585 diff: indoc! {"
1586 @@ -1,2 +1,3 @@
1587 not_found
1588 +skipped
1589 also_not_found
1590 @@ -5,2 +6,3 @@
1591 line5
1592 +applied
1593 line6
1594 "},
1595 content: indoc! {"
1596 line1
1597 line2
1598 line3
1599 line4
1600 line5
1601 line6
1602 "},
1603 expected_result: indoc! {"
1604 line1
1605 line2
1606 line3
1607 line4
1608 line5
1609 applied
1610 line6
1611 "},
1612 },
1613 Case {
1614 name: "empty_diff",
1615 diff: "",
1616 content: indoc! {"
1617 unchanged
1618 content
1619 "},
1620 expected_result: indoc! {"
1621 unchanged
1622 content
1623 "},
1624 },
1625 ];
1626
1627 for case in &cases {
1628 let result = apply_diff_to_string_lenient(case.diff, case.content);
1629 assert_eq!(
1630 result, case.expected_result,
1631 "Test '{}': expected:\n{}\ngot:\n{}",
1632 case.name, case.expected_result, result
1633 );
1634 }
1635 }
1636
1637 #[test]
1638 fn test_unicode_reversal_overlap() {
1639 struct Case {
1640 name: &'static str,
1641 original: &'static str,
1642 current: &'static str,
1643 predicted: &'static str,
1644 expected_reversal_chars: usize,
1645 expected_total_chars: usize,
1646 }
1647
1648 let cases = [
1649 Case {
1650 name: "unicode_extension_cjk",
1651 original: "",
1652 current: "日", // 1 char
1653 predicted: "日本語", // 3 chars, adds 2 chars
1654 expected_reversal_chars: 0,
1655 expected_total_chars: 2, // "本語" = 2 chars added
1656 },
1657 Case {
1658 name: "unicode_extension_emoji",
1659 original: "",
1660 current: "🎉", // 1 char
1661 predicted: "🎉🎊🎈", // 3 chars, adds 2 chars
1662 expected_reversal_chars: 0,
1663 expected_total_chars: 2, // "🎊🎈" = 2 chars added
1664 },
1665 Case {
1666 name: "unicode_deletion_restored",
1667 original: "héllo wörld", // 11 chars
1668 current: "héllo", // 5 chars
1669 predicted: "héllo wörld", // restores " wörld" = 6 chars
1670 expected_reversal_chars: 6, // LCS(" wörld", " wörld") = 6 chars
1671 expected_total_chars: 6,
1672 },
1673 Case {
1674 name: "unicode_addition_reversed",
1675 original: "café", // 4 chars
1676 current: "café latté", // 10 chars, added " latté" = 6 chars
1677 predicted: "café", // removes " latté"
1678 expected_reversal_chars: 6, // 6 chars removed
1679 expected_total_chars: 6,
1680 },
1681 Case {
1682 name: "mixed_ascii_unicode",
1683 original: "",
1684 current: "test日本", // 6 chars
1685 predicted: "test日本語です", // 9 chars
1686 expected_reversal_chars: 0,
1687 expected_total_chars: 3, // 3 new chars after subsequence normalization
1688 },
1689 Case {
1690 name: "unicode_replacement_not_subsequence",
1691 original: "",
1692 current: "日本", // 2 chars
1693 predicted: "中国", // 2 chars, different
1694 expected_reversal_chars: 2, // removes "日本" = 2 chars
1695 expected_total_chars: 4, // 2 removed + 2 added
1696 },
1697 ];
1698
1699 for case in &cases {
1700 let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
1701 assert_eq!(
1702 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
1703 "Test '{}': expected {} reversal chars, got {}",
1704 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
1705 );
1706 assert_eq!(
1707 overlap.total_chars_in_prediction, case.expected_total_chars,
1708 "Test '{}': expected {} total chars, got {}",
1709 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
1710 );
1711 }
1712 }
1713
1714 #[test]
1715 fn test_compute_lcs_length() {
1716 assert_eq!(compute_lcs_length("", ""), 0);
1717 assert_eq!(compute_lcs_length("abc", ""), 0);
1718 assert_eq!(compute_lcs_length("", "abc"), 0);
1719 assert_eq!(compute_lcs_length("abc", "abc"), 3);
1720 assert_eq!(compute_lcs_length("abc", "def"), 0);
1721 assert_eq!(compute_lcs_length("abcdef", "ace"), 3);
1722 assert_eq!(compute_lcs_length("AGGTAB", "GXTXAYB"), 4);
1723 assert_eq!(compute_lcs_length("日本語", "日語"), 2);
1724 }
1725
1726 #[test]
1727 fn test_compute_prediction_reversal_ratio_full_file() {
1728 let prompt_inputs = ExamplePromptInputs {
1729 content: indoc! {"
1730 line1
1731 user_added
1732 line2
1733 "}
1734 .to_string(),
1735 cursor_row: 0,
1736 cursor_column: 0,
1737 cursor_offset: 0,
1738 edit_history: vec![Arc::new(zeta_prompt::Event::BufferChange {
1739 path: Arc::from(Path::new("src/test.rs")),
1740 old_path: Arc::from(Path::new("src/test.rs")),
1741 diff: indoc! {"
1742 @@ -1,2 +1,3 @@
1743 line1
1744 +user_added
1745 line2
1746 "}
1747 .into(),
1748 predicted: false,
1749 in_open_source_repo: false,
1750 })],
1751 excerpt_start_row: None,
1752 related_files: None,
1753 };
1754
1755 let predicted = indoc! {"
1756 line1
1757 line2
1758 "};
1759 let ratio =
1760 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1761
1762 assert!(
1763 ratio > 0.9,
1764 "Expected high reversal ratio when prediction removes user addition, got {}",
1765 ratio
1766 );
1767 }
1768
1769 #[test]
1770 fn test_compute_prediction_reversal_ratio_with_excerpt() {
1771 let prompt_inputs = ExamplePromptInputs {
1772 content: indoc! {"
1773 line10
1774 user_added
1775 line11
1776 "}
1777 .to_string(),
1778 cursor_row: 0,
1779 cursor_column: 0,
1780 cursor_offset: 0,
1781 edit_history: vec![Arc::new(zeta_prompt::Event::BufferChange {
1782 path: Arc::from(Path::new("src/test.rs")),
1783 old_path: Arc::from(Path::new("src/test.rs")),
1784 diff: indoc! {"
1785 @@ -10,2 +10,3 @@
1786 line10
1787 +user_added
1788 line11
1789 "}
1790 .into(),
1791 predicted: false,
1792 in_open_source_repo: false,
1793 })],
1794 excerpt_start_row: Some(10),
1795 related_files: None,
1796 };
1797
1798 let predicted = indoc! {"
1799 line10
1800 line11
1801 "};
1802 let ratio =
1803 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1804
1805 assert!(
1806 ratio > 0.9,
1807 "Expected high reversal ratio for excerpt-aware computation, got {}",
1808 ratio
1809 );
1810 }
1811
1812 #[test]
1813 fn test_compute_prediction_reversal_ratio_no_history() {
1814 let prompt_inputs = ExamplePromptInputs {
1815 content: indoc! {"
1816 original content
1817 "}
1818 .to_string(),
1819 cursor_row: 0,
1820 cursor_column: 0,
1821 cursor_offset: 0,
1822 edit_history: vec![],
1823 excerpt_start_row: None,
1824 related_files: None,
1825 };
1826
1827 let predicted = indoc! {"
1828 completely different
1829 "};
1830 let ratio =
1831 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1832
1833 assert_eq!(
1834 ratio, 0.0,
1835 "Expected zero reversal ratio with no edit history"
1836 );
1837 }
1838
1839 #[test]
1840 fn test_compute_prediction_reversal_ratio_path_filtering() {
1841 let prompt_inputs = ExamplePromptInputs {
1842 content: indoc! {"
1843 line1
1844 user_added
1845 line2
1846 "}
1847 .to_string(),
1848 cursor_row: 0,
1849 cursor_column: 0,
1850 cursor_offset: 0,
1851 edit_history: vec![Arc::new(zeta_prompt::Event::BufferChange {
1852 path: Arc::from(Path::new("src/other.rs")),
1853 old_path: Arc::from(Path::new("src/other.rs")),
1854 diff: indoc! {"
1855 @@ -1,2 +1,3 @@
1856 line1
1857 +user_added
1858 line2
1859 "}
1860 .into(),
1861 predicted: false,
1862 in_open_source_repo: false,
1863 })],
1864 excerpt_start_row: None,
1865 related_files: None,
1866 };
1867
1868 let predicted = indoc! {"
1869 line1
1870 line2
1871 "};
1872 let ratio =
1873 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1874
1875 assert_eq!(
1876 ratio, 0.0,
1877 "Expected zero reversal when edit history is for different file"
1878 );
1879 }
1880
1881 #[test]
1882 fn test_compute_prediction_reversal_ratio_lenient_fallback() {
1883 let prompt_inputs = ExamplePromptInputs {
1884 content: indoc! {"
1885 actual_line1
1886 user_added
1887 actual_line2
1888 "}
1889 .to_string(),
1890 cursor_row: 0,
1891 cursor_column: 0,
1892 cursor_offset: 0,
1893 edit_history: vec![Arc::new(zeta_prompt::Event::BufferChange {
1894 path: Arc::from(Path::new("src/test.rs")),
1895 old_path: Arc::from(Path::new("src/test.rs")),
1896 diff: indoc! {"
1897 @@ -1,2 +1,3 @@
1898 wrong_context
1899 +user_added
1900 more_wrong
1901 "}
1902 .into(),
1903 predicted: false,
1904 in_open_source_repo: false,
1905 })],
1906 excerpt_start_row: None,
1907 related_files: None,
1908 };
1909
1910 let predicted = indoc! {"
1911 actual_line1
1912 actual_line2
1913 "};
1914 let ratio =
1915 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1916
1917 assert!(
1918 ratio >= 0.0 && ratio <= 1.0,
1919 "Ratio should be valid even with lenient fallback, got {}",
1920 ratio
1921 );
1922 }
1923
1924 #[test]
1925 fn test_excerpt_aware_reversal_error_recovery() {
1926 let diffs = vec![indoc! {"
1927 @@ -1,2 +1,3 @@
1928 nonexistent_context
1929 +added
1930 more_nonexistent
1931 "}];
1932 let excerpt_content = indoc! {"
1933 completely
1934 different
1935 content
1936 "};
1937 let predicted_content = indoc! {"
1938 completely
1939 modified
1940 content
1941 "};
1942
1943 let overlap =
1944 compute_excerpt_aware_reversal_overlap(&diffs, excerpt_content, 0, predicted_content);
1945
1946 assert!(
1947 overlap.ratio() >= 0.0 && overlap.ratio() <= 1.0,
1948 "Should handle failed diff application gracefully"
1949 );
1950 }
1951
1952 #[test]
1953 fn test_multiple_sequential_diffs() {
1954 let prompt_inputs = ExamplePromptInputs {
1955 content: indoc! {"
1956 line1
1957 first_add
1958 second_add
1959 line2
1960 "}
1961 .to_string(),
1962 cursor_row: 0,
1963 cursor_column: 0,
1964 cursor_offset: 0,
1965 edit_history: vec![
1966 Arc::new(zeta_prompt::Event::BufferChange {
1967 path: Arc::from(Path::new("src/test.rs")),
1968 old_path: Arc::from(Path::new("src/test.rs")),
1969 diff: indoc! {"
1970 @@ -1,2 +1,3 @@
1971 line1
1972 +first_add
1973 line2
1974 "}
1975 .into(),
1976 predicted: false,
1977 in_open_source_repo: false,
1978 }),
1979 Arc::new(zeta_prompt::Event::BufferChange {
1980 path: Arc::from(Path::new("src/test.rs")),
1981 old_path: Arc::from(Path::new("src/test.rs")),
1982 diff: indoc! {"
1983 @@ -2,2 +2,3 @@
1984 first_add
1985 +second_add
1986 line2
1987 "}
1988 .into(),
1989 predicted: false,
1990 in_open_source_repo: false,
1991 }),
1992 ],
1993 excerpt_start_row: None,
1994 related_files: None,
1995 };
1996
1997 let predicted = indoc! {"
1998 line1
1999 line2
2000 "};
2001 let ratio =
2002 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
2003
2004 assert!(
2005 ratio > 0.9,
2006 "Expected high reversal ratio when reversing multiple sequential edits, got {}",
2007 ratio
2008 );
2009 }
2010}