reversal_tracking.rs

  1use std::ops::Range;
  2use std::path::Path;
  3use std::sync::Arc;
  4
  5use edit_prediction::udiff::apply_diff_to_string;
  6use language::text_diff;
  7
  8use crate::example::ExamplePromptInputs;
  9
 10pub fn reverse_diff(diff: &str) -> String {
 11    let mut result: String = diff
 12        .lines()
 13        .map(|line| {
 14            if line.starts_with("--- ") {
 15                line.replacen("--- ", "+++ ", 1)
 16            } else if line.starts_with("+++ ") {
 17                line.replacen("+++ ", "--- ", 1)
 18            } else if line.starts_with('+') && !line.starts_with("+++") {
 19                format!("-{}", &line[1..])
 20            } else if line.starts_with('-') && !line.starts_with("---") {
 21                format!("+{}", &line[1..])
 22            } else {
 23                line.to_string()
 24            }
 25        })
 26        .collect::<Vec<_>>()
 27        .join("\n");
 28    if diff.ends_with('\n') {
 29        result.push('\n');
 30    }
 31    result
 32}
 33
 34#[derive(Debug, Clone, PartialEq, Eq)]
 35pub struct GranularEdit {
 36    pub range: Range<usize>,
 37    pub old_text: String,
 38    pub new_text: String,
 39}
 40
 41pub fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
 42    text_diff(old_text, new_text)
 43        .into_iter()
 44        .map(|(range, new_text)| GranularEdit {
 45            old_text: old_text[range.clone()].to_string(),
 46            range,
 47            new_text: new_text.to_string(),
 48        })
 49        .collect()
 50}
 51
 52#[derive(Debug, Clone)]
 53pub struct HistoryAdditionRange {
 54    pub range_in_current: Range<usize>,
 55}
 56
 57#[derive(Debug, Clone)]
 58pub struct HistoryDeletionRange {
 59    pub deleted_text: String,
 60}
 61
 62pub fn compute_history_addition_ranges(
 63    history_edits: &[GranularEdit],
 64) -> Vec<HistoryAdditionRange> {
 65    let mut result = Vec::new();
 66    let mut offset_delta: isize = 0;
 67
 68    for edit in history_edits {
 69        if !edit.new_text.is_empty() {
 70            let new_start = (edit.range.start as isize + offset_delta) as usize;
 71            let new_end = new_start + edit.new_text.len();
 72            result.push(HistoryAdditionRange {
 73                range_in_current: new_start..new_end,
 74            });
 75        }
 76
 77        offset_delta += edit.new_text.len() as isize - edit.old_text.len() as isize;
 78    }
 79
 80    result
 81}
 82
 83pub fn compute_history_deletion_ranges(
 84    history_edits: &[GranularEdit],
 85) -> Vec<HistoryDeletionRange> {
 86    history_edits
 87        .iter()
 88        .filter(|edit| !edit.old_text.is_empty())
 89        .map(|edit| HistoryDeletionRange {
 90            deleted_text: edit.old_text.clone(),
 91        })
 92        .collect()
 93}
 94
 95#[derive(Debug, Clone, Default, PartialEq, Eq)]
 96pub struct ReversalOverlap {
 97    pub chars_reversing_user_edits: usize,
 98    pub total_chars_in_prediction: usize,
 99}
