reversal_tracking.rs

  1use std::ops::Range;
  2use std::path::Path;
  3use std::sync::Arc;
  4
  5use edit_prediction::udiff::apply_diff_to_string;
  6use language::text_diff;
  7
  8use crate::example::ExamplePromptInputs;
  9
 10pub fn reverse_diff(diff: &str) -> String {
 11    let mut result: String = diff
 12        .lines()
 13        .map(|line| {
 14            if line.starts_with("--- ") {
 15                line.replacen("--- ", "+++ ", 1)
 16            } else if line.starts_with("+++ ") {
 17                line.replacen("+++ ", "--- ", 1)
 18            } else if line.starts_with('+') && !line.starts_with("+++") {
 19                format!("-{}", &line[1..])
 20            } else if line.starts_with('-') && !line.starts_with("---") {
 21                format!("+{}", &line[1..])
 22            } else {
 23                line.to_string()
 24            }
 25        })
 26        .collect::<Vec<_>>()
 27        .join("\n");
 28    if diff.ends_with('\n') {
 29        result.push('\n');
 30    }
 31    result
 32}
 33
 34#[derive(Debug, Clone, PartialEq, Eq)]
 35struct GranularEdit {
 36    range: Range<usize>,
 37    old_text: String,
 38    new_text: String,
 39}
 40
 41fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
 42    text_diff(old_text, new_text)
 43        .into_iter()
 44        .map(|(range, new_text)| GranularEdit {
 45            old_text: old_text[range.clone()].to_string(),
 46            range,
 47            new_text: new_text.to_string(),
 48        })
 49        .collect()
 50}
 51
 52#[derive(Debug, Clone)]
 53struct HistoryAdditionRange {
 54    range_in_current: Range<usize>,
 55}
 56
 57#[derive(Debug, Clone)]
 58struct HistoryDeletionRange {
 59    deleted_text: String,
 60}
 61
 62fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryAdditionRange> {
 63    let mut result = Vec::new();
 64    let mut offset_delta: isize = 0;
 65
 66    for edit in history_edits {
 67        if !edit.new_text.is_empty() {
 68            let new_start = (edit.range.start as isize + offset_delta) as usize;
 69            let new_end = new_start + edit.new_text.len();
 70            result.push(HistoryAdditionRange {
 71                range_in_current: new_start..new_end,
 72            });
 73        }
 74
 75        offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
 76    }
 77
 78    result
 79}
 80
 81fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryDeletionRange> {
 82    history_edits
 83        .iter()
 84        .filter(|edit| !edit.old_text.is_empty())
 85        .map(|edit| HistoryDeletionRange {
 86            deleted_text: edit.old_text.clone(),
 87        })
 88        .collect()
 89}
 90
 91#[derive(Debug, Clone, Default, PartialEq, Eq)]
 92struct ReversalOverlap {
 93    chars_reversing_user_edits: usize,
 94    total_chars_in_prediction: usize,
 95}
 96
 97impl ReversalOverlap {
 98    fn ratio(&self) -> f32 {
 99        if self.total_chars_in_prediction == 0 {
100            0.0
101        } else {
102            self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32
103        }
104    }
105}
106
107/// Check if `needle` is a subsequence of `haystack` (characters appear in order, not necessarily contiguous).
108fn is_subsequence(needle: &str, haystack: &str) -> bool {
109    let mut needle_chars = needle.chars().peekable();
110    for c in haystack.chars() {
111        if needle_chars.peek() == Some(&c) {
112            needle_chars.next();
113        }
114    }
115    needle_chars.peek().is_none()
116}
117
118/// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension),
119/// or where `new_text` appears as a subsequence within `old_text` (reduction).
120///
121/// For extensions: when the user's text is preserved (in order) within the prediction,
122/// we only count the newly inserted characters, not the preserved ones.
123/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()")
124/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars
125///
126/// For reductions: when the prediction's text is preserved (in order) within the original,
127/// we only count the deleted characters, not the preserved ones.
128/// E.g., "ifrom" → "from" becomes 1 deleted char ("i")
129fn normalize_extension_edits(edits: Vec<GranularEdit>) -> Vec<GranularEdit> {
130    edits
131        .into_iter()
132        .map(|edit| {
133            if edit.old_text.is_empty() || edit.new_text.is_empty() {
134                return edit;
135            }
136
137            if is_subsequence(&edit.old_text, &edit.new_text) {
138                let inserted_len = edit.new_text.len() - edit.old_text.len();
139                GranularEdit {
140                    range: edit.range.start..edit.range.start,
141                    old_text: String::new(),
142                    new_text: edit.new_text.chars().take(inserted_len).collect(),
143                }
144            } else if is_subsequence(&edit.new_text, &edit.old_text) {
145                let deleted_len = edit.old_text.len() - edit.new_text.len();
146                GranularEdit {
147                    range: edit.range.start..edit.range.start + deleted_len,
148                    old_text: edit.old_text.chars().take(deleted_len).collect(),
149                    new_text: String::new(),
150                }
151            } else {
152                edit
153            }
154        })
155        .collect()
156}
157
158fn compute_reversal_overlap(
159    original_content: &str,
160    current_content: &str,
161    predicted_content: &str,
162) -> ReversalOverlap {
163    let history_edits =
164        normalize_extension_edits(compute_granular_edits(original_content, current_content));
165    let prediction_edits =
166        normalize_extension_edits(compute_granular_edits(current_content, predicted_content));
167
168    let history_addition_ranges = compute_history_addition_ranges(&history_edits);
169    let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
170
171    let reversed_additions =
172        compute_reversed_additions(&history_addition_ranges, &prediction_edits);
173    let restored_deletions =
174        compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
175
176    let total_chars_in_prediction: usize = prediction_edits
177        .iter()
178        .map(|e| e.new_text.len() + e.old_text.len())
179        .sum();
180
181    ReversalOverlap {
182        chars_reversing_user_edits: reversed_additions + restored_deletions,
183        total_chars_in_prediction,
184    }
185}
186
187fn compute_reversed_additions(
188    history_addition_ranges: &[HistoryAdditionRange],
189    prediction_edits: &[GranularEdit],
190) -> usize {
191    let mut reversed_chars = 0;
192
193    for pred_edit in prediction_edits {
194        for history_addition in history_addition_ranges {
195            let overlap_start = pred_edit
196                .range
197                .start
198                .max(history_addition.range_in_current.start);
199            let overlap_end = pred_edit
200                .range
201                .end
202                .min(history_addition.range_in_current.end);
203
204            if overlap_start < overlap_end {
205                reversed_chars += overlap_end - overlap_start;
206            }
207        }
208    }
209
210    reversed_chars
211}
212
213fn compute_restored_deletions(
214    history_deletion_ranges: &[HistoryDeletionRange],
215    prediction_edits: &[GranularEdit],
216) -> usize {
217    let history_deleted_text: String = history_deletion_ranges
218        .iter()
219        .map(|r| r.deleted_text.as_str())
220        .collect();
221
222    let prediction_added_text: String = prediction_edits
223        .iter()
224        .map(|e| e.new_text.as_str())
225        .collect();
226
227    compute_lcs_length(&history_deleted_text, &prediction_added_text)
228}
229
230fn compute_lcs_length(a: &str, b: &str) -> usize {
231    let a_chars: Vec<char> = a.chars().collect();
232    let b_chars: Vec<char> = b.chars().collect();
233    let m = a_chars.len();
234    let n = b_chars.len();
235
236    if m == 0 || n == 0 {
237        return 0;
238    }
239
240    let mut prev = vec![0; n + 1];
241    let mut curr = vec![0; n + 1];
242
243    for i in 1..=m {
244        for j in 1..=n {
245            if a_chars[i - 1] == b_chars[j - 1] {
246                curr[j] = prev[j - 1] + 1;
247            } else {
248                curr[j] = prev[j].max(curr[j - 1]);
249            }
250        }
251        std::mem::swap(&mut prev, &mut curr);
252        curr.fill(0);
253    }
254
255    prev[n]
256}
257
258pub fn filter_edit_history_by_path<'a>(
259    edit_history: &'a [Arc<zeta_prompt::Event>],
260    cursor_path: &std::path::Path,
261) -> Vec<&'a zeta_prompt::Event> {
262    edit_history
263        .iter()
264        .filter(|event| match event.as_ref() {
265            zeta_prompt::Event::BufferChange { path, .. } => {
266                let event_path = path.as_ref();
267                if event_path == cursor_path {
268                    return true;
269                }
270                let stripped = event_path
271                    .components()
272                    .skip(1)
273                    .collect::<std::path::PathBuf>();
274                stripped == cursor_path
275            }
276        })
277        .map(|arc| arc.as_ref())
278        .collect()
279}
280
281pub fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str {
282    match event {
283        zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(),
284    }
285}
286
287pub fn compute_prediction_reversal_ratio(
288    prompt_inputs: &ExamplePromptInputs,
289    predicted_content: &str,
290    cursor_path: &Path,
291) -> f32 {
292    let current_content = &prompt_inputs.content;
293
294    let edit_history: &[Arc<zeta_prompt::Event>] = &prompt_inputs.edit_history;
295    let relevant_events = filter_edit_history_by_path(edit_history, cursor_path);
296
297    let mut original_content = current_content.to_string();
298    for event in relevant_events.into_iter().rev() {
299        let diff = extract_diff_from_event(event);
300        if diff.is_empty() {
301            continue;
302        }
303        let reversed = reverse_diff(diff);
304        let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed);
305        match apply_diff_to_string(&with_headers, &original_content) {
306            Ok(updated_content) => original_content = updated_content,
307            Err(err) => {
308                log::warn!(
309                    "Failed to reconstruct original content for reversal tracking: Failed to apply reversed diff: {:#}",
310                    err
311                );
312                return 0.0;
313            }
314        }
315    }
316
317    let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content);
318    overlap.ratio()
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use edit_prediction::udiff::apply_diff_to_string;
325
326    #[test]
327    fn test_reversal_overlap() {
328        struct Case {
329            name: &'static str,
330            original: &'static str,
331            current: &'static str,
332            predicted: &'static str,
333            expected_reversal_chars: usize,
334            expected_total_chars: usize,
335        }
336
337        let cases = [
338            Case {
339                name: "user_adds_line_prediction_removes_it",
340                original: "a\nb\nc",
341                current: "a\nnew line\nb\nc",
342                predicted: "a\nb\nc",
343                expected_reversal_chars: 9,
344                expected_total_chars: 9,
345            },
346            Case {
347                name: "user_deletes_line_prediction_restores_it",
348                original: "a\ndeleted\nb",
349                current: "a\nb",
350                predicted: "a\ndeleted\nb",
351                expected_reversal_chars: 8,
352                expected_total_chars: 8,
353            },
354            Case {
355                name: "user_deletes_text_prediction_restores_partial",
356                original: "hello beautiful world",
357                current: "hello world",
358                predicted: "hello beautiful world",
359                expected_reversal_chars: 10,
360                expected_total_chars: 10,
361            },
362            Case {
363                name: "user_deletes_foo_prediction_adds_bar",
364                original: "foo",
365                current: "",
366                predicted: "bar",
367                expected_reversal_chars: 0,
368                expected_total_chars: 3,
369            },
370            Case {
371                name: "independent_edits_different_locations",
372                original: "line1\nline2\nline3",
373                current: "LINE1\nline2\nline3",
374                predicted: "LINE1\nline2\nLINE3",
375                expected_reversal_chars: 0,
376                expected_total_chars: 10,
377            },
378            Case {
379                name: "no_history_edits",
380                original: "same",
381                current: "same",
382                predicted: "different",
383                expected_reversal_chars: 0,
384                expected_total_chars: 13,
385            },
386            Case {
387                name: "user_replaces_text_prediction_reverses",
388                original: "keep\ndelete_me\nkeep2",
389                current: "keep\nadded\nkeep2",
390                predicted: "keep\ndelete_me\nkeep2",
391                expected_reversal_chars: 14,
392                expected_total_chars: 14,
393            },
394            Case {
395                name: "user_modifies_word_prediction_modifies_differently",
396                original: "the quick brown fox",
397                current: "the slow brown fox",
398                predicted: "the fast brown fox",
399                expected_reversal_chars: 4,
400                expected_total_chars: 8,
401            },
402            Case {
403                name: "user finishes function name (suffix)",
404                original: "",
405                current: "epr",
406                predicted: "eprintln!()",
407                expected_reversal_chars: 0,
408                expected_total_chars: 8,
409            },
410            Case {
411                name: "user starts function name (prefix)",
412                original: "",
413                current: "my_function()",
414                predicted: "test_my_function()",
415                expected_reversal_chars: 0,
416                expected_total_chars: 5,
417            },
418            Case {
419                name: "user types partial, prediction extends in multiple places",
420                original: "",
421                current: "test_my_function",
422                predicted: "a_test_for_my_special_function_plz",
423                expected_reversal_chars: 0,
424                expected_total_chars: 18,
425            },
426            // Edge cases for subsequence matching
427            Case {
428                name: "subsequence with interleaved underscores",
429                original: "",
430                current: "a_b_c",
431                predicted: "_a__b__c__",
432                expected_reversal_chars: 0,
433                expected_total_chars: 5,
434            },
435            Case {
436                name: "not a subsequence - different characters",
437                original: "",
438                current: "abc",
439                predicted: "xyz",
440                expected_reversal_chars: 3,
441                expected_total_chars: 6,
442            },
443            Case {
444                name: "not a subsequence - wrong order",
445                original: "",
446                current: "abc",
447                predicted: "cba",
448                expected_reversal_chars: 3,
449                expected_total_chars: 6,
450            },
451            Case {
452                name: "partial subsequence - only some chars match",
453                original: "",
454                current: "abcd",
455                predicted: "axbx",
456                expected_reversal_chars: 4,
457                expected_total_chars: 8,
458            },
459            // Common completion patterns
460            Case {
461                name: "completing a method call",
462                original: "",
463                current: "vec.pu",
464                predicted: "vec.push(item)",
465                expected_reversal_chars: 0,
466                expected_total_chars: 8,
467            },
468            Case {
469                name: "completing an import statement",
470                original: "",
471                current: "use std::col",
472                predicted: "use std::collections::HashMap",
473                expected_reversal_chars: 0,
474                expected_total_chars: 17,
475            },
476            Case {
477                name: "completing a struct field",
478                original: "",
479                current: "name: St",
480                predicted: "name: String",
481                expected_reversal_chars: 0,
482                expected_total_chars: 4,
483            },
484            Case {
485                name: "prediction replaces with completely different text",
486                original: "",
487                current: "hello",
488                predicted: "world",
489                expected_reversal_chars: 5,
490                expected_total_chars: 10,
491            },
492            Case {
493                name: "empty prediction removes user text",
494                original: "",
495                current: "mistake",
496                predicted: "",
497                expected_reversal_chars: 7,
498                expected_total_chars: 7,
499            },
500            Case {
501                name: "fixing typo is not reversal",
502                original: "",
503                current: "<dv",
504                predicted: "<div>",
505                expected_reversal_chars: 0,
506                expected_total_chars: 2,
507            },
508            Case {
509                name: "infix insertion not reversal",
510                original: "from my_project import Foo\n",
511                current: "ifrom my_project import Foo\n",
512                predicted: indoc::indoc! {"
513                    import
514                    from my_project import Foo
515                "},
516                expected_reversal_chars: 0,
517                expected_total_chars: 6,
518            },
519            Case {
520                name: "non-word based reversal",
521                original: "from",
522                current: "ifrom",
523                predicted: "from",
524                expected_reversal_chars: 1,
525                expected_total_chars: 1,
526            },
527        ];
528
529        for case in &cases {
530            let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
531            assert_eq!(
532                overlap.chars_reversing_user_edits, case.expected_reversal_chars,
533                "Test '{}': expected {} reversal chars, got {}",
534                case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
535            );
536            assert_eq!(
537                overlap.total_chars_in_prediction, case.expected_total_chars,
538                "Test '{}': expected {} total chars, got {}",
539                case.name, case.expected_total_chars, overlap.total_chars_in_prediction
540            );
541        }
542    }
543
544    #[test]
545    fn test_reverse_diff() {
546        let forward_diff = "\
547--- a/file.rs
548+++ b/file.rs
549@@ -1,3 +1,4 @@
550 fn main() {
551+    let x = 42;
552     println!(\"hello\");
553}";
554
555        let reversed = reverse_diff(forward_diff);
556
557        assert!(
558            reversed.contains("+++ a/file.rs"),
559            "Should have +++ for old path"
560        );
561        assert!(
562            reversed.contains("--- b/file.rs"),
563            "Should have --- for new path"
564        );
565        assert!(
566            reversed.contains("-    let x = 42;"),
567            "Added line should become deletion"
568        );
569        assert!(
570            reversed.contains(" fn main()"),
571            "Context lines should be unchanged"
572        );
573    }
574
575    #[test]
576    fn test_reverse_diff_roundtrip() {
577        // Applying a diff and then its reverse should get back to original
578        let original = "first line\nhello world\nlast line\n";
579        let modified = "first line\nhello beautiful world\nlast line\n";
580
581        // unified_diff doesn't include file headers, but apply_diff_to_string needs them
582        let diff_body = language::unified_diff(original, modified);
583        let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body);
584        let reversed_diff = reverse_diff(&forward_diff);
585
586        // Apply forward diff to original
587        let after_forward = apply_diff_to_string(&forward_diff, original).unwrap();
588        assert_eq!(after_forward, modified);
589
590        // Apply reversed diff to modified
591        let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap();
592        assert_eq!(after_reverse, original);
593    }
594
595    #[test]
596    fn test_filter_edit_history_by_path() {
597        // Test that filter_edit_history_by_path correctly matches paths when
598        // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs")
599        // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs")
600        let events = vec![
601            Arc::new(zeta_prompt::Event::BufferChange {
602                path: Arc::from(Path::new("myrepo/src/file.rs")),
603                old_path: Arc::from(Path::new("myrepo/src/file.rs")),
604                diff: "@@ -1 +1 @@\n-old\n+new".into(),
605                predicted: false,
606                in_open_source_repo: true,
607            }),
608            Arc::new(zeta_prompt::Event::BufferChange {
609                path: Arc::from(Path::new("myrepo/other.rs")),
610                old_path: Arc::from(Path::new("myrepo/other.rs")),
611                diff: "@@ -1 +1 @@\n-a\n+b".into(),
612                predicted: false,
613                in_open_source_repo: true,
614            }),
615            Arc::new(zeta_prompt::Event::BufferChange {
616                path: Arc::from(Path::new("src/file.rs")),
617                old_path: Arc::from(Path::new("src/file.rs")),
618                diff: "@@ -1 +1 @@\n-x\n+y".into(),
619                predicted: false,
620                in_open_source_repo: true,
621            }),
622        ];
623
624        // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path
625        // "src/file.rs" exact match
626        let cursor_path = Path::new("src/file.rs");
627        let filtered = filter_edit_history_by_path(&events, cursor_path);
628        assert_eq!(
629            filtered.len(),
630            2,
631            "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)"
632        );
633
634        // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs"
635        // "src/file.rs" stripped -> "file.rs" == "file.rs"
636        let cursor_path = Path::new("file.rs");
637        let filtered = filter_edit_history_by_path(&events, cursor_path);
638        assert_eq!(
639            filtered.len(),
640            1,
641            "Should only match src/file.rs (stripped to file.rs)"
642        );
643
644        // "myrepo/other.rs" stripped -> "other.rs" == "other.rs"
645        let cursor_path = Path::new("other.rs");
646        let filtered = filter_edit_history_by_path(&events, cursor_path);
647        assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs");
648    }
649
650    #[test]
651    fn test_reverse_diff_preserves_trailing_newline() {
652        let diff_with_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new\n";
653        let reversed = reverse_diff(diff_with_trailing_newline);
654        assert!(
655            reversed.ends_with('\n'),
656            "Reversed diff should preserve trailing newline"
657        );
658
659        let diff_without_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new";
660        let reversed = reverse_diff(diff_without_trailing_newline);
661        assert!(
662            !reversed.ends_with('\n'),
663            "Reversed diff should not add trailing newline if original didn't have one"
664        );
665    }
666}