1use std::ops::Range;
2use std::path::Path;
3use std::sync::Arc;
4
5use language::{char_diff, text_diff};
6use zeta_prompt::udiff::apply_diff_to_string;
7
8fn apply_diff_to_string_lenient(diff_str: &str, text: &str) -> String {
9 let hunks = parse_diff_hunks(diff_str);
10 let mut result = text.to_string();
11
12 for hunk in hunks {
13 let hunk_diff = format!("--- a/file\n+++ b/file\n{}", format_hunk(&hunk));
14 if let Ok(updated) = apply_diff_to_string(&hunk_diff, &result) {
15 result = updated;
16 }
17 }
18
19 result
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23struct ParsedHunk {
24 old_start: u32,
25 old_count: u32,
26 new_start: u32,
27 new_count: u32,
28 lines: Vec<HunkLine>,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32enum HunkLine {
33 Context(String),
34 Addition(String),
35 Deletion(String),
36}
37
38fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> {
39 let line = line.strip_prefix("@@ -")?;
40 let (old_part, rest) = line.split_once(' ')?;
41 let rest = rest.strip_prefix('+')?;
42 let (new_part, _) = rest.split_once(" @@")?;
43
44 let (old_start, old_count) = if let Some((start, count)) = old_part.split_once(',') {
45 (start.parse().ok()?, count.parse().ok()?)
46 } else {
47 (old_part.parse().ok()?, 1)
48 };
49
50 let (new_start, new_count) = if let Some((start, count)) = new_part.split_once(',') {
51 (start.parse().ok()?, count.parse().ok()?)
52 } else {
53 (new_part.parse().ok()?, 1)
54 };
55
56 Some((old_start, old_count, new_start, new_count))
57}
58
59fn parse_diff_hunks(diff: &str) -> Vec<ParsedHunk> {
60 let mut hunks = Vec::new();
61 let mut current_hunk: Option<ParsedHunk> = None;
62
63 for line in diff.lines() {
64 if let Some((old_start, old_count, new_start, new_count)) = parse_hunk_header(line) {
65 if let Some(hunk) = current_hunk.take() {
66 hunks.push(hunk);
67 }
68 current_hunk = Some(ParsedHunk {
69 old_start,
70 old_count,
71 new_start,
72 new_count,
73 lines: Vec::new(),
74 });
75 } else if let Some(ref mut hunk) = current_hunk {
76 if let Some(stripped) = line.strip_prefix('+') {
77 hunk.lines.push(HunkLine::Addition(stripped.to_string()));
78 } else if let Some(stripped) = line.strip_prefix('-') {
79 hunk.lines.push(HunkLine::Deletion(stripped.to_string()));
80 } else if let Some(stripped) = line.strip_prefix(' ') {
81 hunk.lines.push(HunkLine::Context(stripped.to_string()));
82 } else if line.is_empty() {
83 hunk.lines.push(HunkLine::Context(String::new()));
84 }
85 }
86 }
87
88 if let Some(hunk) = current_hunk {
89 hunks.push(hunk);
90 }
91
92 hunks
93}
94
95fn format_hunk(hunk: &ParsedHunk) -> String {
96 let mut result = format!(
97 "@@ -{},{} +{},{} @@\n",
98 hunk.old_start, hunk.old_count, hunk.new_start, hunk.new_count
99 );
100 for line in &hunk.lines {
101 match line {
102 HunkLine::Context(text) => {
103 result.push(' ');
104 result.push_str(text);
105 result.push('\n');
106 }
107 HunkLine::Addition(text) => {
108 result.push('+');
109 result.push_str(text);
110 result.push('\n');
111 }
112 HunkLine::Deletion(text) => {
113 result.push('-');
114 result.push_str(text);
115 result.push('\n');
116 }
117 }
118 }
119 result
120}
121
122fn filter_diff_hunks_by_excerpt(
123 diff: &str,
124 excerpt_start_row: u32,
125 excerpt_row_count: u32,
126) -> (String, i32) {
127 let hunks = parse_diff_hunks(diff);
128 let excerpt_start_0based = excerpt_start_row;
129 let excerpt_end_0based = excerpt_start_row + excerpt_row_count;
130
131 let mut filtered_hunks = Vec::new();
132 let mut cumulative_line_offset: i32 = 0;
133
134 for hunk in hunks {
135 let hunk_start_0based = hunk.new_start.saturating_sub(1);
136 let hunk_end_0based = hunk_start_0based + hunk.new_count;
137
138 let additions: i32 = hunk
139 .lines
140 .iter()
141 .filter(|l| matches!(l, HunkLine::Addition(_)))
142 .count() as i32;
143 let deletions: i32 = hunk
144 .lines
145 .iter()
146 .filter(|l| matches!(l, HunkLine::Deletion(_)))
147 .count() as i32;
148 let hunk_line_delta = additions - deletions;
149
150 if hunk_end_0based <= excerpt_start_0based {
151 cumulative_line_offset += hunk_line_delta;
152 continue;
153 }
154
155 if hunk_start_0based >= excerpt_end_0based {
156 continue;
157 }
158
159 let mut filtered_lines = Vec::new();
160 let mut current_row_0based = hunk_start_0based;
161 let mut filtered_old_count = 0u32;
162 let mut filtered_new_count = 0u32;
163 let mut first_included_row: Option<u32> = None;
164
165 for line in &hunk.lines {
166 match line {
167 HunkLine::Context(text) => {
168 if current_row_0based >= excerpt_start_0based
169 && current_row_0based < excerpt_end_0based
170 {
171 if first_included_row.is_none() {
172 first_included_row = Some(current_row_0based);
173 }
174 filtered_lines.push(HunkLine::Context(text.clone()));
175 filtered_old_count += 1;
176 filtered_new_count += 1;
177 }
178 current_row_0based += 1;
179 }
180 HunkLine::Addition(text) => {
181 if current_row_0based >= excerpt_start_0based
182 && current_row_0based < excerpt_end_0based
183 {
184 if first_included_row.is_none() {
185 first_included_row = Some(current_row_0based);
186 }
187 filtered_lines.push(HunkLine::Addition(text.clone()));
188 filtered_new_count += 1;
189 }
190 current_row_0based += 1;
191 }
192 HunkLine::Deletion(text) => {
193 if current_row_0based >= excerpt_start_0based
194 && current_row_0based < excerpt_end_0based
195 {
196 if first_included_row.is_none() {
197 first_included_row = Some(current_row_0based);
198 }
199 filtered_lines.push(HunkLine::Deletion(text.clone()));
200 filtered_old_count += 1;
201 }
202 }
203 }
204 }
205
206 if !filtered_lines.is_empty() {
207 let first_row = first_included_row.unwrap_or(excerpt_start_0based);
208 let new_start_1based = (first_row - excerpt_start_0based) + 1;
209
210 filtered_hunks.push(ParsedHunk {
211 old_start: new_start_1based,
212 old_count: filtered_old_count,
213 new_start: new_start_1based,
214 new_count: filtered_new_count,
215 lines: filtered_lines,
216 });
217 }
218
219 cumulative_line_offset += hunk_line_delta;
220 }
221
222 let mut result = String::new();
223 for hunk in &filtered_hunks {
224 result.push_str(&format_hunk(hunk));
225 }
226
227 (result, cumulative_line_offset)
228}
229
230fn compute_excerpt_aware_reversal_overlap(
231 edit_history_diffs: &[&str],
232 excerpt_content: &str,
233 excerpt_start_row: u32,
234 predicted_content: &str,
235) -> ReversalOverlap {
236 let mut current_content = excerpt_content.to_string();
237 let mut current_excerpt_start_row = excerpt_start_row;
238
239 for diff in edit_history_diffs.iter().rev() {
240 if diff.is_empty() {
241 continue;
242 }
243
244 let current_row_count = current_content.lines().count() as u32;
245 let (filtered_diff, _line_offset) =
246 filter_diff_hunks_by_excerpt(diff, current_excerpt_start_row, current_row_count.max(1));
247
248 if filtered_diff.is_empty() {
249 let hunks = parse_diff_hunks(diff);
250 for hunk in hunks {
251 let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
252 if hunk_end <= current_excerpt_start_row {
253 let additions: u32 = hunk
254 .lines
255 .iter()
256 .filter(|l| matches!(l, HunkLine::Addition(_)))
257 .count() as u32;
258 let deletions: u32 = hunk
259 .lines
260 .iter()
261 .filter(|l| matches!(l, HunkLine::Deletion(_)))
262 .count() as u32;
263 if additions >= deletions {
264 current_excerpt_start_row =
265 current_excerpt_start_row.saturating_sub(additions - deletions);
266 } else {
267 current_excerpt_start_row += deletions - additions;
268 }
269 }
270 }
271 continue;
272 }
273
274 let reversed = reverse_diff(&format!("--- a/file\n+++ b/file\n{}", filtered_diff));
275 match apply_diff_to_string(&reversed, ¤t_content) {
276 Ok(updated) => {
277 current_content = updated;
278 }
279 Err(_) => {
280 continue;
281 }
282 }
283
284 let hunks = parse_diff_hunks(diff);
285 for hunk in hunks {
286 let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
287 if hunk_end <= current_excerpt_start_row {
288 let additions: u32 = hunk
289 .lines
290 .iter()
291 .filter(|l| matches!(l, HunkLine::Addition(_)))
292 .count() as u32;
293 let deletions: u32 = hunk
294 .lines
295 .iter()
296 .filter(|l| matches!(l, HunkLine::Deletion(_)))
297 .count() as u32;
298 if additions >= deletions {
299 current_excerpt_start_row =
300 current_excerpt_start_row.saturating_sub(additions - deletions);
301 } else {
302 current_excerpt_start_row += deletions - additions;
303 }
304 }
305 }
306 }
307
308 compute_reversal_overlap(¤t_content, excerpt_content, predicted_content)
309}
310
311fn reverse_diff(diff: &str) -> String {
312 let mut result: String = diff
313 .lines()
314 .map(|line| {
315 if line.starts_with("--- ") {
316 line.replacen("--- ", "+++ ", 1)
317 } else if line.starts_with("+++ ") {
318 line.replacen("+++ ", "--- ", 1)
319 } else if line.starts_with('+') && !line.starts_with("+++") {
320 format!("-{}", &line[1..])
321 } else if line.starts_with('-') && !line.starts_with("---") {
322 format!("+{}", &line[1..])
323 } else {
324 line.to_string()
325 }
326 })
327 .collect::<Vec<_>>()
328 .join("\n");
329 if diff.ends_with('\n') {
330 result.push('\n');
331 }
332 result
333}
334
335#[derive(Debug, Clone, PartialEq, Eq)]
336struct GranularEdit {
337 range: Range<usize>,
338 old_text: String,
339 new_text: String,
340}
341
342fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
343 text_diff(old_text, new_text)
344 .into_iter()
345 .map(|(range, new_text)| GranularEdit {
346 old_text: old_text[range.clone()].to_string(),
347 range,
348 new_text: new_text.to_string(),
349 })
350 .collect()
351}
352
353#[derive(Debug, Clone)]
354struct HistoryAdditionRange {
355 range_in_current: Range<usize>,
356}
357
358#[derive(Debug, Clone)]
359struct HistoryDeletionRange {
360 deleted_text: String,
361 position_in_current: usize,
362}
363
364fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryAdditionRange> {
365 let mut result = Vec::new();
366 let mut offset_delta: isize = 0;
367
368 for edit in history_edits {
369 if !edit.new_text.is_empty() {
370 let new_start = (edit.range.start as isize + offset_delta) as usize;
371 let new_end = new_start + edit.new_text.len();
372 result.push(HistoryAdditionRange {
373 range_in_current: new_start..new_end,
374 });
375 }
376
377 offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
378 }
379
380 result
381}
382
383fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryDeletionRange> {
384 let mut result = Vec::new();
385 let mut offset_delta: isize = 0;
386
387 for edit in history_edits {
388 if !edit.old_text.is_empty() {
389 let position_in_current = (edit.range.start as isize + offset_delta) as usize;
390 result.push(HistoryDeletionRange {
391 deleted_text: edit.old_text.clone(),
392 position_in_current,
393 });
394 }
395
396 offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
397 }
398
399 result
400}
401
402#[derive(Debug, Clone, Default, PartialEq, Eq)]
403struct ReversalOverlap {
404 chars_reversing_user_edits: usize,
405 total_chars_in_prediction: usize,
406}
407
408impl ReversalOverlap {
409 fn ratio(&self) -> f32 {
410 if self.total_chars_in_prediction == 0 {
411 0.0
412 } else {
413 self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32
414 }
415 }
416}
417
418/// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension),
419/// or where `new_text` appears as a subsequence within `old_text` (reduction).
420///
421/// For extensions: when the user's text is preserved (in order) within the prediction,
422/// we only count the newly inserted characters, not the preserved ones.
423/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()")
424/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars
425///
426/// For reductions: when the prediction's text is preserved (in order) within the original,
427/// we only count the deleted characters, not the preserved ones.
428/// E.g., "ifrom" → "from" becomes 1 deleted char ("i")
429fn normalize_extension_edits(edits: Vec<GranularEdit>) -> Vec<GranularEdit> {
430 edits
431 .into_iter()
432 .flat_map(|edit| {
433 if edit.old_text.is_empty() || edit.new_text.is_empty() {
434 return vec![edit];
435 }
436
437 // Use character-wise diff to find exact byte ranges of changes
438 let char_edits = char_diff(&edit.old_text, &edit.new_text);
439
440 let all_deletions = !char_edits.is_empty()
441 && char_edits
442 .iter()
443 .all(|(range, replacement)| !range.is_empty() && replacement.is_empty());
444 let all_insertions = !char_edits.is_empty()
445 && char_edits
446 .iter()
447 .all(|(range, replacement)| range.is_empty() && !replacement.is_empty());
448 if all_deletions || all_insertions {
449 return char_edits
450 .into_iter()
451 .map(|(range, replacement)| GranularEdit {
452 range: edit.range.start + range.start..edit.range.start + range.end,
453 old_text: edit.old_text[range].to_string(),
454 new_text: replacement.to_string(),
455 })
456 .collect();
457 }
458
459 // Otherwise, keep the original edit (mixed changes)
460 vec![edit]
461 })
462 .collect()
463}
464
465fn compute_reversal_overlap(
466 original_content: &str,
467 current_content: &str,
468 predicted_content: &str,
469) -> ReversalOverlap {
470 let history_edits =
471 normalize_extension_edits(compute_granular_edits(original_content, current_content));
472 let prediction_edits =
473 normalize_extension_edits(compute_granular_edits(current_content, predicted_content));
474
475 let history_addition_ranges = compute_history_addition_ranges(&history_edits);
476 let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
477
478 let reversed_additions =
479 compute_reversed_additions(&history_addition_ranges, &prediction_edits);
480 let restored_deletions =
481 compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
482
483 let total_chars_in_prediction: usize = prediction_edits
484 .iter()
485 .map(|e| e.new_text.chars().count() + e.old_text.chars().count())
486 .sum();
487
488 ReversalOverlap {
489 chars_reversing_user_edits: reversed_additions + restored_deletions,
490 total_chars_in_prediction,
491 }
492}
493
494fn compute_reversed_additions(
495 history_addition_ranges: &[HistoryAdditionRange],
496 prediction_edits: &[GranularEdit],
497) -> usize {
498 let mut reversed_chars = 0;
499
500 for pred_edit in prediction_edits {
501 for history_addition in history_addition_ranges {
502 let overlap_start = pred_edit
503 .range
504 .start
505 .max(history_addition.range_in_current.start);
506 let overlap_end = pred_edit
507 .range
508 .end
509 .min(history_addition.range_in_current.end);
510
511 if overlap_start < overlap_end {
512 let relative_start = overlap_start - pred_edit.range.start;
513 let relative_end = overlap_end - pred_edit.range.start;
514 let overlap_text = &pred_edit.old_text[relative_start..relative_end];
515 reversed_chars += overlap_text.chars().count();
516 }
517 }
518 }
519
520 reversed_chars
521}
522
523fn compute_restored_deletions(
524 history_deletion_ranges: &[HistoryDeletionRange],
525 prediction_edits: &[GranularEdit],
526) -> usize {
527 let mut restored = 0;
528
529 for pred_edit in prediction_edits {
530 if pred_edit.new_text.is_empty() {
531 continue;
532 }
533
534 for deletion in history_deletion_ranges {
535 if pred_edit.range.contains(&deletion.position_in_current)
536 || deletion.position_in_current == pred_edit.range.start
537 {
538 restored += compute_lcs_length(&deletion.deleted_text, &pred_edit.new_text);
539 }
540 }
541 }
542
543 restored
544}
545
546fn compute_lcs_length(a: &str, b: &str) -> usize {
547 let a_chars: Vec<char> = a.chars().collect();
548 let b_chars: Vec<char> = b.chars().collect();
549 let m = a_chars.len();
550 let n = b_chars.len();
551
552 if m == 0 || n == 0 {
553 return 0;
554 }
555
556 let mut prev = vec![0; n + 1];
557 let mut curr = vec![0; n + 1];
558
559 for i in 1..=m {
560 for j in 1..=n {
561 if a_chars[i - 1] == b_chars[j - 1] {
562 curr[j] = prev[j - 1] + 1;
563 } else {
564 curr[j] = prev[j].max(curr[j - 1]);
565 }
566 }
567 std::mem::swap(&mut prev, &mut curr);
568 curr.fill(0);
569 }
570
571 prev[n]
572}
573
574fn filter_edit_history_by_path<'a>(
575 edit_history: &'a [Arc<zeta_prompt::Event>],
576 cursor_path: &std::path::Path,
577) -> Vec<&'a zeta_prompt::Event> {
578 edit_history
579 .iter()
580 .filter(|event| match event.as_ref() {
581 zeta_prompt::Event::BufferChange { path, .. } => {
582 let event_path = path.as_ref();
583 if event_path == cursor_path {
584 return true;
585 }
586 let stripped = event_path
587 .components()
588 .skip(1)
589 .collect::<std::path::PathBuf>();
590 stripped == cursor_path
591 }
592 })
593 .map(|arc| arc.as_ref())
594 .collect()
595}
596
597fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str {
598 match event {
599 zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(),
600 }
601}
602
603fn is_predicted_event(event: &zeta_prompt::Event) -> bool {
604 match event {
605 zeta_prompt::Event::BufferChange { predicted, .. } => *predicted,
606 }
607}
608
609pub fn compute_prediction_reversal_ratio_from_history(
610 current_content: &str,
611 edit_history: &[Arc<zeta_prompt::Event>],
612 excerpt_start_row: Option<u32>,
613 predicted_content: &str,
614 cursor_path: &Path,
615) -> f32 {
616 let relevant_events = filter_edit_history_by_path(edit_history, cursor_path);
617
618 let most_recent = match relevant_events.last() {
619 Some(event) if !is_predicted_event(event) => *event,
620 _ => return 0.0,
621 };
622
623 let diff = extract_diff_from_event(most_recent);
624 if diff.is_empty() {
625 return 0.0;
626 }
627
628 if let Some(excerpt_start_row) = excerpt_start_row {
629 let diffs = vec![diff];
630 let overlap = compute_excerpt_aware_reversal_overlap(
631 &diffs,
632 current_content,
633 excerpt_start_row,
634 predicted_content,
635 );
636 return overlap.ratio();
637 }
638
639 let reversed = reverse_diff(diff);
640 let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed);
641 let original_content = match apply_diff_to_string(&with_headers, current_content) {
642 Ok(updated_content) => updated_content,
643 Err(_) => apply_diff_to_string_lenient(&reversed, current_content),
644 };
645
646 let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content);
647 overlap.ratio()
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653 use indoc::indoc;
654 use zeta_prompt::udiff::apply_diff_to_string;
655 use zeta_prompt::{ExcerptRanges, ZetaPromptInput};
656
657 fn compute_prediction_reversal_ratio(
658 prompt_inputs: &ZetaPromptInput,
659 predicted_content: &str,
660 cursor_path: &Path,
661 ) -> f32 {
662 compute_prediction_reversal_ratio_from_history(
663 prompt_inputs.cursor_excerpt.as_ref(),
664 &prompt_inputs.events,
665 prompt_inputs.excerpt_start_row,
666 predicted_content,
667 cursor_path,
668 )
669 }
670
671 fn make_test_prompt_inputs(
672 content: &str,
673 events: Vec<Arc<zeta_prompt::Event>>,
674 excerpt_start_row: Option<u32>,
675 ) -> ZetaPromptInput {
676 ZetaPromptInput {
677 cursor_path: Arc::from(Path::new("src/test.rs")),
678 cursor_excerpt: content.into(),
679 cursor_offset_in_excerpt: 0,
680 excerpt_start_row,
681 events,
682 related_files: Some(Vec::new()),
683 active_buffer_diagnostics: Vec::new(),
684 excerpt_ranges: ExcerptRanges {
685 editable_150: 0..content.len(),
686 editable_180: 0..content.len(),
687 editable_350: 0..content.len(),
688 editable_150_context_350: 0..content.len(),
689 editable_180_context_350: 0..content.len(),
690 editable_350_context_150: 0..content.len(),
691 ..Default::default()
692 },
693 syntax_ranges: None,
694 in_open_source_repo: false,
695 can_collect_data: false,
696 repo_url: None,
697 }
698 }
699
700 #[test]
701 fn test_reversal_overlap() {
702 struct Case {
703 name: &'static str,
704 original: &'static str,
705 current: &'static str,
706 predicted: &'static str,
707 expected_reversal_chars: usize,
708 expected_total_chars: usize,
709 }
710
711 let cases = [
712 Case {
713 name: "user_adds_line_prediction_removes_it",
714 original: indoc! {"
715 a
716 b
717 c"},
718 current: indoc! {"
719 a
720 new line
721 b
722 c"},
723 predicted: indoc! {"
724 a
725 b
726 c"},
727 expected_reversal_chars: 9,
728 expected_total_chars: 9,
729 },
730 Case {
731 name: "user_deletes_line_prediction_restores_it",
732 original: indoc! {"
733 a
734 deleted
735 b"},
736 current: indoc! {"
737 a
738 b"},
739 predicted: indoc! {"
740 a
741 deleted
742 b"},
743 expected_reversal_chars: 8,
744 expected_total_chars: 8,
745 },
746 Case {
747 name: "user_deletes_text_prediction_restores_partial",
748 original: "hello beautiful world",
749 current: "hello world",
750 predicted: "hello beautiful world",
751 expected_reversal_chars: 10,
752 expected_total_chars: 10,
753 },
754 Case {
755 name: "user_deletes_foo_prediction_adds_bar",
756 original: "foo",
757 current: "",
758 predicted: "bar",
759 expected_reversal_chars: 0,
760 expected_total_chars: 3,
761 },
762 Case {
763 name: "independent_edits_different_locations",
764 original: indoc! {"
765 line1
766 line2
767 line3"},
768 current: indoc! {"
769 LINE1
770 line2
771 line3"},
772 predicted: indoc! {"
773 LINE1
774 line2
775 LINE3"},
776 expected_reversal_chars: 0,
777 expected_total_chars: 10,
778 },
779 Case {
780 name: "no_history_edits",
781 original: "same",
782 current: "same",
783 predicted: "different",
784 expected_reversal_chars: 0,
785 expected_total_chars: 13,
786 },
787 Case {
788 name: "user_replaces_text_prediction_reverses",
789 original: indoc! {"
790 keep
791 delete_me
792 keep2"},
793 current: indoc! {"
794 keep
795 added
796 keep2"},
797 predicted: indoc! {"
798 keep
799 delete_me
800 keep2"},
801 expected_reversal_chars: 14,
802 expected_total_chars: 14,
803 },
804 Case {
805 name: "user_modifies_word_prediction_modifies_differently",
806 original: "the quick brown fox",
807 current: "the slow brown fox",
808 predicted: "the fast brown fox",
809 expected_reversal_chars: 4,
810 expected_total_chars: 8,
811 },
812 Case {
813 name: "user finishes function name (suffix)",
814 original: "",
815 current: "epr",
816 predicted: "eprintln!()",
817 expected_reversal_chars: 0,
818 expected_total_chars: 8,
819 },
820 Case {
821 name: "user starts function name (prefix)",
822 original: "",
823 current: "my_function()",
824 predicted: "test_my_function()",
825 expected_reversal_chars: 0,
826 expected_total_chars: 5,
827 },
828 Case {
829 name: "user types partial, prediction extends in multiple places",
830 original: "",
831 current: "test_my_function",
832 predicted: "a_test_for_my_special_function_plz",
833 expected_reversal_chars: 0,
834 expected_total_chars: 18,
835 },
836 // Edge cases for subsequence matching
837 Case {
838 name: "subsequence with interleaved underscores",
839 original: "",
840 current: "a_b_c",
841 predicted: "_a__b__c__",
842 expected_reversal_chars: 0,
843 expected_total_chars: 5,
844 },
845 Case {
846 name: "not a subsequence - different characters",
847 original: "",
848 current: "abc",
849 predicted: "xyz",
850 expected_reversal_chars: 3,
851 expected_total_chars: 6,
852 },
853 Case {
854 name: "not a subsequence - wrong order",
855 original: "",
856 current: "abc",
857 predicted: "cba",
858 expected_reversal_chars: 3,
859 expected_total_chars: 6,
860 },
861 Case {
862 name: "partial subsequence - only some chars match",
863 original: "",
864 current: "abcd",
865 predicted: "axbx",
866 expected_reversal_chars: 4,
867 expected_total_chars: 8,
868 },
869 // Common completion patterns
870 Case {
871 name: "completing a method call",
872 original: "",
873 current: "vec.pu",
874 predicted: "vec.push(item)",
875 expected_reversal_chars: 0,
876 expected_total_chars: 8,
877 },
878 Case {
879 name: "completing an import statement",
880 original: "",
881 current: "use std::col",
882 predicted: "use std::collections::HashMap",
883 expected_reversal_chars: 0,
884 expected_total_chars: 17,
885 },
886 Case {
887 name: "completing a struct field",
888 original: "",
889 current: "name: St",
890 predicted: "name: String",
891 expected_reversal_chars: 0,
892 expected_total_chars: 4,
893 },
894 Case {
895 name: "prediction replaces with completely different text",
896 original: "",
897 current: "hello",
898 predicted: "world",
899 expected_reversal_chars: 5,
900 expected_total_chars: 10,
901 },
902 Case {
903 name: "empty prediction removes user text",
904 original: "",
905 current: "mistake",
906 predicted: "",
907 expected_reversal_chars: 7,
908 expected_total_chars: 7,
909 },
910 Case {
911 name: "fixing typo is not reversal",
912 original: "",
913 current: "<dv",
914 predicted: "<div>",
915 expected_reversal_chars: 0,
916 expected_total_chars: 2,
917 },
918 Case {
919 name: "infix insertion not reversal",
920 original: indoc! {"
921 from my_project import Foo
922 "},
923 current: indoc! {"
924 ifrom my_project import Foo
925 "},
926 predicted: indoc! {"
927 import
928 from my_project import Foo
929 "},
930 expected_reversal_chars: 0,
931 expected_total_chars: 6,
932 },
933 Case {
934 name: "non-word based reversal",
935 original: "from",
936 current: "ifrom",
937 predicted: "from",
938 expected_reversal_chars: 1,
939 expected_total_chars: 1,
940 },
941 Case {
942 name: "multiple insertions no reversal",
943 original: "print(\"Hello, World!\")",
944 current: "sys.(\"Hello, World!\")",
945 predicted: "sys.stdout.write(\"Hello, World!\\n\")",
946 expected_reversal_chars: 0,
947 expected_total_chars: 14,
948 },
949 ];
950
951 for case in &cases {
952 let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
953 assert_eq!(
954 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
955 "Test '{}': expected {} reversal chars, got {}",
956 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
957 );
958 assert_eq!(
959 overlap.total_chars_in_prediction, case.expected_total_chars,
960 "Test '{}': expected {} total chars, got {}",
961 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
962 );
963 }
964 }
965
966 #[test]
967 fn test_reverse_diff() {
968 let forward_diff = indoc! {"
969 --- a/file.rs
970 +++ b/file.rs
971 @@ -1,3 +1,4 @@
972 fn main() {
973 + let x = 42;
974 println!(\"hello\");
975 }"};
976
977 let reversed = reverse_diff(forward_diff);
978
979 assert!(
980 reversed.contains("+++ a/file.rs"),
981 "Should have +++ for old path"
982 );
983 assert!(
984 reversed.contains("--- b/file.rs"),
985 "Should have --- for new path"
986 );
987 assert!(
988 reversed.contains("- let x = 42;"),
989 "Added line should become deletion"
990 );
991 assert!(
992 reversed.contains(" fn main()"),
993 "Context lines should be unchanged"
994 );
995 }
996
997 #[test]
998 fn test_reverse_diff_roundtrip() {
999 // Applying a diff and then its reverse should get back to original
1000 let original = indoc! {"
1001 first line
1002 hello world
1003 last line
1004 "};
1005 let modified = indoc! {"
1006 first line
1007 hello beautiful world
1008 last line
1009 "};
1010
1011 // unified_diff doesn't include file headers, but apply_diff_to_string needs them
1012 let diff_body = language::unified_diff(original, modified);
1013 let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body);
1014 let reversed_diff = reverse_diff(&forward_diff);
1015
1016 // Apply forward diff to original
1017 let after_forward = apply_diff_to_string(&forward_diff, original).unwrap();
1018 assert_eq!(after_forward, modified);
1019
1020 // Apply reversed diff to modified
1021 let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap();
1022 assert_eq!(after_reverse, original);
1023 }
1024
1025 #[test]
1026 fn test_filter_edit_history_by_path() {
1027 // Test that filter_edit_history_by_path correctly matches paths when
1028 // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs")
1029 // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs")
1030 let events = vec![
1031 Arc::new(zeta_prompt::Event::BufferChange {
1032 path: Arc::from(Path::new("myrepo/src/file.rs")),
1033 old_path: Arc::from(Path::new("myrepo/src/file.rs")),
1034 diff: indoc! {"
1035 @@ -1 +1 @@
1036 -old
1037 +new"}
1038 .into(),
1039 predicted: false,
1040 in_open_source_repo: true,
1041 }),
1042 Arc::new(zeta_prompt::Event::BufferChange {
1043 path: Arc::from(Path::new("myrepo/other.rs")),
1044 old_path: Arc::from(Path::new("myrepo/other.rs")),
1045 diff: indoc! {"
1046 @@ -1 +1 @@
1047 -a
1048 +b"}
1049 .into(),
1050 predicted: false,
1051 in_open_source_repo: true,
1052 }),
1053 Arc::new(zeta_prompt::Event::BufferChange {
1054 path: Arc::from(Path::new("src/file.rs")),
1055 old_path: Arc::from(Path::new("src/file.rs")),
1056 diff: indoc! {"
1057 @@ -1 +1 @@
1058 -x
1059 +y"}
1060 .into(),
1061 predicted: false,
1062 in_open_source_repo: true,
1063 }),
1064 ];
1065
1066 // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path
1067 // "src/file.rs" exact match
1068 let cursor_path = Path::new("src/file.rs");
1069 let filtered = filter_edit_history_by_path(&events, cursor_path);
1070 assert_eq!(
1071 filtered.len(),
1072 2,
1073 "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)"
1074 );
1075
1076 // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs"
1077 // "src/file.rs" stripped -> "file.rs" == "file.rs"
1078 let cursor_path = Path::new("file.rs");
1079 let filtered = filter_edit_history_by_path(&events, cursor_path);
1080 assert_eq!(
1081 filtered.len(),
1082 1,
1083 "Should only match src/file.rs (stripped to file.rs)"
1084 );
1085
1086 // "myrepo/other.rs" stripped -> "other.rs" == "other.rs"
1087 let cursor_path = Path::new("other.rs");
1088 let filtered = filter_edit_history_by_path(&events, cursor_path);
1089 assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs");
1090 }
1091
1092 #[test]
1093 fn test_reverse_diff_preserves_trailing_newline() {
1094 let diff_with_trailing_newline = indoc! {"
1095 --- a/file
1096 +++ b/file
1097 @@ -1 +1 @@
1098 -old
1099 +new
1100 "};
1101 let reversed = reverse_diff(diff_with_trailing_newline);
1102 assert!(
1103 reversed.ends_with('\n'),
1104 "Reversed diff should preserve trailing newline"
1105 );
1106
1107 let diff_without_trailing_newline = indoc! {"
1108 --- a/file
1109 +++ b/file
1110 @@ -1 +1 @@
1111 -old
1112 +new"};
1113 let reversed = reverse_diff(diff_without_trailing_newline);
1114 assert!(
1115 !reversed.ends_with('\n'),
1116 "Reversed diff should not add trailing newline if original didn't have one"
1117 );
1118 }
1119
1120 #[test]
1121 fn test_filter_hunks_by_excerpt_region() {
1122 struct Case {
1123 name: &'static str,
1124 diff: &'static str,
1125 excerpt_start_row: u32,
1126 excerpt_row_count: u32,
1127 expected_filtered_diff: &'static str,
1128 expected_line_offset: i32,
1129 }
1130
1131 let cases = [
1132 Case {
1133 name: "hunk_entirely_before_excerpt",
1134 diff: indoc! {"
1135 @@ -1,3 +1,4 @@
1136 line1
1137 +inserted
1138 line2
1139 line3
1140 "},
1141 excerpt_start_row: 10,
1142 excerpt_row_count: 5,
1143 expected_filtered_diff: "",
1144 expected_line_offset: 1,
1145 },
1146 Case {
1147 name: "hunk_entirely_inside_excerpt",
1148 diff: indoc! {"
1149 @@ -12,3 +12,4 @@
1150 line12
1151 +inserted
1152 line13
1153 line14
1154 "},
1155 excerpt_start_row: 10,
1156 excerpt_row_count: 10,
1157 expected_filtered_diff: indoc! {"
1158 @@ -2,3 +2,4 @@
1159 line12
1160 +inserted
1161 line13
1162 line14
1163 "},
1164 expected_line_offset: 1,
1165 },
1166 Case {
1167 name: "hunk_entirely_after_excerpt",
1168 diff: indoc! {"
1169 @@ -50,3 +50,4 @@
1170 line50
1171 +inserted
1172 line51
1173 line52
1174 "},
1175 excerpt_start_row: 10,
1176 excerpt_row_count: 5,
1177 expected_filtered_diff: "",
1178 expected_line_offset: 0,
1179 },
1180 Case {
1181 name: "hunk_straddles_excerpt_start",
1182 diff: indoc! {"
1183 @@ -8,5 +8,6 @@
1184 line8
1185 line9
1186 +inserted
1187 line10
1188 line11
1189 line12
1190 "},
1191 excerpt_start_row: 10,
1192 excerpt_row_count: 10,
1193 expected_filtered_diff: indoc! {"
1194 @@ -1,3 +1,3 @@
1195 line10
1196 line11
1197 line12
1198 "},
1199 expected_line_offset: 1,
1200 },
1201 Case {
1202 name: "hunk_straddles_excerpt_end",
1203 diff: indoc! {"
1204 @@ -18,5 +18,6 @@
1205 line18
1206 line19
1207 +inserted
1208 line20
1209 line21
1210 line22
1211 "},
1212 excerpt_start_row: 10,
1213 excerpt_row_count: 10,
1214 expected_filtered_diff: indoc! {"
1215 @@ -8,2 +8,3 @@
1216 line18
1217 line19
1218 +inserted
1219 "},
1220 expected_line_offset: 1,
1221 },
1222 Case {
1223 name: "multiple_hunks_mixed",
1224 diff: indoc! {"
1225 @@ -1,2 +1,3 @@
1226 line1
1227 +before_excerpt
1228 line2
1229 @@ -12,2 +13,3 @@
1230 line12
1231 +inside_excerpt
1232 line13
1233 @@ -50,2 +52,3 @@
1234 line50
1235 +after_excerpt
1236 line51
1237 "},
1238 excerpt_start_row: 10,
1239 excerpt_row_count: 10,
1240 expected_filtered_diff: indoc! {"
1241 @@ -3,2 +3,3 @@
1242 line12
1243 +inside_excerpt
1244 line13
1245 "},
1246 expected_line_offset: 2,
1247 },
1248 Case {
1249 name: "deletion_before_excerpt",
1250 diff: indoc! {"
1251 @@ -1,4 +1,3 @@
1252 line1
1253 -deleted
1254 line2
1255 line3
1256 "},
1257 excerpt_start_row: 10,
1258 excerpt_row_count: 5,
1259 expected_filtered_diff: "",
1260 expected_line_offset: -1,
1261 },
1262 Case {
1263 name: "deletion_inside_excerpt",
1264 diff: indoc! {"
1265 @@ -12,4 +12,3 @@
1266 line12
1267 -deleted
1268 line13
1269 line14
1270 "},
1271 excerpt_start_row: 10,
1272 excerpt_row_count: 10,
1273 expected_filtered_diff: indoc! {"
1274 @@ -2,4 +2,3 @@
1275 line12
1276 -deleted
1277 line13
1278 line14
1279 "},
1280 expected_line_offset: -1,
1281 },
1282 Case {
1283 name: "empty_diff",
1284 diff: "",
1285 excerpt_start_row: 10,
1286 excerpt_row_count: 5,
1287 expected_filtered_diff: "",
1288 expected_line_offset: 0,
1289 },
1290 Case {
1291 name: "hunk_spans_entire_excerpt",
1292 diff: indoc! {"
1293 @@ -8,10 +8,12 @@
1294 line8
1295 line9
1296 line10
1297 line11
1298 +inserted1
1299 line12
1300 line13
1301 +inserted2
1302 line14
1303 line15
1304 line16
1305 line17
1306 "},
1307 excerpt_start_row: 10,
1308 excerpt_row_count: 5,
1309 expected_filtered_diff: indoc! {"
1310 @@ -1,3 +1,5 @@
1311 line11
1312 +inserted1
1313 line12
1314 line13
1315 +inserted2
1316 "},
1317 expected_line_offset: 2,
1318 },
1319 Case {
1320 name: "replacement_inside_excerpt",
1321 diff: indoc! {"
1322 @@ -12,3 +12,3 @@
1323 line12
1324 -old_text
1325 +new_text
1326 line14
1327 "},
1328 excerpt_start_row: 10,
1329 excerpt_row_count: 10,
1330 expected_filtered_diff: indoc! {"
1331 @@ -2,3 +2,3 @@
1332 line12
1333 -old_text
1334 +new_text
1335 line14
1336 "},
1337 expected_line_offset: 0,
1338 },
1339 ];
1340
1341 for case in &cases {
1342 let (filtered, line_offset) = filter_diff_hunks_by_excerpt(
1343 case.diff,
1344 case.excerpt_start_row,
1345 case.excerpt_row_count,
1346 );
1347 assert_eq!(
1348 filtered, case.expected_filtered_diff,
1349 "Test '{}': filtered diff mismatch.\nExpected:\n{}\nGot:\n{}",
1350 case.name, case.expected_filtered_diff, filtered
1351 );
1352 assert_eq!(
1353 line_offset, case.expected_line_offset,
1354 "Test '{}': line offset mismatch. Expected {}, got {}",
1355 case.name, case.expected_line_offset, line_offset
1356 );
1357 }
1358 }
1359
1360 #[test]
1361 fn test_excerpt_aware_reversal_tracking() {
1362 struct Case {
1363 name: &'static str,
1364 edit_history_diffs: Vec<&'static str>,
1365 excerpt_content: &'static str,
1366 excerpt_start_row: u32,
1367 predicted_content: &'static str,
1368 expected_reversal_chars: usize,
1369 expected_total_chars: usize,
1370 }
1371
1372 let cases = [
1373 Case {
1374 name: "edit_outside_excerpt_no_reversal",
1375 edit_history_diffs: vec![indoc! {"
1376 @@ -1,2 +1,3 @@
1377 line1
1378 +added_outside
1379 line2
1380 "}],
1381 excerpt_content: indoc! {"
1382 line10
1383 line11
1384 line12
1385 "},
1386 excerpt_start_row: 10,
1387 predicted_content: indoc! {"
1388 line10
1389 modified
1390 line12
1391 "},
1392 expected_reversal_chars: 0,
1393 expected_total_chars: 14,
1394 },
1395 Case {
1396 name: "edit_inside_excerpt_with_reversal",
1397 edit_history_diffs: vec![indoc! {"
1398 @@ -10,3 +10,4 @@
1399 line10
1400 +user_added
1401 line11
1402 line12
1403 "}],
1404 excerpt_content: indoc! {"
1405 line10
1406 user_added
1407 line11
1408 line12
1409 "},
1410 excerpt_start_row: 10,
1411 predicted_content: indoc! {"
1412 line10
1413 line11
1414 line12
1415 "},
1416 expected_reversal_chars: 11,
1417 expected_total_chars: 11,
1418 },
1419 Case {
1420 name: "straddling_edit_partial_reversal",
1421 edit_history_diffs: vec![indoc! {"
1422 @@ -8,6 +8,8 @@
1423 line8
1424 line9
1425 +before_excerpt
1426 line10
1427 +inside_excerpt
1428 line11
1429 line12
1430 line13
1431 "}],
1432 excerpt_content: indoc! {"
1433 line10
1434 inside_excerpt
1435 line11
1436 line12
1437 line13
1438 "},
1439 excerpt_start_row: 10,
1440 predicted_content: indoc! {"
1441 line10
1442 line11
1443 line12
1444 line13
1445 "},
1446 expected_reversal_chars: 15,
1447 expected_total_chars: 15,
1448 },
1449 Case {
1450 name: "multiple_edits_mixed_locations",
1451 edit_history_diffs: vec![
1452 indoc! {"
1453 @@ -1,2 +1,3 @@
1454 line1
1455 +outside1
1456 line2
1457 "},
1458 indoc! {"
1459 @@ -11,2 +12,3 @@
1460 line11
1461 +inside1
1462 line12
1463 "},
1464 ],
1465 excerpt_content: indoc! {"
1466 line10
1467 line11
1468 inside1
1469 line12
1470 line13
1471 "},
1472 excerpt_start_row: 10,
1473 predicted_content: indoc! {"
1474 line10
1475 line11
1476 line12
1477 line13
1478 "},
1479 expected_reversal_chars: 8,
1480 expected_total_chars: 8,
1481 },
1482 Case {
1483 name: "no_edit_history",
1484 edit_history_diffs: vec![],
1485 excerpt_content: indoc! {"
1486 line10
1487 line11
1488 line12
1489 "},
1490 excerpt_start_row: 10,
1491 predicted_content: indoc! {"
1492 line10
1493 modified
1494 line12
1495 "},
1496 expected_reversal_chars: 0,
1497 expected_total_chars: 14,
1498 },
1499 Case {
1500 name: "edit_after_excerpt_no_effect",
1501 edit_history_diffs: vec![indoc! {"
1502 @@ -50,2 +50,3 @@
1503 line50
1504 +added_after
1505 line51
1506 "}],
1507 excerpt_content: indoc! {"
1508 line10
1509 line11
1510 line12
1511 "},
1512 excerpt_start_row: 10,
1513 predicted_content: indoc! {"
1514 line10
1515 changed
1516 line12
1517 "},
1518 expected_reversal_chars: 0,
1519 expected_total_chars: 13,
1520 },
1521 Case {
1522 name: "line_offset_tracking_across_hunks",
1523 edit_history_diffs: vec![
1524 indoc! {"
1525 @@ -1,2 +1,4 @@
1526 line1
1527 +added1
1528 +added2
1529 line2
1530 "},
1531 indoc! {"
1532 @@ -12,2 +14,3 @@
1533 line12
1534 +inside_after_offset
1535 line13
1536 "},
1537 ],
1538 excerpt_content: indoc! {"
1539 line10
1540 line11
1541 line12
1542 inside_after_offset
1543 line13
1544 "},
1545 excerpt_start_row: 10,
1546 predicted_content: indoc! {"
1547 line10
1548 line11
1549 line12
1550 line13
1551 "},
1552 expected_reversal_chars: 20,
1553 expected_total_chars: 20,
1554 },
1555 ];
1556
1557 for case in &cases {
1558 let overlap = compute_excerpt_aware_reversal_overlap(
1559 &case.edit_history_diffs,
1560 case.excerpt_content,
1561 case.excerpt_start_row,
1562 case.predicted_content,
1563 );
1564 assert_eq!(
1565 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
1566 "Test '{}': expected {} reversal chars, got {}",
1567 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
1568 );
1569 assert_eq!(
1570 overlap.total_chars_in_prediction, case.expected_total_chars,
1571 "Test '{}': expected {} total chars, got {}",
1572 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
1573 );
1574 }
1575 }
1576
1577 #[test]
1578 fn test_lenient_diff_application() {
1579 struct Case {
1580 name: &'static str,
1581 diff: &'static str,
1582 content: &'static str,
1583 expected_result: &'static str,
1584 }
1585
1586 let cases = [
1587 Case {
1588 name: "hunk_context_not_found_skipped",
1589 diff: indoc! {"
1590 @@ -1,3 +1,4 @@
1591 context_not_in_content
1592 +added_line
1593 more_context
1594 final_context
1595 "},
1596 content: indoc! {"
1597 completely
1598 different
1599 content
1600 "},
1601 expected_result: indoc! {"
1602 completely
1603 different
1604 content
1605 "},
1606 },
1607 Case {
1608 name: "hunk_context_found_applied",
1609 diff: indoc! {"
1610 @@ -1,3 +1,4 @@
1611 line1
1612 +inserted
1613 line2
1614 line3
1615 "},
1616 content: indoc! {"
1617 line1
1618 line2
1619 line3
1620 "},
1621 expected_result: indoc! {"
1622 line1
1623 inserted
1624 line2
1625 line3
1626 "},
1627 },
1628 Case {
1629 name: "multiple_hunks_partial_match",
1630 diff: indoc! {"
1631 @@ -1,2 +1,3 @@
1632 not_found
1633 +skipped
1634 also_not_found
1635 @@ -5,2 +6,3 @@
1636 line5
1637 +applied
1638 line6
1639 "},
1640 content: indoc! {"
1641 line1
1642 line2
1643 line3
1644 line4
1645 line5
1646 line6
1647 "},
1648 expected_result: indoc! {"
1649 line1
1650 line2
1651 line3
1652 line4
1653 line5
1654 applied
1655 line6
1656 "},
1657 },
1658 Case {
1659 name: "empty_diff",
1660 diff: "",
1661 content: indoc! {"
1662 unchanged
1663 content
1664 "},
1665 expected_result: indoc! {"
1666 unchanged
1667 content
1668 "},
1669 },
1670 ];
1671
1672 for case in &cases {
1673 let result = apply_diff_to_string_lenient(case.diff, case.content);
1674 assert_eq!(
1675 result, case.expected_result,
1676 "Test '{}': expected:\n{}\ngot:\n{}",
1677 case.name, case.expected_result, result
1678 );
1679 }
1680 }
1681
1682 #[test]
1683 fn test_unicode_reversal_overlap() {
1684 struct Case {
1685 name: &'static str,
1686 original: &'static str,
1687 current: &'static str,
1688 predicted: &'static str,
1689 expected_reversal_chars: usize,
1690 expected_total_chars: usize,
1691 }
1692
1693 let cases = [
1694 Case {
1695 name: "unicode_extension_cjk",
1696 original: "",
1697 current: "日", // 1 char
1698 predicted: "日本語", // 3 chars, adds 2 chars
1699 expected_reversal_chars: 0,
1700 expected_total_chars: 2, // "本語" = 2 chars added
1701 },
1702 Case {
1703 name: "unicode_extension_emoji",
1704 original: "",
1705 current: "🎉", // 1 char
1706 predicted: "🎉🎊🎈", // 3 chars, adds 2 chars
1707 expected_reversal_chars: 0,
1708 expected_total_chars: 2, // "🎊🎈" = 2 chars added
1709 },
1710 Case {
1711 name: "unicode_deletion_restored",
1712 original: "héllo wörld", // 11 chars
1713 current: "héllo", // 5 chars
1714 predicted: "héllo wörld", // restores " wörld" = 6 chars
1715 expected_reversal_chars: 6, // LCS(" wörld", " wörld") = 6 chars
1716 expected_total_chars: 6,
1717 },
1718 Case {
1719 name: "unicode_addition_reversed",
1720 original: "café", // 4 chars
1721 current: "café latté", // 10 chars, added " latté" = 6 chars
1722 predicted: "café", // removes " latté"
1723 expected_reversal_chars: 6, // 6 chars removed
1724 expected_total_chars: 6,
1725 },
1726 Case {
1727 name: "mixed_ascii_unicode",
1728 original: "",
1729 current: "test日本", // 6 chars
1730 predicted: "test日本語です", // 9 chars
1731 expected_reversal_chars: 0,
1732 expected_total_chars: 3, // 3 new chars after subsequence normalization
1733 },
1734 Case {
1735 name: "unicode_replacement_not_subsequence",
1736 original: "",
1737 current: "日本", // 2 chars
1738 predicted: "中国", // 2 chars, different
1739 expected_reversal_chars: 2, // removes "日本" = 2 chars
1740 expected_total_chars: 4, // 2 removed + 2 added
1741 },
1742 ];
1743
1744 for case in &cases {
1745 let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
1746 assert_eq!(
1747 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
1748 "Test '{}': expected {} reversal chars, got {}",
1749 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
1750 );
1751 assert_eq!(
1752 overlap.total_chars_in_prediction, case.expected_total_chars,
1753 "Test '{}': expected {} total chars, got {}",
1754 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
1755 );
1756 }
1757 }
1758
1759 #[test]
1760 fn test_compute_lcs_length() {
1761 assert_eq!(compute_lcs_length("", ""), 0);
1762 assert_eq!(compute_lcs_length("abc", ""), 0);
1763 assert_eq!(compute_lcs_length("", "abc"), 0);
1764 assert_eq!(compute_lcs_length("abc", "abc"), 3);
1765 assert_eq!(compute_lcs_length("abc", "def"), 0);
1766 assert_eq!(compute_lcs_length("abcdef", "ace"), 3);
1767 assert_eq!(compute_lcs_length("AGGTAB", "GXTXAYB"), 4);
1768 assert_eq!(compute_lcs_length("日本語", "日語"), 2);
1769 }
1770
1771 #[test]
1772 fn test_compute_prediction_reversal_ratio_full_file() {
1773 let prompt_inputs = make_test_prompt_inputs(
1774 indoc! {"
1775 line1
1776 user_added
1777 line2
1778 "},
1779 vec![Arc::new(zeta_prompt::Event::BufferChange {
1780 path: Arc::from(Path::new("src/test.rs")),
1781 old_path: Arc::from(Path::new("src/test.rs")),
1782 diff: indoc! {"
1783 @@ -1,2 +1,3 @@
1784 line1
1785 +user_added
1786 line2
1787 "}
1788 .into(),
1789 predicted: false,
1790 in_open_source_repo: false,
1791 })],
1792 None,
1793 );
1794
1795 let predicted = indoc! {"
1796 line1
1797 line2
1798 "};
1799 let ratio =
1800 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1801
1802 assert!(
1803 ratio > 0.9,
1804 "Expected high reversal ratio when prediction removes user addition, got {}",
1805 ratio
1806 );
1807 }
1808
1809 #[test]
1810 fn test_compute_prediction_reversal_ratio_with_excerpt() {
1811 let prompt_inputs = make_test_prompt_inputs(
1812 indoc! {"
1813 line10
1814 user_added
1815 line11
1816 "},
1817 vec![Arc::new(zeta_prompt::Event::BufferChange {
1818 path: Arc::from(Path::new("src/test.rs")),
1819 old_path: Arc::from(Path::new("src/test.rs")),
1820 diff: indoc! {"
1821 @@ -10,2 +10,3 @@
1822 line10
1823 +user_added
1824 line11
1825 "}
1826 .into(),
1827 predicted: false,
1828 in_open_source_repo: false,
1829 })],
1830 Some(10),
1831 );
1832
1833 let predicted = indoc! {"
1834 line10
1835 line11
1836 "};
1837 let ratio =
1838 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1839
1840 assert!(
1841 ratio > 0.9,
1842 "Expected high reversal ratio for excerpt-aware computation, got {}",
1843 ratio
1844 );
1845 }
1846
1847 #[test]
1848 fn test_compute_prediction_reversal_ratio_no_history() {
1849 let prompt_inputs = make_test_prompt_inputs(
1850 indoc! {"
1851 original content
1852 "},
1853 vec![],
1854 None,
1855 );
1856
1857 let predicted = indoc! {"
1858 completely different
1859 "};
1860 let ratio =
1861 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1862
1863 assert_eq!(
1864 ratio, 0.0,
1865 "Expected zero reversal ratio with no edit history"
1866 );
1867 }
1868
1869 #[test]
1870 fn test_compute_prediction_reversal_ratio_path_filtering() {
1871 let prompt_inputs = make_test_prompt_inputs(
1872 indoc! {"
1873 line1
1874 user_added
1875 line2
1876 "},
1877 vec![Arc::new(zeta_prompt::Event::BufferChange {
1878 path: Arc::from(Path::new("src/other.rs")),
1879 old_path: Arc::from(Path::new("src/other.rs")),
1880 diff: indoc! {"
1881 @@ -1,2 +1,3 @@
1882 line1
1883 +user_added
1884 line2
1885 "}
1886 .into(),
1887 predicted: false,
1888 in_open_source_repo: false,
1889 })],
1890 None,
1891 );
1892
1893 let predicted = indoc! {"
1894 line1
1895 line2
1896 "};
1897 let ratio =
1898 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1899
1900 assert_eq!(
1901 ratio, 0.0,
1902 "Expected zero reversal when edit history is for different file"
1903 );
1904 }
1905
1906 #[test]
1907 fn test_compute_prediction_reversal_ratio_lenient_fallback() {
1908 let prompt_inputs = make_test_prompt_inputs(
1909 indoc! {"
1910 actual_line1
1911 user_added
1912 actual_line2
1913 "},
1914 vec![Arc::new(zeta_prompt::Event::BufferChange {
1915 path: Arc::from(Path::new("src/test.rs")),
1916 old_path: Arc::from(Path::new("src/test.rs")),
1917 diff: indoc! {"
1918 @@ -1,2 +1,3 @@
1919 wrong_context
1920 +user_added
1921 more_wrong
1922 "}
1923 .into(),
1924 predicted: false,
1925 in_open_source_repo: false,
1926 })],
1927 None,
1928 );
1929
1930 let predicted = indoc! {"
1931 actual_line1
1932 actual_line2
1933 "};
1934 let ratio =
1935 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1936
1937 assert!(
1938 ratio >= 0.0 && ratio <= 1.0,
1939 "Ratio should be valid even with lenient fallback, got {}",
1940 ratio
1941 );
1942 }
1943
1944 #[test]
1945 fn test_excerpt_aware_reversal_error_recovery() {
1946 let diffs = vec![indoc! {"
1947 @@ -1,2 +1,3 @@
1948 nonexistent_context
1949 +added
1950 more_nonexistent
1951 "}];
1952 let excerpt_content = indoc! {"
1953 completely
1954 different
1955 content
1956 "};
1957 let predicted_content = indoc! {"
1958 completely
1959 modified
1960 content
1961 "};
1962
1963 let overlap =
1964 compute_excerpt_aware_reversal_overlap(&diffs, excerpt_content, 0, predicted_content);
1965
1966 assert!(
1967 overlap.ratio() >= 0.0 && overlap.ratio() <= 1.0,
1968 "Should handle failed diff application gracefully"
1969 );
1970 }
1971
1972 #[test]
1973 fn test_only_most_recent_edit_tracked() {
1974 let prompt_inputs = make_test_prompt_inputs(
1975 indoc! {"
1976 line1
1977 first_add
1978 second_add
1979 line2
1980 "},
1981 vec![
1982 Arc::new(zeta_prompt::Event::BufferChange {
1983 path: Arc::from(Path::new("src/test.rs")),
1984 old_path: Arc::from(Path::new("src/test.rs")),
1985 diff: indoc! {"
1986 @@ -1,2 +1,3 @@
1987 line1
1988 +first_add
1989 line2
1990 "}
1991 .into(),
1992 predicted: false,
1993 in_open_source_repo: false,
1994 }),
1995 Arc::new(zeta_prompt::Event::BufferChange {
1996 path: Arc::from(Path::new("src/test.rs")),
1997 old_path: Arc::from(Path::new("src/test.rs")),
1998 diff: indoc! {"
1999 @@ -2,2 +2,3 @@
2000 first_add
2001 +second_add
2002 line2
2003 "}
2004 .into(),
2005 predicted: false,
2006 in_open_source_repo: false,
2007 }),
2008 ],
2009 None,
2010 );
2011
2012 let predicted = indoc! {"
2013 line1
2014 first_add
2015 line2
2016 "};
2017 let ratio =
2018 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
2019
2020 assert!(
2021 ratio > 0.9,
2022 "Expected high reversal ratio when prediction exactly reverses the most recent edit, got {}",
2023 ratio
2024 );
2025 }
2026}