1use std::ops::Range;
2use std::path::Path;
3use std::sync::Arc;
4
5use edit_prediction::udiff::apply_diff_to_string;
6use language::{char_diff, text_diff};
7
8use zeta_prompt::ZetaPromptInput;
9
10fn apply_diff_to_string_lenient(diff_str: &str, text: &str) -> String {
11 let hunks = parse_diff_hunks(diff_str);
12 let mut result = text.to_string();
13
14 for hunk in hunks {
15 let hunk_diff = format!("--- a/file\n+++ b/file\n{}", format_hunk(&hunk));
16 if let Ok(updated) = apply_diff_to_string(&hunk_diff, &result) {
17 result = updated;
18 }
19 }
20
21 result
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25struct ParsedHunk {
26 old_start: u32,
27 old_count: u32,
28 new_start: u32,
29 new_count: u32,
30 lines: Vec<HunkLine>,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34enum HunkLine {
35 Context(String),
36 Addition(String),
37 Deletion(String),
38}
39
40fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> {
41 let line = line.strip_prefix("@@ -")?;
42 let (old_part, rest) = line.split_once(' ')?;
43 let rest = rest.strip_prefix('+')?;
44 let (new_part, _) = rest.split_once(" @@")?;
45
46 let (old_start, old_count) = if let Some((start, count)) = old_part.split_once(',') {
47 (start.parse().ok()?, count.parse().ok()?)
48 } else {
49 (old_part.parse().ok()?, 1)
50 };
51
52 let (new_start, new_count) = if let Some((start, count)) = new_part.split_once(',') {
53 (start.parse().ok()?, count.parse().ok()?)
54 } else {
55 (new_part.parse().ok()?, 1)
56 };
57
58 Some((old_start, old_count, new_start, new_count))
59}
60
61fn parse_diff_hunks(diff: &str) -> Vec<ParsedHunk> {
62 let mut hunks = Vec::new();
63 let mut current_hunk: Option<ParsedHunk> = None;
64
65 for line in diff.lines() {
66 if let Some((old_start, old_count, new_start, new_count)) = parse_hunk_header(line) {
67 if let Some(hunk) = current_hunk.take() {
68 hunks.push(hunk);
69 }
70 current_hunk = Some(ParsedHunk {
71 old_start,
72 old_count,
73 new_start,
74 new_count,
75 lines: Vec::new(),
76 });
77 } else if let Some(ref mut hunk) = current_hunk {
78 if let Some(stripped) = line.strip_prefix('+') {
79 hunk.lines.push(HunkLine::Addition(stripped.to_string()));
80 } else if let Some(stripped) = line.strip_prefix('-') {
81 hunk.lines.push(HunkLine::Deletion(stripped.to_string()));
82 } else if let Some(stripped) = line.strip_prefix(' ') {
83 hunk.lines.push(HunkLine::Context(stripped.to_string()));
84 } else if line.is_empty() {
85 hunk.lines.push(HunkLine::Context(String::new()));
86 }
87 }
88 }
89
90 if let Some(hunk) = current_hunk {
91 hunks.push(hunk);
92 }
93
94 hunks
95}
96
97fn format_hunk(hunk: &ParsedHunk) -> String {
98 let mut result = format!(
99 "@@ -{},{} +{},{} @@\n",
100 hunk.old_start, hunk.old_count, hunk.new_start, hunk.new_count
101 );
102 for line in &hunk.lines {
103 match line {
104 HunkLine::Context(text) => {
105 result.push(' ');
106 result.push_str(text);
107 result.push('\n');
108 }
109 HunkLine::Addition(text) => {
110 result.push('+');
111 result.push_str(text);
112 result.push('\n');
113 }
114 HunkLine::Deletion(text) => {
115 result.push('-');
116 result.push_str(text);
117 result.push('\n');
118 }
119 }
120 }
121 result
122}
123
124fn filter_diff_hunks_by_excerpt(
125 diff: &str,
126 excerpt_start_row: u32,
127 excerpt_row_count: u32,
128) -> (String, i32) {
129 let hunks = parse_diff_hunks(diff);
130 let excerpt_start_0based = excerpt_start_row;
131 let excerpt_end_0based = excerpt_start_row + excerpt_row_count;
132
133 let mut filtered_hunks = Vec::new();
134 let mut cumulative_line_offset: i32 = 0;
135
136 for hunk in hunks {
137 let hunk_start_0based = hunk.new_start.saturating_sub(1);
138 let hunk_end_0based = hunk_start_0based + hunk.new_count;
139
140 let additions: i32 = hunk
141 .lines
142 .iter()
143 .filter(|l| matches!(l, HunkLine::Addition(_)))
144 .count() as i32;
145 let deletions: i32 = hunk
146 .lines
147 .iter()
148 .filter(|l| matches!(l, HunkLine::Deletion(_)))
149 .count() as i32;
150 let hunk_line_delta = additions - deletions;
151
152 if hunk_end_0based <= excerpt_start_0based {
153 cumulative_line_offset += hunk_line_delta;
154 continue;
155 }
156
157 if hunk_start_0based >= excerpt_end_0based {
158 continue;
159 }
160
161 let mut filtered_lines = Vec::new();
162 let mut current_row_0based = hunk_start_0based;
163 let mut filtered_old_count = 0u32;
164 let mut filtered_new_count = 0u32;
165 let mut first_included_row: Option<u32> = None;
166
167 for line in &hunk.lines {
168 match line {
169 HunkLine::Context(text) => {
170 if current_row_0based >= excerpt_start_0based
171 && current_row_0based < excerpt_end_0based
172 {
173 if first_included_row.is_none() {
174 first_included_row = Some(current_row_0based);
175 }
176 filtered_lines.push(HunkLine::Context(text.clone()));
177 filtered_old_count += 1;
178 filtered_new_count += 1;
179 }
180 current_row_0based += 1;
181 }
182 HunkLine::Addition(text) => {
183 if current_row_0based >= excerpt_start_0based
184 && current_row_0based < excerpt_end_0based
185 {
186 if first_included_row.is_none() {
187 first_included_row = Some(current_row_0based);
188 }
189 filtered_lines.push(HunkLine::Addition(text.clone()));
190 filtered_new_count += 1;
191 }
192 current_row_0based += 1;
193 }
194 HunkLine::Deletion(text) => {
195 if current_row_0based >= excerpt_start_0based
196 && current_row_0based < excerpt_end_0based
197 {
198 if first_included_row.is_none() {
199 first_included_row = Some(current_row_0based);
200 }
201 filtered_lines.push(HunkLine::Deletion(text.clone()));
202 filtered_old_count += 1;
203 }
204 }
205 }
206 }
207
208 if !filtered_lines.is_empty() {
209 let first_row = first_included_row.unwrap_or(excerpt_start_0based);
210 let new_start_1based = (first_row - excerpt_start_0based) + 1;
211
212 filtered_hunks.push(ParsedHunk {
213 old_start: new_start_1based,
214 old_count: filtered_old_count,
215 new_start: new_start_1based,
216 new_count: filtered_new_count,
217 lines: filtered_lines,
218 });
219 }
220
221 cumulative_line_offset += hunk_line_delta;
222 }
223
224 let mut result = String::new();
225 for hunk in &filtered_hunks {
226 result.push_str(&format_hunk(hunk));
227 }
228
229 (result, cumulative_line_offset)
230}
231
232fn compute_excerpt_aware_reversal_overlap(
233 edit_history_diffs: &[&str],
234 excerpt_content: &str,
235 excerpt_start_row: u32,
236 predicted_content: &str,
237) -> ReversalOverlap {
238 let mut current_content = excerpt_content.to_string();
239 let mut current_excerpt_start_row = excerpt_start_row;
240
241 for diff in edit_history_diffs.iter().rev() {
242 if diff.is_empty() {
243 continue;
244 }
245
246 let current_row_count = current_content.lines().count() as u32;
247 let (filtered_diff, _line_offset) =
248 filter_diff_hunks_by_excerpt(diff, current_excerpt_start_row, current_row_count.max(1));
249
250 if filtered_diff.is_empty() {
251 let hunks = parse_diff_hunks(diff);
252 for hunk in hunks {
253 let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
254 if hunk_end <= current_excerpt_start_row {
255 let additions: u32 = hunk
256 .lines
257 .iter()
258 .filter(|l| matches!(l, HunkLine::Addition(_)))
259 .count() as u32;
260 let deletions: u32 = hunk
261 .lines
262 .iter()
263 .filter(|l| matches!(l, HunkLine::Deletion(_)))
264 .count() as u32;
265 if additions >= deletions {
266 current_excerpt_start_row =
267 current_excerpt_start_row.saturating_sub(additions - deletions);
268 } else {
269 current_excerpt_start_row += deletions - additions;
270 }
271 }
272 }
273 continue;
274 }
275
276 let reversed = reverse_diff(&format!("--- a/file\n+++ b/file\n{}", filtered_diff));
277 match apply_diff_to_string(&reversed, ¤t_content) {
278 Ok(updated) => {
279 current_content = updated;
280 }
281 Err(_) => {
282 continue;
283 }
284 }
285
286 let hunks = parse_diff_hunks(diff);
287 for hunk in hunks {
288 let hunk_end = hunk.new_start.saturating_sub(1) + hunk.new_count;
289 if hunk_end <= current_excerpt_start_row {
290 let additions: u32 = hunk
291 .lines
292 .iter()
293 .filter(|l| matches!(l, HunkLine::Addition(_)))
294 .count() as u32;
295 let deletions: u32 = hunk
296 .lines
297 .iter()
298 .filter(|l| matches!(l, HunkLine::Deletion(_)))
299 .count() as u32;
300 if additions >= deletions {
301 current_excerpt_start_row =
302 current_excerpt_start_row.saturating_sub(additions - deletions);
303 } else {
304 current_excerpt_start_row += deletions - additions;
305 }
306 }
307 }
308 }
309
310 compute_reversal_overlap(¤t_content, excerpt_content, predicted_content)
311}
312
313fn reverse_diff(diff: &str) -> String {
314 let mut result: String = diff
315 .lines()
316 .map(|line| {
317 if line.starts_with("--- ") {
318 line.replacen("--- ", "+++ ", 1)
319 } else if line.starts_with("+++ ") {
320 line.replacen("+++ ", "--- ", 1)
321 } else if line.starts_with('+') && !line.starts_with("+++") {
322 format!("-{}", &line[1..])
323 } else if line.starts_with('-') && !line.starts_with("---") {
324 format!("+{}", &line[1..])
325 } else {
326 line.to_string()
327 }
328 })
329 .collect::<Vec<_>>()
330 .join("\n");
331 if diff.ends_with('\n') {
332 result.push('\n');
333 }
334 result
335}
336
337#[derive(Debug, Clone, PartialEq, Eq)]
338struct GranularEdit {
339 range: Range<usize>,
340 old_text: String,
341 new_text: String,
342}
343
344fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
345 text_diff(old_text, new_text)
346 .into_iter()
347 .map(|(range, new_text)| GranularEdit {
348 old_text: old_text[range.clone()].to_string(),
349 range,
350 new_text: new_text.to_string(),
351 })
352 .collect()
353}
354
355#[derive(Debug, Clone)]
356struct HistoryAdditionRange {
357 range_in_current: Range<usize>,
358}
359
360#[derive(Debug, Clone)]
361struct HistoryDeletionRange {
362 deleted_text: String,
363 position_in_current: usize,
364}
365
366fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryAdditionRange> {
367 let mut result = Vec::new();
368 let mut offset_delta: isize = 0;
369
370 for edit in history_edits {
371 if !edit.new_text.is_empty() {
372 let new_start = (edit.range.start as isize + offset_delta) as usize;
373 let new_end = new_start + edit.new_text.len();
374 result.push(HistoryAdditionRange {
375 range_in_current: new_start..new_end,
376 });
377 }
378
379 offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
380 }
381
382 result
383}
384
385fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryDeletionRange> {
386 let mut result = Vec::new();
387 let mut offset_delta: isize = 0;
388
389 for edit in history_edits {
390 if !edit.old_text.is_empty() {
391 let position_in_current = (edit.range.start as isize + offset_delta) as usize;
392 result.push(HistoryDeletionRange {
393 deleted_text: edit.old_text.clone(),
394 position_in_current,
395 });
396 }
397
398 offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
399 }
400
401 result
402}
403
404#[derive(Debug, Clone, Default, PartialEq, Eq)]
405struct ReversalOverlap {
406 chars_reversing_user_edits: usize,
407 total_chars_in_prediction: usize,
408}
409
410impl ReversalOverlap {
411 fn ratio(&self) -> f32 {
412 if self.total_chars_in_prediction == 0 {
413 0.0
414 } else {
415 self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32
416 }
417 }
418}
419
420/// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension),
421/// or where `new_text` appears as a subsequence within `old_text` (reduction).
422///
423/// For extensions: when the user's text is preserved (in order) within the prediction,
424/// we only count the newly inserted characters, not the preserved ones.
425/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()")
426/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars
427///
428/// For reductions: when the prediction's text is preserved (in order) within the original,
429/// we only count the deleted characters, not the preserved ones.
430/// E.g., "ifrom" → "from" becomes 1 deleted char ("i")
431fn normalize_extension_edits(edits: Vec<GranularEdit>) -> Vec<GranularEdit> {
432 edits
433 .into_iter()
434 .flat_map(|edit| {
435 if edit.old_text.is_empty() || edit.new_text.is_empty() {
436 return vec![edit];
437 }
438
439 // Use character-wise diff to find exact byte ranges of changes
440 let char_edits = char_diff(&edit.old_text, &edit.new_text);
441
442 let all_deletions = !char_edits.is_empty()
443 && char_edits
444 .iter()
445 .all(|(range, replacement)| !range.is_empty() && replacement.is_empty());
446 let all_insertions = !char_edits.is_empty()
447 && char_edits
448 .iter()
449 .all(|(range, replacement)| range.is_empty() && !replacement.is_empty());
450 if all_deletions || all_insertions {
451 return char_edits
452 .into_iter()
453 .map(|(range, replacement)| GranularEdit {
454 range: edit.range.start + range.start..edit.range.start + range.end,
455 old_text: edit.old_text[range].to_string(),
456 new_text: replacement.to_string(),
457 })
458 .collect();
459 }
460
461 // Otherwise, keep the original edit (mixed changes)
462 vec![edit]
463 })
464 .collect()
465}
466
467fn compute_reversal_overlap(
468 original_content: &str,
469 current_content: &str,
470 predicted_content: &str,
471) -> ReversalOverlap {
472 let history_edits =
473 normalize_extension_edits(compute_granular_edits(original_content, current_content));
474 let prediction_edits =
475 normalize_extension_edits(compute_granular_edits(current_content, predicted_content));
476
477 let history_addition_ranges = compute_history_addition_ranges(&history_edits);
478 let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
479
480 let reversed_additions =
481 compute_reversed_additions(&history_addition_ranges, &prediction_edits);
482 let restored_deletions =
483 compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
484
485 let total_chars_in_prediction: usize = prediction_edits
486 .iter()
487 .map(|e| e.new_text.chars().count() + e.old_text.chars().count())
488 .sum();
489
490 ReversalOverlap {
491 chars_reversing_user_edits: reversed_additions + restored_deletions,
492 total_chars_in_prediction,
493 }
494}
495
496fn compute_reversed_additions(
497 history_addition_ranges: &[HistoryAdditionRange],
498 prediction_edits: &[GranularEdit],
499) -> usize {
500 let mut reversed_chars = 0;
501
502 for pred_edit in prediction_edits {
503 for history_addition in history_addition_ranges {
504 let overlap_start = pred_edit
505 .range
506 .start
507 .max(history_addition.range_in_current.start);
508 let overlap_end = pred_edit
509 .range
510 .end
511 .min(history_addition.range_in_current.end);
512
513 if overlap_start < overlap_end {
514 let relative_start = overlap_start - pred_edit.range.start;
515 let relative_end = overlap_end - pred_edit.range.start;
516 let overlap_text = &pred_edit.old_text[relative_start..relative_end];
517 reversed_chars += overlap_text.chars().count();
518 }
519 }
520 }
521
522 reversed_chars
523}
524
525fn compute_restored_deletions(
526 history_deletion_ranges: &[HistoryDeletionRange],
527 prediction_edits: &[GranularEdit],
528) -> usize {
529 let mut restored = 0;
530
531 for pred_edit in prediction_edits {
532 if pred_edit.new_text.is_empty() {
533 continue;
534 }
535
536 for deletion in history_deletion_ranges {
537 if pred_edit.range.contains(&deletion.position_in_current)
538 || deletion.position_in_current == pred_edit.range.start
539 {
540 restored += compute_lcs_length(&deletion.deleted_text, &pred_edit.new_text);
541 }
542 }
543 }
544
545 restored
546}
547
548fn compute_lcs_length(a: &str, b: &str) -> usize {
549 let a_chars: Vec<char> = a.chars().collect();
550 let b_chars: Vec<char> = b.chars().collect();
551 let m = a_chars.len();
552 let n = b_chars.len();
553
554 if m == 0 || n == 0 {
555 return 0;
556 }
557
558 let mut prev = vec![0; n + 1];
559 let mut curr = vec![0; n + 1];
560
561 for i in 1..=m {
562 for j in 1..=n {
563 if a_chars[i - 1] == b_chars[j - 1] {
564 curr[j] = prev[j - 1] + 1;
565 } else {
566 curr[j] = prev[j].max(curr[j - 1]);
567 }
568 }
569 std::mem::swap(&mut prev, &mut curr);
570 curr.fill(0);
571 }
572
573 prev[n]
574}
575
576fn filter_edit_history_by_path<'a>(
577 edit_history: &'a [Arc<zeta_prompt::Event>],
578 cursor_path: &std::path::Path,
579) -> Vec<&'a zeta_prompt::Event> {
580 edit_history
581 .iter()
582 .filter(|event| match event.as_ref() {
583 zeta_prompt::Event::BufferChange { path, .. } => {
584 let event_path = path.as_ref();
585 if event_path == cursor_path {
586 return true;
587 }
588 let stripped = event_path
589 .components()
590 .skip(1)
591 .collect::<std::path::PathBuf>();
592 stripped == cursor_path
593 }
594 })
595 .map(|arc| arc.as_ref())
596 .collect()
597}
598
599fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str {
600 match event {
601 zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(),
602 }
603}
604
605fn is_predicted_event(event: &zeta_prompt::Event) -> bool {
606 match event {
607 zeta_prompt::Event::BufferChange { predicted, .. } => *predicted,
608 }
609}
610
611pub fn compute_prediction_reversal_ratio(
612 prompt_inputs: &ZetaPromptInput,
613 predicted_content: &str,
614 cursor_path: &Path,
615) -> f32 {
616 let current_content: &str = prompt_inputs.cursor_excerpt.as_ref();
617
618 let edit_history: &[Arc<zeta_prompt::Event>] = &prompt_inputs.events;
619 let relevant_events = filter_edit_history_by_path(edit_history, cursor_path);
620
621 let most_recent = match relevant_events.last() {
622 Some(event) if !is_predicted_event(event) => *event,
623 _ => return 0.0,
624 };
625
626 let diff = extract_diff_from_event(most_recent);
627 if diff.is_empty() {
628 return 0.0;
629 }
630
631 if let Some(excerpt_start_row) = prompt_inputs.excerpt_start_row {
632 let diffs = vec![diff];
633 let overlap = compute_excerpt_aware_reversal_overlap(
634 &diffs,
635 current_content,
636 excerpt_start_row,
637 predicted_content,
638 );
639 return overlap.ratio();
640 }
641
642 let reversed = reverse_diff(diff);
643 let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed);
644 let original_content = match apply_diff_to_string(&with_headers, current_content) {
645 Ok(updated_content) => updated_content,
646 Err(_) => apply_diff_to_string_lenient(&reversed, current_content),
647 };
648
649 let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content);
650 overlap.ratio()
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656 use edit_prediction::udiff::apply_diff_to_string;
657 use indoc::indoc;
658 use zeta_prompt::ExcerptRanges;
659
660 fn make_test_prompt_inputs(
661 content: &str,
662 events: Vec<Arc<zeta_prompt::Event>>,
663 excerpt_start_row: Option<u32>,
664 ) -> ZetaPromptInput {
665 ZetaPromptInput {
666 cursor_path: Arc::from(Path::new("src/test.rs")),
667 cursor_excerpt: content.into(),
668 cursor_offset_in_excerpt: 0,
669 excerpt_start_row,
670 events,
671 related_files: Vec::new(),
672 excerpt_ranges: ExcerptRanges {
673 editable_150: 0..content.len(),
674 editable_180: 0..content.len(),
675 editable_350: 0..content.len(),
676 editable_150_context_350: 0..content.len(),
677 editable_180_context_350: 0..content.len(),
678 editable_350_context_150: 0..content.len(),
679 ..Default::default()
680 },
681 experiment: None,
682 in_open_source_repo: false,
683 can_collect_data: false,
684 repo_url: None,
685 }
686 }
687
688 #[test]
689 fn test_reversal_overlap() {
690 struct Case {
691 name: &'static str,
692 original: &'static str,
693 current: &'static str,
694 predicted: &'static str,
695 expected_reversal_chars: usize,
696 expected_total_chars: usize,
697 }
698
699 let cases = [
700 Case {
701 name: "user_adds_line_prediction_removes_it",
702 original: indoc! {"
703 a
704 b
705 c"},
706 current: indoc! {"
707 a
708 new line
709 b
710 c"},
711 predicted: indoc! {"
712 a
713 b
714 c"},
715 expected_reversal_chars: 9,
716 expected_total_chars: 9,
717 },
718 Case {
719 name: "user_deletes_line_prediction_restores_it",
720 original: indoc! {"
721 a
722 deleted
723 b"},
724 current: indoc! {"
725 a
726 b"},
727 predicted: indoc! {"
728 a
729 deleted
730 b"},
731 expected_reversal_chars: 8,
732 expected_total_chars: 8,
733 },
734 Case {
735 name: "user_deletes_text_prediction_restores_partial",
736 original: "hello beautiful world",
737 current: "hello world",
738 predicted: "hello beautiful world",
739 expected_reversal_chars: 10,
740 expected_total_chars: 10,
741 },
742 Case {
743 name: "user_deletes_foo_prediction_adds_bar",
744 original: "foo",
745 current: "",
746 predicted: "bar",
747 expected_reversal_chars: 0,
748 expected_total_chars: 3,
749 },
750 Case {
751 name: "independent_edits_different_locations",
752 original: indoc! {"
753 line1
754 line2
755 line3"},
756 current: indoc! {"
757 LINE1
758 line2
759 line3"},
760 predicted: indoc! {"
761 LINE1
762 line2
763 LINE3"},
764 expected_reversal_chars: 0,
765 expected_total_chars: 10,
766 },
767 Case {
768 name: "no_history_edits",
769 original: "same",
770 current: "same",
771 predicted: "different",
772 expected_reversal_chars: 0,
773 expected_total_chars: 13,
774 },
775 Case {
776 name: "user_replaces_text_prediction_reverses",
777 original: indoc! {"
778 keep
779 delete_me
780 keep2"},
781 current: indoc! {"
782 keep
783 added
784 keep2"},
785 predicted: indoc! {"
786 keep
787 delete_me
788 keep2"},
789 expected_reversal_chars: 14,
790 expected_total_chars: 14,
791 },
792 Case {
793 name: "user_modifies_word_prediction_modifies_differently",
794 original: "the quick brown fox",
795 current: "the slow brown fox",
796 predicted: "the fast brown fox",
797 expected_reversal_chars: 4,
798 expected_total_chars: 8,
799 },
800 Case {
801 name: "user finishes function name (suffix)",
802 original: "",
803 current: "epr",
804 predicted: "eprintln!()",
805 expected_reversal_chars: 0,
806 expected_total_chars: 8,
807 },
808 Case {
809 name: "user starts function name (prefix)",
810 original: "",
811 current: "my_function()",
812 predicted: "test_my_function()",
813 expected_reversal_chars: 0,
814 expected_total_chars: 5,
815 },
816 Case {
817 name: "user types partial, prediction extends in multiple places",
818 original: "",
819 current: "test_my_function",
820 predicted: "a_test_for_my_special_function_plz",
821 expected_reversal_chars: 0,
822 expected_total_chars: 18,
823 },
824 // Edge cases for subsequence matching
825 Case {
826 name: "subsequence with interleaved underscores",
827 original: "",
828 current: "a_b_c",
829 predicted: "_a__b__c__",
830 expected_reversal_chars: 0,
831 expected_total_chars: 5,
832 },
833 Case {
834 name: "not a subsequence - different characters",
835 original: "",
836 current: "abc",
837 predicted: "xyz",
838 expected_reversal_chars: 3,
839 expected_total_chars: 6,
840 },
841 Case {
842 name: "not a subsequence - wrong order",
843 original: "",
844 current: "abc",
845 predicted: "cba",
846 expected_reversal_chars: 3,
847 expected_total_chars: 6,
848 },
849 Case {
850 name: "partial subsequence - only some chars match",
851 original: "",
852 current: "abcd",
853 predicted: "axbx",
854 expected_reversal_chars: 4,
855 expected_total_chars: 8,
856 },
857 // Common completion patterns
858 Case {
859 name: "completing a method call",
860 original: "",
861 current: "vec.pu",
862 predicted: "vec.push(item)",
863 expected_reversal_chars: 0,
864 expected_total_chars: 8,
865 },
866 Case {
867 name: "completing an import statement",
868 original: "",
869 current: "use std::col",
870 predicted: "use std::collections::HashMap",
871 expected_reversal_chars: 0,
872 expected_total_chars: 17,
873 },
874 Case {
875 name: "completing a struct field",
876 original: "",
877 current: "name: St",
878 predicted: "name: String",
879 expected_reversal_chars: 0,
880 expected_total_chars: 4,
881 },
882 Case {
883 name: "prediction replaces with completely different text",
884 original: "",
885 current: "hello",
886 predicted: "world",
887 expected_reversal_chars: 5,
888 expected_total_chars: 10,
889 },
890 Case {
891 name: "empty prediction removes user text",
892 original: "",
893 current: "mistake",
894 predicted: "",
895 expected_reversal_chars: 7,
896 expected_total_chars: 7,
897 },
898 Case {
899 name: "fixing typo is not reversal",
900 original: "",
901 current: "<dv",
902 predicted: "<div>",
903 expected_reversal_chars: 0,
904 expected_total_chars: 2,
905 },
906 Case {
907 name: "infix insertion not reversal",
908 original: indoc! {"
909 from my_project import Foo
910 "},
911 current: indoc! {"
912 ifrom my_project import Foo
913 "},
914 predicted: indoc! {"
915 import
916 from my_project import Foo
917 "},
918 expected_reversal_chars: 0,
919 expected_total_chars: 6,
920 },
921 Case {
922 name: "non-word based reversal",
923 original: "from",
924 current: "ifrom",
925 predicted: "from",
926 expected_reversal_chars: 1,
927 expected_total_chars: 1,
928 },
929 Case {
930 name: "multiple insertions no reversal",
931 original: "print(\"Hello, World!\")",
932 current: "sys.(\"Hello, World!\")",
933 predicted: "sys.stdout.write(\"Hello, World!\\n\")",
934 expected_reversal_chars: 0,
935 expected_total_chars: 14,
936 },
937 ];
938
939 for case in &cases {
940 let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
941 assert_eq!(
942 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
943 "Test '{}': expected {} reversal chars, got {}",
944 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
945 );
946 assert_eq!(
947 overlap.total_chars_in_prediction, case.expected_total_chars,
948 "Test '{}': expected {} total chars, got {}",
949 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
950 );
951 }
952 }
953
954 #[test]
955 fn test_reverse_diff() {
956 let forward_diff = indoc! {"
957 --- a/file.rs
958 +++ b/file.rs
959 @@ -1,3 +1,4 @@
960 fn main() {
961 + let x = 42;
962 println!(\"hello\");
963 }"};
964
965 let reversed = reverse_diff(forward_diff);
966
967 assert!(
968 reversed.contains("+++ a/file.rs"),
969 "Should have +++ for old path"
970 );
971 assert!(
972 reversed.contains("--- b/file.rs"),
973 "Should have --- for new path"
974 );
975 assert!(
976 reversed.contains("- let x = 42;"),
977 "Added line should become deletion"
978 );
979 assert!(
980 reversed.contains(" fn main()"),
981 "Context lines should be unchanged"
982 );
983 }
984
985 #[test]
986 fn test_reverse_diff_roundtrip() {
987 // Applying a diff and then its reverse should get back to original
988 let original = indoc! {"
989 first line
990 hello world
991 last line
992 "};
993 let modified = indoc! {"
994 first line
995 hello beautiful world
996 last line
997 "};
998
999 // unified_diff doesn't include file headers, but apply_diff_to_string needs them
1000 let diff_body = language::unified_diff(original, modified);
1001 let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body);
1002 let reversed_diff = reverse_diff(&forward_diff);
1003
1004 // Apply forward diff to original
1005 let after_forward = apply_diff_to_string(&forward_diff, original).unwrap();
1006 assert_eq!(after_forward, modified);
1007
1008 // Apply reversed diff to modified
1009 let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap();
1010 assert_eq!(after_reverse, original);
1011 }
1012
1013 #[test]
1014 fn test_filter_edit_history_by_path() {
1015 // Test that filter_edit_history_by_path correctly matches paths when
1016 // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs")
1017 // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs")
1018 let events = vec![
1019 Arc::new(zeta_prompt::Event::BufferChange {
1020 path: Arc::from(Path::new("myrepo/src/file.rs")),
1021 old_path: Arc::from(Path::new("myrepo/src/file.rs")),
1022 diff: indoc! {"
1023 @@ -1 +1 @@
1024 -old
1025 +new"}
1026 .into(),
1027 predicted: false,
1028 in_open_source_repo: true,
1029 }),
1030 Arc::new(zeta_prompt::Event::BufferChange {
1031 path: Arc::from(Path::new("myrepo/other.rs")),
1032 old_path: Arc::from(Path::new("myrepo/other.rs")),
1033 diff: indoc! {"
1034 @@ -1 +1 @@
1035 -a
1036 +b"}
1037 .into(),
1038 predicted: false,
1039 in_open_source_repo: true,
1040 }),
1041 Arc::new(zeta_prompt::Event::BufferChange {
1042 path: Arc::from(Path::new("src/file.rs")),
1043 old_path: Arc::from(Path::new("src/file.rs")),
1044 diff: indoc! {"
1045 @@ -1 +1 @@
1046 -x
1047 +y"}
1048 .into(),
1049 predicted: false,
1050 in_open_source_repo: true,
1051 }),
1052 ];
1053
1054 // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path
1055 // "src/file.rs" exact match
1056 let cursor_path = Path::new("src/file.rs");
1057 let filtered = filter_edit_history_by_path(&events, cursor_path);
1058 assert_eq!(
1059 filtered.len(),
1060 2,
1061 "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)"
1062 );
1063
1064 // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs"
1065 // "src/file.rs" stripped -> "file.rs" == "file.rs"
1066 let cursor_path = Path::new("file.rs");
1067 let filtered = filter_edit_history_by_path(&events, cursor_path);
1068 assert_eq!(
1069 filtered.len(),
1070 1,
1071 "Should only match src/file.rs (stripped to file.rs)"
1072 );
1073
1074 // "myrepo/other.rs" stripped -> "other.rs" == "other.rs"
1075 let cursor_path = Path::new("other.rs");
1076 let filtered = filter_edit_history_by_path(&events, cursor_path);
1077 assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs");
1078 }
1079
1080 #[test]
1081 fn test_reverse_diff_preserves_trailing_newline() {
1082 let diff_with_trailing_newline = indoc! {"
1083 --- a/file
1084 +++ b/file
1085 @@ -1 +1 @@
1086 -old
1087 +new
1088 "};
1089 let reversed = reverse_diff(diff_with_trailing_newline);
1090 assert!(
1091 reversed.ends_with('\n'),
1092 "Reversed diff should preserve trailing newline"
1093 );
1094
1095 let diff_without_trailing_newline = indoc! {"
1096 --- a/file
1097 +++ b/file
1098 @@ -1 +1 @@
1099 -old
1100 +new"};
1101 let reversed = reverse_diff(diff_without_trailing_newline);
1102 assert!(
1103 !reversed.ends_with('\n'),
1104 "Reversed diff should not add trailing newline if original didn't have one"
1105 );
1106 }
1107
1108 #[test]
1109 fn test_filter_hunks_by_excerpt_region() {
1110 struct Case {
1111 name: &'static str,
1112 diff: &'static str,
1113 excerpt_start_row: u32,
1114 excerpt_row_count: u32,
1115 expected_filtered_diff: &'static str,
1116 expected_line_offset: i32,
1117 }
1118
1119 let cases = [
1120 Case {
1121 name: "hunk_entirely_before_excerpt",
1122 diff: indoc! {"
1123 @@ -1,3 +1,4 @@
1124 line1
1125 +inserted
1126 line2
1127 line3
1128 "},
1129 excerpt_start_row: 10,
1130 excerpt_row_count: 5,
1131 expected_filtered_diff: "",
1132 expected_line_offset: 1,
1133 },
1134 Case {
1135 name: "hunk_entirely_inside_excerpt",
1136 diff: indoc! {"
1137 @@ -12,3 +12,4 @@
1138 line12
1139 +inserted
1140 line13
1141 line14
1142 "},
1143 excerpt_start_row: 10,
1144 excerpt_row_count: 10,
1145 expected_filtered_diff: indoc! {"
1146 @@ -2,3 +2,4 @@
1147 line12
1148 +inserted
1149 line13
1150 line14
1151 "},
1152 expected_line_offset: 1,
1153 },
1154 Case {
1155 name: "hunk_entirely_after_excerpt",
1156 diff: indoc! {"
1157 @@ -50,3 +50,4 @@
1158 line50
1159 +inserted
1160 line51
1161 line52
1162 "},
1163 excerpt_start_row: 10,
1164 excerpt_row_count: 5,
1165 expected_filtered_diff: "",
1166 expected_line_offset: 0,
1167 },
1168 Case {
1169 name: "hunk_straddles_excerpt_start",
1170 diff: indoc! {"
1171 @@ -8,5 +8,6 @@
1172 line8
1173 line9
1174 +inserted
1175 line10
1176 line11
1177 line12
1178 "},
1179 excerpt_start_row: 10,
1180 excerpt_row_count: 10,
1181 expected_filtered_diff: indoc! {"
1182 @@ -1,3 +1,3 @@
1183 line10
1184 line11
1185 line12
1186 "},
1187 expected_line_offset: 1,
1188 },
1189 Case {
1190 name: "hunk_straddles_excerpt_end",
1191 diff: indoc! {"
1192 @@ -18,5 +18,6 @@
1193 line18
1194 line19
1195 +inserted
1196 line20
1197 line21
1198 line22
1199 "},
1200 excerpt_start_row: 10,
1201 excerpt_row_count: 10,
1202 expected_filtered_diff: indoc! {"
1203 @@ -8,2 +8,3 @@
1204 line18
1205 line19
1206 +inserted
1207 "},
1208 expected_line_offset: 1,
1209 },
1210 Case {
1211 name: "multiple_hunks_mixed",
1212 diff: indoc! {"
1213 @@ -1,2 +1,3 @@
1214 line1
1215 +before_excerpt
1216 line2
1217 @@ -12,2 +13,3 @@
1218 line12
1219 +inside_excerpt
1220 line13
1221 @@ -50,2 +52,3 @@
1222 line50
1223 +after_excerpt
1224 line51
1225 "},
1226 excerpt_start_row: 10,
1227 excerpt_row_count: 10,
1228 expected_filtered_diff: indoc! {"
1229 @@ -3,2 +3,3 @@
1230 line12
1231 +inside_excerpt
1232 line13
1233 "},
1234 expected_line_offset: 2,
1235 },
1236 Case {
1237 name: "deletion_before_excerpt",
1238 diff: indoc! {"
1239 @@ -1,4 +1,3 @@
1240 line1
1241 -deleted
1242 line2
1243 line3
1244 "},
1245 excerpt_start_row: 10,
1246 excerpt_row_count: 5,
1247 expected_filtered_diff: "",
1248 expected_line_offset: -1,
1249 },
1250 Case {
1251 name: "deletion_inside_excerpt",
1252 diff: indoc! {"
1253 @@ -12,4 +12,3 @@
1254 line12
1255 -deleted
1256 line13
1257 line14
1258 "},
1259 excerpt_start_row: 10,
1260 excerpt_row_count: 10,
1261 expected_filtered_diff: indoc! {"
1262 @@ -2,4 +2,3 @@
1263 line12
1264 -deleted
1265 line13
1266 line14
1267 "},
1268 expected_line_offset: -1,
1269 },
1270 Case {
1271 name: "empty_diff",
1272 diff: "",
1273 excerpt_start_row: 10,
1274 excerpt_row_count: 5,
1275 expected_filtered_diff: "",
1276 expected_line_offset: 0,
1277 },
1278 Case {
1279 name: "hunk_spans_entire_excerpt",
1280 diff: indoc! {"
1281 @@ -8,10 +8,12 @@
1282 line8
1283 line9
1284 line10
1285 line11
1286 +inserted1
1287 line12
1288 line13
1289 +inserted2
1290 line14
1291 line15
1292 line16
1293 line17
1294 "},
1295 excerpt_start_row: 10,
1296 excerpt_row_count: 5,
1297 expected_filtered_diff: indoc! {"
1298 @@ -1,3 +1,5 @@
1299 line11
1300 +inserted1
1301 line12
1302 line13
1303 +inserted2
1304 "},
1305 expected_line_offset: 2,
1306 },
1307 Case {
1308 name: "replacement_inside_excerpt",
1309 diff: indoc! {"
1310 @@ -12,3 +12,3 @@
1311 line12
1312 -old_text
1313 +new_text
1314 line14
1315 "},
1316 excerpt_start_row: 10,
1317 excerpt_row_count: 10,
1318 expected_filtered_diff: indoc! {"
1319 @@ -2,3 +2,3 @@
1320 line12
1321 -old_text
1322 +new_text
1323 line14
1324 "},
1325 expected_line_offset: 0,
1326 },
1327 ];
1328
1329 for case in &cases {
1330 let (filtered, line_offset) = filter_diff_hunks_by_excerpt(
1331 case.diff,
1332 case.excerpt_start_row,
1333 case.excerpt_row_count,
1334 );
1335 assert_eq!(
1336 filtered, case.expected_filtered_diff,
1337 "Test '{}': filtered diff mismatch.\nExpected:\n{}\nGot:\n{}",
1338 case.name, case.expected_filtered_diff, filtered
1339 );
1340 assert_eq!(
1341 line_offset, case.expected_line_offset,
1342 "Test '{}': line offset mismatch. Expected {}, got {}",
1343 case.name, case.expected_line_offset, line_offset
1344 );
1345 }
1346 }
1347
1348 #[test]
1349 fn test_excerpt_aware_reversal_tracking() {
1350 struct Case {
1351 name: &'static str,
1352 edit_history_diffs: Vec<&'static str>,
1353 excerpt_content: &'static str,
1354 excerpt_start_row: u32,
1355 predicted_content: &'static str,
1356 expected_reversal_chars: usize,
1357 expected_total_chars: usize,
1358 }
1359
1360 let cases = [
1361 Case {
1362 name: "edit_outside_excerpt_no_reversal",
1363 edit_history_diffs: vec![indoc! {"
1364 @@ -1,2 +1,3 @@
1365 line1
1366 +added_outside
1367 line2
1368 "}],
1369 excerpt_content: indoc! {"
1370 line10
1371 line11
1372 line12
1373 "},
1374 excerpt_start_row: 10,
1375 predicted_content: indoc! {"
1376 line10
1377 modified
1378 line12
1379 "},
1380 expected_reversal_chars: 0,
1381 expected_total_chars: 14,
1382 },
1383 Case {
1384 name: "edit_inside_excerpt_with_reversal",
1385 edit_history_diffs: vec![indoc! {"
1386 @@ -10,3 +10,4 @@
1387 line10
1388 +user_added
1389 line11
1390 line12
1391 "}],
1392 excerpt_content: indoc! {"
1393 line10
1394 user_added
1395 line11
1396 line12
1397 "},
1398 excerpt_start_row: 10,
1399 predicted_content: indoc! {"
1400 line10
1401 line11
1402 line12
1403 "},
1404 expected_reversal_chars: 11,
1405 expected_total_chars: 11,
1406 },
1407 Case {
1408 name: "straddling_edit_partial_reversal",
1409 edit_history_diffs: vec![indoc! {"
1410 @@ -8,6 +8,8 @@
1411 line8
1412 line9
1413 +before_excerpt
1414 line10
1415 +inside_excerpt
1416 line11
1417 line12
1418 line13
1419 "}],
1420 excerpt_content: indoc! {"
1421 line10
1422 inside_excerpt
1423 line11
1424 line12
1425 line13
1426 "},
1427 excerpt_start_row: 10,
1428 predicted_content: indoc! {"
1429 line10
1430 line11
1431 line12
1432 line13
1433 "},
1434 expected_reversal_chars: 15,
1435 expected_total_chars: 15,
1436 },
1437 Case {
1438 name: "multiple_edits_mixed_locations",
1439 edit_history_diffs: vec![
1440 indoc! {"
1441 @@ -1,2 +1,3 @@
1442 line1
1443 +outside1
1444 line2
1445 "},
1446 indoc! {"
1447 @@ -11,2 +12,3 @@
1448 line11
1449 +inside1
1450 line12
1451 "},
1452 ],
1453 excerpt_content: indoc! {"
1454 line10
1455 line11
1456 inside1
1457 line12
1458 line13
1459 "},
1460 excerpt_start_row: 10,
1461 predicted_content: indoc! {"
1462 line10
1463 line11
1464 line12
1465 line13
1466 "},
1467 expected_reversal_chars: 8,
1468 expected_total_chars: 8,
1469 },
1470 Case {
1471 name: "no_edit_history",
1472 edit_history_diffs: vec![],
1473 excerpt_content: indoc! {"
1474 line10
1475 line11
1476 line12
1477 "},
1478 excerpt_start_row: 10,
1479 predicted_content: indoc! {"
1480 line10
1481 modified
1482 line12
1483 "},
1484 expected_reversal_chars: 0,
1485 expected_total_chars: 14,
1486 },
1487 Case {
1488 name: "edit_after_excerpt_no_effect",
1489 edit_history_diffs: vec![indoc! {"
1490 @@ -50,2 +50,3 @@
1491 line50
1492 +added_after
1493 line51
1494 "}],
1495 excerpt_content: indoc! {"
1496 line10
1497 line11
1498 line12
1499 "},
1500 excerpt_start_row: 10,
1501 predicted_content: indoc! {"
1502 line10
1503 changed
1504 line12
1505 "},
1506 expected_reversal_chars: 0,
1507 expected_total_chars: 13,
1508 },
1509 Case {
1510 name: "line_offset_tracking_across_hunks",
1511 edit_history_diffs: vec![
1512 indoc! {"
1513 @@ -1,2 +1,4 @@
1514 line1
1515 +added1
1516 +added2
1517 line2
1518 "},
1519 indoc! {"
1520 @@ -12,2 +14,3 @@
1521 line12
1522 +inside_after_offset
1523 line13
1524 "},
1525 ],
1526 excerpt_content: indoc! {"
1527 line10
1528 line11
1529 line12
1530 inside_after_offset
1531 line13
1532 "},
1533 excerpt_start_row: 10,
1534 predicted_content: indoc! {"
1535 line10
1536 line11
1537 line12
1538 line13
1539 "},
1540 expected_reversal_chars: 20,
1541 expected_total_chars: 20,
1542 },
1543 ];
1544
1545 for case in &cases {
1546 let overlap = compute_excerpt_aware_reversal_overlap(
1547 &case.edit_history_diffs,
1548 case.excerpt_content,
1549 case.excerpt_start_row,
1550 case.predicted_content,
1551 );
1552 assert_eq!(
1553 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
1554 "Test '{}': expected {} reversal chars, got {}",
1555 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
1556 );
1557 assert_eq!(
1558 overlap.total_chars_in_prediction, case.expected_total_chars,
1559 "Test '{}': expected {} total chars, got {}",
1560 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
1561 );
1562 }
1563 }
1564
1565 #[test]
1566 fn test_lenient_diff_application() {
1567 struct Case {
1568 name: &'static str,
1569 diff: &'static str,
1570 content: &'static str,
1571 expected_result: &'static str,
1572 }
1573
1574 let cases = [
1575 Case {
1576 name: "hunk_context_not_found_skipped",
1577 diff: indoc! {"
1578 @@ -1,3 +1,4 @@
1579 context_not_in_content
1580 +added_line
1581 more_context
1582 final_context
1583 "},
1584 content: indoc! {"
1585 completely
1586 different
1587 content
1588 "},
1589 expected_result: indoc! {"
1590 completely
1591 different
1592 content
1593 "},
1594 },
1595 Case {
1596 name: "hunk_context_found_applied",
1597 diff: indoc! {"
1598 @@ -1,3 +1,4 @@
1599 line1
1600 +inserted
1601 line2
1602 line3
1603 "},
1604 content: indoc! {"
1605 line1
1606 line2
1607 line3
1608 "},
1609 expected_result: indoc! {"
1610 line1
1611 inserted
1612 line2
1613 line3
1614 "},
1615 },
1616 Case {
1617 name: "multiple_hunks_partial_match",
1618 diff: indoc! {"
1619 @@ -1,2 +1,3 @@
1620 not_found
1621 +skipped
1622 also_not_found
1623 @@ -5,2 +6,3 @@
1624 line5
1625 +applied
1626 line6
1627 "},
1628 content: indoc! {"
1629 line1
1630 line2
1631 line3
1632 line4
1633 line5
1634 line6
1635 "},
1636 expected_result: indoc! {"
1637 line1
1638 line2
1639 line3
1640 line4
1641 line5
1642 applied
1643 line6
1644 "},
1645 },
1646 Case {
1647 name: "empty_diff",
1648 diff: "",
1649 content: indoc! {"
1650 unchanged
1651 content
1652 "},
1653 expected_result: indoc! {"
1654 unchanged
1655 content
1656 "},
1657 },
1658 ];
1659
1660 for case in &cases {
1661 let result = apply_diff_to_string_lenient(case.diff, case.content);
1662 assert_eq!(
1663 result, case.expected_result,
1664 "Test '{}': expected:\n{}\ngot:\n{}",
1665 case.name, case.expected_result, result
1666 );
1667 }
1668 }
1669
1670 #[test]
1671 fn test_unicode_reversal_overlap() {
1672 struct Case {
1673 name: &'static str,
1674 original: &'static str,
1675 current: &'static str,
1676 predicted: &'static str,
1677 expected_reversal_chars: usize,
1678 expected_total_chars: usize,
1679 }
1680
1681 let cases = [
1682 Case {
1683 name: "unicode_extension_cjk",
1684 original: "",
1685 current: "日", // 1 char
1686 predicted: "日本語", // 3 chars, adds 2 chars
1687 expected_reversal_chars: 0,
1688 expected_total_chars: 2, // "本語" = 2 chars added
1689 },
1690 Case {
1691 name: "unicode_extension_emoji",
1692 original: "",
1693 current: "🎉", // 1 char
1694 predicted: "🎉🎊🎈", // 3 chars, adds 2 chars
1695 expected_reversal_chars: 0,
1696 expected_total_chars: 2, // "🎊🎈" = 2 chars added
1697 },
1698 Case {
1699 name: "unicode_deletion_restored",
1700 original: "héllo wörld", // 11 chars
1701 current: "héllo", // 5 chars
1702 predicted: "héllo wörld", // restores " wörld" = 6 chars
1703 expected_reversal_chars: 6, // LCS(" wörld", " wörld") = 6 chars
1704 expected_total_chars: 6,
1705 },
1706 Case {
1707 name: "unicode_addition_reversed",
1708 original: "café", // 4 chars
1709 current: "café latté", // 10 chars, added " latté" = 6 chars
1710 predicted: "café", // removes " latté"
1711 expected_reversal_chars: 6, // 6 chars removed
1712 expected_total_chars: 6,
1713 },
1714 Case {
1715 name: "mixed_ascii_unicode",
1716 original: "",
1717 current: "test日本", // 6 chars
1718 predicted: "test日本語です", // 9 chars
1719 expected_reversal_chars: 0,
1720 expected_total_chars: 3, // 3 new chars after subsequence normalization
1721 },
1722 Case {
1723 name: "unicode_replacement_not_subsequence",
1724 original: "",
1725 current: "日本", // 2 chars
1726 predicted: "中国", // 2 chars, different
1727 expected_reversal_chars: 2, // removes "日本" = 2 chars
1728 expected_total_chars: 4, // 2 removed + 2 added
1729 },
1730 ];
1731
1732 for case in &cases {
1733 let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
1734 assert_eq!(
1735 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
1736 "Test '{}': expected {} reversal chars, got {}",
1737 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
1738 );
1739 assert_eq!(
1740 overlap.total_chars_in_prediction, case.expected_total_chars,
1741 "Test '{}': expected {} total chars, got {}",
1742 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
1743 );
1744 }
1745 }
1746
1747 #[test]
1748 fn test_compute_lcs_length() {
1749 assert_eq!(compute_lcs_length("", ""), 0);
1750 assert_eq!(compute_lcs_length("abc", ""), 0);
1751 assert_eq!(compute_lcs_length("", "abc"), 0);
1752 assert_eq!(compute_lcs_length("abc", "abc"), 3);
1753 assert_eq!(compute_lcs_length("abc", "def"), 0);
1754 assert_eq!(compute_lcs_length("abcdef", "ace"), 3);
1755 assert_eq!(compute_lcs_length("AGGTAB", "GXTXAYB"), 4);
1756 assert_eq!(compute_lcs_length("日本語", "日語"), 2);
1757 }
1758
1759 #[test]
1760 fn test_compute_prediction_reversal_ratio_full_file() {
1761 let prompt_inputs = make_test_prompt_inputs(
1762 indoc! {"
1763 line1
1764 user_added
1765 line2
1766 "},
1767 vec![Arc::new(zeta_prompt::Event::BufferChange {
1768 path: Arc::from(Path::new("src/test.rs")),
1769 old_path: Arc::from(Path::new("src/test.rs")),
1770 diff: indoc! {"
1771 @@ -1,2 +1,3 @@
1772 line1
1773 +user_added
1774 line2
1775 "}
1776 .into(),
1777 predicted: false,
1778 in_open_source_repo: false,
1779 })],
1780 None,
1781 );
1782
1783 let predicted = indoc! {"
1784 line1
1785 line2
1786 "};
1787 let ratio =
1788 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1789
1790 assert!(
1791 ratio > 0.9,
1792 "Expected high reversal ratio when prediction removes user addition, got {}",
1793 ratio
1794 );
1795 }
1796
1797 #[test]
1798 fn test_compute_prediction_reversal_ratio_with_excerpt() {
1799 let prompt_inputs = make_test_prompt_inputs(
1800 indoc! {"
1801 line10
1802 user_added
1803 line11
1804 "},
1805 vec![Arc::new(zeta_prompt::Event::BufferChange {
1806 path: Arc::from(Path::new("src/test.rs")),
1807 old_path: Arc::from(Path::new("src/test.rs")),
1808 diff: indoc! {"
1809 @@ -10,2 +10,3 @@
1810 line10
1811 +user_added
1812 line11
1813 "}
1814 .into(),
1815 predicted: false,
1816 in_open_source_repo: false,
1817 })],
1818 Some(10),
1819 );
1820
1821 let predicted = indoc! {"
1822 line10
1823 line11
1824 "};
1825 let ratio =
1826 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1827
1828 assert!(
1829 ratio > 0.9,
1830 "Expected high reversal ratio for excerpt-aware computation, got {}",
1831 ratio
1832 );
1833 }
1834
1835 #[test]
1836 fn test_compute_prediction_reversal_ratio_no_history() {
1837 let prompt_inputs = make_test_prompt_inputs(
1838 indoc! {"
1839 original content
1840 "},
1841 vec![],
1842 None,
1843 );
1844
1845 let predicted = indoc! {"
1846 completely different
1847 "};
1848 let ratio =
1849 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1850
1851 assert_eq!(
1852 ratio, 0.0,
1853 "Expected zero reversal ratio with no edit history"
1854 );
1855 }
1856
1857 #[test]
1858 fn test_compute_prediction_reversal_ratio_path_filtering() {
1859 let prompt_inputs = make_test_prompt_inputs(
1860 indoc! {"
1861 line1
1862 user_added
1863 line2
1864 "},
1865 vec![Arc::new(zeta_prompt::Event::BufferChange {
1866 path: Arc::from(Path::new("src/other.rs")),
1867 old_path: Arc::from(Path::new("src/other.rs")),
1868 diff: indoc! {"
1869 @@ -1,2 +1,3 @@
1870 line1
1871 +user_added
1872 line2
1873 "}
1874 .into(),
1875 predicted: false,
1876 in_open_source_repo: false,
1877 })],
1878 None,
1879 );
1880
1881 let predicted = indoc! {"
1882 line1
1883 line2
1884 "};
1885 let ratio =
1886 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1887
1888 assert_eq!(
1889 ratio, 0.0,
1890 "Expected zero reversal when edit history is for different file"
1891 );
1892 }
1893
1894 #[test]
1895 fn test_compute_prediction_reversal_ratio_lenient_fallback() {
1896 let prompt_inputs = make_test_prompt_inputs(
1897 indoc! {"
1898 actual_line1
1899 user_added
1900 actual_line2
1901 "},
1902 vec![Arc::new(zeta_prompt::Event::BufferChange {
1903 path: Arc::from(Path::new("src/test.rs")),
1904 old_path: Arc::from(Path::new("src/test.rs")),
1905 diff: indoc! {"
1906 @@ -1,2 +1,3 @@
1907 wrong_context
1908 +user_added
1909 more_wrong
1910 "}
1911 .into(),
1912 predicted: false,
1913 in_open_source_repo: false,
1914 })],
1915 None,
1916 );
1917
1918 let predicted = indoc! {"
1919 actual_line1
1920 actual_line2
1921 "};
1922 let ratio =
1923 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
1924
1925 assert!(
1926 ratio >= 0.0 && ratio <= 1.0,
1927 "Ratio should be valid even with lenient fallback, got {}",
1928 ratio
1929 );
1930 }
1931
1932 #[test]
1933 fn test_excerpt_aware_reversal_error_recovery() {
1934 let diffs = vec![indoc! {"
1935 @@ -1,2 +1,3 @@
1936 nonexistent_context
1937 +added
1938 more_nonexistent
1939 "}];
1940 let excerpt_content = indoc! {"
1941 completely
1942 different
1943 content
1944 "};
1945 let predicted_content = indoc! {"
1946 completely
1947 modified
1948 content
1949 "};
1950
1951 let overlap =
1952 compute_excerpt_aware_reversal_overlap(&diffs, excerpt_content, 0, predicted_content);
1953
1954 assert!(
1955 overlap.ratio() >= 0.0 && overlap.ratio() <= 1.0,
1956 "Should handle failed diff application gracefully"
1957 );
1958 }
1959
1960 #[test]
1961 fn test_only_most_recent_edit_tracked() {
1962 let prompt_inputs = make_test_prompt_inputs(
1963 indoc! {"
1964 line1
1965 first_add
1966 second_add
1967 line2
1968 "},
1969 vec![
1970 Arc::new(zeta_prompt::Event::BufferChange {
1971 path: Arc::from(Path::new("src/test.rs")),
1972 old_path: Arc::from(Path::new("src/test.rs")),
1973 diff: indoc! {"
1974 @@ -1,2 +1,3 @@
1975 line1
1976 +first_add
1977 line2
1978 "}
1979 .into(),
1980 predicted: false,
1981 in_open_source_repo: false,
1982 }),
1983 Arc::new(zeta_prompt::Event::BufferChange {
1984 path: Arc::from(Path::new("src/test.rs")),
1985 old_path: Arc::from(Path::new("src/test.rs")),
1986 diff: indoc! {"
1987 @@ -2,2 +2,3 @@
1988 first_add
1989 +second_add
1990 line2
1991 "}
1992 .into(),
1993 predicted: false,
1994 in_open_source_repo: false,
1995 }),
1996 ],
1997 None,
1998 );
1999
2000 let predicted = indoc! {"
2001 line1
2002 first_add
2003 line2
2004 "};
2005 let ratio =
2006 compute_prediction_reversal_ratio(&prompt_inputs, predicted, Path::new("src/test.rs"));
2007
2008 assert!(
2009 ratio > 0.9,
2010 "Expected high reversal ratio when prediction exactly reverses the most recent edit, got {}",
2011 ratio
2012 );
2013 }
2014}