@@ -32,13 +32,13 @@ pub fn reverse_diff(diff: &str) -> String {
}
#[derive(Debug, Clone, PartialEq, Eq)]
-pub struct GranularEdit {
- pub range: Range<usize>,
- pub old_text: String,
- pub new_text: String,
+struct GranularEdit {
+ range: Range<usize>,
+ old_text: String,
+ new_text: String,
}
-pub fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
+fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdit> {
text_diff(old_text, new_text)
.into_iter()
.map(|(range, new_text)| GranularEdit {
@@ -50,18 +50,16 @@ pub fn compute_granular_edits(old_text: &str, new_text: &str) -> Vec<GranularEdi
}
#[derive(Debug, Clone)]
-pub struct HistoryAdditionRange {
- pub range_in_current: Range<usize>,
+struct HistoryAdditionRange {
+ range_in_current: Range<usize>,
}
#[derive(Debug, Clone)]
-pub struct HistoryDeletionRange {
- pub deleted_text: String,
+struct HistoryDeletionRange {
+ deleted_text: String,
}
-pub fn compute_history_addition_ranges(
- history_edits: &[GranularEdit],
-) -> Vec<HistoryAdditionRange> {
+fn compute_history_addition_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryAdditionRange> {
let mut result = Vec::new();
let mut offset_delta: isize = 0;
@@ -80,9 +78,7 @@ pub fn compute_history_addition_ranges(
result
}
-pub fn compute_history_deletion_ranges(
- history_edits: &[GranularEdit],
-) -> Vec<HistoryDeletionRange> {
+fn compute_history_deletion_ranges(history_edits: &[GranularEdit]) -> Vec<HistoryDeletionRange> {
history_edits
.iter()
.filter(|edit| !edit.old_text.is_empty())
@@ -93,13 +89,13 @@ pub fn compute_history_deletion_ranges(
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
-pub struct ReversalOverlap {
- pub chars_reversing_user_edits: usize,
- pub total_chars_in_prediction: usize,
+struct ReversalOverlap {
+ chars_reversing_user_edits: usize,
+ total_chars_in_prediction: usize,
}
impl ReversalOverlap {
- pub fn ratio(&self) -> f32 {
+ fn ratio(&self) -> f32 {
if self.total_chars_in_prediction == 0 {
0.0
} else {
@@ -108,14 +104,52 @@ impl ReversalOverlap {
}
}
-/// Compute how much of a prediction reverses recent user edits.
-pub fn compute_reversal_overlap(
+/// Check if `needle` is a subsequence of `haystack` (characters appear in order, not necessarily contiguous).
+fn is_subsequence(needle: &str, haystack: &str) -> bool {
+ let mut needle_chars = needle.chars().peekable();
+ for c in haystack.chars() {
+ if needle_chars.peek() == Some(&c) {
+ needle_chars.next();
+ }
+ }
+ needle_chars.peek().is_none()
+}
+
+/// Normalize edits where `old_text` appears as a subsequence within `new_text`.
+/// When the user's text is preserved (in order) within the prediction, we only
+/// count the newly inserted characters, not the preserved ones.
+/// E.g., "epr" → "eprintln!()" becomes 8 inserted chars ("intln!()")
+/// E.g., "test_my_function" → "a_test_for_my_special_function_plz" becomes 18 inserted chars
+fn normalize_extension_edits(edits: Vec<GranularEdit>) -> Vec<GranularEdit> {
+ edits
+ .into_iter()
+ .map(|edit| {
+ if edit.old_text.is_empty() {
+ return edit;
+ }
+
+ if is_subsequence(&edit.old_text, &edit.new_text) {
+ let inserted_len = edit.new_text.len() - edit.old_text.len();
+ GranularEdit {
+ range: edit.range.start..edit.range.start,
+ old_text: String::new(),
+ new_text: edit.new_text.chars().take(inserted_len).collect(),
+ }
+ } else {
+ edit
+ }
+ })
+ .collect()
+}
+
+fn compute_reversal_overlap(
original_content: &str,
current_content: &str,
predicted_content: &str,
) -> ReversalOverlap {
let history_edits = compute_granular_edits(original_content, current_content);
- let prediction_edits = compute_granular_edits(current_content, predicted_content);
+ let prediction_edits =
+ normalize_extension_edits(compute_granular_edits(current_content, predicted_content));
let history_addition_ranges = compute_history_addition_ranges(&history_edits);
let history_deletion_ranges = compute_history_deletion_ranges(&history_edits);
@@ -125,16 +159,18 @@ pub fn compute_reversal_overlap(
let restored_deletions =
compute_restored_deletions(&history_deletion_ranges, &prediction_edits);
- let prediction_added_chars: usize = prediction_edits.iter().map(|e| e.new_text.len()).sum();
- let prediction_deleted_chars: usize = prediction_edits.iter().map(|e| e.old_text.len()).sum();
+ let total_chars_in_prediction: usize = prediction_edits
+ .iter()
+ .map(|e| e.new_text.len() + e.old_text.len())
+ .sum();
ReversalOverlap {
chars_reversing_user_edits: reversed_additions + restored_deletions,
- total_chars_in_prediction: prediction_added_chars + prediction_deleted_chars,
+ total_chars_in_prediction,
}
}
-pub fn compute_reversed_additions(
+fn compute_reversed_additions(
history_addition_ranges: &[HistoryAdditionRange],
prediction_edits: &[GranularEdit],
) -> usize {
@@ -160,7 +196,7 @@ pub fn compute_reversed_additions(
reversed_chars
}
-pub fn compute_restored_deletions(
+fn compute_restored_deletions(
history_deletion_ranges: &[HistoryDeletionRange],
prediction_edits: &[GranularEdit],
) -> usize {
@@ -349,6 +385,104 @@ mod tests {
expected_reversal_chars: 4,
expected_total_chars: 8,
},
+ Case {
+ name: "user finishes function name (suffix)",
+ original: "",
+ current: "epr",
+ predicted: "eprintln!()",
+ expected_reversal_chars: 0,
+ expected_total_chars: 8,
+ },
+ Case {
+ name: "user starts function name (prefix)",
+ original: "",
+ current: "my_function()",
+ predicted: "test_my_function()",
+ expected_reversal_chars: 0,
+ expected_total_chars: 5,
+ },
+ Case {
+ name: "user types partial, prediction extends in multiple places",
+ original: "",
+ current: "test_my_function",
+ predicted: "a_test_for_my_special_function_plz",
+ expected_reversal_chars: 0,
+ expected_total_chars: 18,
+ },
+ // Edge cases for subsequence matching
+ Case {
+ name: "subsequence with interleaved underscores",
+ original: "",
+ current: "a_b_c",
+ predicted: "_a__b__c__",
+ expected_reversal_chars: 0,
+ expected_total_chars: 5,
+ },
+ Case {
+ name: "not a subsequence - different characters",
+ original: "",
+ current: "abc",
+ predicted: "xyz",
+ expected_reversal_chars: 3,
+ expected_total_chars: 6,
+ },
+ Case {
+ name: "not a subsequence - wrong order",
+ original: "",
+ current: "abc",
+ predicted: "cba",
+ expected_reversal_chars: 3,
+ expected_total_chars: 6,
+ },
+ Case {
+ name: "partial subsequence - only some chars match",
+ original: "",
+ current: "abcd",
+ predicted: "axbx",
+ expected_reversal_chars: 4,
+ expected_total_chars: 8,
+ },
+ // Common completion patterns
+ Case {
+ name: "completing a method call",
+ original: "",
+ current: "vec.pu",
+ predicted: "vec.push(item)",
+ expected_reversal_chars: 0,
+ expected_total_chars: 8,
+ },
+ Case {
+ name: "completing an import statement",
+ original: "",
+ current: "use std::col",
+ predicted: "use std::collections::HashMap",
+ expected_reversal_chars: 0,
+ expected_total_chars: 17,
+ },
+ Case {
+ name: "completing a struct field",
+ original: "",
+ current: "name: St",
+ predicted: "name: String",
+ expected_reversal_chars: 0,
+ expected_total_chars: 4,
+ },
+ Case {
+ name: "prediction replaces with completely different text",
+ original: "",
+ current: "hello",
+ predicted: "world",
+ expected_reversal_chars: 5,
+ expected_total_chars: 10,
+ },
+ Case {
+ name: "empty prediction removes user text",
+ original: "",
+ current: "mistake",
+ predicted: "",
+ expected_reversal_chars: 7,
+ expected_total_chars: 7,
+ },
];
for case in &cases {