1use std::ops::Range;
2use std::path::Path;
3use std::sync::Arc;
4
5use edit_prediction::udiff::apply_diff_to_string;
6use language::{char_diff, text_diff};
7
8use crate::example::ExamplePromptInputs;
9
10fn apply_diff_to_string_lenient(diff_str: &str, text: &str) -> String {
11 let hunks = parse_diff_hunks(diff_str);
12 let mut result = text.to_string();
13
14 for hunk in hunks {
15 let hunk_diff = format!("--- a/file\n+++ b/file\n{}", format_hunk(&hunk));
16 if let Ok(updated) = apply_diff_to_string(&hunk_diff, &result) {
17 result = updated;
18 }
19 }
20
21 result
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25struct ParsedHunk {
26 old_start: u32,
27 old_count: u32,
28 new_start: u32,
29 new_count: u32,
30 lines: Vec<HunkLine>,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34enum HunkLine {
35 Context(String),
36 Addition(String),
37 Deletion(String),
38}
39
40fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> {
41 let line = line.strip_prefix("@@ -")?;
42 let (old_part, rest) = line.split_once(' ')?;
43 let rest = rest.strip_prefix('+')?;
44 let (new_part, _) = rest.split_once(" @@")?;
45
46 let (old_start, old_count) = if let Some((start, count)) = old_part.split_once(',') {
47 (start.parse().ok()?, count.parse().ok()?)
48 } else {
49 (old_part.parse().ok()?, 1)
50 };
51
52 let (new_start, new_count) = if let Some((start, count)) = new_part.split_once(',') {
53 (start.parse().ok()?, count.parse().ok()?)
54 } else {
55 (new_part.parse().ok()?, 1)
56 };
57
58 Some((old_start, old_count, new_start, new_count))
59}
60
61fn parse_diff_hunks(diff: &str) -> Vec<ParsedHunk> {
62 let mut hunks = Vec::new();
63 let mut current_hunk: Option<ParsedHunk> = None;
64
65 for line in diff.lines() {
66 if let Some((old_start, old_count, new_start, new_count)) = parse_hunk_header(line) {
67 if let Some(hunk) = current_hunk.take() {
68 hunks.push(hunk);
69 }
70 current_hunk = Some(ParsedHunk {
71 old_start,
72 old_count,
73 new_start,
74 new_count,
75 lines: Vec::new(),
76 });
77 } else if let Some(ref mut hunk) = current_hunk {
78 if let Some(stripped) = line.strip_prefix('+') {
79 hunk.lines.push(HunkLine::Addition(stripped.to_string()));
80 } else if let Some(stripped) = line.strip_prefix('-') {
81 hunk.lines.push(HunkLine::Deletion(stripped.to_string()));
82 } else if let Some(stripped) = line.strip_prefix(' ') {
83 hunk.lines.push(HunkLine::Context(stripped.to_string()));
84 } else if line.is_empty() {
85 hunk.lines.push(HunkLine::Context(String::new()));
86 }
87 }
88 }
89
90 if let Some(hunk) = current_hunk {
91 hunks.push(hunk);
92 }
93
94 hunks
95}
96
97fn format_hunk(hunk: &ParsedHunk) -> String {
98 let mut result = format!(
99 "@@ -{},{} +{},{} @@\n",
100 hunk.old_start, hunk.old_count, hunk.new_start, hunk.new_count
101 );
102 for line in &hunk.lines {
103 match line {
104 HunkLine::Context(text) => {
105 result.push(' ');
106 result.push_str(text);
107 result.push('\n');
108 }
109 HunkLine::Addition(text) => {
110 result.push('+');
111 result.push_str(text);
112 result.push('\n');
113 }
114 HunkLine::Deletion(text) => {
115 result.push('-');
116 result.push_str(text);
117 result.push('\n');
118 }
119 }
120 }
121 result
122}
123
124fn filter_diff_hunks_by_excerpt(
125 diff: &str,
126 excerpt_start_row: u32,
127 excerpt_row_count: u32,
128) -> (String, i32) {
129 let hunks = parse_diff_hunks(diff);
130 let excerpt_start_0based = excerpt_start_row;
131 let excerpt_end_0based = excerpt_start_row + excerpt_row_count;
132
133 let mut filtered_hunks = Vec::new();
134 let mut cumulative_line_offset: i32 = 0;
135
136 for hunk in hunks {
137 let hunk_start_0based = hunk.new_start.saturating_sub(1);
138 let hunk_end_0based = hunk_start_0based + hunk.new_count;
139
140 let additions: i32 = hunk
141 .lines
142 .iter()
143 .filter(|l| matches!(l, HunkLine::Addition(_)))
144 .count() as i32;
145 let deletions: i32 = hunk
146 .lines
147 .iter()
148 .filter(|l| matches!(l, HunkLine::Deletion(_)))
149 .count() as i32;
150 let hunk_line_delta = additions - deletions;
151
152 if hunk_end_0based <= excerpt_start_0based {
153 cumulative_line_offset += hunk_line_delta;
154 continue;
155 }
156
157 if hunk_start_0based >= excerpt_end_0based {
158 continue;
159 }
160
161 let mut filtered_lines = Vec::new();
162 let mut current_row_0based = hunk_start_0based;
163 let mut filtered_old_count = 0u32;
164 let mut filtered_new_count = 0u32;
165 let mut first_included_row: Option<u32> = None;
166
167 for line in &hunk.lines {
168 match line {
169 HunkLine::Context(text) => {
170 if current_row_0based >= excerpt_start_0based
171 && current_row_0based < excerpt_end_0based
172 {
173 if first_included_row.is_none() {
174 first_included_row = Some(current_row_0based);
175 }
176 filtered_lines.push(HunkLine::Context(text.clone()));
177 filtered_old_count += 1;
178 filtered_new_count += 1;
179 }
180 current_row_0based += 1;
181 }
182 HunkLine::Addition(text) => {
183 if current_row_0based >= excerpt_start_0based
184 && current_row_0based < excerpt_end_0based
185 {
186 if first_included_row.is_none() {
187 first_included_row = Some(current_row_0based);
188 }
189 filtered_lines.push(HunkLine::Addition(text.clone()));
190 filtered_new_count += 1;
191 }
192 current_row_0based += 1;
193 }
194 HunkLine::Deletion(text) => {
195 if current_row_0based >= excerpt_start_0based
196 && current_row_0based < excerpt_end_0based
197 {
198 if first_included_row.is_none() {
199 first_included_row = Some(current_row_0based);
200 }
201 filtered_lines.push(HunkLine::Deletion(text.clone()));
202 filtered_old_count += 1;
203 }
204 }
205 }
206 }
207
208 if !filtered_lines.is_empty() {
209 let first_row = first_included_row.unwrap_or(excerpt_start_0based);
210 let new_start_1based = (first_row - excerpt_start_0based) + 1;
211
212 filtered_hunks.push(ParsedHunk {
213 old_start: new_start_1based,
214 old_count: filtered_old_count,
215 new_start: new_start_1based,
216 new_count: filtered_new_count,
217 lines: filtered_lines,
218 });
219 }
220
221 cumulative_line_offset += hunk_line_delta;
222 }
223
224 let mut result = String::new();
225 for hunk in &filtered_hunks {
226 result.push_str(&format_hunk(hunk));
227 }
228
229 (result, cumulative_line_offset)
230}
231
232fn compute_excerpt_aware_reversal_overlap(
233 edit_history_diffs: &[&str],
234 excerpt_content: &str,
235 excerpt_start_row: u32,
236 predicted_content: &str,
237) -> ReversalOverlap {
238 let mut current_content = excerpt_content.to_string();
239 let mut current_excerpt_start_row = excerpt_start_row;
240
241 for diff in edit_history_diffs.iter().rev() {
242 if diff.is_empty() {
243 continue;
244 }
245
246 let current_row_count = current_content.lines().count() as u32;
247 let (filtered_diff, _line_offset) =
248 filter_diff_hunks_by_excerpt(diff, current_excerpt_start_row, current_row_count.max(1));
249
250 if filtered_diff.is_empty() {
251 let hunks = parse_diff_hunks(diff);
252 for hunk in hunks {
253 let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
254 if hunk_end <= current_excerpt_start_row {
255 let additions: u32 = hunk
256 .lines
257 .iter()
258 .filter(|l| matches!(l, HunkLine::Addition(_)))
259 .count() as u32;
260 let deletions: u32 = hunk
261 .lines
262 .iter()
263 .filter(|l| matches!(l, HunkLine::Deletion(_)))
264 .count() as u32;
265 if additions >= deletions {
266 current_excerpt_start_row =
267 current_excerpt_start_row.saturating_sub(additions - deletions);
268 } else {
269 current_excerpt_start_row += deletions - additions;
270 }
271 }
272 }
273 continue;
274 }
275
276 let reversed = reverse_diff(&format!("--- a/file\n+++ b/file\n{}", filtered_diff));
277 match apply_diff_to_string(&reversed, ¤t_content) {
278 Ok(updated) => {
279 current_content = updated;
280 }
281 Err(_) => {
282 continue;
283 }
284 }
285
286 let hunks = parse_diff_hunks(diff);
287 for hunk in hunks {
288 let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
289 if hunk_end <= current_excerpt_start_row {
290 let additions: u32 = hunk
291 .lines
292 .iter()
293 .filter(|l| matches!(l, HunkLine::Addition(_)))
294 .count() as u32;
295 let deletions: u32 = hunk
296 .lines
297 .iter()
298 .filter(|l| matches!(l, HunkLine::Deletion(_)))
299 .count() as u32;
300 if additions >= deletions {
301 current_excerpt_start_row =
302 current_excerpt_start_row.saturating_sub(additions - deletions);
303 } else {
304 current_excerpt_start_row += deletions - additions;
305 }
306 }
307 }
308 }
309
310 compute_reversal_overlap(¤t_content, excerpt_content, predicted_content)
311}
312
313fn reverse_diff(diff: &str) -> String {
314 let mut result: String = diff
315 .lines()
316 .map(|line| {
317 if line.starts_with("--- ") {
318 line.replacen("--- ", "+++ ", 1)
319 } else if line.starts_with("+++ ") {
320 line.replacen("+++ ", "--- ", 1)
321 } else if line.starts_with('+') && !line.starts_with("+++") {
322 format!("-{}", &line[1..])
323 } else if line.starts_with('-') && !line.starts_with("---") {
324 format!("+{}", &line[1..])
325 } else {
326 line.to_string()
327 }
328 })
329 .collect::<Vec<_>>()
330 .join("\n");
331 if diff.ends_with('\n') {
332 result.push('\n');
333 }
334 result
335}
336
337#[derive(Debug, Clone, PartialEq, Eq)]
338struct GranularEdit {
339 range: Range<usize>,
340 old_text: String,
341 new_text: String,
342}
343
344fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
345 text_diff(old_text, new_text)
346 .into_iter()
347 .map(|(range, new_text)| GranularEdit {
348 old_text: old_text[range.clone()].to_string(),
349 range,
350 new_text: new_text.to_string(),
351 })
352 .collect()
353}
354
355#[derive(Debug, Clone)]
356struct HistoryAdditionRange {
357 range_in_current: Range<usize>,
358}
359
360#[derive(Debug, Clone)]
361struct HistoryDeletionRange {
362 deleted_text: String,
363 position_in_current: usize,
364}
365
366fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryAdditionRange> {
367 let mut result = Vec::new();
368 let mut offset_delta: isize = 0;
369
370 for edit in history_edits {
371 if !edit.new_text.is_empty() {
372 let new_start = (edit.range.start as isize + offset_delta) as usize;
373 let new_end = new_start + edit.new_text.len();
374 result.push(HistoryAdditionRange {
375 range_in_current: new_start..new_end,
376 });
377 }
378
379 offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
380 }
381
382 result
383}
384
385fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryDeletionRange> {
386 let mut result = Vec::new();
387 let mut offset_delta: isize = 0;
388
389 for edit in history_edits {
390 if !edit.old_text.is_empty() {
391 let position_in_current = (edit.range.start as isize + offset_delta) as usize;
392 result.push(HistoryDeletionRange {
393 deleted_text: edit.old_text.clone(),
394 position_in_current,
395 });
396 }
397
398 offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
399 }
400
401 result
402}
403
404#[derive(Debug, Clone, Default, PartialEq, Eq)]
405struct ReversalOverlap {
406 chars_reversing_user_edits: usize,
407 total_chars_in_prediction: usize,
408}
409
410impl ReversalOverlap {
411 fn ratio(&self) -> f32 {
412 if self.total_chars_in_prediction == 0 {
413 0.0
414 } else {
415 self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32
416 }
417 }
418}
419
420/// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension),
421/// or where `new_text` appears as a subsequence within `old_text` (reduction).
422///
423/// For extensions: when the user's text is preserved (in order) within the prediction,
424/// we only count the newly inserted characters, not the preserved ones.
425/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()")
426/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars
427///
428/// For reductions: when the prediction's text is preserved (in order) within the original,
429/// we only count the deleted characters, not the preserved ones.
430/// E.g., "ifrom" → "from" becomes 1 deleted char ("i")
431fn normalize_extension_edits(edits: Vec<GranularEdit>) -> Vec<GranularEdit> {
432 edits
433 .into_iter()
434 .flat_map(|edit| {
435 if edit.old_text.is_empty() || edit.new_text.is_empty() {
436 return vec![edit];
437 }
438
439 // Use character-wise diff to find exact byte ranges of changes
440 let char_edits = char_diff(&edit.old_text, &edit.new_text);
441
442 let all_deletions = !char_edits.is_empty()
443 && char_edits
444 .iter()
445 .all(|(range, replacement)| !range.is_empty() && replacement.is_empty());
446 let all_insertions = !char_edits.is_empty()
447 && char_edits
448 .iter()
449 .all(|(range, replacement)| range.is_empty() && !replacement.is_empty());
450 if all_deletions || all_insertions {
451 return char_edits
452 .into_iter()
453 .map(|(range, replacement)| GranularEdit {
454 range: edit.range.start + range.start..edit.range.start + range.end,
455 old_text: edit.old_text[range].to_string(),
456 new_text: replacement.to_string(),
457 })
458 .collect();
459 }
460
461 // Otherwise, keep the original edit (mixed changes)
462 vec![edit]
463 })
464 .collect()
465}
466
467fn compute_reversal_overlap(
468 original_content: &str,
469 current_content: &str,
470 predicted_content: &str,
471) -> ReversalOverlap {
472 let history_edits =
473 normalize_extension_edits(compute_granular_edits(original_content, current_content));
474 let prediction_edits =
475 normalize_extension_edits(compute_granular_edits(current_content, predicted_content));
476
477 let history_addition_ranges = compute_history_addition_ranges(&history_edits);
478 let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
479
480 let reversed_additions =
481 compute_reversed_additions(&history_addition_ranges, &prediction_edits);
482 let restored_deletions =
483 compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
484
485 let total_chars_in_prediction: usize = prediction_edits
486 .iter()
487 .map(|e| e.new_text.chars().count() + e.old_text.chars().count())
488 .sum();
489
490 ReversalOverlap {
491 chars_reversing_user_edits: reversed_additions + restored_deletions,
492 total_chars_in_prediction,
493 }
494}
495
496fn compute_reversed_additions(
497 history_addition_ranges: &[HistoryAdditionRange],
498 prediction_edits: &[GranularEdit],
499) -> usize {
500 let mut reversed_chars = 0;
501
502 for pred_edit in prediction_edits {
503 for history_addition in history_addition_ranges {
504 let overlap_start = pred_edit
505 .range
506 .start
507 .max(history_addition.range_in_current.start);
508 let overlap_end = pred_edit
509 .range
510 .end
511 .min(history_addition.range_in_current.end);
512
513 if overlap_start < overlap_end {
514 let relative_start = overlap_start - pred_edit.range.start;
515 let relative_end = overlap_end - pred_edit.range.start;
516 let overlap_text = &pred_edit.old_text[relative_start..relative_end];
517 reversed_chars += overlap_text.chars().count();
518 }
519 }
520 }
521
522 reversed_chars
523}
524
525fn compute_restored_deletions(
526 history_deletion_ranges: &[HistoryDeletionRange],
527 prediction_edits: &[GranularEdit],
528) -> usize {
529 let mut restored = 0;
530
531 for pred_edit in prediction_edits {
532 if pred_edit.new_text.is_empty() {
533 continue;
534 }
535
536 for deletion in history_deletion_ranges {
537 if pred_edit.range.contains(&deletion.position_in_current)
538 || deletion.position_in_current == pred_edit.range.start
539 {
540 restored += compute_lcs_length(&deletion.deleted_text, &pred_edit.new_text);
541 }
542 }
543 }
544
545 restored
546}
547
548fn compute_lcs_length(a: &str, b: &str) -> usize {
549 let a_chars: Vec<char> = a.chars().collect();
550 let b_chars: Vec<char> = b.chars().collect();
551 let m = a_chars.len();
552 let n = b_chars.len();
553
554 if m == 0 || n == 0 {
555 return 0;
556 }
557
558 let mut prev = vec![0; n + 1];
559 let mut curr = vec![0; n + 1];
560
561 for i in 1..=m {
562 for j in 1..=n {
563 if a_chars[i - 1] == b_chars[j - 1] {
564 curr[j] = prev[j - 1] + 1;
565 } else {
566 curr[j] = prev[j].max(curr[j - 1]);
567 }
568 }
569 std::mem::swap(&mut prev, &mut curr);
570 curr.fill(0);
571 }
572
573 prev[n]
574}
575
576fn filter_edit_history_by_path<'a>(
577 edit_history: &'a [Arc<zeta_prompt::Event>],
578 cursor_path: &std::path::Path,
579) -> Vec<&'a zeta_prompt::Event> {
580 edit_history
581 .iter()
582 .filter(|event| match event.as_ref() {
583 zeta_prompt::Event::BufferChange { path, .. } => {
584 let event_path = path.as_ref();
585 if event_path == cursor_path {
586 return true;
587 }
588 let stripped = event_path
589 .components()
590 .skip(1)
591 .collect::<std::path::PathBuf>();
592 stripped == cursor_path
593 }
594 })
595 .map(|arc| arc.as_ref())
596 .collect()
597}
598
599fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str {
600 match event {
601 zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(),
602 }
603}
604
605fn is_predicted_event(event: &zeta_prompt::Event) -> bool {
606 match event {
607 zeta_prompt::Event::BufferChange { predicted, .. } => *predicted,
608 }
609}
610
611pub fn compute_prediction_reversal_ratio(
612 prompt_inputs: &ExamplePromptInputs,
613 predicted_content: &str,
614 cursor_path: &Path,
615) -> f32 {
616 let current_content = &prompt_inputs.content;
617
618 let edit_history: &[Arc<zeta_prompt::Event>] = &prompt_inputs.edit_history;
619 let relevant_events = filter_edit_history_by_path(edit_history, cursor_path);
620
621 let most_recent = match relevant_events.last() {
622 Some(event) if !is_predicted_event(event) => *event,
623 _ => return 0.0,
624 };
625
626 let diff = extract_diff_from_event(most_recent);
627 if diff.is_empty() {
628 return 0.0;
629 }
630
631 if let Some(excerpt_start_row) = prompt_inputs.excerpt_start_row {
632 let diffs = vec![diff];
633 let overlap = compute_excerpt_aware_reversal_overlap(
634 &diffs,
635 current_content,
636 excerpt_start_row,
637 predicted_content,
638 );
639 return overlap.ratio();
640 }
641
642 let reversed = reverse_diff(diff);
643 let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed);
644 let original_content = match apply_diff_to_string(&with_headers, current_content) {
645 Ok(updated_content) => updated_content,
646 Err(_) => apply_diff_to_string_lenient(&reversed, current_content),
647 };
648
649 let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content);
650 overlap.ratio()
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656 use edit_prediction::udiff::apply_diff_to_string;
657 use indoc::indoc;
658
659 #[test]
660 fn test_reversal_overlap() {
661 struct Case {
662 name: &'static str,
663 original: &'static str,
664 current: &'static str,
665 predicted: &'static str,
666 expected_reversal_chars: usize,
667 expected_total_chars: usize,
668 }
669
670 let cases = [
671 Case {
672 name: "user_adds_line_prediction_removes_it",
673 original: indoc! {"
674 a
675 b
676 c"},
677 current: indoc! {"
678 a
679 new line
680 b
681 c"},
682 predicted: indoc! {"
683 a
684 b
685 c"},
686 expected_reversal_chars: 9,
687 expected_total_chars: 9,
688 },
689 Case {
690 name: "user_deletes_line_prediction_restores_it",
691 original: indoc! {"
692 a
693 deleted
694 b"},
695 current: indoc! {"
696 a
697 b"},
698 predicted: indoc! {"
699 a
700 deleted
701 b"},
702 expected_reversal_chars: 8,
703 expected_total_chars: 8,
704 },
705 Case {
706 name: "user_deletes_text_prediction_restores_partial",
707 original: "hello beautiful world",
708 current: "hello world",
709 predicted: "hello beautiful world",
710 expected_reversal_chars: 10,
711 expected_total_chars: 10,
712 },
713 Case {
714 name: "user_deletes_foo_prediction_adds_bar",
715 original: "foo",
716 current: "",
717 predicted: "bar",
718 expected_reversal_chars: 0,
719 expected_total_chars: 3,
720 },
721 Case {
722 name: "independent_edits_different_locations",
723 original: indoc! {"
724 line1
725 line2
726 line3"},
727 current: indoc! {"
728 LINE1
729 line2
730 line3"},
731 predicted: indoc! {"
732 LINE1
733 line2
734 LINE3"},
735 expected_reversal_chars: 0,
736 expected_total_chars: 10,
737 },
738 Case {
739 name: "no_history_edits",
740 original: "same",
741 current: "same",
742 predicted: "different",
743 expected_reversal_chars: 0,
744 expected_total_chars: 13,
745 },
746 Case {
747 name: "user_replaces_text_prediction_reverses",
748 original: indoc! {"
749 keep
750 delete_me
751 keep2"},
752 current: indoc! {"
753 keep
754 added
755 keep2"},
756 predicted: indoc! {"
757 keep
758 delete_me
759 keep2"},
760 expected_reversal_chars: 14,
761 expected_total_chars: 14,
762 },
763 Case {
764 name: "user_modifies_word_prediction_modifies_differently",
765 original: "the quick brown fox",
766 current: "the slow brown fox",
767 predicted: "the fast brown fox",
768 expected_reversal_chars: 4,
769 expected_total_chars: 8,
770 },
771 Case {
772 name: "user finishes function name (suffix)",
773 original: "",
774 current: "epr",
775 predicted: "eprintln!()",
776 expected_reversal_chars: 0,
777 expected_total_chars: 8,
778 },
779 Case {
780 name: "user starts function name (prefix)",
781 original: "",
782 current: "my_function()",
783 predicted: "test_my_function()",
784 expected_reversal_chars: 0,
785 expected_total_chars: 5,
786 },
787 Case {
788 name: "user types partial, prediction extends in multiple places",
789 original: "",
790 current: "test_my_function",
791 predicted: "a_test_for_my_special_function_plz",
792 expected_reversal_chars: 0,
793 expected_total_chars: 18,
794 },
795 // Edge cases for subsequence matching
796 Case {
797 name: "subsequence with interleaved underscores",
798 original: "",
799 current: "a_b_c",
800 predicted: "_a__b__c__",
801 expected_reversal_chars: 0,
802 expected_total_chars: 5,
803 },
804 Case {
805 name: "not a subsequence - different characters",
806 original: "",
807 current: "abc",
808 predicted: "xyz",
809 expected_reversal_chars: 3,
810 expected_total_chars: 6,
811 },
812 Case {
813 name: "not a subsequence - wrong order",
814 original: "",
815 current: "abc",
816 predicted: "cba",
817 expected_reversal_chars: 3,
818 expected_total_chars: 6,
819 },
820 Case {
821 name: "partial subsequence - only some chars match",
822 original: "",
823 current: "abcd",
824 predicted: "axbx",
825 expected_reversal_chars: 4,
826 expected_total_chars: 8,
827 },
828 // Common completion patterns
829 Case {
830 name: "completing a method call",
831 original: "",
832 current: "vec.pu",
833 predicted: "vec.push(item)",
834 expected_reversal_chars: 0,
835 expected_total_chars: 8,
836 },
837 Case {
838 name: "completing an import statement",
839 original: "",
840 current: "use std::col",
841 predicted: "use std::collections::HashMap",
842 expected_reversal_chars: 0,
843 expected_total_chars: 17,
844 },
845 Case {
846 name: "completing a struct field",
847 original: "",
848 current: "name: St",
849 predicted: "name: String",
850 expected_reversal_chars: 0,
851 expected_total_chars: 4,
852 },
853 Case {
854 name: "prediction replaces with completely different text",
855 original: "",
856 current: "hello",
857 predicted: "world",
858 expected_reversal_chars: 5,
859 expected_total_chars: 10,
860 },
861 Case {
862 name: "empty prediction removes user text",
863 original: "",
864 current: "mistake",
865 predicted: "",
866 expected_reversal_chars: 7,
867 expected_total_chars: 7,
868 },
869 Case {
870 name: "fixing typo is not reversal",
871 original: "",
872 current: "<dv",
873 predicted: "<div>",
874 expected_reversal_chars: 0,
875 expected_total_chars: 2,
876 },
877 Case {
878 name: "infix insertion not reversal",
879 original: indoc! {"
880 from my_project import Foo
881 "},
882 current: indoc! {"
883 ifrom my_project import Foo
884 "},
885 predicted: indoc! {"
886 import
887 from my_project import Foo
888 "},
889 expected_reversal_chars: 0,
890 expected_total_chars: 6,
891 },
892 Case {
893 name: "non-word based reversal",
894 original: "from",
895 current: "ifrom",
896 predicted: "from",
897 expected_reversal_chars: 1,
898 expected_total_chars: 1,
899 },
900 Case {
901 name: "multiple insertions no reversal",
902 original: "print(\"Hello, World!\")",
903 current: "sys.(\"Hello, World!\")",
904 predicted: "sys.stdout.write(\"Hello, World!\\n\")",
905 expected_reversal_chars: 0,
906 expected_total_chars: 14,
907 },
908 ];
909
910 for case in &cases {
911 let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
912 assert_eq!(
913 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
914 "Test '{}': expected {} reversal chars, got {}",
915 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
916 );
917 assert_eq!(
918 overlap.total_chars_in_prediction, case.expected_total_chars,
919 "Test '{}': expected {} total chars, got {}",
920 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
921 );
922 }
923 }
924
925 #[test]
926 fn test_reverse_diff() {
927 let forward_diff = indoc! {"
928 --- a/file.rs
929 +++ b/file.rs
930 @@ -1,3 +1,4 @@
931 fn main() {
932 + let x = 42;
933 println!(\"hello\");
934 }"};
935
936 let reversed = reverse_diff(forward_diff);
937
938 assert!(
939 reversed.contains("+++ a/file.rs"),
940 "Should have +++ for old path"
941 );
942 assert!(
943 reversed.contains("--- b/file.rs"),
944 "Should have --- for new path"
945 );
946 assert!(
947 reversed.contains("- let x = 42;"),
948 "Added line should become deletion"
949 );
950 assert!(
951 reversed.contains(" fn main()"),
952 "Context lines should be unchanged"
953 );
954 }
955
956 #[test]
957 fn test_reverse_diff_roundtrip() {
958 // Applying a diff and then its reverse should get back to original
959 let original = indoc! {"
960 first line
961 hello world
962 last line
963 "};
964 let modified = indoc! {"
965 first line
966 hello beautiful world
967 last line
968 "};
969
970 // unified_diff doesn't include file headers, but apply_diff_to_string needs them
971 let diff_body = language::unified_diff(original, modified);
972 let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body);
973 let reversed_diff = reverse_diff(&forward_diff);
974
975 // Apply forward diff to original
976 let after_forward = apply_diff_to_string(&forward_diff, original).unwrap();
977 assert_eq!(after_forward, modified);
978
979 // Apply reversed diff to modified
980 let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap();
981 assert_eq!(after_reverse, original);
982 }
983
984 #[test]
985 fn test_filter_edit_history_by_path() {
986 // Test that filter_edit_history_by_path correctly matches paths when
987 // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs")
988 // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs")
989 let events = vec![
990 Arc::new(zeta_prompt::Event::BufferChange {
991 path: Arc::from(Path::new("myrepo/src/file.rs")),
992 old_path: Arc::from(Path::new("myrepo/src/file.rs")),
993 diff: indoc! {"
994 @@ -1 +1 @@
995 -old
996 +new"}
997 .into(),
998 predicted: false,
999 in_open_source_repo: true,
1000 }),
1001 Arc::new(zeta_prompt::Event::BufferChange {
1002 path: Arc::from(Path::new("myrepo/other.rs")),
1003 old_path: Arc::from(Path::new("myrepo/other.rs")),
1004 diff: indoc! {"
1005 @@ -1 +1 @@
1006 -a
1007 +b"}
1008 .into(),
1009 predicted: false,
1010 in_open_source_repo: true,
1011 }),
1012 Arc::new(zeta_prompt::Event::BufferChange {
1013 path: Arc::from(Path::new("src/file.rs")),
1014 old_path: Arc::from(Path::new("src/file.rs")),
1015 diff: indoc! {"
1016 @@ -1 +1 @@
1017 -x
1018 +y"}
1019 .into(),
1020 predicted: false,
1021 in_open_source_repo: true,
1022 }),
1023 ];
1024
1025 // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path
1026 // "src/file.rs" exact match
1027 let cursor_path = Path::new("src/file.rs");
1028 let filtered = filter_edit_history_by_path(&events, cursor_path);
1029 assert_eq!(
1030 filtered.len(),
1031 2,
1032 "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)"
1033 );
1034
1035 // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs"
1036 // "src/file.rs" stripped -> "file.rs" == "file.rs"
1037 let cursor_path = Path::new("file.rs");
1038 let filtered = filter_edit_history_by_path(&events, cursor_path);
1039 assert_eq!(
1040 filtered.len(),
1041 1,
1042 "Should only match src/file.rs (stripped to file.rs)"
1043 );
1044
1045 // "myrepo/other.rs" stripped -> "other.rs" == "other.rs"
1046 let cursor_path = Path::new("other.rs");
1047 let filtered = filter_edit_history_by_path(&events, cursor_path);
1048 assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs");
1049 }
1050
1051 #[test]
1052 fn test_reverse_diff_preserves_trailing_newline() {
1053 let diff_with_trailing_newline = indoc! {"
1054 --- a/file
1055 +++ b/file
1056 @@ -1 +1 @@
1057 -old
1058 +new
1059 "};
1060 let reversed = reverse_diff(diff_with_trailing_newline);
1061 assert!(
1062 reversed.ends_with('\n'),
1063 "Reversed diff should preserve trailing newline"
1064 );
1065
1066 let diff_without_trailing_newline = indoc! {"
1067 --- a/file
1068 +++ b/file
1069 @@ -1 +1 @@
1070 -old
1071 +new"};
1072 let reversed = reverse_diff(diff_without_trailing_newline);
1073 assert!(
1074 !reversed.ends_with('\n'),
1075 "Reversed diff should not add trailing newline if original didn't have one"
1076 );
1077 }
1078
1079 #[test]
1080 fn test_filter_hunks_by_excerpt_region() {
1081 struct Case {
1082 name: &'static str,
1083 diff: &'static str,
1084 excerpt_start_row: u32,
1085 excerpt_row_count: u32,
1086 expected_filtered_diff: &'static str,
1087 expected_line_offset: i32,
1088 }
1089
1090 let cases = [
1091 Case {
1092 name: "hunk_entirely_before_excerpt",
1093 diff: indoc! {"
1094 @@ -1,3 +1,4 @@
1095 line1
1096 +inserted
1097 line2
1098 line3
1099 "},
1100 excerpt_start_row: 10,
1101 excerpt_row_count: 5,
1102 expected_filtered_diff: "",
1103 expected_line_offset: 1,
1104 },
1105 Case {
1106 name: "hunk_entirely_inside_excerpt",
1107 diff: indoc! {"
1108 @@ -12,3 +12,4 @@
1109 line12
1110 +inserted
1111 line13
1112 line14
1113 "},
1114 excerpt_start_row: 10,
1115 excerpt_row_count: 10,
1116 expected_filtered_diff: indoc! {"
1117 @@ -2,3 +2,4 @@
1118 line12
1119 +inserted
1120 line13
1121 line14
1122 "},
1123 expected_line_offset: 1,
1124 },
1125 Case {
1126 name: "hunk_entirely_after_excerpt",
1127 diff: indoc! {"
1128 @@ -50,3 +50,4 @@
1129 line50
1130 +inserted
1131 line51
1132 line52
1133 "},
1134 excerpt_start_row: 10,
1135 excerpt_row_count: 5,
1136 expected_filtered_diff: "",
1137 expected_line_offset: 0,
1138 },
1139 Case {
1140 name: "hunk_straddles_excerpt_start",
1141 diff: indoc! {"
1142 @@ -8,5 +8,6 @@
1143 line8
1144 line9
1145 +inserted
1146 line10
1147 line11
1148 line12
1149 "},
1150 excerpt_start_row: 10,
1151 excerpt_row_count: 10,
1152 expected_filtered_diff: indoc! {"
1153 @@ -1,3 +1,3 @@
1154 line10
1155 line11
1156 line12
1157 "},
1158 expected_line_offset: 1,
1159 },
1160 Case {
1161 name: "hunk_straddles_excerpt_end",
1162 diff: indoc! {"
1163 @@ -18,5 +18,6 @@
1164 line18
1165 line19
1166 +inserted
1167 line20
1168 line21
1169 line22
1170 "},
1171 excerpt_start_row: 10,
1172 excerpt_row_count: 10,
1173 expected_filtered_diff: indoc! {"
1174 @@ -8,2 +8,3 @@
1175 line18
1176 line19
1177 +inserted
1178 "},
1179 expected_line_offset: 1,
1180 },
1181 Case {
1182 name: "multiple_hunks_mixed",
1183 diff: indoc! {"
1184 @@ -1,2 +1,3 @@
1185 line1
1186 +before_excerpt
1187 line2
1188 @@ -12,2 +13,3 @@
1189 line12
1190 +inside_excerpt
1191 line13
1192 @@ -50,2 +52,3 @@
1193 line50
1194 +after_excerpt
1195 line51
1196 "},
1197 excerpt_start_row: 10,
1198 excerpt_row_count: 10,
1199 expected_filtered_diff: indoc! {"
1200 @@ -3,2 +3,3 @@
1201 line12
1202 +inside_excerpt
1203 line13
1204 "},
1205 expected_line_offset: 2,
1206 },
1207 Case {
1208 name: "deletion_before_excerpt",
1209 diff: indoc! {"
1210 @@ -1,4 +1,3 @@
1211 line1
1212 -deleted
1213 line2
1214 line3
1215 "},
1216 excerpt_start_row: 10,
1217 excerpt_row_count: 5,
1218 expected_filtered_diff: "",
1219 expected_line_offset: -1,
1220 },
1221 Case {
1222 name: "deletion_inside_excerpt",
1223 diff: indoc! {"
1224 @@ -12,4 +12,3 @@
1225 line12
1226 -deleted
1227 line13
1228 line14
1229 "},
1230 excerpt_start_row: 10,
1231 excerpt_row_count: 10,
1232 expected_filtered_diff: indoc! {"
1233 @@ -2,4 +2,3 @@
1234 line12
1235 -deleted
1236 line13
1237 line14
1238 "},
1239 expected_line_offset: -1,
1240 },
1241 Case {
1242 name: "empty_diff",
1243 diff: "",
1244 excerpt_start_row: 10,
1245 excerpt_row_count: 5,
1246 expected_filtered_diff: "",
1247 expected_line_offset: 0,
1248 },
1249 Case {
1250 name: "hunk_spans_entire_excerpt",
1251 diff: indoc! {"
1252 @@ -8,10 +8,12 @@
1253 line8
1254 line9
1255 line10
1256 line11
1257 +inserted1
1258 line12
1259 line13
1260 +inserted2
1261 line14
1262 line15
1263 line16
1264 line17
1265 "},
1266 excerpt_start_row: 10,
1267 excerpt_row_count: 5,
1268 expected_filtered_diff: indoc! {"
1269 @@ -1,3 +1,5 @@
1270 line11
1271 +inserted1
1272 line12
1273 line13
1274 +inserted2
1275 "},
1276 expected_line_offset: 2,
1277 },
1278 Case {
1279 name: "replacement_inside_excerpt",
1280 diff: indoc! {"
1281 @@ -12,3 +12,3 @@
1282 line12
1283 -old_text
1284 +new_text
1285 line14
1286 "},
1287 excerpt_start_row: 10,
1288 excerpt_row_count: 10,
1289 expected_filtered_diff: indoc! {"
1290 @@ -2,3 +2,3 @@
1291 line12
1292 -old_text
1293 +new_text
1294 line14
1295 "},
1296 expected_line_offset: 0,
1297 },
1298 ];
1299
1300 for case in &cases {
1301 let (filtered, line_offset) = filter_diff_hunks_by_excerpt(
1302 case.diff,
1303 case.excerpt_start_row,
1304 case.excerpt_row_count,
1305 );
1306 assert_eq!(
1307 filtered, case.expected_filtered_diff,
1308 "Test '{}': filtered diff mismatch.\nExpected:\n{}\nGot:\n{}",
1309 case.name, case.expected_filtered_diff, filtered
1310 );
1311 assert_eq!(
1312 line_offset, case.expected_line_offset,
1313 "Test '{}': line offset mismatch. Expected {}, got {}",
1314 case.name, case.expected_line_offset, line_offset
1315 );
1316 }
1317 }
1318
1319 #[test]
1320 fn test_excerpt_aware_reversal_tracking() {
1321 struct Case {
1322 name: &'static str,
1323 edit_history_diffs: Vec<&'static str>,
1324 excerpt_content: &'static str,
1325 excerpt_start_row: u32,
1326 predicted_content: &'static str,
1327 expected_reversal_chars: usize,
1328 expected_total_chars: usize,
1329 }
1330
1331 let cases = [
1332 Case {
1333 name: "edit_outside_excerpt_no_reversal",
1334 edit_history_diffs: vec![indoc! {"
1335 @@ -1,2 +1,3 @@
1336 line1
1337 +added_outside
1338 line2
1339 "}],
1340 excerpt_content: indoc! {"
1341 line10
1342 line11
1343 line12
1344 "},
1345 excerpt_start_row: 10,
1346 predicted_content: indoc! {"
1347 line10
1348 modified
1349 line12
1350 "},
1351 expected_reversal_chars: 0,
1352 expected_total_chars: 14,
1353 },
1354 Case {
1355 name: "edit_inside_excerpt_with_reversal",
1356 edit_history_diffs: vec![indoc! {"
1357 @@ -10,3 +10,4 @@
1358 line10
1359 +user_added
1360 line11
1361 line12
1362 "}],
1363 excerpt_content: indoc! {"
1364 line10
1365 user_added
1366 line11
1367 line12
1368 "},
1369 excerpt_start_row: 10,
1370 predicted_content: indoc! {"
1371 line10
1372 line11
1373 line12
1374 "},
1375 expected_reversal_chars: 11,
1376 expected_total_chars: 11,
1377 },
1378 Case {
1379 name: "straddling_edit_partial_reversal",
1380 edit_history_diffs: vec![indoc! {"
1381 @@ -8,6 +8,8 @@
1382 line8
1383 line9
1384 +before_excerpt
1385 line10
1386 +inside_excerpt
1387 line11
1388 line12
1389 line13
1390 "}],
1391 excerpt_content: indoc! {"
1392 line10
1393 inside_excerpt
1394 line11
1395 line12
1396 line13
1397 "},
1398 excerpt_start_row: 10,
1399 predicted_content: indoc! {"
1400 line10
1401 line11
1402 line12
1403 line13
1404 "},
1405 expected_reversal_chars: 15,
1406 expected_total_chars: 15,
1407 },
1408 Case {
1409 name: "multiple_edits_mixed_locations",
1410 edit_history_diffs: vec![
1411 indoc! {"
1412 @@ -1,2 +1,3 @@
1413 line1
1414 +outside1
1415 line2
1416 "},
1417 indoc! {"
1418 @@ -11,2 +12,3 @@
1419 line11
1420 +inside1
1421 line12
1422 "},
1423 ],
1424 excerpt_content: indoc! {"
1425 line10
1426 line11
1427 inside1
1428 line12
1429 line13
1430 "},
1431 excerpt_start_row: 10,
1432 predicted_content: indoc! {"
1433 line10
1434 line11
1435 line12
1436 line13
1437 "},
1438 expected_reversal_chars: 8,
1439 expected_total_chars: 8,
1440 },
1441 Case {
1442 name: "no_edit_history",
1443 edit_history_diffs: vec![],
1444 excerpt_content: indoc! {"
1445 line10
1446 line11
1447 line12
1448 "},
1449 excerpt_start_row: 10,
1450 predicted_content: indoc! {"
1451 line10
1452 modified
1453 line12
1454 "},
1455 expected_reversal_chars: 0,
1456 expected_total_chars: 14,
1457 },
1458 Case {
1459 name: "edit_after_excerpt_no_effect",
1460 edit_history_diffs: vec![indoc! {"
1461 @@ -50,2 +50,3 @@
1462 line50
1463 +added_after
1464 line51
1465 "}],
1466 excerpt_content: indoc! {"
1467 line10
1468 line11
1469 line12
1470 "},
1471 excerpt_start_row: 10,
1472 predicted_content: indoc! {"
1473 line10
1474 changed
1475 line12
1476 "},
1477 expected_reversal_chars: 0,
1478 expected_total_chars: 13,
1479 },
1480 Case {
1481 name: "line_offset_tracking_across_hunks",
1482 edit_history_diffs: vec![
1483 indoc! {"
1484 @@ -1,2 +1,4 @@
1485 line1
1486 +added1
1487 +added2
1488 line2
1489 "},
1490 indoc! {"
1491 @@ -12,2 +14,3 @@
1492 line12
1493 +inside_after_offset
1494 line13
1495 "},
1496 ],
1497 excerpt_content: indoc! {"
1498 line10
1499 line11
1500 line12
1501 inside_after_offset
1502 line13
1503 "},
1504 excerpt_start_row: 10,
1505 predicted_content: indoc! {"
1506 line10
1507 line11
1508 line12
1509 line13
1510 "},
1511 expected_reversal_chars: 20,
1512 expected_total_chars: 20,
1513 },
1514 ];
1515
1516 for case in &cases {
1517 let overlap = compute_excerpt_aware_reversal_overlap(
1518 &case.edit_history_diffs,
1519 case.excerpt_content,
1520 case.excerpt_start_row,
1521 case.predicted_content,
1522 );
1523 assert_eq!(
1524 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
1525 "Test '{}': expected {} reversal chars, got {}",
1526 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
1527 );
1528 assert_eq!(
1529 overlap.total_chars_in_prediction, case.expected_total_chars,
1530 "Test '{}': expected {} total chars, got {}",
1531 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
1532 );
1533 }
1534 }
1535
1536 #[test]
1537 fn test_lenient_diff_application() {
1538 struct Case {
1539 name: &'static str,
1540 diff: &'static str,
1541 content: &'static str,
1542 expected_result: &'static str,
1543 }
1544
1545 let cases = [
1546 Case {
1547 name: "hunk_context_not_found_skipped",
1548 diff: indoc! {"
1549 @@ -1,3 +1,4 @@
1550 context_not_in_content
1551 +added_line
1552 more_context
1553 final_context
1554 "},
1555 content: indoc! {"
1556 completely
1557 different
1558 content
1559 "},
1560 expected_result: indoc! {"
1561 completely
1562 different
1563 content
1564 "},
1565 },
1566 Case {
1567 name: "hunk_context_found_applied",
1568 diff: indoc! {"
1569 @@ -1,3 +1,4 @@
1570 line1
1571 +inserted
1572 line2
1573 line3
1574 "},
1575 content: indoc! {"
1576 line1
1577 line2
1578 line3
1579 "},
1580 expected_result: indoc! {"
1581 line1
1582 inserted
1583 line2
1584 line3
1585 "},
1586 },
1587 Case {
1588 name: "multiple_hunks_partial_match",
1589 diff: indoc! {"
1590 @@ -1,2 +1,3 @@
1591 not_found
1592 +skipped
1593 also_not_found
1594 @@ -5,2 +6,3 @@
1595 line5
1596 +applied
1597 line6
1598 "},
1599 content: indoc! {"
1600 line1
1601 line2
1602 line3
1603 line4
1604 line5
1605 line6
1606 "},
1607 expected_result: indoc! {"
1608 line1
1609 line2
1610 line3
1611 line4
1612 line5
1613 applied
1614 line6
1615 "},
1616 },
1617 Case {
1618 name: "empty_diff",
1619 diff: "",
1620 content: indoc! {"
1621 unchanged
1622 content
1623 "},
1624 expected_result: indoc! {"
1625 unchanged
1626 content
1627 "},
1628 },
1629 ];
1630
1631 for case in &cases {
1632 let result = apply_diff_to_string_lenient(case.diff, case.content);
1633 assert_eq!(
1634 result, case.expected_result,
1635 "Test '{}': expected:\n{}\ngot:\n{}",
1636 case.name, case.expected_result, result
1637 );
1638 }
1639 }
1640
1641 #[test]
1642 fn test_unicode_reversal_overlap() {
1643 struct Case {
1644 name: &'static str,
1645 original: &'static str,
1646 current: &'static str,
1647 predicted: &'static str,
1648 expected_reversal_chars: usize,
1649 expected_total_chars: usize,
1650 }
1651
1652 let cases = [
1653 Case {
1654 name: "unicode_extension_cjk",
1655 original: "",
1656 current: "日", // 1 char
1657 predicted: "日本語", // 3 chars, adds 2 chars
1658 expected_reversal_chars: 0,
1659 expected_total_chars: 2, // "本語" = 2 chars added
1660 },
1661 Case {
1662 name: "unicode_extension_emoji",
1663 original: "",
1664 current: "🎉", // 1 char
1665 predicted: "🎉🎊🎈", // 3 chars, adds 2 chars
1666 expected_reversal_chars: 0,
1667 expected_total_chars: 2, // "🎊🎈" = 2 chars added
1668 },
1669 Case {
1670 name: "unicode_deletion_restored",
1671 original: "héllo wörld", // 11 chars
1672 current: "héllo", // 5 chars
1673 predicted: "héllo wörld", // restores " wörld" = 6 chars
1674 expected_reversal_chars: 6, // LCS(" wörld", " wörld") = 6 chars
1675 expected_total_chars: 6,
1676 },
1677 Case {
1678 name: "unicode_addition_reversed",
1679 original: "café", // 4 chars
1680 current: "café latté", // 10 chars, added " latté" = 6 chars
1681 predicted: "café", // removes " latté"
1682 expected_reversal_chars: 6, // 6 chars removed
1683 expected_total_chars: 6,
1684 },
1685 Case {
1686 name: "mixed_ascii_unicode",
1687 original: "",
1688 current: "test日本", // 6 chars
1689 predicted: "test日本語です", // 9 chars
1690 expected_reversal_chars: 0,
1691 expected_total_chars: 3, // 3 new chars after subsequence normalization
1692 },
1693 Case {
1694 name: "unicode_replacement_not_subsequence",
1695 original: "",
1696 current: "日本", // 2 chars
1697 predicted: "中国", // 2 chars, different
1698 expected_reversal_chars: 2, // removes "日本" = 2 chars
1699 expected_total_chars: 4, // 2 removed + 2 added
1700 },
1701 ];
1702
1703 for case in &cases {
1704 let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
1705 assert_eq!(
1706 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
1707 "Test '{}': expected {} reversal chars, got {}",
1708 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
1709 );
1710 assert_eq!(
1711 overlap.total_chars_in_prediction, case.expected_total_chars,
1712 "Test '{}': expected {} total chars, got {}",
1713 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
1714 );
1715 }
1716 }
1717
1718 #[test]
1719 fn test_compute_lcs_length() {
1720 assert_eq!(compute_lcs_length("", ""), 0);
1721 assert_eq!(compute_lcs_length("abc", ""), 0);
1722 assert_eq!(compute_lcs_length("", "abc"), 0);
1723 assert_eq!(compute_lcs_length("abc", "abc"), 3);
1724 assert_eq!(compute_lcs_length("abc", "def"), 0);
1725 assert_eq!(compute_lcs_length("abcdef", "ace"), 3);
1726 assert_eq!(compute_lcs_length("AGGTAB", "GXTXAYB"), 4);
1727 assert_eq!(compute_lcs_length("日本語", "日語"), 2);
1728 }
1729
1730 #[test]
1731 fn test_compute_prediction_reversal_ratio_full_file() {
1732 let prompt_inputs = ExamplePromptInputs {
1733 content: indoc! {"
1734 line1
1735 user_added
1736 line2
1737 "}
1738 .to_string(),
1739 cursor_row: 0,
1740 cursor_column: 0,
1741 cursor_offset: 0,
1742 edit_history: vec![Arc::new(zeta_prompt::Event::BufferChange {
1743 path: Arc::from(Path::new("src/test.rs")),
1744 old_path: Arc::from(Path::new("src/test.rs")),
1745 diff: indoc! {"
1746 @@ -1,2 +1,3 @@
1747 line1
1748 +user_added
1749 line2
1750 "}
1751 .into(),
1752 predicted: false,
1753 in_open_source_repo: false,
1754 })],
1755 excerpt_start_row: None,
1756 related_files: None,
1757 };
1758
1759 let predicted = indoc! {"
1760 line1
1761 line2
1762 "};
1763 let ratio =
1764 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1765
1766 assert!(
1767 ratio > 0.9,
1768 "Expected high reversal ratio when prediction removes user addition, got {}",
1769 ratio
1770 );
1771 }
1772
1773 #[test]
1774 fn test_compute_prediction_reversal_ratio_with_excerpt() {
1775 let prompt_inputs = ExamplePromptInputs {
1776 content: indoc! {"
1777 line10
1778 user_added
1779 line11
1780 "}
1781 .to_string(),
1782 cursor_row: 0,
1783 cursor_column: 0,
1784 cursor_offset: 0,
1785 edit_history: vec![Arc::new(zeta_prompt::Event::BufferChange {
1786 path: Arc::from(Path::new("src/test.rs")),
1787 old_path: Arc::from(Path::new("src/test.rs")),
1788 diff: indoc! {"
1789 @@ -10,2 +10,3 @@
1790 line10
1791 +user_added
1792 line11
1793 "}
1794 .into(),
1795 predicted: false,
1796 in_open_source_repo: false,
1797 })],
1798 excerpt_start_row: Some(10),
1799 related_files: None,
1800 };
1801
1802 let predicted = indoc! {"
1803 line10
1804 line11
1805 "};
1806 let ratio =
1807 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1808
1809 assert!(
1810 ratio > 0.9,
1811 "Expected high reversal ratio for excerpt-aware computation, got {}",
1812 ratio
1813 );
1814 }
1815
1816 #[test]
1817 fn test_compute_prediction_reversal_ratio_no_history() {
1818 let prompt_inputs = ExamplePromptInputs {
1819 content: indoc! {"
1820 original content
1821 "}
1822 .to_string(),
1823 cursor_row: 0,
1824 cursor_column: 0,
1825 cursor_offset: 0,
1826 edit_history: vec![],
1827 excerpt_start_row: None,
1828 related_files: None,
1829 };
1830
1831 let predicted = indoc! {"
1832 completely different
1833 "};
1834 let ratio =
1835 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1836
1837 assert_eq!(
1838 ratio, 0.0,
1839 "Expected zero reversal ratio with no edit history"
1840 );
1841 }
1842
1843 #[test]
1844 fn test_compute_prediction_reversal_ratio_path_filtering() {
1845 let prompt_inputs = ExamplePromptInputs {
1846 content: indoc! {"
1847 line1
1848 user_added
1849 line2
1850 "}
1851 .to_string(),
1852 cursor_row: 0,
1853 cursor_column: 0,
1854 cursor_offset: 0,
1855 edit_history: vec![Arc::new(zeta_prompt::Event::BufferChange {
1856 path: Arc::from(Path::new("src/other.rs")),
1857 old_path: Arc::from(Path::new("src/other.rs")),
1858 diff: indoc! {"
1859 @@ -1,2 +1,3 @@
1860 line1
1861 +user_added
1862 line2
1863 "}
1864 .into(),
1865 predicted: false,
1866 in_open_source_repo: false,
1867 })],
1868 excerpt_start_row: None,
1869 related_files: None,
1870 };
1871
1872 let predicted = indoc! {"
1873 line1
1874 line2
1875 "};
1876 let ratio =
1877 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1878
1879 assert_eq!(
1880 ratio, 0.0,
1881 "Expected zero reversal when edit history is for different file"
1882 );
1883 }
1884
1885 #[test]
1886 fn test_compute_prediction_reversal_ratio_lenient_fallback() {
1887 let prompt_inputs = ExamplePromptInputs {
1888 content: indoc! {"
1889 actual_line1
1890 user_added
1891 actual_line2
1892 "}
1893 .to_string(),
1894 cursor_row: 0,
1895 cursor_column: 0,
1896 cursor_offset: 0,
1897 edit_history: vec![Arc::new(zeta_prompt::Event::BufferChange {
1898 path: Arc::from(Path::new("src/test.rs")),
1899 old_path: Arc::from(Path::new("src/test.rs")),
1900 diff: indoc! {"
1901 @@ -1,2 +1,3 @@
1902 wrong_context
1903 +user_added
1904 more_wrong
1905 "}
1906 .into(),
1907 predicted: false,
1908 in_open_source_repo: false,
1909 })],
1910 excerpt_start_row: None,
1911 related_files: None,
1912 };
1913
1914 let predicted = indoc! {"
1915 actual_line1
1916 actual_line2
1917 "};
1918 let ratio =
1919 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1920
1921 assert!(
1922 ratio >= 0.0 && ratio <= 1.0,
1923 "Ratio should be valid even with lenient fallback, got {}",
1924 ratio
1925 );
1926 }
1927
1928 #[test]
1929 fn test_excerpt_aware_reversal_error_recovery() {
1930 let diffs = vec![indoc! {"
1931 @@ -1,2 +1,3 @@
1932 nonexistent_context
1933 +added
1934 more_nonexistent
1935 "}];
1936 let excerpt_content = indoc! {"
1937 completely
1938 different
1939 content
1940 "};
1941 let predicted_content = indoc! {"
1942 completely
1943 modified
1944 content
1945 "};
1946
1947 let overlap =
1948 compute_excerpt_aware_reversal_overlap(&diffs, excerpt_content, 0, predicted_content);
1949
1950 assert!(
1951 overlap.ratio() >= 0.0 && overlap.ratio() <= 1.0,
1952 "Should handle failed diff application gracefully"
1953 );
1954 }
1955
1956 #[test]
1957 fn test_only_most_recent_edit_tracked() {
1958 let prompt_inputs = ExamplePromptInputs {
1959 content: indoc! {"
1960 line1
1961 first_add
1962 second_add
1963 line2
1964 "}
1965 .to_string(),
1966 cursor_row: 0,
1967 cursor_column: 0,
1968 cursor_offset: 0,
1969 edit_history: vec![
1970 Arc::new(zeta_prompt::Event::BufferChange {
1971 path: Arc::from(Path::new("src/test.rs")),
1972 old_path: Arc::from(Path::new("src/test.rs")),
1973 diff: indoc! {"
1974 @@ -1,2 +1,3 @@
1975 line1
1976 +first_add
1977 line2
1978 "}
1979 .into(),
1980 predicted: false,
1981 in_open_source_repo: false,
1982 }),
1983 Arc::new(zeta_prompt::Event::BufferChange {
1984 path: Arc::from(Path::new("src/test.rs")),
1985 old_path: Arc::from(Path::new("src/test.rs")),
1986 diff: indoc! {"
1987 @@ -2,2 +2,3 @@
1988 first_add
1989 +second_add
1990 line2
1991 "}
1992 .into(),
1993 predicted: false,
1994 in_open_source_repo: false,
1995 }),
1996 ],
1997 excerpt_start_row: None,
1998 related_files: None,
1999 };
2000
2001 let predicted = indoc! {"
2002 line1
2003 first_add
2004 line2
2005 "};
2006 let ratio =
2007 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
2008
2009 assert!(
2010 ratio > 0.9,
2011 "Expected high reversal ratio when prediction exactly reverses the most recent edit, got {}",
2012 ratio
2013 );
2014 }
2015}