100
101impl ReversalOverlap {
102    pub fn ratio(&self) -> f32 {
103        if self.total_chars_in_prediction == 0 {
104            0.0
105        } else {
106            self.chars_reversing_user_edits as f32 / self.total_chars_in_prediction as f32
107        }
108    }
109}
110
111/// Compute how much of a prediction reverses recent user edits.
112pub fn compute_reversal_overlap(
113    original_content: &str,
114    current_content: &str,
115    predicted_content: &str,
116) -> ReversalOverlap {
117    let history_edits = compute_granular_edits(original_content, current_content);
118    let prediction_edits = compute_granular_edits(current_content, predicted_content);
119
120    let history_addition_ranges = compute_history_addition_ranges(&history_edits);
121    let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
122
123    let reversed_additions =
124        compute_reversed_additions(&history_addition_ranges, &prediction_edits);
125    let restored_deletions =
126        compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
127
128    let prediction_added_chars: usize = prediction_edits.iter().map(|e| e.new_text.len()).sum();
129    let prediction_deleted_chars: usize = prediction_edits.iter().map(|e| e.old_text.len()).sum();
130
131    ReversalOverlap {
132        chars_reversing_user_edits: reversed_additions + restored_deletions,
133        total_chars_in_prediction: prediction_added_chars + prediction_deleted_chars,
134    }
135}
136
137pub fn compute_reversed_additions(
138    history_addition_ranges: &[HistoryAdditionRange],
139    prediction_edits: &[GranularEdit],
140) -> usize {
141    let mut reversed_chars = 0;
142
143    for pred_edit in prediction_edits {
144        for history_addition in history_addition_ranges {
145            let overlap_start = pred_edit
146                .range
147                .start
148                .max(history_addition.range_in_current.start);
149            let overlap_end = pred_edit
150                .range
151                .end
152                .min(history_addition.range_in_current.end);
153
154            if overlap_start < overlap_end {
155                reversed_chars += overlap_end - overlap_start;
156            }
157        }
158    }
159
160    reversed_chars
161}
162
163pub fn compute_restored_deletions(
164    history_deletion_ranges: &[HistoryDeletionRange],
165    prediction_edits: &[GranularEdit],
166) -> usize {
167    let history_deleted_text: String = history_deletion_ranges
168        .iter()
169        .map(|r| r.deleted_text.as_str())
170        .collect();
171
172    let prediction_added_text: String = prediction_edits
173        .iter()
174        .map(|e| e.new_text.as_str())
175        .collect();
176
177    compute_lcs_length(&history_deleted_text, &prediction_added_text)
178}
179
180fn compute_lcs_length(a: &str, b: &str) -> usize {
181    let a_chars: Vec<char> = a.chars().collect();
182    let b_chars: Vec<char> = b.chars().collect();
183    let m = a_chars.len();
184    let n = b_chars.len();
185
186    if m == 0 || n == 0 {
187        return 0;
188    }
189
190    let mut prev = vec![0; n + 1];
191    let mut curr = vec![0; n + 1];
192
193    for i in 1..=m {
194        for j in 1..=n {
195            if a_chars[i - 1] == b_chars[j - 1] {
196                curr[j] = prev[j - 1] + 1;
197            } else {
198                curr[j] = prev[j].max(curr[j - 1]);
199            }
200        }
201        std::mem::swap(&mut prev, &mut curr);
202        curr.fill(0);
203    }
204
205    prev[n]
206}
207
208pub fn filter_edit_history_by_path<'a>(
209    edit_history: &'a [Arc<zeta_prompt::Event>],
210    cursor_path: &std::path::Path,
211) -> Vec<&'a zeta_prompt::Event> {
212    edit_history
213        .iter()
214        .filter(|event| match event.as_ref() {
215            zeta_prompt::Event::BufferChange { path, .. } => {
216                let event_path = path.as_ref();
217                if event_path == cursor_path {
218                    return true;
219                }
220                let stripped = event_path
221                    .components()
222                    .skip(1)
223                    .collect::<std::path::PathBuf>();
224                stripped == cursor_path
225            }
226        })
227        .map(|arc| arc.as_ref())
228        .collect()
229}
230
231pub fn extract_diff_from_event(event: &zeta_prompt::Event) -> &str {
232    match event {
233        zeta_prompt::Event::BufferChange { diff, .. } => diff.as_str(),
234    }
235}
236
237pub fn compute_prediction_reversal_ratio(
238    prompt_inputs: &ExamplePromptInputs,
239    predicted_content: &str,
240    cursor_path: &Path,
241) -> f32 {
242    let current_content = &prompt_inputs.content;
243
244    let edit_history: &[Arc<zeta_prompt::Event>] = &prompt_inputs.edit_history;
245    let relevant_events = filter_edit_history_by_path(edit_history, cursor_path);
246
247    let mut original_content = current_content.to_string();
248    for event in relevant_events.into_iter().rev() {
249        let diff = extract_diff_from_event(event);
250        if diff.is_empty() {
251            continue;
252        }
253        let reversed = reverse_diff(diff);
254        let with_headers = format!("--- a/file\n+++ b/file\n{}", reversed);
255        match apply_diff_to_string(&with_headers, &original_content) {
256            Ok(updated_content) => original_content = updated_content,
257            Err(err) => {
258                log::warn!(
259                    "Failed to reconstruct original content for reversal tracking: Failed to apply reversed diff: {:#}",
260                    err
261                );
262                return 0.0;
263            }
264        }
265    }
266
267    let overlap = compute_reversal_overlap(&original_content, current_content, predicted_content);
268    overlap.ratio()
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use edit_prediction::udiff::apply_diff_to_string;
275
276    #[test]
277    fn test_reversal_overlap() {
278        struct Case {
279            name: &'static str,
280            original: &'static str,
281            current: &'static str,
282            predicted: &'static str,
283            expected_reversal_chars: usize,
284            expected_total_chars: usize,
285        }
286
287        let cases = [
288            Case {
289                name: "user_adds_line_prediction_removes_it",
290                original: "a\nb\nc",
291                current: "a\nnew line\nb\nc",
292                predicted: "a\nb\nc",
293                expected_reversal_chars: 9,
294                expected_total_chars: 9,
295            },
296            Case {
297                name: "user_deletes_line_prediction_restores_it",
298                original: "a\ndeleted\nb",
299                current: "a\nb",
300                predicted: "a\ndeleted\nb",
301                expected_reversal_chars: 8,
302                expected_total_chars: 8,
303            },
304            Case {
305                name: "user_deletes_text_prediction_restores_partial",
306                original: "hello beautiful world",
307                current: "hello world",
308                predicted: "hello beautiful world",
309                expected_reversal_chars: 10,
310                expected_total_chars: 10,
311            },
312            Case {
313                name: "user_deletes_foo_prediction_adds_bar",
314                original: "foo",
315                current: "",
316                predicted: "bar",
317                expected_reversal_chars: 0,
318                expected_total_chars: 3,
319            },
320            Case {
321                name: "independent_edits_different_locations",
322                original: "line1\nline2\nline3",
323                current: "LINE1\nline2\nline3",
324                predicted: "LINE1\nline2\nLINE3",
325                expected_reversal_chars: 0,
326                expected_total_chars: 10,
327            },
328            Case {
329                name: "no_history_edits",
330                original: "same",
331                current: "same",
332                predicted: "different",
333                expected_reversal_chars: 0,
334                expected_total_chars: 13,
335            },
336            Case {
337                name: "user_replaces_text_prediction_reverses",
338                original: "keep\ndelete_me\nkeep2",
339                current: "keep\nadded\nkeep2",
340                predicted: "keep\ndelete_me\nkeep2",
341                expected_reversal_chars: 14,
342                expected_total_chars: 14,
343            },
344            Case {
345                name: "user_modifies_word_prediction_modifies_differently",
346                original: "the quick brown fox",
347                current: "the slow brown fox",
348                predicted: "the fast brown fox",
349                expected_reversal_chars: 4,
350                expected_total_chars: 8,
351            },
352        ];
353
354        for case in &cases {
355            let overlap = compute_reversal_overlap(case.original, case.current, case.predicted);
356            assert_eq!(
357                overlap.chars_reversing_user_edits, case.expected_reversal_chars,
358                "Test '{}': expected {} reversal chars, got {}",
359                case.name, case.expected_reversal_chars, overlap.chars_reversing_user_edits
360            );
361            assert_eq!(
362                overlap.total_chars_in_prediction, case.expected_total_chars,
363                "Test '{}': expected {} total chars, got {}",
364                case.name, case.expected_total_chars, overlap.total_chars_in_prediction
365            );
366        }
367    }
368
369    #[test]
370    fn test_reverse_diff() {
371        let forward_diff = "\
372--- a/file.rs
373+++ b/file.rs
374@@ -1,3 +1,4 @@
375 fn main() {
376+    let x = 42;
377     println!(\"hello\");
378}";
379
380        let reversed = reverse_diff(forward_diff);
381
382        assert!(
383            reversed.contains("+++ a/file.rs"),
384            "Should have +++ for old path"
385        );
386        assert!(
387            reversed.contains("--- b/file.rs"),
388            "Should have --- for new path"
389        );
390        assert!(
391            reversed.contains("-    let x = 42;"),
392            "Added line should become deletion"
393        );
394        assert!(
395            reversed.contains(" fn main()"),
396            "Context lines should be unchanged"
397        );
398    }
399
400    #[test]
401    fn test_reverse_diff_roundtrip() {
402        // Applying a diff and then its reverse should get back to original
403        let original = "first line\nhello world\nlast line\n";
404        let modified = "first line\nhello beautiful world\nlast line\n";
405
406        // unified_diff doesn't include file headers, but apply_diff_to_string needs them
407        let diff_body = language::unified_diff(original, modified);
408        let forward_diff = format!("--- a/file\n+++ b/file\n{}", diff_body);
409        let reversed_diff = reverse_diff(&forward_diff);
410
411        // Apply forward diff to original
412        let after_forward = apply_diff_to_string(&forward_diff, original).unwrap();
413        assert_eq!(after_forward, modified);
414
415        // Apply reversed diff to modified
416        let after_reverse = apply_diff_to_string(&reversed_diff, &after_forward).unwrap();
417        assert_eq!(after_reverse, original);
418    }
419
420    #[test]
421    fn test_filter_edit_history_by_path() {
422        // Test that filter_edit_history_by_path correctly matches paths when
423        // the edit history has paths with a repo prefix (e.g., "repo/src/file.rs")
424        // but the cursor_path doesn't have the repo prefix (e.g., "src/file.rs")
425        let events = vec![
426            Arc::new(zeta_prompt::Event::BufferChange {
427                path: Arc::from(Path::new("myrepo/src/file.rs")),
428                old_path: Arc::from(Path::new("myrepo/src/file.rs")),
429                diff: "@@ -1 +1 @@\n-old\n+new".into(),
430                predicted: false,
431                in_open_source_repo: true,
432            }),
433            Arc::new(zeta_prompt::Event::BufferChange {
434                path: Arc::from(Path::new("myrepo/other.rs")),
435                old_path: Arc::from(Path::new("myrepo/other.rs")),
436                diff: "@@ -1 +1 @@\n-a\n+b".into(),
437                predicted: false,
438                in_open_source_repo: true,
439            }),
440            Arc::new(zeta_prompt::Event::BufferChange {
441                path: Arc::from(Path::new("src/file.rs")),
442                old_path: Arc::from(Path::new("src/file.rs")),
443                diff: "@@ -1 +1 @@\n-x\n+y".into(),
444                predicted: false,
445                in_open_source_repo: true,
446            }),
447        ];
448
449        // "myrepo/src/file.rs" stripped -> "src/file.rs" matches cursor_path
450        // "src/file.rs" exact match
451        let cursor_path = Path::new("src/file.rs");
452        let filtered = filter_edit_history_by_path(&events, cursor_path);
453        assert_eq!(
454            filtered.len(),
455            2,
456            "Should match myrepo/src/file.rs (stripped) and src/file.rs (exact)"
457        );
458
459        // "myrepo/src/file.rs" stripped -> "src/file.rs" != "file.rs"
460        // "src/file.rs" stripped -> "file.rs" == "file.rs"
461        let cursor_path = Path::new("file.rs");
462        let filtered = filter_edit_history_by_path(&events, cursor_path);
463        assert_eq!(
464            filtered.len(),
465            1,
466            "Should only match src/file.rs (stripped to file.rs)"
467        );
468
469        // "myrepo/other.rs" stripped -> "other.rs" == "other.rs"
470        let cursor_path = Path::new("other.rs");
471        let filtered = filter_edit_history_by_path(&events, cursor_path);
472        assert_eq!(filtered.len(), 1, "Should match only myrepo/other.rs");
473    }
474
475    #[test]
476    fn test_reverse_diff_preserves_trailing_newline() {
477        let diff_with_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new\n";
478        let reversed = reverse_diff(diff_with_trailing_newline);
479        assert!(
480            reversed.ends_with('\n'),
481            "Reversed diff should preserve trailing newline"
482        );
483
484        let diff_without_trailing_newline = "--- a/file\n+++ b/file\n@@ -1 +1 @@\n-old\n+new";
485        let reversed = reverse_diff(diff_without_trailing_newline);
486        assert!(
487            !reversed.ends_with('\n'),
488            "Reversed diff should not add trailing newline if original didn't have one"
489        );
490    }
491}