diff --git a/crates/edit_prediction_cli/src/reversal_tracking.rs b/crates/edit_prediction_cli/src/reversal_tracking.rs index a014539650748188e19e2c3c8a687ed3a4b0ac6f..6f1be51c79df3d002004a09c5d0977d5f2a215eb 100644 --- a/crates/edit_prediction_cli/src/reversal_tracking.rs +++ b/crates/edit_prediction_cli/src/reversal_tracking.rs @@ -115,16 +115,22 @@ fn is_subsequence(needle: &str, haystack: &str) -> bool { needle_chars.peek().is_none() } -/// Normalize edits where `old_text` appears as a subsequence within `new_text`. -/// When the user's text is preserved (in order) within the prediction, we only -/// count the newly inserted characters, not the preserved ones. +/// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension), +/// or where `new_text` appears as a subsequence within `old_text` (reduction). +/// +/// For extensions: when the user's text is preserved (in order) within the prediction, +/// we only count the newly inserted characters, not the preserved ones. /// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()") /// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars +/// +/// For reductions: when the prediction's text is preserved (in order) within the original, +/// we only count the deleted characters, not the preserved ones. +/// E.g., "ifrom" → "from" becomes 1 deleted char ("i") fn normalize_extension_edits(edits: Vec) -> Vec { edits .into_iter() .map(|edit| { - if edit.old_text.is_empty() { + if edit.old_text.is_empty() || edit.new_text.is_empty() { return edit; } @@ -135,6 +141,13 @@ fn normalize_extension_edits(edits: Vec) -> Vec { old_text: String::new(), new_text: edit.new_text.chars().take(inserted_len).collect(), } + } else if is_subsequence(&edit.new_text, &edit.old_text) { + let deleted_len = edit.old_text.len() - edit.new_text.len(); + GranularEdit { + range: edit.range.start..edit.range.start + deleted_len, + old_text: edit.old_text.chars().take(deleted_len).collect(), + new_text: String::new(), + } } else { edit } @@ -147,7 +160,8 @@ fn compute_reversal_overlap( current_content: &str, predicted_content: &str, ) -> ReversalOverlap { - let history_edits = compute_granular_edits(original_content, current_content); + let history_edits = + normalize_extension_edits(compute_granular_edits(original_content, current_content)); let prediction_edits = normalize_extension_edits(compute_granular_edits(current_content, predicted_content)); @@ -483,6 +497,33 @@ mod tests { expected_reversal_chars: 7, expected_total_chars: 7, }, + Case { + name: "fixing typo is not reversal", + original: "", + current: "", + expected_reversal_chars: 0, + expected_total_chars: 2, + }, + Case { + name: "infix insertion not reversal", + original: "from my_project import Foo\n", + current: "ifrom my_project import Foo\n", + predicted: indoc::indoc! {" + import + from my_project import Foo + "}, + expected_reversal_chars: 0, + expected_total_chars: 6, + }, + Case { + name: "non-word based reversal", + original: "from", + current: "ifrom", + predicted: "from", + expected_reversal_chars: 1, + expected_total_chars: 1, + }, ]; for case in &cases {