1use std::ops::Range;
2use std::path::Path;
3use std::sync::Arc;
4
5use edit_prediction::udiff::apply_diff_to_string;
6use language::text_diff;
7
8use crate::example::ExamplePromptInputs;
9
10pub fn reverse_diff(diff: &str) -> String {
11 let mut result: String = diff
12 .lines()
13 .map(|line| {
14 if line.starts_with("--- ") {
15 line.replacen("--- ", "+++ ", 1)
16 } else if line.starts_with("+++ ") {
17 line.replacen("+++ ", "--- ", 1)
18 } else if line.starts_with('+') && !line.starts_with("+++") {
19 format!("-{}", &line[1..])
20 } else if line.starts_with('-') && !line.starts_with("---") {
21 format!("+{}", &line[1..])
22 } else {
23 line.to_string()
24 }
25 })
26 .collect::<Vec<_>>()
27 .join("\n");
28 if diff.ends_with('\n') {
29 result.push('\n');
30 }
31 result
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35struct GranularEdit {
36 range: Range<usize>,
37 old_text: String,
38 new_text: String,
39}
40
41fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
42 text_diff(old_text, new_text)
43 .into_iter()
44 .map(|(range, new_text)| GranularEdit {
45 old_text: old_text[range.clone()].to_string(),
46 range,
47 new_text: new_text.to_string(),
48 })
49 .collect()
50}
51
52#[derive(Debug, Clone)]
53struct HistoryAdditionRange {
54 range_in_current: Range<usize>,
55}
56
57#[derive(Debug, Clone)]
58struct HistoryDeletionRange {
59 deleted_text: String,
60}
61
62fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryAdditionRange> {
63 let mut result = Vec::new();
64 let mut offset_delta: isize = 0;
65
66 for edit in history_edits {
67 if !edit.new_text.is_empty() {
68 let new_start = (edit.range.start as isize + offset_delta) as usize;
69 let new_end = new_start + edit.new_text.len();
70 result.push(HistoryAdditionRange {
71 range_in_current: new_start..new_end,
72 });
73 }
74
75 offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
76 }
77
78 result
79}
80
81fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryDeletionRange> {
82 history_edits
83 .iter()
84 .filter(|edit| !edit.old_text.is_empty())
85 .map(|edit| HistoryDeletionRange {
86 deleted_text: edit.old_text.clone(),
87 })
88 .collect()
89}
90
91#[derive(Debug, Clone, Default, PartialEq, Eq)]
92struct ReversalOverlap {
93 chars_reversing_user_edits: usize,
94 total_chars_in_prediction: usize,
95}
96
97impl ReversalOverlap {
98 fn ratio(&self) -> f32 {
99 if self.total_chars_in_prediction == 0 {
100 0.0
101 } else {
102 self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32
103 }
104 }
105}
106
107/// Check if `needle` is a subsequence of `haystack` (characters appear in order, not necessarily contiguous).
108fn is_subsequence(needle: &str, haystack: &str) -> bool {
109 let mut needle_chars = needle.chars().peekable();
110 for c in haystack.chars() {
111 if needle_chars.peek() == Some(&c) {
112 needle_chars.next();
113 }
114 }
115 needle_chars.peek().is_none()
116}
117
118/// Normalize edits where `old_text` appears as a subsequence within `new_text`.
119/// When the user's text is preserved (in order) within the prediction, we only
120/// count the newly inserted characters, not the preserved ones.
121/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()")
122/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars
123fn normalize_extension_edits(edits: Vec<GranularEdit>) -> Vec<GranularEdit> {
124 edits
125 .into_iter()
126 .map(|edit| {
127 if edit.old_text.is_empty() {
128 return edit;
129 }
130
131 if is_subsequence(&edit.old_text, &edit.new_text) {
132 let inserted_len = edit.new_text.len() - edit.old_text.len();
133 GranularEdit {
134 range: edit.range.start..edit.range.start,
135 old_text: String::new(),
136 new_text: edit.new_text.chars().take(inserted_len).collect(),
137 }
138 } else {
139 edit
140 }
141 })
142 .collect()
143}
144
145fn compute_reversal_overlap(
146 original_content: &str,
147 current_content: &str,
148 predicted_content: &str,
149) -> ReversalOverlap {
150 let history_edits = compute_granular_edits(original_content, current_content);
151 let prediction_edits =
152 normalize_extension_edits(compute_granular_edits(current_content, predicted_content));
153
154 let history_addition_ranges = compute_history_addition_ranges(&history_edits);
155 let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
156
157 let reversed_additions =
158 compute_reversed_additions(&history_addition_ranges, &prediction_edits);
159 let restored_deletions =
160 compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
161
162 let total_chars_in_prediction: usize = prediction_edits
163 .iter()
164 .map(|e| e.new_text.len() + e.old_text.len())
165 .sum();
166
167 ReversalOverlap {
168 chars_reversing_user_edits: reversed_additions + restored_deletions,
169 total_chars_in_prediction,
170 }
171}
172
173fn compute_reversed_additions(
174 history_addition_ranges: &[HistoryAdditionRange],
175 prediction_edits: &[GranularEdit],
176) -> usize {
177 let mut reversed_chars = 0;
178
179 for pred_edit in prediction_edits {
180 for history_addition in history_addition_ranges {
181 let overlap_start = pred_edit
182 .range
183 .start
184 .max(history_addition.range_in_current.start);
185 let overlap_end = pred_edit
186 .range
187 .end
188 .min(history_addition.range_in_current.end);
189
190 if overlap_start < overlap_end {
191 reversed_chars += overlap_end - overlap_start;
192 }
193 }
194 }
195
196 reversed_chars
197}
198
199fn compute_restored_deletions(
200 history_deletion_ranges: &[HistoryDeletionRange],
201 prediction_edits: &[GranularEdit],
202) -> usize {
203 let history_deleted_text: String = history_deletion_ranges
204 .iter()
205 .map(|r| r.deleted_text.as_str())
206 .collect();
207
208 let prediction_added_text: String = prediction_edits
209 .iter()
210 .map(|e| e.new_text.as_str())
211 .collect();
212
213 compute_lcs_length(&history_deleted_text, &prediction_added_text)
214}
215
216fn compute_lcs_length(a: &str, b: &str) -> usize {
217 let a_chars: Vec<char> = a.chars().collect();
218 let b_chars: Vec<char> = b.chars().collect();
219 let m = a_chars.len();
220 let n = b_chars.len();
221
222 if m == 0 || n == 0 {
223 return 0;
224 }
225
226 let mut prev = vec![0; n + 1];
227 let mut curr = vec![0; n + 1];
228
229 for i in 1..=m {
230 for j in 1..=n {
231 if a_chars[i - 1] == b_chars[j - 1] {
232 curr[j] = prev[j - 1] + 1;
233 } else {
234 curr[j] = prev[j].max(curr[j - 1]);
235 }
236 }
237 std::mem::swap(&mut prev, &mut curr);
238 curr.fill(0);
239 }
240
241 prev[n]
242}
243
244pub fn filter_edit_history_by_path<'a>(
245 edit_history: &'a [Arc<zeta_prompt::Event>],
246 cursor_path: &std::path::Path,
247) -> Vec<&'a zeta_prompt::Event> {
248 edit_history
249 .iter()
250 .filter(|event| match event.as_ref() {
251 zeta_prompt::Event::BufferChange { path, .. } => {
252 let event_path = path.as_ref();
253 if event_path == cursor_path {
254 return true;
255 }
256 let stripped = event_path
257 .components()
258 .skip(1)
259 .collect::<std::path::PathBuf>();
260 stripped == cursor_path
261 }
262 })
263 .map(|arc| arc.as_ref())
264 .collect()
265}
266
267pub fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str {
268 match event {
269 zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(),
270 }
271}
272
273pub fn compute_prediction_reversal_ratio(
274 prompt_inputs: &ExamplePromptInputs,
275 predicted_content: &str,
276 cursor_path: &Path,
277) -> f32 {
278 let current_content = &prompt_inputs.content;
279
280 let edit_history: &[Arc<zeta_prompt::Event>] = &prompt_inputs.edit_history;
281 let relevant_events = filter_edit_history_by_path(edit_history, cursor_path);
282
283 let mut original_content = current_content.to_string();
284 for event in relevant_events.into_iter().rev() {
285 let diff = extract_diff_from_event(event);
286 if diff.is_empty() {
287 continue;
288 }
289 let reversed = reverse_diff(diff);
290 let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed);
291 match apply_diff_to_string(&with_headers, &original_content) {
292 Ok(updated_content) => original_content = updated_content,
293 Err(err) => {
294 log::warn!(
295 "Failed to reconstruct original content for reversal tracking: Failed to apply reversed diff: {:#}",
296 err
297 );
298 return 0.0;
299 }
300 }
301 }
302
303 let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content);
304 overlap.ratio()
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use edit_prediction::udiff::apply_diff_to_string;
311
312 #[test]
313 fn test_reversal_overlap() {
314 struct Case {
315 name: &'static str,
316 original: &'static str,
317 current: &'static str,
318 predicted: &'static str,
319 expected_reversal_chars: usize,
320 expected_total_chars: usize,
321 }
322
323 let cases = [
324 Case {
325 name: "user_adds_line_prediction_removes_it",
326 original: "a\nb\nc",
327 current: "a\nnew line\nb\nc",
328 predicted: "a\nb\nc",
329 expected_reversal_chars: 9,
330 expected_total_chars: 9,
331 },
332 Case {
333 name: "user_deletes_line_prediction_restores_it",
334 original: "a\ndeleted\nb",
335 current: "a\nb",
336 predicted: "a\ndeleted\nb",
337 expected_reversal_chars: 8,
338 expected_total_chars: 8,
339 },
340 Case {
341 name: "user_deletes_text_prediction_restores_partial",
342 original: "hello beautiful world",
343 current: "hello world",
344 predicted: "hello beautiful world",
345 expected_reversal_chars: 10,
346 expected_total_chars: 10,
347 },
348 Case {
349 name: "user_deletes_foo_prediction_adds_bar",
350 original: "foo",
351 current: "",
352 predicted: "bar",
353 expected_reversal_chars: 0,
354 expected_total_chars: 3,
355 },
356 Case {
357 name: "independent_edits_different_locations",
358 original: "line1\nline2\nline3",
359 current: "LINE1\nline2\nline3",
360 predicted: "LINE1\nline2\nLINE3",
361 expected_reversal_chars: 0,
362 expected_total_chars: 10,
363 },
364 Case {
365 name: "no_history_edits",
366 original: "same",
367 current: "same",
368 predicted: "different",
369 expected_reversal_chars: 0,
370 expected_total_chars: 13,
371 },
372 Case {
373 name: "user_replaces_text_prediction_reverses",
374 original: "keep\ndelete_me\nkeep2",
375 current: "keep\nadded\nkeep2",
376 predicted: "keep\ndelete_me\nkeep2",
377 expected_reversal_chars: 14,
378 expected_total_chars: 14,
379 },
380 Case {
381 name: "user_modifies_word_prediction_modifies_differently",
382 original: "the quick brown fox",
383 current: "the slow brown fox",
384 predicted: "the fast brown fox",
385 expected_reversal_chars: 4,
386 expected_total_chars: 8,
387 },
388 Case {
389 name: "user finishes function name (suffix)",
390 original: "",
391 current: "epr",
392 predicted: "eprintln!()",
393 expected_reversal_chars: 0,
394 expected_total_chars: 8,
395 },
396 Case {
397 name: "user starts function name (prefix)",
398 original: "",
399 current: "my_function()",
400 predicted: "test_my_function()",
401 expected_reversal_chars: 0,
402 expected_total_chars: 5,
403 },
404 Case {
405 name: "user types partial, prediction extends in multiple places",
406 original: "",
407 current: "test_my_function",
408 predicted: "a_test_for_my_special_function_plz",
409 expected_reversal_chars: 0,
410 expected_total_chars: 18,
411 },
412 // Edge cases for subsequence matching
413 Case {
414 name: "subsequence with interleaved underscores",
415 original: "",
416 current: "a_b_c",
417 predicted: "_a__b__c__",
418 expected_reversal_chars: 0,
419 expected_total_chars: 5,
420 },
421 Case {
422 name: "not a subsequence - different characters",
423 original: "",
424 current: "abc",
425 predicted: "xyz",
426 expected_reversal_chars: 3,
427 expected_total_chars: 6,
428 },
429 Case {
430 name: "not a subsequence - wrong order",
431 original: "",
432 current: "abc",
433 predicted: "cba",
434 expected_reversal_chars: 3,
435 expected_total_chars: 6,
436 },
437 Case {
438 name: "partial subsequence - only some chars match",
439 original: "",
440 current: "abcd",
441 predicted: "axbx",
442 expected_reversal_chars: 4,
443 expected_total_chars: 8,
444 },
445 // Common completion patterns
446 Case {
447 name: "completing a method call",
448 original: "",
449 current: "vec.pu",
450 predicted: "vec.push(item)",
451 expected_reversal_chars: 0,
452 expected_total_chars: 8,
453 },
454 Case {
455 name: "completing an import statement",
456 original: "",
457 current: "use std::col",
458 predicted: "use std::collections::HashMap",
459 expected_reversal_chars: 0,
460 expected_total_chars: 17,
461 },
462 Case {
463 name: "completing a struct field",
464 original: "",
465 current: "name: St",
466 predicted: "name: String",
467 expected_reversal_chars: 0,
468 expected_total_chars: 4,
469 },
470 Case {
471 name: "prediction replaces with completely different text",
472 original: "",
473 current: "hello",
474 predicted: "world",
475 expected_reversal_chars: 5,
476 expected_total_chars: 10,
477 },
478 Case {
479 name: "empty prediction removes user text",
480 original: "",
481 current: "mistake",
482 predicted: "",
483 expected_reversal_chars: 7,
484 expected_total_chars: 7,
485 },
486 ];
487
488 for case in &cases {
489 let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
490 assert_eq!(
491 overlap.chars_reversing_user_edits, case.expected_reversal_chars,
492 "Test '{}': expected {} reversal chars, got {}",
493 case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
494 );
495 assert_eq!(
496 overlap.total_chars_in_prediction, case.expected_total_chars,
497 "Test '{}': expected {} total chars, got {}",
498 case.name, case.expected_total_chars, overlap.total_chars_in_prediction
499 );
500 }
501 }
502
503 #[test]
504 fn test_reverse_diff() {
505 let forward_diff = "\
506--- a/file.rs
507+++ b/file.rs
508@@ -1,3 +1,4 @@
509 fn main() {
510+ let x = 42;
511 println!(\"hello\");
512}";
513
514 let reversed = reverse_diff(forward_diff);
515
516 assert!(
517 reversed.contains("+++ a/file.rs"),
518 "Should have +++ for old path"
519 );
520 assert!(
521 reversed.contains("--- b/file.rs"),
522 "Should have --- for new path"
523 );
524 assert!(
525 reversed.contains("- let x = 42;"),
526 "Added line should become deletion"
527 );
528 assert!(
529 reversed.contains(" fn main()"),
530 "Context lines should be unchanged"
531 );
532 }
533
534 #[test]
535 fn test_reverse_diff_roundtrip() {
536 // Applying a diff and then its reverse should get back to original
537 let original = "first line\nhello world\nlast line\n";
538 let modified = "first line\nhello beautiful world\nlast line\n";
539
540 // unified_diff doesn't include file headers, but apply_diff_to_string needs them
541 let diff_body = language::unified_diff(original, modified);
542 let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body);
543 let reversed_diff = reverse_diff(&forward_diff);
544
545 // Apply forward diff to original
546 let after_forward = apply_diff_to_string(&forward_diff, original).unwrap();
547 assert_eq!(after_forward, modified);
548
549 // Apply reversed diff to modified
550 let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap();
551 assert_eq!(after_reverse, original);
552 }
553
554 #[test]
555 fn test_filter_edit_history_by_path() {
556 // Test that filter_edit_history_by_path correctly matches paths when
557 // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs")
558 // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs")
559 let events = vec![
560 Arc::new(zeta_prompt::Event::BufferChange {
561 path: Arc::from(Path::new("myrepo/src/file.rs")),
562 old_path: Arc::from(Path::new("myrepo/src/file.rs")),
563 diff: "@@ -1 +1 @@\n-old\n+new".into(),
564 predicted: false,
565 in_open_source_repo: true,
566 }),
567 Arc::new(zeta_prompt::Event::BufferChange {
568 path: Arc::from(Path::new("myrepo/other.rs")),
569 old_path: Arc::from(Path::new("myrepo/other.rs")),
570 diff: "@@ -1 +1 @@\n-a\n+b".into(),
571 predicted: false,
572 in_open_source_repo: true,
573 }),
574 Arc::new(zeta_prompt::Event::BufferChange {
575 path: Arc::from(Path::new("src/file.rs")),
576 old_path: Arc::from(Path::new("src/file.rs")),
577 diff: "@@ -1 +1 @@\n-x\n+y".into(),
578 predicted: false,
579 in_open_source_repo: true,
580 }),
581 ];
582
583 // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path
584 // "src/file.rs" exact match
585 let cursor_path = Path::new("src/file.rs");
586 let filtered = filter_edit_history_by_path(&events, cursor_path);
587 assert_eq!(
588 filtered.len(),
589 2,
590 "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)"
591 );
592
593 // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs"
594 // "src/file.rs" stripped -> "file.rs" == "file.rs"
595 let cursor_path = Path::new("file.rs");
596 let filtered = filter_edit_history_by_path(&events, cursor_path);
597 assert_eq!(
598 filtered.len(),
599 1,
600 "Should only match src/file.rs (stripped to file.rs)"
601 );
602
603 // "myrepo/other.rs" stripped -> "other.rs" == "other.rs"
604 let cursor_path = Path::new("other.rs");
605 let filtered = filter_edit_history_by_path(&events, cursor_path);
606 assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs");
607 }
608
609 #[test]
610 fn test_reverse_diff_preserves_trailing_newline() {
611 let diff_with_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new\n";
612 let reversed = reverse_diff(diff_with_trailing_newline);
613 assert!(
614 reversed.ends_with('\n'),
615 "Reversed diff should preserve trailing newline"
616 );
617
618 let diff_without_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new";
619 let reversed = reverse_diff(diff_without_trailing_newline);
620 assert!(
621 !reversed.ends_with('\n'),
622 "Reversed diff should not add trailing newline if original didn't have one"
623 );
624 }
625}