reversal_tracking.rs

  1use std::ops::Range;
  2use std::path::Path;
  3use std::sync::Arc;
  4
  5use edit_prediction::udiff::apply_diff_to_string;
  6use language::text_diff;
  7
  8use crate::example::ExamplePromptInputs;
  9
 10pub fn reverse_diff(diff: &str) -> String {
 11    let mut result: String = diff
 12        .lines()
 13        .map(|line| {
 14            if line.starts_with("--- ") {
 15                line.replacen("--- ", "+++ ", 1)
 16            } else if line.starts_with("+++ ") {
 17                line.replacen("+++ ", "--- ", 1)
 18            } else if line.starts_with('+') && !line.starts_with("+++") {
 19                format!("-{}", &line[1..])
 20            } else if line.starts_with('-') && !line.starts_with("---") {
 21                format!("+{}", &line[1..])
 22            } else {
 23                line.to_string()
 24            }
 25        })
 26        .collect::<Vec<_>>()
 27        .join("\n");
 28    if diff.ends_with('\n') {
 29        result.push('\n');
 30    }
 31    result
 32}
 33
 34#[derive(Debug, Clone, PartialEq, Eq)]
 35struct GranularEdit {
 36    range: Range<usize>,
 37    old_text: String,
 38    new_text: String,
 39}
 40
 41fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
 42    text_diff(old_text, new_text)
 43        .into_iter()
 44        .map(|(range, new_text)| GranularEdit {
 45            old_text: old_text[range.clone()].to_string(),
 46            range,
 47            new_text: new_text.to_string(),
 48        })
 49        .collect()
 50}
 51
 52#[derive(Debug, Clone)]
 53struct HistoryAdditionRange {
 54    range_in_current: Range<usize>,
 55}
 56
 57#[derive(Debug, Clone)]
 58struct HistoryDeletionRange {
 59    deleted_text: String,
 60}
 61
 62fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryAdditionRange> {
 63    let mut result = Vec::new();
 64    let mut offset_delta: isize = 0;
 65
 66    for edit in history_edits {
 67        if !edit.new_text.is_empty() {
 68            let new_start = (edit.range.start as isize + offset_delta) as usize;
 69            let new_end = new_start + edit.new_text.len();
 70            result.push(HistoryAdditionRange {
 71                range_in_current: new_start..new_end,
 72            });
 73        }
 74
 75        offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
 76    }
 77
 78    result
 79}
 80
 81fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryDeletionRange> {
 82    history_edits
 83        .iter()
 84        .filter(|edit| !edit.old_text.is_empty())
 85        .map(|edit| HistoryDeletionRange {
 86            deleted_text: edit.old_text.clone(),
 87        })
 88        .collect()
 89}
 90
 91#[derive(Debug, Clone, Default, PartialEq, Eq)]
 92struct ReversalOverlap {
 93    chars_reversing_user_edits: usize,
 94    total_chars_in_prediction: usize,
 95}
 96
 97impl ReversalOverlap {
 98    fn ratio(&self) -> f32 {
 99        if self.total_chars_in_prediction == 0 {
100            0.0
101        } else {
102            self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32
103        }
104    }
105}
106
107/// Check if `needle` is a subsequence of `haystack` (characters appear in order, not necessarily contiguous).
108fn is_subsequence(needle: &str, haystack: &str) -> bool {
109    let mut needle_chars = needle.chars().peekable();
110    for c in haystack.chars() {
111        if needle_chars.peek() == Some(&c) {
112            needle_chars.next();
113        }
114    }
115    needle_chars.peek().is_none()
116}
117
118/// Normalize edits where `old_text` appears as a subsequence within `new_text`.
119/// When the user's text is preserved (in order) within the prediction, we only
120/// count the newly inserted characters, not the preserved ones.
121/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()")
122/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars
123fn normalize_extension_edits(edits: Vec<GranularEdit>) -> Vec<GranularEdit> {
124    edits
125        .into_iter()
126        .map(|edit| {
127            if edit.old_text.is_empty() {
128                return edit;
129            }
130
131            if is_subsequence(&edit.old_text, &edit.new_text) {
132                let inserted_len = edit.new_text.len() - edit.old_text.len();
133                GranularEdit {
134                    range: edit.range.start..edit.range.start,
135                    old_text: String::new(),
136                    new_text: edit.new_text.chars().take(inserted_len).collect(),
137                }
138            } else {
139                edit
140            }
141        })
142        .collect()
143}
144
145fn compute_reversal_overlap(
146    original_content: &str,
147    current_content: &str,
148    predicted_content: &str,
149) -> ReversalOverlap {
150    let history_edits = compute_granular_edits(original_content, current_content);
151    let prediction_edits =
152        normalize_extension_edits(compute_granular_edits(current_content, predicted_content));
153
154    let history_addition_ranges = compute_history_addition_ranges(&history_edits);
155    let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
156
157    let reversed_additions =
158        compute_reversed_additions(&history_addition_ranges, &prediction_edits);
159    let restored_deletions =
160        compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
161
162    let total_chars_in_prediction: usize = prediction_edits
163        .iter()
164        .map(|e| e.new_text.len() + e.old_text.len())
165        .sum();
166
167    ReversalOverlap {
168        chars_reversing_user_edits: reversed_additions + restored_deletions,
169        total_chars_in_prediction,
170    }
171}
172
173fn compute_reversed_additions(
174    history_addition_ranges: &[HistoryAdditionRange],
175    prediction_edits: &[GranularEdit],
176) -> usize {
177    let mut reversed_chars = 0;
178
179    for pred_edit in prediction_edits {
180        for history_addition in history_addition_ranges {
181            let overlap_start = pred_edit
182                .range
183                .start
184                .max(history_addition.range_in_current.start);
185            let overlap_end = pred_edit
186                .range
187                .end
188                .min(history_addition.range_in_current.end);
189
190            if overlap_start < overlap_end {
191                reversed_chars += overlap_end - overlap_start;
192            }
193        }
194    }
195
196    reversed_chars
197}
198
199fn compute_restored_deletions(
200    history_deletion_ranges: &[HistoryDeletionRange],
201    prediction_edits: &[GranularEdit],
202) -> usize {
203    let history_deleted_text: String = history_deletion_ranges
204        .iter()
205        .map(|r| r.deleted_text.as_str())
206        .collect();
207
208    let prediction_added_text: String = prediction_edits
209        .iter()
210        .map(|e| e.new_text.as_str())
211        .collect();
212
213    compute_lcs_length(&history_deleted_text, &prediction_added_text)
214}
215
216fn compute_lcs_length(a: &str, b: &str) -> usize {
217    let a_chars: Vec<char> = a.chars().collect();
218    let b_chars: Vec<char> = b.chars().collect();
219    let m = a_chars.len();
220    let n = b_chars.len();
221
222    if m == 0 || n == 0 {
223        return 0;
224    }
225
226    let mut prev = vec![0; n + 1];
227    let mut curr = vec![0; n + 1];
228
229    for i in 1..=m {
230        for j in 1..=n {
231            if a_chars[i - 1] == b_chars[j - 1] {
232                curr[j] = prev[j - 1] + 1;
233            } else {
234                curr[j] = prev[j].max(curr[j - 1]);
235            }
236        }
237        std::mem::swap(&mut prev, &mut curr);
238        curr.fill(0);
239    }
240
241    prev[n]
242}
243
244pub fn filter_edit_history_by_path<'a>(
245    edit_history: &'a [Arc<zeta_prompt::Event>],
246    cursor_path: &std::path::Path,
247) -> Vec<&'a zeta_prompt::Event> {
248    edit_history
249        .iter()
250        .filter(|event| match event.as_ref() {
251            zeta_prompt::Event::BufferChange { path, .. } => {
252                let event_path = path.as_ref();
253                if event_path == cursor_path {
254                    return true;
255                }
256                let stripped = event_path
257                    .components()
258                    .skip(1)
259                    .collect::<std::path::PathBuf>();
260                stripped == cursor_path
261            }
262        })
263        .map(|arc| arc.as_ref())
264        .collect()
265}
266
267pub fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str {
268    match event {
269        zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(),
270    }
271}
272
273pub fn compute_prediction_reversal_ratio(
274    prompt_inputs: &ExamplePromptInputs,
275    predicted_content: &str,
276    cursor_path: &Path,
277) -> f32 {
278    let current_content = &prompt_inputs.content;
279
280    let edit_history: &[Arc<zeta_prompt::Event>] = &prompt_inputs.edit_history;
281    let relevant_events = filter_edit_history_by_path(edit_history, cursor_path);
282
283    let mut original_content = current_content.to_string();
284    for event in relevant_events.into_iter().rev() {
285        let diff = extract_diff_from_event(event);
286        if diff.is_empty() {
287            continue;
288        }
289        let reversed = reverse_diff(diff);
290        let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed);
291        match apply_diff_to_string(&with_headers, &original_content) {
292            Ok(updated_content) => original_content = updated_content,
293            Err(err) => {
294                log::warn!(
295                    "Failed to reconstruct original content for reversal tracking: Failed to apply reversed diff: {:#}",
296                    err
297                );
298                return 0.0;
299            }
300        }
301    }
302
303    let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content);
304    overlap.ratio()
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use edit_prediction::udiff::apply_diff_to_string;
311
312    #[test]
313    fn test_reversal_overlap() {
314        struct Case {
315            name: &'static str,
316            original: &'static str,
317            current: &'static str,
318            predicted: &'static str,
319            expected_reversal_chars: usize,
320            expected_total_chars: usize,
321        }
322
323        let cases = [
324            Case {
325                name: "user_adds_line_prediction_removes_it",
326                original: "a\nb\nc",
327                current: "a\nnew line\nb\nc",
328                predicted: "a\nb\nc",
329                expected_reversal_chars: 9,
330                expected_total_chars: 9,
331            },
332            Case {
333                name: "user_deletes_line_prediction_restores_it",
334                original: "a\ndeleted\nb",
335                current: "a\nb",
336                predicted: "a\ndeleted\nb",
337                expected_reversal_chars: 8,
338                expected_total_chars: 8,
339            },
340            Case {
341                name: "user_deletes_text_prediction_restores_partial",
342                original: "hello beautiful world",
343                current: "hello world",
344                predicted: "hello beautiful world",
345                expected_reversal_chars: 10,
346                expected_total_chars: 10,
347            },
348            Case {
349                name: "user_deletes_foo_prediction_adds_bar",
350                original: "foo",
351                current: "",
352                predicted: "bar",
353                expected_reversal_chars: 0,
354                expected_total_chars: 3,
355            },
356            Case {
357                name: "independent_edits_different_locations",
358                original: "line1\nline2\nline3",
359                current: "LINE1\nline2\nline3",
360                predicted: "LINE1\nline2\nLINE3",
361                expected_reversal_chars: 0,
362                expected_total_chars: 10,
363            },
364            Case {
365                name: "no_history_edits",
366                original: "same",
367                current: "same",
368                predicted: "different",
369                expected_reversal_chars: 0,
370                expected_total_chars: 13,
371            },
372            Case {
373                name: "user_replaces_text_prediction_reverses",
374                original: "keep\ndelete_me\nkeep2",
375                current: "keep\nadded\nkeep2",
376                predicted: "keep\ndelete_me\nkeep2",
377                expected_reversal_chars: 14,
378                expected_total_chars: 14,
379            },
380            Case {
381                name: "user_modifies_word_prediction_modifies_differently",
382                original: "the quick brown fox",
383                current: "the slow brown fox",
384                predicted: "the fast brown fox",
385                expected_reversal_chars: 4,
386                expected_total_chars: 8,
387            },
388            Case {
389                name: "user finishes function name (suffix)",
390                original: "",
391                current: "epr",
392                predicted: "eprintln!()",
393                expected_reversal_chars: 0,
394                expected_total_chars: 8,
395            },
396            Case {
397                name: "user starts function name (prefix)",
398                original: "",
399                current: "my_function()",
400                predicted: "test_my_function()",
401                expected_reversal_chars: 0,
402                expected_total_chars: 5,
403            },
404            Case {
405                name: "user types partial, prediction extends in multiple places",
406                original: "",
407                current: "test_my_function",
408                predicted: "a_test_for_my_special_function_plz",
409                expected_reversal_chars: 0,
410                expected_total_chars: 18,
411            },
412            // Edge cases for subsequence matching
413            Case {
414                name: "subsequence with interleaved underscores",
415                original: "",
416                current: "a_b_c",
417                predicted: "_a__b__c__",
418                expected_reversal_chars: 0,
419                expected_total_chars: 5,
420            },
421            Case {
422                name: "not a subsequence - different characters",
423                original: "",
424                current: "abc",
425                predicted: "xyz",
426                expected_reversal_chars: 3,
427                expected_total_chars: 6,
428            },
429            Case {
430                name: "not a subsequence - wrong order",
431                original: "",
432                current: "abc",
433                predicted: "cba",
434                expected_reversal_chars: 3,
435                expected_total_chars: 6,
436            },
437            Case {
438                name: "partial subsequence - only some chars match",
439                original: "",
440                current: "abcd",
441                predicted: "axbx",
442                expected_reversal_chars: 4,
443                expected_total_chars: 8,
444            },
445            // Common completion patterns
446            Case {
447                name: "completing a method call",
448                original: "",
449                current: "vec.pu",
450                predicted: "vec.push(item)",
451                expected_reversal_chars: 0,
452                expected_total_chars: 8,
453            },
454            Case {
455                name: "completing an import statement",
456                original: "",
457                current: "use std::col",
458                predicted: "use std::collections::HashMap",
459                expected_reversal_chars: 0,
460                expected_total_chars: 17,
461            },
462            Case {
463                name: "completing a struct field",
464                original: "",
465                current: "name: St",
466                predicted: "name: String",
467                expected_reversal_chars: 0,
468                expected_total_chars: 4,
469            },
470            Case {
471                name: "prediction replaces with completely different text",
472                original: "",
473                current: "hello",
474                predicted: "world",
475                expected_reversal_chars: 5,
476                expected_total_chars: 10,
477            },
478            Case {
479                name: "empty prediction removes user text",
480                original: "",
481                current: "mistake",
482                predicted: "",
483                expected_reversal_chars: 7,
484                expected_total_chars: 7,
485            },
486        ];
487
488        for case in &cases {
489            let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
490            assert_eq!(
491                overlap.chars_reversing_user_edits, case.expected_reversal_chars,
492                "Test '{}': expected {} reversal chars, got {}",
493                case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
494            );
495            assert_eq!(
496                overlap.total_chars_in_prediction, case.expected_total_chars,
497                "Test '{}': expected {} total chars, got {}",
498                case.name, case.expected_total_chars, overlap.total_chars_in_prediction
499            );
500        }
501    }
502
503    #[test]
504    fn test_reverse_diff() {
505        let forward_diff = "\
506--- a/file.rs
507+++ b/file.rs
508@@ -1,3 +1,4 @@
509 fn main() {
510+    let x = 42;
511     println!(\"hello\");
512}";
513
514        let reversed = reverse_diff(forward_diff);
515
516        assert!(
517            reversed.contains("+++ a/file.rs"),
518            "Should have +++ for old path"
519        );
520        assert!(
521            reversed.contains("--- b/file.rs"),
522            "Should have --- for new path"
523        );
524        assert!(
525            reversed.contains("-    let x = 42;"),
526            "Added line should become deletion"
527        );
528        assert!(
529            reversed.contains(" fn main()"),
530            "Context lines should be unchanged"
531        );
532    }
533
534    #[test]
535    fn test_reverse_diff_roundtrip() {
536        // Applying a diff and then its reverse should get back to original
537        let original = "first line\nhello world\nlast line\n";
538        let modified = "first line\nhello beautiful world\nlast line\n";
539
540        // unified_diff doesn't include file headers, but apply_diff_to_string needs them
541        let diff_body = language::unified_diff(original, modified);
542        let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body);
543        let reversed_diff = reverse_diff(&forward_diff);
544
545        // Apply forward diff to original
546        let after_forward = apply_diff_to_string(&forward_diff, original).unwrap();
547        assert_eq!(after_forward, modified);
548
549        // Apply reversed diff to modified
550        let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap();
551        assert_eq!(after_reverse, original);
552    }
553
554    #[test]
555    fn test_filter_edit_history_by_path() {
556        // Test that filter_edit_history_by_path correctly matches paths when
557        // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs")
558        // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs")
559        let events = vec![
560            Arc::new(zeta_prompt::Event::BufferChange {
561                path: Arc::from(Path::new("myrepo/src/file.rs")),
562                old_path: Arc::from(Path::new("myrepo/src/file.rs")),
563                diff: "@@ -1 +1 @@\n-old\n+new".into(),
564                predicted: false,
565                in_open_source_repo: true,
566            }),
567            Arc::new(zeta_prompt::Event::BufferChange {
568                path: Arc::from(Path::new("myrepo/other.rs")),
569                old_path: Arc::from(Path::new("myrepo/other.rs")),
570                diff: "@@ -1 +1 @@\n-a\n+b".into(),
571                predicted: false,
572                in_open_source_repo: true,
573            }),
574            Arc::new(zeta_prompt::Event::BufferChange {
575                path: Arc::from(Path::new("src/file.rs")),
576                old_path: Arc::from(Path::new("src/file.rs")),
577                diff: "@@ -1 +1 @@\n-x\n+y".into(),
578                predicted: false,
579                in_open_source_repo: true,
580            }),
581        ];
582
583        // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path
584        // "src/file.rs" exact match
585        let cursor_path = Path::new("src/file.rs");
586        let filtered = filter_edit_history_by_path(&events, cursor_path);
587        assert_eq!(
588            filtered.len(),
589            2,
590            "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)"
591        );
592
593        // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs"
594        // "src/file.rs" stripped -> "file.rs" == "file.rs"
595        let cursor_path = Path::new("file.rs");
596        let filtered = filter_edit_history_by_path(&events, cursor_path);
597        assert_eq!(
598            filtered.len(),
599            1,
600            "Should only match src/file.rs (stripped to file.rs)"
601        );
602
603        // "myrepo/other.rs" stripped -> "other.rs" == "other.rs"
604        let cursor_path = Path::new("other.rs");
605        let filtered = filter_edit_history_by_path(&events, cursor_path);
606        assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs");
607    }
608
609    #[test]
610    fn test_reverse_diff_preserves_trailing_newline() {
611        let diff_with_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new\n";
612        let reversed = reverse_diff(diff_with_trailing_newline);
613        assert!(
614            reversed.ends_with('\n'),
615            "Reversed diff should preserve trailing newline"
616        );
617
618        let diff_without_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new";
619        let reversed = reverse_diff(diff_without_trailing_newline);
620        assert!(
621            !reversed.ends_with('\n'),
622            "Reversed diff should not add trailing newline if original didn't have one"
623        );
624    }
625}