diff --git a/crates/edit_prediction_cli/src/reversal_tracking.rs b/crates/edit_prediction_cli/src/reversal_tracking.rs index a23343dea3dec4d18b6b24833c50efe85014e247..a014539650748188e19e2c3c8a687ed3a4b0ac6f 100644 --- a/crates/edit_prediction_cli/src/reversal_tracking.rs +++ b/crates/edit_prediction_cli/src/reversal_tracking.rs @@ -32,13 +32,13 @@ pub fn reverse_diff(diff: &str) -> String { } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct GranularEdit { - pub range: Range, - pub old_text: String, - pub new_text: String, +struct GranularEdit { + range: Range, + old_text: String, + new_text: String, } -pub fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec { +fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec { text_diff(old_text, new_text) .into_iter() .map(|(range, new_text)| GranularEdit { @@ -50,18 +50,16 @@ pub fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec, +struct HistoryAdditionRange { + range_in_current: Range, } #[derive(Debug, Clone)] -pub struct HistoryDeletionRange { - pub deleted_text: String, +struct HistoryDeletionRange { + deleted_text: String, } -pub fn compute_history_addition_ranges( - history_edits: &[GranularEdit], -) -> Vec { +fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec { let mut result = Vec::new(); let mut offset_delta: isize = 0; @@ -80,9 +78,7 @@ pub fn compute_history_addition_ranges( result } -pub fn compute_history_deletion_ranges( - history_edits: &[GranularEdit], -) -> Vec { +fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec { history_edits .iter() .filter(|edit| !edit.old_text.is_empty()) @@ -93,13 +89,13 @@ pub fn compute_history_deletion_ranges( } #[derive(Debug, Clone, Default, PartialEq, Eq)] -pub struct ReversalOverlap { - pub chars_reversing_user_edits: usize, - pub total_chars_in_prediction: usize, +struct ReversalOverlap { + chars_reversing_user_edits: usize, + total_chars_in_prediction: usize, } impl ReversalOverlap { - pub fn ratio(&self) -> f32 { + fn ratio(&self) -> f32 { if self.total_chars_in_prediction == 0 { 0.0 } else { @@ -108,14 +104,52 @@ impl ReversalOverlap { } } -/// Compute how much of a prediction reverses recent user edits. -pub fn compute_reversal_overlap( +/// Check if `needle` is a subsequence of `haystack` (characters appear in order, not necessarily contiguous). +fn is_subsequence(needle: &str, haystack: &str) -> bool { + let mut needle_chars = needle.chars().peekable(); + for c in haystack.chars() { + if needle_chars.peek() == Some(&c) { + needle_chars.next(); + } + } + needle_chars.peek().is_none() +} + +/// Normalize edits where `old_text` appears as a subsequence within `new_text`. +/// When the user's text is preserved (in order) within the prediction, we only +/// count the newly inserted characters, not the preserved ones. +/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()") +/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars +fn normalize_extension_edits(edits: Vec) -> Vec { + edits + .into_iter() + .map(|edit| { + if edit.old_text.is_empty() { + return edit; + } + + if is_subsequence(&edit.old_text, &edit.new_text) { + let inserted_len = edit.new_text.len() - edit.old_text.len(); + GranularEdit { + range: edit.range.start..edit.range.start, + old_text: String::new(), + new_text: edit.new_text.chars().take(inserted_len).collect(), + } + } else { + edit + } + }) + .collect() +} + +fn compute_reversal_overlap( original_content: &str, current_content: &str, predicted_content: &str, ) -> ReversalOverlap { let history_edits = compute_granular_edits(original_content, current_content); - let prediction_edits = compute_granular_edits(current_content, predicted_content); + let prediction_edits = + normalize_extension_edits(compute_granular_edits(current_content, predicted_content)); let history_addition_ranges = compute_history_addition_ranges(&history_edits); let history_deletion_ranges = compute_history_deletion_ranges(&history_edits); @@ -125,16 +159,18 @@ pub fn compute_reversal_overlap( let restored_deletions = compute_restored_deletions(&history_deletion_ranges, &prediction_edits); - let prediction_added_chars: usize = prediction_edits.iter().map(|e| e.new_text.len()).sum(); - let prediction_deleted_chars: usize = prediction_edits.iter().map(|e| e.old_text.len()).sum(); + let total_chars_in_prediction: usize = prediction_edits + .iter() + .map(|e| e.new_text.len() + e.old_text.len()) + .sum(); ReversalOverlap { chars_reversing_user_edits: reversed_additions + restored_deletions, - total_chars_in_prediction: prediction_added_chars + prediction_deleted_chars, + total_chars_in_prediction, } } -pub fn compute_reversed_additions( +fn compute_reversed_additions( history_addition_ranges: &[HistoryAdditionRange], prediction_edits: &[GranularEdit], ) -> usize { @@ -160,7 +196,7 @@ pub fn compute_reversed_additions( reversed_chars } -pub fn compute_restored_deletions( +fn compute_restored_deletions( history_deletion_ranges: &[HistoryDeletionRange], prediction_edits: &[GranularEdit], ) -> usize { @@ -349,6 +385,104 @@ mod tests { expected_reversal_chars: 4, expected_total_chars: 8, }, + Case { + name: "user finishes function name (suffix)", + original: "", + current: "epr", + predicted: "eprintln!()", + expected_reversal_chars: 0, + expected_total_chars: 8, + }, + Case { + name: "user starts function name (prefix)", + original: "", + current: "my_function()", + predicted: "test_my_function()", + expected_reversal_chars: 0, + expected_total_chars: 5, + }, + Case { + name: "user types partial, prediction extends in multiple places", + original: "", + current: "test_my_function", + predicted: "a_test_for_my_special_function_plz", + expected_reversal_chars: 0, + expected_total_chars: 18, + }, + // Edge cases for subsequence matching + Case { + name: "subsequence with interleaved underscores", + original: "", + current: "a_b_c", + predicted: "_a__b__c__", + expected_reversal_chars: 0, + expected_total_chars: 5, + }, + Case { + name: "not a subsequence - different characters", + original: "", + current: "abc", + predicted: "xyz", + expected_reversal_chars: 3, + expected_total_chars: 6, + }, + Case { + name: "not a subsequence - wrong order", + original: "", + current: "abc", + predicted: "cba", + expected_reversal_chars: 3, + expected_total_chars: 6, + }, + Case { + name: "partial subsequence - only some chars match", + original: "", + current: "abcd", + predicted: "axbx", + expected_reversal_chars: 4, + expected_total_chars: 8, + }, + // Common completion patterns + Case { + name: "completing a method call", + original: "", + current: "vec.pu", + predicted: "vec.push(item)", + expected_reversal_chars: 0, + expected_total_chars: 8, + }, + Case { + name: "completing an import statement", + original: "", + current: "use std::col", + predicted: "use std::collections::HashMap", + expected_reversal_chars: 0, + expected_total_chars: 17, + }, + Case { + name: "completing a struct field", + original: "", + current: "name: St", + predicted: "name: String", + expected_reversal_chars: 0, + expected_total_chars: 4, + }, + Case { + name: "prediction replaces with completely different text", + original: "", + current: "hello", + predicted: "world", + expected_reversal_chars: 5, + expected_total_chars: 10, + }, + Case { + name: "empty prediction removes user text", + original: "", + current: "mistake", + predicted: "", + expected_reversal_chars: 7, + expected_total_chars: 7, + }, ]; for case in &cases {