1use std::ops::Range;
2use std::path::Path;
3use std::sync::Arc;
4
5use language::{char_diff, text_diff};
6use zeta_prompt::udiff::apply_diff_to_string;
7
8use zeta_prompt::ZetaPromptInput;
9
10fn apply_diff_to_string_lenient(diff_str: &str, text: &str) -> String {
11 let hunks = parse_diff_hunks(diff_str);
12 let mut result = text.to_string();
13
14 for hunk in hunks {
15 let hunk_diff = format!("--- a/file\n+++ b/file\n{}", format_hunk(&hunk));
16 if let Ok(updated) = apply_diff_to_string(&hunk_diff, &result) {
17 result = updated;
18 }
19 }
20
21 result
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25struct ParsedHunk {
26 old_start: u32,
27 old_count: u32,
28 new_start: u32,
29 new_count: u32,
30 lines: Vec<HunkLine>,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34enum HunkLine {
35 Context(String),
36 Addition(String),
37 Deletion(String),
38}
39
40fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> {
41 let line = line.strip_prefix("@@ -")?;
42 let (old_part, rest) = line.split_once(' ')?;
43 let rest = rest.strip_prefix('+')?;
44 let (new_part, _) = rest.split_once(" @@")?;
45
46 let (old_start, old_count) = if let Some((start, count)) = old_part.split_once(',') {
47 (start.parse().ok()?, count.parse().ok()?)
48 } else {
49 (old_part.parse().ok()?, 1)
50 };
51
52 let (new_start, new_count) = if let Some((start, count)) = new_part.split_once(',') {
53 (start.parse().ok()?, count.parse().ok()?)
54 } else {
55 (new_part.parse().ok()?, 1)
56 };
57
58 Some((old_start, old_count, new_start, new_count))
59}
60
61fn parse_diff_hunks(diff: &str) -> Vec<ParsedHunk> {
62 let mut hunks = Vec::new();
63 let mut current_hunk: Option<ParsedHunk> = None;
64
65 for line in diff.lines() {
66 if let Some((old_start, old_count, new_start, new_count)) = parse_hunk_header(line) {
67 if let Some(hunk) = current_hunk.take() {
68 hunks.push(hunk);
69 }
70 current_hunk = Some(ParsedHunk {
71 old_start,
72 old_count,
73 new_start,
74 new_count,
75 lines: Vec::new(),
76 });
77 } else if let Some(ref mut hunk) = current_hunk {
78 if let Some(stripped) = line.strip_prefix('+') {
79 hunk.lines.push(HunkLine::Addition(stripped.to_string()));
80 } else if let Some(stripped) = line.strip_prefix('-') {
81 hunk.lines.push(HunkLine::Deletion(stripped.to_string()));
82 } else if let Some(stripped) = line.strip_prefix(' ') {
83 hunk.lines.push(HunkLine::Context(stripped.to_string()));
84 } else if line.is_empty() {
85 hunk.lines.push(HunkLine::Context(String::new()));
86 }
87 }
88 }
89
90 if let Some(hunk) = current_hunk {
91 hunks.push(hunk);
92 }
93
94 hunks
95}
96
97fn format_hunk(hunk: &ParsedHunk) -> String {
98 let mut result = format!(
99 "@@ -{},{} +{},{} @@\n",
100 hunk.old_start, hunk.old_count, hunk.new_start, hunk.new_count
101 );
102 for line in &hunk.lines {
103 match line {
104 HunkLine::Context(text) => {
105 result.push(' ');
106 result.push_str(text);
107 result.push('\n');
108 }
109 HunkLine::Addition(text) => {
110 result.push('+');
111 result.push_str(text);
112 result.push('\n');
113 }
114 HunkLine::Deletion(text) => {
115 result.push('-');
116 result.push_str(text);
117 result.push('\n');
118 }
119 }
120 }
121 result
122}
123
124fn filter_diff_hunks_by_excerpt(
125 diff: &str,
126 excerpt_start_row: u32,
127 excerpt_row_count: u32,
128) -> (String, i32) {
129 let hunks = parse_diff_hunks(diff);
130 let excerpt_start_0based = excerpt_start_row;
131 let excerpt_end_0based = excerpt_start_row + excerpt_row_count;
132
133 let mut filtered_hunks = Vec::new();
134 let mut cumulative_line_offset: i32 = 0;
135
136 for hunk in hunks {
137 let hunk_start_0based = hunk.new_start.saturating_sub(1);
138 let hunk_end_0based = hunk_start_0based + hunk.new_count;
139
140 let additions: i32 = hunk
141 .lines
142 .iter()
143 .filter(|l| matches!(l, HunkLine::Addition(_)))
144 .count() as i32;
145 let deletions: i32 = hunk
146 .lines
147 .iter()
148 .filter(|l| matches!(l, HunkLine::Deletion(_)))
149 .count() as i32;
150 let hunk_line_delta = additions - deletions;
151
152 if hunk_end_0based <= excerpt_start_0based {
153 cumulative_line_offset += hunk_line_delta;
154 continue;
155 }
156
157 if hunk_start_0based >= excerpt_end_0based {
158 continue;
159 }
160
161 let mut filtered_lines = Vec::new();
162 let mut current_row_0based = hunk_start_0based;
163 let mut filtered_old_count = 0u32;
164 let mut filtered_new_count = 0u32;
165 let mut first_included_row: Option<u32> = None;
166
167 for line in &hunk.lines {
168 match line {
169 HunkLine::Context(text) => {
170 if current_row_0based >= excerpt_start_0based
171 && current_row_0based < excerpt_end_0based
172 {
173 if first_included_row.is_none() {
174 first_included_row = Some(current_row_0based);
175 }
176 filtered_lines.push(HunkLine::Context(text.clone()));
177 filtered_old_count += 1;
178 filtered_new_count += 1;
179 }
180 current_row_0based += 1;
181 }
182 HunkLine::Addition(text) => {
183 if current_row_0based >= excerpt_start_0based
184 && current_row_0based < excerpt_end_0based
185 {
186 if first_included_row.is_none() {
187 first_included_row = Some(current_row_0based);
188 }
189 filtered_lines.push(HunkLine::Addition(text.clone()));
190 filtered_new_count += 1;
191 }
192 current_row_0based += 1;
193 }
194 HunkLine::Deletion(text) => {
195 if current_row_0based >= excerpt_start_0based
196 && current_row_0based < excerpt_end_0based
197 {
198 if first_included_row.is_none() {
199 first_included_row = Some(current_row_0based);
200 }
201 filtered_lines.push(HunkLine::Deletion(text.clone()));
202 filtered_old_count += 1;
203 }
204 }
205 }
206 }
207
208 if !filtered_lines.is_empty() {
209 let first_row = first_included_row.unwrap_or(excerpt_start_0based);
210 let new_start_1based = (first_row - excerpt_start_0based) + 1;
211
212 filtered_hunks.push(ParsedHunk {
213 old_start: new_start_1based,
214 old_count: filtered_old_count,
215 new_start: new_start_1based,
216 new_count: filtered_new_count,
217 lines: filtered_lines,
218 });
219 }
220
221 cumulative_line_offset += hunk_line_delta;
222 }
223
224 let mut result = String::new();
225 for hunk in &filtered_hunks {
226 result.push_str(&format_hunk(hunk));
227 }
228
229 (result, cumulative_line_offset)
230}
231
232fn compute_excerpt_aware_reversal_overlap(
233 edit_history_diffs: &[&str],
234 excerpt_content: &str,
235 excerpt_start_row: u32,
236 predicted_content: &str,
237) -> ReversalOverlap {
238 let mut current_content = excerpt_content.to_string();
239 let mut current_excerpt_start_row = excerpt_start_row;
240
241 for diff in edit_history_diffs.iter().rev() {
242 if diff.is_empty() {
243 continue;
244 }
245
246 let current_row_count = current_content.lines().count() as u32;
247 let (filtered_diff, _line_offset) =
248 filter_diff_hunks_by_excerpt(diff, current_excerpt_start_row, current_row_count.max(1));
249
250 if filtered_diff.is_empty() {
251 let hunks = parse_diff_hunks(diff);
252 for hunk in hunks {
253 let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
254 if hunk_end <= current_excerpt_start_row {
255 let additions: u32 = hunk
256 .lines
257 .iter()
258 .filter(|l| matches!(l, HunkLine::Addition(_)))
259 .count() as u32;
260 let deletions: u32 = hunk
261 .lines
262 .iter()
263 .filter(|l| matches!(l, HunkLine::Deletion(_)))
264 .count() as u32;
265 if additions >= deletions {
266 current_excerpt_start_row =
267 current_excerpt_start_row.saturating_sub(additions - deletions);
268 } else {
269 current_excerpt_start_row += deletions - additions;
270 }
271 }
272 }
273 continue;
274 }
275
276 let reversed = reverse_diff(&format!("--- a/file\n+++ b/file\n{}", filtered_diff));
277 match apply_diff_to_string(&reversed, ¤t_content) {
278 Ok(updated) => {
279 current_content = updated;
280 }
281 Err(_) => {
282 continue;
283 }
284 }
285
286 let hunks = parse_diff_hunks(diff);
287 for hunk in hunks {
288 let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
289 if hunk_end <= current_excerpt_start_row {
290 let additions: u32 = hunk
291 .lines
292 .iter()
293 .filter(|l| matches!(l, HunkLine::Addition(_)))
294 .count() as u32;
295 let deletions: u32 = hunk
296 .lines
297 .iter()
298 .filter(|l| matches!(l, HunkLine::Deletion(_)))
299 .count() as u32;
300 if additions >= deletions {
301 current_excerpt_start_row =
302 current_excerpt_start_row.saturating_sub(additions - deletions);
303 } else {
304 current_excerpt_start_row += deletions - additions;
305 }
306 }
307 }
308 }
309
310 compute_reversal_overlap(¤t_content, excerpt_content, predicted_content)
311}
312
313fn reverse_diff(diff: &str) -> String {
314 let mut result: String = diff
315 .lines()
316 .map(|line| {
317 if line.starts_with("--- ") {
318 line.replacen("--- ", "+++ ", 1)
319 } else if line.starts_with("+++ ") {
320 line.replacen("+++ ", "--- ", 1)
321 } else if line.starts_with('+') && !line.starts_with("+++") {
322 format!("-{}", &line[1..])
323 } else if line.starts_with('-') && !line.starts_with("---") {
324 format!("+{}", &line[1..])
325 } else {
326 line.to_string()
327 }
328 })
329 .collect::<Vec<_>>()
330 .join("\n");
331 if diff.ends_with('\n') {
332 result.push('\n');
333 }
334 result
335}
336
337#[derive(Debug, Clone, PartialEq, Eq)]
338struct GranularEdit {
339 range: Range<usize>,
340 old_text: String,
341 new_text: String,
342}
343
344fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
345 text_diff(old_text, new_text)
346 .into_iter()
347 .map(|(range, new_text)| GranularEdit {
348 old_text: old_text[range.clone()].to_string(),
349 range,
350 new_text: new_text.to_string(),
351 })
352 .collect()
353}
354
355#[derive(Debug, Clone)]
356struct HistoryAdditionRange {
357 range_in_current: Range<usize>,
358}
359
360#[derive(Debug, Clone)]
361struct HistoryDeletionRange {
362 deleted_text: String,
363 position_in_current: usize,
364}
365
366fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryAdditionRange> {
367 let mut result = Vec::new();
368 let mut offset_delta: isize = 0;
369
370 for edit in history_edits {
371 if !edit.new_text.is_empty() {
372 let new_start = (edit.range.start as isize + offset_delta) as usize;
373 let new_end = new_start + edit.new_text.len();
374 result.push(HistoryAdditionRange {
375 range_in_current: new_start..new_end,
376 });
377 }
378
379 offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
380 }
381
382 result
383}
384
385fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryDeletionRange> {
386 let mut result = Vec::new();
387 let mut offset_delta: isize = 0;
388
389 for edit in history_edits {
390 if !edit.old_text.is_empty() {
391 let position_in_current = (edit.range.start as isize + offset_delta) as usize;
392 result.push(HistoryDeletionRange {
393 deleted_text: edit.old_text.clone(),
394 position_in_current,
395 });
396 }
397
398 offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
399 }
400
401 result
402}
403
404#[derive(Debug, Clone, Default, PartialEq, Eq)]
405struct ReversalOverlap {
406 chars_reversing_user_edits: usize,
407 total_chars_in_prediction: usize,
408}
409
410impl ReversalOverlap {
411 fn ratio(&self) -> f32 {
412 if self.total_chars_in_prediction == 0 {
413 0.0
414 } else {
415 self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32
416 }
417 }
418}
419
420/// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension),
421/// or where `new_text` appears as a subsequence within `old_text` (reduction).
422///
423/// For extensions: when the user's text is preserved (in order) within the prediction,
424/// we only count the newly inserted characters, not the preserved ones.
425/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()")
426/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars
427///
428/// For reductions: when the prediction's text is preserved (in order) within the original,
429/// we only count the deleted characters, not the preserved ones.
430/// E.g., "ifrom" → "from" becomes 1 deleted char ("i")
431fn normalize_extension_edits(edits: Vec<GranularEdit>) -> Vec<GranularEdit> {
432 edits
433 .into_iter()
434 .flat_map(|edit| {
435 if edit.old_text.is_empty() || edit.new_text.is_empty() {
436 return vec![edit];
437 }
438
439 // Use character-wise diff to find exact byte ranges of changes
440 let char_edits = char_diff(&edit.old_text, &edit.new_text);
441
442 let all_deletions = !char_edits.is_empty()
443 && char_edits
444 .iter()
445 .all(|(range, replacement)| !range.is_empty() && replacement.is_empty());
446 let all_insertions = !char_edits.is_empty()
447 && char_edits
448 .iter()
449 .all(|(range, replacement)| range.is_empty() && !replacement.is_empty());
450 if all_deletions || all_insertions {
451 return char_edits
452 .into_iter()
453 .map(|(range, replacement)| GranularEdit {
454 range: edit.range.start + range.start..edit.range.start + range.end,
455 old_text: edit.old_text[range].to_string(),
456 new_text: replacement.to_string(),
457 })
458 .collect();
459 }
460
461 // Otherwise, keep the original edit (mixed changes)
462 vec![edit]
463 })
464 .collect()
465}
466
467fn compute_reversal_overlap(
468 original_content: &str,
469 current_content: &str,
470 predicted_content: &str,
471) -> ReversalOverlap {
472 let history_edits =
473 normalize_extension_edits(compute_granular_edits(original_content, current_content));
474 let prediction_edits =
475 normalize_extension_edits(compute_granular_edits(current_content, predicted_content));
476
477 let history_addition_ranges = compute_history_addition_ranges(&history_edits);
478 let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
479
480 let reversed_additions =
481 compute_reversed_additions(&history_addition_ranges, &prediction_edits);
482 let restored_deletions =
483 compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
484
485 let total_chars_in_prediction: usize = prediction_edits
486 .iter()
487 .map(|e| e.new_text.chars().count() + e.old_text.chars().count())
488 .sum();
489
490 ReversalOverlap {
491 chars_reversing_user_edits: reversed_additions + restored_deletions,
492 total_chars_in_prediction,
493 }
494}
495
496fn compute_reversed_additions(
497 history_addition_ranges: &[HistoryAdditionRange],
498 prediction_edits: &[GranularEdit],
499) -> usize {
500 let mut reversed_chars = 0;
501
502 for pred_edit in prediction_edits {
503 for history_addition in history_addition_ranges {
504 let overlap_start = pred_edit
505 .range
506 .start
507 .max(history_addition.range_in_current.start);
508 let overlap_end = pred_edit
509 .range
510 .end
511 .min(history_addition.range_in_current.end);
512
513 if overlap_start < overlap_end {
514 let relative_start = overlap_start - pred_edit.range.start;
515 let relative_end = overlap_end - pred_edit.range.start;
516 let overlap_text = &pred_edit.old_text[relative_start..relative_end];
517 reversed_chars += overlap_text.chars().count();
518 }
519 }
520 }
521
522 reversed_chars
523}
524
525fn compute_restored_deletions(
526 history_deletion_ranges: &[HistoryDeletionRange],
527 prediction_edits: &[GranularEdit],
528) -> usize {
529 let mut restored = 0;
530
531 for pred_edit in prediction_edits {
532 if pred_edit.new_text.is_empty() {
533 continue;
534 }
535
536 for deletion in history_deletion_ranges {
537 if pred_edit.range.contains(&deletion.position_in_current)
538 || deletion.position_in_current == pred_edit.range.start
539 {
540 restored += compute_lcs_length(&deletion.deleted_text, &pred_edit.new_text);
541 }
542 }
543 }
544
545 restored
546}
547
548fn compute_lcs_length(a: &str, b: &str) -> usize {
549 let a_chars: Vec<char> = a.chars().collect();
550 let b_chars: Vec<char> = b.chars().collect();
551 let m = a_chars.len();
552 let n = b_chars.len();
553
554 if m == 0 || n == 0 {
555 return 0;
556 }
557
558 let mut prev = vec![0; n + 1];
559 let mut curr = vec![0; n + 1];
560
561 for i in 1..=m {
562 for j in 1..=n {
563 if a_chars[i - 1] == b_chars[j - 1] {
564 curr[j] = prev[j - 1] + 1;
565 } else {
566 curr[j] = prev[j].max(curr[j - 1]);
567 }
568 }
569 std::mem::swap(&mut prev, &mut curr);
570 curr.fill(0);
571 }
572
573 prev[n]
574}
575
576fn filter_edit_history_by_path<'a>(
577 edit_history: &'a [Arc<zeta_prompt::Event>],
578 cursor_path: &std::path::Path,
579) -> Vec<&'a zeta_prompt::Event> {
580 edit_history
581 .iter()
582 .filter(|event| match event.as_ref() {
583 zeta_prompt::Event::BufferChange { path, .. } => {
584 let event_path = path.as_ref();
585 if event_path == cursor_path {
586 return true;
587 }
588 let stripped = event_path
589 .components()
590 .skip(1)
591 .collect::<std::path::PathBuf>();
592 stripped == cursor_path
593 }
594 })
595 .map(|arc| arc.as_ref())
596 .collect()
597}
598
599fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str {
600 match event {
601 zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(),
602 }
603}
604
605fn is_predicted_event(event: &zeta_prompt::Event) -> bool {
606 match event {
607 zeta_prompt::Event::BufferChange { predicted, .. } => *predicted,
608 }
609}
610
611pub fn compute_prediction_reversal_ratio(
612 prompt_inputs: &ZetaPromptInput,
613 predicted_content: &str,
614 cursor_path: &Path,
615) -> f32 {
616 let current_content: &str = prompt_inputs.cursor_excerpt.as_ref();
617
618 let edit_history: &[Arc<zeta_prompt::Event>] = &prompt_inputs.events;
619 let relevant_events = filter_edit_history_by_path(edit_history, cursor_path);
620
621 let most_recent = match relevant_events.last() {
622 Some(event) if !is_predicted_event(event) => *event,
623 _ => return 0.0,
624 };
625
626 let diff = extract_diff_from_event(most_recent);
627 if diff.is_empty() {
628 return 0.0;
629 }
630
631 if let Some(excerpt_start_row) = prompt_inputs.excerpt_start_row {
632 let diffs = vec![diff];
633 let overlap = compute_excerpt_aware_reversal_overlap(
634 &diffs,
635 current_content,
636 excerpt_start_row,
637 predicted_content,
638 );
639 return overlap.ratio();
640 }
641
642 let reversed = reverse_diff(diff);
643 let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed);
644 let original_content = match apply_diff_to_string(&with_headers, current_content) {
645 Ok(updated_content) => updated_content,
646 Err(_) => apply_diff_to_string_lenient(&reversed, current_content),
647 };
648
649 let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content);
650 overlap.ratio()
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656 use indoc::indoc;
657 use zeta_prompt::ExcerptRanges;
658 use zeta_prompt::udiff::apply_diff_to_string;
659
660 fn make_test_prompt_inputs(
661 content: &str,
662 events: Vec<Arc<zeta_prompt::Event>>,
663 excerpt_start_row: Option<u32>,
664 ) -> ZetaPromptInput {
665 ZetaPromptInput {
666 cursor_path: Arc::from(Path::new("src/test.rs")),
667 cursor_excerpt: content.into(),
668 cursor_offset_in_excerpt: 0,
669 excerpt_start_row,
670 events,
671 related_files: Some(Vec::new()),
672 active_buffer_diagnostics: Vec::new(),
673 excerpt_ranges: ExcerptRanges {
674 editable_150: 0..content.len(),
675 editable_180: 0..content.len(),
676 editable_350: 0..content.len(),
677 editable_150_context_350: 0..content.len(),
678 editable_180_context_350: 0..content.len(),
679 editable_350_context_150: 0..content.len(),
680 ..Default::default()
681 },
682 syntax_ranges: None,
683 experiment: None,
684 in_open_source_repo: false,
685 can_collect_data: false,
686 repo_url: None,
687 }
688 }
689
690 #[test]
691 fn test_reversal_overlap() {
692 struct Case {
693 name: &'static str,
694 original: &'static str,
695 current: &'static str,
696 predicted: &'static str,
697 expected_reversal_chars: usize,
698 expected_total_chars: usize,
699 }
700
701 let cases = [
702 Case {
703 name: "user_adds_line_prediction_removes_it",
704 original: indoc! {"
705 a
706 b
707 c"},
708 current: indoc! {"
709 a
710 new line
711 b
712 c"},
713 predicted: indoc! {"
714 a
715 b
716 c"},
717 expected_reversal_chars: 9,
718 expected_total_chars: 9,
719 },
720 Case {
721 name: "user_deletes_line_prediction_restores_it",
722 original: indoc! {"
723 a
724 deleted
725 b"},
726 current: indoc! {"
727 a
728 b"},
729 predicted: indoc! {"
730 a
731 deleted
732 b"},
733 expected_reversal_chars: 8,
734 expected_total_chars: 8,
735 },
736 Case {
737 name: "user_deletes_text_prediction_restores_partial",
738 original: "hello beautiful world",
739 current: "hello world",
740 predicted: "hello beautiful world",
741 expected_reversal_chars: 10,
742 expected_total_chars: 10,
743 },
744 Case {
745 name: "user_deletes_foo_prediction_adds_bar",
746 original: "foo",
747 current: "",
748 predicted: "bar",
749 expected_reversal_chars: 0,
750 expected_total_chars: 3,
751 },
752 Case {
753 name: "independent_edits_different_locations",
754 original: indoc! {"
755 line1
756 line2
757 line3"},
758 current: indoc! {"
759 LINE1
760 line2
761 line3"},
762 predicted: indoc! {"
763 LINE1
764 line2
765 LINE3"},
766 expected_reversal_chars: 0,
767 expected_total_chars: 10,
768 },
769 Case {
770 name: "no_history_edits",
771 original: "same",
772 current: "same",
773 predicted: "different",
774 expected_reversal_chars: 0,
775 expected_total_chars: 13,
776 },
777 Case {
778 name: "user_replaces_text_prediction_reverses",
779 original: indoc! {"
780 keep
781 delete_me
782 keep2"},
783 current: indoc! {"
784 keep
785 added
786 keep2"},
787 predicted: indoc! {"
788 keep
789 delete_me
790 keep2"},
791 expected_reversal_chars: 14,
792 expected_total_chars: 14,
793 },
794 Case {
795 name: "user_modifies_word_prediction_modifies_differently",
796 original: "the quick brown fox",
797 current: "the slow brown fox",
798 predicted: "the fast brown fox",
799 expected_reversal_chars: 4,
800 expected_total_chars: 8,
801 },
802 Case {
803 name: "user finishes function name (suffix)",
804 original: "",
805 current: "epr",
806 predicted: "eprintln!()",
807 expected_reversal_chars: 0,
808 expected_total_chars: 8,
809 },
810 Case {
811 name: "user starts function name (prefix)",
812 original: "",
813 current: "my_function()",
814 predicted: "test_my_function()",
815 expected_reversal_chars: 0,
816 expected_total_chars: 5,
817 },
818 Case {
819 name: "user types partial, prediction extends in multiple places",
820 original: "",
821 current: "test_my_function",
822 predicted: "a_test_for_my_special_function_plz",
823 expected_reversal_chars: 0,
824 expected_total_chars: 18,
825 },
826 // Edge cases for subsequence matching
827 Case {
828 name: "subsequence with interleaved underscores",
829 original: "",
830 current: "a_b_c",
831 predicted: "_a__b__c__",
832 expected_reversal_chars: 0,
833 expected_total_chars: 5,
834 },
835 Case {
836 name: "not a subsequence - different characters",
837 original: "",
838 current: "abc",
839 predicted: "xyz",
840 expected_reversal_chars: 3,
841 expected_total_chars: 6,
842 },
843 Case {
844 name: "not a subsequence - wrong order",
845 original: "",
846 current: "abc",
847 predicted: "cba",
848 expected_reversal_chars: 3,
849 expected_total_chars: 6,
850 },
851 Case {
852 name: "partial subsequence - only some chars match",
853 original: "",
854 current: "abcd",
855 predicted: "axbx",
856 expected_reversal_chars: 4,
857 expected_total_chars: 8,
858 },
859 // Common completion patterns
860 Case {
861 name: "completing a method call",
862 original: "",
863 current: "vec.pu",
864 predicted: "vec.push(item)",
865 expected_reversal_chars: 0,
866 expected_total_chars: 8,
867 },
868 Case {
869 name: "completing an import statement",
870 original: "",
871 current: "use std::col",
872 predicted: "use std::collections::HashMap",
873 expected_reversal_chars: 0,
874 expected_total_chars: 17,
875 },
876 Case {
877 name: "completing a struct field",
878 original: "",
879 current: "name: St",
880 predicted: "name: String",
881 expected_reversal_chars: 0,
882 expected_total_chars: 4,
883 },
884 Case {
885 name: "prediction replaces with completely different text",
886 original: "",
887 current: "hello",
888 predicted: "world",
889 expected_reversal_chars: 5,
890 expected_total_chars: 10,
891 },
892 Case {
893 name: "empty prediction removes user text",
894 original: "",
895 current: "mistake",
896 predicted: "",
897 expected_reversal_chars: 7,
898 expected_total_chars: 7,
899 },
900 Case {
901 name: "fixing typo is not reversal",
902 original: "",
903 current: "<dv",
904 predicted: "<div>",
905 expected_reversal_chars: 0,
906 expected_total_chars: 2,
907 },
908 Case {
909 name: "infix insertion not reversal",
910 original: indoc! {"
911 from my_project import Foo
912 "},
913 current: indoc! {"
914 ifrom my_project import Foo
915 "},
916 predicted: indoc! {"
917 import
918 from my_project import Foo
919 "},
920 expected_reversal_chars: 0,
921 expected_total_chars: 6,
922 },
923 Case {
924 name: "non-word based reversal",
925 original: "from",
926 current: "ifrom",
927 predicted: "from",
928 expected_reversal_chars: 1,
929 expected_total_chars: 1,
930 },
931 Case {
932 name: "multiple insertions no reversal",
933 original: "print(\"Hello, World!\")",
934 current: "sys.(\"Hello, World!\")",
935 predicted: "sys.stdout.write(\"Hello, World!\\n\")",
936 expected_reversal_chars: 0,
937 expected_total_chars: 14,
938 },
939 ];
940
941 for case in &cases {
942 let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
943 assert_eq!(
944 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
945 "Test '{}': expected {} reversal chars, got {}",
946 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
947 );
948 assert_eq!(
949 overlap.total_chars_in_prediction, case.expected_total_chars,
950 "Test '{}': expected {} total chars, got {}",
951 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
952 );
953 }
954 }
955
956 #[test]
957 fn test_reverse_diff() {
958 let forward_diff = indoc! {"
959 --- a/file.rs
960 +++ b/file.rs
961 @@ -1,3 +1,4 @@
962 fn main() {
963 + let x = 42;
964 println!(\"hello\");
965 }"};
966
967 let reversed = reverse_diff(forward_diff);
968
969 assert!(
970 reversed.contains("+++ a/file.rs"),
971 "Should have +++ for old path"
972 );
973 assert!(
974 reversed.contains("--- b/file.rs"),
975 "Should have --- for new path"
976 );
977 assert!(
978 reversed.contains("- let x = 42;"),
979 "Added line should become deletion"
980 );
981 assert!(
982 reversed.contains(" fn main()"),
983 "Context lines should be unchanged"
984 );
985 }
986
987 #[test]
988 fn test_reverse_diff_roundtrip() {
989 // Applying a diff and then its reverse should get back to original
990 let original = indoc! {"
991 first line
992 hello world
993 last line
994 "};
995 let modified = indoc! {"
996 first line
997 hello beautiful world
998 last line
999 "};
1000
1001 // unified_diff doesn't include file headers, but apply_diff_to_string needs them
1002 let diff_body = language::unified_diff(original, modified);
1003 let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body);
1004 let reversed_diff = reverse_diff(&forward_diff);
1005
1006 // Apply forward diff to original
1007 let after_forward = apply_diff_to_string(&forward_diff, original).unwrap();
1008 assert_eq!(after_forward, modified);
1009
1010 // Apply reversed diff to modified
1011 let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap();
1012 assert_eq!(after_reverse, original);
1013 }
1014
1015 #[test]
1016 fn test_filter_edit_history_by_path() {
1017 // Test that filter_edit_history_by_path correctly matches paths when
1018 // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs")
1019 // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs")
1020 let events = vec![
1021 Arc::new(zeta_prompt::Event::BufferChange {
1022 path: Arc::from(Path::new("myrepo/src/file.rs")),
1023 old_path: Arc::from(Path::new("myrepo/src/file.rs")),
1024 diff: indoc! {"
1025 @@ -1 +1 @@
1026 -old
1027 +new"}
1028 .into(),
1029 predicted: false,
1030 in_open_source_repo: true,
1031 }),
1032 Arc::new(zeta_prompt::Event::BufferChange {
1033 path: Arc::from(Path::new("myrepo/other.rs")),
1034 old_path: Arc::from(Path::new("myrepo/other.rs")),
1035 diff: indoc! {"
1036 @@ -1 +1 @@
1037 -a
1038 +b"}
1039 .into(),
1040 predicted: false,
1041 in_open_source_repo: true,
1042 }),
1043 Arc::new(zeta_prompt::Event::BufferChange {
1044 path: Arc::from(Path::new("src/file.rs")),
1045 old_path: Arc::from(Path::new("src/file.rs")),
1046 diff: indoc! {"
1047 @@ -1 +1 @@
1048 -x
1049 +y"}
1050 .into(),
1051 predicted: false,
1052 in_open_source_repo: true,
1053 }),
1054 ];
1055
1056 // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path
1057 // "src/file.rs" exact match
1058 let cursor_path = Path::new("src/file.rs");
1059 let filtered = filter_edit_history_by_path(&events, cursor_path);
1060 assert_eq!(
1061 filtered.len(),
1062 2,
1063 "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)"
1064 );
1065
1066 // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs"
1067 // "src/file.rs" stripped -> "file.rs" == "file.rs"
1068 let cursor_path = Path::new("file.rs");
1069 let filtered = filter_edit_history_by_path(&events, cursor_path);
1070 assert_eq!(
1071 filtered.len(),
1072 1,
1073 "Should only match src/file.rs (stripped to file.rs)"
1074 );
1075
1076 // "myrepo/other.rs" stripped -> "other.rs" == "other.rs"
1077 let cursor_path = Path::new("other.rs");
1078 let filtered = filter_edit_history_by_path(&events, cursor_path);
1079 assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs");
1080 }
1081
1082 #[test]
1083 fn test_reverse_diff_preserves_trailing_newline() {
1084 let diff_with_trailing_newline = indoc! {"
1085 --- a/file
1086 +++ b/file
1087 @@ -1 +1 @@
1088 -old
1089 +new
1090 "};
1091 let reversed = reverse_diff(diff_with_trailing_newline);
1092 assert!(
1093 reversed.ends_with('\n'),
1094 "Reversed diff should preserve trailing newline"
1095 );
1096
1097 let diff_without_trailing_newline = indoc! {"
1098 --- a/file
1099 +++ b/file
1100 @@ -1 +1 @@
1101 -old
1102 +new"};
1103 let reversed = reverse_diff(diff_without_trailing_newline);
1104 assert!(
1105 !reversed.ends_with('\n'),
1106 "Reversed diff should not add trailing newline if original didn't have one"
1107 );
1108 }
1109
1110 #[test]
1111 fn test_filter_hunks_by_excerpt_region() {
1112 struct Case {
1113 name: &'static str,
1114 diff: &'static str,
1115 excerpt_start_row: u32,
1116 excerpt_row_count: u32,
1117 expected_filtered_diff: &'static str,
1118 expected_line_offset: i32,
1119 }
1120
1121 let cases = [
1122 Case {
1123 name: "hunk_entirely_before_excerpt",
1124 diff: indoc! {"
1125 @@ -1,3 +1,4 @@
1126 line1
1127 +inserted
1128 line2
1129 line3
1130 "},
1131 excerpt_start_row: 10,
1132 excerpt_row_count: 5,
1133 expected_filtered_diff: "",
1134 expected_line_offset: 1,
1135 },
1136 Case {
1137 name: "hunk_entirely_inside_excerpt",
1138 diff: indoc! {"
1139 @@ -12,3 +12,4 @@
1140 line12
1141 +inserted
1142 line13
1143 line14
1144 "},
1145 excerpt_start_row: 10,
1146 excerpt_row_count: 10,
1147 expected_filtered_diff: indoc! {"
1148 @@ -2,3 +2,4 @@
1149 line12
1150 +inserted
1151 line13
1152 line14
1153 "},
1154 expected_line_offset: 1,
1155 },
1156 Case {
1157 name: "hunk_entirely_after_excerpt",
1158 diff: indoc! {"
1159 @@ -50,3 +50,4 @@
1160 line50
1161 +inserted
1162 line51
1163 line52
1164 "},
1165 excerpt_start_row: 10,
1166 excerpt_row_count: 5,
1167 expected_filtered_diff: "",
1168 expected_line_offset: 0,
1169 },
1170 Case {
1171 name: "hunk_straddles_excerpt_start",
1172 diff: indoc! {"
1173 @@ -8,5 +8,6 @@
1174 line8
1175 line9
1176 +inserted
1177 line10
1178 line11
1179 line12
1180 "},
1181 excerpt_start_row: 10,
1182 excerpt_row_count: 10,
1183 expected_filtered_diff: indoc! {"
1184 @@ -1,3 +1,3 @@
1185 line10
1186 line11
1187 line12
1188 "},
1189 expected_line_offset: 1,
1190 },
1191 Case {
1192 name: "hunk_straddles_excerpt_end",
1193 diff: indoc! {"
1194 @@ -18,5 +18,6 @@
1195 line18
1196 line19
1197 +inserted
1198 line20
1199 line21
1200 line22
1201 "},
1202 excerpt_start_row: 10,
1203 excerpt_row_count: 10,
1204 expected_filtered_diff: indoc! {"
1205 @@ -8,2 +8,3 @@
1206 line18
1207 line19
1208 +inserted
1209 "},
1210 expected_line_offset: 1,
1211 },
1212 Case {
1213 name: "multiple_hunks_mixed",
1214 diff: indoc! {"
1215 @@ -1,2 +1,3 @@
1216 line1
1217 +before_excerpt
1218 line2
1219 @@ -12,2 +13,3 @@
1220 line12
1221 +inside_excerpt
1222 line13
1223 @@ -50,2 +52,3 @@
1224 line50
1225 +after_excerpt
1226 line51
1227 "},
1228 excerpt_start_row: 10,
1229 excerpt_row_count: 10,
1230 expected_filtered_diff: indoc! {"
1231 @@ -3,2 +3,3 @@
1232 line12
1233 +inside_excerpt
1234 line13
1235 "},
1236 expected_line_offset: 2,
1237 },
1238 Case {
1239 name: "deletion_before_excerpt",
1240 diff: indoc! {"
1241 @@ -1,4 +1,3 @@
1242 line1
1243 -deleted
1244 line2
1245 line3
1246 "},
1247 excerpt_start_row: 10,
1248 excerpt_row_count: 5,
1249 expected_filtered_diff: "",
1250 expected_line_offset: -1,
1251 },
1252 Case {
1253 name: "deletion_inside_excerpt",
1254 diff: indoc! {"
1255 @@ -12,4 +12,3 @@
1256 line12
1257 -deleted
1258 line13
1259 line14
1260 "},
1261 excerpt_start_row: 10,
1262 excerpt_row_count: 10,
1263 expected_filtered_diff: indoc! {"
1264 @@ -2,4 +2,3 @@
1265 line12
1266 -deleted
1267 line13
1268 line14
1269 "},
1270 expected_line_offset: -1,
1271 },
1272 Case {
1273 name: "empty_diff",
1274 diff: "",
1275 excerpt_start_row: 10,
1276 excerpt_row_count: 5,
1277 expected_filtered_diff: "",
1278 expected_line_offset: 0,
1279 },
1280 Case {
1281 name: "hunk_spans_entire_excerpt",
1282 diff: indoc! {"
1283 @@ -8,10 +8,12 @@
1284 line8
1285 line9
1286 line10
1287 line11
1288 +inserted1
1289 line12
1290 line13
1291 +inserted2
1292 line14
1293 line15
1294 line16
1295 line17
1296 "},
1297 excerpt_start_row: 10,
1298 excerpt_row_count: 5,
1299 expected_filtered_diff: indoc! {"
1300 @@ -1,3 +1,5 @@
1301 line11
1302 +inserted1
1303 line12
1304 line13
1305 +inserted2
1306 "},
1307 expected_line_offset: 2,
1308 },
1309 Case {
1310 name: "replacement_inside_excerpt",
1311 diff: indoc! {"
1312 @@ -12,3 +12,3 @@
1313 line12
1314 -old_text
1315 +new_text
1316 line14
1317 "},
1318 excerpt_start_row: 10,
1319 excerpt_row_count: 10,
1320 expected_filtered_diff: indoc! {"
1321 @@ -2,3 +2,3 @@
1322 line12
1323 -old_text
1324 +new_text
1325 line14
1326 "},
1327 expected_line_offset: 0,
1328 },
1329 ];
1330
1331 for case in &cases {
1332 let (filtered, line_offset) = filter_diff_hunks_by_excerpt(
1333 case.diff,
1334 case.excerpt_start_row,
1335 case.excerpt_row_count,
1336 );
1337 assert_eq!(
1338 filtered, case.expected_filtered_diff,
1339 "Test '{}': filtered diff mismatch.\nExpected:\n{}\nGot:\n{}",
1340 case.name, case.expected_filtered_diff, filtered
1341 );
1342 assert_eq!(
1343 line_offset, case.expected_line_offset,
1344 "Test '{}': line offset mismatch. Expected {}, got {}",
1345 case.name, case.expected_line_offset, line_offset
1346 );
1347 }
1348 }
1349
1350 #[test]
1351 fn test_excerpt_aware_reversal_tracking() {
1352 struct Case {
1353 name: &'static str,
1354 edit_history_diffs: Vec<&'static str>,
1355 excerpt_content: &'static str,
1356 excerpt_start_row: u32,
1357 predicted_content: &'static str,
1358 expected_reversal_chars: usize,
1359 expected_total_chars: usize,
1360 }
1361
1362 let cases = [
1363 Case {
1364 name: "edit_outside_excerpt_no_reversal",
1365 edit_history_diffs: vec![indoc! {"
1366 @@ -1,2 +1,3 @@
1367 line1
1368 +added_outside
1369 line2
1370 "}],
1371 excerpt_content: indoc! {"
1372 line10
1373 line11
1374 line12
1375 "},
1376 excerpt_start_row: 10,
1377 predicted_content: indoc! {"
1378 line10
1379 modified
1380 line12
1381 "},
1382 expected_reversal_chars: 0,
1383 expected_total_chars: 14,
1384 },
1385 Case {
1386 name: "edit_inside_excerpt_with_reversal",
1387 edit_history_diffs: vec![indoc! {"
1388 @@ -10,3 +10,4 @@
1389 line10
1390 +user_added
1391 line11
1392 line12
1393 "}],
1394 excerpt_content: indoc! {"
1395 line10
1396 user_added
1397 line11
1398 line12
1399 "},
1400 excerpt_start_row: 10,
1401 predicted_content: indoc! {"
1402 line10
1403 line11
1404 line12
1405 "},
1406 expected_reversal_chars: 11,
1407 expected_total_chars: 11,
1408 },
1409 Case {
1410 name: "straddling_edit_partial_reversal",
1411 edit_history_diffs: vec![indoc! {"
1412 @@ -8,6 +8,8 @@
1413 line8
1414 line9
1415 +before_excerpt
1416 line10
1417 +inside_excerpt
1418 line11
1419 line12
1420 line13
1421 "}],
1422 excerpt_content: indoc! {"
1423 line10
1424 inside_excerpt
1425 line11
1426 line12
1427 line13
1428 "},
1429 excerpt_start_row: 10,
1430 predicted_content: indoc! {"
1431 line10
1432 line11
1433 line12
1434 line13
1435 "},
1436 expected_reversal_chars: 15,
1437 expected_total_chars: 15,
1438 },
1439 Case {
1440 name: "multiple_edits_mixed_locations",
1441 edit_history_diffs: vec![
1442 indoc! {"
1443 @@ -1,2 +1,3 @@
1444 line1
1445 +outside1
1446 line2
1447 "},
1448 indoc! {"
1449 @@ -11,2 +12,3 @@
1450 line11
1451 +inside1
1452 line12
1453 "},
1454 ],
1455 excerpt_content: indoc! {"
1456 line10
1457 line11
1458 inside1
1459 line12
1460 line13
1461 "},
1462 excerpt_start_row: 10,
1463 predicted_content: indoc! {"
1464 line10
1465 line11
1466 line12
1467 line13
1468 "},
1469 expected_reversal_chars: 8,
1470 expected_total_chars: 8,
1471 },
1472 Case {
1473 name: "no_edit_history",
1474 edit_history_diffs: vec![],
1475 excerpt_content: indoc! {"
1476 line10
1477 line11
1478 line12
1479 "},
1480 excerpt_start_row: 10,
1481 predicted_content: indoc! {"
1482 line10
1483 modified
1484 line12
1485 "},
1486 expected_reversal_chars: 0,
1487 expected_total_chars: 14,
1488 },
1489 Case {
1490 name: "edit_after_excerpt_no_effect",
1491 edit_history_diffs: vec![indoc! {"
1492 @@ -50,2 +50,3 @@
1493 line50
1494 +added_after
1495 line51
1496 "}],
1497 excerpt_content: indoc! {"
1498 line10
1499 line11
1500 line12
1501 "},
1502 excerpt_start_row: 10,
1503 predicted_content: indoc! {"
1504 line10
1505 changed
1506 line12
1507 "},
1508 expected_reversal_chars: 0,
1509 expected_total_chars: 13,
1510 },
1511 Case {
1512 name: "line_offset_tracking_across_hunks",
1513 edit_history_diffs: vec![
1514 indoc! {"
1515 @@ -1,2 +1,4 @@
1516 line1
1517 +added1
1518 +added2
1519 line2
1520 "},
1521 indoc! {"
1522 @@ -12,2 +14,3 @@
1523 line12
1524 +inside_after_offset
1525 line13
1526 "},
1527 ],
1528 excerpt_content: indoc! {"
1529 line10
1530 line11
1531 line12
1532 inside_after_offset
1533 line13
1534 "},
1535 excerpt_start_row: 10,
1536 predicted_content: indoc! {"
1537 line10
1538 line11
1539 line12
1540 line13
1541 "},
1542 expected_reversal_chars: 20,
1543 expected_total_chars: 20,
1544 },
1545 ];
1546
1547 for case in &cases {
1548 let overlap = compute_excerpt_aware_reversal_overlap(
1549 &case.edit_history_diffs,
1550 case.excerpt_content,
1551 case.excerpt_start_row,
1552 case.predicted_content,
1553 );
1554 assert_eq!(
1555 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
1556 "Test '{}': expected {} reversal chars, got {}",
1557 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
1558 );
1559 assert_eq!(
1560 overlap.total_chars_in_prediction, case.expected_total_chars,
1561 "Test '{}': expected {} total chars, got {}",
1562 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
1563 );
1564 }
1565 }
1566
1567 #[test]
1568 fn test_lenient_diff_application() {
1569 struct Case {
1570 name: &'static str,
1571 diff: &'static str,
1572 content: &'static str,
1573 expected_result: &'static str,
1574 }
1575
1576 let cases = [
1577 Case {
1578 name: "hunk_context_not_found_skipped",
1579 diff: indoc! {"
1580 @@ -1,3 +1,4 @@
1581 context_not_in_content
1582 +added_line
1583 more_context
1584 final_context
1585 "},
1586 content: indoc! {"
1587 completely
1588 different
1589 content
1590 "},
1591 expected_result: indoc! {"
1592 completely
1593 different
1594 content
1595 "},
1596 },
1597 Case {
1598 name: "hunk_context_found_applied",
1599 diff: indoc! {"
1600 @@ -1,3 +1,4 @@
1601 line1
1602 +inserted
1603 line2
1604 line3
1605 "},
1606 content: indoc! {"
1607 line1
1608 line2
1609 line3
1610 "},
1611 expected_result: indoc! {"
1612 line1
1613 inserted
1614 line2
1615 line3
1616 "},
1617 },
1618 Case {
1619 name: "multiple_hunks_partial_match",
1620 diff: indoc! {"
1621 @@ -1,2 +1,3 @@
1622 not_found
1623 +skipped
1624 also_not_found
1625 @@ -5,2 +6,3 @@
1626 line5
1627 +applied
1628 line6
1629 "},
1630 content: indoc! {"
1631 line1
1632 line2
1633 line3
1634 line4
1635 line5
1636 line6
1637 "},
1638 expected_result: indoc! {"
1639 line1
1640 line2
1641 line3
1642 line4
1643 line5
1644 applied
1645 line6
1646 "},
1647 },
1648 Case {
1649 name: "empty_diff",
1650 diff: "",
1651 content: indoc! {"
1652 unchanged
1653 content
1654 "},
1655 expected_result: indoc! {"
1656 unchanged
1657 content
1658 "},
1659 },
1660 ];
1661
1662 for case in &cases {
1663 let result = apply_diff_to_string_lenient(case.diff, case.content);
1664 assert_eq!(
1665 result, case.expected_result,
1666 "Test '{}': expected:\n{}\ngot:\n{}",
1667 case.name, case.expected_result, result
1668 );
1669 }
1670 }
1671
1672 #[test]
1673 fn test_unicode_reversal_overlap() {
1674 struct Case {
1675 name: &'static str,
1676 original: &'static str,
1677 current: &'static str,
1678 predicted: &'static str,
1679 expected_reversal_chars: usize,
1680 expected_total_chars: usize,
1681 }
1682
1683 let cases = [
1684 Case {
1685 name: "unicode_extension_cjk",
1686 original: "",
1687 current: "日", // 1 char
1688 predicted: "日本語", // 3 chars, adds 2 chars
1689 expected_reversal_chars: 0,
1690 expected_total_chars: 2, // "本語" = 2 chars added
1691 },
1692 Case {
1693 name: "unicode_extension_emoji",
1694 original: "",
1695 current: "🎉", // 1 char
1696 predicted: "🎉🎊🎈", // 3 chars, adds 2 chars
1697 expected_reversal_chars: 0,
1698 expected_total_chars: 2, // "🎊🎈" = 2 chars added
1699 },
1700 Case {
1701 name: "unicode_deletion_restored",
1702 original: "héllo wörld", // 11 chars
1703 current: "héllo", // 5 chars
1704 predicted: "héllo wörld", // restores " wörld" = 6 chars
1705 expected_reversal_chars: 6, // LCS(" wörld", " wörld") = 6 chars
1706 expected_total_chars: 6,
1707 },
1708 Case {
1709 name: "unicode_addition_reversed",
1710 original: "café", // 4 chars
1711 current: "café latté", // 10 chars, added " latté" = 6 chars
1712 predicted: "café", // removes " latté"
1713 expected_reversal_chars: 6, // 6 chars removed
1714 expected_total_chars: 6,
1715 },
1716 Case {
1717 name: "mixed_ascii_unicode",
1718 original: "",
1719 current: "test日本", // 6 chars
1720 predicted: "test日本語です", // 9 chars
1721 expected_reversal_chars: 0,
1722 expected_total_chars: 3, // 3 new chars after subsequence normalization
1723 },
1724 Case {
1725 name: "unicode_replacement_not_subsequence",
1726 original: "",
1727 current: "日本", // 2 chars
1728 predicted: "中国", // 2 chars, different
1729 expected_reversal_chars: 2, // removes "日本" = 2 chars
1730 expected_total_chars: 4, // 2 removed + 2 added
1731 },
1732 ];
1733
1734 for case in &cases {
1735 let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
1736 assert_eq!(
1737 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
1738 "Test '{}': expected {} reversal chars, got {}",
1739 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
1740 );
1741 assert_eq!(
1742 overlap.total_chars_in_prediction, case.expected_total_chars,
1743 "Test '{}': expected {} total chars, got {}",
1744 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
1745 );
1746 }
1747 }
1748
1749 #[test]
1750 fn test_compute_lcs_length() {
1751 assert_eq!(compute_lcs_length("", ""), 0);
1752 assert_eq!(compute_lcs_length("abc", ""), 0);
1753 assert_eq!(compute_lcs_length("", "abc"), 0);
1754 assert_eq!(compute_lcs_length("abc", "abc"), 3);
1755 assert_eq!(compute_lcs_length("abc", "def"), 0);
1756 assert_eq!(compute_lcs_length("abcdef", "ace"), 3);
1757 assert_eq!(compute_lcs_length("AGGTAB", "GXTXAYB"), 4);
1758 assert_eq!(compute_lcs_length("日本語", "日語"), 2);
1759 }
1760
1761 #[test]
1762 fn test_compute_prediction_reversal_ratio_full_file() {
1763 let prompt_inputs = make_test_prompt_inputs(
1764 indoc! {"
1765 line1
1766 user_added
1767 line2
1768 "},
1769 vec![Arc::new(zeta_prompt::Event::BufferChange {
1770 path: Arc::from(Path::new("src/test.rs")),
1771 old_path: Arc::from(Path::new("src/test.rs")),
1772 diff: indoc! {"
1773 @@ -1,2 +1,3 @@
1774 line1
1775 +user_added
1776 line2
1777 "}
1778 .into(),
1779 predicted: false,
1780 in_open_source_repo: false,
1781 })],
1782 None,
1783 );
1784
1785 let predicted = indoc! {"
1786 line1
1787 line2
1788 "};
1789 let ratio =
1790 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1791
1792 assert!(
1793 ratio > 0.9,
1794 "Expected high reversal ratio when prediction removes user addition, got {}",
1795 ratio
1796 );
1797 }
1798
1799 #[test]
1800 fn test_compute_prediction_reversal_ratio_with_excerpt() {
1801 let prompt_inputs = make_test_prompt_inputs(
1802 indoc! {"
1803 line10
1804 user_added
1805 line11
1806 "},
1807 vec![Arc::new(zeta_prompt::Event::BufferChange {
1808 path: Arc::from(Path::new("src/test.rs")),
1809 old_path: Arc::from(Path::new("src/test.rs")),
1810 diff: indoc! {"
1811 @@ -10,2 +10,3 @@
1812 line10
1813 +user_added
1814 line11
1815 "}
1816 .into(),
1817 predicted: false,
1818 in_open_source_repo: false,
1819 })],
1820 Some(10),
1821 );
1822
1823 let predicted = indoc! {"
1824 line10
1825 line11
1826 "};
1827 let ratio =
1828 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1829
1830 assert!(
1831 ratio > 0.9,
1832 "Expected high reversal ratio for excerpt-aware computation, got {}",
1833 ratio
1834 );
1835 }
1836
1837 #[test]
1838 fn test_compute_prediction_reversal_ratio_no_history() {
1839 let prompt_inputs = make_test_prompt_inputs(
1840 indoc! {"
1841 original content
1842 "},
1843 vec![],
1844 None,
1845 );
1846
1847 let predicted = indoc! {"
1848 completely different
1849 "};
1850 let ratio =
1851 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1852
1853 assert_eq!(
1854 ratio, 0.0,
1855 "Expected zero reversal ratio with no edit history"
1856 );
1857 }
1858
1859 #[test]
1860 fn test_compute_prediction_reversal_ratio_path_filtering() {
1861 let prompt_inputs = make_test_prompt_inputs(
1862 indoc! {"
1863 line1
1864 user_added
1865 line2
1866 "},
1867 vec![Arc::new(zeta_prompt::Event::BufferChange {
1868 path: Arc::from(Path::new("src/other.rs")),
1869 old_path: Arc::from(Path::new("src/other.rs")),
1870 diff: indoc! {"
1871 @@ -1,2 +1,3 @@
1872 line1
1873 +user_added
1874 line2
1875 "}
1876 .into(),
1877 predicted: false,
1878 in_open_source_repo: false,
1879 })],
1880 None,
1881 );
1882
1883 let predicted = indoc! {"
1884 line1
1885 line2
1886 "};
1887 let ratio =
1888 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1889
1890 assert_eq!(
1891 ratio, 0.0,
1892 "Expected zero reversal when edit history is for different file"
1893 );
1894 }
1895
1896 #[test]
1897 fn test_compute_prediction_reversal_ratio_lenient_fallback() {
1898 let prompt_inputs = make_test_prompt_inputs(
1899 indoc! {"
1900 actual_line1
1901 user_added
1902 actual_line2
1903 "},
1904 vec![Arc::new(zeta_prompt::Event::BufferChange {
1905 path: Arc::from(Path::new("src/test.rs")),
1906 old_path: Arc::from(Path::new("src/test.rs")),
1907 diff: indoc! {"
1908 @@ -1,2 +1,3 @@
1909 wrong_context
1910 +user_added
1911 more_wrong
1912 "}
1913 .into(),
1914 predicted: false,
1915 in_open_source_repo: false,
1916 })],
1917 None,
1918 );
1919
1920 let predicted = indoc! {"
1921 actual_line1
1922 actual_line2
1923 "};
1924 let ratio =
1925 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1926
1927 assert!(
1928 ratio >= 0.0 && ratio <= 1.0,
1929 "Ratio should be valid even with lenient fallback, got {}",
1930 ratio
1931 );
1932 }
1933
1934 #[test]
1935 fn test_excerpt_aware_reversal_error_recovery() {
1936 let diffs = vec![indoc! {"
1937 @@ -1,2 +1,3 @@
1938 nonexistent_context
1939 +added
1940 more_nonexistent
1941 "}];
1942 let excerpt_content = indoc! {"
1943 completely
1944 different
1945 content
1946 "};
1947 let predicted_content = indoc! {"
1948 completely
1949 modified
1950 content
1951 "};
1952
1953 let overlap =
1954 compute_excerpt_aware_reversal_overlap(&diffs, excerpt_content, 0, predicted_content);
1955
1956 assert!(
1957 overlap.ratio() >= 0.0 && overlap.ratio() <= 1.0,
1958 "Should handle failed diff application gracefully"
1959 );
1960 }
1961
1962 #[test]
1963 fn test_only_most_recent_edit_tracked() {
1964 let prompt_inputs = make_test_prompt_inputs(
1965 indoc! {"
1966 line1
1967 first_add
1968 second_add
1969 line2
1970 "},
1971 vec![
1972 Arc::new(zeta_prompt::Event::BufferChange {
1973 path: Arc::from(Path::new("src/test.rs")),
1974 old_path: Arc::from(Path::new("src/test.rs")),
1975 diff: indoc! {"
1976 @@ -1,2 +1,3 @@
1977 line1
1978 +first_add
1979 line2
1980 "}
1981 .into(),
1982 predicted: false,
1983 in_open_source_repo: false,
1984 }),
1985 Arc::new(zeta_prompt::Event::BufferChange {
1986 path: Arc::from(Path::new("src/test.rs")),
1987 old_path: Arc::from(Path::new("src/test.rs")),
1988 diff: indoc! {"
1989 @@ -2,2 +2,3 @@
1990 first_add
1991 +second_add
1992 line2
1993 "}
1994 .into(),
1995 predicted: false,
1996 in_open_source_repo: false,
1997 }),
1998 ],
1999 None,
2000 );
2001
2002 let predicted = indoc! {"
2003 line1
2004 first_add
2005 line2
2006 "};
2007 let ratio =
2008 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
2009
2010 assert!(
2011 ratio > 0.9,
2012 "Expected high reversal ratio when prediction exactly reverses the most recent edit, got {}",
2013 ratio
2014 );
2015 }
2016}